Coverage for src/flag_gems/ops/mse_loss.py: 66%
61 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
1import logging
2import math
3from enum import Enum
5import torch
6import triton
7import triton.language as tl
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry, pointwise_dynamic
11from flag_gems.utils import triton_lang_extension as tle
13logger = logging.getLogger(__name__)
16@libentry()
17@triton.jit
18def kernel_1(inp, target, mid, M, BLOCK_SIZE: tl.constexpr, reduction: tl.constexpr):
19 pid = tle.program_id(0)
20 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
21 inp_ptrs = inp + offset
22 target_ptrs = target + offset
23 mask = offset < M
25 inp_val = tl.load(inp_ptrs, mask=mask, other=0).to(tl.float32)
26 target_val = tl.load(target_ptrs, mask=mask, other=0).to(tl.float32)
27 sub = inp_val - target_val
28 pow_val = sub * sub
29 # Reduction.MEAN.value: 1 Reduction.SUM.value: 2
30 if reduction == 1:
31 sum_val = tl.sum(pow_val) / M
32 else:
33 sum_val = tl.sum(pow_val)
34 mid_ptr = mid + pid
35 tl.store(mid_ptr, sum_val)
38@libentry()
39@triton.jit
40def kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr):
41 offset = tl.arange(0, BLOCK_MID)
42 mid_ptrs = mid + offset
43 mask = offset < mid_size
44 mid_val = tl.load(mid_ptrs, mask=mask, other=0).to(tl.float32)
45 sum_val = tl.sum(mid_val)
46 tl.store(out, sum_val)
49@pointwise_dynamic(is_tensor=[True, True], promotion_methods=[(0, "DEFAULT")])
50@triton.jit
51def func(x, y):
52 return (x - y) * (x - y)
55class Reduction(Enum):
56 NONE = 0
57 MEAN = 1
58 SUM = 2
61def mse_loss(inp, target, reduction=Reduction.MEAN.value):
62 logger.debug("GEMS MSE LOSS")
63 if reduction == Reduction.NONE.value:
64 return func(inp, target)
66 inp = inp.contiguous()
67 target = target.contiguous()
68 M = inp.numel()
69 dtype = inp.dtype
71 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M)))
72 mid_size = triton.cdiv(M, block_size)
73 block_mid = triton.next_power_of_2(mid_size)
75 mid = torch.empty((mid_size,), dtype=torch.float32, device=inp.device)
76 out = torch.empty([], dtype=dtype, device=inp.device)
78 with torch_device_fn.device(inp.device):
79 kernel_1[(mid_size, 1, 1)](inp, target, mid, M, block_size, reduction)
80 kernel_2[(1, 1, 1)](mid, out, mid_size, block_mid)
81 return out