Coverage for src/flag_gems/ops/scaled_softmax.py: 41%

126 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-28 12:23 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils import libentry 

8 

9autotune_configs = [ 

10 triton.Config({"BLOCK_Q": 1, "BLOCK_K": 128}, num_warps=4, num_stages=2), 

11 triton.Config({"BLOCK_Q": 1, "BLOCK_K": 256}, num_warps=4, num_stages=2), 

12 triton.Config({"BLOCK_Q": 1, "BLOCK_K": 512}, num_warps=4, num_stages=2), 

13 triton.Config({"BLOCK_Q": 1, "BLOCK_K": 1024}, num_warps=4, num_stages=2), 

14 triton.Config({"BLOCK_Q": 1, "BLOCK_K": 2048}, num_warps=4, num_stages=2), 

15 triton.Config({"BLOCK_Q": 2, "BLOCK_K": 128}, num_warps=4, num_stages=2), 

16 triton.Config({"BLOCK_Q": 4, "BLOCK_K": 32}, num_warps=4, num_stages=2), 

17 triton.Config({"BLOCK_Q": 4, "BLOCK_K": 64}, num_warps=4, num_stages=2), 

18 triton.Config({"BLOCK_Q": 4, "BLOCK_K": 128}, num_warps=4, num_stages=2), 

19 triton.Config({"BLOCK_Q": 8, "BLOCK_K": 32}, num_warps=4, num_stages=2), 

20 triton.Config({"BLOCK_Q": 8, "BLOCK_K": 64}, num_warps=4, num_stages=2), 

21 triton.Config({"BLOCK_Q": 16, "BLOCK_K": 32}, num_warps=4, num_stages=2), 

22 triton.Config({"BLOCK_Q": 32, "BLOCK_K": 128}, num_warps=8, num_stages=4), 

23 triton.Config({"BLOCK_Q": 32, "BLOCK_K": 256}, num_warps=8, num_stages=4), 

24 triton.Config({"BLOCK_Q": 32, "BLOCK_K": 512}, num_warps=8, num_stages=4), 

25 triton.Config({"BLOCK_Q": 64, "BLOCK_K": 128}, num_warps=8, num_stages=4), 

26 triton.Config({"BLOCK_Q": 64, "BLOCK_K": 256}, num_warps=8, num_stages=4), 

27 triton.Config({"BLOCK_Q": 64, "BLOCK_K": 512}, num_warps=8, num_stages=4), 

28 triton.Config({"BLOCK_Q": 64, "BLOCK_K": 1024}, num_warps=8, num_stages=4), 

29 triton.Config({"BLOCK_Q": 128, "BLOCK_K": 512}, num_warps=8, num_stages=4), 

30] 

31 

32 

33logger = logging.getLogger(__name__) 

34 

35 

36@libentry() 

37@triton.autotune(configs=autotune_configs, key=["query_seq_len", "key_seq_len"]) 

38@triton.jit 

39def scaled_softmax_forward_kernel( 

40 output_ptr, 

41 input_ptr, 

42 scale_factor, 

43 query_seq_len, 

44 key_seq_len, 

45 stride_b, 

46 stride_h, 

47 stride_q, 

48 BLOCK_Q: tl.constexpr, 

49 BLOCK_K: tl.constexpr, 

50): 

51 query_seq_tile_idx = tl.program_id(0) 

52 attn_head_idx = tl.program_id(1) 

53 batch_idx = tl.program_id(2) 

54 

55 start_query_idx = query_seq_tile_idx * BLOCK_Q 

56 query_offsets = start_query_idx + tl.arange(0, BLOCK_Q) 

57 

58 query_mask = query_offsets < query_seq_len 

59 

60 row_start_ptr = ( 

61 input_ptr 

62 + batch_idx * stride_b 

63 + attn_head_idx * stride_h 

64 + query_offsets * stride_q 

65 ) 

66 

67 m = tl.full([BLOCK_Q], -float("inf"), dtype=tl.float32) 

68 exp_sum = tl.zeros([BLOCK_Q], dtype=tl.float32) 

69 

70 for k_block_idx in range(0, tl.cdiv(key_seq_len, BLOCK_K)): 

71 k_offsets = k_block_idx * BLOCK_K + tl.arange(0, BLOCK_K) 

72 block_ptr = row_start_ptr[:, None] + k_offsets[None, :] 

73 

74 row_mask = query_mask[:, None] 

75 col_mask = k_offsets[None, :] < key_seq_len 

76 mask = row_mask & col_mask 

77 

78 s_block = tl.load( 

79 block_ptr, mask=mask, other=-float("inf"), cache_modifier=".ca" 

80 ) 

81 s_block = s_block * scale_factor 

82 

83 m_new = tl.max(s_block, axis=1) 

84 m_old = m 

85 m = tl.maximum(m_old, m_new) 

86 

87 s_prev = tl.exp(m_old - m) 

88 exp_sum = exp_sum * s_prev 

89 

90 s_curr = tl.exp(s_block - m[:, None]) 

91 l_new = tl.sum(tl.where(mask, s_curr, 0.0), axis=1) 

92 exp_sum = exp_sum + l_new 

93 

94 exp_sum_inv = 1.0 / exp_sum 

95 

96 out_row_start_ptr = ( 

97 output_ptr 

98 + batch_idx * stride_b 

99 + attn_head_idx * stride_h 

100 + query_offsets * stride_q 

101 ) 

102 

103 for k_block_idx in range(0, tl.cdiv(key_seq_len, BLOCK_K)): 

104 k_offsets = k_block_idx * BLOCK_K + tl.arange(0, BLOCK_K) 

105 

106 block_ptr_in = row_start_ptr[:, None] + k_offsets[None, :] 

107 block_ptr_out = out_row_start_ptr[:, None] + k_offsets[None, :] 

108 

109 row_mask = query_mask[:, None] 

110 col_mask = k_offsets[None, :] < key_seq_len 

111 mask = row_mask & col_mask 

112 

113 s_block = tl.load( 

114 block_ptr_in, mask=mask, other=-float("inf"), eviction_policy="evict_first" 

115 ) 

116 

117 s_block = s_block * scale_factor 

118 s_block = s_block - m[:, None] 

119 p_block = tl.exp(s_block) 

120 p_block = p_block * exp_sum_inv[:, None] 

121 

122 tl.store(block_ptr_out, p_block, mask=mask, cache_modifier=".cs") 

123 

124 

125def scaled_softmax_forward(input_t: torch.Tensor, scale_factor: float): 

126 logger.debug("GEMS SCALED SOFTMAX FORWARD") 

127 assert input_t.dim() == 4, "expected 4D tensor" 

128 batch_size, attn_heads, query_seq_len, key_seq_len = input_t.shape 

129 assert input_t.dtype in [ 

130 torch.float16, 

131 torch.bfloat16, 

132 ], "Only fp16 and bf16 are supported" 

133 assert key_seq_len <= 16384, "Key sequence length must be 16384 or less" 

134 assert key_seq_len % 8 == 0, "Key sequence length must be divisible by 8" 

135 assert query_seq_len > 1, "Query sequence length must be greater than 1" 

136 

137 def grid(meta): 

138 BLOCK_Q = meta["BLOCK_Q"] 

139 query_seq_tile_len = triton.cdiv(query_seq_len, BLOCK_Q) 

140 return (query_seq_tile_len, attn_heads, batch_size) 

141 

142 output_t = torch.empty_like(input_t) 

143 

144 stride_b = input_t.stride(0) 

145 stride_h = input_t.stride(1) 

146 stride_q = input_t.stride(2) 

147 

148 scaled_softmax_forward_kernel[grid]( 

149 output_t, 

150 input_t, 

151 scale_factor, 

152 query_seq_len, 

153 key_seq_len, 

154 stride_b, 

155 stride_h, 

156 stride_q, 

157 ) 

158 return output_t 

159 

160 

161@libentry() 

162@triton.autotune(configs=autotune_configs, key=["query_seq_len", "key_seq_len"]) 

163@triton.jit 

164def scaled_softmax_backward_kernel( 

165 grad_input_ptr, 

166 grad_output_ptr, 

167 output_ptr, 

168 scale_factor, 

169 query_seq_len, 

170 key_seq_len, 

171 stride_b, 

172 stride_h, 

173 stride_q, 

174 BLOCK_Q: tl.constexpr, 

175 BLOCK_K: tl.constexpr, 

176): 

177 query_seq_tile_idx = tl.program_id(0) 

178 attn_head_idx = tl.program_id(1) 

179 batch_idx = tl.program_id(2) 

180 

181 start_query_idx = query_seq_tile_idx * BLOCK_Q 

182 query_offsets = start_query_idx + tl.arange(0, BLOCK_Q) 

183 

184 query_mask = query_offsets < query_seq_len 

185 

186 output_row_ptr = ( 

187 output_ptr 

188 + batch_idx * stride_b 

189 + attn_head_idx * stride_h 

190 + query_offsets * stride_q 

191 ) 

192 

193 grad_output_row_ptr = ( 

194 grad_output_ptr 

195 + batch_idx * stride_b 

196 + attn_head_idx * stride_h 

197 + query_offsets * stride_q 

198 ) 

199 

200 grad_input_row_ptr = ( 

201 grad_input_ptr 

202 + batch_idx * stride_b 

203 + attn_head_idx * stride_h 

204 + query_offsets * stride_q 

205 ) 

206 

207 D = tl.zeros([BLOCK_Q], dtype=tl.float32) 

208 

209 for k_block_idx in range(0, tl.cdiv(key_seq_len, BLOCK_K)): 

210 k_offsets = k_block_idx * BLOCK_K + tl.arange(0, BLOCK_K) 

211 row_mask = query_mask[:, None] 

212 col_mask = k_offsets[None, :] < key_seq_len 

213 mask = row_mask & col_mask 

214 

215 ptr_P = output_row_ptr[:, None] + k_offsets[None, :] 

216 ptr_dP = grad_output_row_ptr[:, None] + k_offsets[None, :] 

217 

218 P_block = tl.load(ptr_P, mask=mask, other=0.0, cache_modifier=".ca") 

219 dP_block = tl.load(ptr_dP, mask=mask, other=0.0, cache_modifier=".ca") 

220 

221 dot_block = P_block * dP_block 

222 D += tl.sum(tl.where(mask, dot_block, 0.0), axis=1) 

223 

224 for k_block_idx in range(0, tl.cdiv(key_seq_len, BLOCK_K)): 

225 k_offsets = k_block_idx * BLOCK_K + tl.arange(0, BLOCK_K) 

226 row_mask = query_mask[:, None] 

227 col_mask = k_offsets[None, :] < key_seq_len 

228 mask = row_mask & col_mask 

229 

230 ptr_P = output_row_ptr[:, None] + k_offsets[None, :] 

231 ptr_dP = grad_output_row_ptr[:, None] + k_offsets[None, :] 

232 ptr_dS = grad_input_row_ptr[:, None] + k_offsets[None, :] 

233 

234 P_block = tl.load(ptr_P, mask=mask, other=0.0, eviction_policy="evict_first") 

235 dP_block = tl.load(ptr_dP, mask=mask, other=0.0, eviction_policy="evict_first") 

236 

237 dZ_block = P_block * (dP_block - D[:, None]) 

238 dS_block = scale_factor * dZ_block 

239 

240 tl.store(ptr_dS, dS_block, mask=mask, cache_modifier=".cs") 

241 

242 

243def scaled_softmax_backward( 

244 grad_output: torch.Tensor, softmax_results: torch.Tensor, scale_factor: float 

245): 

246 logger.debug("GEMS SCALED SOFTMAX BACKWARD") 

247 assert grad_output.dim() == 4, "expected 4D tensor" 

248 assert softmax_results.dim() == 4, "expected 4D tensor" 

249 assert grad_output.dtype in [ 

250 torch.float16, 

251 torch.bfloat16, 

252 ], "Only fp16 and bf16 are supported" 

253 assert softmax_results.dtype in [ 

254 torch.float16, 

255 torch.bfloat16, 

256 ], "Only fp16 and bf16 are supported" 

257 

258 grad_output = grad_output.contiguous() 

259 softmax_results = softmax_results.contiguous() 

260 

261 batch_size, attn_heads, query_seq_len, key_seq_len = softmax_results.shape 

262 

263 def grid(meta): 

264 BLOCK_Q = meta["BLOCK_Q"] 

265 query_seq_tile_len = triton.cdiv(query_seq_len, BLOCK_Q) 

266 return (query_seq_tile_len, attn_heads, batch_size) 

267 

268 grad_input = torch.empty_like(grad_output) 

269 

270 stride_b = softmax_results.stride(0) 

271 stride_h = softmax_results.stride(1) 

272 stride_q = softmax_results.stride(2) 

273 

274 scaled_softmax_backward_kernel[grid]( 

275 grad_input, 

276 grad_output, 

277 softmax_results, 

278 scale_factor, 

279 query_seq_len, 

280 key_seq_len, 

281 stride_b, 

282 stride_h, 

283 stride_q, 

284 ) 

285 

286 return grad_input