Coverage for src/flag_gems/runtime/backend/_arm/fused/patch_qwen3_5_rmsnorm.py: 0%
69 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1"""Monkey-patch Qwen3_5RMSNorm.forward to use the existing
2tle_ops.rms_norm @builtin (NEON RMSNorm in C runtime), eliminating the
3~30 us ATen sequence (pow + mean + rsqrt + mul × 2) per call.
5Per Qwen3.5-2B decode token, RMSNorm is invoked at:
6 - input_layernorm (per decoder layer) : 24 calls
7 - post_attention_layernorm (per decoder layer) : 24 calls
8 - q_norm, k_norm (full_attention layers) : 6 layers × 2 = 12 calls
9Total: ~60 RMSNorm calls per token → ~1.8 ms/token saved if each call drops
10from 30 us to 5 us.
12Qwen3.5's RMSNorm uses `out = (x / rms(x)) * (1 + weight)` (note the +1).
13We pre-compute `_weight_plus_one = weight + 1.0` at patch time so the
14existing tle.rms_norm (which computes `out = (x / rms(x)) * w_in`) gives
15the right result.
17Decode (M=1, BF16) hits the fast path. Other shapes / dtypes fall back
18to the original forward.
19"""
20import logging
21import types
23import torch
24import triton
25import triton.language as tl
26from triton.language.extra.cpu.tle_ops import rms_norm as _tle_rms_norm
28logger = logging.getLogger(__name__)
30_PATCHED: set = set()
33@triton.jit
34def _rms_norm_kernel(x_ptr, w_ptr, out_ptr, D: tl.constexpr, eps: tl.constexpr):
35 _tle_rms_norm(x_ptr, w_ptr, out_ptr, D, eps)
38def _patched_rmsnorm_forward(self, x: torch.Tensor) -> torch.Tensor:
39 # Fast path: bf16 input, last-dim contiguous, single row.
40 if x.dtype == torch.bfloat16 and x.is_contiguous() and x.shape[-1] == self._tle_D:
41 # Reshape to [M, D] flat
42 shape = x.shape
43 D = self._tle_D
44 M = x.numel() // D
45 if M == 1:
46 xc = x.reshape(D).contiguous()
47 out = torch.empty(D, dtype=torch.bfloat16)
48 _rms_norm_kernel[(1,)](
49 xc, self._weight_plus_one_bf16, out, D=D, eps=float(self.eps)
50 )
51 return out.reshape(*shape)
52 # Multi-row: call kernel M times. For decode M=1 we never hit this.
53 out = torch.empty_like(x)
54 x_2d = x.reshape(M, D)
55 out_2d = out.reshape(M, D)
56 for i in range(M):
57 _rms_norm_kernel[(1,)](
58 x_2d[i].contiguous(),
59 self._weight_plus_one_bf16,
60 out_2d[i],
61 D=D,
62 eps=float(self.eps),
63 )
64 return out
66 # Slow / fallback path: original forward
67 return self._original_forward(x)
70def _get_qwen3_5_rmsnorm_classes():
71 classes = []
72 for modname, clsname in [
73 ("transformers.models.qwen3_5.modeling_qwen3_5", "Qwen3_5RMSNorm"),
74 ("transformers.models.qwen3_5_moe.modeling_qwen3_5_moe", "Qwen3_5RMSNorm"),
75 ("transformers.models.qwen3_next.modeling_qwen3_next", "Qwen3NextRMSNorm"),
76 ]:
77 try:
78 mod = __import__(modname, fromlist=[clsname])
79 classes.append(getattr(mod, clsname))
80 except (ImportError, AttributeError):
81 pass
82 return tuple(classes)
85def patch_qwen3_5_rmsnorm(model) -> int:
86 rms_classes = _get_qwen3_5_rmsnorm_classes()
87 if not rms_classes:
88 return 0
89 n = 0
90 for _name, mod in list(model.named_modules()):
91 if isinstance(mod, rms_classes) and id(mod) not in _PATCHED:
92 D = mod.weight.shape[0]
93 # Pre-compute weight + 1.0 in bf16 (matches Qwen3.5's RMSNorm formula)
94 mod._weight_plus_one_bf16 = (
95 (1.0 + mod.weight).to(torch.bfloat16).contiguous()
96 )
97 mod._tle_D = D
98 mod._original_forward = mod.forward
99 mod.forward = types.MethodType(_patched_rmsnorm_forward, mod)
100 _PATCHED.add(id(mod))
101 n += 1
102 if n > 0:
103 logger.info("Patched %d Qwen3.5 RMSNorm modules with TLE rms_norm", n)
104 return n
107def unpatch_qwen3_5_rmsnorm(model) -> int:
108 rms_classes = _get_qwen3_5_rmsnorm_classes()
109 if not rms_classes:
110 return 0
111 n = 0
112 for _name, mod in list(model.named_modules()):
113 if isinstance(mod, rms_classes) and id(mod) in _PATCHED:
114 if hasattr(mod, "_original_forward"):
115 mod.forward = mod._original_forward
116 del mod._original_forward
117 del mod._weight_plus_one_bf16
118 del mod._tle_D
119 _PATCHED.discard(id(mod))
120 n += 1
121 return n