Coverage for src/flag_gems/experimental_ops/mse_loss.py: 0%

134 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-27 02:51 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def _mse_elemwise_kernel(x_ptr, y_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

8 pid = tl.program_id(axis=0) 

9 block_start = pid * BLOCK_SIZE 

10 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

11 mask = offsets < n_elements 

12 

13 x = tl.load(x_ptr + offsets, mask=mask) 

14 y = tl.load(y_ptr + offsets, mask=mask) 

15 diff = x - y 

16 sq = diff * diff 

17 tl.store(out_ptr + offsets, sq, mask=mask) 

18 

19 

20@triton.jit 

21def _mse_reduce_kernel( 

22 x_ptr, y_ptr, acc_ptr, n_elements, scale, BLOCK_SIZE: tl.constexpr 

23): 

24 pid = tl.program_id(axis=0) 

25 block_start = pid * BLOCK_SIZE 

26 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

27 mask = offsets < n_elements 

28 

29 # load as float32 for stable accumulation 

30 x = tl.load(x_ptr + offsets, mask=mask, other=0).to(tl.float32) 

31 y = tl.load(y_ptr + offsets, mask=mask, other=0).to(tl.float32) 

32 diff = x - y 

33 sq = diff * diff 

34 sq = sq * scale 

35 block_sum = tl.sum(sq, axis=0) 

36 tl.atomic_add(acc_ptr, block_sum) 

37 

38 

39def _parse_reduction(reduction): 

40 # Accept both strings and integers consistent with ATen Reduction enum: 

41 # 0: 'none', 1: 'mean', 2: 'sum' 

42 if isinstance(reduction, str): 

43 r = reduction.lower() 

44 if r == "none": 

45 return 0 

46 if r == "mean": 

47 return 1 

48 if r == "sum": 

49 return 2 

50 raise ValueError(f"Invalid reduction string: {reduction}") 

51 # Assume integer 

52 if reduction in (0, 1, 2): 

53 return int(reduction) 

54 raise ValueError(f"Invalid reduction value: {reduction}") 

55 

56 

57def _ensure_supported_dtype(t: torch.Tensor, op_name="mse_loss"): 

58 if t.dtype not in (torch.float16, torch.bfloat16, torch.float32): 

59 raise TypeError( 

60 f"{op_name} Triton kernel supports float16, bfloat16, and float32 dtypes, got {t.dtype}." 

61 ) 

62 

63 

64def _launch_mse_elemwise(x, y, out): 

65 n_elements = out.numel() 

66 BLOCK_SIZE = 1024 

67 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

68 _mse_elemwise_kernel[grid](x, y, out, n_elements, BLOCK_SIZE=BLOCK_SIZE) 

69 

70 

71def _launch_mse_reduce(x, y, n_elements, scale): 

72 BLOCK_SIZE = 1024 

73 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

74 acc = torch.zeros((), device=x.device, dtype=torch.float32) 

75 _mse_reduce_kernel[grid](x, y, acc, n_elements, float(scale), BLOCK_SIZE=BLOCK_SIZE) 

76 return acc 

77 

78 

79def mse_loss(*args, **kwargs): 

80 # Expected calling pattern: mse_loss(self, target, reduction=Mean) 

81 if len(args) < 2: 

82 raise TypeError( 

83 "mse_loss requires at least 2 positional arguments: (input, target)" 

84 ) 

85 inp = args[0] 

86 target = args[1] 

87 reduction = kwargs.get("reduction", args[2] if len(args) > 2 else 1) 

88 reduction = _parse_reduction(reduction) 

89 

90 if not isinstance(inp, torch.Tensor) or not isinstance(target, torch.Tensor): 

91 raise TypeError("mse_loss expects tensor inputs") 

92 

93 if inp.numel() != target.numel(): 

94 raise ValueError( 

95 "mse_loss: input and target must have the same number of elements" 

96 ) 

97 

98 if inp.device != target.device: 

99 raise ValueError("mse_loss: input and target must be on the same device") 

100 

101 _ensure_supported_dtype(inp, "mse_loss") 

102 _ensure_supported_dtype(target, "mse_loss") 

103 

104 x = inp.contiguous() 

105 y = target.contiguous() 

106 

107 n_elements = x.numel() 

108 

109 if reduction == 0: # 'none' 

110 out = torch.empty_like(x) 

111 if not out.is_contiguous(): 

112 # Ensure output is contiguous for Triton; then copy back 

113 tmp = torch.empty_like(x, memory_format=torch.contiguous_format) 

114 _launch_mse_elemwise(x, y, tmp) 

115 out.copy_(tmp) 

116 else: 

117 _launch_mse_elemwise(x, y, out) 

118 return out.reshape_as(inp) 

119 

120 # sum or mean -> scalar 

121 if n_elements == 0: 

122 # Follow a simple convention: return 0 for empty tensors 

123 zero = torch.zeros((), device=x.device, dtype=inp.dtype) 

124 return zero 

125 

126 scale = 1.0 if reduction == 2 else (1.0 / float(n_elements)) # sum or mean 

127 acc = _launch_mse_reduce(x, y, n_elements, scale) 

128 result = acc.to(dtype=inp.dtype) 

129 return result 

130 

131 

132def mse_loss_out(*args, **kwargs): 

133 # Expected calling pattern: mse_loss_out(self, target, reduction=Mean, *, out) 

134 if len(args) < 2: 

135 raise TypeError( 

136 "mse_loss_out requires at least 2 positional arguments: (input, target)" 

137 ) 

138 inp = args[0] 

139 target = args[1] 

140 reduction = kwargs.get("reduction", args[2] if len(args) > 2 else 1) 

141 out = kwargs.get("out", args[3] if len(args) > 3 else None) 

142 

143 if out is None: 

144 raise TypeError("mse_loss_out requires an 'out' tensor") 

145 

146 reduction = _parse_reduction(reduction) 

147 

148 if not isinstance(inp, torch.Tensor) or not isinstance(target, torch.Tensor): 

149 raise TypeError("mse_loss_out expects tensor inputs") 

150 

151 if inp.numel() != target.numel(): 

152 raise ValueError( 

153 "mse_loss_out: input and target must have the same number of elements" 

154 ) 

155 

156 if inp.device != target.device: 

157 raise ValueError("mse_loss_out: input and target must be on the same device") 

158 

159 _ensure_supported_dtype(inp, "mse_loss_out") 

160 _ensure_supported_dtype(target, "mse_loss_out") 

161 

162 x = inp.contiguous() 

163 y = target.contiguous() 

164 n_elements = x.numel() 

165 

166 if reduction == 0: # 'none' 

167 # out must have same shape as input 

168 if out.numel() != n_elements: 

169 raise ValueError( 

170 "mse_loss_out (reduction='none'): 'out' must have the same number of elements as input" 

171 ) 

172 if out.device != x.device: 

173 raise ValueError("mse_loss_out: 'out' must be on the same device as input") 

174 if out.dtype != inp.dtype: 

175 raise TypeError( 

176 "mse_loss_out (reduction='none'): 'out' dtype must match input dtype" 

177 ) 

178 

179 if out.is_contiguous(): 

180 _launch_mse_elemwise(x, y, out) 

181 else: 

182 tmp = torch.empty_like(x, memory_format=torch.contiguous_format) 

183 _launch_mse_elemwise(x, y, tmp) 

184 out.copy_(tmp) 

185 return out 

186 

187 # sum or mean 

188 if out.device != x.device: 

189 raise ValueError("mse_loss_out: 'out' must be on the same device as input") 

190 if out.numel() != 1: 

191 raise ValueError( 

192 "mse_loss_out (reduction in ['sum','mean']): 'out' must be a scalar tensor" 

193 ) 

194 # out dtype must be a supported float dtype 

195 if out.dtype not in (torch.float16, torch.bfloat16, torch.float32): 

196 raise TypeError( 

197 "mse_loss_out: 'out' dtype must be one of float16, bfloat16, or float32 for Triton kernel" 

198 ) 

199 

200 if n_elements == 0: 

201 out.fill_(0) 

202 return out 

203 

204 scale = 1.0 if reduction == 2 else (1.0 / float(n_elements)) 

205 acc = _launch_mse_reduce(x, y, n_elements, scale) 

206 out.fill_(acc.to(dtype=out.dtype)) 

207 return out