Coverage for src/flag_gems/runtime/backend/_cambricon/fused/fused_add_rms_norm.py: 0%
63 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
1import logging
2import math
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils import libentry
10from ..utils import TOTAL_CORE_NUM
12logger = logging.getLogger(__name__)
15def get_configs():
16 configs = []
17 for BLOCK_SIZE in [2048, 1024, 512]:
18 for M_BLOCK in range(1, 10, 2):
19 for num_stages in [1, 5]:
20 configs.append(
21 triton.Config(
22 {"M_BLOCK": M_BLOCK, "BLOCK_SIZE": BLOCK_SIZE},
23 num_stages=num_stages,
24 num_warps=1,
25 )
26 )
27 return configs
30@triton.autotune(
31 configs=get_configs(),
32 key=["M", "N_COLS"],
33 restore_value=["x_ptr", "r_ptr"],
34)
35@libentry()
36@triton.jit(do_not_specialize=["eps"])
37def fused_add_rms_norm_kernel(
38 x_ptr,
39 r_ptr,
40 w_ptr,
41 eps,
42 stride,
43 M,
44 N_COLS: tl.constexpr,
45 BLOCK_SIZE: tl.constexpr,
46 M_BLOCK: tl.constexpr,
47):
48 pid = tl.program_id(0)
49 pnum = tl.num_programs(axis=0)
50 M_OUT_BLOCK = tl.cdiv(M, pnum)
52 lb = pid * M_OUT_BLOCK
53 ub = tl.minimum((pid + 1) * M_OUT_BLOCK, M)
54 for m_start in range(lb, ub, M_BLOCK):
55 m_offset = m_start + tl.arange(0, M_BLOCK)
56 mx_ptr = x_ptr + stride * m_offset
57 mr_ptr = r_ptr + stride * m_offset
58 _mean = tl.zeros([M_BLOCK, BLOCK_SIZE], dtype=tl.float32)
59 for offset in range(0, N_COLS, BLOCK_SIZE):
60 cols = offset + tl.arange(0, BLOCK_SIZE)
61 row_mask = m_offset < ub
62 col_mask = cols < N_COLS
63 mask = row_mask[:, None] & col_mask[None, :]
64 x = tl.load(mx_ptr[:, None] + cols[None, :], mask=mask, other=0.0).to(
65 tl.float32
66 )
67 r = tl.load(mr_ptr[:, None] + cols[None, :], mask=mask, other=0.0).to(
68 tl.float32
69 )
70 xpr = x + r
71 tl.store(mr_ptr[:, None] + cols[None, :], xpr, mask=mask)
72 _mean += xpr * xpr
74 # Since `_mean * (1 / N_COLS)` performs better, make this change.
75 # var = tl.sum(_mean / N_COLS, axis=1)
76 var = tl.sum(_mean * (1.0 / N_COLS), axis=1)
77 rrms = 1.0 / tl.sqrt(var + eps)
79 for offset in range(0, N_COLS, BLOCK_SIZE):
80 cols = offset + tl.arange(0, BLOCK_SIZE)
81 row_mask = m_offset < ub
82 col_mask = cols < N_COLS
83 mask = row_mask[:, None] & col_mask[None, :]
85 xpr = tl.load(mr_ptr[:, None] + cols[None, :], mask=mask, other=0.0).to(
86 tl.float32
87 )
88 w = tl.load(w_ptr + cols, mask=col_mask, other=0.0).to(tl.float32)
89 y = xpr * rrms[:, None]
90 y = y * w
91 y = y.to(x_ptr.dtype.element_ty)
92 tl.store(mx_ptr[:, None] + cols[None, :], y, mask=mask)
95def fused_add_rms_norm(x, residual, normalized_shape, weight, eps=1e-5):
96 """
97 This function performs fused residual addition and RMS normalization **in-place**.
98 Both `x` and `residual` tensors will be modified. Use with caution if these tensors
99 are reused elsewhere or require gradients.
100 """
101 logger.debug(
102 "GEMS_CAMBRICON FUSED_ADD_RMS_NORM FORWARD, [input shape]: %s, [residual shape]: %s, [weight shape]: %s",
103 x.size(),
104 residual.size(),
105 weight.size(),
106 )
107 dim = x.ndim - len(normalized_shape)
108 M = math.prod(x.shape[:dim])
109 N = math.prod(normalized_shape)
111 x = x.contiguous()
112 residual = residual.contiguous()
113 weight = weight.contiguous()
115 with torch_device_fn.device(x.device):
116 fused_add_rms_norm_kernel[TOTAL_CORE_NUM,](
117 x, residual, weight, eps, x.stride(dim - 1), M, N
118 )
119 return x, residual