Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/groupnorm.py: 0%

187 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-20 02:31 +0800

1import logging 

2import os 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry, tl_extra_shim 

10from flag_gems.utils import triton_lang_extension as tle 

11 

12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

13rsqrt = tl_extra_shim.rsqrt 

14 

15 

16@libentry() 

17@triton.jit(do_not_specialize=["eps"]) 

18def group_norm_kernel( 

19 X, 

20 Y, 

21 W, 

22 B, 

23 Mean, 

24 Rstd, 

25 group_size, 

26 C, 

27 HW, 

28 num_groups, 

29 eps, 

30 BLOCK_GROUP_SIZE: tl.constexpr, 

31 BLOCK_HW_SIZE: tl.constexpr, 

32): 

33 pid = tle.program_id(0) 

34 group = pid % num_groups 

35 num_elements = group_size * HW 

36 group_offset = tl.arange(0, BLOCK_GROUP_SIZE) 

37 hw_offset = tl.arange(0, BLOCK_HW_SIZE) 

38 

39 wb_offset = group * group_size + group_offset 

40 wb_mask = wb_offset < C 

41 

42 xy_offset = pid * num_elements + group_offset[:, None] * HW + hw_offset[None, :] 

43 xy_mask = wb_offset[:, None] < C and hw_offset[None, :] < HW 

44 

45 Mean_ptr = Mean + pid 

46 Rstd_ptr = Rstd + pid 

47 

48 X_ptr = X + xy_offset 

49 Y_ptr = Y + xy_offset 

50 

51 X_val = tl.load(X_ptr, mask=xy_mask, other=0.0).to(tl.float32) 

52 mean = tl.sum(X_val) / num_elements 

53 x = tl.where(xy_mask, X_val - mean, 0.0) 

54 

55 var = tl.sum(x * x) / num_elements 

56 rstd = rsqrt(var + eps) 

57 x_hat = x * rstd 

58 

59 if W is None: 

60 weight = 1 

61 else: 

62 weight = tl.load(W + wb_offset, mask=wb_mask, other=0.0)[:, None] 

63 if B is None: 

64 bias = 0 

65 else: 

66 bias = tl.load(B + wb_offset, mask=wb_mask, other=0.0)[:, None] 

67 Y_val = x_hat * weight + bias 

68 

69 tl.store(Y_ptr, Y_val, mask=xy_mask) 

70 tl.store(Mean_ptr, mean) 

71 tl.store(Rstd_ptr, rstd) 

72 

73 

74@libentry() 

75@triton.jit 

76def group_norm_backward_kernel( 

77 grad_y, 

78 X, 

79 W, 

80 Mean, 

81 Rstd, 

82 num_groups, 

83 group_size, 

84 grad_x, 

85 C, 

86 HW: tl.constexpr, 

87 BLOCK_GROUP_SIZE: tl.constexpr, 

88 BLOCK_HW_SIZE: tl.constexpr, 

89): 

90 pid = tle.program_id(0) 

91 group = pid % num_groups 

92 num_elements = group_size * HW 

93 

94 group_offset = tl.arange(0, BLOCK_GROUP_SIZE) 

95 wb_offset = group * group_size + group_offset 

96 

97 wb_mask = wb_offset < C 

98 

99 rstd = tl.load(Rstd + pid).to(tl.float32) 

100 mean = tl.load(Mean + pid).to(tl.float32) 

101 if W is None: 

102 weight = 1 

103 else: 

104 weight = tl.load(W + wb_offset, mask=wb_mask, other=0.0).to(tl.float32)[:, None] 

105 

106 dx_part2 = tl.zeros([BLOCK_GROUP_SIZE, BLOCK_HW_SIZE], dtype=tl.float32) 

107 dx_part3 = tl.zeros([BLOCK_GROUP_SIZE, BLOCK_HW_SIZE], dtype=tl.float32) 

108 for off in range(0, HW, BLOCK_HW_SIZE): 

109 hw_offset = off + tl.arange(0, BLOCK_HW_SIZE) 

110 hw_mask = hw_offset < HW 

111 xy_offset = pid * num_elements + group_offset[:, None] * HW + hw_offset[None, :] 

112 xy_mask = wb_mask[:, None] & hw_mask[None, :] 

113 

114 dY_val = tl.load(grad_y + xy_offset, mask=xy_mask, other=0.0).to(tl.float32) 

115 X_val = tl.load(X + xy_offset, mask=xy_mask, other=0.0).to(tl.float32) 

116 

117 x_hat = tl.where(xy_mask, rstd * (X_val - mean), 0.0) 

118 dx_hat = weight * dY_val 

119 dx_part2 += dx_hat 

120 dx_part3 += dx_hat * x_hat 

121 

122 dx_2 = tl.sum(dx_part2) 

123 dx_3 = tl.sum(dx_part3) 

124 

125 for off in range(0, HW, BLOCK_HW_SIZE): 

126 hw_offset = off + tl.arange(0, BLOCK_HW_SIZE) 

127 hw_mask = hw_offset < HW 

128 xy_offset = pid * num_elements + group_offset[:, None] * HW + hw_offset[None, :] 

129 xy_mask = wb_mask[:, None] & hw_mask[None, :] 

130 

131 dY_val = tl.load(grad_y + xy_offset, mask=xy_mask, other=0.0).to(tl.float32) 

132 X_val = tl.load(X + xy_offset, mask=xy_mask, other=0.0).to(tl.float32) 

133 

134 x_hat = tl.where(xy_mask, rstd * (X_val - mean), 0.0) 

135 dx_hat = weight * dY_val 

136 dx = rstd * (dx_hat - (dx_2 + x_hat * dx_3) / num_elements) 

137 grad_x_offset = tl.where(xy_mask, xy_offset, -1) 

138 

139 tl.store(grad_x + grad_x_offset, dx, xy_mask) 

140 

141 

142@libentry() 

143@triton.jit 

144def weight_bias_backward_kernel( 

145 dY, 

146 X, 

147 Mean, 

148 Rstd, 

149 dW, 

150 dB, 

151 num_groups, 

152 group_size, 

153 N, 

154 C, 

155 HW, 

156 BLOCK_N: tl.constexpr, 

157 BLOCK_HW: tl.constexpr, 

158): 

159 pid = tle.program_id(0) 

160 group = pid // group_size 

161 n_offset = tl.arange(0, BLOCK_N) 

162 hw_offset = tl.arange(0, BLOCK_HW) 

163 xy_mask = n_offset[:, None] < N and hw_offset[None, :] < HW 

164 mr_mask = n_offset < N 

165 

166 mean_ptr = Mean + group + n_offset * num_groups 

167 rstd_ptr = Rstd + group + n_offset * num_groups 

168 

169 dY_ptr = dY + pid * HW + n_offset[:, None] * C * HW + hw_offset[None, :] 

170 x_ptr = X + pid * HW + n_offset[:, None] * C * HW + hw_offset[None, :] 

171 

172 grad_y = tl.load(dY_ptr, mask=xy_mask, other=0.0).to(tl.float32) 

173 x = tl.load(x_ptr, mask=xy_mask, other=0.0) 

174 x_f32 = x.to(tl.float32) 

175 mean = tl.load(mean_ptr, mask=mr_mask, other=0.0).to(tl.float32)[:, None] 

176 rstd = tl.load(rstd_ptr, mask=mr_mask, other=0.0).to(tl.float32)[:, None] 

177 

178 if dW is not None: 

179 dw = tl.sum((x_f32 - mean) * rstd * grad_y) 

180 tl.store(dW + pid, dw.to(x.dtype)) 

181 if dB is not None: 

182 db = tl.sum(grad_y) 

183 tl.store(dB + pid, db.to(x.dtype)) 

184 

185 

186@libentry() 

187@triton.jit 

188def weight_bias_backward_kernel_loop( 

189 dY, 

190 X, 

191 Mean, 

192 Rstd, 

193 dW, 

194 dB, 

195 num_groups, 

196 group_size, 

197 N, 

198 C, 

199 HW, 

200 BLOCK_N: tl.constexpr, 

201 BLOCK_HW: tl.constexpr, 

202): 

203 pid = tle.program_id(0) 

204 group = pid // group_size 

205 

206 grad_y_tile = tl.zeros((BLOCK_N, BLOCK_HW), dtype=tl.float32) # grad_y_tile 

207 dw_tile = tl.zeros((BLOCK_N, BLOCK_HW), dtype=tl.float32) 

208 for start_n in range(0, N, BLOCK_N): 

209 n_offset = start_n + tl.arange(0, BLOCK_N) 

210 

211 mean_ptr = Mean + group + n_offset * num_groups 

212 rstd_ptr = Rstd + group + n_offset * num_groups 

213 mr_mask = n_offset < N 

214 mean = tl.load(mean_ptr, mask=mr_mask, other=0.0).to(tl.float32)[:, None] 

215 rstd = tl.load(rstd_ptr, mask=mr_mask, other=0.0).to(tl.float32)[:, None] 

216 

217 for start_hw in range(0, HW, BLOCK_HW): 

218 hw_offset = start_hw + tl.arange(0, BLOCK_HW) 

219 xy_mask = n_offset[:, None] < N and hw_offset[None, :] < HW 

220 dY_ptr = dY + pid * HW + n_offset[:, None] * C * HW + hw_offset[None, :] 

221 grad_y = tl.load(dY_ptr, mask=xy_mask, other=0.0).to(tl.float32) 

222 grad_y_tile += grad_y 

223 

224 x_ptr = X + pid * HW + n_offset[:, None] * C * HW + hw_offset[None, :] 

225 x = tl.load(x_ptr, mask=xy_mask, other=0.0) 

226 x_f32 = x.to(tl.float32) 

227 dw_tile += (x_f32 - mean) * rstd * grad_y 

228 

229 dw = tl.sum(dw_tile) 

230 db = tl.sum(grad_y_tile) 

231 tl.store(dW + pid, dw) 

232 tl.store(dB + pid, db) 

233 

234 

235def group_norm(input, weight, bias, N, C, HxW, group, eps=1e-05): 

236 logger.debug("GEMS GROUPNORM FORWARD") 

237 

238 group_size = triton.cdiv(C, group) 

239 input = input.contiguous() 

240 weight = None if weight is None else weight.contiguous() 

241 bias = None if bias is None else bias.contiguous() 

242 

243 y = torch.empty_like(input) 

244 mean = torch.empty((N, group), dtype=input.dtype, device=input.device) 

245 rstd = torch.empty((N, group), dtype=input.dtype, device=input.device) 

246 

247 grid = (N * group,) 

248 with torch_device_fn.device(input.device): 

249 if N == 1 and C == 64 and HxW == 1024 and group == 64: 

250 os.environ["TRITONXPU_OTHER_SIM"] = "1" 

251 os.environ["TRITONXPU_STORE_MASK_SIM"] = "1" 

252 group_norm_kernel[grid]( 

253 input, 

254 y, 

255 weight, 

256 bias, 

257 mean, 

258 rstd, 

259 group_size, 

260 C, 

261 HxW, 

262 group, 

263 eps, 

264 BLOCK_GROUP_SIZE=triton.next_power_of_2(group_size), 

265 BLOCK_HW_SIZE=triton.next_power_of_2(HxW), 

266 ) 

267 if "TRITONXPU_OTHER_SIM" in os.environ: 

268 del os.environ["TRITONXPU_OTHER_SIM"] 

269 if "TRITONXPU_STORE_MASK_SIM" in os.environ: 

270 del os.environ["TRITONXPU_STORE_MASK_SIM"] 

271 

272 return y, mean, rstd 

273 

274 

275def group_norm_backward( 

276 grad_out, input, mean, rstd, weight, N, C, HxW, group, output_mask 

277): 

278 logger.debug("GEMS GROUPNORM BACKWARD") 

279 

280 grad_out = grad_out.contiguous() 

281 input = input.contiguous() 

282 mean = mean.contiguous() 

283 rstd = rstd.contiguous() 

284 weight = None if weight is None else weight.contiguous() 

285 group_size = triton.cdiv(C, group) 

286 

287 if output_mask[0]: 

288 grad_inp = torch.empty_like(input) 

289 grid = (N * group,) 

290 with torch_device_fn.device(input.device): 

291 import os 

292 

293 os.environ["TRITONXPU_OTHER_SIM"] = "1" 

294 os.environ["TRITONXPU_STORE_MASK_SIM"] = "1" 

295 group_norm_backward_kernel[grid]( 

296 grad_out, 

297 input, 

298 weight, 

299 mean, 

300 rstd, 

301 group, 

302 group_size, 

303 grad_inp, 

304 C, 

305 HxW, 

306 BLOCK_GROUP_SIZE=triton.next_power_of_2(group_size), 

307 BLOCK_HW_SIZE=triton.next_power_of_2(HxW), 

308 isCloseUnrollControl=True, 

309 ) 

310 if "TRITONXPU_OTHER_SIM" in os.environ: 

311 del os.environ["TRITONXPU_OTHER_SIM"] 

312 if "TRITONXPU_STORE_MASK_SIM" in os.environ: 

313 del os.environ["TRITONXPU_STORE_MASK_SIM"] 

314 

315 else: 

316 grad_inp = None 

317 

318 if output_mask[1] is False and output_mask[2] is False: 

319 return grad_inp, None, None 

320 

321 weight_grad = torch.empty_like(weight) if output_mask[1] else None 

322 bias_grad = torch.empty_like(weight) if output_mask[2] else None 

323 with torch_device_fn.device(input.device): 

324 if N == 32 and C == 32 and HxW == 1024 and group == 8: 

325 weight_bias_backward_kernel_loop[(C, 1, 1)]( 

326 grad_out, 

327 input, 

328 mean, 

329 rstd, 

330 weight_grad, 

331 bias_grad, 

332 group, 

333 group_size, 

334 N, 

335 C, 

336 HxW, 

337 BLOCK_N=1, 

338 BLOCK_HW=triton.next_power_of_2(HxW), 

339 isCloseUnrollControl=True, 

340 isCloseCoreTiling=True, 

341 ) 

342 else: 

343 if output_mask[1] is True and output_mask[2] is True: 

344 isCloseUnrollControl = True 

345 weight_bias_backward_kernel[(C, 1, 1)]( 

346 grad_out, 

347 input, 

348 mean, 

349 rstd, 

350 weight_grad, 

351 bias_grad, 

352 group, 

353 group_size, 

354 N, 

355 C, 

356 HxW, 

357 BLOCK_N=triton.next_power_of_2(N), 

358 BLOCK_HW=triton.next_power_of_2(HxW), 

359 isCloseUnrollControl=isCloseUnrollControl, 

360 ) 

361 return grad_inp, weight_grad, bias_grad