Coverage for src/flag_gems/fused/FLA/cumsum.py: 0%

90 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-29 04:01 +0800

1# This file contains code copied from the flash-linear-attention project. 

2# The original source code was licensed under the MIT license and included 

3# the following copyright notice: 

4# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang 

5# ruff: noqa: E501 

6import warnings 

7 

8import torch 

9import triton 

10import triton.language as tl 

11 

12from flag_gems.fused.FLA.index import prepare_chunk_indices 

13from flag_gems.fused.FLA.utils import check_shared_mem, input_guard 

14from flag_gems.utils import libentry, libtuner 

15 

16BS_LIST = [32, 64] if check_shared_mem() else [16, 32] 

17 

18 

19@libentry() 

20@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) 

21@libtuner( 

22 configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]], 

23 key=["B", "H", "BT", "IS_VARLEN", "REVERSE"], 

24) 

25@triton.jit(do_not_specialize=["T"]) 

26def chunk_local_cumsum_scalar_kernel( 

27 s, 

28 o, 

29 cu_seqlens, 

30 chunk_indices, 

31 T, 

32 B: tl.constexpr, 

33 H: tl.constexpr, 

34 BT: tl.constexpr, 

35 REVERSE: tl.constexpr, 

36 IS_VARLEN: tl.constexpr, 

37 HEAD_FIRST: tl.constexpr, 

38): 

39 i_t, i_bh = tl.program_id(0), tl.program_id(1) 

40 i_b, i_h = i_bh // H, i_bh % H 

41 if IS_VARLEN: 

42 i_n, i_t = ( 

43 tl.load(chunk_indices + i_t * 2).to(tl.int32), 

44 tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), 

45 ) 

46 bos, eos = ( 

47 tl.load(cu_seqlens + i_n).to(tl.int32), 

48 tl.load(cu_seqlens + i_n + 1).to(tl.int32), 

49 ) 

50 T = eos - bos 

51 else: 

52 bos, eos = i_b * T, i_b * T + T 

53 

54 if HEAD_FIRST: 

55 p_s = tl.make_block_ptr( 

56 s + bos * H + i_h * T, (T,), (1,), (i_t * BT,), (BT,), (0,) 

57 ) 

58 p_o = tl.make_block_ptr( 

59 o + bos * H + i_h * T, (T,), (1,), (i_t * BT,), (BT,), (0,) 

60 ) 

61 else: 

62 p_s = tl.make_block_ptr(s + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) 

63 p_o = tl.make_block_ptr(o + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) 

64 # [BT] 

65 b_s = tl.load(p_s, boundary_check=(0,)).to(tl.float32) 

66 b_o = tl.cumsum(b_s, axis=0) 

67 if REVERSE: 

68 b_z = tl.sum(b_s, axis=0) 

69 b_o = -b_o + b_z[None] + b_s 

70 tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,)) 

71 

72 

73@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) 

74@triton.autotune( 

75 configs=[ 

76 triton.Config({"BS": BS}, num_warps=num_warps) 

77 for BS in BS_LIST 

78 for num_warps in [2, 4, 8] 

79 ], 

80 key=["B", "H", "S", "BT", "IS_VARLEN", "REVERSE"], 

81) 

82@triton.jit(do_not_specialize=["T"]) 

83def chunk_local_cumsum_vector_kernel( 

84 s, 

85 o, 

86 cu_seqlens, 

87 chunk_indices, 

88 T, 

89 B: tl.constexpr, 

90 H: tl.constexpr, 

91 S: tl.constexpr, 

92 BT: tl.constexpr, 

93 BS: tl.constexpr, 

94 REVERSE: tl.constexpr, 

95 IS_VARLEN: tl.constexpr, 

96 HEAD_FIRST: tl.constexpr, 

97): 

98 i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) 

99 i_b, i_h = i_bh // H, i_bh % H 

100 if IS_VARLEN: 

101 i_n, i_t = ( 

102 tl.load(chunk_indices + i_t * 2).to(tl.int32), 

103 tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), 

104 ) 

105 bos, eos = ( 

106 tl.load(cu_seqlens + i_n).to(tl.int32), 

107 tl.load(cu_seqlens + i_n + 1).to(tl.int32), 

108 ) 

109 T = eos - bos 

110 else: 

111 bos, eos = i_b * T, i_b * T + T 

112 

113 o_i = tl.arange(0, BT) 

114 if REVERSE: 

115 m_s = tl.where(o_i[:, None] <= o_i[None, :], 1.0, 0.0) 

116 else: 

117 m_s = tl.where(o_i[:, None] >= o_i[None, :], 1.0, 0.0) 

118 

119 if HEAD_FIRST: 

120 p_s = tl.make_block_ptr( 

121 s + (bos * H + i_h * T) * S, 

122 (T, S), 

123 (S, 1), 

124 (i_t * BT, i_s * BS), 

125 (BT, BS), 

126 (1, 0), 

127 ) 

128 p_o = tl.make_block_ptr( 

129 o + (bos * H + i_h * T) * S, 

130 (T, S), 

131 (S, 1), 

132 (i_t * BT, i_s * BS), 

133 (BT, BS), 

134 (1, 0), 

135 ) 

136 else: 

137 p_s = tl.make_block_ptr( 

138 s + (bos * H + i_h) * S, 

139 (T, S), 

140 (H * S, 1), 

141 (i_t * BT, i_s * BS), 

142 (BT, BS), 

143 (1, 0), 

144 ) 

145 p_o = tl.make_block_ptr( 

146 o + (bos * H + i_h) * S, 

147 (T, S), 

148 (H * S, 1), 

149 (i_t * BT, i_s * BS), 

150 (BT, BS), 

151 (1, 0), 

152 ) 

153 # [BT, BS] 

154 b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32) 

155 b_o = tl.dot(m_s, b_s, allow_tf32=False) 

156 tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) 

157 

158 

159def chunk_local_cumsum_scalar( 

160 g: torch.Tensor, 

161 chunk_size: int, 

162 reverse: bool = False, 

163 cu_seqlens: torch.Tensor | None = None, 

164 head_first: bool = False, 

165 output_dtype: torch.dtype | None = torch.float, 

166) -> torch.Tensor: 

167 if head_first: 

168 B, H, T = g.shape 

169 else: 

170 B, T, H = g.shape 

171 assert chunk_size == 2 ** ( 

172 chunk_size.bit_length() - 1 

173 ), "chunk_size must be a power of 2" 

174 BT = chunk_size 

175 chunk_indices = ( 

176 prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None 

177 ) 

178 NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) 

179 g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype) 

180 grid = (NT, B * H) 

181 chunk_local_cumsum_scalar_kernel[grid]( 

182 g_org, 

183 g, 

184 cu_seqlens, 

185 chunk_indices, 

186 T=T, 

187 B=B, 

188 H=H, 

189 BT=BT, 

190 HEAD_FIRST=head_first, 

191 REVERSE=reverse, 

192 ) 

193 return g 

194 

195 

196def chunk_local_cumsum_vector( 

197 g: torch.Tensor, 

198 chunk_size: int, 

199 reverse: bool = False, 

200 cu_seqlens: torch.Tensor | None = None, 

201 head_first: bool = False, 

202 output_dtype: torch.dtype | None = torch.float, 

203) -> torch.Tensor: 

204 if head_first: 

205 B, H, T, S = g.shape 

206 else: 

207 B, T, H, S = g.shape 

208 BT = chunk_size 

209 chunk_indices = ( 

210 prepare_chunk_indices(cu_seqlens, chunk_size) 

211 if cu_seqlens is not None 

212 else None 

213 ) 

214 NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) 

215 assert chunk_size == 2 ** ( 

216 chunk_size.bit_length() - 1 

217 ), "chunk_size must be a power of 2" 

218 

219 g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype) 

220 

221 def grid(meta): 

222 return (triton.cdiv(meta["S"], meta["BS"]), NT, B * H) 

223 

224 # keep cumulative normalizer in fp32 

225 # this kernel is equivalent to 

226 # g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1) 

227 chunk_local_cumsum_vector_kernel[grid]( 

228 g_org, 

229 g, 

230 cu_seqlens, 

231 chunk_indices, 

232 T=T, 

233 B=B, 

234 H=H, 

235 S=S, 

236 BT=BT, 

237 HEAD_FIRST=head_first, 

238 REVERSE=reverse, 

239 ) 

240 return g 

241 

242 

243@input_guard 

244def chunk_local_cumsum( 

245 g: torch.Tensor, 

246 chunk_size: int, 

247 reverse: bool = False, 

248 cu_seqlens: torch.Tensor | None = None, 

249 head_first: bool = False, 

250 output_dtype: torch.dtype | None = torch.float, 

251 **kwargs, 

252) -> torch.Tensor: 

253 if not head_first and g.shape[1] < g.shape[2]: 

254 warnings.warn( 

255 "Input tensor shape suggests potential format mismatch: " 

256 f" seq_len ({g.shape[1]}) < num_heads ({g.shape[2]}). " 

257 "This may indicate the inputs were passed in head-first format [B, H, T, ...] " 

258 "when head_first=False was specified. " 

259 "Please verify your input tensor format matches the expected shape [B, T, H, ...].", 

260 stacklevel=2, 

261 ) 

262 if cu_seqlens is not None: 

263 assert ( 

264 g.shape[0] == 1 

265 ), "Only batch size 1 is supported when cu_seqlens are provided" 

266 if len(g.shape) == 3: 

267 return chunk_local_cumsum_scalar( 

268 g, chunk_size, reverse, cu_seqlens, head_first, output_dtype 

269 ) 

270 elif len(g.shape) == 4: 

271 return chunk_local_cumsum_vector( 

272 g, chunk_size, reverse, cu_seqlens, head_first, output_dtype 

273 ) 

274 else: 

275 raise ValueError( 

276 f"Unsupported input shape {g.shape}. " 

277 f"which should be (B, T, H, D) if `head_first=False` " 

278 f"or (B, H, T, D) otherwise" 

279 )