Coverage for src/flag_gems/runtime/backend/_cambricon/ops/layernorm.py: 0%

341 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-27 02:51 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry 

11from flag_gems.utils.type_utils import get_accumulator_dtype 

12 

13from ..utils import TOTAL_CORE_NUM 

14 

15logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

16MAX_C_MLU_LAYERNORM_FORWARD = 8192 

17MAX_C_MLU_LAYERNORM_BACKWARD = 5120 

18 

19 

20@libentry() 

21@triton.autotune( 

22 configs=runtime.get_tuned_config("layer_norm_persistent"), 

23 key=["M", "N"], 

24) 

25@triton.jit(do_not_specialize=["eps"]) 

26def layer_norm_kernel_middle_n( 

27 X, 

28 Y, 

29 W, 

30 B, 

31 Mean, # pointer to the mean 

32 Rstd, # pointer to the 1/std 

33 M, 

34 eps, 

35 N: tl.constexpr, 

36 BLOCK_ROW_SIZE: tl.constexpr, 

37): 

38 pid = tl.program_id(0) 

39 row_start = pid * BLOCK_ROW_SIZE 

40 num_jobs = tl.num_programs(axis=0) 

41 step = num_jobs * BLOCK_ROW_SIZE 

42 

43 cols_n = tl.arange(0, N) 

44 X += cols_n[None, :] 

45 Y += cols_n[None, :] 

46 cols_off = tl.arange(0, N)[None, :] 

47 if W is None: 

48 w = 1 

49 else: 

50 w = tl.load(W + cols_off) 

51 if B is None: 

52 b = 0 

53 else: 

54 b = tl.load(B + cols_off) 

55 for row in range(row_start, M, step): 

56 row_off = row + tl.arange(0, BLOCK_ROW_SIZE) 

57 mask = row_off[:, None] < M 

58 off = row_off[:, None] * N 

59 x = tl.load(X + off, mask, other=0.0).to(tl.float32) 

60 

61 # TODO: Use the following code as a fallback once the optimization for trans is complete. 

62 # mean = tl.sum(x_v, axis=1) / N 

63 # var = tl.sum(x_v * x_v, axis=1) / N - (mean * mean) 

64 # mean_bc = mean[:, None] 

65 

66 x_v = tl.view(x, (BLOCK_ROW_SIZE, N)) 

67 x_trans = tl.trans(x_v) 

68 mean = tl.sum(x_trans, axis=0) / N 

69 mean_bc = mean[:, None] 

70 tl.store(Mean + row_off[:, None], mean_bc, mask) 

71 var = tl.sum(x_trans * x_trans, axis=0) / N - (mean * mean) 

72 var = var[:, None] 

73 rstd = 1 / tl.sqrt(var + eps) 

74 tl.store(Rstd + row_off[:, None], rstd, mask) 

75 x = x - mean_bc 

76 x_hat = x * rstd 

77 y = x_hat * w + b 

78 tl.store(Y + off, y, mask=mask) 

79 

80 

81def config_prune(configs, named_args, **kwargs): 

82 M = named_args["M"] 

83 pruned_configs = [] 

84 for config in configs: 

85 BLOCK_M = config.kwargs["BLOCK_ROW_SIZE"] 

86 if (M >= 1024 and BLOCK_M >= 22) or (M < 1024 and BLOCK_M < 22): 

87 pruned_configs.append(config) 

88 return pruned_configs 

89 

90 

91def cfggen(): 

92 configs = [ 

93 triton.Config({"BLOCK_ROW_SIZE": 2}, num_warps=1, num_stages=1), 

94 triton.Config({"BLOCK_ROW_SIZE": 8}, num_warps=1, num_stages=1), 

95 triton.Config({"BLOCK_ROW_SIZE": 14}, num_warps=1, num_stages=1), 

96 triton.Config({"BLOCK_ROW_SIZE": 22}, num_warps=1, num_stages=1), 

97 triton.Config({"BLOCK_ROW_SIZE": 32}, num_warps=1, num_stages=1), 

98 ] 

99 return configs 

100 

101 

102@libentry() 

103@triton.autotune( 

104 configs=cfggen(), 

105 key=["M", "N"], 

106 prune_configs_by={"early_config_prune": config_prune}, 

107) 

108@triton.jit(do_not_specialize=["eps"]) 

109def layer_norm_kernel_non_inner( 

110 X, 

111 Y, 

112 W, 

113 B, 

114 Mean, # pointer to the mean 

115 Rstd, # pointer to the 1/std 

116 M, 

117 N, 

118 eps, 

119 BLOCK_ROW_SIZE: tl.constexpr, 

120 BLOCK_COL_SIZE: tl.constexpr, 

121): 

122 # Map the program id to the row of X and Y it should compute. 

123 pid = tl.program_id(0) 

124 row = pid * BLOCK_ROW_SIZE + tl.arange(0, BLOCK_ROW_SIZE)[:, None] 

125 row_mask = row < M 

126 X += row * N 

127 Y += row * N 

128 # BLOCK_COL_SIZE = N 

129 

130 # Compute mean 

131 _mean = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32) 

132 # Compute variance 

133 _var = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32) 

134 # for off in range(0, N, BLOCK_COL_SIZE): 

135 cols = tl.arange(0, BLOCK_COL_SIZE)[None, :] 

136 col_mask = cols < N 

137 mask = row_mask and col_mask 

138 a = tl.load(X + cols, mask, other=0.0).to(tl.float32) 

139 _mean += a 

140 _var += a * a 

141 mean = tl.sum(_mean, axis=1) / N 

142 mean_bc = mean[:, None] 

143 

144 a = tl.where(col_mask, a - mean_bc, 0.0) 

145 # Write mean / rstd 

146 tl.store(Mean + row, mean_bc, row_mask) 

147 var = tl.sum(_var, axis=1) / N - (mean * mean) 

148 var = var[:, None] 

149 rstd = 1 / tl.sqrt(var + eps) 

150 x_hat = a * rstd 

151 tl.store(Rstd + row, rstd, row_mask) 

152 

153 # Normalize and apply linear transformation 

154 if W is None: 

155 w = 1 

156 else: 

157 w = tl.load(W + cols, col_mask) 

158 if B is None: 

159 b = 0 

160 else: 

161 b = tl.load(B + cols, col_mask) 

162 y = x_hat * w + b 

163 # Write output 

164 tl.store(Y + cols, y, mask=mask) 

165 

166 

167@libentry() 

168@triton.autotune( 

169 configs=runtime.get_tuned_config("layer_norm_loop"), 

170 key=["M", "N"], 

171 prune_configs_by={"early_config_prune": config_prune}, 

172) 

173@triton.jit(do_not_specialize=["eps"]) 

174def layer_norm_kernel_inner( 

175 X, 

176 Y, 

177 W, 

178 B, 

179 Mean, # pointer to the mean 

180 Rstd, # pointer to the 1/std 

181 M, 

182 eps, 

183 N: tl.constexpr, 

184 BLOCK_ROW_SIZE: tl.constexpr, 

185 BLOCK_COL_SIZE: tl.constexpr, 

186): 

187 # Map the program id to the row of X and Y it should compute. 

188 pid = tl.program_id(0) 

189 row = pid * BLOCK_ROW_SIZE + tl.arange(0, BLOCK_ROW_SIZE)[:, None] 

190 row_mask = row < M 

191 X += row * N 

192 Y += row * N 

193 

194 # Compute mean 

195 _mean = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32) 

196 # Compute variance 

197 _var = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32) 

198 block_col_size = tl.arange(0, BLOCK_COL_SIZE)[None, :] 

199 for off in range(0, N, BLOCK_COL_SIZE): 

200 cols = off + block_col_size 

201 col_mask = cols < N 

202 mask = row_mask and col_mask 

203 a = tl.load(X + cols, mask, other=0.0).to(tl.float32) 

204 _mean += a 

205 _var += a * a 

206 

207 mean = tl.sum(_mean, axis=1) / N 

208 mean_bc = mean[:, None] 

209 

210 var = tl.sum(_var, axis=1) / N - (mean * mean) 

211 var = var[:, None] 

212 rstd = 1 / tl.sqrt(var + eps) 

213 # Write mean / rstd 

214 tl.store(Mean + row, mean_bc, row_mask) 

215 tl.store(Rstd + row, rstd, row_mask) 

216 

217 # Normalize and apply linear transformation 

218 for off in range(0, N, BLOCK_COL_SIZE): 

219 cols = off + block_col_size 

220 col_mask = cols < N 

221 mask = row_mask and col_mask 

222 if W is None: 

223 w = 1 

224 else: 

225 w = tl.load(W + cols, col_mask) 

226 if B is None: 

227 b = 0 

228 else: 

229 b = tl.load(B + cols, col_mask) 

230 x = tl.load(X + cols, mask, other=0.0).to(tl.float32) 

231 x = tl.where(col_mask, x - mean_bc, 0.0) 

232 x_hat = x * rstd 

233 y = x_hat * w + b 

234 # Write output 

235 tl.store(Y + cols, y, mask=mask) 

236 

237 

238def prune_in_wb_config(configs, named_args, **kwargs): 

239 M = named_args["M"] 

240 pruned_configs = [] 

241 for config in configs: 

242 BLOCK_M = config.kwargs["BLOCK_ROW_SIZE"] 

243 if M // BLOCK_M < 1: 

244 continue 

245 pruned_configs.append(config) 

246 return pruned_configs 

247 

248 

249@libentry() 

250@triton.autotune( 

251 configs=runtime.get_tuned_config("weight_bias_backward"), 

252 prune_configs_by={"early_config_prune": prune_in_wb_config}, 

253 key=["M", "N"], 

254) 

255@triton.jit 

256def input_backward_kernel( 

257 dY, 

258 X, 

259 W, 

260 Mean, 

261 Rstd, 

262 dX, 

263 M, 

264 N, 

265 BLOCK_ROW_SIZE: tl.constexpr, 

266 BLOCK_COL_SIZE: tl.constexpr, 

267): 

268 pid = tl.program_id(0) 

269 

270 row_start = pid * BLOCK_ROW_SIZE 

271 num_jobs = tl.num_programs(axis=0) 

272 step = num_jobs * BLOCK_ROW_SIZE 

273 

274 for row in range(row_start, M, step): 

275 row_off = row + tl.arange(0, BLOCK_ROW_SIZE) 

276 mean = tl.load(Mean + row_off, mask=row_off < M, other=0.0)[:, None].to( 

277 tl.float32 

278 ) 

279 rstd = tl.load(Rstd + row_off, mask=row_off < M, other=0.0)[:, None].to( 

280 tl.float32 

281 ) 

282 

283 row_mask = row_off[:, None] < M 

284 off = row_off[:, None] * N 

285 new_dY = dY + off 

286 new_X = X + off 

287 new_DX = dX + off 

288 

289 dx_part2 = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32) 

290 dx_part3 = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32) 

291 

292 for off in range(0, N, BLOCK_COL_SIZE): 

293 cols = off + tl.arange(0, BLOCK_COL_SIZE) 

294 col_mask = cols[None, :] < N 

295 mask = row_mask and col_mask 

296 dy = tl.load(new_dY + cols[None, :], mask, other=0.0).to(tl.float32) 

297 x = tl.load(new_X + cols[None, :], mask, other=0.0).to(tl.float32) 

298 x_hat = (x - mean) * rstd 

299 if W is None: 

300 wdy = dy 

301 else: 

302 w = tl.load(W + cols, mask=cols < N).to(tl.float32) 

303 wdy = dy * w 

304 dx_part2 += wdy 

305 dx_part3 += wdy * x_hat 

306 

307 dx_part2_trans = tl.trans(dx_part2) 

308 dx_2 = tl.sum(dx_part2_trans, axis=0)[:, None] 

309 dx_part3_trans = tl.trans(dx_part3) 

310 dx_3 = tl.sum(dx_part3_trans, axis=0)[:, None] 

311 

312 for off in range(0, N, BLOCK_COL_SIZE): 

313 cols = off + tl.arange(0, BLOCK_COL_SIZE) 

314 col_mask = cols[None, :] < N 

315 mask = row_mask and col_mask 

316 dy = tl.load(new_dY + cols[None, :], mask, other=0.0).to(tl.float32) 

317 x = tl.load(new_X + cols[None, :], mask, other=0.0).to(tl.float32) 

318 if W is None: 

319 wdy = dy 

320 else: 

321 w = tl.load(W + cols, mask=cols < N, other=0.0).to(tl.float32) 

322 wdy = dy * w 

323 x_hat = (x - mean) * rstd 

324 dx = rstd * (wdy - (dx_2 + x_hat * dx_3) / N) 

325 tl.store(new_DX + cols, dx.to(x.dtype), mask=mask) 

326 

327 

328@libentry() 

329@triton.autotune( 

330 configs=runtime.get_tuned_config("weight_bias_backward"), 

331 prune_configs_by={"early_config_prune": prune_in_wb_config}, 

332 key=["M", "N"], 

333) 

334@triton.jit 

335def weight_bias_backward_kernel( 

336 dY, 

337 X, 

338 Mean, 

339 Rstd, 

340 dW, 

341 dB, 

342 M, 

343 N, 

344 BLOCK_ROW_SIZE: tl.constexpr, 

345 BLOCK_COL_SIZE: tl.constexpr, 

346): 

347 pid = tl.program_id(0) 

348 

349 col_start = pid * BLOCK_COL_SIZE 

350 num_jobs = tl.num_programs(axis=0) 

351 step = num_jobs * BLOCK_COL_SIZE 

352 

353 for col in range(col_start, N, step): 

354 col_off = col + tl.arange(0, BLOCK_COL_SIZE)[None, :] 

355 col_mask = col_off < N 

356 

357 new_dY = dY + col_off 

358 new_X = X + col_off 

359 new_dW = dW + col_off 

360 new_dB = dB + col_off 

361 

362 accW = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32) 

363 accB = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32) 

364 

365 for off in range(0, M, BLOCK_ROW_SIZE): 

366 rows = off + tl.arange(0, BLOCK_ROW_SIZE) 

367 row_mask = rows[:, None] < M 

368 mask = row_mask and col_mask 

369 dy = tl.load(new_dY + rows[:, None] * N, mask, other=0.0).to(tl.float32) 

370 x = tl.load(new_X + rows[:, None] * N, mask, other=0.0).to(tl.float32) 

371 mean = tl.load(Mean + rows, mask=rows < M, other=0.0)[:, None].to( 

372 tl.float32 

373 ) 

374 rstd = tl.load(Rstd + rows, mask=rows < M, other=0.0)[:, None].to( 

375 tl.float32 

376 ) 

377 x_hat = (x - mean) * rstd 

378 accW += dy * x_hat 

379 accB += dy 

380 dw = tl.sum(accW, axis=0) 

381 db = tl.sum(accB, axis=0) 

382 tl.store(new_dW, dw[None, :], mask=col_mask) 

383 tl.store(new_dB, db[None, :], mask=col_mask) 

384 

385 

386def cfggen_bw_middle_n(): 

387 block_m = [1, 2, 4, 8, 12, 18, 22, 32] 

388 

389 warps = [1] 

390 num_stages = [1, 3] 

391 configs = [ 

392 triton.Config( 

393 { 

394 "BLOCK_ROW_SIZE": m, 

395 }, 

396 num_warps=w, 

397 num_stages=s, 

398 ) 

399 for m in block_m 

400 for w in warps 

401 for s in num_stages 

402 ] 

403 return configs 

404 

405 

406# Set [DW, DB] to zero, can't use reset_to_zero here for DW/DB could be None. 

407def pre_hook(args, reset_only=True): 

408 for i in ["DW", "DB"]: 

409 if args[i] is not None: 

410 args[i].zero_() 

411 

412 

413@libentry() 

414@triton.autotune( 

415 configs=cfggen_bw_middle_n(), 

416 key=["M", "N"], 

417 pre_hook=pre_hook, 

418) 

419@triton.jit 

420def layer_norm_backward_kernel_middle_n( 

421 DX, # pointer to the input gradient 

422 DY, # pointer to the output gradient 

423 DW, # pointer to the partial sum of weights gradient 

424 DB, # pointer to the partial sum of biases gradient 

425 X, # pointer to the input 

426 W, # pointer to the weights 

427 Mean, # pointer to the mean 

428 Rstd, # pointer to the 1/std 

429 M, # number of rows in X 

430 N: tl.constexpr, # number of columns in X 

431 BLOCK_ROW_SIZE: tl.constexpr, 

432): 

433 pid = tl.program_id(0) 

434 

435 row_start = pid * BLOCK_ROW_SIZE 

436 cols = tl.arange(0, N) 

437 num_jobs = tl.num_programs(axis=0) 

438 step = num_jobs * BLOCK_ROW_SIZE 

439 

440 X += cols[None, :] 

441 DY += cols[None, :] 

442 DX += cols[None, :] 

443 if W is None: 

444 w = 1 

445 else: 

446 W += cols[None, :] 

447 w = tl.load(W).to(tl.float32) 

448 

449 if DW is not None: 

450 partial_dw = tl.zeros([BLOCK_ROW_SIZE, N], dtype=tl.float32) 

451 if DB is not None: 

452 partial_db = tl.zeros([BLOCK_ROW_SIZE, N], dtype=tl.float32) 

453 for row in range(row_start, M, step): 

454 row_off = row + tl.arange(0, BLOCK_ROW_SIZE) 

455 mask = row_off[:, None] < M 

456 # Load data to SRAM 

457 off = row_off[:, None] * N 

458 x = tl.load(X + off, mask, other=0.0).to(tl.float32) 

459 dy = tl.load(DY + off, mask, other=0.0).to(tl.float32) 

460 mean = tl.load(Mean + row_off, mask=row_off < M)[:, None].to(tl.float32) 

461 rstd = tl.load(Rstd + row_off, mask=row_off < M)[:, None].to(tl.float32) 

462 # Compute dx 

463 x_hat = (x - mean) * rstd 

464 wdy = w * dy 

465 x_hat_dy = x_hat * wdy 

466 x_hat_dy = tl.view(x_hat_dy, (BLOCK_ROW_SIZE, N)) 

467 x_hat_dy_trans = tl.trans(x_hat_dy) 

468 c1 = tl.sum(x_hat_dy_trans, axis=0)[:, None] 

469 

470 wdy_v = tl.view(wdy, (BLOCK_ROW_SIZE, N)) 

471 wdy_v_trans = tl.trans(wdy_v) 

472 c2 = tl.sum(wdy_v_trans, axis=0)[:, None] 

473 dx = (wdy - (x_hat * c1 + c2) / N) * rstd 

474 # Write dx 

475 tl.store(DX + off, dx.to(x.dtype), mask=mask) 

476 

477 # Accumulate partial sums for dw/db 

478 if DW is not None: 

479 partial_dw += (dy * x_hat).to(tl.float32) 

480 if DB is not None: 

481 partial_db += (dy).to(tl.float32) 

482 

483 if DW is not None: 

484 dw = tl.sum(partial_dw, axis=0) 

485 tl.atomic_add(DW + cols, dw) 

486 if DB is not None: 

487 db = tl.sum(partial_db, axis=0) 

488 tl.atomic_add(DB + cols, db) 

489 

490 

491def layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-5): 

492 logger.debug("GEMS_CAMBRICON LAYERNORM FORWARD") 

493 # dim = x.ndim - len(normalized_shape) 

494 # M = math.prod(x.shape[:dim]) 

495 N = math.prod(normalized_shape) 

496 M = input.numel() // N 

497 input = input.contiguous() 

498 if weight is not None: 

499 weight = weight.contiguous() 

500 if bias is not None: 

501 bias = bias.contiguous() 

502 y = torch.empty_like(input) 

503 acc_type = get_accumulator_dtype(input.dtype) 

504 mean = torch.empty(M, dtype=acc_type, device=input.device) 

505 rstd = torch.empty(M, dtype=acc_type, device=input.device) 

506 if N <= MAX_C_MLU_LAYERNORM_FORWARD: 

507 grid = lambda META: ( 

508 min(triton.cdiv(M, META["BLOCK_ROW_SIZE"]), TOTAL_CORE_NUM), 

509 ) 

510 with torch_device_fn.device(input.device): 

511 layer_norm_kernel_middle_n[grid]( 

512 input, y, weight, bias, mean, rstd, M, eps, N 

513 ) 

514 else: 

515 grid = lambda META: (triton.cdiv(M, META["BLOCK_ROW_SIZE"]),) 

516 with torch_device_fn.device(input.device): 

517 layer_norm_kernel_inner[grid](input, y, weight, bias, mean, rstd, M, eps, N) 

518 return y, mean, rstd 

519 

520 

521def layer_norm_backward( 

522 grad_out, 

523 input, 

524 normalized_shape, 

525 mean, 

526 rstd, 

527 weight=None, 

528 bias=None, 

529 output_mask=None, 

530): 

531 logger.debug("GEMS_CAMBRICON LAYERNORM BACKWARD") 

532 grad_out = grad_out.contiguous() 

533 input = input.contiguous() 

534 mean = mean.contiguous() 

535 rstd = rstd.contiguous() 

536 weight = None if weight is None else weight.contiguous() 

537 bias = None if bias is None else bias.contiguous() 

538 

539 M = input.shape[0] 

540 N = input.numel() // M 

541 

542 if N <= MAX_C_MLU_LAYERNORM_BACKWARD: 

543 in_grad = torch.empty_like(grad_out) 

544 if weight is None: 

545 weight_grad = None 

546 else: 

547 weight_grad = torch.zeros( 

548 (weight.shape[0],), dtype=torch.float, device=weight.device 

549 ) 

550 if bias is None: 

551 bias_grad = None 

552 else: 

553 bias_grad = torch.zeros( 

554 (weight.shape[0],), dtype=torch.float, device=weight.device 

555 ) 

556 # enqueue kernel using forward pass heuristics 

557 # also compute partial sums for DW and DB 

558 grid = lambda META: ( 

559 min(triton.cdiv(M, META["BLOCK_ROW_SIZE"]), TOTAL_CORE_NUM), 

560 ) 

561 with torch_device_fn.device(input.device): 

562 layer_norm_backward_kernel_middle_n[grid]( 

563 in_grad, 

564 grad_out, 

565 weight_grad, 

566 bias_grad, 

567 input, 

568 weight, 

569 mean, 

570 rstd, 

571 M=M, 

572 N=N, 

573 ) 

574 if weight_grad is not None: 

575 weight_grad = weight_grad.to(input.dtype) 

576 if bias_grad is not None: 

577 bias_grad = bias_grad.to(input.dtype) 

578 else: 

579 in_grad = torch.empty_like(input) 

580 grid = lambda META: ( 

581 min(triton.cdiv(M, META["BLOCK_ROW_SIZE"]), TOTAL_CORE_NUM), 

582 ) 

583 input_backward_kernel[grid]( 

584 grad_out, 

585 input, 

586 weight, 

587 mean, 

588 rstd, 

589 in_grad, 

590 M, 

591 N, 

592 ) 

593 if weight is None and bias is None: 

594 return in_grad, None, None 

595 

596 with torch_device_fn.device(input.device): 

597 grid = lambda META: ( 

598 min(triton.cdiv(N, META["BLOCK_COL_SIZE"]), TOTAL_CORE_NUM), 

599 ) 

600 weight_grad = None if weight is None else torch.empty_like(weight) 

601 bias_grad = None if bias is None else torch.empty_like(bias) 

602 weight_bias_backward_kernel[grid]( 

603 grad_out, 

604 input, 

605 mean, 

606 rstd, 

607 weight_grad, 

608 bias_grad, 

609 M, 

610 N, 

611 ) 

612 

613 return in_grad, weight_grad, bias_grad