Coverage for src/flag_gems/fused/FLA/chunk_scaled_dot_kkt.py: 0%

46 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-28 12:23 +0800

1# This file contains code copied from the flash-linear-attention project. 

2# The original source code was licensed under the MIT license and included 

3# the following copyright notice: 

4# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang 

5# ruff: noqa: E501 

6 

7 

8import torch 

9import triton 

10import triton.language as tl 

11 

12from flag_gems.fused.FLA.index import prepare_chunk_indices 

13from flag_gems.fused.FLA.triton_ops_helper import exp 

14from flag_gems.utils import libentry, libtuner 

15 

16 

17@libentry() 

18@triton.heuristics( 

19 { 

20 "USE_G": lambda args: args["g"] is not None, 

21 "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, 

22 } 

23) 

24@libtuner( 

25 configs=[ 

26 triton.Config({"BK": BK}, num_warps=num_warps, num_stages=num_stages) 

27 for BK in [32, 64, 128] 

28 for num_warps in [2, 4, 8] 

29 for num_stages in [2, 3, 4] 

30 ], 

31 key=["H", "K", "BT", "IS_VARLEN"], 

32) 

33@triton.jit(do_not_specialize=["T"]) 

34def chunk_scaled_dot_kkt_fwd_kernel( 

35 k, 

36 beta, 

37 g, 

38 A, 

39 cu_seqlens, 

40 chunk_indices, 

41 T, 

42 H: tl.constexpr, 

43 Hg: tl.constexpr, 

44 K: tl.constexpr, 

45 BT: tl.constexpr, 

46 BK: tl.constexpr, 

47 IS_VARLEN: tl.constexpr, 

48 USE_G: tl.constexpr, 

49): 

50 i_t, i_bh = tl.program_id(0), tl.program_id(1) 

51 i_b, i_h = i_bh // H, i_bh % H 

52 if IS_VARLEN: 

53 i_n, i_t = ( 

54 tl.load(chunk_indices + i_t * 2).to(tl.int32), 

55 tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), 

56 ) 

57 bos, eos = ( 

58 tl.load(cu_seqlens + i_n).to(tl.int32), 

59 tl.load(cu_seqlens + i_n + 1).to(tl.int32), 

60 ) 

61 T = eos - bos 

62 else: 

63 bos, eos = i_b * T, i_b * T + T 

64 o_t = i_t * BT + tl.arange(0, BT) 

65 m_t = o_t < T 

66 

67 p_beta = tl.make_block_ptr( 

68 beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,) 

69 ) 

70 b_beta = tl.load(p_beta, boundary_check=(0,)) 

71 

72 b_A = tl.zeros([BT, BT], dtype=tl.float32) 

73 for i_k in range(tl.cdiv(K, BK)): 

74 p_k = tl.make_block_ptr( 

75 k + (bos * Hg + i_h // (H // Hg)) * K, 

76 (T, K), 

77 (Hg * K, 1), 

78 (i_t * BT, i_k * BK), 

79 (BT, BK), 

80 (1, 0), 

81 ) 

82 b_k = tl.load(p_k, boundary_check=(0, 1)) 

83 b_kb = b_k * b_beta[:, None] 

84 b_A += tl.dot(b_kb.to(b_k.dtype), tl.trans(b_k)) 

85 

86 if USE_G: 

87 p_g = tl.make_block_ptr(g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) 

88 b_g = tl.load(p_g, boundary_check=(0,)) 

89 b_g_diff = b_g[:, None] - b_g[None, :] 

90 b_A = b_A * exp(b_g_diff) 

91 

92 m_A = (o_t[:, None] > o_t[None, :]) & (m_t[:, None] & m_t) 

93 b_A = tl.where(m_A, b_A, 0) 

94 p_A = tl.make_block_ptr( 

95 A + (bos * H + i_h) * BT, (T, BT), (BT * H, 1), (i_t * BT, 0), (BT, BT), (1, 0) 

96 ) 

97 tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1)) 

98 

99 

100def chunk_scaled_dot_kkt_fwd( 

101 k: torch.Tensor, 

102 g: torch.Tensor | None = None, 

103 beta: torch.Tensor | None = None, 

104 cu_seqlens: torch.LongTensor | None = None, 

105 chunk_size: int = 64, 

106 output_dtype: torch.dtype = torch.float32, 

107) -> torch.Tensor: 

108 r""" 

109 Compute beta * K * K^T. 

110 

111 Args: 

112 k (torch.Tensor): 

113 The key tensor of shape `[B, T, H, K]`. 

114 beta (torch.Tensor): 

115 The beta tensor of shape `[B, T, H]`. 

116 g (torch.Tensor): 

117 The cumulative sum of the gate tensor of shape `[B, T, H]`. Default: `None`. 

118 cu_seqlens (torch.LongTensor): 

119 The cumulative sequence lengths of the input tensor. 

120 Default: None 

121 chunk_size (int): 

122 The chunk size. Default: 64. 

123 output_dtype (torch.dtype): 

124 The dtype of the output tensor. Default: `torch.float32` 

125 

126 Returns: 

127 beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size. 

128 """ 

129 # This kernel is slightly different from fla to support Q/K with different head numbers. 

130 # In fla, Q/K always have the same head number, so Hg is always equal to H. 

131 B, T, Hg, K = k.shape 

132 H = beta.shape[-1] 

133 BT = chunk_size 

134 chunk_indices = ( 

135 prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None 

136 ) 

137 NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) 

138 

139 A = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype) 

140 chunk_scaled_dot_kkt_fwd_kernel[(NT, B * H)]( 

141 k=k, 

142 g=g, 

143 beta=beta, 

144 A=A, 

145 cu_seqlens=cu_seqlens, 

146 chunk_indices=chunk_indices, 

147 T=T, 

148 H=H, 

149 Hg=Hg, 

150 K=K, 

151 BT=BT, 

152 ) 

153 return A