Coverage for src/flag_gems/runtime/backend/_arm/fused/patch_qwen3_5_rmsnorm_gated.py: 0%

61 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1"""Monkey-patch Qwen3_5RMSNormGated.forward to use the fused multi-row 

2tle_ops.rms_norm_gated builtin (NEON RMSNormGated in C runtime), replacing 

3a 6-op ATen sequence (pow + mean + rsqrt + mul × 2 + silu*mul) and turning 

4M single-row calls into one OMP-parallel multi-row kernel. 

5 

6Per Qwen3.5-2B GDN decode token, RMSNormGated is invoked at: 

7 - GDN per-head norm (each linear_attention layer): 1 call/layer × 6 GDN 

8 layers × 1 token = 6 calls/tok, each shape [M=num_v_heads, D=head_v_dim] 

9 typically [16, 128]. 

10 

11Reference formula: 

12 out = (x / rms(x)) * weight * silu(gate) 

13 

14Decode (BF16, [M, D] last-dim contiguous, M aligned). Other shapes / 

15dtypes fall back to the original forward. 

16""" 

17import logging 

18import types 

19 

20import torch 

21import triton 

22import triton.language as tl 

23from triton.language.extra.cpu.tle_ops import rms_norm_gated as _tle_rms_norm_gated 

24 

25logger = logging.getLogger(__name__) 

26 

27_PATCHED: set = set() 

28 

29 

30@triton.jit 

31def _rms_norm_gated_kernel( 

32 x_ptr, gate_ptr, w_ptr, out_ptr, M: tl.constexpr, D: tl.constexpr, eps: tl.constexpr 

33): 

34 _tle_rms_norm_gated(x_ptr, gate_ptr, w_ptr, out_ptr, M, D, eps) 

35 

36 

37def _patched_rmsnorm_gated_forward(self, hidden_states, gate=None): 

38 if ( 

39 gate is not None 

40 and hidden_states.dtype == torch.bfloat16 

41 and gate.dtype == torch.bfloat16 

42 and hidden_states.is_contiguous() 

43 and gate.is_contiguous() 

44 and hidden_states.shape == gate.shape 

45 and hidden_states.shape[-1] == self._tle_D 

46 ): 

47 shape = hidden_states.shape 

48 D = self._tle_D 

49 M = hidden_states.numel() // D 

50 x_flat = hidden_states.reshape(M, D).contiguous() 

51 g_flat = gate.reshape(M, D).contiguous() 

52 out = torch.empty_like(x_flat) 

53 _rms_norm_gated_kernel[(1,)]( 

54 x_flat, 

55 g_flat, 

56 self.weight.to(torch.bfloat16), 

57 out, 

58 M=M, 

59 D=D, 

60 eps=float(self.variance_epsilon), 

61 ) 

62 return out.reshape(*shape) 

63 

64 # Fallback: original forward 

65 return self._original_forward(hidden_states, gate) 

66 

67 

68def _get_qwen3_5_rmsnorm_gated_classes(): 

69 classes = [] 

70 for modname, clsname in [ 

71 ("transformers.models.qwen3_5.modeling_qwen3_5", "Qwen3_5RMSNormGated"), 

72 ("transformers.models.qwen3_5_moe.modeling_qwen3_5_moe", "Qwen3_5RMSNormGated"), 

73 ("transformers.models.qwen3_next.modeling_qwen3_next", "Qwen3NextRMSNormGated"), 

74 ]: 

75 try: 

76 mod = __import__(modname, fromlist=[clsname]) 

77 classes.append(getattr(mod, clsname)) 

78 except (ImportError, AttributeError): 

79 pass 

80 return tuple(classes) 

81 

82 

83def patch_qwen3_5_rmsnorm_gated(model) -> int: 

84 rms_classes = _get_qwen3_5_rmsnorm_gated_classes() 

85 if not rms_classes: 

86 return 0 

87 n = 0 

88 for _name, mod in list(model.named_modules()): 

89 if isinstance(mod, rms_classes) and id(mod) not in _PATCHED: 

90 D = mod.weight.shape[0] 

91 mod._tle_D = D 

92 mod._original_forward = mod.forward 

93 mod.forward = types.MethodType(_patched_rmsnorm_gated_forward, mod) 

94 _PATCHED.add(id(mod)) 

95 n += 1 

96 if n > 0: 

97 logger.info( 

98 "Patched %d Qwen3.5 RMSNormGated modules with TLE rms_norm_gated", n 

99 ) 

100 return n 

101 

102 

103def unpatch_qwen3_5_rmsnorm_gated(model) -> int: 

104 rms_classes = _get_qwen3_5_rmsnorm_gated_classes() 

105 if not rms_classes: 

106 return 0 

107 n = 0 

108 for _name, mod in list(model.named_modules()): 

109 if isinstance(mod, rms_classes) and id(mod) in _PATCHED: 

110 if hasattr(mod, "_original_forward"): 

111 mod.forward = mod._original_forward 

112 del mod._original_forward 

113 del mod._tle_D 

114 _PATCHED.discard(id(mod)) 

115 n += 1 

116 return n