Coverage for src/flag_gems/fused/FLA/chunk.py: 50%

18 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-24 15:40 +0800

1# This file contains code copied from the flash-linear-attention project. 

2# The original source code was licensed under the MIT license and included 

3# the following copyright notice: 

4# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang 

5# ruff: noqa: E501 

6 

7import logging 

8 

9import torch 

10 

11from flag_gems.fused.FLA.chunk_delta_h import chunk_gated_delta_rule_fwd_h 

12from flag_gems.fused.FLA.chunk_o import chunk_fwd_o 

13from flag_gems.fused.FLA.fused_cumsum_kkt_solve_tril import ( 

14 chunk_gated_delta_rule_fused_cumsum_kkt_solve_tril, 

15) 

16from flag_gems.fused.FLA.utils import SUPPRESS_LEVEL 

17from flag_gems.fused.FLA.wy_fast import recompute_w_u_fwd 

18 

19logger = logging.getLogger(__name__) 

20 

21 

22def chunk_gated_delta_rule_fwd( 

23 q: torch.Tensor, 

24 k: torch.Tensor, 

25 v: torch.Tensor, 

26 g: torch.Tensor, 

27 beta: torch.Tensor, 

28 scale: float, 

29 initial_state: torch.Tensor, 

30 output_final_state: bool, 

31 cu_seqlens: torch.LongTensor | None = None, 

32): 

33 logger.debug("GEMS CHUNK GATED DELTA RULE FWD") 

34 g, A = chunk_gated_delta_rule_fused_cumsum_kkt_solve_tril( 

35 g=g, k=k, beta=beta, cu_seqlens=cu_seqlens, chunk_size=64, output_dtype=k.dtype 

36 ) 

37 w, u = recompute_w_u_fwd( 

38 k=k, 

39 v=v, 

40 beta=beta, 

41 A=A, 

42 g_cumsum=g, 

43 cu_seqlens=cu_seqlens, 

44 ) 

45 h, v_new, final_state = chunk_gated_delta_rule_fwd_h( 

46 k=k, 

47 w=w, 

48 u=u, 

49 g=g, 

50 initial_state=initial_state, 

51 output_final_state=output_final_state, 

52 cu_seqlens=cu_seqlens, 

53 ) 

54 o = chunk_fwd_o( 

55 q=q, 

56 k=k, 

57 v=v_new, 

58 h=h, 

59 g=g, 

60 scale=scale, 

61 cu_seqlens=cu_seqlens, 

62 ) 

63 if SUPPRESS_LEVEL < 3: 

64 return g, o, A, final_state, None, None, None 

65 elif SUPPRESS_LEVEL >= 3: 

66 return g, o, A, final_state, w, h, v_new