Coverage for src/flag_gems/runtime/backend/_tsingmicro/ops/rms_norm.py: 0%
175 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry
12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
13MAX_NRAM_C_FORWARD = 16384 * 2
16@libentry()
17@triton.jit(do_not_specialize=["eps"])
18def rms_norm_kernel(
19 Y, # pointer to the output
20 INV_RMS, # pointer to inverse rms
21 X, # pointer to the input
22 W, # pointer to the weights
23 y_stride_r,
24 y_stride_c,
25 x_stride_r, # how much to increase the pointer when moving by 1 row
26 x_stride_c, # how much to increase the pointer when moving by 1 col
27 N, # number of columns in X
28 eps, # epsilon to avoid division by zero
29 BLOCK_SIZE: tl.constexpr,
30):
31 pid = tl.program_id(0)
32 Y += pid * y_stride_r
33 X += pid * x_stride_r
35 mask = tl.arange(0, BLOCK_SIZE) < N
36 cols = tl.arange(0, BLOCK_SIZE)
37 x = tl.load(X + cols * x_stride_c, mask, other=0.0).to(tl.float32)
39 var = tl.sum(x * x, axis=0) / N
40 rrms = 1 / tl.sqrt(var + eps)
42 w = tl.load(W + tl.arange(0, BLOCK_SIZE), mask=mask, other=0.0)
43 y = (x * rrms).to(Y.dtype.element_ty) * w
44 tl.store(Y + cols * y_stride_c, y, mask=mask)
45 tl.store(INV_RMS + pid, rrms)
48@libentry()
49@triton.autotune(
50 configs=runtime.get_tuned_config("common_reduce_ops"),
51 key=["N"],
52)
53@triton.jit(do_not_specialize=["eps"])
54def rms_norm_kernel_C_split(
55 Y, # pointer to the output
56 INV_RMS, # pointer to inverse rms
57 X, # pointer to the input
58 W, # pointer to the weights
59 y_stride_r,
60 y_stride_c,
61 x_stride_r, # how much to increase the pointer when moving by 1 row
62 x_stride_c, # how much to increase the pointer when moving by 1 col
63 N, # number of columns in X
64 eps, # epsilon to avoid division by zero
65 BLOCK_SIZE: tl.constexpr,
66):
67 pid = tl.program_id(0)
68 Y += pid * y_stride_r
69 X += pid * x_stride_r
71 var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
72 for m_idx in range(0, N, BLOCK_SIZE):
73 cols = m_idx + tl.arange(0, BLOCK_SIZE)
74 mask = cols < N
75 x = tl.load(X + cols * x_stride_c, mask, other=0.0).to(tl.float32)
76 var += x * x
78 var = tl.sum(var, axis=0) / N
79 rrms = 1 / tl.sqrt(var + eps)
81 for m_idx in range(0, N, BLOCK_SIZE):
82 cols = m_idx + tl.arange(0, BLOCK_SIZE)
83 mask = cols < N
84 w = tl.load(W + cols, mask=mask, other=0.0)
85 x = tl.load(X + cols * x_stride_c, mask, other=0.0).to(tl.float32)
86 y = (x * rrms).to(Y.dtype.element_ty) * w
87 tl.store(Y + cols * y_stride_c, y, mask=mask)
88 tl.store(INV_RMS + pid, rrms)
91@libentry()
92@triton.jit(do_not_specialize=["eps"])
93def rms_norm_grad_dx_kernel(
94 X, # pointer to the input
95 DY,
96 INV_RMS, # pointer to inverse rms
97 DX, # pointer to the output
98 W, # pointer to the weights
99 dx_stride_r,
100 dx_stride_c,
101 x_stride_r, # how much to increase the pointer when moving by 1 row
102 x_stride_c, # how much to increase the pointer when moving by 1 col
103 N, # number of columns in X
104 eps, # epsilon to avoid division by zero
105 BLOCK_SIZE: tl.constexpr,
106):
107 pid = tl.program_id(0)
108 DX += pid * dx_stride_r
109 X += pid * x_stride_r
110 DY += pid * x_stride_r
111 INV_RMS += pid
113 mask = tl.arange(0, BLOCK_SIZE) < N
114 cols = tl.arange(0, BLOCK_SIZE)
115 x = tl.load(X + cols * x_stride_c, mask, other=0.0).to(tl.float32)
116 inv_rms = tl.load(INV_RMS).to(tl.float32)
117 dy = tl.load(DY + cols * x_stride_c, mask, other=0.0).to(tl.float32)
118 w = tl.load(W + tl.arange(0, BLOCK_SIZE), mask=mask, other=0.0)
120 dy = dy * w
122 normalized_buf = x * inv_rms
123 row_sum_stats = tl.sum(normalized_buf * dy, axis=0)
125 norm_val = normalized_buf / N
126 dx = (dy - norm_val * row_sum_stats) * inv_rms
128 tl.store(DX + cols * dx_stride_c, dx, mask=mask)
131@libentry()
132@triton.autotune(
133 configs=runtime.get_tuned_config("common_reduce_ops"),
134 key=["N"],
135)
136@triton.jit(do_not_specialize=["eps"])
137def rms_norm_grad_dx_kernel_C_split(
138 X, # pointer to the input
139 DY,
140 INV_RMS, # pointer to inverse rms
141 DX, # pointer to the output
142 W, # pointer to the weights
143 dx_stride_r,
144 dx_stride_c,
145 x_stride_r, # how much to increase the pointer when moving by 1 row
146 x_stride_c, # how much to increase the pointer when moving by 1 col
147 N, # number of columns in X
148 eps, # epsilon to avoid division by zero
149 BLOCK_SIZE: tl.constexpr,
150):
151 pid = tl.program_id(0)
152 DX += pid * dx_stride_r
153 X += pid * x_stride_r
154 DY += pid * x_stride_r
155 INV_RMS += pid
156 inv_rms = tl.load(INV_RMS).to(tl.float32)
158 acc = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
159 for m_idx in range(0, N, BLOCK_SIZE):
160 cols = m_idx + tl.arange(0, BLOCK_SIZE)
161 mask = cols < N
162 x = tl.load(X + cols * x_stride_c, mask=mask, other=0.0).to(tl.float32)
163 inv_rms = tl.load(INV_RMS).to(tl.float32)
164 dy = tl.load(DY + cols * x_stride_c, mask=mask, other=0.0).to(tl.float32)
165 w = tl.load(W + cols, mask=mask, other=0.0)
166 dy = dy * w
167 normalized = x * inv_rms
168 acc += normalized * dy
170 row_sum_stats = tl.sum(acc, axis=0)
172 for m_idx in range(0, N, BLOCK_SIZE):
173 cols = m_idx + tl.arange(0, BLOCK_SIZE)
174 mask = cols < N
175 x = tl.load(X + cols * x_stride_c, mask=mask, other=0.0).to(tl.float32)
176 inv_rms = tl.load(INV_RMS).to(tl.float32)
177 dy = tl.load(DY + cols * x_stride_c, mask=mask, other=0.0).to(tl.float32)
178 w = tl.load(W + cols, mask=mask, other=0.0)
179 dy = dy * w
180 normalized = x * inv_rms
181 norm_val = normalized / N
182 dx = (dy - norm_val * row_sum_stats) * inv_rms
183 tl.store(DX + cols * dx_stride_c, dx, mask=mask)
186@libentry()
187@triton.jit
188def rms_norm_grad_dw_kernel(
189 X, # pointer to the input
190 DY,
191 INV_RMS, # pointer to inverse rms
192 DW, # pointer to the output
193 dx_stride_r,
194 dx_stride_c,
195 x_stride_r, # how much to increase the pointer when moving by 1 row
196 x_stride_c, # how much to increase the pointer when moving by 1 col
197 M, # number of rows in X
198 N, # number of columns in X
199 ROW_BLOCK_SIZE: tl.constexpr,
200 COL_BLOCK_SIZE: tl.constexpr,
201):
202 row_pid = tl.program_id(0)
203 col_pid = tl.program_id(1)
205 row_start = row_pid * ROW_BLOCK_SIZE
206 col_start = col_pid * COL_BLOCK_SIZE
208 offset = row_start * x_stride_r + col_start * x_stride_c
209 X += offset
210 DY += offset
211 INV_RMS += row_start
213 rows = tl.arange(0, ROW_BLOCK_SIZE)
214 cols = tl.arange(0, COL_BLOCK_SIZE)
216 row_mask = (row_start + rows) < M
217 col_mask = (col_start + cols) < N
219 x = tl.load(
220 X + rows[:, None] * x_stride_r + cols[None, :] * x_stride_c,
221 row_mask[:, None] & col_mask[None, :],
222 other=0.0,
223 ).to(tl.float32)
224 inv_rms = tl.load(INV_RMS + rows, row_mask, other=0.0).to(tl.float32)
225 dy = tl.load(
226 DY + rows[:, None] * x_stride_r + cols[None, :] * x_stride_c,
227 row_mask[:, None] & col_mask[None, :],
228 other=0.0,
229 ).to(tl.float32)
231 d_weight = x * dy * inv_rms[:, None]
232 partial_dweight_sum = tl.sum(d_weight, axis=0)
234 tl.store(
235 DW + row_pid * N + col_start + cols,
236 partial_dweight_sum,
237 mask=col_mask,
238 )
241class RmsNorm(torch.autograd.Function):
242 @staticmethod
243 def forward(ctx, x, normalized_shape, weight, eps=1e-5):
244 logger.debug("GEMS_TSINGMICRO RMSNORM FORWARD")
245 dim = x.ndim - len(normalized_shape)
246 M = math.prod(x.shape[:dim])
247 N = math.prod(normalized_shape)
249 BLOCK_SIZE = N # triton.next_power_of_2(N)
250 x = x.contiguous()
251 weight = weight.contiguous()
252 y = torch.empty_like(x)
253 inv_rms = torch.empty((M,), device=x.device, dtype=torch.float32)
255 with torch_device_fn.device(x.device):
256 if BLOCK_SIZE <= MAX_NRAM_C_FORWARD:
257 logger.debug("GEMS_TSINGMICRO RMSNORM FORWARD NOT USING C SPLIT")
258 rms_norm_kernel[M,](
259 y, inv_rms, x, weight, N, 1, N, 1, N, eps, BLOCK_SIZE
260 )
261 else:
262 logger.debug("GEMS_TSINGMICRO RMSNORM FORWARD USING C SPLIT")
263 rms_norm_kernel_C_split[M,](y, inv_rms, x, weight, N, 1, N, 1, N, eps)
265 ctx.save_for_backward(x, inv_rms, weight)
266 ctx.normalized_shape = normalized_shape
267 ctx.eps = eps
268 return y
270 @staticmethod
271 def backward(ctx, dy):
272 logger.debug("GEMS_TSINGMICRO RMSNORM BACKWARD")
273 x, inv_rms, weight = ctx.saved_tensors
274 normalized_shape = ctx.normalized_shape
275 eps = ctx.eps
277 dim = x.ndim - len(normalized_shape)
278 M = math.prod(x.shape[:dim])
279 N = math.prod(normalized_shape)
281 # BLOCK_SIZE = triton.next_power_of_2(N)
282 BLOCK_SIZE = N
283 x = x.contiguous()
284 weight = weight.contiguous()
285 dx = torch.empty_like(x)
287 with torch_device_fn.device(x.device):
288 if BLOCK_SIZE <= MAX_NRAM_C_FORWARD:
289 logger.debug("GEMS_TSINGMICRO RMSNORM BACKWARD NOT USING C SPLIT")
290 rms_norm_grad_dx_kernel[M,](
291 x, dy, inv_rms, dx, weight, N, 1, N, 1, N, eps, BLOCK_SIZE
292 )
293 else:
294 logger.debug("GEMS_TSINGMICRO RMSNORM BACKWARD USING C SPLIT")
295 rms_norm_grad_dx_kernel_C_split[M,](
296 x, dy, inv_rms, dx, weight, N, 1, N, 1, N, eps
297 )
299 ROW_BLOCK_SIZE = 16
300 COL_BLOCK_SIZE = 256
301 row_block_num = triton.cdiv(M, ROW_BLOCK_SIZE)
302 col_block_num = triton.cdiv(N, COL_BLOCK_SIZE)
304 partial_buffer = torch.empty(
305 (row_block_num, N), dtype=torch.float32, device=x.device
306 )
308 with torch_device_fn.device(x.device):
309 rms_norm_grad_dw_kernel[row_block_num, col_block_num](
310 x,
311 dy,
312 inv_rms,
313 partial_buffer,
314 N,
315 1,
316 N,
317 1,
318 M,
319 N,
320 ROW_BLOCK_SIZE,
321 COL_BLOCK_SIZE,
322 )
323 dw = torch.sum(partial_buffer, dim=0, dtype=x.dtype).reshape(-1)
325 return dx, None, dw, None
328def rms_norm(x, normalized_shape, weight, eps=1e-5):
329 return RmsNorm.apply(x, normalized_shape, weight, eps)