Coverage for src/flag_gems/experimental_ops/huber_loss.py: 0%

102 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-13 10:08 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def huber_loss_element_kernel( 

8 x_ptr, # pointer to input tensor (broadcasted, contiguous, flattened) 

9 y_ptr, # pointer to target tensor (broadcasted, contiguous, flattened) 

10 out_ptr, # pointer to output tensor (contiguous, flattened) 

11 n_elements, # number of elements 

12 delta, # huber delta (scalar) 

13 BLOCK_SIZE: tl.constexpr, 

14): 

15 pid = tl.program_id(axis=0) 

16 block_start = pid * BLOCK_SIZE 

17 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

18 mask = offsets < n_elements 

19 

20 x = tl.load(x_ptr + offsets, mask=mask) 

21 y = tl.load(y_ptr + offsets, mask=mask) 

22 diff = x - y 

23 absdiff = tl.abs(diff) 

24 

25 # loss = 0.5 * diff^2 if absdiff <= delta else delta * (absdiff - 0.5 * delta) 

26 loss_quad = 0.5 * diff * diff 

27 loss_linear = delta * (absdiff - 0.5 * delta) 

28 loss = tl.where(absdiff <= delta, loss_quad, loss_linear) 

29 

30 tl.store(out_ptr + offsets, loss, mask=mask) 

31 

32 

33@triton.jit 

34def reduce_sum_kernel( 

35 x_ptr, # pointer to input tensor (contiguous, flattened) 

36 out_ptr, # pointer to single scalar (float32) to accumulate sum into 

37 n_elements, 

38 BLOCK_SIZE: tl.constexpr, 

39): 

40 pid = tl.program_id(axis=0) 

41 block_start = pid * BLOCK_SIZE 

42 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

43 mask = offsets < n_elements 

44 vals = tl.load(x_ptr + offsets, mask=mask, other=0.0) 

45 vals_f32 = vals.to(tl.float32) 

46 partial_sum = tl.sum(vals_f32, axis=0) 

47 tl.atomic_add(out_ptr, partial_sum) 

48 

49 

50def _normalize_reduction(reduction): 

51 if isinstance(reduction, str): 

52 r = reduction.lower() 

53 if r == "none": 

54 return 0 

55 elif r == "mean": 

56 return 1 

57 elif r == "sum": 

58 return 2 

59 else: 

60 raise ValueError(f"Unsupported reduction: {reduction}") 

61 elif isinstance(reduction, int): 

62 if reduction in (0, 1, 2): 

63 return reduction 

64 else: 

65 raise ValueError(f"Unsupported reduction: {reduction}") 

66 else: 

67 raise ValueError(f"Unsupported reduction type: {type(reduction)}") 

68 

69 

70def huber_loss(input, target, reduction=1, delta=1.0): 

71 reduction = _normalize_reduction(reduction) 

72 if not (input.is_cuda and target.is_cuda): 

73 raise AssertionError("Triton kernels require CUDA tensors") 

74 device = input.device 

75 # Promote dtype similar to PyTorch type promotion rules 

76 result_dtype = torch.result_type(input, target) 

77 

78 # Broadcast tensors to a common shape 

79 x_b, y_b = torch.broadcast_tensors(input.to(result_dtype), target.to(result_dtype)) 

80 x_b = x_b.contiguous() 

81 y_b = y_b.contiguous() 

82 numel = x_b.numel() 

83 

84 BLOCK_SIZE = 1024 

85 grid = lambda meta: (triton.cdiv(numel, meta["BLOCK_SIZE"]),) 

86 

87 if reduction == 0: # 'none' 

88 out = torch.empty_like(x_b, dtype=result_dtype, device=device) 

89 huber_loss_element_kernel[grid]( 

90 x_b, y_b, out, numel, float(delta), BLOCK_SIZE=BLOCK_SIZE 

91 ) 

92 return out 

93 else: 

94 # Compute elementwise loss into a temporary buffer 

95 tmp = torch.empty_like(x_b, dtype=result_dtype, device=device) 

96 huber_loss_element_kernel[grid]( 

97 x_b, y_b, tmp, numel, float(delta), BLOCK_SIZE=BLOCK_SIZE 

98 ) 

99 # Reduce to scalar using float32 accumulator 

100 acc = torch.zeros((), dtype=torch.float32, device=device) 

101 reduce_sum_kernel[grid](tmp, acc, numel, BLOCK_SIZE=BLOCK_SIZE) 

102 if reduction == 1: # mean 

103 val = (acc / numel).to(result_dtype) 

104 else: # sum 

105 val = acc.to(result_dtype) 

106 return val 

107 

108 

109def huber_loss_out(input, target, reduction=1, delta=1.0, out=None): 

110 if out is None: 

111 raise ValueError("huber_loss_out requires an 'out' tensor") 

112 reduction = _normalize_reduction(reduction) 

113 if not (input.is_cuda and target.is_cuda and out.is_cuda): 

114 raise AssertionError("Triton kernels require CUDA tensors") 

115 

116 device = input.device 

117 # Determine result dtype; use out.dtype if provided to match .out behavior 

118 # but ensure it's compatible with promoted dtype 

119 promoted_dtype = torch.result_type(input, target) 

120 result_dtype = out.dtype 

121 

122 # Broadcast tensors 

123 x_b, y_b = torch.broadcast_tensors( 

124 input.to(promoted_dtype), target.to(promoted_dtype) 

125 ) 

126 x_b = x_b.contiguous() 

127 y_b = y_b.contiguous() 

128 numel = x_b.numel() 

129 

130 BLOCK_SIZE = 1024 

131 grid = lambda meta: (triton.cdiv(numel, meta["BLOCK_SIZE"]),) 

132 

133 if reduction == 0: # 'none' 

134 # Ensure out has correct shape 

135 if out.numel() != numel or out.shape != x_b.shape: 

136 raise ValueError( 

137 f"'out' tensor must have shape {tuple(x_b.shape)} for reduction='none'" 

138 ) 

139 # Compute into a temporary if out is not contiguous or dtype mismatches 

140 needs_tmp = (not out.is_contiguous()) or (out.dtype != result_dtype) 

141 if needs_tmp: 

142 tmp = torch.empty_like(x_b, dtype=result_dtype, device=device) 

143 huber_loss_element_kernel[grid]( 

144 x_b.to(result_dtype), 

145 y_b.to(result_dtype), 

146 tmp, 

147 numel, 

148 float(delta), 

149 BLOCK_SIZE=BLOCK_SIZE, 

150 ) 

151 out.copy_(tmp) 

152 else: 

153 huber_loss_element_kernel[grid]( 

154 x_b.to(result_dtype), 

155 y_b.to(result_dtype), 

156 out, 

157 numel, 

158 float(delta), 

159 BLOCK_SIZE=BLOCK_SIZE, 

160 ) 

161 return out 

162 else: 

163 # Compute elementwise loss into temporary (in promoted dtype), then reduce to scalar 

164 tmp = torch.empty_like(x_b, dtype=promoted_dtype, device=device) 

165 huber_loss_element_kernel[grid]( 

166 x_b, y_b, tmp, numel, float(delta), BLOCK_SIZE=BLOCK_SIZE 

167 ) 

168 acc = torch.zeros((), dtype=torch.float32, device=device) 

169 reduce_sum_kernel[grid](tmp, acc, numel, BLOCK_SIZE=BLOCK_SIZE) 

170 if reduction == 1: # mean 

171 val = (acc / numel).to(result_dtype) 

172 else: # sum 

173 val = acc.to(result_dtype) 

174 # Ensure out is scalar/0-d 

175 if out.numel() != 1 or out.dim() > 1: 

176 raise ValueError( 

177 "For reduction='mean' or 'sum', 'out' must be a scalar (0-d or 1-element) tensor" 

178 ) 

179 # Copy the scalar value into out 

180 out.copy_(val) 

181 return out