Coverage for src/flag_gems/runtime/backend/_ascend/fla/wy_fast.py: 0%
60 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-23 02:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-23 02:03 +0800
1# SPDX-License-Identifier: Apache-2.0
2# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
4#
5# This file contains code copied from the flash-linear-attention project.
6# The original source code was licensed under the MIT license and included
7# the following copyright notice:
8# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
10# ruff: noqa: E501
11# mypy: ignore-errors
12import torch
13import triton
14import triton.language as tl
16from .utils import prepare_chunk_indices
19@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
20@triton.jit(do_not_specialize=["T"])
21def recompute_w_u_fwd_kernel(
22 k,
23 v,
24 beta,
25 w,
26 u,
27 A,
28 g,
29 cu_seqlens,
30 chunk_indices,
31 T,
32 H: tl.constexpr,
33 Hg: tl.constexpr,
34 K: tl.constexpr,
35 V: tl.constexpr,
36 BT: tl.constexpr,
37 BK: tl.constexpr,
38 BV: tl.constexpr,
39 IS_VARLEN: tl.constexpr,
40):
41 T_max = T
42 i_t_o = tl.program_id(0)
44 for i_bh in range(H):
45 i_b, i_h = i_bh // H, i_bh % H
46 if IS_VARLEN:
47 i_n, i_t = (
48 tl.load(chunk_indices + i_t_o * 2).to(tl.int32),
49 tl.load(chunk_indices + i_t_o * 2 + 1).to(tl.int32),
50 )
51 bos, eos = (
52 tl.load(cu_seqlens + i_n).to(tl.int32),
53 tl.load(cu_seqlens + i_n + 1).to(tl.int32),
54 )
55 T = eos - bos
56 else:
57 bos, eos = i_b * T, i_b * T + T
59 offs_t = tl.arange(0, BT)
60 global_offs_t = i_t * BT + offs_t
61 mask_t = global_offs_t < T
63 offs_t_2d = global_offs_t[:, None]
64 offs_bt = tl.arange(0, BT)[None, :]
65 ptr_A = A + (bos * H + i_h) * BT + offs_t_2d * (H * BT) + offs_bt * 1
66 mask_A = mask_t[:, None]
67 b_A = tl.load(ptr_A, mask=mask_A, other=0.0).to(tl.float32)
69 ptr_g = g + bos + i_h * T_max + global_offs_t
70 b_g = tl.exp(tl.load(ptr_g, mask=mask_t, other=0.0)).to(tl.float32)
72 ptr_beta = beta + bos + i_h * T_max + global_offs_t
73 b_beta = tl.load(ptr_beta, mask=mask_t, other=0.0).to(tl.float32)
75 for i_v in range(tl.cdiv(V, BV)):
76 offs_v = i_v * BV + tl.arange(0, BV)[None, :]
77 mask_v = (mask_t[:, None]) & (offs_v < V)
79 ptr_v = v + (bos * H + i_h) * V + offs_t_2d * (H * V) + offs_v * 1
80 b_v = tl.load(ptr_v, mask=mask_v, other=0.0).to(tl.float32)
82 b_vb = b_v * b_beta[:, None]
83 b_u = tl.dot(b_A, b_vb, allow_tf32=False)
85 ptr_u = u + (bos * H + i_h) * V + offs_t_2d * (H * V) + offs_v * 1
86 tl.store(ptr_u, b_u.to(ptr_u.dtype.element_ty), mask=mask_v)
88 for i_k in range(tl.cdiv(K, BK)):
89 offs_k = i_k * BK + tl.arange(0, BK)[None, :]
90 mask_k = (mask_t[:, None]) & (offs_k < K)
91 ptr_k = (
92 k
93 + (bos * Hg + i_h // (H // Hg)) * K
94 + offs_t_2d * (Hg * K)
95 + offs_k * 1
96 )
97 b_k = tl.load(ptr_k, mask=mask_k, other=0.0).to(tl.float32)
99 b_kb = b_k * b_beta[:, None] * b_g[:, None]
100 b_w = tl.dot(b_A, b_kb)
102 ptr_w = w + (bos * H + i_h) * K + offs_t_2d * (H * K) + offs_k * 1
103 tl.store(ptr_w, b_w.to(ptr_w.dtype.element_ty), mask=mask_k)
106def recompute_w_u_fwd(
107 k: torch.Tensor,
108 v: torch.Tensor,
109 beta: torch.Tensor,
110 g_cumsum: torch.Tensor,
111 A: torch.Tensor,
112 cu_seqlens: torch.LongTensor | None = None,
113) -> tuple[torch.Tensor, torch.Tensor]:
114 B, T, Hg, K, V = *k.shape, v.shape[-1]
115 H = v.shape[-2]
116 BT = A.shape[-1]
118 chunk_indices = (
119 prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
120 )
121 NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
123 BK = 64
124 BV = 64
126 u = torch.empty_like(v)
127 w = k.new_empty(B, T, H, K)
128 beta = beta.transpose(1, 2).contiguous()
129 g_cumsum = g_cumsum.transpose(1, 2).contiguous()
130 recompute_w_u_fwd_kernel[(NT, B)](
131 k=k,
132 v=v,
133 beta=beta,
134 w=w,
135 u=u,
136 A=A,
137 g=g_cumsum,
138 cu_seqlens=cu_seqlens,
139 chunk_indices=chunk_indices,
140 T=T,
141 H=H,
142 Hg=Hg,
143 K=K,
144 V=V,
145 BT=BT,
146 BK=BK,
147 BV=BV,
148 num_warps=4,
149 num_stages=3,
150 )
151 return w, u