Coverage for src/flag_gems/runtime/backend/_cambricon/ops/var_mean.py: 0%

160 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-20 02:31 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import dim_compress, libentry 

10 

11from ..utils import ( 

12 MAX_NRAM_SIZE, 

13 TOTAL_CORE_NUM, 

14 cfggen_reduce_op, 

15 count_divisible_by_2, 

16) 

17 

18logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

19 

20 

21@triton.jit 

22def welford_func(mean_x, count_x, M_x, mean_y, count_y, M_y): 

23 count = count_x + count_y 

24 _count = tl.maximum(count, 1) 

25 mc_x = mean_x * count_x 

26 mc_y = mean_y * count_y 

27 mean = (mc_x + mc_y) / _count 

28 M = M_x + mc_x * mean_x + M_y + mc_y * mean_y - count * mean * mean 

29 return mean, count, M 

30 

31 

32@libentry() 

33@triton.autotune(configs=runtime.get_tuned_config("var_mean"), key=["M", "N"]) 

34@triton.jit(do_not_specialize=["correction"]) 

35def var_mean_welford_kernel( 

36 X, 

37 Var, 

38 Mean, 

39 M, 

40 N, 

41 correction, 

42 BLOCK_M: tl.constexpr, 

43 BLOCK_N: tl.constexpr, 

44): 

45 # Map the program id to the row of X it should compute. 

46 num_prog = tl.num_programs(0) 

47 task_num = tl.cdiv(M, BLOCK_M) 

48 iter_num = tl.cdiv(task_num, num_prog) 

49 for i in range(0, iter_num): 

50 pid = (i * num_prog + tl.program_id(0)) * BLOCK_M + tl.arange(0, BLOCK_M)[ 

51 :, None 

52 ] 

53 X_ptr = X + pid * N 

54 Var_ptr = Var + pid 

55 Mean_ptr = Mean + pid 

56 row_mask = pid < M 

57 

58 _mean = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 

59 _acc = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 

60 _count = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 

61 for off in range(0, N, BLOCK_N): 

62 cols = off + tl.arange(0, BLOCK_N)[None, :] 

63 col_mask = cols < N 

64 mask = row_mask and col_mask 

65 

66 x = tl.load(X_ptr + cols, mask, other=0.0).to(tl.float32) 

67 

68 count = _count + mask 

69 cnt = tl.maximum(count, 1) 

70 cur_mean = (_mean * _count + x) / cnt 

71 _acc += (x - cur_mean) * (x - _mean) * mask 

72 _mean = cur_mean 

73 _count = count 

74 

75 mean, _, acc = tl.reduce((_mean, _count, _acc), axis=1, combine_fn=welford_func) 

76 var = acc / (N - correction) 

77 mean = mean[:, None] 

78 var = var[:, None] 

79 # Write mean / var 

80 tl.store(Mean_ptr, mean, row_mask) 

81 tl.store(Var_ptr, var, row_mask) 

82 

83 

84def prune_varmean_config(configs, named_args, **kwargs): 

85 M = named_args["M"] 

86 pruned_configs = [] 

87 for config in configs: 

88 BLOCK_SIZE = config.kwargs["BLOCK_SIZE"] 

89 num_stages = config.num_stages 

90 num_block = M // BLOCK_SIZE 

91 if num_block < 1: 

92 continue 

93 if num_block < TOTAL_CORE_NUM: 

94 # A core must process a BLOCK_SIZE of data. 

95 if num_stages > 1: 

96 continue 

97 alloc_num = 3 

98 else: 

99 alloc_num = 6 

100 # Set f32 as the default type. 

101 if BLOCK_SIZE * 4 * alloc_num < MAX_NRAM_SIZE: 

102 pruned_configs.append(config) 

103 # If M < 512, append the default config. 

104 if len(pruned_configs) == 0: 

105 pruned_configs.append( 

106 triton.Config({"BLOCK_SIZE": 512}, num_warps=1, num_stages=1) 

107 ) 

108 return pruned_configs 

109 

110 

111@libentry() 

112@triton.autotune( 

113 configs=cfggen_reduce_op(), 

114 prune_configs_by={"early_config_prune": prune_varmean_config}, 

115 key=["M"], 

116 reset_to_zero=["Acc", "Average", "Count"], 

117) 

118@triton.heuristics( 

119 values={ 

120 "ONE_TILE_PER_CTA": lambda args: args["M"] 

121 <= args["BLOCK_SIZE"] * TOTAL_CORE_NUM 

122 }, 

123) 

124@triton.jit 

125def var_mean_kernel_1( 

126 X, Acc, Average, Count, M, BLOCK_SIZE: tl.constexpr, ONE_TILE_PER_CTA: tl.constexpr 

127): 

128 # Map the program id to the row of X it should compute. 

129 pid = tl.program_id(0) 

130 block_start = pid * BLOCK_SIZE 

131 

132 count = 0.0 

133 average = 0.0 

134 acc = 0.0 

135 if ONE_TILE_PER_CTA: 

136 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

137 mask = offsets < M 

138 x = tl.load(X + offsets, mask, other=0.0).to(tl.float32) 

139 count = tl.sum(mask.to(tl.float32)) 

140 average = tl.sum(x) / count 

141 acc = tl.sum(x * x) - count * average * average 

142 else: 

143 _tmp1 = tl.zeros([BLOCK_SIZE], tl.float32) 

144 _tmp2 = tl.zeros([BLOCK_SIZE], tl.float32) 

145 num_jobs = tl.num_programs(axis=0) 

146 step = num_jobs * BLOCK_SIZE 

147 for block_start_offset in range(block_start, M, step): 

148 offsets = block_start_offset + tl.arange(0, BLOCK_SIZE) 

149 mask = offsets < M 

150 x = tl.load(X + offsets, mask, other=0.0).to(tl.float32) 

151 _count = tl.sum(mask.to(tl.float32)) 

152 count = count + _count 

153 _tmp1 = _tmp1 + x 

154 _tmp2 = _tmp2 + x * x 

155 count = tl.maximum(count, 1) 

156 average = tl.sum(_tmp1) / count 

157 acc = tl.sum(_tmp2) - count * average * average 

158 

159 Acc = Acc + pid 

160 Average = Average + pid 

161 Count = Count + pid 

162 

163 tl.store(Average, average) 

164 tl.store(Acc, acc) 

165 tl.store(Count, count) 

166 

167 

168@libentry() 

169@triton.heuristics(runtime.get_heuristic_config("var_mean")) 

170@triton.jit(do_not_specialize=["correction"]) 

171def var_mean_kernel_2( 

172 Acc, 

173 Average, 

174 Count, 

175 Var, 

176 Mean, 

177 M, 

178 correction, 

179 BLOCK_NUM: tl.constexpr, 

180 ITER_NUM: tl.constexpr, 

181): 

182 offset = tl.arange(0, BLOCK_NUM) 

183 Acc = Acc + offset 

184 Average = Average + offset 

185 Count = Count + offset 

186 acc = tl.load(Acc) 

187 average = tl.load(Average) 

188 count = tl.load(Count) 

189 

190 for x in tl.static_range(1, ITER_NUM, 1): 

191 ( 

192 average[: BLOCK_NUM // (2**x)], 

193 count[: BLOCK_NUM // (2**x)], 

194 acc[: BLOCK_NUM // (2**x)], 

195 ) = welford_func( 

196 average[: BLOCK_NUM // (2**x)], 

197 count[: BLOCK_NUM // (2**x)], 

198 acc[: BLOCK_NUM // (2**x)], 

199 average[BLOCK_NUM // (2**x) : (BLOCK_NUM // (2**x)) * 2], 

200 count[BLOCK_NUM // (2**x) : (BLOCK_NUM // (2**x)) * 2], 

201 acc[BLOCK_NUM // (2**x) : (BLOCK_NUM // (2**x)) * 2], 

202 ) 

203 mean, _, nvar = tl.reduce( 

204 ( 

205 average[: BLOCK_NUM // (2 ** (ITER_NUM - 1))], 

206 count[: BLOCK_NUM // (2 ** (ITER_NUM - 1))], 

207 acc[: BLOCK_NUM // (2 ** (ITER_NUM - 1))], 

208 ), 

209 axis=0, 

210 combine_fn=welford_func, 

211 ) 

212 

213 # FIXME: Reset to original reduce programming mode after optimizing the tl.reduce. 

214 # mean, _, nvar = tl.reduce((average, count, acc), axis=0, combine_fn=welford_func) 

215 

216 var = nvar / (M - correction) 

217 tl.store(Mean, mean) 

218 tl.store(Var, var) 

219 

220 

221def var_mean(x, dim=None, *, correction=None, keepdim=False): 

222 logger.debug("GEMS_CAMBRICON VAR MEAN") 

223 if correction is None: 

224 correction = 1.0 

225 

226 if dim is None or len(dim) == x.ndim: 

227 dim = list(range(x.ndim)) 

228 shape = [1] * x.ndim 

229 M = x.numel() 

230 

231 grid = lambda meta: (min(triton.cdiv(M, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),) 

232 var = torch.empty(shape, dtype=x.dtype, device=x.device) 

233 mean = torch.empty(shape, dtype=x.dtype, device=x.device) 

234 acc = torch.zeros([TOTAL_CORE_NUM], dtype=torch.float, device=x.device) 

235 average = torch.zeros([TOTAL_CORE_NUM], dtype=torch.float, device=x.device) 

236 count = torch.zeros([TOTAL_CORE_NUM], dtype=torch.float, device=x.device) 

237 loop_num = count_divisible_by_2(TOTAL_CORE_NUM) + 1 

238 

239 with torch_device_fn.device(x.device): 

240 var_mean_kernel_1[grid](x, acc, average, count, M) 

241 var_mean_kernel_2[(1,)]( 

242 acc, 

243 average, 

244 count, 

245 var, 

246 mean, 

247 M, 

248 correction, 

249 BLOCK_NUM=TOTAL_CORE_NUM, 

250 ITER_NUM=loop_num, 

251 ) 

252 else: 

253 shape = list(x.shape) 

254 dim = [d % x.ndim for d in dim] 

255 x = dim_compress(x, dim) 

256 N = 1 

257 for i in dim: 

258 N *= shape[i] 

259 shape[i] = 1 

260 M = x.numel() // N 

261 var = torch.empty(shape, dtype=x.dtype, device=x.device) 

262 mean = torch.empty(shape, dtype=x.dtype, device=x.device) 

263 

264 grid = lambda META: (min(triton.cdiv(M, META["BLOCK_M"]), TOTAL_CORE_NUM),) 

265 with torch_device_fn.device(x.device): 

266 var_mean_welford_kernel[grid](x, var, mean, M, N, correction) 

267 

268 if not keepdim: 

269 var = var.squeeze(dim=dim) 

270 mean = mean.squeeze(dim=dim) 

271 return var, mean