Coverage for src/flag_gems/runtime/backend/_ascend/fla/chunk_scaled_dot_kkt.py: 0%

49 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-21 14:31 +0800

1# SPDX-License-Identifier: Apache-2.0 

2# SPDX-FileCopyrightText: Copyright contributors to the vLLM project 

3# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang 

4# 

5# This file contains code copied from the flash-linear-attention project. 

6# The original source code was licensed under the MIT license and included 

7# the following copyright notice: 

8# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang 

9# ruff: noqa: E501 

10# mypy: ignore-errors 

11import torch 

12import triton 

13import triton.language as tl 

14 

15from .utils import prepare_chunk_indices, safe_exp 

16 

17 

18@triton.heuristics( 

19 { 

20 "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, 

21 "USE_G": lambda args: args["g_cumsum"] is not None, 

22 } 

23) 

24@triton.jit(do_not_specialize=["T"]) 

25def chunk_scaled_dot_kkt_fwd_kernel( 

26 k, 

27 beta, # [H, B, T] 

28 g_cumsum, # [H, B, T] 

29 A, 

30 cu_seqlens, 

31 chunk_indices, 

32 T, 

33 B, 

34 H: tl.constexpr, 

35 Hg: tl.constexpr, 

36 K: tl.constexpr, 

37 BT: tl.constexpr, 

38 BK: tl.constexpr, 

39 IS_VARLEN: tl.constexpr, 

40 USE_G: tl.constexpr, 

41): 

42 bt_stride = B * T 

43 i_t_i, _ = tl.program_id(0), tl.program_id(1) 

44 

45 for i_bh in range(B * H): 

46 i_b, i_h = i_bh // H, i_bh % H 

47 if IS_VARLEN: 

48 i_n, i_t = ( 

49 tl.load(chunk_indices + i_t_i * 2).to(tl.int32), 

50 tl.load(chunk_indices + i_t_i * 2 + 1).to(tl.int32), 

51 ) 

52 bos, eos = ( 

53 tl.load(cu_seqlens + i_n).to(tl.int32), 

54 tl.load(cu_seqlens + i_n + 1).to(tl.int32), 

55 ) 

56 T = eos - bos 

57 else: 

58 bos, eos = i_b * T, i_b * T + T 

59 i_t = i_t_i 

60 o_t = tl.arange(0, BT) 

61 o_t_fp32 = o_t.to(tl.float32) 

62 

63 p_beta = tl.make_block_ptr( 

64 beta + i_h * bt_stride + bos, (T,), (1,), (i_t * BT,), (BT,), (0,) 

65 ) 

66 b_beta = tl.load(p_beta, boundary_check=(0,)) 

67 

68 b_A = tl.zeros([BT, BT], dtype=tl.float32) 

69 for i_k in range(tl.cdiv(K, BK)): 

70 p_k = tl.make_block_ptr( 

71 k + (bos * Hg + i_h // (H // Hg)) * K, 

72 (T, K), 

73 (Hg * K, 1), 

74 (i_t * BT, i_k * BK), 

75 (BT, BK), 

76 (1, 0), 

77 ) 

78 b_k = tl.load(p_k, boundary_check=(0, 1)) 

79 b_A += tl.dot(b_k, tl.trans(b_k)) 

80 

81 if USE_G: 

82 p_g = tl.make_block_ptr( 

83 g_cumsum + i_h * bt_stride + bos, (T,), (1,), (i_t * BT,), (BT,), (0,) 

84 ) 

85 b_g = tl.load(p_g, boundary_check=(0,)) 

86 b_g_diff = b_g[:, None] - b_g[None, :] 

87 b_A *= safe_exp(b_g_diff) 

88 

89 b_A *= b_beta[:, None] 

90 b_A = tl.where(o_t_fp32[:, None] > o_t_fp32[None, :], b_A, 0) 

91 p_A = tl.make_block_ptr( 

92 A + (bos * H + i_h) * BT, 

93 (T, BT), 

94 (BT * H, 1), 

95 (i_t * BT, 0), 

96 (BT, BT), 

97 (1, 0), 

98 ) 

99 tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1)) 

100 

101 

102def chunk_scaled_dot_kkt_fwd( 

103 k: torch.Tensor, 

104 beta: torch.Tensor, 

105 g_cumsum: torch.Tensor | None = None, 

106 cu_seqlens: torch.LongTensor | None = None, 

107 chunk_size: int = 64, 

108 output_dtype: torch.dtype = torch.float32, 

109) -> torch.Tensor: 

110 r""" 

111 Compute beta * K * K^T. 

112 

113 Args: 

114 k (torch.Tensor): 

115 The key tensor of shape `[B, T, H, K]`. 

116 beta (torch.Tensor): 

117 The beta tensor of shape `[B, T, H]`. 

118 g (torch.Tensor): 

119 The cumulative sum of the gate tensor of shape `[B, T, H]`. Default: `None`. 

120 gk (torch.Tensor): 

121 The cumulative sum of the gate tensor of shape `[B, T, H, K]` applied to the key tensor. Default: `None`. 

122 cu_seqlens (torch.LongTensor): 

123 The cumulative sequence lengths of the input tensor. 

124 Default: None 

125 chunk_size (int): 

126 The chunk size. Default: 64. 

127 output_dtype (torch.dtype): 

128 The dtype of the output tensor. Default: `torch.float32` 

129 

130 Returns: 

131 beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size. 

132 """ 

133 B, T, Hg, K = k.shape 

134 

135 H = beta.shape[-1] 

136 BT = chunk_size 

137 if cu_seqlens is not None: 

138 cu_seqlens = cu_seqlens.cpu() 

139 chunk_indices = ( 

140 prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None 

141 ) 

142 chunk_indices = chunk_indices.npu() 

143 cu_seqlens = cu_seqlens.npu() 

144 else: 

145 chunk_indices = None 

146 NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) 

147 A = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype) 

148 

149 chunk_scaled_dot_kkt_fwd_kernel[(NT, 1)]( 

150 k=k, 

151 beta=torch.permute(beta, (2, 0, 1)).contiguous(), 

152 g_cumsum=torch.permute(g_cumsum, (2, 0, 1)).contiguous(), 

153 A=A, 

154 cu_seqlens=cu_seqlens, 

155 chunk_indices=chunk_indices, 

156 T=T, 

157 B=B, 

158 H=H, 

159 Hg=Hg, 

160 K=K, 

161 BT=BT, 

162 BK=128, 

163 num_warps=8, 

164 num_stages=3, 

165 multibuffer=True, 

166 ) 

167 return A