Coverage for src/flag_gems/experimental_ops/smooth_l1_loss.py: 0%
149 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def smooth_l1_elementwise_kernel(
8 x_ptr,
9 y_ptr,
10 out_ptr,
11 n_elements,
12 beta, # scalar
13 BLOCK_SIZE: tl.constexpr,
14):
15 pid = tl.program_id(0)
16 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
17 mask = offsets < n_elements
19 x = tl.load(x_ptr + offsets, mask=mask, other=0)
20 y = tl.load(y_ptr + offsets, mask=mask, other=0)
22 diff = x - y
23 ad = tl.abs(diff)
25 # Broadcast beta to vector shape
26 beta_vec = tl.full(ad.shape, beta, x.dtype)
28 # Smooth L1 piecewise (for beta > 0):
29 # 0.5 * x^2 / beta if |x| < beta
30 # |x| - 0.5 * beta otherwise
31 loss_beta = 0.5 * diff * diff / beta_vec
32 loss_piecewise = tl.where(ad < beta_vec, loss_beta, ad - 0.5 * beta_vec)
34 # If beta <= 0, fall back to L1: |x|
35 # Use vectorized condition to avoid divide-by-zero
36 use_piecewise = beta_vec > 0
37 loss = tl.where(use_piecewise, loss_piecewise, ad)
39 tl.store(out_ptr + offsets, loss, mask=mask)
42def _normalize_reduction(reduction):
43 if reduction is None:
44 return "mean"
45 if isinstance(reduction, str):
46 reduction = reduction.lower()
47 if reduction in ("none", "mean", "sum"):
48 return reduction
49 raise ValueError(f"Invalid reduction: {reduction}")
50 if isinstance(reduction, int):
51 mapping = {0: "none", 1: "mean", 2: "sum"}
52 if reduction in mapping:
53 return mapping[reduction]
54 raise ValueError(f"Invalid reduction code: {reduction}")
55 raise ValueError(f"Unsupported reduction type: {type(reduction)}")
58def _parse_smooth_l1_args(args, kwargs, out_variant=False):
59 if len(args) < 2:
60 raise TypeError("smooth_l1_loss requires at least input and target tensors")
62 x = args[0]
63 y = args[1]
65 beta = kwargs.pop("beta", None)
66 reduction = kwargs.pop("reduction", None)
67 out = kwargs.pop("out", None) if out_variant else None
69 # Parse remaining positional arguments flexibly
70 rest = list(args[2:])
72 # Try to infer reduction and beta from positional args
73 def maybe_set_reduction(val):
74 nonlocal reduction
75 if reduction is not None:
76 return False
77 if isinstance(val, str):
78 reduction = val
79 return True
80 if isinstance(val, int) and val in (0, 1, 2):
81 reduction = val
82 return True
83 return False
85 def maybe_set_beta(val):
86 nonlocal beta
87 if beta is not None:
88 return False
89 if isinstance(val, (float, int)):
90 beta = float(val)
91 return True
92 return False
94 # Accept either order for the two optional parameters
95 for val in rest:
96 if not maybe_set_reduction(val):
97 maybe_set_beta(val)
99 if beta is None:
100 beta = 1.0
101 reduction = _normalize_reduction(reduction)
103 return x, y, reduction, float(beta), out, kwargs
106def _launch_smooth_l1_elementwise(x, y, out_buf, beta):
107 n_elements = out_buf.numel()
108 if n_elements == 0:
109 return # nothing to do
111 BLOCK_SIZE = 1024
112 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
114 smooth_l1_elementwise_kernel[grid](
115 x, y, out_buf, n_elements, beta, BLOCK_SIZE=BLOCK_SIZE
116 )
119def _prepare_tensors_for_elementwise(x, y, dtype=None):
120 if dtype is None:
121 dtype = torch.result_type(x, y)
122 if not (dtype.is_floating_point or dtype.is_complex):
123 dtype = torch.get_default_dtype()
124 device = x.device
125 if x.device != y.device:
126 raise ValueError("input and target must be on the same device")
127 if not device.type == "cuda":
128 return None, None, None, None # signal fallback
130 # Broadcast to a common shape
131 bshape = torch.broadcast_shapes(tuple(x.shape), tuple(y.shape))
132 xb = x.to(dtype).expand(bshape).contiguous()
133 yb = y.to(dtype).expand(bshape).contiguous()
134 out_buf = torch.empty(bshape, device=device, dtype=dtype)
135 return xb, yb, out_buf, bshape
138def smooth_l1_loss(*args, **kwargs):
139 x, y, reduction, beta, _, leftover = _parse_smooth_l1_args(
140 args, kwargs, out_variant=False
141 )
142 if leftover:
143 raise TypeError(f"Unexpected keyword arguments: {list(leftover.keys())}")
145 prep = _prepare_tensors_for_elementwise(x, y)
146 if prep[0] is None:
147 # Fallback to PyTorch if not CUDA
148 return torch.ops.aten.smooth_l1_loss(x, y, reduction=reduction, beta=beta)
150 xb, yb, tmp, _ = prep
151 _launch_smooth_l1_elementwise(xb, yb, tmp, beta)
153 if reduction == "none":
154 return tmp
155 elif reduction == "mean":
156 return tmp.mean()
157 elif reduction == "sum":
158 return tmp.sum()
159 else:
160 raise ValueError(f"Invalid reduction: {reduction}")
163def smooth_l1_loss_out(*args, **kwargs):
164 x, y, reduction, beta, out, leftover = _parse_smooth_l1_args(
165 args, kwargs, out_variant=True
166 )
167 if leftover:
168 raise TypeError(f"Unexpected keyword arguments: {list(leftover.keys())}")
170 # Fallback if not CUDA
171 if x.device.type != "cuda" or y.device.type != "cuda":
172 res = torch.ops.aten.smooth_l1_loss(x, y, reduction=reduction, beta=beta)
173 if out is None:
174 return res
175 else:
176 out.copy_(res)
177 return out
179 xb, yb, tmp, bshape = _prepare_tensors_for_elementwise(x, y)
180 if xb is None:
181 # Should not happen due to device check above
182 res = torch.ops.aten.smooth_l1_loss(x, y, reduction=reduction, beta=beta)
183 if out is None:
184 return res
185 else:
186 out.copy_(res)
187 return out
189 _launch_smooth_l1_elementwise(xb, yb, tmp, beta)
191 if reduction == "none":
192 if out is None:
193 return tmp
194 # Validate 'out' shape/device/dtype for 'none'
195 if out.device != tmp.device:
196 raise ValueError("out tensor device mismatch")
197 if out.dtype != tmp.dtype:
198 raise ValueError("out tensor dtype mismatch")
199 if tuple(out.shape) != tuple(bshape):
200 raise ValueError("out tensor shape mismatch for reduction='none'")
201 if out.is_contiguous():
202 out.copy_(tmp)
203 else:
204 out.reshape(-1).copy_(tmp.reshape(-1))
205 return out
206 else:
207 if reduction == "mean":
208 res = tmp.mean()
209 elif reduction == "sum":
210 res = tmp.sum()
211 else:
212 raise ValueError(f"Invalid reduction: {reduction}")
213 if out is None:
214 return res
215 # For reduced results, expect out to be a scalar tensor (numel == 1)
216 if out.device != res.device:
217 raise ValueError("out tensor device mismatch")
218 if out.dtype != res.dtype:
219 raise ValueError("out tensor dtype mismatch")
220 if out.numel() != 1:
221 raise ValueError("out tensor must have one element for reduced output")
222 out.copy_(res)
223 return out