Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/var_mean.py: 0%

120 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-17 02:35 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7# from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import dim_compress, libentry 

10from flag_gems.utils import triton_lang_extension as tle 

11 

12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

13 

14 

15def heur_block_m(args): 

16 return triton.next_power_of_2(triton.cdiv(args["M"], 12)) 

17 

18 

19def heur_block_n(args): 

20 return min(8192, triton.next_power_of_2(args["N"])) 

21 

22 

23@triton.jit 

24def welford_func(mean_x, count_x, M_x, mean_y, count_y, M_y): 

25 count = count_x + count_y 

26 _count = tl.maximum(count, 1) 

27 mc_x = mean_x * count_x 

28 mc_y = mean_y * count_y 

29 mean = (mc_x + mc_y) / _count 

30 M = M_x + mc_x * mean_x + M_y + mc_y * mean_y - count * mean * mean 

31 return mean, count, M 

32 

33 

34@libentry() 

35# @triton.autotune(configs=runtime.get_tuned_config("var_mean"), key=["M", "N"]) 

36@triton.heuristics( 

37 { 

38 "BLOCK_M": heur_block_m, 

39 "BLOCK_N": heur_block_n, 

40 } 

41) 

42@triton.jit(do_not_specialize=["correction"]) 

43def var_mean_welford_kernel( 

44 X, 

45 Var, 

46 Mean, 

47 M, 

48 N, 

49 correction, 

50 BLOCK_M: tl.constexpr, 

51 BLOCK_N: tl.constexpr, 

52): 

53 # Map the program id to the row of X it should compute. 

54 pid = tle.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

55 X = X + pid * N 

56 Var = Var + pid 

57 Mean = Mean + pid 

58 row_mask = pid < M 

59 

60 _mean = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 

61 _acc = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 

62 _count = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 

63 for off in range(0, N, BLOCK_N): 

64 cols = off + tl.arange(0, BLOCK_N)[None, :] 

65 col_mask = cols < N 

66 mask = row_mask and col_mask 

67 

68 x = tl.load(X + cols, mask, other=0.0).to(tl.float32) 

69 

70 count = _count + mask 

71 cnt = tl.maximum(count, 1) 

72 cur_mean = (_mean * _count + x) / cnt 

73 _acc += (x - cur_mean) * (x - _mean) * mask 

74 _mean = cur_mean 

75 _count = count 

76 

77 mean, _, acc = tl.reduce((_mean, _count, _acc), axis=1, combine_fn=welford_func) 

78 var = acc / (N - correction) 

79 mean = mean[:, None] 

80 var = var[:, None] 

81 # Write mean / var 

82 tl.store(Mean, mean, row_mask) 

83 tl.store(Var, var, row_mask) 

84 

85 

86@libentry() 

87@triton.jit 

88def var_mean_kernel_1( 

89 X, 

90 Acc, 

91 Average, 

92 Count, 

93 N, 

94 BLOCK_N: tl.constexpr, 

95): 

96 # Map the program id to the row of X it should compute. 

97 pid = tle.program_id(0) 

98 offset = pid * BLOCK_N + tl.arange(0, BLOCK_N) 

99 

100 X = X + offset 

101 Acc = Acc + pid 

102 Average = Average + pid 

103 Count = Count + pid 

104 mask = offset < N 

105 

106 x = tl.load(X, mask, other=0.0).to(tl.float32) 

107 

108 count = tl.sum(mask.to(tl.float32)) 

109 average = tl.sum(x) / count 

110 acc = tl.sum(x * x) - count * average * average 

111 

112 tl.store(Average, average) 

113 tl.store(Acc, acc) 

114 tl.store(Count, count) 

115 

116 

117def heur_block_n(args): 

118 return triton.next_power_of_2(args["BLOCK_NUM"]) 

119 

120 

121@libentry() 

122# @triton.heuristics(runtime.get_heuristic_config("var_mean")) 

123@triton.heuristics( 

124 { 

125 "BLOCK_N": heur_block_n, 

126 } 

127) 

128@triton.jit(do_not_specialize=["correction"]) 

129def var_mean_kernel_2( 

130 Acc, 

131 Average, 

132 Count, 

133 Var, 

134 Mean, 

135 N, 

136 correction, 

137 BLOCK_NUM, 

138 BLOCK_N: tl.constexpr, 

139): 

140 offset = tl.arange(0, BLOCK_N) 

141 mask = offset < BLOCK_NUM 

142 Acc = Acc + offset 

143 Average = Average + offset 

144 Count = Count + offset 

145 acc = tl.load(Acc, mask, other=0.0).to(tl.float32) 

146 average = tl.load(Average, mask, other=0.0).to(tl.float32) 

147 count = tl.load(Count, mask, other=0.0).to(tl.float32) 

148 

149 mean, _, nvar = tl.reduce((average, count, acc), axis=0, combine_fn=welford_func) 

150 

151 var = nvar / (N - correction) 

152 tl.store(Mean, mean) 

153 tl.store(Var, var) 

154 

155 

156def var_mean(x, dim=None, *, correction=None, keepdim=False): 

157 logger.debug("GEMS VAR MEAN") 

158 if correction is None: 

159 correction = 1.0 

160 

161 if dim is None or len(dim) == x.ndim: 

162 dim = list(range(x.ndim)) 

163 shape = [1] * x.ndim 

164 N = x.numel() 

165 var = torch.empty(shape, dtype=x.dtype, device=x.device) 

166 mean = torch.empty(shape, dtype=x.dtype, device=x.device) 

167 BLOCK_N = 1024 

168 BLOCK_NUM = triton.cdiv(N, BLOCK_N) 

169 acc = torch.empty([BLOCK_NUM], dtype=x.dtype, device=x.device) 

170 average = torch.empty([BLOCK_NUM], dtype=x.dtype, device=x.device) 

171 count = torch.empty([BLOCK_NUM], dtype=x.dtype, device=x.device) 

172 

173 with torch_device_fn.device(x.device): 

174 var_mean_kernel_1[(BLOCK_NUM,)](x, acc, average, count, N, BLOCK_N=BLOCK_N) 

175 var_mean_kernel_2[(1,)]( 

176 acc, 

177 average, 

178 count, 

179 var, 

180 mean, 

181 N, 

182 correction, 

183 BLOCK_NUM, 

184 isCloseUnrollControl=True, 

185 ) 

186 else: 

187 shape = list(x.shape) 

188 dim = [d % x.ndim for d in dim] 

189 x = dim_compress(x, dim) 

190 N = 1 

191 for i in dim: 

192 N *= shape[i] 

193 shape[i] = 1 

194 M = x.numel() // N 

195 var = torch.empty(shape, dtype=x.dtype, device=x.device) 

196 mean = torch.empty(shape, dtype=x.dtype, device=x.device) 

197 

198 grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]),) 

199 with torch_device_fn.device(x.device): 

200 var_mean_welford_kernel[grid]( 

201 x, var, mean, M, N, correction, isCloseUnrollControl=True 

202 ) 

203 

204 if not keepdim: 

205 var = var.squeeze(dim=dim) 

206 mean = mean.squeeze(dim=dim) 

207 return var, mean