Coverage for src/flag_gems/fused/FLA/chunk.py: 50%
18 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
1# This file contains code copied from the flash-linear-attention project.
2# The original source code was licensed under the MIT license and included
3# the following copyright notice:
4# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
5# ruff: noqa: E501
7import logging
9import torch
11from flag_gems.fused.FLA.chunk_delta_h import chunk_gated_delta_rule_fwd_h
12from flag_gems.fused.FLA.chunk_o import chunk_fwd_o
13from flag_gems.fused.FLA.fused_cumsum_kkt_solve_tril import (
14 chunk_gated_delta_rule_fused_cumsum_kkt_solve_tril,
15)
16from flag_gems.fused.FLA.utils import SUPPRESS_LEVEL
17from flag_gems.fused.FLA.wy_fast import recompute_w_u_fwd
19logger = logging.getLogger(__name__)
22def chunk_gated_delta_rule_fwd(
23 q: torch.Tensor,
24 k: torch.Tensor,
25 v: torch.Tensor,
26 g: torch.Tensor,
27 beta: torch.Tensor,
28 scale: float,
29 initial_state: torch.Tensor,
30 output_final_state: bool,
31 cu_seqlens: torch.LongTensor | None = None,
32):
33 logger.debug("GEMS CHUNK GATED DELTA RULE FWD")
34 g, A = chunk_gated_delta_rule_fused_cumsum_kkt_solve_tril(
35 g=g, k=k, beta=beta, cu_seqlens=cu_seqlens, chunk_size=64, output_dtype=k.dtype
36 )
37 w, u = recompute_w_u_fwd(
38 k=k,
39 v=v,
40 beta=beta,
41 A=A,
42 g_cumsum=g,
43 cu_seqlens=cu_seqlens,
44 )
45 h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
46 k=k,
47 w=w,
48 u=u,
49 g=g,
50 initial_state=initial_state,
51 output_final_state=output_final_state,
52 cu_seqlens=cu_seqlens,
53 )
54 o = chunk_fwd_o(
55 q=q,
56 k=k,
57 v=v_new,
58 h=h,
59 g=g,
60 scale=scale,
61 cu_seqlens=cu_seqlens,
62 )
63 if SUPPRESS_LEVEL < 3:
64 return g, o, A, final_state, None, None, None
65 elif SUPPRESS_LEVEL >= 3:
66 return g, o, A, final_state, w, h, v_new