Coverage for src/flag_gems/experimental_ops/huber_loss.py: 0%
102 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def huber_loss_element_kernel(
8 x_ptr, # pointer to input tensor (broadcasted, contiguous, flattened)
9 y_ptr, # pointer to target tensor (broadcasted, contiguous, flattened)
10 out_ptr, # pointer to output tensor (contiguous, flattened)
11 n_elements, # number of elements
12 delta, # huber delta (scalar)
13 BLOCK_SIZE: tl.constexpr,
14):
15 pid = tl.program_id(axis=0)
16 block_start = pid * BLOCK_SIZE
17 offsets = block_start + tl.arange(0, BLOCK_SIZE)
18 mask = offsets < n_elements
20 x = tl.load(x_ptr + offsets, mask=mask)
21 y = tl.load(y_ptr + offsets, mask=mask)
22 diff = x - y
23 absdiff = tl.abs(diff)
25 # loss = 0.5 * diff^2 if absdiff <= delta else delta * (absdiff - 0.5 * delta)
26 loss_quad = 0.5 * diff * diff
27 loss_linear = delta * (absdiff - 0.5 * delta)
28 loss = tl.where(absdiff <= delta, loss_quad, loss_linear)
30 tl.store(out_ptr + offsets, loss, mask=mask)
33@triton.jit
34def reduce_sum_kernel(
35 x_ptr, # pointer to input tensor (contiguous, flattened)
36 out_ptr, # pointer to single scalar (float32) to accumulate sum into
37 n_elements,
38 BLOCK_SIZE: tl.constexpr,
39):
40 pid = tl.program_id(axis=0)
41 block_start = pid * BLOCK_SIZE
42 offsets = block_start + tl.arange(0, BLOCK_SIZE)
43 mask = offsets < n_elements
44 vals = tl.load(x_ptr + offsets, mask=mask, other=0.0)
45 vals_f32 = vals.to(tl.float32)
46 partial_sum = tl.sum(vals_f32, axis=0)
47 tl.atomic_add(out_ptr, partial_sum)
50def _normalize_reduction(reduction):
51 if isinstance(reduction, str):
52 r = reduction.lower()
53 if r == "none":
54 return 0
55 elif r == "mean":
56 return 1
57 elif r == "sum":
58 return 2
59 else:
60 raise ValueError(f"Unsupported reduction: {reduction}")
61 elif isinstance(reduction, int):
62 if reduction in (0, 1, 2):
63 return reduction
64 else:
65 raise ValueError(f"Unsupported reduction: {reduction}")
66 else:
67 raise ValueError(f"Unsupported reduction type: {type(reduction)}")
70def huber_loss(input, target, reduction=1, delta=1.0):
71 reduction = _normalize_reduction(reduction)
72 if not (input.is_cuda and target.is_cuda):
73 raise AssertionError("Triton kernels require CUDA tensors")
74 device = input.device
75 # Promote dtype similar to PyTorch type promotion rules
76 result_dtype = torch.result_type(input, target)
78 # Broadcast tensors to a common shape
79 x_b, y_b = torch.broadcast_tensors(input.to(result_dtype), target.to(result_dtype))
80 x_b = x_b.contiguous()
81 y_b = y_b.contiguous()
82 numel = x_b.numel()
84 BLOCK_SIZE = 1024
85 grid = lambda meta: (triton.cdiv(numel, meta["BLOCK_SIZE"]),)
87 if reduction == 0: # 'none'
88 out = torch.empty_like(x_b, dtype=result_dtype, device=device)
89 huber_loss_element_kernel[grid](
90 x_b, y_b, out, numel, float(delta), BLOCK_SIZE=BLOCK_SIZE
91 )
92 return out
93 else:
94 # Compute elementwise loss into a temporary buffer
95 tmp = torch.empty_like(x_b, dtype=result_dtype, device=device)
96 huber_loss_element_kernel[grid](
97 x_b, y_b, tmp, numel, float(delta), BLOCK_SIZE=BLOCK_SIZE
98 )
99 # Reduce to scalar using float32 accumulator
100 acc = torch.zeros((), dtype=torch.float32, device=device)
101 reduce_sum_kernel[grid](tmp, acc, numel, BLOCK_SIZE=BLOCK_SIZE)
102 if reduction == 1: # mean
103 val = (acc / numel).to(result_dtype)
104 else: # sum
105 val = acc.to(result_dtype)
106 return val
109def huber_loss_out(input, target, reduction=1, delta=1.0, out=None):
110 if out is None:
111 raise ValueError("huber_loss_out requires an 'out' tensor")
112 reduction = _normalize_reduction(reduction)
113 if not (input.is_cuda and target.is_cuda and out.is_cuda):
114 raise AssertionError("Triton kernels require CUDA tensors")
116 device = input.device
117 # Determine result dtype; use out.dtype if provided to match .out behavior
118 # but ensure it's compatible with promoted dtype
119 promoted_dtype = torch.result_type(input, target)
120 result_dtype = out.dtype
122 # Broadcast tensors
123 x_b, y_b = torch.broadcast_tensors(
124 input.to(promoted_dtype), target.to(promoted_dtype)
125 )
126 x_b = x_b.contiguous()
127 y_b = y_b.contiguous()
128 numel = x_b.numel()
130 BLOCK_SIZE = 1024
131 grid = lambda meta: (triton.cdiv(numel, meta["BLOCK_SIZE"]),)
133 if reduction == 0: # 'none'
134 # Ensure out has correct shape
135 if out.numel() != numel or out.shape != x_b.shape:
136 raise ValueError(
137 f"'out' tensor must have shape {tuple(x_b.shape)} for reduction='none'"
138 )
139 # Compute into a temporary if out is not contiguous or dtype mismatches
140 needs_tmp = (not out.is_contiguous()) or (out.dtype != result_dtype)
141 if needs_tmp:
142 tmp = torch.empty_like(x_b, dtype=result_dtype, device=device)
143 huber_loss_element_kernel[grid](
144 x_b.to(result_dtype),
145 y_b.to(result_dtype),
146 tmp,
147 numel,
148 float(delta),
149 BLOCK_SIZE=BLOCK_SIZE,
150 )
151 out.copy_(tmp)
152 else:
153 huber_loss_element_kernel[grid](
154 x_b.to(result_dtype),
155 y_b.to(result_dtype),
156 out,
157 numel,
158 float(delta),
159 BLOCK_SIZE=BLOCK_SIZE,
160 )
161 return out
162 else:
163 # Compute elementwise loss into temporary (in promoted dtype), then reduce to scalar
164 tmp = torch.empty_like(x_b, dtype=promoted_dtype, device=device)
165 huber_loss_element_kernel[grid](
166 x_b, y_b, tmp, numel, float(delta), BLOCK_SIZE=BLOCK_SIZE
167 )
168 acc = torch.zeros((), dtype=torch.float32, device=device)
169 reduce_sum_kernel[grid](tmp, acc, numel, BLOCK_SIZE=BLOCK_SIZE)
170 if reduction == 1: # mean
171 val = (acc / numel).to(result_dtype)
172 else: # sum
173 val = acc.to(result_dtype)
174 # Ensure out is scalar/0-d
175 if out.numel() != 1 or out.dim() > 1:
176 raise ValueError(
177 "For reduction='mean' or 'sum', 'out' must be a scalar (0-d or 1-element) tensor"
178 )
179 # Copy the scalar value into out
180 out.copy_(val)
181 return out