Coverage for src/flag_gems/ops/layernorm.py: 32%

241 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-22 16:54 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry 

11from flag_gems.utils import triton_lang_extension as tle 

12 

13logger = logging.getLogger(__name__) 

14 

15 

16@triton.jit 

17def prev_multiple_of(a, b): 

18 # the largest x<a that x%b ==0 

19 return tl.cdiv(a, b) * b - b 

20 

21 

22@libentry() 

23@triton.autotune( 

24 configs=runtime.get_tuned_config("layer_norm_persistent"), 

25 key=["M", "N"], 

26) 

27@triton.jit(do_not_specialize=["eps"]) 

28def layer_norm_persistent_kernel( 

29 in_ptr, 

30 out_ptr, 

31 weight_ptr, 

32 bias_ptr, 

33 out_mean_ptr, # pointer to the mean 

34 out_rstd_ptr, # pointer to the 1/std 

35 M, 

36 N, 

37 eps, 

38 TILE_N: tl.constexpr, 

39): 

40 # using 1d tile makes code clean 

41 # Map the program id to the row of X and Y it should compute. 

42 pid = tle.program_id(0) 

43 

44 n_offsets = tl.arange(0, TILE_N) 

45 mask = n_offsets < N 

46 

47 x = tl.load(in_ptr + pid * N + n_offsets, mask, other=0.0).to(tl.float32) 

48 m = tl.sum(x) / N 

49 d = x - m # deviation 

50 s = tl.where(mask, d * d, 0) 

51 sum_square = tl.sum(s) # sum of square of deviation 

52 var = sum_square / N 

53 rstd = tl.math.rsqrt(var + eps) 

54 

55 tl.store(out_mean_ptr + pid, m) 

56 tl.store(out_rstd_ptr + pid, rstd) 

57 

58 if weight_ptr is None: 

59 w = 1 

60 else: 

61 w = tl.load(weight_ptr + n_offsets, mask=mask) 

62 if bias_ptr is None: 

63 b = 0 

64 else: 

65 b = tl.load(bias_ptr + n_offsets, mask=mask) 

66 out = (x - m) * rstd * w + b 

67 

68 tl.store(out_ptr + pid * N + n_offsets, out, mask=mask) 

69 

70 

71@libentry() 

72@triton.autotune( 

73 configs=runtime.get_tuned_config("layer_norm_persistent"), 

74 key=["M", "N"], 

75) 

76@triton.jit(do_not_specialize=["eps"]) 

77def layer_norm_persistent_kernel_multiline( 

78 in_ptr, 

79 out_ptr, 

80 weight_ptr, 

81 bias_ptr, 

82 out_mean_ptr, # pointer to the mean 

83 out_rstd_ptr, # pointer to the 1/std 

84 M, 

85 N, 

86 eps, 

87 TILE_M: tl.constexpr, 

88 TILE_N: tl.constexpr, 

89): 

90 # Map the program id to the row of X and Y it should compute. 

91 pid = tle.program_id(0) 

92 m_offsets = pid * TILE_M + tl.arange(0, TILE_M) 

93 m_mask = m_offsets < M 

94 

95 n_offsets = tl.arange(0, TILE_N)[None, :] 

96 n_mask = n_offsets < N 

97 mask = m_mask[:, None] & n_mask 

98 

99 x = tl.load(in_ptr + m_offsets[:, None] * N + n_offsets, mask, other=0.0).to( 

100 tl.float32 

101 ) 

102 m = tl.sum(x, axis=1) / N 

103 d = x - m[:, None] # deviation 

104 s = tl.where(mask, d * d, 0) 

105 sum_square = tl.sum(s, axis=1) # sum of square of deviation 

106 var = sum_square / N 

107 rstd = tl.math.rsqrt(var + eps) 

108 

109 tl.store(out_mean_ptr + m_offsets, m, mask=m_mask) 

110 tl.store(out_rstd_ptr + m_offsets, rstd, mask=m_mask) 

111 

112 if weight_ptr is None: 

113 w = 1 

114 else: 

115 w = tl.load(weight_ptr + n_offsets, mask=n_mask) 

116 if bias_ptr is None: 

117 b = 0 

118 else: 

119 b = tl.load(bias_ptr + n_offsets, mask=n_mask) 

120 out = (x - m[:, None]) * rstd[:, None] * w + b 

121 

122 tl.store(out_ptr + m_offsets[:, None] * N + n_offsets, out, mask=mask) 

123 

124 

125@libentry() 

126@triton.autotune( 

127 configs=runtime.get_tuned_config("layer_norm_loop"), 

128 key=["M", "N"], 

129) 

130@triton.jit(do_not_specialize=["eps"]) 

131def layer_norm_loop_kernel( 

132 in_ptr, 

133 out_ptr, 

134 weight_ptr, 

135 bias_ptr, 

136 out_mean_ptr, # pointer to the mean 

137 out_rstd_ptr, # pointer to the 1/std 

138 M, 

139 N, 

140 eps, 

141 TILE_N: tl.constexpr, 

142): 

143 # Map the program id to the row of X and Y it should compute. 

144 pid = tle.program_id(0) 

145 

146 # Compute mean 

147 m = tl.zeros((TILE_N,), dtype=tl.float32) # mean 

148 s = tl.zeros((TILE_N,), dtype=tl.float32) # sum((x - m)^2) 

149 cnt = tl.zeros((TILE_N,), dtype=tl.int32) 

150 num_steps = tl.cdiv(N, TILE_N) 

151 for step in range(0, num_steps - 1, 1): 

152 start_n = step * TILE_N 

153 n_offsets = start_n + tl.arange(0, TILE_N) 

154 x = tl.load(in_ptr + pid * N + n_offsets).to(tl.float32) 

155 new_m = m + (x - m) / (step + 1) 

156 new_s = s + (x - new_m) * (x - m) 

157 cnt += 1 

158 m = new_m 

159 s = new_s 

160 

161 # the last step 

162 for step in range(num_steps - 1, num_steps, 1): 

163 start_n = step * TILE_N 

164 n_offsets = start_n + tl.arange(0, TILE_N) 

165 mask = n_offsets < N 

166 x = tl.load(in_ptr + pid * N + n_offsets, mask=mask).to(tl.float32) 

167 new_m = tl.where(mask, m + (x - m) / (step + 1), m) 

168 new_s = tl.where(mask, s + (x - new_m) * (x - m), s) 

169 cnt += mask.to(tl.int32) 

170 m = new_m 

171 s = new_s 

172 

173 final_m = tl.sum(m * cnt) / N 

174 var = tl.sum(s + cnt * (m - final_m) * (m - final_m)) / N 

175 rstd = tl.math.rsqrt(var + eps) 

176 m = final_m 

177 # Write mean / rstd 

178 tl.store(out_mean_ptr + pid, m) 

179 tl.store(out_rstd_ptr + pid, rstd) 

180 

181 # reverse the order of the second sweep 

182 # Normalize and apply linear transformation 

183 prev_multiple = prev_multiple_of(N, TILE_N) 

184 # the first step, masking is needed 

185 for start_n in range(0, TILE_N, TILE_N): 

186 n_offsets = (prev_multiple - start_n) + tl.arange(0, TILE_N) 

187 mask = n_offsets < N 

188 x = tl.load( 

189 in_ptr + pid * N + n_offsets, 

190 mask=mask, 

191 other=0.0, 

192 eviction_policy="evict_first", 

193 ).to(tl.float32) 

194 if weight_ptr is None: 

195 w = 1 

196 else: 

197 w = tl.load(weight_ptr + n_offsets, mask=mask) 

198 if bias_ptr is None: 

199 b = 0 

200 else: 

201 b = tl.load(bias_ptr + n_offsets, mask=mask) 

202 out = w * (x - m) * rstd + b 

203 tl.store(out_ptr + pid * N + n_offsets, out, mask=mask) 

204 

205 for start_n in range(TILE_N, N, TILE_N): 

206 n_offsets = (prev_multiple - start_n) + tl.arange(0, TILE_N) 

207 x = tl.load(in_ptr + pid * N + n_offsets, eviction_policy="evict_first").to( 

208 tl.float32 

209 ) 

210 if weight_ptr is None: 

211 w = 1 

212 else: 

213 w = tl.load(weight_ptr + n_offsets) 

214 if bias_ptr is None: 

215 b = 0 

216 else: 

217 b = tl.load(bias_ptr + n_offsets) 

218 out = w * (x - m) * rstd + b 

219 tl.store(out_ptr + pid * N + n_offsets, out) 

220 

221 

222@libentry() 

223@triton.autotune( 

224 configs=runtime.get_tuned_config("layer_norm_backward"), 

225 key=["M", "N"], 

226) 

227@triton.jit 

228def layer_norm_backward_kernel( 

229 dY, 

230 X, 

231 W, 

232 Mean, 

233 Rstd, 

234 dX, 

235 M, 

236 N, 

237 BLOCK_ROW_SIZE: tl.constexpr, 

238 BLOCK_COL_SIZE: tl.constexpr, 

239): 

240 pid = tle.program_id(0) * BLOCK_ROW_SIZE + tl.arange(0, BLOCK_ROW_SIZE)[:, None] 

241 row_mask = pid < M 

242 dY += pid * N 

243 X += pid * N 

244 dX += pid * N 

245 Mean += pid 

246 Rstd += pid 

247 

248 mean = tl.load(Mean, mask=row_mask).to(tl.float32) 

249 rstd = tl.load(Rstd, mask=row_mask).to(tl.float32) 

250 

251 dx_part2 = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32) 

252 dx_part3 = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32) 

253 

254 for off in range(0, N, BLOCK_COL_SIZE): 

255 cols = off + tl.arange(0, BLOCK_COL_SIZE) 

256 col_mask = cols[None, :] < N 

257 mask = row_mask and col_mask 

258 dy = tl.load(dY + cols[None, :], mask).to(tl.float32) 

259 x = tl.load(X + cols[None, :], mask).to(tl.float32) 

260 x = tl.where(mask, x - mean, 0.0) 

261 x_hat = x * rstd 

262 if W is None: 

263 w = 1 

264 else: 

265 w = tl.load(W + cols, mask=cols < N).to(tl.float32) 

266 dx_hat = dy * w 

267 dx_part2 += dx_hat 

268 dx_part3 += dx_hat * x_hat 

269 

270 dx_2 = tl.sum(dx_part2, axis=1)[:, None] 

271 dx_3 = tl.sum(dx_part3, axis=1)[:, None] 

272 

273 for off in range(0, N, BLOCK_COL_SIZE): 

274 cols = off + tl.arange(0, BLOCK_COL_SIZE) 

275 col_mask = cols[None, :] < N 

276 mask = row_mask and col_mask 

277 dy = tl.load(dY + cols[None, :], mask).to(tl.float32) 

278 x = tl.load(X + cols[None, :], mask).to(tl.float32) 

279 if W is None: 

280 w = 1 

281 else: 

282 w = tl.load(W + cols, mask=cols < N).to(tl.float32) 

283 x = tl.where(mask, x - mean, 0.0) 

284 x_hat = x * rstd 

285 dx_hat = dy * w 

286 dx = rstd * (dx_hat - (dx_2 + x_hat * dx_3) / N) 

287 tl.store(dX + cols, dx, mask=mask) 

288 

289 

290@libentry() 

291@triton.autotune( 

292 configs=runtime.get_tuned_config("weight_bias_backward"), 

293 key=["N"], 

294) 

295@triton.jit 

296def weight_bias_backward_kernel( 

297 dY, 

298 X, 

299 Mean, 

300 Rstd, 

301 dW, 

302 dB, 

303 M, 

304 N, 

305 BLOCK_ROW_SIZE: tl.constexpr, 

306 BLOCK_COL_SIZE: tl.constexpr, 

307): 

308 pid = tle.program_id(0) * BLOCK_COL_SIZE + tl.arange(0, BLOCK_COL_SIZE) 

309 col_mask = pid < N 

310 dY += pid[None, :] 

311 X += pid[None, :] 

312 accW = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32) 

313 accB = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32) 

314 for off in range(0, M, BLOCK_ROW_SIZE): 

315 rows = off + tl.arange(0, BLOCK_ROW_SIZE)[:, None] 

316 row_mask = rows < M 

317 mask = row_mask and col_mask[None, :] 

318 dy = tl.load(dY + rows * N, mask).to(tl.float32) 

319 x = tl.load(X + rows * N, mask).to(tl.float32) 

320 mean = tl.load(Mean + rows, mask=rows < M).to(tl.float32) 

321 rstd = tl.load(Rstd + rows, mask=rows < M).to(tl.float32) 

322 x = tl.where(mask, x - mean, 0.0) 

323 accW += dy * x * rstd 

324 accB += dy 

325 if dW: 

326 dw = tl.sum(accW, axis=0) 

327 tl.store(dW + pid, dw, mask=col_mask) 

328 if dB: 

329 db = tl.sum(accB, axis=0) 

330 tl.store(dB + pid, db, mask=col_mask) 

331 

332 

333def layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-5): 

334 logger.debug("GEMS LAYERNORM FORWARD") 

335 

336 N = math.prod(normalized_shape) 

337 M = input.numel() // N 

338 

339 input = input.contiguous() 

340 weight = None if weight is None else weight.contiguous() 

341 bias = None if bias is None else bias.contiguous() 

342 y = torch.empty_like(input) 

343 

344 # NOTE: when the input is half-precision(either float16 or bfloat16) 

345 # these statistical data saved for backward is in single precision 

346 mean = torch.empty(M, dtype=input.dtype, device=input.device) 

347 rstd = torch.empty(M, dtype=input.dtype, device=input.device) 

348 

349 with torch_device_fn.device(input.device): 

350 if N <= 128: 

351 TILE_N = triton.next_power_of_2(N) 

352 TILE_M = triton.cdiv(1024, TILE_N) 

353 grid = (triton.cdiv(M, TILE_M), 1, 1) 

354 layer_norm_persistent_kernel_multiline[grid]( 

355 input, 

356 y, 

357 weight, 

358 bias, 

359 mean, 

360 rstd, 

361 M, 

362 N, 

363 eps, 

364 TILE_M, 

365 TILE_N, 

366 ) 

367 elif N <= 4096: 

368 TILE_N = triton.next_power_of_2(N) 

369 grid = (M, 1, 1) 

370 layer_norm_persistent_kernel[grid]( 

371 input, 

372 y, 

373 weight, 

374 bias, 

375 mean, 

376 rstd, 

377 M, 

378 N, 

379 eps, 

380 TILE_N, 

381 ) 

382 else: 

383 grid = (M, 1, 1) 

384 layer_norm_loop_kernel[grid]( 

385 input, 

386 y, 

387 weight, 

388 bias, 

389 mean, 

390 rstd, 

391 M, 

392 N, 

393 eps, 

394 ) 

395 return y, mean, rstd 

396 

397 

398def layer_norm_backward( 

399 grad_out, 

400 input, 

401 normalized_shape, 

402 mean, 

403 rstd, 

404 weight=None, 

405 bias=None, 

406 output_mask=None, 

407): 

408 logger.debug("GEMS LAYERNORM BACKWARD") 

409 

410 grad_out = grad_out.contiguous() 

411 input = input.contiguous() 

412 mean = mean.contiguous() 

413 rstd = rstd.contiguous() 

414 weight = None if weight is None else weight.contiguous() 

415 bias = None if bias is None else bias.contiguous() 

416 

417 M = input.shape[0] 

418 N = input.numel() // M 

419 

420 if output_mask[0]: 

421 in_grad = torch.empty_like(input) 

422 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_ROW_SIZE"]), 1, 1) 

423 with torch_device_fn.device(input.device): 

424 layer_norm_backward_kernel[grid]( 

425 grad_out, input, weight, mean, rstd, in_grad, M, N 

426 ) 

427 else: 

428 in_grad = None 

429 

430 if output_mask[1] is False and output_mask[2] is False: 

431 return in_grad, None, None 

432 

433 grid = lambda meta: (triton.cdiv(N, meta["BLOCK_COL_SIZE"]), 1, 1) 

434 weight_grad = torch.empty_like(weight) if output_mask[1] else None 

435 bias_grad = torch.empty_like(bias) if output_mask[2] else None 

436 with torch_device_fn.device(input.device): 

437 weight_bias_backward_kernel[grid]( 

438 grad_out, input, mean, rstd, weight_grad, bias_grad, M, N 

439 ) 

440 return in_grad, weight_grad, bias_grad