Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/mse_loss.py: 0%

72 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-27 02:51 +0800

1import logging 

2import math 

3from enum import Enum 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry 

11from flag_gems.utils import triton_lang_extension as tle 

12 

13from ..utils.pointwise_dynamic import pointwise_dynamic 

14 

15logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

16 

17 

18@libentry() 

19@triton.jit 

20def kernel_1(inp, target, mid, M, BLOCK_SIZE: tl.constexpr, reduction: tl.constexpr): 

21 pid = tle.program_id(0) 

22 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

23 inp_ptrs = inp + offset 

24 target_ptrs = target + offset 

25 mask = offset < M 

26 

27 inp_val = tl.load(inp_ptrs, mask=mask, other=0).to(tl.float32) 

28 target_val = tl.load(target_ptrs, mask=mask, other=0).to(tl.float32) 

29 sub = inp_val - target_val 

30 pow_val = sub * sub 

31 # Reduction.MEAN.value: 1 Reduction.SUM.value: 2 

32 if reduction == 1: 

33 sum_val = tl.sum(pow_val) / M 

34 else: 

35 sum_val = tl.sum(pow_val) 

36 mid_ptr = mid + pid 

37 tl.store(mid_ptr, sum_val) 

38 

39 

40@libentry() 

41@triton.jit 

42def kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr): 

43 offset = tl.arange(0, BLOCK_MID) 

44 mid_ptrs = mid + offset 

45 mask = offset < mid_size 

46 mid_val = tl.load(mid_ptrs, mask=mask, other=0).to(tl.float32) 

47 sum_val = tl.sum(mid_val) 

48 tl.store(out, sum_val) 

49 

50 

51@pointwise_dynamic(is_tensor=[True, True], promotion_methods=[(0, "DEFAULT")]) 

52@triton.jit 

53def func(x, y): 

54 return (x - y) * (x - y) 

55 

56 

57class Reduction(Enum): 

58 NONE = 0 

59 MEAN = 1 

60 SUM = 2 

61 

62 

63def mse_loss(inp, target, reduction=Reduction.MEAN.value): 

64 logger.debug("GEMS MSE LOSS") 

65 if reduction == Reduction.NONE.value: 

66 return func(inp, target) 

67 

68 inp = inp.contiguous() 

69 target = target.contiguous() 

70 M = inp.numel() 

71 dtype = inp.dtype 

72 

73 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M))) 

74 mid_size = triton.cdiv(M, block_size) 

75 block_mid = triton.next_power_of_2(mid_size) 

76 

77 if ( 

78 dtype == torch.bfloat16 

79 and mid_size > 1024 

80 and reduction == Reduction.MEAN.value 

81 ): 

82 mid_size = 12 

83 block_size = triton.next_power_of_2(triton.cdiv(M, mid_size)) 

84 block_mid = triton.next_power_of_2(mid_size) 

85 

86 mid = torch.empty( 

87 (mid_size,), 

88 dtype=torch.float32 

89 if ( 

90 dtype == torch.bfloat16 

91 and mid_size > 1024 

92 and reduction == Reduction.MEAN.value 

93 ) 

94 else dtype, 

95 device=inp.device, 

96 ) 

97 out = torch.empty([], dtype=dtype, device=inp.device) 

98 

99 import os 

100 

101 os.environ["TRITONXPU_OTHER_SIM"] = "1" 

102 

103 with torch_device_fn.device(inp.device): 

104 kernel_1[(mid_size, 1, 1)](inp, target, mid, M, block_size, reduction) 

105 if mid_size == 1: 

106 return mid.reshape([]) 

107 kernel_2[(1, 1, 1)](mid, out, mid_size, block_mid) 

108 

109 if "TRITONXPU_OTHER_SIM" in os.environ: 

110 del os.environ["TRITONXPU_OTHER_SIM"] 

111 

112 return out