Coverage for src/flag_gems/runtime/backend/_cambricon/ops/groupnorm.py: 0%

343 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-24 15:40 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils import libentry, tl_extra_shim 

9 

10from ..utils import TOTAL_CORE_NUM 

11 

12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

13rsqrt = tl_extra_shim.rsqrt 

14 

15 

16def group_norm_kernel_opt_prune(configs, named_args, **kwargs): 

17 pruned_configs = [] 

18 hw = kwargs["HW"] 

19 num_groups = named_args["num_groups"] 

20 all_sizes = [] 

21 for config in configs: 

22 BLOCK_HW_SIZE = config.kwargs["BLOCK_HW_SIZE"] 

23 if BLOCK_HW_SIZE not in all_sizes: 

24 all_sizes.append(BLOCK_HW_SIZE) 

25 

26 for config in configs: 

27 BLOCK_HW_SIZE = config.kwargs["BLOCK_HW_SIZE"] 

28 SPLIT = config.kwargs["SPLIT"] 

29 if (hw > 4096) and (BLOCK_HW_SIZE >= 4096) and (SPLIT <= 1): 

30 pruned_configs.append(config) 

31 elif (BLOCK_HW_SIZE >= hw) and (SPLIT <= num_groups): 

32 not_step_bigger = False 

33 for size in all_sizes: 

34 if (size < BLOCK_HW_SIZE) and (size > hw): 

35 not_step_bigger = True 

36 if not not_step_bigger: 

37 pruned_configs.append(config) 

38 return pruned_configs 

39 

40 

41@libentry() 

42@triton.autotune( 

43 configs=[ 

44 triton.Config({"SPLIT": s, "BLOCK_HW_SIZE": size}, num_stages=3, num_warps=1) 

45 for size in [64, 256, 512, 1024, 2048, 4096, 5120] 

46 for s in [1, 4, 6, 8, 16] 

47 ], 

48 key=["X", "group_size", "C", "HW", "num_groups"], 

49 prune_configs_by={"early_config_prune": group_norm_kernel_opt_prune}, 

50) 

51@triton.jit(do_not_specialize=["eps"]) 

52def group_norm_kernel_opt( 

53 X, 

54 Y, 

55 W, 

56 B, 

57 Mean, 

58 Rstd, 

59 group_size, 

60 C, 

61 num_groups, 

62 eps, 

63 HW: tl.constexpr, 

64 BLOCK_GROUP_SIZE: tl.constexpr, 

65 BLOCK_HW_SIZE: tl.constexpr, 

66 SPLIT: tl.constexpr, 

67): 

68 pid = tl.program_id(0) 

69 div_v = tl.cdiv(num_groups, SPLIT) 

70 div_mod = num_groups % SPLIT 

71 split_group = pid % div_v 

72 split_n = pid // div_v 

73 real_num_elements = group_size * HW 

74 

75 group_offset = tl.arange(0, BLOCK_GROUP_SIZE) 

76 hw_offset = tl.arange(0, BLOCK_HW_SIZE) 

77 if BLOCK_HW_SIZE >= HW: 

78 hw_offset = tl.arange(0, HW) 

79 hw_iter = tl.cdiv(HW, BLOCK_HW_SIZE) 

80 

81 if W is None: 

82 W_ptr = None 

83 else: 

84 W_ptr = W + split_group * SPLIT * group_size 

85 if B is None: 

86 B_ptr = None 

87 else: 

88 B_ptr = B + split_group * SPLIT * group_size 

89 

90 Mean_ptr = Mean + split_n * num_groups + split_group * SPLIT 

91 Rstd_ptr = Rstd + split_n * num_groups + split_group * SPLIT 

92 

93 xy_offset = ( 

94 split_n * C * HW 

95 + split_group * SPLIT * real_num_elements 

96 + group_offset[:, None] * HW 

97 + hw_offset[None, :] 

98 ) 

99 

100 ub = SPLIT 

101 if (div_mod != 0) and ((split_group + 1) == div_v): 

102 ub = div_mod 

103 for idx in range(0, ub): 

104 if BLOCK_HW_SIZE >= HW: 

105 tmp = tl.load(X + xy_offset, cache_modifier=".cg").to(tl.float32) 

106 mean = tl.sum(tmp) / real_num_elements 

107 x = tmp - mean 

108 var = tl.sum(x * x) / real_num_elements 

109 var = tl.rsqrt(var + eps) 

110 

111 tl.store(Mean_ptr + idx, mean) 

112 tl.store(Rstd_ptr + idx, var) 

113 

114 if W_ptr is None: 

115 weight = 1 

116 else: 

117 weight = tl.load(W_ptr + group_offset, cache_modifier=".cg")[:, None] 

118 if B_ptr is None: 

119 bias = 0 

120 else: 

121 bias = tl.load(B_ptr + group_offset, cache_modifier=".cg")[:, None] 

122 tmp = (tmp - mean) * var 

123 tmp = tmp * weight + bias 

124 tl.store(Y + xy_offset, tmp) 

125 else: 

126 mean = tl.zeros([BLOCK_GROUP_SIZE, BLOCK_HW_SIZE], tl.float32) 

127 var = tl.zeros([BLOCK_GROUP_SIZE, BLOCK_HW_SIZE], tl.float32) 

128 for idy in range(0, hw_iter): 

129 xy_mask = ( 

130 group_offset[:, None] < group_size 

131 and (idy * BLOCK_HW_SIZE + hw_offset[None, :]) < HW 

132 ) 

133 tmp = tl.load( 

134 X + idy * BLOCK_HW_SIZE + xy_offset, 

135 mask=xy_mask, 

136 other=0.0, 

137 cache_modifier=".cg", 

138 ).to(tl.float32) 

139 mean += tmp 

140 var += tmp * tmp 

141 mean = tl.sum(mean) / real_num_elements 

142 var = tl.sum(var) / real_num_elements - (mean * mean) 

143 var = tl.rsqrt(var + eps) 

144 tl.store(Mean_ptr + idx, mean) 

145 tl.store(Rstd_ptr + idx, var) 

146 

147 if W_ptr is None: 

148 weight = 1 

149 else: 

150 weight = tl.load(W_ptr + group_offset, cache_modifier=".cg")[:, None] 

151 if B_ptr is None: 

152 bias = 0 

153 else: 

154 bias = tl.load(B_ptr + group_offset, cache_modifier=".cg")[:, None] 

155 

156 for idy in range(0, hw_iter): 

157 xy_mask = ( 

158 group_offset[:, None] < group_size 

159 and (idy * BLOCK_HW_SIZE + hw_offset[None, :]) < HW 

160 ) 

161 tmp = tl.load( 

162 X + idy * BLOCK_HW_SIZE + xy_offset, 

163 mask=xy_mask, 

164 other=0.0, 

165 cache_modifier=".cg", 

166 ).to(tl.float32) 

167 tmp = (tmp - mean) * var 

168 tmp = tmp * weight + bias 

169 tl.store(Y + idy * BLOCK_HW_SIZE + xy_offset, tmp, mask=xy_mask) 

170 

171 xy_offset += real_num_elements 

172 group_offset += group_size 

173 

174 

175def group_norm_backward_kernel_opt_prune(configs, named_args, **kwargs): 

176 pruned_configs = [] 

177 hw = kwargs["HW"] 

178 all_sizes = [] 

179 for config in configs: 

180 BLOCK_HW_SIZE = config.kwargs["BLOCK_HW_SIZE"] 

181 if BLOCK_HW_SIZE not in all_sizes: 

182 all_sizes.append(BLOCK_HW_SIZE) 

183 for config in configs: 

184 BLOCK_HW_SIZE = config.kwargs["BLOCK_HW_SIZE"] 

185 SPLIT = config.kwargs["SPLIT"] 

186 if (hw > 2048) and (BLOCK_HW_SIZE >= 2048) and (SPLIT <= 1): 

187 pruned_configs.append(config) 

188 elif BLOCK_HW_SIZE > hw: 

189 not_step_bigger = False 

190 for size in all_sizes: 

191 if (size < BLOCK_HW_SIZE) and (size > hw): 

192 not_step_bigger = True 

193 if not not_step_bigger: 

194 pruned_configs.append(config) 

195 return pruned_configs 

196 

197 

198@libentry() 

199@triton.autotune( 

200 configs=[ 

201 triton.Config({"SPLIT": s, "BLOCK_HW_SIZE": size}, num_stages=3, num_warps=1) 

202 for s in [1, 4, 6, 8] 

203 for size in [64, 256, 512, 1024, 2048] 

204 ], 

205 prune_configs_by={"early_config_prune": group_norm_backward_kernel_opt_prune}, 

206 key=["X", "group_size", "C", "HW", "num_groups"], 

207) 

208@triton.jit() 

209def group_norm_backward_kernel_opt( 

210 grad_y, 

211 X, 

212 W, 

213 Mean, 

214 Rstd, 

215 num_groups, 

216 group_size, 

217 grad_x, 

218 C, 

219 HW: tl.constexpr, 

220 BLOCK_GROUP_SIZE: tl.constexpr, 

221 BLOCK_HW_SIZE: tl.constexpr, 

222 SPLIT: tl.constexpr, 

223): 

224 pid = tl.program_id(0) 

225 div_v = tl.cdiv(num_groups, SPLIT) 

226 div_mod = num_groups % SPLIT 

227 split_group = pid % div_v 

228 split_n = pid // div_v 

229 real_num_elements = group_size * HW 

230 hw_iter = tl.cdiv(HW, BLOCK_HW_SIZE) 

231 

232 group_offset = tl.arange(0, BLOCK_GROUP_SIZE) 

233 if BLOCK_HW_SIZE >= HW: 

234 hw_offset = tl.arange(0, HW) 

235 else: 

236 hw_offset = tl.arange(0, BLOCK_HW_SIZE) 

237 

238 if W is None: 

239 W_ptr = None 

240 else: 

241 W_ptr = W + split_group * SPLIT * group_size 

242 

243 Mean_ptr = Mean + split_n * num_groups + split_group * SPLIT 

244 Rstd_ptr = Rstd + split_n * num_groups + split_group * SPLIT 

245 

246 xy_offset = ( 

247 split_n * real_num_elements * num_groups 

248 + split_group * SPLIT * real_num_elements 

249 + group_offset[:, None] * HW 

250 + hw_offset[None, :] 

251 ) 

252 

253 ub = SPLIT 

254 if (div_mod != 0) and ((split_group + 1) == div_v): 

255 ub = div_mod 

256 for idx in range(0, ub): 

257 wb_mask = group_offset < C 

258 

259 if W_ptr is None: 

260 weight = 1 

261 else: 

262 weight = tl.load( 

263 W_ptr + group_offset, mask=wb_mask, other=0.0, cache_modifier=".cg" 

264 ).to(tl.float32)[:, None] 

265 rstd = tl.load(Rstd_ptr + idx).to(tl.float32) 

266 mean = tl.load(Mean_ptr + idx).to(tl.float32) 

267 

268 if BLOCK_HW_SIZE >= HW: 

269 dY_val = tl.load(grad_y + xy_offset, cache_modifier=".cg").to(tl.float32) 

270 X_val = tl.load(X + xy_offset, cache_modifier=".cg").to(tl.float32) 

271 

272 x_hat = rstd * (X_val - mean) 

273 dx_hat = weight * dY_val 

274 

275 grad_dx_hat_sum = tl.sum(dx_hat) 

276 grad_x_hat_sum = tl.sum(dx_hat * x_hat) 

277 

278 dx = rstd * ( 

279 dx_hat - (grad_dx_hat_sum + x_hat * grad_x_hat_sum) / real_num_elements 

280 ) 

281 

282 tl.store(grad_x + xy_offset, dx) 

283 else: 

284 grad_dx_hat_accum = tl.zeros([BLOCK_GROUP_SIZE, BLOCK_HW_SIZE], tl.float32) 

285 grad_x_hat_accum = tl.zeros([BLOCK_GROUP_SIZE, BLOCK_HW_SIZE], tl.float32) 

286 

287 for idy in range(0, hw_iter): 

288 xy_mask = (group_offset[:, None] < C) & ( 

289 (idy * BLOCK_HW_SIZE + hw_offset[None, :]) < HW 

290 ) 

291 dY_val = tl.load( 

292 grad_y + idy * BLOCK_HW_SIZE + xy_offset, 

293 mask=xy_mask, 

294 other=0.0, 

295 cache_modifier=".cg", 

296 ).to(tl.float32) 

297 X_val = tl.load( 

298 X + idy * BLOCK_HW_SIZE + xy_offset, 

299 mask=xy_mask, 

300 other=0.0, 

301 cache_modifier=".cg", 

302 ).to(tl.float32) 

303 

304 x_hat = tl.where(xy_mask, rstd * (X_val - mean), 0.0) 

305 dx_hat = weight * dY_val 

306 grad_dx_hat_accum += dx_hat 

307 grad_x_hat_accum += dx_hat * x_hat 

308 

309 grad_dx_hat_total = tl.sum(grad_dx_hat_accum) 

310 grad_x_hat_total = tl.sum(grad_x_hat_accum) 

311 

312 for idy in range(0, hw_iter): 

313 xy_mask = (group_offset[:, None] < C) & ( 

314 (idy * BLOCK_HW_SIZE + hw_offset[None, :]) < HW 

315 ) 

316 dY_val = tl.load( 

317 grad_y + idy * BLOCK_HW_SIZE + xy_offset, 

318 mask=xy_mask, 

319 other=0.0, 

320 cache_modifier=".cg", 

321 ).to(tl.float32) 

322 X_val = tl.load( 

323 X + idy * BLOCK_HW_SIZE + xy_offset, 

324 mask=xy_mask, 

325 other=0.0, 

326 cache_modifier=".cg", 

327 ).to(tl.float32) 

328 

329 x_hat = tl.where(xy_mask, rstd * (X_val - mean), 0.0) 

330 dx_hat = weight * dY_val 

331 dx = rstd * ( 

332 dx_hat 

333 - (grad_dx_hat_total + x_hat * grad_x_hat_total) / real_num_elements 

334 ) 

335 

336 tl.store(grad_x + idy * BLOCK_HW_SIZE + xy_offset, dx, mask=xy_mask) 

337 

338 xy_offset += real_num_elements 

339 group_offset += group_size 

340 

341 

342def weight_bias_backward_kernel_opt_prune(configs, named_args, **kwargs): 

343 pruned_configs = [] 

344 pruned_configs_cached = [] 

345 n = named_args["N"] 

346 hw = kwargs["HW"] 

347 all_sizes = [] 

348 for config in configs: 

349 BLOCK_HW_SIZE = config.kwargs["BLOCK_HW_SIZE"] 

350 if BLOCK_HW_SIZE not in all_sizes: 

351 all_sizes.append(BLOCK_HW_SIZE) 

352 for config in configs: 

353 BLOCK_HW_SIZE = config.kwargs["BLOCK_HW_SIZE"] 

354 BLOCK_N = config.kwargs["BLOCK_N"] 

355 if (hw > 2048) and (BLOCK_HW_SIZE >= 2048) and (BLOCK_N <= 4): 

356 pruned_configs_cached.append(config) 

357 elif BLOCK_HW_SIZE > hw: 

358 not_step_bigger = False 

359 for size in all_sizes: 

360 if (size < BLOCK_HW_SIZE) and (size > hw): 

361 not_step_bigger = True 

362 if not not_step_bigger: 

363 pruned_configs_cached.append(config) 

364 # remove some block n 

365 for config in pruned_configs_cached: 

366 block_n = config.kwargs["BLOCK_N"] 

367 if n % block_n == 0: 

368 pruned_configs.append(config) 

369 return pruned_configs 

370 

371 

372@libentry() 

373@triton.autotune( 

374 configs=[ 

375 triton.Config({"BLOCK_N": bn, "BLOCK_HW_SIZE": size}, num_stages=3, num_warps=1) 

376 for bn in [1, 4, 8, 16] 

377 for size in [512, 1024, 2048] 

378 ], 

379 prune_configs_by={"early_config_prune": weight_bias_backward_kernel_opt_prune}, 

380 key=["X", "N", "C", "HW", "num_groups"], 

381) 

382@triton.jit 

383def weight_bias_backward_kernel_opt( 

384 dY, 

385 X, 

386 Mean, 

387 Rstd, 

388 dW, 

389 dB, 

390 num_groups, 

391 group_size, 

392 N, 

393 C, 

394 HW: tl.constexpr, 

395 BLOCK_N: tl.constexpr, 

396 BLOCK_HW_SIZE: tl.constexpr, 

397): 

398 pid = tl.program_id(0) 

399 pnum = tl.num_programs(axis=0) 

400 C_SPLIT = tl.cdiv(C, pnum) 

401 N_SPLIT = tl.cdiv(N, BLOCK_N) 

402 hw_iter = tl.cdiv(HW, BLOCK_HW_SIZE) 

403 

404 n_offset = tl.arange(0, BLOCK_N) 

405 hw_offset = tl.arange(0, BLOCK_HW_SIZE) 

406 if BLOCK_HW_SIZE >= HW: 

407 hw_offset = tl.arange(0, HW) 

408 

409 lb = pid * C_SPLIT 

410 ub = tl.minimum((pid + 1) * C_SPLIT, C) 

411 for c_start in range(lb, ub): 

412 if BLOCK_HW_SIZE >= HW: 

413 dY_ptr = dY + c_start * HW + n_offset[:, None] * C * HW + hw_offset[None, :] 

414 x_ptr = X + c_start * HW + n_offset[:, None] * C * HW + hw_offset[None, :] 

415 grad_y = tl.load(dY_ptr, cache_modifier=".cg").to(tl.float32) 

416 

417 x = tl.load(x_ptr, cache_modifier=".cg") 

418 x_f32 = x.to(tl.float32) 

419 

420 mean_ptr = Mean + c_start // group_size + n_offset * num_groups 

421 rstd_ptr = Rstd + c_start // group_size + n_offset * num_groups 

422 

423 mean = tl.load(mean_ptr, cache_modifier=".cg").to(tl.float32)[:, None] 

424 rstd = tl.load(rstd_ptr, cache_modifier=".cg").to(tl.float32)[:, None] 

425 

426 dB_val = tl.sum(grad_y) 

427 dW_val = tl.sum((x_f32 - mean) * rstd * grad_y) 

428 

429 for n_start in range(1, N_SPLIT): 

430 new_n_offset = n_start * BLOCK_N + n_offset 

431 

432 dY_ptr = ( 

433 dY 

434 + c_start * HW 

435 + new_n_offset[:, None] * C * HW 

436 + hw_offset[None, :] 

437 ) 

438 x_ptr = ( 

439 X 

440 + c_start * HW 

441 + new_n_offset[:, None] * C * HW 

442 + hw_offset[None, :] 

443 ) 

444 grad_y = tl.load(dY_ptr, cache_modifier=".cg").to(tl.float32) 

445 

446 x = tl.load(x_ptr, cache_modifier=".cg") 

447 x_f32 = x.to(tl.float32) 

448 

449 mean_ptr = Mean + c_start // group_size + new_n_offset * num_groups 

450 rstd_ptr = Rstd + c_start // group_size + new_n_offset * num_groups 

451 

452 mean = tl.load(mean_ptr, cache_modifier=".cg").to(tl.float32)[:, None] 

453 rstd = tl.load(rstd_ptr, cache_modifier=".cg").to(tl.float32)[:, None] 

454 

455 dB_val += tl.sum(grad_y) 

456 dW_val += tl.sum((x_f32 - mean) * rstd * grad_y) 

457 

458 if dW is not None: 

459 tl.store(dW + c_start, dW_val) 

460 if dB is not None: 

461 tl.store(dB + c_start, dB_val) 

462 else: 

463 xy_mask = (n_offset[:, None] < N) & (hw_offset[None, :] < HW) 

464 

465 dY_ptr = dY + c_start * HW + n_offset[:, None] * C * HW + hw_offset[None, :] 

466 x_ptr = X + c_start * HW + n_offset[:, None] * C * HW + hw_offset[None, :] 

467 grad_y = tl.load(dY_ptr, cache_modifier=".cg").to(tl.float32) 

468 

469 x = tl.load(x_ptr, cache_modifier=".cg") 

470 x_f32 = x.to(tl.float32) 

471 

472 mean_ptr = Mean + c_start // group_size + n_offset * num_groups 

473 rstd_ptr = Rstd + c_start // group_size + n_offset * num_groups 

474 

475 mean = tl.load(mean_ptr, cache_modifier=".cg").to(tl.float32)[:, None] 

476 rstd = tl.load(rstd_ptr, cache_modifier=".cg").to(tl.float32)[:, None] 

477 

478 dB_val = tl.sum(grad_y) 

479 dW_val = tl.sum((x_f32 - mean) * rstd * grad_y) 

480 

481 for idx in range(1, hw_iter): 

482 xy_mask = (n_offset[:, None] < N) & ( 

483 (idx * BLOCK_HW_SIZE + hw_offset[None, :]) < HW 

484 ) 

485 dY_ptr = ( 

486 dY 

487 + c_start * HW 

488 + n_offset[:, None] * C * HW 

489 + hw_offset[None, :] 

490 + idx * BLOCK_HW_SIZE 

491 ) 

492 x_ptr = ( 

493 X 

494 + c_start * HW 

495 + n_offset[:, None] * C * HW 

496 + hw_offset[None, :] 

497 + idx * BLOCK_HW_SIZE 

498 ) 

499 

500 grad_y = tl.load( 

501 dY_ptr, mask=xy_mask, other=0.0, cache_modifier=".cg" 

502 ).to(tl.float32) 

503 x = tl.load(x_ptr, mask=xy_mask, other=0.0, cache_modifier=".cg") 

504 x_f32 = x.to(tl.float32) 

505 dB_val += tl.sum(grad_y) 

506 x_f32 = tl.where(xy_mask, x_f32 - mean, 0.0) 

507 dW_val += tl.sum(x_f32 * rstd * grad_y) 

508 

509 for n_start in range(1, N_SPLIT): 

510 new_n_offset = n_start * BLOCK_N + n_offset 

511 xy_mask = (new_n_offset[:, None] < N) & (hw_offset[None, :] < HW) 

512 

513 dY_ptr = ( 

514 dY 

515 + c_start * HW 

516 + new_n_offset[:, None] * C * HW 

517 + hw_offset[None, :] 

518 ) 

519 x_ptr = ( 

520 X 

521 + c_start * HW 

522 + new_n_offset[:, None] * C * HW 

523 + hw_offset[None, :] 

524 ) 

525 grad_y = tl.load( 

526 dY_ptr, mask=xy_mask, other=0.0, cache_modifier=".cg" 

527 ).to(tl.float32) 

528 

529 x = tl.load(x_ptr, mask=xy_mask, other=0.0, cache_modifier=".cg") 

530 x_f32 = x.to(tl.float32) 

531 

532 mean_ptr = Mean + c_start // group_size + new_n_offset * num_groups 

533 rstd_ptr = Rstd + c_start // group_size + new_n_offset * num_groups 

534 

535 mean = tl.load(mean_ptr, cache_modifier=".cg").to(tl.float32)[:, None] 

536 rstd = tl.load(rstd_ptr, cache_modifier=".cg").to(tl.float32)[:, None] 

537 

538 dB_val += tl.sum(grad_y) 

539 dW_val += tl.sum((x_f32 - mean) * rstd * grad_y) 

540 

541 for idx in range(1, hw_iter): 

542 xy_mask = (new_n_offset[:, None] < N) & ( 

543 (idx * BLOCK_HW_SIZE + hw_offset[None, :]) < HW 

544 ) 

545 dY_ptr = ( 

546 dY 

547 + c_start * HW 

548 + new_n_offset[:, None] * C * HW 

549 + hw_offset[None, :] 

550 + idx * BLOCK_HW_SIZE 

551 ) 

552 x_ptr = ( 

553 X 

554 + c_start * HW 

555 + new_n_offset[:, None] * C * HW 

556 + hw_offset[None, :] 

557 + idx * BLOCK_HW_SIZE 

558 ) 

559 

560 grad_y = tl.load( 

561 dY_ptr, mask=xy_mask, other=0.0, cache_modifier=".cg" 

562 ).to(tl.float32) 

563 x = tl.load(x_ptr, mask=xy_mask, other=0.0, cache_modifier=".cg") 

564 x_f32 = x.to(tl.float32) 

565 dB_val += tl.sum(grad_y) 

566 x_f32 = tl.where(xy_mask, x_f32 - mean, 0.0) 

567 dW_val += tl.sum(x_f32 * rstd * grad_y) 

568 if dW is not None: 

569 tl.store(dW + c_start, dW_val) 

570 if dB is not None: 

571 tl.store(dB + c_start, dB_val) 

572 

573 

574def group_norm(input, weight, bias, N, C, HxW, group, eps=1e-05): 

575 logger.debug("GEMS_CAMBRICON GROUPNORM FORWARD") 

576 group_size = C // group 

577 input = input.contiguous() 

578 if weight is not None: 

579 weight = weight.contiguous() 

580 if bias is not None: 

581 bias = bias.contiguous() 

582 y = torch.empty_like(input) 

583 mean = torch.empty((N, group), dtype=input.dtype, device=input.device) 

584 rstd = torch.empty((N, group), dtype=input.dtype, device=input.device) 

585 grid = lambda meta: (N * triton.cdiv(group, meta["SPLIT"]),) 

586 

587 with torch_device_fn.device(input.device): 

588 group_norm_kernel_opt[grid]( 

589 input, 

590 y, 

591 weight, 

592 bias, 

593 mean, 

594 rstd, 

595 group_size, 

596 C, 

597 group, 

598 eps, 

599 HW=HxW, 

600 BLOCK_GROUP_SIZE=group_size, 

601 ) 

602 return y, mean, rstd 

603 

604 

605def group_norm_backward( 

606 grad_out, input, mean, rstd, weight, N, C, HxW, group, output_mask 

607): 

608 logger.debug("GEMS_CAMBRICON GROUPNORM BACKWARD") 

609 

610 grad_out = grad_out.contiguous() 

611 input = input.contiguous() 

612 mean = mean.contiguous() 

613 rstd = rstd.contiguous() 

614 weight = None if weight is None else weight.contiguous() 

615 group_size = triton.cdiv(C, group) 

616 

617 if output_mask[0]: 

618 grad_inp = torch.empty_like(input) 

619 grid = lambda meta: (N * triton.cdiv(group, meta["SPLIT"]),) 

620 with torch_device_fn.device(input.device): 

621 group_norm_backward_kernel_opt[grid]( 

622 grad_out, 

623 input, 

624 weight, 

625 mean, 

626 rstd, 

627 group, 

628 group_size, 

629 grad_inp, 

630 C, 

631 HW=HxW, 

632 BLOCK_GROUP_SIZE=group_size, 

633 ) 

634 else: 

635 grad_inp = None 

636 

637 if output_mask[1] is False and output_mask[2] is False: 

638 return grad_inp, None, None 

639 

640 weight_grad = torch.empty_like(weight) if output_mask[1] else None 

641 bias_grad = torch.empty_like(weight) if output_mask[2] else None 

642 with torch_device_fn.device(input.device): 

643 weight_bias_backward_kernel_opt[(TOTAL_CORE_NUM, 1, 1)]( 

644 grad_out, 

645 input, 

646 mean, 

647 rstd, 

648 weight_grad, 

649 bias_grad, 

650 group, 

651 group_size, 

652 N, 

653 C, 

654 HW=HxW, 

655 ) 

656 return grad_inp, weight_grad, bias_grad