Coverage for src/flag_gems/fused/FLA/chunk.py: 47%
19 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-07 22:33 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-07 22:33 +0800
1# This file contains code copied from the flash-linear-attention project.
2# The original source code was licensed under the MIT license and included
3# the following copyright notice:
4# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
5# ruff: noqa: E501
6import torch
8from flag_gems.fused.FLA.chunk_delta_h import chunk_gated_delta_rule_fwd_h
9from flag_gems.fused.FLA.chunk_o import chunk_fwd_o
10from flag_gems.fused.FLA.chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd
11from flag_gems.fused.FLA.cumsum import chunk_local_cumsum
12from flag_gems.fused.FLA.solve_tril import solve_tril
13from flag_gems.fused.FLA.utils import SUPPRESS_LEVEL
14from flag_gems.fused.FLA.wy_fast import recompute_w_u_fwd
17def chunk_gated_delta_rule_fwd(
18 q: torch.Tensor,
19 k: torch.Tensor,
20 v: torch.Tensor,
21 g: torch.Tensor,
22 beta: torch.Tensor,
23 scale: float,
24 initial_state: torch.Tensor,
25 output_final_state: bool,
26 cu_seqlens: torch.LongTensor | None = None,
27):
28 g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
29 # obtain WY representation. u is actually the new v.
30 A = chunk_scaled_dot_kkt_fwd(
31 k=k, beta=beta, g=g, cu_seqlens=cu_seqlens, output_dtype=torch.float32
32 )
33 A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype)
34 w, u = recompute_w_u_fwd(
35 k=k,
36 v=v,
37 beta=beta,
38 A=A,
39 g_cumsum=g,
40 cu_seqlens=cu_seqlens,
41 )
42 h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
43 k=k,
44 w=w,
45 u=u,
46 g=g,
47 initial_state=initial_state,
48 output_final_state=output_final_state,
49 cu_seqlens=cu_seqlens,
50 )
51 o = chunk_fwd_o(
52 q=q,
53 k=k,
54 v=v_new,
55 h=h,
56 g=g,
57 scale=scale,
58 cu_seqlens=cu_seqlens,
59 )
60 if SUPPRESS_LEVEL < 3:
61 return g, o, A, final_state, None, None, None
62 elif SUPPRESS_LEVEL >= 3:
63 return g, o, A, final_state, w, h, v_new