Coverage for src/flag_gems/runtime/backend/_ascend/fla/chunk.py: 0%

19 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-28 12:23 +0800

1# SPDX-License-Identifier: Apache-2.0 

2# SPDX-FileCopyrightText: Copyright contributors to the vLLM project 

3# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang 

4# 

5# This file contains code copied from the flash-linear-attention project. 

6# The original source code was licensed under the MIT license and included 

7# the following copyright notice: 

8# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang 

9# ruff: noqa: E501 

10# mypy: ignore-errors 

11import torch 

12 

13from flag_gems.fused.FLA.utils import SUPPRESS_LEVEL 

14 

15from .chunk_delta_h import chunk_gated_delta_rule_fwd_h 

16from .chunk_o import chunk_fwd_o 

17from .chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd 

18from .cumsum import chunk_local_cumsum 

19from .solve_tril import solve_tril 

20from .wy_fast import recompute_w_u_fwd 

21 

22 

23def chunk_gated_delta_rule_fwd( 

24 q: torch.Tensor, 

25 k: torch.Tensor, 

26 v: torch.Tensor, 

27 g: torch.Tensor, 

28 beta: torch.Tensor, 

29 scale: float, 

30 initial_state: torch.Tensor, 

31 output_final_state: bool, 

32 cu_seqlens: torch.LongTensor | None = None, 

33): 

34 g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens) 

35 # obtain WY representation. u is actually the new v. 

36 A = chunk_scaled_dot_kkt_fwd( 

37 k=k, beta=beta, g_cumsum=g, cu_seqlens=cu_seqlens, output_dtype=torch.float32 

38 ) 

39 A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype) 

40 w, u = recompute_w_u_fwd( 

41 k=k, 

42 v=v, 

43 beta=beta, 

44 A=A, 

45 g_cumsum=g, 

46 cu_seqlens=cu_seqlens, 

47 ) 

48 h, v_new, final_state = chunk_gated_delta_rule_fwd_h( 

49 k=k, 

50 w=w, 

51 u=u, 

52 g=g, 

53 initial_state=initial_state, 

54 output_final_state=output_final_state, 

55 cu_seqlens=cu_seqlens, 

56 ) 

57 o = chunk_fwd_o( 

58 q=q, 

59 k=k, 

60 v=v_new, 

61 h=h, 

62 g=g, 

63 scale=scale, 

64 cu_seqlens=cu_seqlens, 

65 ) 

66 if SUPPRESS_LEVEL < 3: 

67 return g, o, A, final_state, None, None, None 

68 elif SUPPRESS_LEVEL >= 3: 

69 return g, o, A, final_state, w, h, v_new