Coverage for src/flag_gems/runtime/backend/_arm/fused/patch_qwen3_5_gated_delta.py: 0%
72 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1"""Monkey-patch Qwen3_5GatedDeltaNet.recurrent_gated_delta_rule with a
2TLE-CPU fused decode kernel.
4Replaces the per-layer ATen sequence
6 state *= exp(g)
7 kv_mem = (state * k.unsqueeze(-1)).sum(dim=-2)
8 delta = (v - kv_mem) * beta
9 state += k.unsqueeze(-1) * delta.unsqueeze(-2)
10 out = (state * q.unsqueeze(-1)).sum(dim=-2)
12with a single fused @triton.jit kernel that calls the TLE builtin
13`tle_ops.gated_delta_decode`, which dispatches to the NEON C runtime
14`standalone_gated_delta_decode_fp32`.
16State update + output dot are fused into a single sweep over state, matching
17llama.cpp's Metal/SYCL backends (their CPU kernel does these as 2 separate
18passes).
20Decode (T=1) only — prefill (T>1) falls back to torch_chunk_gated_delta_rule.
21"""
23import logging
25import torch
26import triton
27import triton.language as tl
28from triton.language.extra.cpu.tle_ops import (
29 gated_delta_decode as _tle_gated_delta_decode,
30)
32logger = logging.getLogger(__name__)
34_PATCHED: set = set()
37@triton.jit
38def _gated_delta_decode_kernel(
39 q_ptr,
40 k_ptr,
41 v_ptr,
42 g_ptr,
43 beta_ptr,
44 state_ptr,
45 out_ptr,
46 B: tl.constexpr,
47 H: tl.constexpr,
48 k_dim: tl.constexpr,
49 v_dim: tl.constexpr,
50 use_l2norm: tl.constexpr,
51):
52 _tle_gated_delta_decode(
53 q_ptr,
54 k_ptr,
55 v_ptr,
56 g_ptr,
57 beta_ptr,
58 state_ptr,
59 out_ptr,
60 B,
61 H,
62 k_dim,
63 v_dim,
64 use_l2norm,
65 )
68def _patched_recurrent_gated_delta_rule(
69 query,
70 key,
71 value,
72 g,
73 beta,
74 initial_state,
75 output_final_state,
76 use_qk_l2norm_in_kernel=False,
77):
78 """Drop-in for torch_recurrent_gated_delta_rule on T=1 decode path.
80 Shapes (matching HF):
81 query, key: [B, T, H, k_dim] (any dtype; cast to fp32 internally)
82 value: [B, T, H, v_dim]
83 g, beta: [B, T, H]
84 initial_state: [B, H, k_dim, v_dim] fp32, or None
86 Returns:
87 core_attn_out: [B, T, H, v_dim] cast back to query.dtype
88 last_recurrent_state: [B, H, k_dim, v_dim] fp32 (or None)
90 For T>1, falls back to the original torch implementation in the host
91 module (caller passes that in via the closure during patching).
92 """
93 raise NotImplementedError("install via patch_qwen3_5_gated_delta(model)")
96def _make_patched_fn(torch_recurrent_fn):
97 def fn(
98 query,
99 key,
100 value,
101 g,
102 beta,
103 initial_state,
104 output_final_state,
105 use_qk_l2norm_in_kernel=False,
106 ):
107 B, T, H, k_dim = query.shape
108 v_dim = value.shape[-1]
110 # Prefill or any non-decode shape: defer to the torch reference.
111 if T != 1 or k_dim > 256 or v_dim > 256 or k_dim % 4 != 0 or v_dim % 4 != 0:
112 return torch_recurrent_fn(
113 query,
114 key,
115 value,
116 g,
117 beta,
118 initial_state,
119 output_final_state,
120 use_qk_l2norm_in_kernel,
121 )
123 orig_dtype = query.dtype
125 # Squeeze T=1; cast to fp32 contiguous flat tensors that our kernel
126 # expects ([B, H, k_dim] / [B, H, v_dim] / [B, H]).
127 q_f = query.squeeze(1).to(torch.float32).contiguous()
128 k_f = key.squeeze(1).to(torch.float32).contiguous()
129 v_f = value.squeeze(1).to(torch.float32).contiguous()
130 g_f = g.squeeze(1).to(torch.float32).contiguous()
131 b_f = beta.squeeze(1).to(torch.float32).contiguous()
133 if initial_state is None:
134 state = torch.zeros(B, H, k_dim, v_dim, dtype=torch.float32)
135 else:
136 # .contiguous() on already-contiguous fp32 is a no-op.
137 # The caller replaces cache_params.recurrent_states[layer_idx]
138 # with our return value, so in-place update is safe here.
139 state = initial_state.to(torch.float32).contiguous()
141 out = torch.empty(B, H, v_dim, dtype=torch.float32)
143 _gated_delta_decode_kernel[(1,)](
144 q_f,
145 k_f,
146 v_f,
147 g_f,
148 b_f,
149 state,
150 out,
151 B=B,
152 H=H,
153 k_dim=k_dim,
154 v_dim=v_dim,
155 use_l2norm=1 if use_qk_l2norm_in_kernel else 0,
156 )
158 core_attn_out = out.unsqueeze(1).to(orig_dtype).contiguous()
159 last_recurrent_state = state if output_final_state else None
160 return core_attn_out, last_recurrent_state
162 return fn
165def _get_qwen3_5_gated_delta_classes() -> tuple:
166 classes = []
167 for modname, clsname in [
168 ("transformers.models.qwen3_5.modeling_qwen3_5", "Qwen3_5GatedDeltaNet"),
169 (
170 "transformers.models.qwen3_5_moe.modeling_qwen3_5_moe",
171 "Qwen3_5MoeGatedDeltaNet",
172 ),
173 (
174 "transformers.models.qwen3_next.modeling_qwen3_next",
175 "Qwen3NextGatedDeltaNet",
176 ),
177 ]:
178 try:
179 mod = __import__(modname, fromlist=[clsname])
180 classes.append(getattr(mod, clsname))
181 except (ImportError, AttributeError):
182 pass
183 return tuple(classes)
186def patch_qwen3_5_gated_delta(model) -> int:
187 """Replace each GDN module's recurrent_gated_delta_rule with the fused
188 TLE kernel. Returns the number of modules patched.
190 Safe to call multiple times (each module is patched once via id-tracking).
191 """
192 gdn_classes = _get_qwen3_5_gated_delta_classes()
193 if not gdn_classes:
194 logger.debug("No Qwen GDN classes found in transformers, skipping patch")
195 return 0
197 n = 0
198 for _name, module in list(model.named_modules()):
199 if isinstance(module, gdn_classes) and id(module) not in _PATCHED:
200 torch_recurrent_fn = module.recurrent_gated_delta_rule
201 module._original_recurrent_gated_delta_rule = torch_recurrent_fn
202 module.recurrent_gated_delta_rule = _make_patched_fn(torch_recurrent_fn)
203 _PATCHED.add(id(module))
204 n += 1
205 if n > 0:
206 cls_names = ", ".join(c.__name__ for c in gdn_classes)
207 logger.info(
208 "Patched %d GDN modules (classes: %s) with TLE gated_delta_decode",
209 n,
210 cls_names,
211 )
212 return n
215def unpatch_qwen3_5_gated_delta(model) -> int:
216 gdn_classes = _get_qwen3_5_gated_delta_classes()
217 if not gdn_classes:
218 return 0
219 n = 0
220 for _name, module in list(model.named_modules()):
221 if isinstance(module, gdn_classes) and id(module) in _PATCHED:
222 if hasattr(module, "_original_recurrent_gated_delta_rule"):
223 module.recurrent_gated_delta_rule = (
224 module._original_recurrent_gated_delta_rule
225 )
226 del module._original_recurrent_gated_delta_rule
227 _PATCHED.discard(id(module))
228 n += 1
229 return n