Coverage for src/flag_gems/runtime/backend/_arm/fused/patch_qwen3_5_gated_delta.py: 0%

72 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1"""Monkey-patch Qwen3_5GatedDeltaNet.recurrent_gated_delta_rule with a 

2TLE-CPU fused decode kernel. 

3 

4Replaces the per-layer ATen sequence 

5 

6 state *= exp(g) 

7 kv_mem = (state * k.unsqueeze(-1)).sum(dim=-2) 

8 delta = (v - kv_mem) * beta 

9 state += k.unsqueeze(-1) * delta.unsqueeze(-2) 

10 out = (state * q.unsqueeze(-1)).sum(dim=-2) 

11 

12with a single fused @triton.jit kernel that calls the TLE builtin 

13`tle_ops.gated_delta_decode`, which dispatches to the NEON C runtime 

14`standalone_gated_delta_decode_fp32`. 

15 

16State update + output dot are fused into a single sweep over state, matching 

17llama.cpp's Metal/SYCL backends (their CPU kernel does these as 2 separate 

18passes). 

19 

20Decode (T=1) only — prefill (T>1) falls back to torch_chunk_gated_delta_rule. 

21""" 

22 

23import logging 

24 

25import torch 

26import triton 

27import triton.language as tl 

28from triton.language.extra.cpu.tle_ops import ( 

29 gated_delta_decode as _tle_gated_delta_decode, 

30) 

31 

32logger = logging.getLogger(__name__) 

33 

34_PATCHED: set = set() 

35 

36 

37@triton.jit 

38def _gated_delta_decode_kernel( 

39 q_ptr, 

40 k_ptr, 

41 v_ptr, 

42 g_ptr, 

43 beta_ptr, 

44 state_ptr, 

45 out_ptr, 

46 B: tl.constexpr, 

47 H: tl.constexpr, 

48 k_dim: tl.constexpr, 

49 v_dim: tl.constexpr, 

50 use_l2norm: tl.constexpr, 

51): 

52 _tle_gated_delta_decode( 

53 q_ptr, 

54 k_ptr, 

55 v_ptr, 

56 g_ptr, 

57 beta_ptr, 

58 state_ptr, 

59 out_ptr, 

60 B, 

61 H, 

62 k_dim, 

63 v_dim, 

64 use_l2norm, 

65 ) 

66 

67 

68def _patched_recurrent_gated_delta_rule( 

69 query, 

70 key, 

71 value, 

72 g, 

73 beta, 

74 initial_state, 

75 output_final_state, 

76 use_qk_l2norm_in_kernel=False, 

77): 

78 """Drop-in for torch_recurrent_gated_delta_rule on T=1 decode path. 

79 

80 Shapes (matching HF): 

81 query, key: [B, T, H, k_dim] (any dtype; cast to fp32 internally) 

82 value: [B, T, H, v_dim] 

83 g, beta: [B, T, H] 

84 initial_state: [B, H, k_dim, v_dim] fp32, or None 

85 

86 Returns: 

87 core_attn_out: [B, T, H, v_dim] cast back to query.dtype 

88 last_recurrent_state: [B, H, k_dim, v_dim] fp32 (or None) 

89 

90 For T>1, falls back to the original torch implementation in the host 

91 module (caller passes that in via the closure during patching). 

92 """ 

93 raise NotImplementedError("install via patch_qwen3_5_gated_delta(model)") 

94 

95 

96def _make_patched_fn(torch_recurrent_fn): 

97 def fn( 

98 query, 

99 key, 

100 value, 

101 g, 

102 beta, 

103 initial_state, 

104 output_final_state, 

105 use_qk_l2norm_in_kernel=False, 

106 ): 

107 B, T, H, k_dim = query.shape 

108 v_dim = value.shape[-1] 

109 

110 # Prefill or any non-decode shape: defer to the torch reference. 

111 if T != 1 or k_dim > 256 or v_dim > 256 or k_dim % 4 != 0 or v_dim % 4 != 0: 

112 return torch_recurrent_fn( 

113 query, 

114 key, 

115 value, 

116 g, 

117 beta, 

118 initial_state, 

119 output_final_state, 

120 use_qk_l2norm_in_kernel, 

121 ) 

122 

123 orig_dtype = query.dtype 

124 

125 # Squeeze T=1; cast to fp32 contiguous flat tensors that our kernel 

126 # expects ([B, H, k_dim] / [B, H, v_dim] / [B, H]). 

127 q_f = query.squeeze(1).to(torch.float32).contiguous() 

128 k_f = key.squeeze(1).to(torch.float32).contiguous() 

129 v_f = value.squeeze(1).to(torch.float32).contiguous() 

130 g_f = g.squeeze(1).to(torch.float32).contiguous() 

131 b_f = beta.squeeze(1).to(torch.float32).contiguous() 

132 

133 if initial_state is None: 

134 state = torch.zeros(B, H, k_dim, v_dim, dtype=torch.float32) 

135 else: 

136 # .contiguous() on already-contiguous fp32 is a no-op. 

137 # The caller replaces cache_params.recurrent_states[layer_idx] 

138 # with our return value, so in-place update is safe here. 

139 state = initial_state.to(torch.float32).contiguous() 

140 

141 out = torch.empty(B, H, v_dim, dtype=torch.float32) 

142 

143 _gated_delta_decode_kernel[(1,)]( 

144 q_f, 

145 k_f, 

146 v_f, 

147 g_f, 

148 b_f, 

149 state, 

150 out, 

151 B=B, 

152 H=H, 

153 k_dim=k_dim, 

154 v_dim=v_dim, 

155 use_l2norm=1 if use_qk_l2norm_in_kernel else 0, 

156 ) 

157 

158 core_attn_out = out.unsqueeze(1).to(orig_dtype).contiguous() 

159 last_recurrent_state = state if output_final_state else None 

160 return core_attn_out, last_recurrent_state 

161 

162 return fn 

163 

164 

165def _get_qwen3_5_gated_delta_classes() -> tuple: 

166 classes = [] 

167 for modname, clsname in [ 

168 ("transformers.models.qwen3_5.modeling_qwen3_5", "Qwen3_5GatedDeltaNet"), 

169 ( 

170 "transformers.models.qwen3_5_moe.modeling_qwen3_5_moe", 

171 "Qwen3_5MoeGatedDeltaNet", 

172 ), 

173 ( 

174 "transformers.models.qwen3_next.modeling_qwen3_next", 

175 "Qwen3NextGatedDeltaNet", 

176 ), 

177 ]: 

178 try: 

179 mod = __import__(modname, fromlist=[clsname]) 

180 classes.append(getattr(mod, clsname)) 

181 except (ImportError, AttributeError): 

182 pass 

183 return tuple(classes) 

184 

185 

186def patch_qwen3_5_gated_delta(model) -> int: 

187 """Replace each GDN module's recurrent_gated_delta_rule with the fused 

188 TLE kernel. Returns the number of modules patched. 

189 

190 Safe to call multiple times (each module is patched once via id-tracking). 

191 """ 

192 gdn_classes = _get_qwen3_5_gated_delta_classes() 

193 if not gdn_classes: 

194 logger.debug("No Qwen GDN classes found in transformers, skipping patch") 

195 return 0 

196 

197 n = 0 

198 for _name, module in list(model.named_modules()): 

199 if isinstance(module, gdn_classes) and id(module) not in _PATCHED: 

200 torch_recurrent_fn = module.recurrent_gated_delta_rule 

201 module._original_recurrent_gated_delta_rule = torch_recurrent_fn 

202 module.recurrent_gated_delta_rule = _make_patched_fn(torch_recurrent_fn) 

203 _PATCHED.add(id(module)) 

204 n += 1 

205 if n > 0: 

206 cls_names = ", ".join(c.__name__ for c in gdn_classes) 

207 logger.info( 

208 "Patched %d GDN modules (classes: %s) with TLE gated_delta_decode", 

209 n, 

210 cls_names, 

211 ) 

212 return n 

213 

214 

215def unpatch_qwen3_5_gated_delta(model) -> int: 

216 gdn_classes = _get_qwen3_5_gated_delta_classes() 

217 if not gdn_classes: 

218 return 0 

219 n = 0 

220 for _name, module in list(model.named_modules()): 

221 if isinstance(module, gdn_classes) and id(module) in _PATCHED: 

222 if hasattr(module, "_original_recurrent_gated_delta_rule"): 

223 module.recurrent_gated_delta_rule = ( 

224 module._original_recurrent_gated_delta_rule 

225 ) 

226 del module._original_recurrent_gated_delta_rule 

227 _PATCHED.discard(id(module)) 

228 n += 1 

229 return n