Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/mse_loss.py: 0%
72 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
1import logging
2import math
3from enum import Enum
5import torch
6import triton
7import triton.language as tl
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry
11from flag_gems.utils import triton_lang_extension as tle
13from ..utils.pointwise_dynamic import pointwise_dynamic
15logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
18@libentry()
19@triton.jit
20def kernel_1(inp, target, mid, M, BLOCK_SIZE: tl.constexpr, reduction: tl.constexpr):
21 pid = tle.program_id(0)
22 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
23 inp_ptrs = inp + offset
24 target_ptrs = target + offset
25 mask = offset < M
27 inp_val = tl.load(inp_ptrs, mask=mask, other=0).to(tl.float32)
28 target_val = tl.load(target_ptrs, mask=mask, other=0).to(tl.float32)
29 sub = inp_val - target_val
30 pow_val = sub * sub
31 # Reduction.MEAN.value: 1 Reduction.SUM.value: 2
32 if reduction == 1:
33 sum_val = tl.sum(pow_val) / M
34 else:
35 sum_val = tl.sum(pow_val)
36 mid_ptr = mid + pid
37 tl.store(mid_ptr, sum_val)
40@libentry()
41@triton.jit
42def kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr):
43 offset = tl.arange(0, BLOCK_MID)
44 mid_ptrs = mid + offset
45 mask = offset < mid_size
46 mid_val = tl.load(mid_ptrs, mask=mask, other=0).to(tl.float32)
47 sum_val = tl.sum(mid_val)
48 tl.store(out, sum_val)
51@pointwise_dynamic(is_tensor=[True, True], promotion_methods=[(0, "DEFAULT")])
52@triton.jit
53def func(x, y):
54 return (x - y) * (x - y)
57class Reduction(Enum):
58 NONE = 0
59 MEAN = 1
60 SUM = 2
63def mse_loss(inp, target, reduction=Reduction.MEAN.value):
64 logger.debug("GEMS MSE LOSS")
65 if reduction == Reduction.NONE.value:
66 return func(inp, target)
68 inp = inp.contiguous()
69 target = target.contiguous()
70 M = inp.numel()
71 dtype = inp.dtype
73 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M)))
74 mid_size = triton.cdiv(M, block_size)
75 block_mid = triton.next_power_of_2(mid_size)
77 if (
78 dtype == torch.bfloat16
79 and mid_size > 1024
80 and reduction == Reduction.MEAN.value
81 ):
82 mid_size = 12
83 block_size = triton.next_power_of_2(triton.cdiv(M, mid_size))
84 block_mid = triton.next_power_of_2(mid_size)
86 mid = torch.empty(
87 (mid_size,),
88 dtype=torch.float32
89 if (
90 dtype == torch.bfloat16
91 and mid_size > 1024
92 and reduction == Reduction.MEAN.value
93 )
94 else dtype,
95 device=inp.device,
96 )
97 out = torch.empty([], dtype=dtype, device=inp.device)
99 import os
101 os.environ["TRITONXPU_OTHER_SIM"] = "1"
103 with torch_device_fn.device(inp.device):
104 kernel_1[(mid_size, 1, 1)](inp, target, mid, M, block_size, reduction)
105 if mid_size == 1:
106 return mid.reshape([])
107 kernel_2[(1, 1, 1)](mid, out, mid_size, block_mid)
109 if "TRITONXPU_OTHER_SIM" in os.environ:
110 del os.environ["TRITONXPU_OTHER_SIM"]
112 return out