Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/layernorm.py: 0%

295 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-19 02:32 +0800

1import logging 

2import math 

3import os 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9# from flag_gems import runtime 

10from flag_gems.runtime import torch_device_fn 

11from flag_gems.utils import libentry 

12from flag_gems.utils import triton_lang_extension as tle 

13 

14logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

15 

16 

17@triton.jit 

18def prev_multiple_of(a, b): 

19 # the largest x<a that x%b ==0 

20 return tl.cdiv(a, b) * b - b 

21 

22 

23@libentry() 

24# @triton.autotune( 

25# configs=runtime.get_tuned_config("layer_norm_persistent"), 

26# key=["M", "N"], 

27# ) 

28@triton.jit(do_not_specialize=["eps"]) 

29def layer_norm_persistent_kernel( 

30 in_ptr, 

31 out_ptr, 

32 weight_ptr, 

33 bias_ptr, 

34 out_mean_ptr, # pointer to the mean 

35 out_rstd_ptr, # pointer to the 1/std 

36 M, 

37 N, 

38 eps, 

39 TILE_N: tl.constexpr, 

40): 

41 # using 1d tile makes code clean 

42 # Map the program id to the row of X and Y it should compute. 

43 pid = tle.program_id(0) 

44 

45 n_offsets = tl.arange(0, TILE_N) 

46 mask = n_offsets < N 

47 

48 x = tl.load(in_ptr + pid * N + n_offsets, mask, other=0.0).to(tl.float32) 

49 m = tl.sum(x) / N 

50 d = x - m # deviation 

51 s = tl.where(mask, d * d, 0) 

52 sum_square = tl.sum(s) # sum of square of deviation 

53 var = sum_square / N 

54 rstd = tl.math.rsqrt(var + eps) 

55 

56 tl.store(out_mean_ptr + pid, m) 

57 tl.store(out_rstd_ptr + pid, rstd) 

58 

59 if weight_ptr is None: 

60 w = 1 

61 else: 

62 w = tl.load(weight_ptr + n_offsets, mask=mask) 

63 if bias_ptr is None: 

64 b = 0 

65 else: 

66 b = tl.load(bias_ptr + n_offsets, mask=mask) 

67 out = (x - m) * rstd * w + b 

68 

69 tl.store(out_ptr + pid * N + n_offsets, out, mask=mask) 

70 

71 

72@libentry() 

73# @triton.autotune( 

74# configs=runtime.get_tuned_config("layer_norm_persistent"), 

75# key=["M", "N"], 

76# ) 

77@triton.jit(do_not_specialize=["eps"]) 

78def layer_norm_persistent_kernel_multiline( 

79 in_ptr, 

80 out_ptr, 

81 weight_ptr, 

82 bias_ptr, 

83 out_mean_ptr, # pointer to the mean 

84 out_rstd_ptr, # pointer to the 1/std 

85 M, 

86 N, 

87 eps, 

88 TILE_M: tl.constexpr, 

89 TILE_N: tl.constexpr, 

90): 

91 # Map the program id to the row of X and Y it should compute. 

92 pid = tle.program_id(0) 

93 m_offsets = pid * TILE_M + tl.arange(0, TILE_M) 

94 m_mask = m_offsets < M 

95 

96 n_offsets = tl.arange(0, TILE_N)[None, :] 

97 n_mask = n_offsets < N 

98 mask = m_mask[:, None] & n_mask 

99 

100 x = tl.load(in_ptr + m_offsets[:, None] * N + n_offsets, mask, other=0.0).to( 

101 tl.float32 

102 ) 

103 m = tl.sum(x, axis=1) / N 

104 d = x - m[:, None] # deviation 

105 s = tl.where(mask, d * d, 0) 

106 sum_square = tl.sum(s, axis=1) # sum of square of deviation 

107 var = sum_square / N 

108 rstd = tl.math.rsqrt(var + eps) 

109 

110 tl.store(out_mean_ptr + m_offsets, m, mask=m_mask) 

111 tl.store(out_rstd_ptr + m_offsets, rstd, mask=m_mask) 

112 

113 if weight_ptr is None: 

114 w = 1 

115 else: 

116 w = tl.load(weight_ptr + n_offsets, mask=n_mask) 

117 if bias_ptr is None: 

118 b = 0 

119 else: 

120 b = tl.load(bias_ptr + n_offsets, mask=n_mask) 

121 out = (x - m[:, None]) * rstd[:, None] * w + b 

122 

123 tl.store(out_ptr + m_offsets[:, None] * N + n_offsets, out, mask=mask) 

124 

125 

126@libentry() 

127# @triton.autotune( 

128# configs=runtime.get_tuned_config("layer_norm_loop"), 

129# key=["M", "N"], 

130# ) 

131@triton.jit(do_not_specialize=["eps"]) 

132def layer_norm_loop_kernel( 

133 in_ptr, 

134 out_ptr, 

135 weight_ptr, 

136 bias_ptr, 

137 out_mean_ptr, # pointer to the mean 

138 out_rstd_ptr, # pointer to the 1/std 

139 M: tl.constexpr, 

140 N: tl.constexpr, 

141 eps, 

142 TILE_N: tl.constexpr, 

143): 

144 # Map the program id to the row of X and Y it should compute. 

145 pid = tle.program_id(0) 

146 

147 # Compute mean 

148 m = tl.zeros((TILE_N,), dtype=tl.float32) # mean 

149 s = tl.zeros((TILE_N,), dtype=tl.float32) # sum((x - m)^2) 

150 cnt = tl.zeros((TILE_N,), dtype=tl.int32) 

151 num_steps = tl.cdiv(N, TILE_N) 

152 for step in range(0, num_steps - 1, 1): 

153 start_n = step * TILE_N 

154 n_offsets = start_n + tl.arange(0, TILE_N) 

155 x = tl.load(in_ptr + pid * N + n_offsets).to(tl.float32) 

156 new_m = m + (x - m) / (step + 1) 

157 new_s = s + (x - new_m) * (x - m) 

158 cnt += 1 

159 m = new_m 

160 s = new_s 

161 

162 # the last step 

163 for step in range(num_steps - 1, num_steps, 1): 

164 start_n = step * TILE_N 

165 n_offsets = start_n + tl.arange(0, TILE_N) 

166 mask = n_offsets < N 

167 x = tl.load(in_ptr + pid * N + n_offsets, mask=mask).to(tl.float32) 

168 new_m = tl.where(mask, m + (x - m) / (step + 1), m) 

169 new_s = tl.where(mask, s + (x - new_m) * (x - m), s) 

170 cnt += mask.to(tl.int32) 

171 m = new_m 

172 s = new_s 

173 

174 final_m = tl.sum(m * cnt) / N 

175 var = tl.sum(s + cnt * (m - final_m) * (m - final_m)) / N 

176 rstd = tl.math.rsqrt(var + eps) 

177 m = final_m 

178 

179 # reverse the order of the second sweep 

180 # Normalize and apply linear transformation 

181 prev_multiple = prev_multiple_of(N, TILE_N) 

182 # the first step, masking is needed 

183 for start_n in range(0, TILE_N, TILE_N): 

184 n_offsets = (prev_multiple - start_n) + tl.arange(0, TILE_N) 

185 mask = n_offsets < N 

186 x = tl.load( 

187 in_ptr + pid * N + n_offsets, 

188 mask=mask, 

189 other=0.0, 

190 eviction_policy="evict_first", 

191 ).to(tl.float32) 

192 if weight_ptr is None: 

193 w = 1 

194 else: 

195 w = tl.load(weight_ptr + n_offsets, mask=mask) 

196 if bias_ptr is None: 

197 b = 0 

198 else: 

199 b = tl.load(bias_ptr + n_offsets, mask=mask) 

200 out = w * (x - m) * rstd + b 

201 tl.store(out_ptr + pid * N + n_offsets, out, mask=mask) 

202 

203 for start_n in range(TILE_N, N, TILE_N): 

204 n_offsets = (prev_multiple - start_n) + tl.arange(0, TILE_N) 

205 x = tl.load(in_ptr + pid * N + n_offsets, eviction_policy="evict_first").to( 

206 tl.float32 

207 ) 

208 if weight_ptr is None: 

209 w = 1 

210 else: 

211 w = tl.load(weight_ptr + n_offsets) 

212 if bias_ptr is None: 

213 b = 0 

214 else: 

215 b = tl.load(bias_ptr + n_offsets) 

216 out = w * (x - m) * rstd + b 

217 tl.store(out_ptr + pid * N + n_offsets, out) 

218 

219 # Write mean / rstd 

220 tl.store(out_mean_ptr + pid, m) 

221 tl.store(out_rstd_ptr + pid, rstd) 

222 

223 

224@triton.jit 

225def layernorm_fwd_kernel( 

226 X, 

227 Y, 

228 W, 

229 B, 

230 eps, 

231 MEAN, 

232 RSTRD, 

233 xnumel: tl.constexpr, 

234 rnumel: tl.constexpr, 

235 XBLOCK: tl.constexpr, 

236 RBLOCK: tl.constexpr, 

237): 

238 xoffset = tl.program_id(0) * XBLOCK 

239 xindex = xoffset + tl.arange(0, XBLOCK)[:, None] 

240 xmask = xindex < xnumel 

241 rbase = tl.arange(0, RBLOCK)[None, :] 

242 _mean = tl.full([XBLOCK, RBLOCK], 0, tl.float32) 

243 _var = tl.full([XBLOCK, RBLOCK], 0, tl.float32) 

244 

245 for roffset in range(0, rnumel, RBLOCK): 

246 rindex = roffset + rbase 

247 rmask = rindex < rnumel 

248 x = tl.load(X + (rindex + (rnumel * xindex)), rmask & xmask, other=0.0) 

249 _mean = _mean + tl.broadcast_to(x, [XBLOCK, RBLOCK]) 

250 _var = _var + tl.broadcast_to(x * x, [XBLOCK, RBLOCK]) 

251 

252 mean = tl.sum(_mean, 1)[:, None] / rnumel 

253 var = tl.sum(_var, 1)[:, None] / rnumel 

254 var_mean = var - mean * mean 

255 rstd = 1 / tl.sqrt(var_mean + eps) 

256 # rstd = tl.math.rsqrt(var_mean + eps) 

257 

258 tl.store(MEAN + xindex, mean, xmask) 

259 tl.store(RSTRD + xindex, rstd, xmask) 

260 

261 for roffset in range(0, rnumel, RBLOCK): 

262 rindex = roffset + rbase 

263 rmask = rindex < rnumel 

264 x = tl.load(X + (rindex + (rnumel * xindex)), rmask & xmask, other=0.0) 

265 if W is None: 

266 w = 1 

267 else: 

268 w = tl.load(W + (rindex), rmask) 

269 if B is None: 

270 b = 0 

271 else: 

272 b = tl.load(B + (rindex), rmask) 

273 x_hat = (x - mean) * rstd 

274 y = x_hat * w + b 

275 tl.store(Y + (rindex + (rnumel * xindex)), y, rmask & xmask) 

276 

277 

278def layer_norm_backward_kernel_heur_block_row_size(args): 

279 # if args["dX"].dtype == torch.bfloat16 and args["M"] == 100 and args["N"] == 40499: 

280 # return args["M"] 

281 return triton.next_power_of_2(triton.cdiv(args["M"], 12)) 

282 # return 1 

283 

284 

285def layer_norm_backward_kernel_heur_block_col_size(args): 

286 if args["dX"].dtype == torch.float32 and args["M"] == 1 and args["N"] == 40999: 

287 return 4096 # 8192 cause leagalize error 

288 

289 if args["M"] == 100 and args["N"] == 40499: 

290 return 4096 # 8192 cause leagalize error 

291 

292 import builtins 

293 

294 return builtins.min(args["N"], 8192) 

295 

296 

297@libentry() 

298# @triton.autotune( 

299# configs=runtime.get_tuned_config("layer_norm_backward"), 

300# key=["M", "N"], 

301# ) 

302@triton.heuristics( 

303 values={ 

304 "BLOCK_ROW_SIZE": layer_norm_backward_kernel_heur_block_row_size, 

305 "BLOCK_COL_SIZE": layer_norm_backward_kernel_heur_block_col_size, 

306 }, 

307) 

308@triton.jit 

309def layer_norm_backward_kernel( 

310 dY, 

311 X, 

312 W, 

313 Mean, 

314 Rstd, 

315 dX, 

316 M: tl.constexpr, 

317 N: tl.constexpr, 

318 BLOCK_ROW_SIZE: tl.constexpr, 

319 BLOCK_COL_SIZE: tl.constexpr, 

320): 

321 pid = tle.program_id(0) * BLOCK_ROW_SIZE + tl.arange(0, BLOCK_ROW_SIZE)[:, None] 

322 row_mask = pid < M 

323 dY += pid * N 

324 X += pid * N 

325 dX += pid * N 

326 Mean += pid 

327 Rstd += pid 

328 

329 mean = tl.load(Mean, mask=row_mask).to(tl.float32) 

330 rstd = tl.load(Rstd, mask=row_mask).to(tl.float32) 

331 

332 dx_part2 = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32) 

333 dx_part3 = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32) 

334 

335 for off in range(0, N, BLOCK_COL_SIZE): 

336 cols = off + tl.arange(0, BLOCK_COL_SIZE) 

337 col_mask = cols[None, :] < N 

338 mask = row_mask and col_mask 

339 dy = tl.load(dY + cols[None, :], mask).to(tl.float32) 

340 x = tl.load(X + cols[None, :], mask).to(tl.float32) 

341 x = tl.where(mask, x - mean, 0.0) 

342 x_hat = x * rstd 

343 if W is None: 

344 w = 1 

345 else: 

346 w = tl.load(W + cols, mask=cols < N).to(tl.float32) 

347 dx_hat = dy * w 

348 dx_part2 += dx_hat 

349 dx_part3 += dx_hat * x_hat 

350 

351 dx_2 = tl.sum(dx_part2, axis=1)[:, None] 

352 dx_3 = tl.sum(dx_part3, axis=1)[:, None] 

353 

354 for off in range(0, N, BLOCK_COL_SIZE): 

355 cols = off + tl.arange(0, BLOCK_COL_SIZE) 

356 col_mask = cols[None, :] < N 

357 mask = row_mask and col_mask 

358 dy = tl.load(dY + cols[None, :], mask).to(tl.float32) 

359 x = tl.load(X + cols[None, :], mask).to(tl.float32) 

360 if W is None: 

361 w = 1 

362 else: 

363 w = tl.load(W + cols, mask=cols < N).to(tl.float32) 

364 x = tl.where(mask, x - mean, 0.0) 

365 x_hat = x * rstd 

366 dx_hat = dy * w 

367 dx = rstd * (dx_hat - (dx_2 + x_hat * dx_3) / N) 

368 tl.store(dX + cols, dx, mask=mask) 

369 

370 

371def weight_bias_backward_kernel_heur_block_row_size(args): 

372 return 1 

373 

374 

375def weight_bias_backward_kernel_heur_block_col_size(args): 

376 # if args["M"] == 100 and args["N"] == 40499: 

377 # if args["dY"].dtype == torch.bfloat16: 

378 # return 2048 

379 # return 4096 # 8192 cause leagalize error 

380 

381 import builtins 

382 

383 return builtins.min(args["N"], 8192) 

384 

385 

386@libentry() 

387# @triton.autotune( 

388# configs=runtime.get_tuned_config("weight_bias_backward"), 

389# key=["N"], 

390# ) 

391@triton.heuristics( 

392 values={ 

393 "BLOCK_ROW_SIZE": weight_bias_backward_kernel_heur_block_row_size, 

394 "BLOCK_COL_SIZE": weight_bias_backward_kernel_heur_block_col_size, 

395 }, 

396) 

397@triton.jit 

398def weight_bias_backward_kernel( 

399 dY, 

400 X, 

401 Mean, 

402 Rstd, 

403 dW, 

404 dB, 

405 M: tl.constexpr, 

406 N: tl.constexpr, 

407 BLOCK_ROW_SIZE: tl.constexpr, 

408 BLOCK_COL_SIZE: tl.constexpr, 

409): 

410 pid = tle.program_id(0) * BLOCK_COL_SIZE + tl.arange(0, BLOCK_COL_SIZE)[None, :] 

411 col_mask = pid < N 

412 dY += pid 

413 X += pid 

414 accW = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32) 

415 accB = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32) 

416 for off in range(0, M, BLOCK_ROW_SIZE): 

417 rows = off + tl.arange(0, BLOCK_ROW_SIZE) 

418 row_mask = rows[:, None] < M 

419 mask = row_mask and col_mask 

420 dy = tl.load(dY + rows[:, None] * N, mask).to(tl.float32) 

421 x = tl.load(X + rows[:, None] * N, mask).to(tl.float32) 

422 mean = tl.load(Mean + rows, mask=rows < M)[:, None].to(tl.float32) 

423 rstd = tl.load(Rstd + rows, mask=rows < M)[:, None].to(tl.float32) 

424 x = tl.where(col_mask, x - mean, 0.0) 

425 x_hat = x * rstd 

426 accW += dy * x_hat 

427 accB += dy 

428 if dW is not None: 

429 dw = tl.sum(accW, axis=0) 

430 tl.store(dW + pid, dw[None, :], mask=col_mask) 

431 if dB is not None: 

432 db = tl.sum(accB, axis=0) 

433 tl.store(dB + pid, db[None, :], mask=col_mask) 

434 

435 

436def layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-5): 

437 logger.debug("GEMS LAYERNORM FORWARD") 

438 

439 N = math.prod(normalized_shape) 

440 M = input.numel() // N 

441 

442 input = input.contiguous() 

443 weight = None if weight is None else weight.contiguous() 

444 bias = None if bias is None else bias.contiguous() 

445 y = torch.empty_like(input) 

446 

447 # NOTE: when the input is half-precision(either float16 or bfloat16) 

448 # these statistical data saved for backward is in single precision 

449 mean = torch.empty(M, dtype=input.dtype, device=input.device) 

450 rstd = torch.empty(M, dtype=input.dtype, device=input.device) 

451 

452 with torch_device_fn.device(input.device): 

453 if input.dtype == torch.float16 and input.shape == (4096, 100): 

454 TILE_N = 8192 # triton.next_power_of_2(N) 

455 grid = (M, 1, 1) 

456 layer_norm_loop_kernel[grid]( 

457 input, 

458 y, 

459 weight, 

460 bias, 

461 mean, 

462 rstd, 

463 M, 

464 N, 

465 eps, 

466 TILE_N, 

467 isCloseUnrollControl=True, 

468 ) 

469 else: 

470 grid = (12, 1, 1) 

471 layernorm_fwd_kernel[grid]( 

472 input, 

473 y, 

474 weight, 

475 bias, 

476 eps, 

477 mean, 

478 rstd, 

479 M, 

480 N, 

481 XBLOCK=triton.next_power_of_2(triton.cdiv(M, 12)), 

482 RBLOCK=8192, 

483 isCloseUnrollControl=True, 

484 buffer_size_limit=512, 

485 ) 

486 

487 return y, mean, rstd 

488 

489 

490def layer_norm_backward( 

491 grad_out, 

492 input, 

493 normalized_shape, 

494 mean, 

495 rstd, 

496 weight=None, 

497 bias=None, 

498 output_mask=None, 

499): 

500 logger.debug("GEMS LAYERNORM BACKWARD") 

501 

502 grad_out = grad_out.contiguous() 

503 input = input.contiguous() 

504 mean = mean.contiguous() 

505 rstd = rstd.contiguous() 

506 weight = None if weight is None else weight.contiguous() 

507 bias = None if bias is None else bias.contiguous() 

508 

509 M = input.shape[0] 

510 N = input.numel() // M 

511 

512 if output_mask[0]: 

513 in_grad = torch.empty_like(input) 

514 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_ROW_SIZE"]), 1, 1) 

515 os.environ["TRITONXPU_OTHER_SIM"] = "1" 

516 os.environ["TRITONXPU_STORE_MASK_SIM"] = "1" 

517 os.environ["TRITONXPU_DTYPE_CONVERT"] = "1" 

518 if M == 100 and N == 40499: 

519 isCloseUnrollControl = True 

520 isCloseCoreTiling = True 

521 else: 

522 isCloseUnrollControl = False 

523 isCloseCoreTiling = False 

524 

525 with torch_device_fn.device(input.device): 

526 layer_norm_backward_kernel[grid]( 

527 grad_out, 

528 input, 

529 weight, 

530 mean, 

531 rstd, 

532 in_grad, 

533 M, 

534 N, 

535 isCloseUnrollControl=isCloseUnrollControl, 

536 isCloseCoreTiling=isCloseCoreTiling, 

537 isCloseVectorization=True, 

538 ) 

539 if "TRITONXPU_OTHER_SIM" in os.environ: 

540 del os.environ["TRITONXPU_OTHER_SIM"] 

541 if "TRITONXPU_STORE_MASK_SIM" in os.environ: 

542 del os.environ["TRITONXPU_STORE_MASK_SIM"] 

543 if "TRITONXPU_DTYPE_CONVERT" in os.environ: 

544 del os.environ["TRITONXPU_DTYPE_CONVERT"] 

545 else: 

546 in_grad = None 

547 

548 if output_mask[1] is False and output_mask[2] is False: 

549 return in_grad, None, None 

550 

551 grid = lambda meta: (triton.cdiv(N, meta["BLOCK_COL_SIZE"]), 1, 1) 

552 weight_grad = torch.empty_like(weight) if output_mask[1] else None 

553 bias_grad = torch.empty_like(bias) if output_mask[2] else None 

554 with torch_device_fn.device(input.device): 

555 weight_bias_backward_kernel[grid]( 

556 grad_out, 

557 input, 

558 mean, 

559 rstd, 

560 weight_grad, 

561 bias_grad, 

562 M, 

563 N, 

564 isCloseCoreTiling=True, 

565 isCloseUnrollControl=True, 

566 isCloseVectorization=True, 

567 ) 

568 return in_grad, weight_grad, bias_grad