Coverage for src/flag_gems/experimental_ops/mse_loss.py: 0%
134 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def _mse_elemwise_kernel(x_ptr, y_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
8 pid = tl.program_id(axis=0)
9 block_start = pid * BLOCK_SIZE
10 offsets = block_start + tl.arange(0, BLOCK_SIZE)
11 mask = offsets < n_elements
13 x = tl.load(x_ptr + offsets, mask=mask)
14 y = tl.load(y_ptr + offsets, mask=mask)
15 diff = x - y
16 sq = diff * diff
17 tl.store(out_ptr + offsets, sq, mask=mask)
20@triton.jit
21def _mse_reduce_kernel(
22 x_ptr, y_ptr, acc_ptr, n_elements, scale, BLOCK_SIZE: tl.constexpr
23):
24 pid = tl.program_id(axis=0)
25 block_start = pid * BLOCK_SIZE
26 offsets = block_start + tl.arange(0, BLOCK_SIZE)
27 mask = offsets < n_elements
29 # load as float32 for stable accumulation
30 x = tl.load(x_ptr + offsets, mask=mask, other=0).to(tl.float32)
31 y = tl.load(y_ptr + offsets, mask=mask, other=0).to(tl.float32)
32 diff = x - y
33 sq = diff * diff
34 sq = sq * scale
35 block_sum = tl.sum(sq, axis=0)
36 tl.atomic_add(acc_ptr, block_sum)
39def _parse_reduction(reduction):
40 # Accept both strings and integers consistent with ATen Reduction enum:
41 # 0: 'none', 1: 'mean', 2: 'sum'
42 if isinstance(reduction, str):
43 r = reduction.lower()
44 if r == "none":
45 return 0
46 if r == "mean":
47 return 1
48 if r == "sum":
49 return 2
50 raise ValueError(f"Invalid reduction string: {reduction}")
51 # Assume integer
52 if reduction in (0, 1, 2):
53 return int(reduction)
54 raise ValueError(f"Invalid reduction value: {reduction}")
57def _ensure_supported_dtype(t: torch.Tensor, op_name="mse_loss"):
58 if t.dtype not in (torch.float16, torch.bfloat16, torch.float32):
59 raise TypeError(
60 f"{op_name} Triton kernel supports float16, bfloat16, and float32 dtypes, got {t.dtype}."
61 )
64def _launch_mse_elemwise(x, y, out):
65 n_elements = out.numel()
66 BLOCK_SIZE = 1024
67 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
68 _mse_elemwise_kernel[grid](x, y, out, n_elements, BLOCK_SIZE=BLOCK_SIZE)
71def _launch_mse_reduce(x, y, n_elements, scale):
72 BLOCK_SIZE = 1024
73 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
74 acc = torch.zeros((), device=x.device, dtype=torch.float32)
75 _mse_reduce_kernel[grid](x, y, acc, n_elements, float(scale), BLOCK_SIZE=BLOCK_SIZE)
76 return acc
79def mse_loss(*args, **kwargs):
80 # Expected calling pattern: mse_loss(self, target, reduction=Mean)
81 if len(args) < 2:
82 raise TypeError(
83 "mse_loss requires at least 2 positional arguments: (input, target)"
84 )
85 inp = args[0]
86 target = args[1]
87 reduction = kwargs.get("reduction", args[2] if len(args) > 2 else 1)
88 reduction = _parse_reduction(reduction)
90 if not isinstance(inp, torch.Tensor) or not isinstance(target, torch.Tensor):
91 raise TypeError("mse_loss expects tensor inputs")
93 if inp.numel() != target.numel():
94 raise ValueError(
95 "mse_loss: input and target must have the same number of elements"
96 )
98 if inp.device != target.device:
99 raise ValueError("mse_loss: input and target must be on the same device")
101 _ensure_supported_dtype(inp, "mse_loss")
102 _ensure_supported_dtype(target, "mse_loss")
104 x = inp.contiguous()
105 y = target.contiguous()
107 n_elements = x.numel()
109 if reduction == 0: # 'none'
110 out = torch.empty_like(x)
111 if not out.is_contiguous():
112 # Ensure output is contiguous for Triton; then copy back
113 tmp = torch.empty_like(x, memory_format=torch.contiguous_format)
114 _launch_mse_elemwise(x, y, tmp)
115 out.copy_(tmp)
116 else:
117 _launch_mse_elemwise(x, y, out)
118 return out.reshape_as(inp)
120 # sum or mean -> scalar
121 if n_elements == 0:
122 # Follow a simple convention: return 0 for empty tensors
123 zero = torch.zeros((), device=x.device, dtype=inp.dtype)
124 return zero
126 scale = 1.0 if reduction == 2 else (1.0 / float(n_elements)) # sum or mean
127 acc = _launch_mse_reduce(x, y, n_elements, scale)
128 result = acc.to(dtype=inp.dtype)
129 return result
132def mse_loss_out(*args, **kwargs):
133 # Expected calling pattern: mse_loss_out(self, target, reduction=Mean, *, out)
134 if len(args) < 2:
135 raise TypeError(
136 "mse_loss_out requires at least 2 positional arguments: (input, target)"
137 )
138 inp = args[0]
139 target = args[1]
140 reduction = kwargs.get("reduction", args[2] if len(args) > 2 else 1)
141 out = kwargs.get("out", args[3] if len(args) > 3 else None)
143 if out is None:
144 raise TypeError("mse_loss_out requires an 'out' tensor")
146 reduction = _parse_reduction(reduction)
148 if not isinstance(inp, torch.Tensor) or not isinstance(target, torch.Tensor):
149 raise TypeError("mse_loss_out expects tensor inputs")
151 if inp.numel() != target.numel():
152 raise ValueError(
153 "mse_loss_out: input and target must have the same number of elements"
154 )
156 if inp.device != target.device:
157 raise ValueError("mse_loss_out: input and target must be on the same device")
159 _ensure_supported_dtype(inp, "mse_loss_out")
160 _ensure_supported_dtype(target, "mse_loss_out")
162 x = inp.contiguous()
163 y = target.contiguous()
164 n_elements = x.numel()
166 if reduction == 0: # 'none'
167 # out must have same shape as input
168 if out.numel() != n_elements:
169 raise ValueError(
170 "mse_loss_out (reduction='none'): 'out' must have the same number of elements as input"
171 )
172 if out.device != x.device:
173 raise ValueError("mse_loss_out: 'out' must be on the same device as input")
174 if out.dtype != inp.dtype:
175 raise TypeError(
176 "mse_loss_out (reduction='none'): 'out' dtype must match input dtype"
177 )
179 if out.is_contiguous():
180 _launch_mse_elemwise(x, y, out)
181 else:
182 tmp = torch.empty_like(x, memory_format=torch.contiguous_format)
183 _launch_mse_elemwise(x, y, tmp)
184 out.copy_(tmp)
185 return out
187 # sum or mean
188 if out.device != x.device:
189 raise ValueError("mse_loss_out: 'out' must be on the same device as input")
190 if out.numel() != 1:
191 raise ValueError(
192 "mse_loss_out (reduction in ['sum','mean']): 'out' must be a scalar tensor"
193 )
194 # out dtype must be a supported float dtype
195 if out.dtype not in (torch.float16, torch.bfloat16, torch.float32):
196 raise TypeError(
197 "mse_loss_out: 'out' dtype must be one of float16, bfloat16, or float32 for Triton kernel"
198 )
200 if n_elements == 0:
201 out.fill_(0)
202 return out
204 scale = 1.0 if reduction == 2 else (1.0 / float(n_elements))
205 acc = _launch_mse_reduce(x, y, n_elements, scale)
206 out.fill_(acc.to(dtype=out.dtype))
207 return out