Coverage for src/flag_gems/fused/grouped_topk.py: 6%

130 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-23 02:03 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def topk_with_k2_triton( 

8 scores_ptr, 

9 bias_ptr, 

10 group_scores_ptr, 

11 num_experts_per_group, 

12 n_group, 

13 stride_scores_token, 

14 stride_group_scores_token, 

15 BLOCK_SIZE: tl.constexpr, 

16 INPUT_DTYPE: tl.constexpr, 

17): 

18 pid = tl.program_id(0) 

19 

20 token_id = pid // n_group 

21 group_id = pid % n_group 

22 

23 lane = tl.arange(0, BLOCK_SIZE) 

24 mask = lane < num_experts_per_group 

25 

26 scores_offset = token_id * stride_scores_token + group_id * num_experts_per_group 

27 bias_offset = group_id * num_experts_per_group 

28 

29 x = tl.load( 

30 scores_ptr + scores_offset + lane, 

31 mask=mask, 

32 other=-float("inf"), 

33 ) 

34 

35 b = tl.load( 

36 bias_ptr + bias_offset + lane, 

37 mask=mask, 

38 other=0.0, 

39 ) 

40 

41 x = x + b 

42 

43 x_f32 = x.to(tl.float32) 

44 

45 max1 = tl.max(x_f32, axis=0) 

46 is_max1 = (x_f32 == max1) & mask 

47 count_max1 = tl.sum(is_max1.to(tl.int32), axis=0) 

48 

49 x2 = tl.where( 

50 is_max1 & (count_max1 == 1), 

51 -float("inf"), 

52 x_f32, 

53 ) 

54 max2 = tl.max(x2, axis=0) 

55 

56 group_scores_offset = token_id * stride_group_scores_token + group_id 

57 tl.store( 

58 group_scores_ptr + group_scores_offset, 

59 (max1 + max2).to(INPUT_DTYPE), 

60 ) 

61 

62 

63@triton.jit 

64def group_idx_and_topk_triton( 

65 scores_ptr, 

66 group_scores_ptr, 

67 topk_values_ptr, 

68 topk_indices_ptr, 

69 bias_ptr, 

70 num_tokens, 

71 n_group, 

72 topk_group, 

73 topk, 

74 num_experts, 

75 num_experts_per_group, 

76 routed_scaling_factor, 

77 stride_scores_token, 

78 stride_group_scores_token, 

79 stride_out_token, 

80 N_GROUP: tl.constexpr, 

81 TOPK_GROUP: tl.constexpr, 

82 TOPK: tl.constexpr, 

83 BLOCK_GROUP: tl.constexpr, 

84 BLOCK_EXPERT: tl.constexpr, 

85 INPUT_DTYPE: tl.constexpr, 

86 renormalize: tl.constexpr, 

87): 

88 pid = tl.program_id(0) 

89 if pid >= num_tokens: 

90 return 

91 

92 neg_inf = -float("inf") 

93 

94 group_offsets = tl.arange(0, BLOCK_GROUP) 

95 valid_group = group_offsets < n_group 

96 

97 group_scores = tl.load( 

98 group_scores_ptr + pid * stride_group_scores_token + group_offsets, 

99 mask=valid_group, 

100 other=neg_inf, 

101 ) 

102 

103 group_scores_f32 = group_scores.to(tl.float32) 

104 is_finite = (group_scores_f32 == group_scores_f32) & ( 

105 group_scores_f32 != float("inf") 

106 ) 

107 group_scores_f32 = tl.where(is_finite & valid_group, group_scores_f32, neg_inf) 

108 

109 max_group_score = tl.max(group_scores_f32, axis=0) 

110 if_proceed = max_group_score != neg_inf 

111 

112 value = group_scores_f32 

113 target_num_min = BLOCK_GROUP - n_group + topk_group 

114 count_equal_to_top_value = BLOCK_GROUP - n_group 

115 pre_count_equal_to_top_value = 0 

116 topk_group_value = neg_inf 

117 

118 for _ in range(TOPK_GROUP): 

119 need = count_equal_to_top_value < target_num_min 

120 max_val = tl.max(value, axis=0) 

121 

122 is_max = need & (value == max_val) 

123 value = tl.where(is_max, neg_inf, value) 

124 

125 newly = tl.sum(is_max.to(tl.int32), axis=0) 

126 

127 pre_count_equal_to_top_value = tl.where( 

128 need, count_equal_to_top_value, pre_count_equal_to_top_value 

129 ) 

130 count_equal_to_top_value = tl.where( 

131 need, count_equal_to_top_value + newly, count_equal_to_top_value 

132 ) 

133 topk_group_value = tl.where(need, max_val, topk_group_value) 

134 

135 num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value 

136 

137 group_gt = group_scores_f32 > topk_group_value 

138 group_eq = group_scores_f32 == topk_group_value 

139 

140 eq_i = group_eq.to(tl.int32) 

141 prefix_eq = tl.cumsum(eq_i, axis=0) - eq_i 

142 

143 group_selected = ( 

144 group_gt | (group_eq & (prefix_eq < num_equalto_topkth_group)) 

145 ) & valid_group 

146 

147 expert_offsets = tl.arange(0, BLOCK_EXPERT) 

148 valid_expert = expert_offsets < num_experts 

149 expert_group = expert_offsets // num_experts_per_group 

150 

151 expert_in_group = expert_group[:, None] == group_offsets[None, :] 

152 expert_selected = ( 

153 tl.sum((expert_in_group & group_selected[None, :]).to(tl.int32), axis=1) > 0 

154 ) & valid_expert 

155 

156 scored = tl.load( 

157 scores_ptr + pid * stride_scores_token + expert_offsets, 

158 mask=expert_selected, 

159 other=neg_inf, 

160 ) 

161 

162 expert_bias = tl.load( 

163 bias_ptr + expert_offsets, 

164 mask=valid_expert, 

165 other=0.0, 

166 ) 

167 

168 selection_scores_native = scored + expert_bias 

169 

170 selection_scores = tl.where( 

171 expert_selected, 

172 selection_scores_native.to(tl.float32), 

173 neg_inf, 

174 ) 

175 

176 topk_vals = tl.full([TOPK], 0.0, tl.float32) 

177 topk_idx = tl.full([TOPK], 0, tl.int32) 

178 pos_range = tl.arange(0, TOPK) 

179 

180 for i in range(TOPK): 

181 max_val = tl.max(selection_scores, axis=0) 

182 is_max = selection_scores == max_val 

183 

184 candidate_idx = tl.where(is_max, expert_offsets, num_experts + 1) 

185 selected_idx = tl.min(candidate_idx, axis=0) 

186 

187 selected_score = tl.load( 

188 scores_ptr + pid * stride_scores_token + selected_idx, 

189 mask=selected_idx < num_experts, 

190 other=neg_inf, 

191 ).to(tl.float32) 

192 

193 topk_vals = tl.where(pos_range == i, selected_score, topk_vals) 

194 topk_idx = tl.where(pos_range == i, selected_idx.to(tl.int32), topk_idx) 

195 

196 selection_scores = tl.where( 

197 expert_offsets == selected_idx, neg_inf, selection_scores 

198 ) 

199 

200 if renormalize == 1: 

201 topk_sum = tl.sum(topk_vals, axis=0) + 1e-20 

202 scale = routed_scaling_factor / topk_sum 

203 else: 

204 scale = routed_scaling_factor 

205 

206 topk_vals = topk_vals * scale 

207 

208 default_idx = pos_range.to(tl.int32) 

209 default_vals = tl.full([TOPK], 1.0 / topk, tl.float32) 

210 

211 final_vals = tl.where(if_proceed, topk_vals, default_vals) 

212 final_idx = tl.where(if_proceed, topk_idx, default_idx) 

213 

214 tl.store( 

215 topk_values_ptr + pid * stride_out_token + pos_range, 

216 final_vals, 

217 mask=pos_range < topk, 

218 ) 

219 

220 tl.store( 

221 topk_indices_ptr + pid * stride_out_token + pos_range, 

222 final_idx, 

223 mask=pos_range < topk, 

224 ) 

225 

226 

227def grouped_topk( 

228 scores: torch.Tensor, 

229 n_group: int, 

230 topk_group: int, 

231 topk: int, 

232 renormalize: bool, 

233 routed_scaling_factor: float, 

234 bias: torch.Tensor, 

235 scoring_func: int = 0, 

236): 

237 if scores.ndim != 2: 

238 raise ValueError("scores must be a 2D Tensor") 

239 num_tokens, num_experts = scores.shape 

240 if num_experts % n_group != 0: 

241 raise ValueError("num_experts must be divisible by n_group") 

242 if n_group > 32: 

243 raise ValueError("n_group should be smaller than or equal to 32") 

244 if topk > 32: 

245 raise ValueError("topk should be smaller than or equal to 32 for now") 

246 if scoring_func not in (0, 1): 

247 raise ValueError("scoring_func must be 0 (none) or 1 (sigmoid)") 

248 

249 if bias.dtype != scores.dtype: 

250 bias = bias.to(scores.dtype) 

251 if bias.ndim != 1: 

252 bias = bias.flatten() 

253 if len(bias) != num_experts: 

254 raise ValueError( 

255 f"bias length ({len(bias)}) must match num_experts ({num_experts})" 

256 ) 

257 

258 num_experts_per_group = num_experts // n_group 

259 

260 if scores.dtype == torch.float32: 

261 INPUT_DTYPE = tl.float32 

262 elif scores.dtype == torch.float16: 

263 INPUT_DTYPE = tl.float16 

264 elif scores.dtype == torch.bfloat16: 

265 INPUT_DTYPE = tl.bfloat16 

266 else: 

267 raise ValueError(f"Unsupported dtype: {scores.dtype}") 

268 

269 if scoring_func == 1: 

270 from flag_gems.ops.tanh import tanh as gems_tanh 

271 

272 scores_processed = 0.5 * gems_tanh(0.5 * scores) + 0.5 

273 else: 

274 scores_processed = scores 

275 

276 group_scores = torch.empty( 

277 (num_tokens, n_group), 

278 device=scores.device, 

279 dtype=scores.dtype, 

280 ) 

281 

282 topk_values = torch.empty( 

283 (num_tokens, topk), 

284 device=scores.device, 

285 dtype=torch.float32, 

286 ) 

287 

288 topk_indices = torch.empty( 

289 (num_tokens, topk), 

290 device=scores.device, 

291 dtype=torch.int32, 

292 ) 

293 

294 BLOCK1 = triton.next_power_of_2(num_experts_per_group) 

295 grid1 = (num_tokens * n_group,) 

296 

297 topk_with_k2_triton[grid1]( 

298 scores_processed, 

299 bias, 

300 group_scores, 

301 num_experts_per_group, 

302 n_group, 

303 scores_processed.stride(0), 

304 group_scores.stride(0), 

305 BLOCK_SIZE=BLOCK1, 

306 INPUT_DTYPE=INPUT_DTYPE, 

307 ) 

308 

309 BLOCK_GROUP = triton.next_power_of_2(n_group) 

310 BLOCK_EXPERT = triton.next_power_of_2(num_experts) 

311 grid2 = (num_tokens,) 

312 

313 group_idx_and_topk_triton[grid2]( 

314 scores_processed, 

315 group_scores, 

316 topk_values, 

317 topk_indices, 

318 bias, 

319 num_tokens, 

320 n_group, 

321 topk_group, 

322 topk, 

323 num_experts, 

324 num_experts_per_group, 

325 routed_scaling_factor, 

326 scores_processed.stride(0), 

327 group_scores.stride(0), 

328 topk_values.stride(0), 

329 N_GROUP=n_group, 

330 TOPK_GROUP=topk_group, 

331 TOPK=topk, 

332 BLOCK_GROUP=BLOCK_GROUP, 

333 BLOCK_EXPERT=BLOCK_EXPERT, 

334 INPUT_DTYPE=INPUT_DTYPE, 

335 renormalize=int(renormalize), 

336 ) 

337 

338 return topk_values, topk_indices