Coverage for src/flag_gems/runtime/backend/_arm/fused/patch_qwen3_5_rmsnorm_gated.py: 0%
61 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1"""Monkey-patch Qwen3_5RMSNormGated.forward to use the fused multi-row
2tle_ops.rms_norm_gated builtin (NEON RMSNormGated in C runtime), replacing
3a 6-op ATen sequence (pow + mean + rsqrt + mul × 2 + silu*mul) and turning
4M single-row calls into one OMP-parallel multi-row kernel.
6Per Qwen3.5-2B GDN decode token, RMSNormGated is invoked at:
7 - GDN per-head norm (each linear_attention layer): 1 call/layer × 6 GDN
8 layers × 1 token = 6 calls/tok, each shape [M=num_v_heads, D=head_v_dim]
9 typically [16, 128].
11Reference formula:
12 out = (x / rms(x)) * weight * silu(gate)
14Decode (BF16, [M, D] last-dim contiguous, M aligned). Other shapes /
15dtypes fall back to the original forward.
16"""
17import logging
18import types
20import torch
21import triton
22import triton.language as tl
23from triton.language.extra.cpu.tle_ops import rms_norm_gated as _tle_rms_norm_gated
25logger = logging.getLogger(__name__)
27_PATCHED: set = set()
30@triton.jit
31def _rms_norm_gated_kernel(
32 x_ptr, gate_ptr, w_ptr, out_ptr, M: tl.constexpr, D: tl.constexpr, eps: tl.constexpr
33):
34 _tle_rms_norm_gated(x_ptr, gate_ptr, w_ptr, out_ptr, M, D, eps)
37def _patched_rmsnorm_gated_forward(self, hidden_states, gate=None):
38 if (
39 gate is not None
40 and hidden_states.dtype == torch.bfloat16
41 and gate.dtype == torch.bfloat16
42 and hidden_states.is_contiguous()
43 and gate.is_contiguous()
44 and hidden_states.shape == gate.shape
45 and hidden_states.shape[-1] == self._tle_D
46 ):
47 shape = hidden_states.shape
48 D = self._tle_D
49 M = hidden_states.numel() // D
50 x_flat = hidden_states.reshape(M, D).contiguous()
51 g_flat = gate.reshape(M, D).contiguous()
52 out = torch.empty_like(x_flat)
53 _rms_norm_gated_kernel[(1,)](
54 x_flat,
55 g_flat,
56 self.weight.to(torch.bfloat16),
57 out,
58 M=M,
59 D=D,
60 eps=float(self.variance_epsilon),
61 )
62 return out.reshape(*shape)
64 # Fallback: original forward
65 return self._original_forward(hidden_states, gate)
68def _get_qwen3_5_rmsnorm_gated_classes():
69 classes = []
70 for modname, clsname in [
71 ("transformers.models.qwen3_5.modeling_qwen3_5", "Qwen3_5RMSNormGated"),
72 ("transformers.models.qwen3_5_moe.modeling_qwen3_5_moe", "Qwen3_5RMSNormGated"),
73 ("transformers.models.qwen3_next.modeling_qwen3_next", "Qwen3NextRMSNormGated"),
74 ]:
75 try:
76 mod = __import__(modname, fromlist=[clsname])
77 classes.append(getattr(mod, clsname))
78 except (ImportError, AttributeError):
79 pass
80 return tuple(classes)
83def patch_qwen3_5_rmsnorm_gated(model) -> int:
84 rms_classes = _get_qwen3_5_rmsnorm_gated_classes()
85 if not rms_classes:
86 return 0
87 n = 0
88 for _name, mod in list(model.named_modules()):
89 if isinstance(mod, rms_classes) and id(mod) not in _PATCHED:
90 D = mod.weight.shape[0]
91 mod._tle_D = D
92 mod._original_forward = mod.forward
93 mod.forward = types.MethodType(_patched_rmsnorm_gated_forward, mod)
94 _PATCHED.add(id(mod))
95 n += 1
96 if n > 0:
97 logger.info(
98 "Patched %d Qwen3.5 RMSNormGated modules with TLE rms_norm_gated", n
99 )
100 return n
103def unpatch_qwen3_5_rmsnorm_gated(model) -> int:
104 rms_classes = _get_qwen3_5_rmsnorm_gated_classes()
105 if not rms_classes:
106 return 0
107 n = 0
108 for _name, mod in list(model.named_modules()):
109 if isinstance(mod, rms_classes) and id(mod) in _PATCHED:
110 if hasattr(mod, "_original_forward"):
111 mod.forward = mod._original_forward
112 del mod._original_forward
113 del mod._tle_D
114 _PATCHED.discard(id(mod))
115 n += 1
116 return n