Coverage for src/flag_gems/runtime/backend/_ascend/fla/fused_qkvzba_split_reshape.py: 0%

45 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-22 16:54 +0800

1# SPDX-License-Identifier: Apache-2.0 

2# SPDX-FileCopyrightText: Copyright contributors to the vLLM project 

3# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang 

4# 

5# This file contains code copied from the flash-linear-attention project. 

6# The original source code was licensed under the MIT license and included 

7# the following copyright notice: 

8# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang 

9 

10# ruff: noqa: E501 

11# mypy: ignore-errors 

12import torch 

13import triton 

14import triton.language as tl 

15 

16 

17@triton.jit 

18def fused_qkvzba_split_reshape_cat_kernel( 

19 mixed_qkv, 

20 z, 

21 b, 

22 a, 

23 mixed_qkvz, 

24 mixed_ba, 

25 NUM_HEADS_QK: tl.constexpr, 

26 NUM_HEADS_V: tl.constexpr, 

27 HEAD_QK: tl.constexpr, 

28 HEAD_V: tl.constexpr, 

29): 

30 i_bs, i_qk = tl.program_id(0), tl.program_id(1) 

31 QKVZ_DIM_T: tl.constexpr = HEAD_QK * 2 + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V * 2 

32 BA_DIM_T: tl.constexpr = NUM_HEADS_V // NUM_HEADS_QK * 2 

33 QKV_DIM_T: tl.constexpr = HEAD_QK * 2 + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V 

34 q_end: tl.constexpr = HEAD_QK 

35 blk_q_ptr = ( 

36 mixed_qkvz 

37 + i_bs * NUM_HEADS_QK * QKVZ_DIM_T 

38 + i_qk * QKVZ_DIM_T 

39 + tl.arange(0, q_end) 

40 ) 

41 k_end: tl.constexpr = q_end + HEAD_QK 

42 blk_k_ptr = ( 

43 mixed_qkvz 

44 + i_bs * NUM_HEADS_QK * QKVZ_DIM_T 

45 + i_qk * QKVZ_DIM_T 

46 + tl.arange(q_end, k_end) 

47 ) 

48 v_end: tl.constexpr = k_end + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V 

49 blk_v_ptr = ( 

50 mixed_qkvz 

51 + i_bs * NUM_HEADS_QK * QKVZ_DIM_T 

52 + i_qk * QKVZ_DIM_T 

53 + tl.arange(k_end, v_end) 

54 ) 

55 z_end: tl.constexpr = v_end + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V 

56 blk_z_ptr = ( 

57 mixed_qkvz 

58 + i_bs * NUM_HEADS_QK * QKVZ_DIM_T 

59 + i_qk * QKVZ_DIM_T 

60 + tl.arange(v_end, z_end) 

61 ) 

62 blk_q_st_ptr = ( 

63 mixed_qkv 

64 + i_bs * NUM_HEADS_QK * QKV_DIM_T 

65 + i_qk * HEAD_QK 

66 + tl.arange(0, HEAD_QK) 

67 ) 

68 blk_k_st_ptr = ( 

69 mixed_qkv 

70 + i_bs * NUM_HEADS_QK * QKV_DIM_T 

71 + NUM_HEADS_QK * HEAD_QK 

72 + i_qk * HEAD_QK 

73 + tl.arange(0, HEAD_QK) 

74 ) 

75 blk_v_st_ptr = ( 

76 mixed_qkv 

77 + i_bs * NUM_HEADS_QK * QKV_DIM_T 

78 + NUM_HEADS_QK * HEAD_QK * 2 

79 + i_qk * HEAD_V * NUM_HEADS_V // NUM_HEADS_QK 

80 + tl.arange(0, HEAD_V * NUM_HEADS_V // NUM_HEADS_QK) 

81 ) 

82 blk_z_st_ptr = ( 

83 z 

84 + i_bs * NUM_HEADS_V * HEAD_V 

85 + i_qk * HEAD_V * NUM_HEADS_V // NUM_HEADS_QK 

86 + tl.arange(0, HEAD_V * NUM_HEADS_V // NUM_HEADS_QK) 

87 ) 

88 tl.store(blk_q_st_ptr, tl.load(blk_q_ptr)) 

89 tl.store(blk_k_st_ptr, tl.load(blk_k_ptr)) 

90 tl.store(blk_v_st_ptr, tl.load(blk_v_ptr)) 

91 tl.store(blk_z_st_ptr, tl.load(blk_z_ptr)) 

92 b_end: tl.constexpr = NUM_HEADS_V // NUM_HEADS_QK 

93 a_end: tl.constexpr = b_end + NUM_HEADS_V // NUM_HEADS_QK 

94 for i in tl.static_range(b_end): 

95 blk_b_ptr = mixed_ba + i_bs * NUM_HEADS_QK * BA_DIM_T + i_qk * BA_DIM_T + i 

96 blk_b_st_ptr = b + i_bs * NUM_HEADS_V + i_qk * NUM_HEADS_V // NUM_HEADS_QK + i 

97 tl.store(blk_b_st_ptr, tl.load(blk_b_ptr)) 

98 for i in tl.static_range(b_end, a_end): 

99 blk_a_ptr = mixed_ba + i_bs * NUM_HEADS_QK * BA_DIM_T + i_qk * BA_DIM_T + i 

100 blk_a_st_ptr = ( 

101 a + i_bs * NUM_HEADS_V + i_qk * NUM_HEADS_V // NUM_HEADS_QK + (i - b_end) 

102 ) 

103 tl.store(blk_a_st_ptr, tl.load(blk_a_ptr)) 

104 

105 

106def fused_qkvzba_split_reshape_cat( 

107 mixed_qkvz, 

108 mixed_ba, 

109 num_heads_qk, 

110 num_heads_v, 

111 head_qk, 

112 head_v, 

113): 

114 batch, seq_len = mixed_qkvz.shape[0], 1 

115 qkv_dim_t = num_heads_qk * head_qk * 2 + num_heads_v * head_v 

116 mixed_qkv = torch.empty( 

117 [batch * seq_len, qkv_dim_t], 

118 dtype=mixed_qkvz.dtype, 

119 device=mixed_qkvz.device, 

120 ) 

121 z = torch.empty( 

122 [batch * seq_len, num_heads_v, head_v], 

123 dtype=mixed_qkvz.dtype, 

124 device=mixed_qkvz.device, 

125 ) 

126 b = torch.empty( 

127 [batch * seq_len, num_heads_v], 

128 dtype=mixed_ba.dtype, 

129 device=mixed_ba.device, 

130 ) 

131 a = torch.empty_like(b) 

132 grid = (batch * seq_len, num_heads_qk) 

133 fused_qkvzba_split_reshape_cat_kernel[grid]( 

134 mixed_qkv, 

135 z, 

136 b, 

137 a, 

138 mixed_qkvz, 

139 mixed_ba, 

140 num_heads_qk, 

141 num_heads_v, 

142 head_qk, 

143 head_v, 

144 num_warps=1, 

145 num_stages=3, 

146 ) 

147 return mixed_qkv, z, b, a