Coverage for src/flag_gems/runtime/backend/_nvidia/fused/fused_add_rms_norm.py: 32%
37 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
1import logging
2import math
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils import libentry
9from flag_gems.utils import triton_lang_extension as tle
11logger = logging.getLogger(__name__)
14@libentry()
15@triton.jit(do_not_specialize=["eps"])
16def fused_add_rms_norm_kernel(
17 X, # pointer to the input
18 R, # pointer to the residual
19 W, # pointer to the weights
20 x_stride_r, # how much to increase the pointer when moving by 1 row
21 x_stride_c, # how much to increase the pointer when moving by 1 col
22 r_stride_r, # how much to increase the pointer when moving by 1 row
23 r_stride_c, # how much to increase the pointer when moving by 1 col
24 N, # number of columns in X
25 eps, # epsilon to avoid division by zero
26 BLOCK_SIZE: tl.constexpr,
27):
28 pid = tle.program_id(0)
29 X += pid * x_stride_r
30 R += pid * r_stride_r
32 mask = tl.arange(0, BLOCK_SIZE) < N
33 cols = tl.arange(0, BLOCK_SIZE)
34 x = tl.load(X + cols * x_stride_c, mask, other=0.0).to(tl.float32)
35 r = tl.load(R + cols * r_stride_c, mask, other=0.0).to(tl.float32)
37 x += r
38 # write back to residual
39 tl.store(R + cols * r_stride_c, x, mask=mask)
41 var = tl.sum(x * x / N, axis=0)
42 rrms = 1 / tl.sqrt(var + eps)
44 w = tl.load(W + tl.arange(0, BLOCK_SIZE), mask=mask, other=0.0)
45 y = (x * rrms).to(X.dtype.element_ty) * w
46 # write back to input
47 tl.store(X + cols * x_stride_c, y, mask=mask)
50def fused_add_rms_norm(x, residual, normalized_shape, weight, eps=1e-5):
51 print(
52 "\n .......test for mutibackend specific fused op fused_add_rms_norm ........\n"
53 )
54 dim = x.ndim - len(normalized_shape)
55 M = math.prod(x.shape[:dim])
56 N = math.prod(normalized_shape)
58 BLOCK_SIZE = triton.next_power_of_2(N)
59 x = x.contiguous()
60 residual = residual.contiguous()
61 weight = weight.contiguous()
63 with torch_device_fn.device(x.device):
64 fused_add_rms_norm_kernel[M,](
65 x, residual, weight, N, 1, N, 1, N, eps, BLOCK_SIZE
66 )
67 return x, residual