Coverage for src/flag_gems/runtime/backend/_sunrise/fused/fused_add_rms_norm.py: 0%
72 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
2import math
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils import libentry
9from flag_gems.utils import triton_lang_extension as ext
11logger = logging.getLogger(__name__)
14@libentry()
15@triton.jit(do_not_specialize=["eps"])
16def fused_add_rms_norm_kernel(
17 input_ptr, # pointer to the input
18 residual_ptr, # pointer to the residual
19 w_ptr, # pointer to the weights
20 in_stride_r, # how much to increase the pointer when moving by 1 row
21 in_stride_c, # how much to increase the pointer when moving by 1 col
22 r_stride_r, # how much to increase the pointer when moving by 1 row
23 r_stride_c, # how much to increase the pointer when moving by 1 col
24 N, # number of columns in in_ptr
25 eps, # epsilon to avoid division by zero
26 BLOCK_SIZE: tl.constexpr,
27):
28 if tl.constexpr(input_ptr.dtype.element_ty == tl.float16) or tl.constexpr(
29 input_ptr.dtype.element_ty == tl.bfloat16
30 ):
31 cdtype = tl.float32
32 else:
33 cdtype = input_ptr.dtype.element_ty
35 pid = ext.program_id(0)
36 input_ptr += pid * in_stride_r
37 residual_ptr += pid * r_stride_r
39 mask = tl.arange(0, BLOCK_SIZE) < N
40 cols = tl.arange(0, BLOCK_SIZE)
41 x = tl.load(input_ptr + cols * in_stride_c, mask, other=0.0).to(cdtype)
42 r = tl.load(residual_ptr + cols * r_stride_c, mask, other=0.0).to(cdtype)
44 x += r
45 # write back to residual
46 tl.store(residual_ptr + cols * r_stride_c, x, mask=mask)
48 var = tl.sum(x * x / N, axis=0)
49 rrms = 1 / tl.sqrt(var + eps)
51 w = tl.load(w_ptr + tl.arange(0, BLOCK_SIZE), mask=mask, other=0.0)
52 y = (x * rrms * w).to(cdtype)
53 # write back to input
54 tl.store(input_ptr + cols * in_stride_c, y, mask=mask)
57@libentry()
58@triton.jit(do_not_specialize=["eps"])
59def fused_add_rms_norm_c_split_kernel(
60 input_ptr, # pointer to the input
61 residual_ptr, # pointer to the residual
62 w_ptr, # pointer to the weights
63 in_stride_r, # how much to increase the pointer when moving by 1 row
64 in_stride_c, # how much to increase the pointer when moving by 1 col
65 r_stride_r, # how much to increase the pointer when moving by 1 row
66 r_stride_c, # how much to increase the pointer when moving by 1 col
67 N, # number of columns in in_ptr
68 eps, # epsilon to avoid division by zero
69 BLOCK_SIZE: tl.constexpr,
70):
71 if tl.constexpr(input_ptr.dtype.element_ty == tl.float16) or tl.constexpr(
72 input_ptr.dtype.element_ty == tl.bfloat16
73 ):
74 cdtype = tl.float32
75 else:
76 cdtype = input_ptr.dtype.element_ty
78 pid = ext.program_id(0)
79 input_ptr += pid * in_stride_r
80 residual_ptr += pid * r_stride_r
82 _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
83 for off in range(0, N, BLOCK_SIZE):
84 cols = off + tl.arange(0, BLOCK_SIZE)
85 mask = cols < N
86 x = tl.load(input_ptr + cols, mask, other=0.0).to(cdtype)
87 r = tl.load(residual_ptr + cols, mask, other=0.0).to(cdtype)
88 x += r
89 _var += x * x / N
91 var = tl.sum(_var)
92 rrms = 1 / tl.sqrt(var + eps)
94 for off in range(0, N, BLOCK_SIZE):
95 cols = off + tl.arange(0, BLOCK_SIZE)
96 mask = cols < N
97 x = tl.load(input_ptr + cols, mask, other=0.0).to(cdtype)
98 r = tl.load(residual_ptr + cols, mask, other=0.0).to(cdtype)
99 x += r
100 tl.store(residual_ptr + cols, x, mask=mask)
101 w = tl.load(w_ptr + cols, mask=mask, other=0.0)
102 y = (x * rrms * w).to(cdtype)
103 tl.store(input_ptr + cols * in_stride_c, y, mask=mask)
106def fused_add_rms_norm(x, residual, normalized_shape, weight, eps=1e-5):
107 """
108 This function performs fused residual addition and RMS normalization **in-place**.
109 Both `x` and `residual` tensors will be modified. Use with caution if these tensors
110 are reused elsewhere or require gradients.
111 """
112 logger.debug(
113 "GEMS FUSED_ADD_RMS_NORM FORWARD, [input shape]: %s, [residual shape]: %s, [weight shape]: %s",
114 x.size(),
115 residual.size(),
116 weight.size(),
117 )
118 dim = x.ndim - len(normalized_shape)
119 M = math.prod(x.shape[:dim])
120 N = math.prod(normalized_shape)
122 BLOCK_SIZE = triton.next_power_of_2(N)
123 x = x.contiguous()
124 residual = residual.contiguous()
125 weight = weight.contiguous()
127 with torch_device_fn.device(x.device):
128 if BLOCK_SIZE <= 1024:
129 fused_add_rms_norm_kernel[M,](
130 x, residual, weight, N, 1, N, 1, N, eps, BLOCK_SIZE
131 )
132 else:
133 BLOCK_SIZE = 1024
134 fused_add_rms_norm_c_split_kernel[M,](
135 x,
136 residual,
137 weight,
138 N,
139 1,
140 N,
141 1,
142 N,
143 eps,
144 BLOCK_SIZE,
145 num_warps=8,
146 )
147 return x, residual