Coverage for src/flag_gems/runtime/backend/_cambricon/ops/layernorm.py: 0%
341 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry
11from flag_gems.utils.type_utils import get_accumulator_dtype
13from ..utils import TOTAL_CORE_NUM
15logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
16MAX_C_MLU_LAYERNORM_FORWARD = 8192
17MAX_C_MLU_LAYERNORM_BACKWARD = 5120
20@libentry()
21@triton.autotune(
22 configs=runtime.get_tuned_config("layer_norm_persistent"),
23 key=["M", "N"],
24)
25@triton.jit(do_not_specialize=["eps"])
26def layer_norm_kernel_middle_n(
27 X,
28 Y,
29 W,
30 B,
31 Mean, # pointer to the mean
32 Rstd, # pointer to the 1/std
33 M,
34 eps,
35 N: tl.constexpr,
36 BLOCK_ROW_SIZE: tl.constexpr,
37):
38 pid = tl.program_id(0)
39 row_start = pid * BLOCK_ROW_SIZE
40 num_jobs = tl.num_programs(axis=0)
41 step = num_jobs * BLOCK_ROW_SIZE
43 cols_n = tl.arange(0, N)
44 X += cols_n[None, :]
45 Y += cols_n[None, :]
46 cols_off = tl.arange(0, N)[None, :]
47 if W is None:
48 w = 1
49 else:
50 w = tl.load(W + cols_off)
51 if B is None:
52 b = 0
53 else:
54 b = tl.load(B + cols_off)
55 for row in range(row_start, M, step):
56 row_off = row + tl.arange(0, BLOCK_ROW_SIZE)
57 mask = row_off[:, None] < M
58 off = row_off[:, None] * N
59 x = tl.load(X + off, mask, other=0.0).to(tl.float32)
61 # TODO: Use the following code as a fallback once the optimization for trans is complete.
62 # mean = tl.sum(x_v, axis=1) / N
63 # var = tl.sum(x_v * x_v, axis=1) / N - (mean * mean)
64 # mean_bc = mean[:, None]
66 x_v = tl.view(x, (BLOCK_ROW_SIZE, N))
67 x_trans = tl.trans(x_v)
68 mean = tl.sum(x_trans, axis=0) / N
69 mean_bc = mean[:, None]
70 tl.store(Mean + row_off[:, None], mean_bc, mask)
71 var = tl.sum(x_trans * x_trans, axis=0) / N - (mean * mean)
72 var = var[:, None]
73 rstd = 1 / tl.sqrt(var + eps)
74 tl.store(Rstd + row_off[:, None], rstd, mask)
75 x = x - mean_bc
76 x_hat = x * rstd
77 y = x_hat * w + b
78 tl.store(Y + off, y, mask=mask)
81def config_prune(configs, named_args, **kwargs):
82 M = named_args["M"]
83 pruned_configs = []
84 for config in configs:
85 BLOCK_M = config.kwargs["BLOCK_ROW_SIZE"]
86 if (M >= 1024 and BLOCK_M >= 22) or (M < 1024 and BLOCK_M < 22):
87 pruned_configs.append(config)
88 return pruned_configs
91def cfggen():
92 configs = [
93 triton.Config({"BLOCK_ROW_SIZE": 2}, num_warps=1, num_stages=1),
94 triton.Config({"BLOCK_ROW_SIZE": 8}, num_warps=1, num_stages=1),
95 triton.Config({"BLOCK_ROW_SIZE": 14}, num_warps=1, num_stages=1),
96 triton.Config({"BLOCK_ROW_SIZE": 22}, num_warps=1, num_stages=1),
97 triton.Config({"BLOCK_ROW_SIZE": 32}, num_warps=1, num_stages=1),
98 ]
99 return configs
102@libentry()
103@triton.autotune(
104 configs=cfggen(),
105 key=["M", "N"],
106 prune_configs_by={"early_config_prune": config_prune},
107)
108@triton.jit(do_not_specialize=["eps"])
109def layer_norm_kernel_non_inner(
110 X,
111 Y,
112 W,
113 B,
114 Mean, # pointer to the mean
115 Rstd, # pointer to the 1/std
116 M,
117 N,
118 eps,
119 BLOCK_ROW_SIZE: tl.constexpr,
120 BLOCK_COL_SIZE: tl.constexpr,
121):
122 # Map the program id to the row of X and Y it should compute.
123 pid = tl.program_id(0)
124 row = pid * BLOCK_ROW_SIZE + tl.arange(0, BLOCK_ROW_SIZE)[:, None]
125 row_mask = row < M
126 X += row * N
127 Y += row * N
128 # BLOCK_COL_SIZE = N
130 # Compute mean
131 _mean = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
132 # Compute variance
133 _var = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
134 # for off in range(0, N, BLOCK_COL_SIZE):
135 cols = tl.arange(0, BLOCK_COL_SIZE)[None, :]
136 col_mask = cols < N
137 mask = row_mask and col_mask
138 a = tl.load(X + cols, mask, other=0.0).to(tl.float32)
139 _mean += a
140 _var += a * a
141 mean = tl.sum(_mean, axis=1) / N
142 mean_bc = mean[:, None]
144 a = tl.where(col_mask, a - mean_bc, 0.0)
145 # Write mean / rstd
146 tl.store(Mean + row, mean_bc, row_mask)
147 var = tl.sum(_var, axis=1) / N - (mean * mean)
148 var = var[:, None]
149 rstd = 1 / tl.sqrt(var + eps)
150 x_hat = a * rstd
151 tl.store(Rstd + row, rstd, row_mask)
153 # Normalize and apply linear transformation
154 if W is None:
155 w = 1
156 else:
157 w = tl.load(W + cols, col_mask)
158 if B is None:
159 b = 0
160 else:
161 b = tl.load(B + cols, col_mask)
162 y = x_hat * w + b
163 # Write output
164 tl.store(Y + cols, y, mask=mask)
167@libentry()
168@triton.autotune(
169 configs=runtime.get_tuned_config("layer_norm_loop"),
170 key=["M", "N"],
171 prune_configs_by={"early_config_prune": config_prune},
172)
173@triton.jit(do_not_specialize=["eps"])
174def layer_norm_kernel_inner(
175 X,
176 Y,
177 W,
178 B,
179 Mean, # pointer to the mean
180 Rstd, # pointer to the 1/std
181 M,
182 eps,
183 N: tl.constexpr,
184 BLOCK_ROW_SIZE: tl.constexpr,
185 BLOCK_COL_SIZE: tl.constexpr,
186):
187 # Map the program id to the row of X and Y it should compute.
188 pid = tl.program_id(0)
189 row = pid * BLOCK_ROW_SIZE + tl.arange(0, BLOCK_ROW_SIZE)[:, None]
190 row_mask = row < M
191 X += row * N
192 Y += row * N
194 # Compute mean
195 _mean = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
196 # Compute variance
197 _var = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
198 block_col_size = tl.arange(0, BLOCK_COL_SIZE)[None, :]
199 for off in range(0, N, BLOCK_COL_SIZE):
200 cols = off + block_col_size
201 col_mask = cols < N
202 mask = row_mask and col_mask
203 a = tl.load(X + cols, mask, other=0.0).to(tl.float32)
204 _mean += a
205 _var += a * a
207 mean = tl.sum(_mean, axis=1) / N
208 mean_bc = mean[:, None]
210 var = tl.sum(_var, axis=1) / N - (mean * mean)
211 var = var[:, None]
212 rstd = 1 / tl.sqrt(var + eps)
213 # Write mean / rstd
214 tl.store(Mean + row, mean_bc, row_mask)
215 tl.store(Rstd + row, rstd, row_mask)
217 # Normalize and apply linear transformation
218 for off in range(0, N, BLOCK_COL_SIZE):
219 cols = off + block_col_size
220 col_mask = cols < N
221 mask = row_mask and col_mask
222 if W is None:
223 w = 1
224 else:
225 w = tl.load(W + cols, col_mask)
226 if B is None:
227 b = 0
228 else:
229 b = tl.load(B + cols, col_mask)
230 x = tl.load(X + cols, mask, other=0.0).to(tl.float32)
231 x = tl.where(col_mask, x - mean_bc, 0.0)
232 x_hat = x * rstd
233 y = x_hat * w + b
234 # Write output
235 tl.store(Y + cols, y, mask=mask)
238def prune_in_wb_config(configs, named_args, **kwargs):
239 M = named_args["M"]
240 pruned_configs = []
241 for config in configs:
242 BLOCK_M = config.kwargs["BLOCK_ROW_SIZE"]
243 if M // BLOCK_M < 1:
244 continue
245 pruned_configs.append(config)
246 return pruned_configs
249@libentry()
250@triton.autotune(
251 configs=runtime.get_tuned_config("weight_bias_backward"),
252 prune_configs_by={"early_config_prune": prune_in_wb_config},
253 key=["M", "N"],
254)
255@triton.jit
256def input_backward_kernel(
257 dY,
258 X,
259 W,
260 Mean,
261 Rstd,
262 dX,
263 M,
264 N,
265 BLOCK_ROW_SIZE: tl.constexpr,
266 BLOCK_COL_SIZE: tl.constexpr,
267):
268 pid = tl.program_id(0)
270 row_start = pid * BLOCK_ROW_SIZE
271 num_jobs = tl.num_programs(axis=0)
272 step = num_jobs * BLOCK_ROW_SIZE
274 for row in range(row_start, M, step):
275 row_off = row + tl.arange(0, BLOCK_ROW_SIZE)
276 mean = tl.load(Mean + row_off, mask=row_off < M, other=0.0)[:, None].to(
277 tl.float32
278 )
279 rstd = tl.load(Rstd + row_off, mask=row_off < M, other=0.0)[:, None].to(
280 tl.float32
281 )
283 row_mask = row_off[:, None] < M
284 off = row_off[:, None] * N
285 new_dY = dY + off
286 new_X = X + off
287 new_DX = dX + off
289 dx_part2 = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
290 dx_part3 = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
292 for off in range(0, N, BLOCK_COL_SIZE):
293 cols = off + tl.arange(0, BLOCK_COL_SIZE)
294 col_mask = cols[None, :] < N
295 mask = row_mask and col_mask
296 dy = tl.load(new_dY + cols[None, :], mask, other=0.0).to(tl.float32)
297 x = tl.load(new_X + cols[None, :], mask, other=0.0).to(tl.float32)
298 x_hat = (x - mean) * rstd
299 if W is None:
300 wdy = dy
301 else:
302 w = tl.load(W + cols, mask=cols < N).to(tl.float32)
303 wdy = dy * w
304 dx_part2 += wdy
305 dx_part3 += wdy * x_hat
307 dx_part2_trans = tl.trans(dx_part2)
308 dx_2 = tl.sum(dx_part2_trans, axis=0)[:, None]
309 dx_part3_trans = tl.trans(dx_part3)
310 dx_3 = tl.sum(dx_part3_trans, axis=0)[:, None]
312 for off in range(0, N, BLOCK_COL_SIZE):
313 cols = off + tl.arange(0, BLOCK_COL_SIZE)
314 col_mask = cols[None, :] < N
315 mask = row_mask and col_mask
316 dy = tl.load(new_dY + cols[None, :], mask, other=0.0).to(tl.float32)
317 x = tl.load(new_X + cols[None, :], mask, other=0.0).to(tl.float32)
318 if W is None:
319 wdy = dy
320 else:
321 w = tl.load(W + cols, mask=cols < N, other=0.0).to(tl.float32)
322 wdy = dy * w
323 x_hat = (x - mean) * rstd
324 dx = rstd * (wdy - (dx_2 + x_hat * dx_3) / N)
325 tl.store(new_DX + cols, dx.to(x.dtype), mask=mask)
328@libentry()
329@triton.autotune(
330 configs=runtime.get_tuned_config("weight_bias_backward"),
331 prune_configs_by={"early_config_prune": prune_in_wb_config},
332 key=["M", "N"],
333)
334@triton.jit
335def weight_bias_backward_kernel(
336 dY,
337 X,
338 Mean,
339 Rstd,
340 dW,
341 dB,
342 M,
343 N,
344 BLOCK_ROW_SIZE: tl.constexpr,
345 BLOCK_COL_SIZE: tl.constexpr,
346):
347 pid = tl.program_id(0)
349 col_start = pid * BLOCK_COL_SIZE
350 num_jobs = tl.num_programs(axis=0)
351 step = num_jobs * BLOCK_COL_SIZE
353 for col in range(col_start, N, step):
354 col_off = col + tl.arange(0, BLOCK_COL_SIZE)[None, :]
355 col_mask = col_off < N
357 new_dY = dY + col_off
358 new_X = X + col_off
359 new_dW = dW + col_off
360 new_dB = dB + col_off
362 accW = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
363 accB = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
365 for off in range(0, M, BLOCK_ROW_SIZE):
366 rows = off + tl.arange(0, BLOCK_ROW_SIZE)
367 row_mask = rows[:, None] < M
368 mask = row_mask and col_mask
369 dy = tl.load(new_dY + rows[:, None] * N, mask, other=0.0).to(tl.float32)
370 x = tl.load(new_X + rows[:, None] * N, mask, other=0.0).to(tl.float32)
371 mean = tl.load(Mean + rows, mask=rows < M, other=0.0)[:, None].to(
372 tl.float32
373 )
374 rstd = tl.load(Rstd + rows, mask=rows < M, other=0.0)[:, None].to(
375 tl.float32
376 )
377 x_hat = (x - mean) * rstd
378 accW += dy * x_hat
379 accB += dy
380 dw = tl.sum(accW, axis=0)
381 db = tl.sum(accB, axis=0)
382 tl.store(new_dW, dw[None, :], mask=col_mask)
383 tl.store(new_dB, db[None, :], mask=col_mask)
386def cfggen_bw_middle_n():
387 block_m = [1, 2, 4, 8, 12, 18, 22, 32]
389 warps = [1]
390 num_stages = [1, 3]
391 configs = [
392 triton.Config(
393 {
394 "BLOCK_ROW_SIZE": m,
395 },
396 num_warps=w,
397 num_stages=s,
398 )
399 for m in block_m
400 for w in warps
401 for s in num_stages
402 ]
403 return configs
406# Set [DW, DB] to zero, can't use reset_to_zero here for DW/DB could be None.
407def pre_hook(args, reset_only=True):
408 for i in ["DW", "DB"]:
409 if args[i] is not None:
410 args[i].zero_()
413@libentry()
414@triton.autotune(
415 configs=cfggen_bw_middle_n(),
416 key=["M", "N"],
417 pre_hook=pre_hook,
418)
419@triton.jit
420def layer_norm_backward_kernel_middle_n(
421 DX, # pointer to the input gradient
422 DY, # pointer to the output gradient
423 DW, # pointer to the partial sum of weights gradient
424 DB, # pointer to the partial sum of biases gradient
425 X, # pointer to the input
426 W, # pointer to the weights
427 Mean, # pointer to the mean
428 Rstd, # pointer to the 1/std
429 M, # number of rows in X
430 N: tl.constexpr, # number of columns in X
431 BLOCK_ROW_SIZE: tl.constexpr,
432):
433 pid = tl.program_id(0)
435 row_start = pid * BLOCK_ROW_SIZE
436 cols = tl.arange(0, N)
437 num_jobs = tl.num_programs(axis=0)
438 step = num_jobs * BLOCK_ROW_SIZE
440 X += cols[None, :]
441 DY += cols[None, :]
442 DX += cols[None, :]
443 if W is None:
444 w = 1
445 else:
446 W += cols[None, :]
447 w = tl.load(W).to(tl.float32)
449 if DW is not None:
450 partial_dw = tl.zeros([BLOCK_ROW_SIZE, N], dtype=tl.float32)
451 if DB is not None:
452 partial_db = tl.zeros([BLOCK_ROW_SIZE, N], dtype=tl.float32)
453 for row in range(row_start, M, step):
454 row_off = row + tl.arange(0, BLOCK_ROW_SIZE)
455 mask = row_off[:, None] < M
456 # Load data to SRAM
457 off = row_off[:, None] * N
458 x = tl.load(X + off, mask, other=0.0).to(tl.float32)
459 dy = tl.load(DY + off, mask, other=0.0).to(tl.float32)
460 mean = tl.load(Mean + row_off, mask=row_off < M)[:, None].to(tl.float32)
461 rstd = tl.load(Rstd + row_off, mask=row_off < M)[:, None].to(tl.float32)
462 # Compute dx
463 x_hat = (x - mean) * rstd
464 wdy = w * dy
465 x_hat_dy = x_hat * wdy
466 x_hat_dy = tl.view(x_hat_dy, (BLOCK_ROW_SIZE, N))
467 x_hat_dy_trans = tl.trans(x_hat_dy)
468 c1 = tl.sum(x_hat_dy_trans, axis=0)[:, None]
470 wdy_v = tl.view(wdy, (BLOCK_ROW_SIZE, N))
471 wdy_v_trans = tl.trans(wdy_v)
472 c2 = tl.sum(wdy_v_trans, axis=0)[:, None]
473 dx = (wdy - (x_hat * c1 + c2) / N) * rstd
474 # Write dx
475 tl.store(DX + off, dx.to(x.dtype), mask=mask)
477 # Accumulate partial sums for dw/db
478 if DW is not None:
479 partial_dw += (dy * x_hat).to(tl.float32)
480 if DB is not None:
481 partial_db += (dy).to(tl.float32)
483 if DW is not None:
484 dw = tl.sum(partial_dw, axis=0)
485 tl.atomic_add(DW + cols, dw)
486 if DB is not None:
487 db = tl.sum(partial_db, axis=0)
488 tl.atomic_add(DB + cols, db)
491def layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-5):
492 logger.debug("GEMS_CAMBRICON LAYERNORM FORWARD")
493 # dim = x.ndim - len(normalized_shape)
494 # M = math.prod(x.shape[:dim])
495 N = math.prod(normalized_shape)
496 M = input.numel() // N
497 input = input.contiguous()
498 if weight is not None:
499 weight = weight.contiguous()
500 if bias is not None:
501 bias = bias.contiguous()
502 y = torch.empty_like(input)
503 acc_type = get_accumulator_dtype(input.dtype)
504 mean = torch.empty(M, dtype=acc_type, device=input.device)
505 rstd = torch.empty(M, dtype=acc_type, device=input.device)
506 if N <= MAX_C_MLU_LAYERNORM_FORWARD:
507 grid = lambda META: (
508 min(triton.cdiv(M, META["BLOCK_ROW_SIZE"]), TOTAL_CORE_NUM),
509 )
510 with torch_device_fn.device(input.device):
511 layer_norm_kernel_middle_n[grid](
512 input, y, weight, bias, mean, rstd, M, eps, N
513 )
514 else:
515 grid = lambda META: (triton.cdiv(M, META["BLOCK_ROW_SIZE"]),)
516 with torch_device_fn.device(input.device):
517 layer_norm_kernel_inner[grid](input, y, weight, bias, mean, rstd, M, eps, N)
518 return y, mean, rstd
521def layer_norm_backward(
522 grad_out,
523 input,
524 normalized_shape,
525 mean,
526 rstd,
527 weight=None,
528 bias=None,
529 output_mask=None,
530):
531 logger.debug("GEMS_CAMBRICON LAYERNORM BACKWARD")
532 grad_out = grad_out.contiguous()
533 input = input.contiguous()
534 mean = mean.contiguous()
535 rstd = rstd.contiguous()
536 weight = None if weight is None else weight.contiguous()
537 bias = None if bias is None else bias.contiguous()
539 M = input.shape[0]
540 N = input.numel() // M
542 if N <= MAX_C_MLU_LAYERNORM_BACKWARD:
543 in_grad = torch.empty_like(grad_out)
544 if weight is None:
545 weight_grad = None
546 else:
547 weight_grad = torch.zeros(
548 (weight.shape[0],), dtype=torch.float, device=weight.device
549 )
550 if bias is None:
551 bias_grad = None
552 else:
553 bias_grad = torch.zeros(
554 (weight.shape[0],), dtype=torch.float, device=weight.device
555 )
556 # enqueue kernel using forward pass heuristics
557 # also compute partial sums for DW and DB
558 grid = lambda META: (
559 min(triton.cdiv(M, META["BLOCK_ROW_SIZE"]), TOTAL_CORE_NUM),
560 )
561 with torch_device_fn.device(input.device):
562 layer_norm_backward_kernel_middle_n[grid](
563 in_grad,
564 grad_out,
565 weight_grad,
566 bias_grad,
567 input,
568 weight,
569 mean,
570 rstd,
571 M=M,
572 N=N,
573 )
574 if weight_grad is not None:
575 weight_grad = weight_grad.to(input.dtype)
576 if bias_grad is not None:
577 bias_grad = bias_grad.to(input.dtype)
578 else:
579 in_grad = torch.empty_like(input)
580 grid = lambda META: (
581 min(triton.cdiv(M, META["BLOCK_ROW_SIZE"]), TOTAL_CORE_NUM),
582 )
583 input_backward_kernel[grid](
584 grad_out,
585 input,
586 weight,
587 mean,
588 rstd,
589 in_grad,
590 M,
591 N,
592 )
593 if weight is None and bias is None:
594 return in_grad, None, None
596 with torch_device_fn.device(input.device):
597 grid = lambda META: (
598 min(triton.cdiv(N, META["BLOCK_COL_SIZE"]), TOTAL_CORE_NUM),
599 )
600 weight_grad = None if weight is None else torch.empty_like(weight)
601 bias_grad = None if bias is None else torch.empty_like(bias)
602 weight_bias_backward_kernel[grid](
603 grad_out,
604 input,
605 mean,
606 rstd,
607 weight_grad,
608 bias_grad,
609 M,
610 N,
611 )
613 return in_grad, weight_grad, bias_grad