Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/rms_norm.py: 0%

213 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-12 02:21 +0800

1import builtins 

2import logging 

3import math 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry 

11from flag_gems.utils import triton_lang_extension as tle 

12 

13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

14 

15 

16@libentry() 

17@triton.jit 

18def rms_norm_kernel( 

19 Y, # pointer to the output 

20 INV_RMS, # pointer to inverse rms 

21 X, # pointer to the input 

22 W, # pointer to the weights 

23 y_stride_r, 

24 y_stride_c, 

25 x_stride_r, # how much to increase the pointer when moving by 1 row 

26 x_stride_c, # how much to increase the pointer when moving by 1 col 

27 M: tl.constexpr, # number of rows in X 

28 N: tl.constexpr, # number of columns in X 

29 eps: tl.constexpr, # epsilon to avoid division by zero 

30 BLOCK_SIZE: tl.constexpr, 

31): 

32 pid = tle.program_id(0) 

33 Y += pid * y_stride_r 

34 X += pid * x_stride_r 

35 

36 colMask = tl.arange(0, BLOCK_SIZE) < M 

37 mask = tl.arange(0, BLOCK_SIZE) < N 

38 cols = tl.arange(0, BLOCK_SIZE) 

39 x = tl.load(X + cols * x_stride_c, mask & colMask, other=0.0).to(tl.float32) 

40 

41 var = tl.sum(x * x, axis=0) / N 

42 rrms = 1 / tl.sqrt(var + eps) 

43 

44 w = tl.load(W + tl.arange(0, BLOCK_SIZE), mask=mask, other=0.0) 

45 y = (x * rrms).to(Y.dtype.element_ty) * w 

46 tl.store(Y + cols * y_stride_c, y, mask=mask) 

47 tl.store(INV_RMS + pid, rrms) 

48 

49 

50@libentry() 

51@triton.jit 

52def rms_norm_kerne_tile( 

53 Y, # pointer to the output 

54 INV_RMS, # pointer to inverse rms 

55 X, # pointer to the input 

56 W, # pointer to the weights 

57 y_stride_r, 

58 y_stride_c, 

59 x_stride_r, # how much to increase the pointer when moving by 1 row 

60 x_stride_c, # how much to increase the pointer when moving by 1 col 

61 M: tl.constexpr, # number of rows in X 

62 N: tl.constexpr, # number of columns in X 

63 eps: tl.constexpr, # epsilon to avoid division by zero 

64 BLOCK_SIZE: tl.constexpr, 

65): 

66 pid = tl.program_id(0) 

67 Y += pid * y_stride_r 

68 X += pid * x_stride_r 

69 

70 # mask = tl.arange(0, BLOCK_SIZE) < N 

71 # cols = tl.arange(0, BLOCK_SIZE) 

72 # x = tl.load(X + cols * x_stride_c, mask, other=0.0).to(tl.float32) 

73 

74 # var = tl.sum(x * x, axis=0) / N 

75 # rrms = 1 / tl.sqrt(var + eps) 

76 

77 colMask = tl.arange(0, BLOCK_SIZE) < M 

78 

79 _var_base = tl.zeros([BLOCK_SIZE], dtype=tl.float32) 

80 for off in range(0, N, BLOCK_SIZE): 

81 cols = off + tl.arange(0, BLOCK_SIZE) 

82 mask = cols < N 

83 x = tl.load(X + cols, mask & colMask, other=0.0).to(tl.float32) 

84 _var_base += x * x / N 

85 var = tl.sum(_var_base) 

86 rrms = 1 / tl.sqrt(var + eps) 

87 

88 # w = tl.load(W + tl.arange(0, BLOCK_SIZE), mask=mask, other=0.0) 

89 # y = (x * rrms).to(Y.dtype.element_ty) * w 

90 # tl.store(Y + cols * y_stride_c, y, mask=mask) 

91 for off in range(0, N, BLOCK_SIZE): 

92 cols = off + tl.arange(0, BLOCK_SIZE) 

93 mask = cols < N 

94 x = tl.load(X + cols, mask, other=0.0).to(tl.float32) 

95 w = tl.load(W + cols, mask, other=0.0) 

96 y = (x * rrms).to(Y.dtype.element_ty) * w 

97 tl.store(Y + cols * y_stride_c, y, mask=mask) 

98 

99 tl.store(INV_RMS + pid, rrms) 

100 

101 

102@libentry() 

103@triton.jit(do_not_specialize=["eps"]) 

104def rms_norm_grad_dx_kernel( 

105 X, # pointer to the input 

106 DY, 

107 INV_RMS, # pointer to inverse rms 

108 DX, # pointer to the output 

109 W, # pointer to the weights 

110 dx_stride_r, 

111 dx_stride_c, 

112 x_stride_r, # how much to increase the pointer when moving by 1 row 

113 x_stride_c, # how much to increase the pointer when moving by 1 col 

114 N, # number of columns in X 

115 eps, # epsilon to avoid division by zero 

116 BLOCK_SIZE: tl.constexpr, 

117): 

118 pid = tle.program_id(0) 

119 DX += pid * dx_stride_r 

120 X += pid * x_stride_r 

121 DY += pid * x_stride_r 

122 INV_RMS += pid 

123 

124 mask = tl.arange(0, BLOCK_SIZE) < N 

125 cols = tl.arange(0, BLOCK_SIZE) 

126 x = tl.load(X + cols * x_stride_c, mask, other=0.0).to(tl.float32) 

127 inv_rms = tl.load(INV_RMS).to(tl.float32) 

128 dy = tl.load(DY + cols * x_stride_c, mask, other=0.0).to(tl.float32) 

129 w = tl.load(W + tl.arange(0, BLOCK_SIZE), mask=mask, other=0.0) 

130 

131 dy = dy * w 

132 

133 normalized_buf = x * inv_rms 

134 row_sum_stats = tl.sum(normalized_buf * dy, axis=0) 

135 

136 norm_val = normalized_buf / N 

137 dx = (dy - norm_val * row_sum_stats) * inv_rms 

138 

139 tl.store(DX + cols * dx_stride_c, dx, mask=mask) 

140 

141 

142@libentry() 

143@triton.jit(do_not_specialize=["eps"]) 

144def rms_norm_grad_dx_kernel_tile( 

145 X, # pointer to the input 

146 DY, 

147 INV_RMS, # pointer to inverse rms 

148 DX, # pointer to the output 

149 W, # pointer to the weights 

150 dx_stride_r, 

151 dx_stride_c, 

152 x_stride_r, # how much to increase the pointer when moving by 1 row 

153 x_stride_c, # how much to increase the pointer when moving by 1 col 

154 N, # number of columns in X 

155 eps, # epsilon to avoid division by zero 

156 BLOCK_SIZE: tl.constexpr, 

157): 

158 pid = tle.program_id(0) 

159 DX += pid * dx_stride_r 

160 X += pid * x_stride_r 

161 DY += pid * x_stride_r 

162 INV_RMS += pid 

163 

164 # mask = tl.arange(0, BLOCK_SIZE) < N 

165 # cols = tl.arange(0, BLOCK_SIZE) 

166 # x = tl.load(X + cols * x_stride_c, mask, other=0.0).to(tl.float32) 

167 inv_rms = tl.load(INV_RMS).to(tl.float32) 

168 # dy = tl.load(DY + cols * x_stride_c, mask, other=0.0).to(tl.float32) 

169 # w = tl.load(W + tl.arange(0, BLOCK_SIZE), mask=mask, other=0.0) 

170 

171 # dy = dy * w 

172 

173 # normalized_buf = x * inv_rms 

174 # row_sum_stats = tl.sum(normalized_buf * dy, axis=0) 

175 

176 row_sum_stats_base = tl.zeros([BLOCK_SIZE], dtype=tl.float32) 

177 for off in range(0, N, BLOCK_SIZE): 

178 cols = off + tl.arange(0, BLOCK_SIZE) 

179 mask = cols < N 

180 x = tl.load(X + cols, mask, other=0.0).to(tl.float32) 

181 dy = tl.load(DY + cols, mask, other=0.0).to(tl.float32) 

182 w = tl.load(W + cols, mask, other=0.0).to(tl.float32) 

183 

184 dy = dy * w 

185 

186 normalized_buf = x * inv_rms 

187 

188 row_sum_stats_base += normalized_buf * dy 

189 row_sum_stats = tl.sum(row_sum_stats_base) 

190 

191 # norm_val = normalized_buf / N 

192 # dx = (dy - norm_val * row_sum_stats) * inv_rms 

193 

194 for off in range(0, N, BLOCK_SIZE): 

195 cols = off + tl.arange(0, BLOCK_SIZE) 

196 mask = cols < N 

197 x = tl.load(X + cols, mask, other=0.0).to(tl.float32) 

198 dy = tl.load(DY + cols, mask, other=0.0).to(tl.float32) 

199 w = tl.load(W + cols, mask, other=0.0).to(tl.float32) 

200 

201 dy = dy * w 

202 

203 normalized_buf = x * inv_rms 

204 norm_val = normalized_buf / N 

205 dx = (dy - norm_val * row_sum_stats) * inv_rms 

206 

207 tl.store(DX + cols * dx_stride_c, dx, mask=mask) 

208 

209 

210@libentry() 

211@triton.jit 

212def rms_norm_grad_dw_kernel( 

213 X, # pointer to the input 

214 DY, 

215 INV_RMS, # pointer to inverse rms 

216 DW, # pointer to the output 

217 dx_stride_r, 

218 dx_stride_c, 

219 x_stride_r, # how much to increase the pointer when moving by 1 row 

220 x_stride_c, # how much to increase the pointer when moving by 1 col 

221 M, # number of rows in X 

222 N, # number of columns in X 

223 ROW_BLOCK_SIZE: tl.constexpr, 

224 COL_BLOCK_SIZE: tl.constexpr, 

225): 

226 row_pid = tl.program_id(0) 

227 col_pid = tl.program_id(1) 

228 

229 row_start = row_pid * ROW_BLOCK_SIZE 

230 col_start = col_pid * COL_BLOCK_SIZE 

231 

232 offset = row_start * x_stride_r + col_start * x_stride_c 

233 X += offset 

234 DY += offset 

235 INV_RMS += row_start 

236 

237 rows = tl.arange(0, ROW_BLOCK_SIZE) 

238 cols = tl.arange(0, COL_BLOCK_SIZE) 

239 

240 row_mask = (row_start + rows) < M 

241 col_mask = (col_start + cols) < N 

242 

243 x = tl.load( 

244 X + rows[:, None] * x_stride_r + cols[None, :] * x_stride_c, 

245 row_mask[:, None] & col_mask[None, :], 

246 other=0.0, 

247 ).to(tl.float32) 

248 inv_rms = tl.load(INV_RMS + rows, row_mask, other=0.0).to(tl.float32) 

249 dy = tl.load( 

250 DY + rows[:, None] * x_stride_r + cols[None, :] * x_stride_c, 

251 row_mask[:, None] & col_mask[None, :], 

252 other=0.0, 

253 ).to(tl.float32) 

254 

255 d_weight = x * dy * inv_rms[:, None] 

256 partial_dweight_sum = tl.sum(d_weight, axis=0) 

257 

258 tl.store( 

259 DW + row_pid * N + col_start + cols, 

260 partial_dweight_sum, 

261 mask=col_mask, 

262 ) 

263 

264 

265@libentry() 

266@triton.jit 

267def rms_norm_grad_kernel( 

268 X, 

269 DY, 

270 DX, 

271 W, 

272 INV_RMS, 

273 DW, 

274 M: tl.constexpr, 

275 N: tl.constexpr, 

276 eps: tl.constexpr, 

277 BLOCK_SIZE: tl.constexpr, 

278): 

279 row_idx = tl.program_id(0) 

280 

281 cols = tl.arange(0, BLOCK_SIZE) 

282 mask = cols < N 

283 

284 x_ptr = X + row_idx * N + cols 

285 dy_ptr = DY + row_idx * N + cols 

286 w_ptr = W + cols 

287 

288 x = tl.load(x_ptr, mask=mask, other=0.0).to(tl.float32) 

289 dy = tl.load(dy_ptr, mask=mask, other=0.0).to(tl.float32) 

290 weight = tl.load(w_ptr, mask=mask, other=0.0).to(tl.float32) 

291 inv_rms = tl.load(INV_RMS + row_idx).to(tl.float32) 

292 

293 dy_w = dy * weight 

294 x_inv_rms = x * inv_rms 

295 m_grad = tl.sum(dy_w * x, axis=0) 

296 dx = inv_rms * (dy_w - x_inv_rms * (m_grad / N)) 

297 dx_ptr = DX + row_idx * N + cols 

298 tl.store(dx_ptr, dx, mask=mask) 

299 dw_partial = dy * x_inv_rms 

300 dw_ptr = DW + cols 

301 tl.store(dw_ptr, dw_partial, mask=mask) 

302 

303 

304def rms_norm_forward(x, normalized_shape, weight, eps=1e-5): 

305 logger.debug("GEMS RMS_NORM FORWARD") 

306 dim = x.ndim - len(normalized_shape) 

307 M = math.prod(x.shape[:dim]) 

308 N = math.prod(normalized_shape) 

309 

310 # BLOCK_SIZE = triton.next_power_of_2(N) 

311 BLOCK_SIZE = builtins.min( 

312 64 * 128, triton.next_power_of_2(N) 

313 ) # core_num * buffer_size_limit 

314 

315 x = x.contiguous() 

316 weight = weight.contiguous() 

317 y = torch.empty_like(x) 

318 inv_rms = torch.empty((M,), device=x.device, dtype=torch.float32) 

319 

320 with torch_device_fn.device(x.device): 

321 if N > 64 * 128: 

322 rms_norm_kerne_tile[M,]( 

323 y, inv_rms, x, weight, N, 1, N, 1, M, N, eps, BLOCK_SIZE 

324 ) 

325 else: 

326 rms_norm_kernel[M,]( 

327 y, inv_rms, x, weight, N, 1, N, 1, M, N, eps, BLOCK_SIZE 

328 ) 

329 

330 return y, inv_rms 

331 

332 

333def rms_norm_backward(dy, x, inv_rms, normalized_shape, weight, eps=1e-5): 

334 logger.debug("GEMS RMS_NORM BACKWARD") 

335 

336 dim = x.ndim - len(normalized_shape) 

337 M = math.prod(x.shape[:dim]) 

338 N = math.prod(normalized_shape) 

339 

340 BLOCK_SIZE = triton.next_power_of_2(N) 

341 x = x.contiguous() 

342 dy = dy.contiguous() 

343 weight = weight.contiguous() 

344 dx = torch.empty_like(x) 

345 

346 with torch_device_fn.device(x.device): 

347 if N > 64 * 128: 

348 BLOCK_SIZE = 8192 

349 rms_norm_grad_dx_kernel_tile[M,]( 

350 x, 

351 dy, 

352 inv_rms, 

353 dx, 

354 weight, 

355 N, 

356 1, 

357 N, 

358 1, 

359 N, 

360 eps, 

361 BLOCK_SIZE, 

362 isCloseUnrollControl=True, 

363 isCloseVectorization=True, 

364 ) 

365 else: 

366 rms_norm_grad_dx_kernel[M,]( 

367 x, 

368 dy, 

369 inv_rms, 

370 dx, 

371 weight, 

372 N, 

373 1, 

374 N, 

375 1, 

376 N, 

377 eps, 

378 BLOCK_SIZE, 

379 isCloseUnrollControl=True, 

380 ) 

381 

382 ROW_BLOCK_SIZE = 1 

383 COL_BLOCK_SIZE = 256 

384 row_block_num = triton.cdiv(M, ROW_BLOCK_SIZE) 

385 col_block_num = triton.cdiv(N, COL_BLOCK_SIZE) 

386 

387 partial_buffer = torch.empty( 

388 (row_block_num, N), dtype=torch.float32, device=x.device 

389 ) 

390 

391 with torch_device_fn.device(x.device): 

392 rms_norm_grad_dw_kernel[row_block_num, col_block_num]( 

393 x, 

394 dy, 

395 inv_rms, 

396 partial_buffer, 

397 N, 

398 1, 

399 N, 

400 1, 

401 M, 

402 N, 

403 ROW_BLOCK_SIZE, 

404 COL_BLOCK_SIZE, 

405 isCloseUnrollControl=True, 

406 isCloseCoreTiling=True, 

407 ) 

408 dw = torch.sum(partial_buffer, dim=0, dtype=x.dtype).reshape(-1) 

409 return dx, dw 

410 

411 

412def rms_norm_backward_fusion(dy, x, inv_rms, normalized_shape, weight, eps=1e-5): 

413 logger.debug("GEMS RMS_NORM BACKWARD") 

414 

415 dim = x.ndim - len(normalized_shape) 

416 M = math.prod(x.shape[:dim]) # Batch dimension 

417 N = math.prod(normalized_shape) # Feature dimension 

418 

419 x = x.contiguous() 

420 dy = dy.contiguous() 

421 weight = weight.contiguous() 

422 

423 dx = torch.empty_like(x) 

424 dw = torch.empty_like(weight) 

425 

426 BLOCK_SIZE = 64 

427 

428 with torch_device_fn.device(x.device): 

429 rms_norm_grad_kernel[(M,)]( 

430 x, 

431 dy, 

432 dx, 

433 weight, 

434 inv_rms, 

435 dw, 

436 M, 

437 N, 

438 eps, 

439 BLOCK_SIZE=BLOCK_SIZE, 

440 ) 

441 return dx, dw 

442 

443 

444class RmsNorm(torch.autograd.Function): 

445 @staticmethod 

446 def forward(ctx, x, normalized_shape, weight, eps=1e-5): 

447 y, inv_rms = rms_norm_forward(x, normalized_shape, weight, eps) 

448 ctx.save_for_backward(x, inv_rms, weight) 

449 ctx.normalized_shape = normalized_shape 

450 ctx.eps = eps 

451 return y 

452 

453 @staticmethod 

454 def backward(ctx, dy): 

455 x, inv_rms, weight = ctx.saved_tensors 

456 normalized_shape = ctx.normalized_shape 

457 eps = ctx.eps 

458 

459 # dx, dw = rms_norm_backward(dy, x, inv_rms, normalized_shape, weight, eps) 

460 dx, dw = rms_norm_backward_fusion(dy, x, inv_rms, normalized_shape, weight, eps) 

461 return dx, None, dw, None 

462 

463 

464def rms_norm(x, normalized_shape, weight, eps=1e-5): 

465 return RmsNorm.apply(x, normalized_shape, weight, eps)