Coverage for src/flag_gems/runtime/backend/_arm/fused/patch_qwen3_5_rmsnorm.py: 0%

69 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1"""Monkey-patch Qwen3_5RMSNorm.forward to use the existing 

2tle_ops.rms_norm @builtin (NEON RMSNorm in C runtime), eliminating the 

3~30 us ATen sequence (pow + mean + rsqrt + mul × 2) per call. 

4 

5Per Qwen3.5-2B decode token, RMSNorm is invoked at: 

6 - input_layernorm (per decoder layer) : 24 calls 

7 - post_attention_layernorm (per decoder layer) : 24 calls 

8 - q_norm, k_norm (full_attention layers) : 6 layers × 2 = 12 calls 

9Total: ~60 RMSNorm calls per token → ~1.8 ms/token saved if each call drops 

10from 30 us to 5 us. 

11 

12Qwen3.5's RMSNorm uses `out = (x / rms(x)) * (1 + weight)` (note the +1). 

13We pre-compute `_weight_plus_one = weight + 1.0` at patch time so the 

14existing tle.rms_norm (which computes `out = (x / rms(x)) * w_in`) gives 

15the right result. 

16 

17Decode (M=1, BF16) hits the fast path. Other shapes / dtypes fall back 

18to the original forward. 

19""" 

20import logging 

21import types 

22 

23import torch 

24import triton 

25import triton.language as tl 

26from triton.language.extra.cpu.tle_ops import rms_norm as _tle_rms_norm 

27 

28logger = logging.getLogger(__name__) 

29 

30_PATCHED: set = set() 

31 

32 

33@triton.jit 

34def _rms_norm_kernel(x_ptr, w_ptr, out_ptr, D: tl.constexpr, eps: tl.constexpr): 

35 _tle_rms_norm(x_ptr, w_ptr, out_ptr, D, eps) 

36 

37 

38def _patched_rmsnorm_forward(self, x: torch.Tensor) -> torch.Tensor: 

39 # Fast path: bf16 input, last-dim contiguous, single row. 

40 if x.dtype == torch.bfloat16 and x.is_contiguous() and x.shape[-1] == self._tle_D: 

41 # Reshape to [M, D] flat 

42 shape = x.shape 

43 D = self._tle_D 

44 M = x.numel() // D 

45 if M == 1: 

46 xc = x.reshape(D).contiguous() 

47 out = torch.empty(D, dtype=torch.bfloat16) 

48 _rms_norm_kernel[(1,)]( 

49 xc, self._weight_plus_one_bf16, out, D=D, eps=float(self.eps) 

50 ) 

51 return out.reshape(*shape) 

52 # Multi-row: call kernel M times. For decode M=1 we never hit this. 

53 out = torch.empty_like(x) 

54 x_2d = x.reshape(M, D) 

55 out_2d = out.reshape(M, D) 

56 for i in range(M): 

57 _rms_norm_kernel[(1,)]( 

58 x_2d[i].contiguous(), 

59 self._weight_plus_one_bf16, 

60 out_2d[i], 

61 D=D, 

62 eps=float(self.eps), 

63 ) 

64 return out 

65 

66 # Slow / fallback path: original forward 

67 return self._original_forward(x) 

68 

69 

70def _get_qwen3_5_rmsnorm_classes(): 

71 classes = [] 

72 for modname, clsname in [ 

73 ("transformers.models.qwen3_5.modeling_qwen3_5", "Qwen3_5RMSNorm"), 

74 ("transformers.models.qwen3_5_moe.modeling_qwen3_5_moe", "Qwen3_5RMSNorm"), 

75 ("transformers.models.qwen3_next.modeling_qwen3_next", "Qwen3NextRMSNorm"), 

76 ]: 

77 try: 

78 mod = __import__(modname, fromlist=[clsname]) 

79 classes.append(getattr(mod, clsname)) 

80 except (ImportError, AttributeError): 

81 pass 

82 return tuple(classes) 

83 

84 

85def patch_qwen3_5_rmsnorm(model) -> int: 

86 rms_classes = _get_qwen3_5_rmsnorm_classes() 

87 if not rms_classes: 

88 return 0 

89 n = 0 

90 for _name, mod in list(model.named_modules()): 

91 if isinstance(mod, rms_classes) and id(mod) not in _PATCHED: 

92 D = mod.weight.shape[0] 

93 # Pre-compute weight + 1.0 in bf16 (matches Qwen3.5's RMSNorm formula) 

94 mod._weight_plus_one_bf16 = ( 

95 (1.0 + mod.weight).to(torch.bfloat16).contiguous() 

96 ) 

97 mod._tle_D = D 

98 mod._original_forward = mod.forward 

99 mod.forward = types.MethodType(_patched_rmsnorm_forward, mod) 

100 _PATCHED.add(id(mod)) 

101 n += 1 

102 if n > 0: 

103 logger.info("Patched %d Qwen3.5 RMSNorm modules with TLE rms_norm", n) 

104 return n 

105 

106 

107def unpatch_qwen3_5_rmsnorm(model) -> int: 

108 rms_classes = _get_qwen3_5_rmsnorm_classes() 

109 if not rms_classes: 

110 return 0 

111 n = 0 

112 for _name, mod in list(model.named_modules()): 

113 if isinstance(mod, rms_classes) and id(mod) in _PATCHED: 

114 if hasattr(mod, "_original_forward"): 

115 mod.forward = mod._original_forward 

116 del mod._original_forward 

117 del mod._weight_plus_one_bf16 

118 del mod._tle_D 

119 _PATCHED.discard(id(mod)) 

120 n += 1 

121 return n