Coverage for src/flag_gems/runtime/backend/_ascend/fused/fused_add_rms_norm.py: 0%
45 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
1import logging
2import math
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils import libentry
10logger = logging.getLogger(__name__)
13@libentry()
14@triton.jit(do_not_specialize=["eps"])
15def fused_add_rms_norm_kernel(
16 X, # pointer to the input
17 R, # pointer to the residual
18 W, # pointer to the weight
19 x_stride_r, # how much to increase the pointer when moving by 1 row
20 x_stride_c, # how much to increase the pointer when moving by 1 col
21 r_stride_r, # how much to increase the pointer when moving by 1 row
22 r_stride_c, # how much to increase the pointer when moving by 1 col
23 N, # number of columns in X
24 eps, # epsilon to avoid division by zero
25 BLOCK_SIZE: tl.constexpr,
26):
27 pid = tl.program_id(0)
28 X += pid * x_stride_r
29 R += pid * r_stride_r
31 _var_base = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
33 for off in range(0, N, BLOCK_SIZE):
34 cols = off + tl.arange(0, BLOCK_SIZE)
35 mask = cols < N
36 x = tl.load(X + cols, mask, other=0.0).to(tl.float32)
37 r = tl.load(R + cols, mask, other=0.0).to(tl.float32)
38 x += r
39 _var_base += x * x / N
40 var = tl.sum(_var_base)
42 rrms = 1 / tl.sqrt(var + eps)
44 for off in range(0, N, BLOCK_SIZE):
45 cols = off + tl.arange(0, BLOCK_SIZE)
46 mask = cols < N
47 x = tl.load(X + cols, mask, other=0.0).to(tl.float32)
48 r = tl.load(R + cols, mask, other=0.0).to(tl.float32)
49 x += r
50 w = tl.load(W + cols, mask, other=0.0)
51 y = (x * rrms).to(X.dtype.element_ty) * w
52 # write back to residual and input
53 tl.store(R + cols * r_stride_c, x, mask=mask)
54 tl.store(X + cols * x_stride_c, y, mask=mask)
57def fused_add_rms_norm(x, residual, normalized_shape, weight, eps=1e-5):
58 """
59 This function performs fused residual addition and RMS normalization **in-place**.
60 Both `x` and `residual` tensors will be modified. Use with caution if these tensors
61 are reused elsewhere or require gradients.
62 """
63 logger.debug("GEMS_ASCEND FUSED_ADD_RMS_NORM FORWARD")
64 dim = x.ndim - len(normalized_shape)
65 M = min(math.prod(x.shape[:dim]), 65535)
66 N = math.prod(normalized_shape)
68 BLOCK_SIZE = min(triton.next_power_of_2(N), 8192)
69 x = x.contiguous()
70 residual = residual.contiguous()
71 weight = weight.contiguous()
73 with torch_device_fn.device(x.device):
74 fused_add_rms_norm_kernel[M,](
75 x, residual, weight, N, 1, N, 1, N, eps, BLOCK_SIZE
76 )
77 return x, residual