Coverage for src/flag_gems/runtime/backend/_arm/fused/fused_add_rms_norm.py: 0%

57 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1"""ARM-optimized fused residual-add + RMS normalization. 

2 

3Replaces 5 separate Triton kernel launches (add, pow, mean, rsqrt, mul) 

4with a single kernel launch. For decode shapes (M=1, N=896) this reduces 

5overhead from 5 × ~9μs ≈ 45μs to 1 × ~9μs ≈ 9μs per layer. 

6 

7Uses a two-pass tiled approach with small BLOCK_SIZE (128) to avoid 

8extremely slow LLVM compilation with large vector widths on ARM. 

9""" 

10 

11import logging 

12import math 

13import os 

14 

15import torch 

16import triton 

17import triton.language as tl 

18 

19from flag_gems.utils import triton_lang_extension as tle 

20 

21logger = logging.getLogger(__name__) 

22 

23_PREWARM_DONE = False 

24_PREWARM_ENABLED = os.environ.get("GEMS_ARM_FUSED_RMS_PREWARM", "1") == "1" 

25 

26# Use small block size to keep LLVM compilation fast (~seconds not minutes) 

27_TILE_SIZE = 128 

28 

29 

30@triton.jit(do_not_specialize=["eps"]) 

31def _fused_add_rms_norm_kernel( 

32 input_ptr, 

33 residual_ptr, 

34 weight_ptr, 

35 in_stride_r, 

36 r_stride_r, 

37 N, 

38 eps, 

39 BLOCK_SIZE: tl.constexpr, 

40): 

41 """Fused: residual += input; output = rms_norm(residual) * weight. 

42 

43 Two-pass tiled approach: 

44 Pass 1: Load tiles, compute x=input+residual, store residual, accumulate x^2 

45 Pass 2: Load tiles of x (from residual), compute normalized output 

46 """ 

47 pid = tle.program_id(0) 

48 in_row = input_ptr + pid * in_stride_r 

49 r_row = residual_ptr + pid * r_stride_r 

50 

51 # Pass 1: fused add + store residual + accumulate variance 

52 sum_sq = tl.zeros([1], dtype=tl.float32) 

53 for off in range(0, N, BLOCK_SIZE): 

54 cols = off + tl.arange(0, BLOCK_SIZE) 

55 mask = cols < N 

56 

57 x = tl.load(in_row + cols, mask=mask, other=0.0).to(tl.float32) 

58 r = tl.load(r_row + cols, mask=mask, other=0.0).to(tl.float32) 

59 

60 x = x + r 

61 

62 # Store updated residual 

63 tl.store(r_row + cols, x.to(residual_ptr.dtype.element_ty), mask=mask) 

64 

65 sum_sq += tl.sum(x * x, axis=0) 

66 

67 # Compute rrms 

68 var = sum_sq / N 

69 rrms = 1.0 / tl.sqrt(var + eps) 

70 

71 # Pass 2: load residual (=x+r), normalize, multiply by weight, store output 

72 for off in range(0, N, BLOCK_SIZE): 

73 cols = off + tl.arange(0, BLOCK_SIZE) 

74 mask = cols < N 

75 

76 # Read back the updated residual (which is x+r in original dtype) 

77 x = tl.load(r_row + cols, mask=mask, other=0.0).to(tl.float32) 

78 w = tl.load(weight_ptr + cols, mask=mask, other=0.0) 

79 

80 y = (x * rrms).to(input_ptr.dtype.element_ty) * w 

81 tl.store(in_row + cols, y, mask=mask) 

82 

83 

84# Note: standalone _rms_norm_kernel (without residual add) was removed after 

85# A/B measurement showed zero E2E benefit vs ATen's Qwen3RMSNorm on BF16 M=1 

86# (see test_tle_phase1_plus.py ENABLE_RMSNORM_PATCH A/B, 3 rounds: 

87# ON=9.93 tok/s, OFF=9.97 tok/s — within noise). 

88# The fused add+rmsnorm path is kept because it saves a residual-add memory 

89# roundtrip and is used by vLLM's forward_cpu when residual is present. 

90 

91 

92def _maybe_prewarm(): 

93 global _PREWARM_DONE 

94 if _PREWARM_DONE or not _PREWARM_ENABLED: 

95 _PREWARM_DONE = True 

96 return 

97 try: 

98 for dt in (torch.float32,): 

99 x = torch.ones((1, _TILE_SIZE), dtype=dt, device="cpu") 

100 r = torch.ones((1, _TILE_SIZE), dtype=dt, device="cpu") 

101 w = torch.ones(_TILE_SIZE, dtype=dt, device="cpu") 

102 _fused_add_rms_norm_kernel[(1,)]( 

103 x, 

104 r, 

105 w, 

106 _TILE_SIZE, 

107 _TILE_SIZE, 

108 _TILE_SIZE, 

109 1e-6, 

110 BLOCK_SIZE=_TILE_SIZE, 

111 num_warps=1, 

112 num_stages=1, 

113 ) 

114 except Exception: 

115 logger.debug("GEMS ARM fused RMSNorm prewarm failed", exc_info=True) 

116 _PREWARM_DONE = True 

117 

118 

119def fused_add_rms_norm(x, residual, normalized_shape, weight, eps=1e-5): 

120 """Fused residual-add + RMS normalization (in-place). 

121 

122 Modifies both x and residual tensors in-place: 

123 residual = x + residual 

124 x = rms_norm(residual) * weight 

125 

126 Returns: (x, residual) - both modified in-place. 

127 """ 

128 _maybe_prewarm() 

129 

130 dim = x.ndim - len(normalized_shape) 

131 M = math.prod(x.shape[:dim]) 

132 N = math.prod(normalized_shape) 

133 

134 x = x.contiguous() 

135 residual = residual.contiguous() 

136 weight = weight.contiguous() 

137 

138 _fused_add_rms_norm_kernel[(M,)]( 

139 x, 

140 residual, 

141 weight, 

142 N, # in_stride_r (contiguous: stride = N) 

143 N, # r_stride_r 

144 N, 

145 eps, 

146 BLOCK_SIZE=_TILE_SIZE, 

147 num_warps=1, 

148 num_stages=1, 

149 ) 

150 return x, residual 

151 

152 

153# rms_norm_forward() (standalone RMSNorm without residual) removed: A/B 

154# measurement on Qwen3-1.7B INT8 decode showed no measurable benefit over 

155# ATen's native Qwen3RMSNorm.forward (9.93 vs 9.97 tok/s, within noise).