Coverage for src/flag_gems/fused/FLA/chunk.py: 47%
15 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
1# This file contains code copied from the flash-linear-attention project.
2# The original source code was licensed under the MIT license and included
3# the following copyright notice:
4# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
5# ruff: noqa: E501
6import torch
8from flag_gems.fused.FLA.chunk_delta_h import chunk_gated_delta_rule_fwd_h
9from flag_gems.fused.FLA.chunk_o import chunk_fwd_o
10from flag_gems.fused.FLA.fused_cumsum_kkt_solve_tril import (
11 chunk_gated_delta_rule_fused_cumsum_kkt_solve_tril,
12)
13from flag_gems.fused.FLA.utils import SUPPRESS_LEVEL
14from flag_gems.fused.FLA.wy_fast import recompute_w_u_fwd
17def chunk_gated_delta_rule_fwd(
18 q: torch.Tensor,
19 k: torch.Tensor,
20 v: torch.Tensor,
21 g: torch.Tensor,
22 beta: torch.Tensor,
23 scale: float,
24 initial_state: torch.Tensor,
25 output_final_state: bool,
26 cu_seqlens: torch.LongTensor | None = None,
27):
28 g, A = chunk_gated_delta_rule_fused_cumsum_kkt_solve_tril(
29 g=g, k=k, beta=beta, cu_seqlens=cu_seqlens, chunk_size=64, output_dtype=k.dtype
30 )
31 w, u = recompute_w_u_fwd(
32 k=k,
33 v=v,
34 beta=beta,
35 A=A,
36 g_cumsum=g,
37 cu_seqlens=cu_seqlens,
38 )
39 h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
40 k=k,
41 w=w,
42 u=u,
43 g=g,
44 initial_state=initial_state,
45 output_final_state=output_final_state,
46 cu_seqlens=cu_seqlens,
47 )
48 o = chunk_fwd_o(
49 q=q,
50 k=k,
51 v=v_new,
52 h=h,
53 g=g,
54 scale=scale,
55 cu_seqlens=cu_seqlens,
56 )
57 if SUPPRESS_LEVEL < 3:
58 return g, o, A, final_state, None, None, None
59 elif SUPPRESS_LEVEL >= 3:
60 return g, o, A, final_state, w, h, v_new