Coverage for src/flag_gems/runtime/backend/_arm/fused/patch_qwen3_layer_norm.py: 0%

47 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1"""Patch Qwen3DecoderLayer.forward to use fused add+RMSNorm for the 

2post-attention layer-norm step. 

3 

4The vanilla forward does: 

5 residual = hidden_states 

6 hidden_states = self.input_layernorm(hidden_states) 

7 hidden_states, _ = self.self_attn(...) 

8 hidden_states = residual + hidden_states # ATen add (M=1×D) 

9 residual = hidden_states # alias 

10 hidden_states = self.post_attention_layernorm(hidden_states) # ATen rmsnorm 

11 hidden_states = self.mlp(hidden_states) 

12 hidden_states = residual + hidden_states # ATen add 

13 return hidden_states 

14 

15We fuse the highlighted add + post_attention_layernorm into a single call. 

16This drops 1 ATen add + 1 ATen rmsnorm dispatch per layer. 

17 

18Skips fusion when: 

19- shape is non-decode (T>1) 

20- dtype is not BF16 

21""" 

22import logging 

23 

24import torch 

25 

26from flag_gems.runtime.backend._arm.fused.fused_add_rms_norm import fused_add_rms_norm 

27 

28logger = logging.getLogger(__name__) 

29_PATCHED: dict = {} 

30 

31 

32def _make_patched_forward(orig_forward): 

33 def patched_forward(self, hidden_states, **kwargs): 

34 # Eligibility: decode T=1, BF16 only 

35 if not ( 

36 hidden_states.dim() == 3 

37 and hidden_states.shape[1] == 1 

38 and hidden_states.dtype == torch.bfloat16 

39 ): 

40 return orig_forward(self, hidden_states, **kwargs) 

41 

42 residual = hidden_states 

43 hidden_states = self.input_layernorm(hidden_states) 

44 hidden_states, _ = self.self_attn( 

45 hidden_states=hidden_states, 

46 **kwargs, 

47 ) 

48 # Fuse: residual = residual + hidden_states ; hidden_states = post_attn_ln(residual) 

49 # fused_add_rms_norm modifies both in-place: residual := residual + hidden_states, 

50 # hidden_states := rms_norm(residual) * weight 

51 hidden_states, residual = fused_add_rms_norm( 

52 hidden_states.contiguous(), 

53 residual.contiguous(), 

54 normalized_shape=(self.post_attention_layernorm.weight.shape[0],), 

55 weight=self.post_attention_layernorm.weight, 

56 eps=self.post_attention_layernorm.variance_epsilon, 

57 ) 

58 

59 hidden_states = self.mlp(hidden_states) 

60 hidden_states = residual + hidden_states 

61 return hidden_states 

62 

63 return patched_forward 

64 

65 

66def patch_qwen3_layer_norm() -> int: 

67 """Monkey-patch Qwen3DecoderLayer.forward to use fused add+rmsnorm. 

68 

69 Returns count of patched modules (qwen3 + qwen3_5). 

70 """ 

71 # Targets regular Qwen3 only. Qwen3.5 has GDN/mamba layers with 

72 # different forward structure; not applicable here. 

73 targets = [ 

74 "transformers.models.qwen3.modeling_qwen3", 

75 ] 

76 n = 0 

77 for modname in targets: 

78 try: 

79 mod = __import__(modname, fromlist=["Qwen3DecoderLayer"]) 

80 except (ImportError, AttributeError): 

81 continue 

82 cls_name = ( 

83 "Qwen3DecoderLayer" if "qwen3_5" not in modname else "Qwen3_5DecoderLayer" 

84 ) 

85 if not hasattr(mod, cls_name): 

86 cls_name = "Qwen3DecoderLayer" 

87 if not hasattr(mod, cls_name): 

88 continue 

89 cls = getattr(mod, cls_name) 

90 key = (modname, cls_name) 

91 if key in _PATCHED: 

92 continue 

93 orig = cls.forward 

94 _PATCHED[key] = (cls, orig) 

95 cls.forward = _make_patched_forward(orig) 

96 n += 1 

97 logger.info(f"Patched {modname}.{cls_name}.forward") 

98 return n 

99 

100 

101def unpatch_qwen3_layer_norm() -> int: 

102 n = 0 

103 for key, (cls, orig) in list(_PATCHED.items()): 

104 cls.forward = orig 

105 del _PATCHED[key] 

106 n += 1 

107 return n