Coverage for src/flag_gems/fused/FLA/chunk.py: 47%

19 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-11 02:28 +0800

1# This file contains code copied from the flash-linear-attention project. 

2# The original source code was licensed under the MIT license and included 

3# the following copyright notice: 

4# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang 

5# ruff: noqa: E501 

6import torch 

7 

8from flag_gems.fused.FLA.chunk_delta_h import chunk_gated_delta_rule_fwd_h 

9from flag_gems.fused.FLA.chunk_o import chunk_fwd_o 

10from flag_gems.fused.FLA.chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd 

11from flag_gems.fused.FLA.cumsum import chunk_local_cumsum 

12from flag_gems.fused.FLA.solve_tril import solve_tril 

13from flag_gems.fused.FLA.utils import SUPPRESS_LEVEL 

14from flag_gems.fused.FLA.wy_fast import recompute_w_u_fwd 

15 

16 

17def chunk_gated_delta_rule_fwd( 

18 q: torch.Tensor, 

19 k: torch.Tensor, 

20 v: torch.Tensor, 

21 g: torch.Tensor, 

22 beta: torch.Tensor, 

23 scale: float, 

24 initial_state: torch.Tensor, 

25 output_final_state: bool, 

26 cu_seqlens: torch.LongTensor | None = None, 

27): 

28 g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens) 

29 # obtain WY representation. u is actually the new v. 

30 A = chunk_scaled_dot_kkt_fwd( 

31 k=k, beta=beta, g=g, cu_seqlens=cu_seqlens, output_dtype=torch.float32 

32 ) 

33 A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype) 

34 w, u = recompute_w_u_fwd( 

35 k=k, 

36 v=v, 

37 beta=beta, 

38 A=A, 

39 g_cumsum=g, 

40 cu_seqlens=cu_seqlens, 

41 ) 

42 h, v_new, final_state = chunk_gated_delta_rule_fwd_h( 

43 k=k, 

44 w=w, 

45 u=u, 

46 g=g, 

47 initial_state=initial_state, 

48 output_final_state=output_final_state, 

49 cu_seqlens=cu_seqlens, 

50 ) 

51 o = chunk_fwd_o( 

52 q=q, 

53 k=k, 

54 v=v_new, 

55 h=h, 

56 g=g, 

57 scale=scale, 

58 cu_seqlens=cu_seqlens, 

59 ) 

60 if SUPPRESS_LEVEL < 3: 

61 return g, o, A, final_state, None, None, None 

62 elif SUPPRESS_LEVEL >= 3: 

63 return g, o, A, final_state, w, h, v_new