Coverage for src/flag_gems/runtime/backend/_ascend/fla/cumsum.py: 0%

55 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-27 02:51 +0800

1# SPDX-License-Identifier: Apache-2.0 

2# SPDX-FileCopyrightText: Copyright contributors to the vLLM project 

3# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang 

4# 

5# This file contains code copied from the flash-linear-attention project. 

6# The original source code was licensed under the MIT license and included 

7# the following copyright notice: 

8# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang 

9# ruff: noqa: E501 

10# mypy: ignore-errors 

11import torch 

12import triton 

13import triton.language as tl 

14 

15from .utils import prepare_chunk_indices 

16 

17 

18@triton.heuristics( 

19 { 

20 "HAS_SCALE": lambda args: args["scale"] is not None, 

21 "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, 

22 } 

23) 

24@triton.jit(do_not_specialize=["T"]) 

25def chunk_local_cumsum_scalar_kernel( 

26 s, 

27 o, 

28 scale, 

29 cu_seqlens, 

30 chunk_indices, 

31 T, 

32 B: tl.constexpr, 

33 H: tl.constexpr, 

34 BLOCK_T: tl.constexpr, 

35 REVERSE: tl.constexpr, 

36 HAS_SCALE: tl.constexpr, 

37 IS_VARLEN: tl.constexpr, 

38 HEAD_FIRST: tl.constexpr, 

39 CHUNK_SIZE: tl.constexpr = 64, 

40): 

41 i_block, i_b = tl.program_id(0), tl.program_id(1) 

42 N_CHUNKS: tl.constexpr = BLOCK_T // CHUNK_SIZE 

43 

44 if IS_VARLEN: 

45 i_s, i_block = ( 

46 tl.load(chunk_indices + i_block * 2).to(tl.int32), 

47 tl.load(chunk_indices + i_block * 2 + 1).to(tl.int32), 

48 ) 

49 bos, eos = ( 

50 tl.load(cu_seqlens + i_s).to(tl.int32), 

51 tl.load(cu_seqlens + i_s + 1).to(tl.int32), 

52 ) 

53 T = eos - bos 

54 else: 

55 bos, eos = i_b * T, i_b * T + T 

56 

57 if HEAD_FIRST: 

58 ptr_s = tl.make_block_ptr( 

59 s + bos * H, (H, T), (T, 1), (0, i_block * BLOCK_T), (H, BLOCK_T), (1, 0) 

60 ) 

61 ptr_o = tl.make_block_ptr( 

62 o + bos * H, (H, T), (T, 1), (0, i_block * BLOCK_T), (H, BLOCK_T), (1, 0) 

63 ) 

64 b_s = tl.load(ptr_s, boundary_check=(0,)).to(tl.float32) 

65 b_s = tl.reshape(b_s, (H, N_CHUNKS, CHUNK_SIZE)) 

66 b_s = tl.trans(b_s, (2, 0, 1)) 

67 b_o = tl.cumsum(b_s, axis=0, reverse=REVERSE) 

68 if HAS_SCALE: 

69 b_o *= scale 

70 b_o = tl.trans(b_o, (2, 0, 1)) 

71 b_o = tl.reshape(b_o, (H, BLOCK_T)) 

72 else: 

73 ptr_s = tl.make_block_ptr( 

74 s + bos * H, (T, H), (H, 1), (i_block * BLOCK_T, 0), (BLOCK_T, H), (1, 0) 

75 ) 

76 ptr_o = tl.make_block_ptr( 

77 o + bos * H, (T, H), (H, 1), (i_block * BLOCK_T, 0), (BLOCK_T, H), (1, 0) 

78 ) 

79 b_s = tl.load(ptr_s, boundary_check=(0,)).to(tl.float32) 

80 b_s = tl.reshape(b_s, (N_CHUNKS, CHUNK_SIZE, H)) 

81 b_s = tl.trans(b_s, (1, 0, 2)) 

82 b_o = tl.cumsum(b_s, axis=0, reverse=REVERSE) 

83 if HAS_SCALE: 

84 b_o *= scale 

85 b_o = tl.trans(b_o, (1, 0, 2)) 

86 b_o = tl.reshape(b_o, (BLOCK_T, H)) 

87 

88 tl.store(ptr_o, b_o.to(s.dtype.element_ty), boundary_check=(0,)) 

89 return 

90 

91 

92def chunk_local_cumsum_scalar( 

93 g, 

94 chunk_size, 

95 reverse: bool = False, 

96 scale: float = None, 

97 cu_seqlens: torch.Tensor | None = None, 

98 head_first: bool = False, 

99 output_dtype: torch.Tensor | None = torch.float, 

100): 

101 if head_first: 

102 B, H, T = g.shape 

103 else: 

104 B, T, H = g.shape 

105 assert chunk_size == 2 ** ( 

106 chunk_size.bit_length() - 1 

107 ), "chunk_size must be a power of 2" 

108 OPTIM_BLOCK_SIZE = triton.next_power_of_2((2**18) // (H * chunk_size)) 

109 block_indices = ( 

110 prepare_chunk_indices(cu_seqlens, chunk_size=OPTIM_BLOCK_SIZE) 

111 if cu_seqlens is not None 

112 else None 

113 ) 

114 num_blocks = ( 

115 len(block_indices) 

116 if cu_seqlens is not None 

117 else triton.cdiv(T, OPTIM_BLOCK_SIZE) 

118 ) 

119 g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype) 

120 grid = (num_blocks, B) 

121 chunk_local_cumsum_scalar_kernel[grid]( 

122 s=g_org, 

123 o=g, 

124 scale=scale, 

125 cu_seqlens=cu_seqlens, 

126 chunk_indices=block_indices, 

127 T=T, 

128 B=B, 

129 H=H, 

130 BLOCK_T=OPTIM_BLOCK_SIZE, 

131 CHUNK_SIZE=chunk_size, 

132 HEAD_FIRST=head_first, 

133 REVERSE=reverse, 

134 num_warps=8, 

135 num_stages=3, 

136 ) 

137 return g 

138 

139 

140def chunk_local_cumsum( 

141 g: torch.Tensor, 

142 chunk_size: int, 

143 reverse: bool = False, 

144 scale: float = None, 

145 cu_seqlens: torch.Tensor | None = None, 

146 head_first: bool = False, 

147 output_dtype: torch.dtype | None = torch.float, 

148 **kwargs, 

149) -> torch.Tensor: 

150 if cu_seqlens is not None: 

151 assert ( 

152 g.shape[0] == 1 

153 ), "Only batch size 1 is supported when cu_seqlens are provided" 

154 if len(g.shape) == 3: 

155 return chunk_local_cumsum_scalar( 

156 g=g, 

157 chunk_size=chunk_size, 

158 reverse=reverse, 

159 scale=scale, 

160 cu_seqlens=cu_seqlens, 

161 head_first=head_first, 

162 output_dtype=output_dtype, 

163 ) 

164 else: 

165 raise ValueError( 

166 f"Unsupported input shape {g.shape}, " 

167 f"which should be (B, T, H, D) if `head_first=False` " 

168 f"or (B, H, T, D) otherwise" 

169 )