Coverage for src/flag_gems/runtime/backend/_ascend/ops/groupnorm.py: 0%

168 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-23 02:03 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils import libentry, tl_extra_shim 

9from flag_gems.utils import triton_lang_extension as tle 

10 

11logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}') 

12 

13rsqrt = tl_extra_shim.rsqrt 

14 

15 

16@libentry() 

17@triton.jit 

18def group_norm_backward_kernel( 

19 grad_y, 

20 X, 

21 W, 

22 Mean, 

23 Rstd, 

24 num_groups, 

25 group_size, 

26 grad_x, 

27 C, 

28 HW, 

29 BLOCK_GROUP_SIZE: tl.constexpr, 

30 BLOCK_HW_SIZE: tl.constexpr = 128, 

31): 

32 pid = tle.program_id(0) 

33 group = pid % num_groups 

34 num_elements = group_size * HW 

35 

36 group_offset = tl.arange(0, BLOCK_GROUP_SIZE) 

37 wb_offset = group * group_size + group_offset 

38 

39 wb_mask = wb_offset < C 

40 rstd = tl.load(Rstd + pid).to(tl.float32) 

41 mean = tl.load(Mean + pid).to(tl.float32) 

42 

43 if W is None: 

44 weight = 1 

45 else: 

46 weight = tl.load(W + wb_offset, mask=wb_mask, other=0.0).to(tl.float32)[:, None] 

47 

48 dx_part2 = tl.zeros([BLOCK_GROUP_SIZE, BLOCK_HW_SIZE], dtype=tl.float32) 

49 dx_part3 = tl.zeros([BLOCK_GROUP_SIZE, BLOCK_HW_SIZE], dtype=tl.float32) 

50 for off in range(0, HW, BLOCK_HW_SIZE): 

51 hw_offset = off + tl.arange(0, BLOCK_HW_SIZE) 

52 hw_mask = hw_offset < HW 

53 xy_offset = pid * num_elements + group_offset[:, None] * HW + hw_offset[None, :] 

54 xy_mask = wb_mask[:, None] & hw_mask[None, :] 

55 dY_val = tl.load(grad_y + xy_offset, mask=xy_mask, other=0.0).to(tl.float32) 

56 X_val = tl.load(X + xy_offset, mask=xy_mask, other=0.0).to(tl.float32) 

57 

58 x_hat = tl.where(xy_mask, rstd * (X_val - mean), 0.0) 

59 dx_hat = weight * dY_val 

60 dx_part2 += dx_hat 

61 dx_part3 += dx_hat * x_hat 

62 

63 dx_2 = tl.sum(dx_part2) 

64 dx_3 = tl.sum(dx_part3) 

65 

66 for off in range(0, HW, BLOCK_HW_SIZE): 

67 hw_offset = off + tl.arange(0, BLOCK_HW_SIZE) 

68 hw_mask = hw_offset < HW 

69 xy_offset = pid * num_elements + group_offset[:, None] * HW + hw_offset[None, :] 

70 xy_mask = wb_mask[:, None] & hw_mask[None, :] 

71 

72 dY_val = tl.load(grad_y + xy_offset, mask=xy_mask, other=0.0).to(tl.float32) 

73 X_val = tl.load(X + xy_offset, mask=xy_mask, other=0.0).to(tl.float32) 

74 

75 x_hat = tl.where(xy_mask, rstd * (X_val - mean), 0.0) 

76 dx_hat = weight * dY_val 

77 dx = rstd * (dx_hat - (dx_2 + x_hat * dx_3) / num_elements) 

78 

79 tl.store(grad_x + xy_offset, dx, xy_mask) 

80 

81 

82@libentry() 

83@triton.jit 

84def weight_bias_backward_kernel( 

85 dY, 

86 X, 

87 Mean, 

88 Rstd, 

89 dW, 

90 dB, 

91 num_groups, 

92 group_size, 

93 N, 

94 C, 

95 HW, 

96 BLOCK_N: tl.constexpr, 

97 BLOCK_HW: tl.constexpr, 

98): 

99 pid = tle.program_id(0) 

100 group = pid // group_size 

101 n_offset = tl.arange(0, BLOCK_N) 

102 mr_mask = n_offset < N 

103 mean_ptr = Mean + group + n_offset * num_groups 

104 rstd_ptr = Rstd + group + n_offset * num_groups 

105 mean = tl.load(mean_ptr, mask=mr_mask, other=0.0).to(tl.float32)[:, None] 

106 rstd = tl.load(rstd_ptr, mask=mr_mask, other=0.0).to(tl.float32)[:, None] 

107 

108 SUB_BLOCK_HW: tl.constexpr = 64 

109 

110 dw_sum = 0.0 

111 db_sum = 0.0 

112 

113 for hw_off in range(0, BLOCK_HW, SUB_BLOCK_HW): 

114 hw_offset = hw_off + tl.arange(0, SUB_BLOCK_HW) 

115 xy_mask = n_offset[:, None] < N and hw_offset[None, :] < HW 

116 dY_ptr = dY + pid * HW + n_offset[:, None] * C * HW + hw_offset[None, :] 

117 x_ptr = X + pid * HW + n_offset[:, None] * C * HW + hw_offset[None, :] 

118 grad_y = tl.load(dY_ptr, mask=xy_mask, other=0.0).to(tl.float32) 

119 x = tl.load(x_ptr, mask=xy_mask, other=0.0) 

120 x_f32 = x.to(tl.float32) 

121 if dW is not None: 

122 dw_sum = dw_sum + tl.sum((x_f32 - mean) * rstd * grad_y) 

123 if dB is not None: 

124 db_sum = db_sum + tl.sum(grad_y) 

125 

126 if dW is not None: 

127 dw = dw_sum 

128 tl.store(dW + pid, dw.to(mean.dtype)) 

129 if dB is not None: 

130 db = db_sum 

131 tl.store(dB + pid, db.to(mean.dtype)) 

132 

133 

134@libentry() 

135@triton.autotune( 

136 configs=[ 

137 triton.Config({"BLOCK_SUB_HW_SIZE": 32}), 

138 triton.Config({"BLOCK_SUB_HW_SIZE": 64}), 

139 triton.Config({"BLOCK_SUB_HW_SIZE": 128}), 

140 triton.Config({"BLOCK_SUB_HW_SIZE": 256}), 

141 triton.Config({"BLOCK_SUB_HW_SIZE": 512}), 

142 triton.Config({"BLOCK_SUB_HW_SIZE": 1024}), 

143 triton.Config({"BLOCK_SUB_HW_SIZE": 2048}), 

144 triton.Config({"BLOCK_SUB_HW_SIZE": 4096}), 

145 triton.Config({"BLOCK_SUB_HW_SIZE": 8192}), 

146 triton.Config({"BLOCK_SUB_HW_SIZE": 16384}), 

147 ], 

148 key=["HW", "group_size"], 

149) 

150@triton.jit(do_not_specialize=["eps"]) 

151def group_norm_kernel( 

152 X, 

153 Y, 

154 W, 

155 B, 

156 Mean, 

157 Rstd, 

158 group_size, 

159 C, 

160 HW, 

161 num_groups, 

162 eps, 

163 BLOCK_GROUP_SIZE: tl.constexpr, 

164 BLOCK_HW_SIZE: tl.constexpr, 

165 BLOCK_SUB_HW_SIZE: tl.constexpr, 

166): 

167 pid = tl.program_id(0) 

168 batch_idx = pid // num_groups 

169 group_idx = pid % num_groups 

170 

171 # 计算当前group在整个tensor中的起始位置 

172 batch_offset = batch_idx * C * HW 

173 group_start_channel = group_idx * group_size 

174 

175 num_elements = group_size * HW 

176 

177 # 第一次遍历:计算均值 

178 X_sum = 0.0 

179 for hw_start in range(0, HW, BLOCK_SUB_HW_SIZE): 

180 hw_offsets = hw_start + tl.arange(0, BLOCK_SUB_HW_SIZE) 

181 hw_mask = hw_offsets < HW 

182 

183 # 先按HW维度连续,再按channel维度 

184 for c_idx in range(BLOCK_GROUP_SIZE): 

185 if c_idx < group_size and (group_start_channel + c_idx) < C: 

186 channel_offset = group_start_channel + c_idx 

187 # 连续访问HW维度的数据 

188 base_offset = batch_offset + channel_offset * HW + hw_offsets 

189 X_vals = tl.load(X + base_offset, mask=hw_mask, other=0.0).to( 

190 tl.float32 

191 ) 

192 X_sum += tl.sum(X_vals) 

193 

194 mean = X_sum / num_elements 

195 

196 # 第二次遍历:计算方差 

197 X_var_sum = 0.0 

198 for hw_start in range(0, HW, BLOCK_SUB_HW_SIZE): 

199 hw_offsets = hw_start + tl.arange(0, BLOCK_SUB_HW_SIZE) 

200 hw_mask = hw_offsets < HW 

201 

202 for c_idx in range(BLOCK_GROUP_SIZE): 

203 if c_idx < group_size and (group_start_channel + c_idx) < C: 

204 channel_offset = group_start_channel + c_idx 

205 base_offset = batch_offset + channel_offset * HW + hw_offsets 

206 X_vals = tl.load(X + base_offset, mask=hw_mask, other=mean).to( 

207 tl.float32 

208 ) 

209 x_centered = X_vals - mean 

210 X_var_sum += tl.sum(x_centered * x_centered) 

211 

212 var = X_var_sum / num_elements 

213 rstd = rsqrt(var + eps) 

214 

215 # 第三次遍历:归一化并写回 

216 for hw_start in range(0, HW, BLOCK_SUB_HW_SIZE): 

217 hw_offsets = hw_start + tl.arange(0, BLOCK_SUB_HW_SIZE) 

218 hw_mask = hw_offsets < HW 

219 

220 for c_idx in range(BLOCK_GROUP_SIZE): 

221 if c_idx < group_size and (group_start_channel + c_idx) < C: 

222 channel_offset = group_start_channel + c_idx 

223 base_offset = batch_offset + channel_offset * HW + hw_offsets 

224 

225 # 加载数据 

226 X_vals = tl.load(X + base_offset, mask=hw_mask, other=0.0).to( 

227 tl.float32 

228 ) 

229 

230 # 归一化并应用仿射变换 

231 x_normalized = (X_vals - mean) * rstd 

232 if W is not None: 

233 w_val = tl.load(W + channel_offset) 

234 x_normalized = x_normalized * w_val 

235 if B is not None: 

236 b_val = tl.load(B + channel_offset) 

237 x_normalized = x_normalized + b_val 

238 

239 # 存储结果 

240 tl.store(Y + base_offset, x_normalized, mask=hw_mask) 

241 

242 # 存储均值和标准差 

243 mean_rstd_offset = batch_idx * num_groups + group_idx 

244 tl.store(Mean + mean_rstd_offset, mean) 

245 tl.store(Rstd + mean_rstd_offset, rstd) 

246 

247 

248def group_norm(input, weight, bias, N, C, HxW, group, eps=1e-05): 

249 logger.debug("ASCEND GEMS GROUPNORM FORWARD") 

250 group_size = triton.cdiv(C, group) 

251 input = input.contiguous() 

252 weight = None if weight is None else weight.contiguous() 

253 bias = None if bias is None else bias.contiguous() 

254 

255 y = torch.empty_like(input) 

256 mean = torch.empty((N, group), dtype=input.dtype, device=input.device) 

257 rstd = torch.empty((N, group), dtype=input.dtype, device=input.device) 

258 

259 grid = (N * group,) 

260 with torch_device_fn.device(input.device): 

261 group_norm_kernel[grid]( 

262 input, 

263 y, 

264 weight, 

265 bias, 

266 mean, 

267 rstd, 

268 group_size, 

269 C, 

270 HxW, 

271 group, 

272 eps, 

273 BLOCK_GROUP_SIZE=triton.next_power_of_2(group_size), 

274 BLOCK_HW_SIZE=triton.next_power_of_2(HxW), 

275 ) 

276 return y, mean, rstd 

277 

278 

279def group_norm_backward( 

280 grad_out, input, mean, rstd, weight, N, C, HxW, group, output_mask 

281): 

282 logger.debug("ASCEND GEMS GROUPNORM BACKWARD") 

283 grad_out = grad_out.contiguous() 

284 input = input.contiguous() 

285 mean = mean.contiguous() 

286 rstd = rstd.contiguous() 

287 weight = None if weight is None else weight.contiguous() 

288 group_size = triton.cdiv(C, group) 

289 

290 if output_mask[0]: 

291 grad_inp = torch.empty_like(input) 

292 grid = (N * group,) 

293 with torch_device_fn.device(input.device): 

294 group_norm_backward_kernel[grid]( 

295 grad_out, 

296 input, 

297 weight, 

298 mean, 

299 rstd, 

300 group, 

301 group_size, 

302 grad_inp, 

303 C, 

304 HxW, 

305 BLOCK_GROUP_SIZE=triton.next_power_of_2(group_size), 

306 ) 

307 else: 

308 grad_inp = None 

309 

310 if output_mask[1] is False and output_mask[2] is False: 

311 return grad_inp, None, None 

312 

313 weight_grad = torch.empty_like(weight) if output_mask[1] else None 

314 bias_grad = torch.empty_like(weight) if output_mask[2] else None 

315 with torch_device_fn.device(input.device): 

316 weight_bias_backward_kernel[(C, 1, 1)]( 

317 grad_out, 

318 input, 

319 mean, 

320 rstd, 

321 weight_grad, 

322 bias_grad, 

323 group, 

324 group_size, 

325 N, 

326 C, 

327 HxW, 

328 BLOCK_N=triton.next_power_of_2(N), 

329 BLOCK_HW=triton.next_power_of_2(HxW), 

330 ) 

331 return grad_inp, weight_grad, bias_grad