Coverage for src/flag_gems/runtime/backend/_ascend/fla/fused_qkvzba_split_reshape.py: 0%
45 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-21 14:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-21 14:31 +0800
1# SPDX-License-Identifier: Apache-2.0
2# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
4#
5# This file contains code copied from the flash-linear-attention project.
6# The original source code was licensed under the MIT license and included
7# the following copyright notice:
8# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
10# ruff: noqa: E501
11# mypy: ignore-errors
12import torch
13import triton
14import triton.language as tl
17@triton.jit
18def fused_qkvzba_split_reshape_cat_kernel(
19 mixed_qkv,
20 z,
21 b,
22 a,
23 mixed_qkvz,
24 mixed_ba,
25 NUM_HEADS_QK: tl.constexpr,
26 NUM_HEADS_V: tl.constexpr,
27 HEAD_QK: tl.constexpr,
28 HEAD_V: tl.constexpr,
29):
30 i_bs, i_qk = tl.program_id(0), tl.program_id(1)
31 QKVZ_DIM_T: tl.constexpr = HEAD_QK * 2 + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V * 2
32 BA_DIM_T: tl.constexpr = NUM_HEADS_V // NUM_HEADS_QK * 2
33 QKV_DIM_T: tl.constexpr = HEAD_QK * 2 + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V
34 q_end: tl.constexpr = HEAD_QK
35 blk_q_ptr = (
36 mixed_qkvz
37 + i_bs * NUM_HEADS_QK * QKVZ_DIM_T
38 + i_qk * QKVZ_DIM_T
39 + tl.arange(0, q_end)
40 )
41 k_end: tl.constexpr = q_end + HEAD_QK
42 blk_k_ptr = (
43 mixed_qkvz
44 + i_bs * NUM_HEADS_QK * QKVZ_DIM_T
45 + i_qk * QKVZ_DIM_T
46 + tl.arange(q_end, k_end)
47 )
48 v_end: tl.constexpr = k_end + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V
49 blk_v_ptr = (
50 mixed_qkvz
51 + i_bs * NUM_HEADS_QK * QKVZ_DIM_T
52 + i_qk * QKVZ_DIM_T
53 + tl.arange(k_end, v_end)
54 )
55 z_end: tl.constexpr = v_end + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V
56 blk_z_ptr = (
57 mixed_qkvz
58 + i_bs * NUM_HEADS_QK * QKVZ_DIM_T
59 + i_qk * QKVZ_DIM_T
60 + tl.arange(v_end, z_end)
61 )
62 blk_q_st_ptr = (
63 mixed_qkv
64 + i_bs * NUM_HEADS_QK * QKV_DIM_T
65 + i_qk * HEAD_QK
66 + tl.arange(0, HEAD_QK)
67 )
68 blk_k_st_ptr = (
69 mixed_qkv
70 + i_bs * NUM_HEADS_QK * QKV_DIM_T
71 + NUM_HEADS_QK * HEAD_QK
72 + i_qk * HEAD_QK
73 + tl.arange(0, HEAD_QK)
74 )
75 blk_v_st_ptr = (
76 mixed_qkv
77 + i_bs * NUM_HEADS_QK * QKV_DIM_T
78 + NUM_HEADS_QK * HEAD_QK * 2
79 + i_qk * HEAD_V * NUM_HEADS_V // NUM_HEADS_QK
80 + tl.arange(0, HEAD_V * NUM_HEADS_V // NUM_HEADS_QK)
81 )
82 blk_z_st_ptr = (
83 z
84 + i_bs * NUM_HEADS_V * HEAD_V
85 + i_qk * HEAD_V * NUM_HEADS_V // NUM_HEADS_QK
86 + tl.arange(0, HEAD_V * NUM_HEADS_V // NUM_HEADS_QK)
87 )
88 tl.store(blk_q_st_ptr, tl.load(blk_q_ptr))
89 tl.store(blk_k_st_ptr, tl.load(blk_k_ptr))
90 tl.store(blk_v_st_ptr, tl.load(blk_v_ptr))
91 tl.store(blk_z_st_ptr, tl.load(blk_z_ptr))
92 b_end: tl.constexpr = NUM_HEADS_V // NUM_HEADS_QK
93 a_end: tl.constexpr = b_end + NUM_HEADS_V // NUM_HEADS_QK
94 for i in tl.static_range(b_end):
95 blk_b_ptr = mixed_ba + i_bs * NUM_HEADS_QK * BA_DIM_T + i_qk * BA_DIM_T + i
96 blk_b_st_ptr = b + i_bs * NUM_HEADS_V + i_qk * NUM_HEADS_V // NUM_HEADS_QK + i
97 tl.store(blk_b_st_ptr, tl.load(blk_b_ptr))
98 for i in tl.static_range(b_end, a_end):
99 blk_a_ptr = mixed_ba + i_bs * NUM_HEADS_QK * BA_DIM_T + i_qk * BA_DIM_T + i
100 blk_a_st_ptr = (
101 a + i_bs * NUM_HEADS_V + i_qk * NUM_HEADS_V // NUM_HEADS_QK + (i - b_end)
102 )
103 tl.store(blk_a_st_ptr, tl.load(blk_a_ptr))
106def fused_qkvzba_split_reshape_cat(
107 mixed_qkvz,
108 mixed_ba,
109 num_heads_qk,
110 num_heads_v,
111 head_qk,
112 head_v,
113):
114 batch, seq_len = mixed_qkvz.shape[0], 1
115 qkv_dim_t = num_heads_qk * head_qk * 2 + num_heads_v * head_v
116 mixed_qkv = torch.empty(
117 [batch * seq_len, qkv_dim_t],
118 dtype=mixed_qkvz.dtype,
119 device=mixed_qkvz.device,
120 )
121 z = torch.empty(
122 [batch * seq_len, num_heads_v, head_v],
123 dtype=mixed_qkvz.dtype,
124 device=mixed_qkvz.device,
125 )
126 b = torch.empty(
127 [batch * seq_len, num_heads_v],
128 dtype=mixed_ba.dtype,
129 device=mixed_ba.device,
130 )
131 a = torch.empty_like(b)
132 grid = (batch * seq_len, num_heads_qk)
133 fused_qkvzba_split_reshape_cat_kernel[grid](
134 mixed_qkv,
135 z,
136 b,
137 a,
138 mixed_qkvz,
139 mixed_ba,
140 num_heads_qk,
141 num_heads_v,
142 head_qk,
143 head_v,
144 num_warps=1,
145 num_stages=3,
146 )
147 return mixed_qkv, z, b, a