Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/mean.py: 0%

92 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-11 02:28 +0800

1import builtins 

2import logging 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8# from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import dim_compress, libentry 

11from flag_gems.utils import triton_lang_extension as tle 

12 

13from ..utils.block_size_utils import get_block_size_1d 

14 

15logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

16 

17 

18@libentry() 

19@triton.jit 

20def mean_kernel_1( 

21 inp, 

22 mid, 

23 M, 

24 BLOCK_SIZE: tl.constexpr, 

25): 

26 pid = tle.program_id(0) 

27 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

28 inp_ptrs = inp + offset 

29 mask = offset < M 

30 inp_val = tl.load(inp_ptrs, mask=mask, other=0.0) 

31 sum_val = tl.sum(inp_val, axis=0) 

32 mid_ptr = mid + pid 

33 tl.store(mid_ptr, sum_val) 

34 

35 

36@libentry() 

37@triton.jit 

38def mean_kernel_2(mid, out, M, MID_SIZE, BLOCK_MID: tl.constexpr): 

39 offset = tl.arange(0, BLOCK_MID) 

40 mid_ptrs = mid + offset 

41 mask = offset < MID_SIZE 

42 mid_val = tl.load(mid_ptrs, mask=mask, other=0.0) 

43 sum_val = tl.sum(mid_val, axis=0) / M 

44 tl.store(out, sum_val) 

45 

46 

47def mean(inp, *, dtype=None): 

48 logger.debug("GEMS MEAN") 

49 M = inp.numel() 

50 if dtype is None: 

51 dtype = inp.dtype 

52 # block_size = triton.next_power_of_2(math.ceil(math.sqrt(M))) 

53 block_size = get_block_size_1d(M, inp.element_size()) 

54 mid_size = triton.cdiv(M, block_size) 

55 block_mid = triton.next_power_of_2(mid_size) 

56 

57 mid = torch.empty((mid_size,), dtype=dtype, device=inp.device) 

58 out = torch.empty([], dtype=dtype, device=inp.device) 

59 

60 with torch_device_fn.device(inp.device): 

61 mean_kernel_1[(mid_size, 1, 1)](inp, mid, M, block_size, buffer_size_limit=2048) 

62 if mid_size == 1: 

63 return (mid / M).reshape([]) 

64 mean_kernel_2[(1, 1, 1)]( 

65 mid, out, M, mid_size, block_mid, buffer_size_limit=2048 

66 ) 

67 return out 

68 

69 

70def heur_m_block_size(args): 

71 return triton.next_power_of_2(triton.cdiv(args["M"], 12)) # cluster_num 

72 

73 

74def heur_n_block_size(args): 

75 return builtins.min(args["N"], 8192) 

76 

77 

78@libentry() 

79# @triton.autotune( 

80# configs=runtime.get_tuned_config("mean"), 

81# key=["M", "N"], 

82# ) 

83@triton.heuristics( 

84 values={ 

85 "BLOCK_M": heur_m_block_size, 

86 "BLOCK_N": heur_n_block_size, 

87 }, 

88) 

89@triton.jit 

90def mean_dim_kernel(X, Mean, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): 

91 # Map the program id to the row of X it should compute. 

92 pid = tle.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

93 X = X + pid * N 

94 Mean = Mean + pid 

95 row_mask = pid < M 

96 

97 # Compute mean 

98 _mean = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 

99 for off in range(0, N, BLOCK_N): 

100 cols = off + tl.arange(0, BLOCK_N)[None, :] 

101 col_mask = cols < N 

102 mask = row_mask and col_mask 

103 

104 a = tl.load(X + cols, mask, other=0.0).to(tl.float32) 

105 _mean += a 

106 mean = tl.sum(_mean, axis=1) / N 

107 mean = mean[:, None] 

108 tl.store(Mean, mean, row_mask) 

109 

110 

111def mean_dim(x, dim, keepdim=False, *, dtype=None): 

112 logger.debug("GEMS MEAN DIM") 

113 

114 if dtype is None: 

115 dtype = x.dtype 

116 if dim is None: 

117 out = mean(x, dtype=dtype) 

118 if not keepdim: 

119 out = out.reshape([1] * x.ndim) 

120 return out 

121 

122 shape = list(x.shape) 

123 dim = [d % x.ndim for d in dim] 

124 x = dim_compress(x, dim) 

125 N = 1 

126 for i in dim: 

127 N *= shape[i] 

128 shape[i] = 1 

129 M = x.numel() // N 

130 out = torch.empty(shape, dtype=dtype, device=x.device) 

131 grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]),) 

132 

133 with torch_device_fn.device(x.device): 

134 mean_dim_kernel[grid](x, out, M, N, buffer_size_limit=2048) 

135 if not keepdim: 

136 out = out.squeeze(dim) 

137 return out