Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/scaled_softmax.py: 0%

122 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-10 02:30 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5from flag_gems.utils import libentry 

6 

7autotune_configs = [ 

8 triton.Config({"BLOCK_Q": 1, "BLOCK_K": 128}, num_warps=4, num_stages=2), 

9 triton.Config({"BLOCK_Q": 1, "BLOCK_K": 256}, num_warps=4, num_stages=2), 

10 triton.Config({"BLOCK_Q": 1, "BLOCK_K": 512}, num_warps=4, num_stages=2), 

11 triton.Config({"BLOCK_Q": 1, "BLOCK_K": 1024}, num_warps=4, num_stages=2), 

12 triton.Config({"BLOCK_Q": 1, "BLOCK_K": 2048}, num_warps=4, num_stages=2), 

13 triton.Config({"BLOCK_Q": 2, "BLOCK_K": 128}, num_warps=4, num_stages=2), 

14 triton.Config({"BLOCK_Q": 4, "BLOCK_K": 32}, num_warps=4, num_stages=2), 

15 triton.Config({"BLOCK_Q": 4, "BLOCK_K": 64}, num_warps=4, num_stages=2), 

16 triton.Config({"BLOCK_Q": 4, "BLOCK_K": 128}, num_warps=4, num_stages=2), 

17 triton.Config({"BLOCK_Q": 8, "BLOCK_K": 32}, num_warps=4, num_stages=2), 

18 triton.Config({"BLOCK_Q": 8, "BLOCK_K": 64}, num_warps=4, num_stages=2), 

19 triton.Config({"BLOCK_Q": 16, "BLOCK_K": 32}, num_warps=4, num_stages=2), 

20 triton.Config({"BLOCK_Q": 32, "BLOCK_K": 128}, num_warps=8, num_stages=4), 

21 triton.Config({"BLOCK_Q": 32, "BLOCK_K": 256}, num_warps=8, num_stages=4), 

22 triton.Config({"BLOCK_Q": 32, "BLOCK_K": 512}, num_warps=8, num_stages=4), 

23 triton.Config({"BLOCK_Q": 64, "BLOCK_K": 128}, num_warps=8, num_stages=4), 

24 triton.Config({"BLOCK_Q": 64, "BLOCK_K": 256}, num_warps=8, num_stages=4), 

25 triton.Config({"BLOCK_Q": 64, "BLOCK_K": 512}, num_warps=8, num_stages=4), 

26 triton.Config({"BLOCK_Q": 64, "BLOCK_K": 1024}, num_warps=8, num_stages=4), 

27 triton.Config({"BLOCK_Q": 128, "BLOCK_K": 512}, num_warps=8, num_stages=4), 

28] 

29 

30 

31@libentry() 

32@triton.autotune(configs=autotune_configs, key=["query_seq_len", "key_seq_len"]) 

33@triton.jit 

34def scaled_softmax_forward_kernel( 

35 output_ptr, 

36 input_ptr, 

37 scale_factor, 

38 query_seq_len, 

39 key_seq_len, 

40 stride_b, 

41 stride_h, 

42 stride_q, 

43 BLOCK_Q: tl.constexpr, 

44 BLOCK_K: tl.constexpr, 

45): 

46 query_seq_tile_idx = tl.program_id(0) 

47 attn_head_idx = tl.program_id(1) 

48 batch_idx = tl.program_id(2) 

49 

50 start_query_idx = query_seq_tile_idx * BLOCK_Q 

51 query_offsets = start_query_idx + tl.arange(0, BLOCK_Q) 

52 

53 query_mask = query_offsets < query_seq_len 

54 

55 row_start_ptr = ( 

56 input_ptr 

57 + batch_idx * stride_b 

58 + attn_head_idx * stride_h 

59 + query_offsets * stride_q 

60 ) 

61 

62 m = tl.full([BLOCK_Q], -float("inf"), dtype=tl.float32) 

63 exp_sum = tl.zeros([BLOCK_Q], dtype=tl.float32) 

64 

65 for k_block_idx in range(0, tl.cdiv(key_seq_len, BLOCK_K)): 

66 k_offsets = k_block_idx * BLOCK_K + tl.arange(0, BLOCK_K) 

67 block_ptr = row_start_ptr[:, None] + k_offsets[None, :] 

68 

69 row_mask = query_mask[:, None] 

70 col_mask = k_offsets[None, :] < key_seq_len 

71 mask = row_mask & col_mask 

72 

73 s_block = tl.load( 

74 block_ptr, mask=mask, other=-float("inf"), cache_modifier=".ca" 

75 ) 

76 s_block = s_block * scale_factor 

77 

78 m_new = tl.max(s_block, axis=1) 

79 m_old = m 

80 m = tl.maximum(m_old, m_new) 

81 

82 s_prev = tl.exp(m_old - m) 

83 exp_sum = exp_sum * s_prev 

84 

85 s_curr = tl.exp(s_block - m[:, None]) 

86 l_new = tl.sum(tl.where(mask, s_curr, 0.0), axis=1) 

87 exp_sum = exp_sum + l_new 

88 

89 exp_sum_inv = 1.0 / exp_sum 

90 

91 out_row_start_ptr = ( 

92 output_ptr 

93 + batch_idx * stride_b 

94 + attn_head_idx * stride_h 

95 + query_offsets * stride_q 

96 ) 

97 

98 for k_block_idx in range(0, tl.cdiv(key_seq_len, BLOCK_K)): 

99 k_offsets = k_block_idx * BLOCK_K + tl.arange(0, BLOCK_K) 

100 

101 block_ptr_in = row_start_ptr[:, None] + k_offsets[None, :] 

102 block_ptr_out = out_row_start_ptr[:, None] + k_offsets[None, :] 

103 

104 row_mask = query_mask[:, None] 

105 col_mask = k_offsets[None, :] < key_seq_len 

106 mask = row_mask & col_mask 

107 

108 s_block = tl.load( 

109 block_ptr_in, mask=mask, other=-float("inf"), eviction_policy="evict_first" 

110 ) 

111 

112 s_block = s_block * scale_factor 

113 s_block = s_block - m[:, None] 

114 p_block = tl.exp(s_block) 

115 p_block = p_block * exp_sum_inv[:, None] 

116 

117 tl.store(block_ptr_out, p_block, mask=mask, cache_modifier=".cs") 

118 

119 

120def scaled_softmax_forward(input_t: torch.Tensor, scale_factor: float): 

121 assert input_t.dim() == 4, "expected 4D tensor" 

122 batch_size, attn_heads, query_seq_len, key_seq_len = input_t.shape 

123 assert input_t.dtype in [ 

124 torch.float16, 

125 torch.bfloat16, 

126 ], "Only fp16 and bf16 are supported" 

127 assert key_seq_len <= 16384, "Key sequence length must be 16384 or less" 

128 assert key_seq_len % 8 == 0, "Key sequence length must be divisible by 8" 

129 assert query_seq_len > 1, "Query sequence length must be greater than 1" 

130 

131 def grid(meta): 

132 BLOCK_Q = meta["BLOCK_Q"] 

133 query_seq_tile_len = triton.cdiv(query_seq_len, BLOCK_Q) 

134 return (query_seq_tile_len, attn_heads, batch_size) 

135 

136 output_t = torch.empty_like(input_t) 

137 

138 stride_b = input_t.stride(0) 

139 stride_h = input_t.stride(1) 

140 stride_q = input_t.stride(2) 

141 

142 scaled_softmax_forward_kernel[grid]( 

143 output_t, 

144 input_t, 

145 scale_factor, 

146 query_seq_len, 

147 key_seq_len, 

148 stride_b, 

149 stride_h, 

150 stride_q, 

151 ) 

152 return output_t 

153 

154 

155@libentry() 

156@triton.autotune(configs=autotune_configs, key=["query_seq_len", "key_seq_len"]) 

157@triton.jit 

158def scaled_softmax_backward_kernel( 

159 grad_input_ptr, 

160 grad_output_ptr, 

161 output_ptr, 

162 scale_factor, 

163 query_seq_len, 

164 key_seq_len, 

165 stride_b, 

166 stride_h, 

167 stride_q, 

168 BLOCK_Q: tl.constexpr, 

169 BLOCK_K: tl.constexpr, 

170): 

171 query_seq_tile_idx = tl.program_id(0) 

172 attn_head_idx = tl.program_id(1) 

173 batch_idx = tl.program_id(2) 

174 

175 start_query_idx = query_seq_tile_idx * BLOCK_Q 

176 query_offsets = start_query_idx + tl.arange(0, BLOCK_Q) 

177 

178 query_mask = query_offsets < query_seq_len 

179 

180 output_row_ptr = ( 

181 output_ptr 

182 + batch_idx * stride_b 

183 + attn_head_idx * stride_h 

184 + query_offsets * stride_q 

185 ) 

186 

187 grad_output_row_ptr = ( 

188 grad_output_ptr 

189 + batch_idx * stride_b 

190 + attn_head_idx * stride_h 

191 + query_offsets * stride_q 

192 ) 

193 

194 grad_input_row_ptr = ( 

195 grad_input_ptr 

196 + batch_idx * stride_b 

197 + attn_head_idx * stride_h 

198 + query_offsets * stride_q 

199 ) 

200 

201 D = tl.zeros([BLOCK_Q], dtype=tl.float32) 

202 

203 for k_block_idx in range(0, tl.cdiv(key_seq_len, BLOCK_K)): 

204 k_offsets = k_block_idx * BLOCK_K + tl.arange(0, BLOCK_K) 

205 row_mask = query_mask[:, None] 

206 col_mask = k_offsets[None, :] < key_seq_len 

207 mask = row_mask & col_mask 

208 

209 ptr_P = output_row_ptr[:, None] + k_offsets[None, :] 

210 ptr_dP = grad_output_row_ptr[:, None] + k_offsets[None, :] 

211 

212 P_block = tl.load(ptr_P, mask=mask, other=0.0, cache_modifier=".ca") 

213 dP_block = tl.load(ptr_dP, mask=mask, other=0.0, cache_modifier=".ca") 

214 

215 dot_block = P_block * dP_block 

216 D += tl.sum(tl.where(mask, dot_block, 0.0), axis=1) 

217 

218 for k_block_idx in range(0, tl.cdiv(key_seq_len, BLOCK_K)): 

219 k_offsets = k_block_idx * BLOCK_K + tl.arange(0, BLOCK_K) 

220 row_mask = query_mask[:, None] 

221 col_mask = k_offsets[None, :] < key_seq_len 

222 mask = row_mask & col_mask 

223 

224 ptr_P = output_row_ptr[:, None] + k_offsets[None, :] 

225 ptr_dP = grad_output_row_ptr[:, None] + k_offsets[None, :] 

226 ptr_dS = grad_input_row_ptr[:, None] + k_offsets[None, :] 

227 

228 P_block = tl.load(ptr_P, mask=mask, other=0.0, eviction_policy="evict_first") 

229 dP_block = tl.load(ptr_dP, mask=mask, other=0.0, eviction_policy="evict_first") 

230 

231 dZ_block = P_block * (dP_block - D[:, None]) 

232 dS_block = scale_factor * dZ_block 

233 

234 tl.store(ptr_dS, dS_block, mask=mask, cache_modifier=".cs") 

235 

236 

237def scaled_softmax_backward( 

238 grad_output: torch.Tensor, softmax_results: torch.Tensor, scale_factor: float 

239): 

240 assert grad_output.dim() == 4, "expected 4D tensor" 

241 assert softmax_results.dim() == 4, "expected 4D tensor" 

242 assert grad_output.dtype in [ 

243 torch.float16, 

244 torch.bfloat16, 

245 ], "Only fp16 and bf16 are supported" 

246 assert softmax_results.dtype in [ 

247 torch.float16, 

248 torch.bfloat16, 

249 ], "Only fp16 and bf16 are supported" 

250 

251 grad_output = grad_output.contiguous() 

252 softmax_results = softmax_results.contiguous() 

253 

254 batch_size, attn_heads, query_seq_len, key_seq_len = softmax_results.shape 

255 

256 def grid(meta): 

257 BLOCK_Q = meta["BLOCK_Q"] 

258 query_seq_tile_len = triton.cdiv(query_seq_len, BLOCK_Q) 

259 return (query_seq_tile_len, attn_heads, batch_size) 

260 

261 grad_input = torch.empty_like(grad_output) 

262 

263 stride_b = softmax_results.stride(0) 

264 stride_h = softmax_results.stride(1) 

265 stride_q = softmax_results.stride(2) 

266 

267 scaled_softmax_backward_kernel[grid]( 

268 grad_input, 

269 grad_output, 

270 softmax_results, 

271 scale_factor, 

272 query_seq_len, 

273 key_seq_len, 

274 stride_b, 

275 stride_h, 

276 stride_q, 

277 ) 

278 

279 return grad_input