Coverage for src/flag_gems/ops/rms_norm.py: 37%
119 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry
10from flag_gems.utils import triton_lang_extension as tle
12logger = logging.getLogger(__name__)
15@libentry()
16@triton.jit(do_not_specialize=["eps"])
17def rms_norm_kernel(
18 out_ptr, # pointer to the output
19 INV_RMS, # pointer to inverse rms
20 in_ptr, # pointer to the input
21 w_ptr, # pointer to the weights
22 y_stride_r,
23 y_stride_c,
24 x_stride_r, # how much to increase the pointer when moving by 1 row
25 x_stride_c, # how much to increase the pointer when moving by 1 col
26 N, # number of columns in X
27 eps, # epsilon to avoid division by zero
28 BLOCK_SIZE: tl.constexpr,
29):
30 if tl.constexpr(in_ptr.dtype.element_ty == tl.float16) or tl.constexpr(
31 in_ptr.dtype.element_ty == tl.bfloat16
32 ):
33 cdtype = tl.float32
34 else:
35 cdtype = in_ptr.dtype.element_ty
37 pid = tl.program_id(0)
38 out_ptr += pid * y_stride_r
39 in_ptr += pid * x_stride_r
41 mask = tl.arange(0, BLOCK_SIZE) < N
42 cols = tl.arange(0, BLOCK_SIZE)
43 x = tl.load(in_ptr + cols * x_stride_c, mask, other=0.0).to(cdtype)
45 var = tl.sum(x * x, axis=0) / N
46 rrms = 1 / tl.sqrt(var + eps)
48 w = tl.load(w_ptr + tl.arange(0, BLOCK_SIZE), mask=mask, other=0.0)
49 y = (x * rrms * w).to(cdtype)
50 tl.store(out_ptr + cols * y_stride_c, y, mask=mask)
51 tl.store(INV_RMS + pid, rrms)
54@libentry()
55@triton.jit(do_not_specialize=["eps"])
56def rms_norm_grad_dx_kernel(
57 X, # pointer to the input
58 DY,
59 INV_RMS, # pointer to inverse rms
60 DX, # pointer to the output
61 W, # pointer to the weights
62 dx_stride_r,
63 dx_stride_c,
64 x_stride_r, # how much to increase the pointer when moving by 1 row
65 x_stride_c, # how much to increase the pointer when moving by 1 col
66 N, # number of columns in X
67 eps, # epsilon to avoid division by zero
68 BLOCK_SIZE: tl.constexpr,
69):
70 pid = tle.program_id(0)
71 DX += pid * dx_stride_r
72 X += pid * x_stride_r
73 DY += pid * x_stride_r
74 INV_RMS += pid
76 mask = tl.arange(0, BLOCK_SIZE) < N
77 cols = tl.arange(0, BLOCK_SIZE)
78 x = tl.load(X + cols * x_stride_c, mask, other=0.0).to(tl.float32)
79 inv_rms = tl.load(INV_RMS).to(tl.float32)
80 dy = tl.load(DY + cols * x_stride_c, mask, other=0.0).to(tl.float32)
81 w = tl.load(W + tl.arange(0, BLOCK_SIZE), mask=mask, other=0.0)
83 dy = dy * w
85 normalized_buf = x * inv_rms
86 row_sum_stats = tl.sum(normalized_buf * dy, axis=0)
88 norm_val = normalized_buf / N
89 dx = (dy - norm_val * row_sum_stats) * inv_rms
91 tl.store(DX + cols * dx_stride_c, dx, mask=mask)
94@libentry()
95@triton.jit
96def rms_norm_grad_dw_kernel(
97 X, # pointer to the input
98 DY,
99 INV_RMS, # pointer to inverse rms
100 DW, # pointer to the output
101 dx_stride_r,
102 dx_stride_c,
103 x_stride_r, # how much to increase the pointer when moving by 1 row
104 x_stride_c, # how much to increase the pointer when moving by 1 col
105 M, # number of rows in X
106 N, # number of columns in X
107 ROW_BLOCK_SIZE: tl.constexpr,
108 COL_BLOCK_SIZE: tl.constexpr,
109):
110 row_pid = tl.program_id(0)
111 col_pid = tl.program_id(1)
113 row_start = row_pid * ROW_BLOCK_SIZE
114 col_start = col_pid * COL_BLOCK_SIZE
116 offset = row_start * x_stride_r + col_start * x_stride_c
117 X += offset
118 DY += offset
119 INV_RMS += row_start
121 rows = tl.arange(0, ROW_BLOCK_SIZE)
122 cols = tl.arange(0, COL_BLOCK_SIZE)
124 row_mask = (row_start + rows) < M
125 col_mask = (col_start + cols) < N
127 x = tl.load(
128 X + rows[:, None] * x_stride_r + cols[None, :] * x_stride_c,
129 row_mask[:, None] & col_mask[None, :],
130 other=0.0,
131 ).to(tl.float32)
132 inv_rms = tl.load(INV_RMS + rows, row_mask, other=0.0).to(tl.float32)
133 dy = tl.load(
134 DY + rows[:, None] * x_stride_r + cols[None, :] * x_stride_c,
135 row_mask[:, None] & col_mask[None, :],
136 other=0.0,
137 ).to(tl.float32)
139 d_weight = x * dy * inv_rms[:, None]
140 # Sum over rows (axis=0) - masked rows are 0 (from other=0.0 in load), so sum is correct
141 # The mask ensures invalid rows contribute 0 to the sum
142 partial_dweight_sum = tl.sum(d_weight, axis=0)
144 tl.store(
145 DW + row_pid * N + col_start + cols,
146 partial_dweight_sum,
147 mask=col_mask,
148 )
151def rms_norm_forward(x, normalized_shape, weight, eps=1e-5):
152 logger.debug("GEMS RMS_NORM FORWARD")
153 dim = x.ndim - len(normalized_shape)
154 M = math.prod(x.shape[:dim])
155 N = math.prod(normalized_shape)
157 BLOCK_SIZE = triton.next_power_of_2(N)
158 x = x.contiguous()
159 weight = weight.contiguous()
160 y = torch.empty_like(x)
161 inv_rms = torch.empty((M,), device=x.device, dtype=torch.float32)
163 with torch_device_fn.device(x.device):
164 rms_norm_kernel[M,](y, inv_rms, x, weight, N, 1, N, 1, N, eps, BLOCK_SIZE)
166 return y, inv_rms
169def rms_norm_backward(dy, x, inv_rms, normalized_shape, weight, eps=1e-5):
170 logger.debug("GEMS RMS_NORM BACKWARD")
171 dim = x.ndim - len(normalized_shape)
172 M = math.prod(x.shape[:dim])
173 N = math.prod(normalized_shape)
175 BLOCK_SIZE = triton.next_power_of_2(N)
176 x = x.contiguous()
177 dy = dy.contiguous()
178 weight = weight.contiguous()
179 dx = torch.empty_like(x)
181 with torch_device_fn.device(x.device):
182 rms_norm_grad_dx_kernel[M,](
183 x, dy, inv_rms, dx, weight, N, 1, N, 1, N, eps, BLOCK_SIZE
184 )
186 ROW_BLOCK_SIZE = 16
187 COL_BLOCK_SIZE = 256
188 row_block_num = triton.cdiv(M, ROW_BLOCK_SIZE)
189 col_block_num = triton.cdiv(N, COL_BLOCK_SIZE)
191 partial_buffer = torch.empty(
192 (row_block_num, N), dtype=torch.float32, device=x.device
193 )
195 with torch_device_fn.device(x.device):
196 rms_norm_grad_dw_kernel[row_block_num, col_block_num](
197 x,
198 dy,
199 inv_rms,
200 partial_buffer,
201 N,
202 1,
203 N,
204 1,
205 M,
206 N,
207 ROW_BLOCK_SIZE,
208 COL_BLOCK_SIZE,
209 )
210 dw = (
211 torch.sum(partial_buffer, dim=0, dtype=torch.float32)
212 .to(x.dtype)
213 .reshape(-1)
214 )
216 return dx, dw
219class RmsNorm(torch.autograd.Function):
220 @staticmethod
221 def forward(ctx, x, normalized_shape, weight, eps=1e-5):
222 y, inv_rms = rms_norm_forward(x, normalized_shape, weight, eps)
223 ctx.save_for_backward(x, inv_rms, weight)
224 ctx.normalized_shape = normalized_shape
225 ctx.eps = eps
226 return y
228 @staticmethod
229 def backward(ctx, dy):
230 x, inv_rms, weight = ctx.saved_tensors
231 normalized_shape = ctx.normalized_shape
232 eps = ctx.eps
234 dx, dw = rms_norm_backward(dy, x, inv_rms, normalized_shape, weight, eps)
235 return dx, None, dw, None
238def rms_norm(x, normalized_shape, weight, eps=1e-5):
239 return RmsNorm.apply(x, normalized_shape, weight, eps)