Coverage for src/flag_gems/fused/FLA/chunk.py: 47%

15 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-17 02:35 +0800

1# This file contains code copied from the flash-linear-attention project. 

2# The original source code was licensed under the MIT license and included 

3# the following copyright notice: 

4# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang 

5# ruff: noqa: E501 

6import torch 

7 

8from flag_gems.fused.FLA.chunk_delta_h import chunk_gated_delta_rule_fwd_h 

9from flag_gems.fused.FLA.chunk_o import chunk_fwd_o 

10from flag_gems.fused.FLA.fused_cumsum_kkt_solve_tril import ( 

11 chunk_gated_delta_rule_fused_cumsum_kkt_solve_tril, 

12) 

13from flag_gems.fused.FLA.utils import SUPPRESS_LEVEL 

14from flag_gems.fused.FLA.wy_fast import recompute_w_u_fwd 

15 

16 

17def chunk_gated_delta_rule_fwd( 

18 q: torch.Tensor, 

19 k: torch.Tensor, 

20 v: torch.Tensor, 

21 g: torch.Tensor, 

22 beta: torch.Tensor, 

23 scale: float, 

24 initial_state: torch.Tensor, 

25 output_final_state: bool, 

26 cu_seqlens: torch.LongTensor | None = None, 

27): 

28 g, A = chunk_gated_delta_rule_fused_cumsum_kkt_solve_tril( 

29 g=g, k=k, beta=beta, cu_seqlens=cu_seqlens, chunk_size=64, output_dtype=k.dtype 

30 ) 

31 w, u = recompute_w_u_fwd( 

32 k=k, 

33 v=v, 

34 beta=beta, 

35 A=A, 

36 g_cumsum=g, 

37 cu_seqlens=cu_seqlens, 

38 ) 

39 h, v_new, final_state = chunk_gated_delta_rule_fwd_h( 

40 k=k, 

41 w=w, 

42 u=u, 

43 g=g, 

44 initial_state=initial_state, 

45 output_final_state=output_final_state, 

46 cu_seqlens=cu_seqlens, 

47 ) 

48 o = chunk_fwd_o( 

49 q=q, 

50 k=k, 

51 v=v_new, 

52 h=h, 

53 g=g, 

54 scale=scale, 

55 cu_seqlens=cu_seqlens, 

56 ) 

57 if SUPPRESS_LEVEL < 3: 

58 return g, o, A, final_state, None, None, None 

59 elif SUPPRESS_LEVEL >= 3: 

60 return g, o, A, final_state, w, h, v_new