Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/rms_norm.py: 0%
213 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
1import builtins
2import logging
3import math
5import torch
6import triton
7import triton.language as tl
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry
11from flag_gems.utils import triton_lang_extension as tle
13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
16@libentry()
17@triton.jit
18def rms_norm_kernel(
19 Y, # pointer to the output
20 INV_RMS, # pointer to inverse rms
21 X, # pointer to the input
22 W, # pointer to the weights
23 y_stride_r,
24 y_stride_c,
25 x_stride_r, # how much to increase the pointer when moving by 1 row
26 x_stride_c, # how much to increase the pointer when moving by 1 col
27 M: tl.constexpr, # number of rows in X
28 N: tl.constexpr, # number of columns in X
29 eps: tl.constexpr, # epsilon to avoid division by zero
30 BLOCK_SIZE: tl.constexpr,
31):
32 pid = tle.program_id(0)
33 Y += pid * y_stride_r
34 X += pid * x_stride_r
36 colMask = tl.arange(0, BLOCK_SIZE) < M
37 mask = tl.arange(0, BLOCK_SIZE) < N
38 cols = tl.arange(0, BLOCK_SIZE)
39 x = tl.load(X + cols * x_stride_c, mask & colMask, other=0.0).to(tl.float32)
41 var = tl.sum(x * x, axis=0) / N
42 rrms = 1 / tl.sqrt(var + eps)
44 w = tl.load(W + tl.arange(0, BLOCK_SIZE), mask=mask, other=0.0)
45 y = (x * rrms).to(Y.dtype.element_ty) * w
46 tl.store(Y + cols * y_stride_c, y, mask=mask)
47 tl.store(INV_RMS + pid, rrms)
50@libentry()
51@triton.jit
52def rms_norm_kerne_tile(
53 Y, # pointer to the output
54 INV_RMS, # pointer to inverse rms
55 X, # pointer to the input
56 W, # pointer to the weights
57 y_stride_r,
58 y_stride_c,
59 x_stride_r, # how much to increase the pointer when moving by 1 row
60 x_stride_c, # how much to increase the pointer when moving by 1 col
61 M: tl.constexpr, # number of rows in X
62 N: tl.constexpr, # number of columns in X
63 eps: tl.constexpr, # epsilon to avoid division by zero
64 BLOCK_SIZE: tl.constexpr,
65):
66 pid = tl.program_id(0)
67 Y += pid * y_stride_r
68 X += pid * x_stride_r
70 # mask = tl.arange(0, BLOCK_SIZE) < N
71 # cols = tl.arange(0, BLOCK_SIZE)
72 # x = tl.load(X + cols * x_stride_c, mask, other=0.0).to(tl.float32)
74 # var = tl.sum(x * x, axis=0) / N
75 # rrms = 1 / tl.sqrt(var + eps)
77 colMask = tl.arange(0, BLOCK_SIZE) < M
79 _var_base = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
80 for off in range(0, N, BLOCK_SIZE):
81 cols = off + tl.arange(0, BLOCK_SIZE)
82 mask = cols < N
83 x = tl.load(X + cols, mask & colMask, other=0.0).to(tl.float32)
84 _var_base += x * x / N
85 var = tl.sum(_var_base)
86 rrms = 1 / tl.sqrt(var + eps)
88 # w = tl.load(W + tl.arange(0, BLOCK_SIZE), mask=mask, other=0.0)
89 # y = (x * rrms).to(Y.dtype.element_ty) * w
90 # tl.store(Y + cols * y_stride_c, y, mask=mask)
91 for off in range(0, N, BLOCK_SIZE):
92 cols = off + tl.arange(0, BLOCK_SIZE)
93 mask = cols < N
94 x = tl.load(X + cols, mask, other=0.0).to(tl.float32)
95 w = tl.load(W + cols, mask, other=0.0)
96 y = (x * rrms).to(Y.dtype.element_ty) * w
97 tl.store(Y + cols * y_stride_c, y, mask=mask)
99 tl.store(INV_RMS + pid, rrms)
102@libentry()
103@triton.jit(do_not_specialize=["eps"])
104def rms_norm_grad_dx_kernel(
105 X, # pointer to the input
106 DY,
107 INV_RMS, # pointer to inverse rms
108 DX, # pointer to the output
109 W, # pointer to the weights
110 dx_stride_r,
111 dx_stride_c,
112 x_stride_r, # how much to increase the pointer when moving by 1 row
113 x_stride_c, # how much to increase the pointer when moving by 1 col
114 N, # number of columns in X
115 eps, # epsilon to avoid division by zero
116 BLOCK_SIZE: tl.constexpr,
117):
118 pid = tle.program_id(0)
119 DX += pid * dx_stride_r
120 X += pid * x_stride_r
121 DY += pid * x_stride_r
122 INV_RMS += pid
124 mask = tl.arange(0, BLOCK_SIZE) < N
125 cols = tl.arange(0, BLOCK_SIZE)
126 x = tl.load(X + cols * x_stride_c, mask, other=0.0).to(tl.float32)
127 inv_rms = tl.load(INV_RMS).to(tl.float32)
128 dy = tl.load(DY + cols * x_stride_c, mask, other=0.0).to(tl.float32)
129 w = tl.load(W + tl.arange(0, BLOCK_SIZE), mask=mask, other=0.0)
131 dy = dy * w
133 normalized_buf = x * inv_rms
134 row_sum_stats = tl.sum(normalized_buf * dy, axis=0)
136 norm_val = normalized_buf / N
137 dx = (dy - norm_val * row_sum_stats) * inv_rms
139 tl.store(DX + cols * dx_stride_c, dx, mask=mask)
142@libentry()
143@triton.jit(do_not_specialize=["eps"])
144def rms_norm_grad_dx_kernel_tile(
145 X, # pointer to the input
146 DY,
147 INV_RMS, # pointer to inverse rms
148 DX, # pointer to the output
149 W, # pointer to the weights
150 dx_stride_r,
151 dx_stride_c,
152 x_stride_r, # how much to increase the pointer when moving by 1 row
153 x_stride_c, # how much to increase the pointer when moving by 1 col
154 N, # number of columns in X
155 eps, # epsilon to avoid division by zero
156 BLOCK_SIZE: tl.constexpr,
157):
158 pid = tle.program_id(0)
159 DX += pid * dx_stride_r
160 X += pid * x_stride_r
161 DY += pid * x_stride_r
162 INV_RMS += pid
164 # mask = tl.arange(0, BLOCK_SIZE) < N
165 # cols = tl.arange(0, BLOCK_SIZE)
166 # x = tl.load(X + cols * x_stride_c, mask, other=0.0).to(tl.float32)
167 inv_rms = tl.load(INV_RMS).to(tl.float32)
168 # dy = tl.load(DY + cols * x_stride_c, mask, other=0.0).to(tl.float32)
169 # w = tl.load(W + tl.arange(0, BLOCK_SIZE), mask=mask, other=0.0)
171 # dy = dy * w
173 # normalized_buf = x * inv_rms
174 # row_sum_stats = tl.sum(normalized_buf * dy, axis=0)
176 row_sum_stats_base = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
177 for off in range(0, N, BLOCK_SIZE):
178 cols = off + tl.arange(0, BLOCK_SIZE)
179 mask = cols < N
180 x = tl.load(X + cols, mask, other=0.0).to(tl.float32)
181 dy = tl.load(DY + cols, mask, other=0.0).to(tl.float32)
182 w = tl.load(W + cols, mask, other=0.0).to(tl.float32)
184 dy = dy * w
186 normalized_buf = x * inv_rms
188 row_sum_stats_base += normalized_buf * dy
189 row_sum_stats = tl.sum(row_sum_stats_base)
191 # norm_val = normalized_buf / N
192 # dx = (dy - norm_val * row_sum_stats) * inv_rms
194 for off in range(0, N, BLOCK_SIZE):
195 cols = off + tl.arange(0, BLOCK_SIZE)
196 mask = cols < N
197 x = tl.load(X + cols, mask, other=0.0).to(tl.float32)
198 dy = tl.load(DY + cols, mask, other=0.0).to(tl.float32)
199 w = tl.load(W + cols, mask, other=0.0).to(tl.float32)
201 dy = dy * w
203 normalized_buf = x * inv_rms
204 norm_val = normalized_buf / N
205 dx = (dy - norm_val * row_sum_stats) * inv_rms
207 tl.store(DX + cols * dx_stride_c, dx, mask=mask)
210@libentry()
211@triton.jit
212def rms_norm_grad_dw_kernel(
213 X, # pointer to the input
214 DY,
215 INV_RMS, # pointer to inverse rms
216 DW, # pointer to the output
217 dx_stride_r,
218 dx_stride_c,
219 x_stride_r, # how much to increase the pointer when moving by 1 row
220 x_stride_c, # how much to increase the pointer when moving by 1 col
221 M, # number of rows in X
222 N, # number of columns in X
223 ROW_BLOCK_SIZE: tl.constexpr,
224 COL_BLOCK_SIZE: tl.constexpr,
225):
226 row_pid = tl.program_id(0)
227 col_pid = tl.program_id(1)
229 row_start = row_pid * ROW_BLOCK_SIZE
230 col_start = col_pid * COL_BLOCK_SIZE
232 offset = row_start * x_stride_r + col_start * x_stride_c
233 X += offset
234 DY += offset
235 INV_RMS += row_start
237 rows = tl.arange(0, ROW_BLOCK_SIZE)
238 cols = tl.arange(0, COL_BLOCK_SIZE)
240 row_mask = (row_start + rows) < M
241 col_mask = (col_start + cols) < N
243 x = tl.load(
244 X + rows[:, None] * x_stride_r + cols[None, :] * x_stride_c,
245 row_mask[:, None] & col_mask[None, :],
246 other=0.0,
247 ).to(tl.float32)
248 inv_rms = tl.load(INV_RMS + rows, row_mask, other=0.0).to(tl.float32)
249 dy = tl.load(
250 DY + rows[:, None] * x_stride_r + cols[None, :] * x_stride_c,
251 row_mask[:, None] & col_mask[None, :],
252 other=0.0,
253 ).to(tl.float32)
255 d_weight = x * dy * inv_rms[:, None]
256 partial_dweight_sum = tl.sum(d_weight, axis=0)
258 tl.store(
259 DW + row_pid * N + col_start + cols,
260 partial_dweight_sum,
261 mask=col_mask,
262 )
265@libentry()
266@triton.jit
267def rms_norm_grad_kernel(
268 X,
269 DY,
270 DX,
271 W,
272 INV_RMS,
273 DW,
274 M: tl.constexpr,
275 N: tl.constexpr,
276 eps: tl.constexpr,
277 BLOCK_SIZE: tl.constexpr,
278):
279 row_idx = tl.program_id(0)
281 cols = tl.arange(0, BLOCK_SIZE)
282 mask = cols < N
284 x_ptr = X + row_idx * N + cols
285 dy_ptr = DY + row_idx * N + cols
286 w_ptr = W + cols
288 x = tl.load(x_ptr, mask=mask, other=0.0).to(tl.float32)
289 dy = tl.load(dy_ptr, mask=mask, other=0.0).to(tl.float32)
290 weight = tl.load(w_ptr, mask=mask, other=0.0).to(tl.float32)
291 inv_rms = tl.load(INV_RMS + row_idx).to(tl.float32)
293 dy_w = dy * weight
294 x_inv_rms = x * inv_rms
295 m_grad = tl.sum(dy_w * x, axis=0)
296 dx = inv_rms * (dy_w - x_inv_rms * (m_grad / N))
297 dx_ptr = DX + row_idx * N + cols
298 tl.store(dx_ptr, dx, mask=mask)
299 dw_partial = dy * x_inv_rms
300 dw_ptr = DW + cols
301 tl.store(dw_ptr, dw_partial, mask=mask)
304def rms_norm_forward(x, normalized_shape, weight, eps=1e-5):
305 logger.debug("GEMS RMS_NORM FORWARD")
306 dim = x.ndim - len(normalized_shape)
307 M = math.prod(x.shape[:dim])
308 N = math.prod(normalized_shape)
310 # BLOCK_SIZE = triton.next_power_of_2(N)
311 BLOCK_SIZE = builtins.min(
312 64 * 128, triton.next_power_of_2(N)
313 ) # core_num * buffer_size_limit
315 x = x.contiguous()
316 weight = weight.contiguous()
317 y = torch.empty_like(x)
318 inv_rms = torch.empty((M,), device=x.device, dtype=torch.float32)
320 with torch_device_fn.device(x.device):
321 if N > 64 * 128:
322 rms_norm_kerne_tile[M,](
323 y, inv_rms, x, weight, N, 1, N, 1, M, N, eps, BLOCK_SIZE
324 )
325 else:
326 rms_norm_kernel[M,](
327 y, inv_rms, x, weight, N, 1, N, 1, M, N, eps, BLOCK_SIZE
328 )
330 return y, inv_rms
333def rms_norm_backward(dy, x, inv_rms, normalized_shape, weight, eps=1e-5):
334 logger.debug("GEMS RMS_NORM BACKWARD")
336 dim = x.ndim - len(normalized_shape)
337 M = math.prod(x.shape[:dim])
338 N = math.prod(normalized_shape)
340 BLOCK_SIZE = triton.next_power_of_2(N)
341 x = x.contiguous()
342 dy = dy.contiguous()
343 weight = weight.contiguous()
344 dx = torch.empty_like(x)
346 with torch_device_fn.device(x.device):
347 if N > 64 * 128:
348 BLOCK_SIZE = 8192
349 rms_norm_grad_dx_kernel_tile[M,](
350 x,
351 dy,
352 inv_rms,
353 dx,
354 weight,
355 N,
356 1,
357 N,
358 1,
359 N,
360 eps,
361 BLOCK_SIZE,
362 isCloseUnrollControl=True,
363 isCloseVectorization=True,
364 )
365 else:
366 rms_norm_grad_dx_kernel[M,](
367 x,
368 dy,
369 inv_rms,
370 dx,
371 weight,
372 N,
373 1,
374 N,
375 1,
376 N,
377 eps,
378 BLOCK_SIZE,
379 isCloseUnrollControl=True,
380 )
382 ROW_BLOCK_SIZE = 1
383 COL_BLOCK_SIZE = 256
384 row_block_num = triton.cdiv(M, ROW_BLOCK_SIZE)
385 col_block_num = triton.cdiv(N, COL_BLOCK_SIZE)
387 partial_buffer = torch.empty(
388 (row_block_num, N), dtype=torch.float32, device=x.device
389 )
391 with torch_device_fn.device(x.device):
392 rms_norm_grad_dw_kernel[row_block_num, col_block_num](
393 x,
394 dy,
395 inv_rms,
396 partial_buffer,
397 N,
398 1,
399 N,
400 1,
401 M,
402 N,
403 ROW_BLOCK_SIZE,
404 COL_BLOCK_SIZE,
405 isCloseUnrollControl=True,
406 isCloseCoreTiling=True,
407 )
408 dw = torch.sum(partial_buffer, dim=0, dtype=x.dtype).reshape(-1)
409 return dx, dw
412def rms_norm_backward_fusion(dy, x, inv_rms, normalized_shape, weight, eps=1e-5):
413 logger.debug("GEMS RMS_NORM BACKWARD")
415 dim = x.ndim - len(normalized_shape)
416 M = math.prod(x.shape[:dim]) # Batch dimension
417 N = math.prod(normalized_shape) # Feature dimension
419 x = x.contiguous()
420 dy = dy.contiguous()
421 weight = weight.contiguous()
423 dx = torch.empty_like(x)
424 dw = torch.empty_like(weight)
426 BLOCK_SIZE = 64
428 with torch_device_fn.device(x.device):
429 rms_norm_grad_kernel[(M,)](
430 x,
431 dy,
432 dx,
433 weight,
434 inv_rms,
435 dw,
436 M,
437 N,
438 eps,
439 BLOCK_SIZE=BLOCK_SIZE,
440 )
441 return dx, dw
444class RmsNorm(torch.autograd.Function):
445 @staticmethod
446 def forward(ctx, x, normalized_shape, weight, eps=1e-5):
447 y, inv_rms = rms_norm_forward(x, normalized_shape, weight, eps)
448 ctx.save_for_backward(x, inv_rms, weight)
449 ctx.normalized_shape = normalized_shape
450 ctx.eps = eps
451 return y
453 @staticmethod
454 def backward(ctx, dy):
455 x, inv_rms, weight = ctx.saved_tensors
456 normalized_shape = ctx.normalized_shape
457 eps = ctx.eps
459 # dx, dw = rms_norm_backward(dy, x, inv_rms, normalized_shape, weight, eps)
460 dx, dw = rms_norm_backward_fusion(dy, x, inv_rms, normalized_shape, weight, eps)
461 return dx, None, dw, None
464def rms_norm(x, normalized_shape, weight, eps=1e-5):
465 return RmsNorm.apply(x, normalized_shape, weight, eps)