Coverage for src/flag_gems/runtime/backend/_arm/fused/patch_qwen3_layer_norm.py: 0%
47 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1"""Patch Qwen3DecoderLayer.forward to use fused add+RMSNorm for the
2post-attention layer-norm step.
4The vanilla forward does:
5 residual = hidden_states
6 hidden_states = self.input_layernorm(hidden_states)
7 hidden_states, _ = self.self_attn(...)
8 hidden_states = residual + hidden_states # ATen add (M=1×D)
9 residual = hidden_states # alias
10 hidden_states = self.post_attention_layernorm(hidden_states) # ATen rmsnorm
11 hidden_states = self.mlp(hidden_states)
12 hidden_states = residual + hidden_states # ATen add
13 return hidden_states
15We fuse the highlighted add + post_attention_layernorm into a single call.
16This drops 1 ATen add + 1 ATen rmsnorm dispatch per layer.
18Skips fusion when:
19- shape is non-decode (T>1)
20- dtype is not BF16
21"""
22import logging
24import torch
26from flag_gems.runtime.backend._arm.fused.fused_add_rms_norm import fused_add_rms_norm
28logger = logging.getLogger(__name__)
29_PATCHED: dict = {}
32def _make_patched_forward(orig_forward):
33 def patched_forward(self, hidden_states, **kwargs):
34 # Eligibility: decode T=1, BF16 only
35 if not (
36 hidden_states.dim() == 3
37 and hidden_states.shape[1] == 1
38 and hidden_states.dtype == torch.bfloat16
39 ):
40 return orig_forward(self, hidden_states, **kwargs)
42 residual = hidden_states
43 hidden_states = self.input_layernorm(hidden_states)
44 hidden_states, _ = self.self_attn(
45 hidden_states=hidden_states,
46 **kwargs,
47 )
48 # Fuse: residual = residual + hidden_states ; hidden_states = post_attn_ln(residual)
49 # fused_add_rms_norm modifies both in-place: residual := residual + hidden_states,
50 # hidden_states := rms_norm(residual) * weight
51 hidden_states, residual = fused_add_rms_norm(
52 hidden_states.contiguous(),
53 residual.contiguous(),
54 normalized_shape=(self.post_attention_layernorm.weight.shape[0],),
55 weight=self.post_attention_layernorm.weight,
56 eps=self.post_attention_layernorm.variance_epsilon,
57 )
59 hidden_states = self.mlp(hidden_states)
60 hidden_states = residual + hidden_states
61 return hidden_states
63 return patched_forward
66def patch_qwen3_layer_norm() -> int:
67 """Monkey-patch Qwen3DecoderLayer.forward to use fused add+rmsnorm.
69 Returns count of patched modules (qwen3 + qwen3_5).
70 """
71 # Targets regular Qwen3 only. Qwen3.5 has GDN/mamba layers with
72 # different forward structure; not applicable here.
73 targets = [
74 "transformers.models.qwen3.modeling_qwen3",
75 ]
76 n = 0
77 for modname in targets:
78 try:
79 mod = __import__(modname, fromlist=["Qwen3DecoderLayer"])
80 except (ImportError, AttributeError):
81 continue
82 cls_name = (
83 "Qwen3DecoderLayer" if "qwen3_5" not in modname else "Qwen3_5DecoderLayer"
84 )
85 if not hasattr(mod, cls_name):
86 cls_name = "Qwen3DecoderLayer"
87 if not hasattr(mod, cls_name):
88 continue
89 cls = getattr(mod, cls_name)
90 key = (modname, cls_name)
91 if key in _PATCHED:
92 continue
93 orig = cls.forward
94 _PATCHED[key] = (cls, orig)
95 cls.forward = _make_patched_forward(orig)
96 n += 1
97 logger.info(f"Patched {modname}.{cls_name}.forward")
98 return n
101def unpatch_qwen3_layer_norm() -> int:
102 n = 0
103 for key, (cls, orig) in list(_PATCHED.items()):
104 cls.forward = orig
105 del _PATCHED[key]
106 n += 1
107 return n