Coverage for src/flag_gems/ops/mse_loss.py: 66%

61 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-21 14:31 +0800

1import logging 

2import math 

3from enum import Enum 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry, pointwise_dynamic 

11from flag_gems.utils import triton_lang_extension as tle 

12 

13logger = logging.getLogger(__name__) 

14 

15 

16@libentry() 

17@triton.jit 

18def kernel_1(inp, target, mid, M, BLOCK_SIZE: tl.constexpr, reduction: tl.constexpr): 

19 pid = tle.program_id(0) 

20 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

21 inp_ptrs = inp + offset 

22 target_ptrs = target + offset 

23 mask = offset < M 

24 

25 inp_val = tl.load(inp_ptrs, mask=mask, other=0).to(tl.float32) 

26 target_val = tl.load(target_ptrs, mask=mask, other=0).to(tl.float32) 

27 sub = inp_val - target_val 

28 pow_val = sub * sub 

29 # Reduction.MEAN.value: 1 Reduction.SUM.value: 2 

30 if reduction == 1: 

31 sum_val = tl.sum(pow_val) / M 

32 else: 

33 sum_val = tl.sum(pow_val) 

34 mid_ptr = mid + pid 

35 tl.store(mid_ptr, sum_val) 

36 

37 

38@libentry() 

39@triton.jit 

40def kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr): 

41 offset = tl.arange(0, BLOCK_MID) 

42 mid_ptrs = mid + offset 

43 mask = offset < mid_size 

44 mid_val = tl.load(mid_ptrs, mask=mask, other=0).to(tl.float32) 

45 sum_val = tl.sum(mid_val) 

46 tl.store(out, sum_val) 

47 

48 

49@pointwise_dynamic(is_tensor=[True, True], promotion_methods=[(0, "DEFAULT")]) 

50@triton.jit 

51def func(x, y): 

52 return (x - y) * (x - y) 

53 

54 

55class Reduction(Enum): 

56 NONE = 0 

57 MEAN = 1 

58 SUM = 2 

59 

60 

61def mse_loss(inp, target, reduction=Reduction.MEAN.value): 

62 logger.debug("GEMS MSE LOSS") 

63 if reduction == Reduction.NONE.value: 

64 return func(inp, target) 

65 

66 inp = inp.contiguous() 

67 target = target.contiguous() 

68 M = inp.numel() 

69 dtype = inp.dtype 

70 

71 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M))) 

72 mid_size = triton.cdiv(M, block_size) 

73 block_mid = triton.next_power_of_2(mid_size) 

74 

75 mid = torch.empty((mid_size,), dtype=torch.float32, device=inp.device) 

76 out = torch.empty([], dtype=dtype, device=inp.device) 

77 

78 with torch_device_fn.device(inp.device): 

79 kernel_1[(mid_size, 1, 1)](inp, target, mid, M, block_size, reduction) 

80 kernel_2[(1, 1, 1)](mid, out, mid_size, block_mid) 

81 return out