Coverage for src/flag_gems/runtime/backend/_ascend/fla/chunk_scaled_dot_kkt.py: 0%
49 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
1# SPDX-License-Identifier: Apache-2.0
2# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
4#
5# This file contains code copied from the flash-linear-attention project.
6# The original source code was licensed under the MIT license and included
7# the following copyright notice:
8# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
9# ruff: noqa: E501
10# mypy: ignore-errors
11import torch
12import triton
13import triton.language as tl
15from .utils import prepare_chunk_indices, safe_exp
18@triton.heuristics(
19 {
20 "IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
21 "USE_G": lambda args: args["g_cumsum"] is not None,
22 }
23)
24@triton.jit(do_not_specialize=["T"])
25def chunk_scaled_dot_kkt_fwd_kernel(
26 k,
27 beta, # [H, B, T]
28 g_cumsum, # [H, B, T]
29 A,
30 cu_seqlens,
31 chunk_indices,
32 T,
33 B,
34 H: tl.constexpr,
35 Hg: tl.constexpr,
36 K: tl.constexpr,
37 BT: tl.constexpr,
38 BK: tl.constexpr,
39 IS_VARLEN: tl.constexpr,
40 USE_G: tl.constexpr,
41):
42 bt_stride = B * T
43 i_t_i, _ = tl.program_id(0), tl.program_id(1)
45 for i_bh in range(B * H):
46 i_b, i_h = i_bh // H, i_bh % H
47 if IS_VARLEN:
48 i_n, i_t = (
49 tl.load(chunk_indices + i_t_i * 2).to(tl.int32),
50 tl.load(chunk_indices + i_t_i * 2 + 1).to(tl.int32),
51 )
52 bos, eos = (
53 tl.load(cu_seqlens + i_n).to(tl.int32),
54 tl.load(cu_seqlens + i_n + 1).to(tl.int32),
55 )
56 T = eos - bos
57 else:
58 bos, eos = i_b * T, i_b * T + T
59 i_t = i_t_i
60 o_t = tl.arange(0, BT)
61 o_t_fp32 = o_t.to(tl.float32)
63 p_beta = tl.make_block_ptr(
64 beta + i_h * bt_stride + bos, (T,), (1,), (i_t * BT,), (BT,), (0,)
65 )
66 b_beta = tl.load(p_beta, boundary_check=(0,))
68 b_A = tl.zeros([BT, BT], dtype=tl.float32)
69 for i_k in range(tl.cdiv(K, BK)):
70 p_k = tl.make_block_ptr(
71 k + (bos * Hg + i_h // (H // Hg)) * K,
72 (T, K),
73 (Hg * K, 1),
74 (i_t * BT, i_k * BK),
75 (BT, BK),
76 (1, 0),
77 )
78 b_k = tl.load(p_k, boundary_check=(0, 1))
79 b_A += tl.dot(b_k, tl.trans(b_k))
81 if USE_G:
82 p_g = tl.make_block_ptr(
83 g_cumsum + i_h * bt_stride + bos, (T,), (1,), (i_t * BT,), (BT,), (0,)
84 )
85 b_g = tl.load(p_g, boundary_check=(0,))
86 b_g_diff = b_g[:, None] - b_g[None, :]
87 b_A *= safe_exp(b_g_diff)
89 b_A *= b_beta[:, None]
90 b_A = tl.where(o_t_fp32[:, None] > o_t_fp32[None, :], b_A, 0)
91 p_A = tl.make_block_ptr(
92 A + (bos * H + i_h) * BT,
93 (T, BT),
94 (BT * H, 1),
95 (i_t * BT, 0),
96 (BT, BT),
97 (1, 0),
98 )
99 tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
102def chunk_scaled_dot_kkt_fwd(
103 k: torch.Tensor,
104 beta: torch.Tensor,
105 g_cumsum: torch.Tensor | None = None,
106 cu_seqlens: torch.LongTensor | None = None,
107 chunk_size: int = 64,
108 output_dtype: torch.dtype = torch.float32,
109) -> torch.Tensor:
110 r"""
111 Compute beta * K * K^T.
113 Args:
114 k (torch.Tensor):
115 The key tensor of shape `[B, T, H, K]`.
116 beta (torch.Tensor):
117 The beta tensor of shape `[B, T, H]`.
118 g (torch.Tensor):
119 The cumulative sum of the gate tensor of shape `[B, T, H]`. Default: `None`.
120 gk (torch.Tensor):
121 The cumulative sum of the gate tensor of shape `[B, T, H, K]` applied to the key tensor. Default: `None`.
122 cu_seqlens (torch.LongTensor):
123 The cumulative sequence lengths of the input tensor.
124 Default: None
125 chunk_size (int):
126 The chunk size. Default: 64.
127 output_dtype (torch.dtype):
128 The dtype of the output tensor. Default: `torch.float32`
130 Returns:
131 beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size.
132 """
133 B, T, Hg, K = k.shape
135 H = beta.shape[-1]
136 BT = chunk_size
137 if cu_seqlens is not None:
138 cu_seqlens = cu_seqlens.cpu()
139 chunk_indices = (
140 prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
141 )
142 chunk_indices = chunk_indices.npu()
143 cu_seqlens = cu_seqlens.npu()
144 else:
145 chunk_indices = None
146 NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
147 A = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype)
149 chunk_scaled_dot_kkt_fwd_kernel[(NT, 1)](
150 k=k,
151 beta=torch.permute(beta, (2, 0, 1)).contiguous(),
152 g_cumsum=torch.permute(g_cumsum, (2, 0, 1)).contiguous(),
153 A=A,
154 cu_seqlens=cu_seqlens,
155 chunk_indices=chunk_indices,
156 T=T,
157 B=B,
158 H=H,
159 Hg=Hg,
160 K=K,
161 BT=BT,
162 BK=128,
163 num_warps=8,
164 num_stages=3,
165 multibuffer=True,
166 )
167 return A