Coverage for src/flag_gems/runtime/backend/_arm/fused/fused_add_rms_norm.py: 0%
57 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1"""ARM-optimized fused residual-add + RMS normalization.
3Replaces 5 separate Triton kernel launches (add, pow, mean, rsqrt, mul)
4with a single kernel launch. For decode shapes (M=1, N=896) this reduces
5overhead from 5 × ~9μs ≈ 45μs to 1 × ~9μs ≈ 9μs per layer.
7Uses a two-pass tiled approach with small BLOCK_SIZE (128) to avoid
8extremely slow LLVM compilation with large vector widths on ARM.
9"""
11import logging
12import math
13import os
15import torch
16import triton
17import triton.language as tl
19from flag_gems.utils import triton_lang_extension as tle
21logger = logging.getLogger(__name__)
23_PREWARM_DONE = False
24_PREWARM_ENABLED = os.environ.get("GEMS_ARM_FUSED_RMS_PREWARM", "1") == "1"
26# Use small block size to keep LLVM compilation fast (~seconds not minutes)
27_TILE_SIZE = 128
30@triton.jit(do_not_specialize=["eps"])
31def _fused_add_rms_norm_kernel(
32 input_ptr,
33 residual_ptr,
34 weight_ptr,
35 in_stride_r,
36 r_stride_r,
37 N,
38 eps,
39 BLOCK_SIZE: tl.constexpr,
40):
41 """Fused: residual += input; output = rms_norm(residual) * weight.
43 Two-pass tiled approach:
44 Pass 1: Load tiles, compute x=input+residual, store residual, accumulate x^2
45 Pass 2: Load tiles of x (from residual), compute normalized output
46 """
47 pid = tle.program_id(0)
48 in_row = input_ptr + pid * in_stride_r
49 r_row = residual_ptr + pid * r_stride_r
51 # Pass 1: fused add + store residual + accumulate variance
52 sum_sq = tl.zeros([1], dtype=tl.float32)
53 for off in range(0, N, BLOCK_SIZE):
54 cols = off + tl.arange(0, BLOCK_SIZE)
55 mask = cols < N
57 x = tl.load(in_row + cols, mask=mask, other=0.0).to(tl.float32)
58 r = tl.load(r_row + cols, mask=mask, other=0.0).to(tl.float32)
60 x = x + r
62 # Store updated residual
63 tl.store(r_row + cols, x.to(residual_ptr.dtype.element_ty), mask=mask)
65 sum_sq += tl.sum(x * x, axis=0)
67 # Compute rrms
68 var = sum_sq / N
69 rrms = 1.0 / tl.sqrt(var + eps)
71 # Pass 2: load residual (=x+r), normalize, multiply by weight, store output
72 for off in range(0, N, BLOCK_SIZE):
73 cols = off + tl.arange(0, BLOCK_SIZE)
74 mask = cols < N
76 # Read back the updated residual (which is x+r in original dtype)
77 x = tl.load(r_row + cols, mask=mask, other=0.0).to(tl.float32)
78 w = tl.load(weight_ptr + cols, mask=mask, other=0.0)
80 y = (x * rrms).to(input_ptr.dtype.element_ty) * w
81 tl.store(in_row + cols, y, mask=mask)
84# Note: standalone _rms_norm_kernel (without residual add) was removed after
85# A/B measurement showed zero E2E benefit vs ATen's Qwen3RMSNorm on BF16 M=1
86# (see test_tle_phase1_plus.py ENABLE_RMSNORM_PATCH A/B, 3 rounds:
87# ON=9.93 tok/s, OFF=9.97 tok/s — within noise).
88# The fused add+rmsnorm path is kept because it saves a residual-add memory
89# roundtrip and is used by vLLM's forward_cpu when residual is present.
92def _maybe_prewarm():
93 global _PREWARM_DONE
94 if _PREWARM_DONE or not _PREWARM_ENABLED:
95 _PREWARM_DONE = True
96 return
97 try:
98 for dt in (torch.float32,):
99 x = torch.ones((1, _TILE_SIZE), dtype=dt, device="cpu")
100 r = torch.ones((1, _TILE_SIZE), dtype=dt, device="cpu")
101 w = torch.ones(_TILE_SIZE, dtype=dt, device="cpu")
102 _fused_add_rms_norm_kernel[(1,)](
103 x,
104 r,
105 w,
106 _TILE_SIZE,
107 _TILE_SIZE,
108 _TILE_SIZE,
109 1e-6,
110 BLOCK_SIZE=_TILE_SIZE,
111 num_warps=1,
112 num_stages=1,
113 )
114 except Exception:
115 logger.debug("GEMS ARM fused RMSNorm prewarm failed", exc_info=True)
116 _PREWARM_DONE = True
119def fused_add_rms_norm(x, residual, normalized_shape, weight, eps=1e-5):
120 """Fused residual-add + RMS normalization (in-place).
122 Modifies both x and residual tensors in-place:
123 residual = x + residual
124 x = rms_norm(residual) * weight
126 Returns: (x, residual) - both modified in-place.
127 """
128 _maybe_prewarm()
130 dim = x.ndim - len(normalized_shape)
131 M = math.prod(x.shape[:dim])
132 N = math.prod(normalized_shape)
134 x = x.contiguous()
135 residual = residual.contiguous()
136 weight = weight.contiguous()
138 _fused_add_rms_norm_kernel[(M,)](
139 x,
140 residual,
141 weight,
142 N, # in_stride_r (contiguous: stride = N)
143 N, # r_stride_r
144 N,
145 eps,
146 BLOCK_SIZE=_TILE_SIZE,
147 num_warps=1,
148 num_stages=1,
149 )
150 return x, residual
153# rms_norm_forward() (standalone RMSNorm without residual) removed: A/B
154# measurement on Qwen3-1.7B INT8 decode showed no measurable benefit over
155# ATen's native Qwen3RMSNorm.forward (9.93 vs 9.97 tok/s, within noise).