Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/layernorm.py: 0%
295 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-17 02:35 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-17 02:35 +0800
1import logging
2import math
3import os
5import torch
6import triton
7import triton.language as tl
9# from flag_gems import runtime
10from flag_gems.runtime import torch_device_fn
11from flag_gems.utils import libentry
12from flag_gems.utils import triton_lang_extension as tle
14logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
17@triton.jit
18def prev_multiple_of(a, b):
19 # the largest x<a that x%b ==0
20 return tl.cdiv(a, b) * b - b
23@libentry()
24# @triton.autotune(
25# configs=runtime.get_tuned_config("layer_norm_persistent"),
26# key=["M", "N"],
27# )
28@triton.jit(do_not_specialize=["eps"])
29def layer_norm_persistent_kernel(
30 in_ptr,
31 out_ptr,
32 weight_ptr,
33 bias_ptr,
34 out_mean_ptr, # pointer to the mean
35 out_rstd_ptr, # pointer to the 1/std
36 M,
37 N,
38 eps,
39 TILE_N: tl.constexpr,
40):
41 # using 1d tile makes code clean
42 # Map the program id to the row of X and Y it should compute.
43 pid = tle.program_id(0)
45 n_offsets = tl.arange(0, TILE_N)
46 mask = n_offsets < N
48 x = tl.load(in_ptr + pid * N + n_offsets, mask, other=0.0).to(tl.float32)
49 m = tl.sum(x) / N
50 d = x - m # deviation
51 s = tl.where(mask, d * d, 0)
52 sum_square = tl.sum(s) # sum of square of deviation
53 var = sum_square / N
54 rstd = tl.math.rsqrt(var + eps)
56 tl.store(out_mean_ptr + pid, m)
57 tl.store(out_rstd_ptr + pid, rstd)
59 if weight_ptr is None:
60 w = 1
61 else:
62 w = tl.load(weight_ptr + n_offsets, mask=mask)
63 if bias_ptr is None:
64 b = 0
65 else:
66 b = tl.load(bias_ptr + n_offsets, mask=mask)
67 out = (x - m) * rstd * w + b
69 tl.store(out_ptr + pid * N + n_offsets, out, mask=mask)
72@libentry()
73# @triton.autotune(
74# configs=runtime.get_tuned_config("layer_norm_persistent"),
75# key=["M", "N"],
76# )
77@triton.jit(do_not_specialize=["eps"])
78def layer_norm_persistent_kernel_multiline(
79 in_ptr,
80 out_ptr,
81 weight_ptr,
82 bias_ptr,
83 out_mean_ptr, # pointer to the mean
84 out_rstd_ptr, # pointer to the 1/std
85 M,
86 N,
87 eps,
88 TILE_M: tl.constexpr,
89 TILE_N: tl.constexpr,
90):
91 # Map the program id to the row of X and Y it should compute.
92 pid = tle.program_id(0)
93 m_offsets = pid * TILE_M + tl.arange(0, TILE_M)
94 m_mask = m_offsets < M
96 n_offsets = tl.arange(0, TILE_N)[None, :]
97 n_mask = n_offsets < N
98 mask = m_mask[:, None] & n_mask
100 x = tl.load(in_ptr + m_offsets[:, None] * N + n_offsets, mask, other=0.0).to(
101 tl.float32
102 )
103 m = tl.sum(x, axis=1) / N
104 d = x - m[:, None] # deviation
105 s = tl.where(mask, d * d, 0)
106 sum_square = tl.sum(s, axis=1) # sum of square of deviation
107 var = sum_square / N
108 rstd = tl.math.rsqrt(var + eps)
110 tl.store(out_mean_ptr + m_offsets, m, mask=m_mask)
111 tl.store(out_rstd_ptr + m_offsets, rstd, mask=m_mask)
113 if weight_ptr is None:
114 w = 1
115 else:
116 w = tl.load(weight_ptr + n_offsets, mask=n_mask)
117 if bias_ptr is None:
118 b = 0
119 else:
120 b = tl.load(bias_ptr + n_offsets, mask=n_mask)
121 out = (x - m[:, None]) * rstd[:, None] * w + b
123 tl.store(out_ptr + m_offsets[:, None] * N + n_offsets, out, mask=mask)
126@libentry()
127# @triton.autotune(
128# configs=runtime.get_tuned_config("layer_norm_loop"),
129# key=["M", "N"],
130# )
131@triton.jit(do_not_specialize=["eps"])
132def layer_norm_loop_kernel(
133 in_ptr,
134 out_ptr,
135 weight_ptr,
136 bias_ptr,
137 out_mean_ptr, # pointer to the mean
138 out_rstd_ptr, # pointer to the 1/std
139 M: tl.constexpr,
140 N: tl.constexpr,
141 eps,
142 TILE_N: tl.constexpr,
143):
144 # Map the program id to the row of X and Y it should compute.
145 pid = tle.program_id(0)
147 # Compute mean
148 m = tl.zeros((TILE_N,), dtype=tl.float32) # mean
149 s = tl.zeros((TILE_N,), dtype=tl.float32) # sum((x - m)^2)
150 cnt = tl.zeros((TILE_N,), dtype=tl.int32)
151 num_steps = tl.cdiv(N, TILE_N)
152 for step in range(0, num_steps - 1, 1):
153 start_n = step * TILE_N
154 n_offsets = start_n + tl.arange(0, TILE_N)
155 x = tl.load(in_ptr + pid * N + n_offsets).to(tl.float32)
156 new_m = m + (x - m) / (step + 1)
157 new_s = s + (x - new_m) * (x - m)
158 cnt += 1
159 m = new_m
160 s = new_s
162 # the last step
163 for step in range(num_steps - 1, num_steps, 1):
164 start_n = step * TILE_N
165 n_offsets = start_n + tl.arange(0, TILE_N)
166 mask = n_offsets < N
167 x = tl.load(in_ptr + pid * N + n_offsets, mask=mask).to(tl.float32)
168 new_m = tl.where(mask, m + (x - m) / (step + 1), m)
169 new_s = tl.where(mask, s + (x - new_m) * (x - m), s)
170 cnt += mask.to(tl.int32)
171 m = new_m
172 s = new_s
174 final_m = tl.sum(m * cnt) / N
175 var = tl.sum(s + cnt * (m - final_m) * (m - final_m)) / N
176 rstd = tl.math.rsqrt(var + eps)
177 m = final_m
179 # reverse the order of the second sweep
180 # Normalize and apply linear transformation
181 prev_multiple = prev_multiple_of(N, TILE_N)
182 # the first step, masking is needed
183 for start_n in range(0, TILE_N, TILE_N):
184 n_offsets = (prev_multiple - start_n) + tl.arange(0, TILE_N)
185 mask = n_offsets < N
186 x = tl.load(
187 in_ptr + pid * N + n_offsets,
188 mask=mask,
189 other=0.0,
190 eviction_policy="evict_first",
191 ).to(tl.float32)
192 if weight_ptr is None:
193 w = 1
194 else:
195 w = tl.load(weight_ptr + n_offsets, mask=mask)
196 if bias_ptr is None:
197 b = 0
198 else:
199 b = tl.load(bias_ptr + n_offsets, mask=mask)
200 out = w * (x - m) * rstd + b
201 tl.store(out_ptr + pid * N + n_offsets, out, mask=mask)
203 for start_n in range(TILE_N, N, TILE_N):
204 n_offsets = (prev_multiple - start_n) + tl.arange(0, TILE_N)
205 x = tl.load(in_ptr + pid * N + n_offsets, eviction_policy="evict_first").to(
206 tl.float32
207 )
208 if weight_ptr is None:
209 w = 1
210 else:
211 w = tl.load(weight_ptr + n_offsets)
212 if bias_ptr is None:
213 b = 0
214 else:
215 b = tl.load(bias_ptr + n_offsets)
216 out = w * (x - m) * rstd + b
217 tl.store(out_ptr + pid * N + n_offsets, out)
219 # Write mean / rstd
220 tl.store(out_mean_ptr + pid, m)
221 tl.store(out_rstd_ptr + pid, rstd)
224@triton.jit
225def layernorm_fwd_kernel(
226 X,
227 Y,
228 W,
229 B,
230 eps,
231 MEAN,
232 RSTRD,
233 xnumel: tl.constexpr,
234 rnumel: tl.constexpr,
235 XBLOCK: tl.constexpr,
236 RBLOCK: tl.constexpr,
237):
238 xoffset = tl.program_id(0) * XBLOCK
239 xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
240 xmask = xindex < xnumel
241 rbase = tl.arange(0, RBLOCK)[None, :]
242 _mean = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
243 _var = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
245 for roffset in range(0, rnumel, RBLOCK):
246 rindex = roffset + rbase
247 rmask = rindex < rnumel
248 x = tl.load(X + (rindex + (rnumel * xindex)), rmask & xmask, other=0.0)
249 _mean = _mean + tl.broadcast_to(x, [XBLOCK, RBLOCK])
250 _var = _var + tl.broadcast_to(x * x, [XBLOCK, RBLOCK])
252 mean = tl.sum(_mean, 1)[:, None] / rnumel
253 var = tl.sum(_var, 1)[:, None] / rnumel
254 var_mean = var - mean * mean
255 rstd = 1 / tl.sqrt(var_mean + eps)
256 # rstd = tl.math.rsqrt(var_mean + eps)
258 tl.store(MEAN + xindex, mean, xmask)
259 tl.store(RSTRD + xindex, rstd, xmask)
261 for roffset in range(0, rnumel, RBLOCK):
262 rindex = roffset + rbase
263 rmask = rindex < rnumel
264 x = tl.load(X + (rindex + (rnumel * xindex)), rmask & xmask, other=0.0)
265 if W is None:
266 w = 1
267 else:
268 w = tl.load(W + (rindex), rmask)
269 if B is None:
270 b = 0
271 else:
272 b = tl.load(B + (rindex), rmask)
273 x_hat = (x - mean) * rstd
274 y = x_hat * w + b
275 tl.store(Y + (rindex + (rnumel * xindex)), y, rmask & xmask)
278def layer_norm_backward_kernel_heur_block_row_size(args):
279 # if args["dX"].dtype == torch.bfloat16 and args["M"] == 100 and args["N"] == 40499:
280 # return args["M"]
281 return triton.next_power_of_2(triton.cdiv(args["M"], 12))
282 # return 1
285def layer_norm_backward_kernel_heur_block_col_size(args):
286 if args["dX"].dtype == torch.float32 and args["M"] == 1 and args["N"] == 40999:
287 return 4096 # 8192 cause leagalize error
289 if args["M"] == 100 and args["N"] == 40499:
290 return 4096 # 8192 cause leagalize error
292 import builtins
294 return builtins.min(args["N"], 8192)
297@libentry()
298# @triton.autotune(
299# configs=runtime.get_tuned_config("layer_norm_backward"),
300# key=["M", "N"],
301# )
302@triton.heuristics(
303 values={
304 "BLOCK_ROW_SIZE": layer_norm_backward_kernel_heur_block_row_size,
305 "BLOCK_COL_SIZE": layer_norm_backward_kernel_heur_block_col_size,
306 },
307)
308@triton.jit
309def layer_norm_backward_kernel(
310 dY,
311 X,
312 W,
313 Mean,
314 Rstd,
315 dX,
316 M: tl.constexpr,
317 N: tl.constexpr,
318 BLOCK_ROW_SIZE: tl.constexpr,
319 BLOCK_COL_SIZE: tl.constexpr,
320):
321 pid = tle.program_id(0) * BLOCK_ROW_SIZE + tl.arange(0, BLOCK_ROW_SIZE)[:, None]
322 row_mask = pid < M
323 dY += pid * N
324 X += pid * N
325 dX += pid * N
326 Mean += pid
327 Rstd += pid
329 mean = tl.load(Mean, mask=row_mask).to(tl.float32)
330 rstd = tl.load(Rstd, mask=row_mask).to(tl.float32)
332 dx_part2 = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
333 dx_part3 = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
335 for off in range(0, N, BLOCK_COL_SIZE):
336 cols = off + tl.arange(0, BLOCK_COL_SIZE)
337 col_mask = cols[None, :] < N
338 mask = row_mask and col_mask
339 dy = tl.load(dY + cols[None, :], mask).to(tl.float32)
340 x = tl.load(X + cols[None, :], mask).to(tl.float32)
341 x = tl.where(mask, x - mean, 0.0)
342 x_hat = x * rstd
343 if W is None:
344 w = 1
345 else:
346 w = tl.load(W + cols, mask=cols < N).to(tl.float32)
347 dx_hat = dy * w
348 dx_part2 += dx_hat
349 dx_part3 += dx_hat * x_hat
351 dx_2 = tl.sum(dx_part2, axis=1)[:, None]
352 dx_3 = tl.sum(dx_part3, axis=1)[:, None]
354 for off in range(0, N, BLOCK_COL_SIZE):
355 cols = off + tl.arange(0, BLOCK_COL_SIZE)
356 col_mask = cols[None, :] < N
357 mask = row_mask and col_mask
358 dy = tl.load(dY + cols[None, :], mask).to(tl.float32)
359 x = tl.load(X + cols[None, :], mask).to(tl.float32)
360 if W is None:
361 w = 1
362 else:
363 w = tl.load(W + cols, mask=cols < N).to(tl.float32)
364 x = tl.where(mask, x - mean, 0.0)
365 x_hat = x * rstd
366 dx_hat = dy * w
367 dx = rstd * (dx_hat - (dx_2 + x_hat * dx_3) / N)
368 tl.store(dX + cols, dx, mask=mask)
371def weight_bias_backward_kernel_heur_block_row_size(args):
372 return 1
375def weight_bias_backward_kernel_heur_block_col_size(args):
376 # if args["M"] == 100 and args["N"] == 40499:
377 # if args["dY"].dtype == torch.bfloat16:
378 # return 2048
379 # return 4096 # 8192 cause leagalize error
381 import builtins
383 return builtins.min(args["N"], 8192)
386@libentry()
387# @triton.autotune(
388# configs=runtime.get_tuned_config("weight_bias_backward"),
389# key=["N"],
390# )
391@triton.heuristics(
392 values={
393 "BLOCK_ROW_SIZE": weight_bias_backward_kernel_heur_block_row_size,
394 "BLOCK_COL_SIZE": weight_bias_backward_kernel_heur_block_col_size,
395 },
396)
397@triton.jit
398def weight_bias_backward_kernel(
399 dY,
400 X,
401 Mean,
402 Rstd,
403 dW,
404 dB,
405 M: tl.constexpr,
406 N: tl.constexpr,
407 BLOCK_ROW_SIZE: tl.constexpr,
408 BLOCK_COL_SIZE: tl.constexpr,
409):
410 pid = tle.program_id(0) * BLOCK_COL_SIZE + tl.arange(0, BLOCK_COL_SIZE)[None, :]
411 col_mask = pid < N
412 dY += pid
413 X += pid
414 accW = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
415 accB = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
416 for off in range(0, M, BLOCK_ROW_SIZE):
417 rows = off + tl.arange(0, BLOCK_ROW_SIZE)
418 row_mask = rows[:, None] < M
419 mask = row_mask and col_mask
420 dy = tl.load(dY + rows[:, None] * N, mask).to(tl.float32)
421 x = tl.load(X + rows[:, None] * N, mask).to(tl.float32)
422 mean = tl.load(Mean + rows, mask=rows < M)[:, None].to(tl.float32)
423 rstd = tl.load(Rstd + rows, mask=rows < M)[:, None].to(tl.float32)
424 x = tl.where(col_mask, x - mean, 0.0)
425 x_hat = x * rstd
426 accW += dy * x_hat
427 accB += dy
428 if dW is not None:
429 dw = tl.sum(accW, axis=0)
430 tl.store(dW + pid, dw[None, :], mask=col_mask)
431 if dB is not None:
432 db = tl.sum(accB, axis=0)
433 tl.store(dB + pid, db[None, :], mask=col_mask)
436def layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-5):
437 logger.debug("GEMS LAYERNORM FORWARD")
439 N = math.prod(normalized_shape)
440 M = input.numel() // N
442 input = input.contiguous()
443 weight = None if weight is None else weight.contiguous()
444 bias = None if bias is None else bias.contiguous()
445 y = torch.empty_like(input)
447 # NOTE: when the input is half-precision(either float16 or bfloat16)
448 # these statistical data saved for backward is in single precision
449 mean = torch.empty(M, dtype=input.dtype, device=input.device)
450 rstd = torch.empty(M, dtype=input.dtype, device=input.device)
452 with torch_device_fn.device(input.device):
453 if input.dtype == torch.float16 and input.shape == (4096, 100):
454 TILE_N = 8192 # triton.next_power_of_2(N)
455 grid = (M, 1, 1)
456 layer_norm_loop_kernel[grid](
457 input,
458 y,
459 weight,
460 bias,
461 mean,
462 rstd,
463 M,
464 N,
465 eps,
466 TILE_N,
467 isCloseUnrollControl=True,
468 )
469 else:
470 grid = (12, 1, 1)
471 layernorm_fwd_kernel[grid](
472 input,
473 y,
474 weight,
475 bias,
476 eps,
477 mean,
478 rstd,
479 M,
480 N,
481 XBLOCK=triton.next_power_of_2(triton.cdiv(M, 12)),
482 RBLOCK=8192,
483 isCloseUnrollControl=True,
484 buffer_size_limit=512,
485 )
487 return y, mean, rstd
490def layer_norm_backward(
491 grad_out,
492 input,
493 normalized_shape,
494 mean,
495 rstd,
496 weight=None,
497 bias=None,
498 output_mask=None,
499):
500 logger.debug("GEMS LAYERNORM BACKWARD")
502 grad_out = grad_out.contiguous()
503 input = input.contiguous()
504 mean = mean.contiguous()
505 rstd = rstd.contiguous()
506 weight = None if weight is None else weight.contiguous()
507 bias = None if bias is None else bias.contiguous()
509 M = input.shape[0]
510 N = input.numel() // M
512 if output_mask[0]:
513 in_grad = torch.empty_like(input)
514 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_ROW_SIZE"]), 1, 1)
515 os.environ["TRITONXPU_OTHER_SIM"] = "1"
516 os.environ["TRITONXPU_STORE_MASK_SIM"] = "1"
517 os.environ["TRITONXPU_DTYPE_CONVERT"] = "1"
518 if M == 100 and N == 40499:
519 isCloseUnrollControl = True
520 isCloseCoreTiling = True
521 else:
522 isCloseUnrollControl = False
523 isCloseCoreTiling = False
525 with torch_device_fn.device(input.device):
526 layer_norm_backward_kernel[grid](
527 grad_out,
528 input,
529 weight,
530 mean,
531 rstd,
532 in_grad,
533 M,
534 N,
535 isCloseUnrollControl=isCloseUnrollControl,
536 isCloseCoreTiling=isCloseCoreTiling,
537 isCloseVectorization=True,
538 )
539 if "TRITONXPU_OTHER_SIM" in os.environ:
540 del os.environ["TRITONXPU_OTHER_SIM"]
541 if "TRITONXPU_STORE_MASK_SIM" in os.environ:
542 del os.environ["TRITONXPU_STORE_MASK_SIM"]
543 if "TRITONXPU_DTYPE_CONVERT" in os.environ:
544 del os.environ["TRITONXPU_DTYPE_CONVERT"]
545 else:
546 in_grad = None
548 if output_mask[1] is False and output_mask[2] is False:
549 return in_grad, None, None
551 grid = lambda meta: (triton.cdiv(N, meta["BLOCK_COL_SIZE"]), 1, 1)
552 weight_grad = torch.empty_like(weight) if output_mask[1] else None
553 bias_grad = torch.empty_like(bias) if output_mask[2] else None
554 with torch_device_fn.device(input.device):
555 weight_bias_backward_kernel[grid](
556 grad_out,
557 input,
558 mean,
559 rstd,
560 weight_grad,
561 bias_grad,
562 M,
563 N,
564 isCloseCoreTiling=True,
565 isCloseUnrollControl=True,
566 isCloseVectorization=True,
567 )
568 return in_grad, weight_grad, bias_grad