Coverage for src/flag_gems/experimental_ops/smooth_l1_loss.py: 0%

149 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-16 02:02 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def smooth_l1_elementwise_kernel( 

8 x_ptr, 

9 y_ptr, 

10 out_ptr, 

11 n_elements, 

12 beta, # scalar 

13 BLOCK_SIZE: tl.constexpr, 

14): 

15 pid = tl.program_id(0) 

16 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

17 mask = offsets < n_elements 

18 

19 x = tl.load(x_ptr + offsets, mask=mask, other=0) 

20 y = tl.load(y_ptr + offsets, mask=mask, other=0) 

21 

22 diff = x - y 

23 ad = tl.abs(diff) 

24 

25 # Broadcast beta to vector shape 

26 beta_vec = tl.full(ad.shape, beta, x.dtype) 

27 

28 # Smooth L1 piecewise (for beta > 0): 

29 # 0.5 * x^2 / beta if |x| < beta 

30 # |x| - 0.5 * beta otherwise 

31 loss_beta = 0.5 * diff * diff / beta_vec 

32 loss_piecewise = tl.where(ad < beta_vec, loss_beta, ad - 0.5 * beta_vec) 

33 

34 # If beta <= 0, fall back to L1: |x| 

35 # Use vectorized condition to avoid divide-by-zero 

36 use_piecewise = beta_vec > 0 

37 loss = tl.where(use_piecewise, loss_piecewise, ad) 

38 

39 tl.store(out_ptr + offsets, loss, mask=mask) 

40 

41 

42def _normalize_reduction(reduction): 

43 if reduction is None: 

44 return "mean" 

45 if isinstance(reduction, str): 

46 reduction = reduction.lower() 

47 if reduction in ("none", "mean", "sum"): 

48 return reduction 

49 raise ValueError(f"Invalid reduction: {reduction}") 

50 if isinstance(reduction, int): 

51 mapping = {0: "none", 1: "mean", 2: "sum"} 

52 if reduction in mapping: 

53 return mapping[reduction] 

54 raise ValueError(f"Invalid reduction code: {reduction}") 

55 raise ValueError(f"Unsupported reduction type: {type(reduction)}") 

56 

57 

58def _parse_smooth_l1_args(args, kwargs, out_variant=False): 

59 if len(args) < 2: 

60 raise TypeError("smooth_l1_loss requires at least input and target tensors") 

61 

62 x = args[0] 

63 y = args[1] 

64 

65 beta = kwargs.pop("beta", None) 

66 reduction = kwargs.pop("reduction", None) 

67 out = kwargs.pop("out", None) if out_variant else None 

68 

69 # Parse remaining positional arguments flexibly 

70 rest = list(args[2:]) 

71 

72 # Try to infer reduction and beta from positional args 

73 def maybe_set_reduction(val): 

74 nonlocal reduction 

75 if reduction is not None: 

76 return False 

77 if isinstance(val, str): 

78 reduction = val 

79 return True 

80 if isinstance(val, int) and val in (0, 1, 2): 

81 reduction = val 

82 return True 

83 return False 

84 

85 def maybe_set_beta(val): 

86 nonlocal beta 

87 if beta is not None: 

88 return False 

89 if isinstance(val, (float, int)): 

90 beta = float(val) 

91 return True 

92 return False 

93 

94 # Accept either order for the two optional parameters 

95 for val in rest: 

96 if not maybe_set_reduction(val): 

97 maybe_set_beta(val) 

98 

99 if beta is None: 

100 beta = 1.0 

101 reduction = _normalize_reduction(reduction) 

102 

103 return x, y, reduction, float(beta), out, kwargs 

104 

105 

106def _launch_smooth_l1_elementwise(x, y, out_buf, beta): 

107 n_elements = out_buf.numel() 

108 if n_elements == 0: 

109 return # nothing to do 

110 

111 BLOCK_SIZE = 1024 

112 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

113 

114 smooth_l1_elementwise_kernel[grid]( 

115 x, y, out_buf, n_elements, beta, BLOCK_SIZE=BLOCK_SIZE 

116 ) 

117 

118 

119def _prepare_tensors_for_elementwise(x, y, dtype=None): 

120 if dtype is None: 

121 dtype = torch.result_type(x, y) 

122 if not (dtype.is_floating_point or dtype.is_complex): 

123 dtype = torch.get_default_dtype() 

124 device = x.device 

125 if x.device != y.device: 

126 raise ValueError("input and target must be on the same device") 

127 if not device.type == "cuda": 

128 return None, None, None, None # signal fallback 

129 

130 # Broadcast to a common shape 

131 bshape = torch.broadcast_shapes(tuple(x.shape), tuple(y.shape)) 

132 xb = x.to(dtype).expand(bshape).contiguous() 

133 yb = y.to(dtype).expand(bshape).contiguous() 

134 out_buf = torch.empty(bshape, device=device, dtype=dtype) 

135 return xb, yb, out_buf, bshape 

136 

137 

138def smooth_l1_loss(*args, **kwargs): 

139 x, y, reduction, beta, _, leftover = _parse_smooth_l1_args( 

140 args, kwargs, out_variant=False 

141 ) 

142 if leftover: 

143 raise TypeError(f"Unexpected keyword arguments: {list(leftover.keys())}") 

144 

145 prep = _prepare_tensors_for_elementwise(x, y) 

146 if prep[0] is None: 

147 # Fallback to PyTorch if not CUDA 

148 return torch.ops.aten.smooth_l1_loss(x, y, reduction=reduction, beta=beta) 

149 

150 xb, yb, tmp, _ = prep 

151 _launch_smooth_l1_elementwise(xb, yb, tmp, beta) 

152 

153 if reduction == "none": 

154 return tmp 

155 elif reduction == "mean": 

156 return tmp.mean() 

157 elif reduction == "sum": 

158 return tmp.sum() 

159 else: 

160 raise ValueError(f"Invalid reduction: {reduction}") 

161 

162 

163def smooth_l1_loss_out(*args, **kwargs): 

164 x, y, reduction, beta, out, leftover = _parse_smooth_l1_args( 

165 args, kwargs, out_variant=True 

166 ) 

167 if leftover: 

168 raise TypeError(f"Unexpected keyword arguments: {list(leftover.keys())}") 

169 

170 # Fallback if not CUDA 

171 if x.device.type != "cuda" or y.device.type != "cuda": 

172 res = torch.ops.aten.smooth_l1_loss(x, y, reduction=reduction, beta=beta) 

173 if out is None: 

174 return res 

175 else: 

176 out.copy_(res) 

177 return out 

178 

179 xb, yb, tmp, bshape = _prepare_tensors_for_elementwise(x, y) 

180 if xb is None: 

181 # Should not happen due to device check above 

182 res = torch.ops.aten.smooth_l1_loss(x, y, reduction=reduction, beta=beta) 

183 if out is None: 

184 return res 

185 else: 

186 out.copy_(res) 

187 return out 

188 

189 _launch_smooth_l1_elementwise(xb, yb, tmp, beta) 

190 

191 if reduction == "none": 

192 if out is None: 

193 return tmp 

194 # Validate 'out' shape/device/dtype for 'none' 

195 if out.device != tmp.device: 

196 raise ValueError("out tensor device mismatch") 

197 if out.dtype != tmp.dtype: 

198 raise ValueError("out tensor dtype mismatch") 

199 if tuple(out.shape) != tuple(bshape): 

200 raise ValueError("out tensor shape mismatch for reduction='none'") 

201 if out.is_contiguous(): 

202 out.copy_(tmp) 

203 else: 

204 out.reshape(-1).copy_(tmp.reshape(-1)) 

205 return out 

206 else: 

207 if reduction == "mean": 

208 res = tmp.mean() 

209 elif reduction == "sum": 

210 res = tmp.sum() 

211 else: 

212 raise ValueError(f"Invalid reduction: {reduction}") 

213 if out is None: 

214 return res 

215 # For reduced results, expect out to be a scalar tensor (numel == 1) 

216 if out.device != res.device: 

217 raise ValueError("out tensor device mismatch") 

218 if out.dtype != res.dtype: 

219 raise ValueError("out tensor dtype mismatch") 

220 if out.numel() != 1: 

221 raise ValueError("out tensor must have one element for reduced output") 

222 out.copy_(res) 

223 return out