Coverage for src/flag_gems/runtime/backend/_ascend/fla/chunk.py: 0%
19 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-24 15:40 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-24 15:40 +0800
1# SPDX-License-Identifier: Apache-2.0
2# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
4#
5# This file contains code copied from the flash-linear-attention project.
6# The original source code was licensed under the MIT license and included
7# the following copyright notice:
8# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
9# ruff: noqa: E501
10# mypy: ignore-errors
11import torch
13from flag_gems.fused.FLA.utils import SUPPRESS_LEVEL
15from .chunk_delta_h import chunk_gated_delta_rule_fwd_h
16from .chunk_o import chunk_fwd_o
17from .chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd
18from .cumsum import chunk_local_cumsum
19from .solve_tril import solve_tril
20from .wy_fast import recompute_w_u_fwd
23def chunk_gated_delta_rule_fwd(
24 q: torch.Tensor,
25 k: torch.Tensor,
26 v: torch.Tensor,
27 g: torch.Tensor,
28 beta: torch.Tensor,
29 scale: float,
30 initial_state: torch.Tensor,
31 output_final_state: bool,
32 cu_seqlens: torch.LongTensor | None = None,
33):
34 g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
35 # obtain WY representation. u is actually the new v.
36 A = chunk_scaled_dot_kkt_fwd(
37 k=k, beta=beta, g_cumsum=g, cu_seqlens=cu_seqlens, output_dtype=torch.float32
38 )
39 A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype)
40 w, u = recompute_w_u_fwd(
41 k=k,
42 v=v,
43 beta=beta,
44 A=A,
45 g_cumsum=g,
46 cu_seqlens=cu_seqlens,
47 )
48 h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
49 k=k,
50 w=w,
51 u=u,
52 g=g,
53 initial_state=initial_state,
54 output_final_state=output_final_state,
55 cu_seqlens=cu_seqlens,
56 )
57 o = chunk_fwd_o(
58 q=q,
59 k=k,
60 v=v_new,
61 h=h,
62 g=g,
63 scale=scale,
64 cu_seqlens=cu_seqlens,
65 )
66 if SUPPRESS_LEVEL < 3:
67 return g, o, A, final_state, None, None, None
68 elif SUPPRESS_LEVEL >= 3:
69 return g, o, A, final_state, w, h, v_new