Coverage for src/flag_gems/runtime/backend/_kunlunxin/fused/fused_add_rms_norm.py: 0%
66 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-26 15:32 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-26 15:32 +0800
1import builtins
2import logging
3import math
5import triton
6import triton.language as tl
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry
10from flag_gems.utils import triton_lang_extension as tle
12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
15@libentry()
16@triton.jit(do_not_specialize=["eps"])
17def fused_add_rmsnorm_kernel(
18 X, # pointer to the input
19 R, # pointer to the residual
20 W, # pointer to the weights
21 x_stride_r, # how much to increase the pointer when moving by 1 row
22 x_stride_c, # how much to increase the pointer when moving by 1 col
23 r_stride_r, # how much to increase the pointer when moving by 1 row
24 r_stride_c, # how much to increase the pointer when moving by 1 col
25 N, # number of columns in X
26 eps, # epsilon to avoid division by zero
27 BLOCK_SIZE: tl.constexpr,
28):
29 pid = tle.program_id(0)
30 X += pid * x_stride_r
31 R += pid * r_stride_r
33 mask = tl.arange(0, BLOCK_SIZE) < N
34 cols = tl.arange(0, BLOCK_SIZE)
35 x = tl.load(X + cols * x_stride_c, mask, other=0.0).to(tl.float32)
36 r = tl.load(R + cols * r_stride_c, mask, other=0.0).to(tl.float32)
38 x += r
39 # write back to residual
40 tl.store(R + cols * r_stride_c, x, mask=mask)
42 var = tl.sum(x * x / N, axis=0)
43 rrms = 1 / tl.sqrt(var + eps)
45 w = tl.load(W + tl.arange(0, BLOCK_SIZE), mask=mask, other=0.0)
46 y = (x * rrms).to(X.dtype.element_ty) * w
47 # write back to input
48 tl.store(X + cols * x_stride_c, y, mask=mask)
51@libentry()
52@triton.jit(do_not_specialize=["eps"])
53def fused_add_rmsnorm_kernel_tile(
54 X, # pointer to the input
55 R, # pointer to the residual
56 W, # pointer to the weight
57 x_stride_r, # how much to increase the pointer when moving by 1 row
58 x_stride_c, # how much to increase the pointer when moving by 1 col
59 r_stride_r, # how much to increase the pointer when moving by 1 row
60 r_stride_c, # how much to increase the pointer when moving by 1 col
61 N, # number of columns in X
62 eps, # epsilon to avoid division by zero
63 BLOCK_SIZE: tl.constexpr,
64):
65 pid = tl.program_id(0)
66 X += pid * x_stride_r
67 R += pid * r_stride_r
69 # var = tl.sum(x * x / N, axis=0)
70 # rrms = 1 / tl.sqrt(var + eps)
72 _var_base = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
73 for off in range(0, N, BLOCK_SIZE):
74 cols = off + tl.arange(0, BLOCK_SIZE)
75 mask = cols < N
76 x = tl.load(X + cols, mask, other=0.0).to(tl.float32)
77 r = tl.load(R + cols, mask, other=0.0).to(tl.float32)
78 x += r
79 _var_base += x * x / N
80 var = tl.sum(_var_base)
81 rrms = 1 / tl.sqrt(var + eps)
83 # w = tl.load(W + tl.arange(0, BLOCK_SIZE), mask=mask, other=0.0)
84 # y = (x * rrms).to(Y.dtype.element_ty) * w
85 # tl.store(Y + cols * y_stride_c, y, mask=mask)
87 for off in range(0, N, BLOCK_SIZE):
88 cols = off + tl.arange(0, BLOCK_SIZE)
89 mask = cols < N
90 x = tl.load(X + cols, mask, other=0.0).to(tl.float32)
91 r = tl.load(R + cols, mask, other=0.0).to(tl.float32)
92 x += r
93 w = tl.load(W + cols, mask, other=0.0)
94 y = (x * rrms).to(X.dtype.element_ty) * w
95 # write back to residual and input
96 tl.store(R + cols * r_stride_c, x, mask=mask)
97 tl.store(X + cols * x_stride_c, y, mask=mask)
100def fused_add_rms_norm(x, residual, normalized_shape, weight, eps=1e-5):
101 """
102 This function performs fused residual addition and RMS normalization **in-place**.
103 Both `x` and `residual` tensors will be modified. Use with caution if these tensors
104 are reused elsewhere or require gradients.
105 """
106 logger.debug("GEMS FUSED_ADD_RMS_NORM FORWARD")
107 dim = x.ndim - len(normalized_shape)
108 M = math.prod(x.shape[:dim])
109 N = math.prod(normalized_shape)
111 BLOCK_SIZE = builtins.min(
112 64 * 128, triton.next_power_of_2(N)
113 ) # core_num * buffer_size_limit
114 x = x.contiguous()
115 residual = residual.contiguous()
116 weight = weight.contiguous()
118 with torch_device_fn.device(x.device):
119 if N > 64 * 128:
120 fused_add_rmsnorm_kernel_tile[M,](
121 x, residual, weight, N, 1, N, 1, N, eps, BLOCK_SIZE
122 )
123 else:
124 fused_add_rmsnorm_kernel[M,](
125 x, residual, weight, N, 1, N, 1, N, eps, BLOCK_SIZE
126 )
127 return x, residual