Coverage for src/flag_gems/fused/grouped_topk.py: 8%

133 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-24 15:40 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7logger = logging.getLogger(__name__) 

8 

9 

10@triton.jit 

11def topk_with_k2_triton( 

12 scores_ptr, 

13 bias_ptr, 

14 group_scores_ptr, 

15 num_experts_per_group, 

16 n_group, 

17 stride_scores_token, 

18 stride_group_scores_token, 

19 BLOCK_SIZE: tl.constexpr, 

20 INPUT_DTYPE: tl.constexpr, 

21): 

22 pid = tl.program_id(0) 

23 

24 token_id = pid // n_group 

25 group_id = pid % n_group 

26 

27 lane = tl.arange(0, BLOCK_SIZE) 

28 mask = lane < num_experts_per_group 

29 

30 scores_offset = token_id * stride_scores_token + group_id * num_experts_per_group 

31 bias_offset = group_id * num_experts_per_group 

32 

33 x = tl.load( 

34 scores_ptr + scores_offset + lane, 

35 mask=mask, 

36 other=-float("inf"), 

37 ) 

38 

39 b = tl.load( 

40 bias_ptr + bias_offset + lane, 

41 mask=mask, 

42 other=0.0, 

43 ) 

44 

45 x = x + b 

46 

47 x_f32 = x.to(tl.float32) 

48 

49 max1 = tl.max(x_f32, axis=0) 

50 is_max1 = (x_f32 == max1) & mask 

51 count_max1 = tl.sum(is_max1.to(tl.int32), axis=0) 

52 

53 x2 = tl.where( 

54 is_max1 & (count_max1 == 1), 

55 -float("inf"), 

56 x_f32, 

57 ) 

58 max2 = tl.max(x2, axis=0) 

59 

60 group_scores_offset = token_id * stride_group_scores_token + group_id 

61 tl.store( 

62 group_scores_ptr + group_scores_offset, 

63 (max1 + max2).to(INPUT_DTYPE), 

64 ) 

65 

66 

67@triton.jit 

68def group_idx_and_topk_triton( 

69 scores_ptr, 

70 group_scores_ptr, 

71 topk_values_ptr, 

72 topk_indices_ptr, 

73 bias_ptr, 

74 num_tokens, 

75 n_group, 

76 topk_group, 

77 topk, 

78 num_experts, 

79 num_experts_per_group, 

80 routed_scaling_factor, 

81 stride_scores_token, 

82 stride_group_scores_token, 

83 stride_out_token, 

84 N_GROUP: tl.constexpr, 

85 TOPK_GROUP: tl.constexpr, 

86 TOPK: tl.constexpr, 

87 BLOCK_GROUP: tl.constexpr, 

88 BLOCK_EXPERT: tl.constexpr, 

89 INPUT_DTYPE: tl.constexpr, 

90 renormalize: tl.constexpr, 

91): 

92 pid = tl.program_id(0) 

93 if pid >= num_tokens: 

94 return 

95 

96 neg_inf = -float("inf") 

97 

98 group_offsets = tl.arange(0, BLOCK_GROUP) 

99 valid_group = group_offsets < n_group 

100 

101 group_scores = tl.load( 

102 group_scores_ptr + pid * stride_group_scores_token + group_offsets, 

103 mask=valid_group, 

104 other=neg_inf, 

105 ) 

106 

107 group_scores_f32 = group_scores.to(tl.float32) 

108 is_finite = (group_scores_f32 == group_scores_f32) & ( 

109 group_scores_f32 != float("inf") 

110 ) 

111 group_scores_f32 = tl.where(is_finite & valid_group, group_scores_f32, neg_inf) 

112 

113 max_group_score = tl.max(group_scores_f32, axis=0) 

114 if_proceed = max_group_score != neg_inf 

115 

116 value = group_scores_f32 

117 target_num_min = BLOCK_GROUP - n_group + topk_group 

118 count_equal_to_top_value = BLOCK_GROUP - n_group 

119 pre_count_equal_to_top_value = 0 

120 topk_group_value = neg_inf 

121 

122 for _ in range(TOPK_GROUP): 

123 need = count_equal_to_top_value < target_num_min 

124 max_val = tl.max(value, axis=0) 

125 

126 is_max = need & (value == max_val) 

127 value = tl.where(is_max, neg_inf, value) 

128 

129 newly = tl.sum(is_max.to(tl.int32), axis=0) 

130 

131 pre_count_equal_to_top_value = tl.where( 

132 need, count_equal_to_top_value, pre_count_equal_to_top_value 

133 ) 

134 count_equal_to_top_value = tl.where( 

135 need, count_equal_to_top_value + newly, count_equal_to_top_value 

136 ) 

137 topk_group_value = tl.where(need, max_val, topk_group_value) 

138 

139 num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value 

140 

141 group_gt = group_scores_f32 > topk_group_value 

142 group_eq = group_scores_f32 == topk_group_value 

143 

144 eq_i = group_eq.to(tl.int32) 

145 prefix_eq = tl.cumsum(eq_i, axis=0) - eq_i 

146 

147 group_selected = ( 

148 group_gt | (group_eq & (prefix_eq < num_equalto_topkth_group)) 

149 ) & valid_group 

150 

151 expert_offsets = tl.arange(0, BLOCK_EXPERT) 

152 valid_expert = expert_offsets < num_experts 

153 expert_group = expert_offsets // num_experts_per_group 

154 

155 expert_in_group = expert_group[:, None] == group_offsets[None, :] 

156 expert_selected = ( 

157 tl.sum((expert_in_group & group_selected[None, :]).to(tl.int32), axis=1) > 0 

158 ) & valid_expert 

159 

160 scored = tl.load( 

161 scores_ptr + pid * stride_scores_token + expert_offsets, 

162 mask=expert_selected, 

163 other=neg_inf, 

164 ) 

165 

166 expert_bias = tl.load( 

167 bias_ptr + expert_offsets, 

168 mask=valid_expert, 

169 other=0.0, 

170 ) 

171 

172 selection_scores_native = scored + expert_bias 

173 

174 selection_scores = tl.where( 

175 expert_selected, 

176 selection_scores_native.to(tl.float32), 

177 neg_inf, 

178 ) 

179 

180 topk_vals = tl.full([TOPK], 0.0, tl.float32) 

181 topk_idx = tl.full([TOPK], 0, tl.int32) 

182 pos_range = tl.arange(0, TOPK) 

183 

184 for i in range(TOPK): 

185 max_val = tl.max(selection_scores, axis=0) 

186 is_max = selection_scores == max_val 

187 

188 candidate_idx = tl.where(is_max, expert_offsets, num_experts + 1) 

189 selected_idx = tl.min(candidate_idx, axis=0) 

190 

191 selected_score = tl.load( 

192 scores_ptr + pid * stride_scores_token + selected_idx, 

193 mask=selected_idx < num_experts, 

194 other=neg_inf, 

195 ).to(tl.float32) 

196 

197 topk_vals = tl.where(pos_range == i, selected_score, topk_vals) 

198 topk_idx = tl.where(pos_range == i, selected_idx.to(tl.int32), topk_idx) 

199 

200 selection_scores = tl.where( 

201 expert_offsets == selected_idx, neg_inf, selection_scores 

202 ) 

203 

204 if renormalize == 1: 

205 topk_sum = tl.sum(topk_vals, axis=0) + 1e-20 

206 scale = routed_scaling_factor / topk_sum 

207 else: 

208 scale = routed_scaling_factor 

209 

210 topk_vals = topk_vals * scale 

211 

212 default_idx = pos_range.to(tl.int32) 

213 default_vals = tl.full([TOPK], 1.0 / topk, tl.float32) 

214 

215 final_vals = tl.where(if_proceed, topk_vals, default_vals) 

216 final_idx = tl.where(if_proceed, topk_idx, default_idx) 

217 

218 tl.store( 

219 topk_values_ptr + pid * stride_out_token + pos_range, 

220 final_vals, 

221 mask=pos_range < topk, 

222 ) 

223 

224 tl.store( 

225 topk_indices_ptr + pid * stride_out_token + pos_range, 

226 final_idx, 

227 mask=pos_range < topk, 

228 ) 

229 

230 

231def grouped_topk( 

232 scores: torch.Tensor, 

233 n_group: int, 

234 topk_group: int, 

235 topk: int, 

236 renormalize: bool, 

237 routed_scaling_factor: float, 

238 bias: torch.Tensor, 

239 scoring_func: int = 0, 

240): 

241 logger.debug("GEMS GROUPED TOPK") 

242 if scores.ndim != 2: 

243 raise ValueError("scores must be a 2D Tensor") 

244 num_tokens, num_experts = scores.shape 

245 if num_experts % n_group != 0: 

246 raise ValueError("num_experts must be divisible by n_group") 

247 if n_group > 32: 

248 raise ValueError("n_group should be smaller than or equal to 32") 

249 if topk > 32: 

250 raise ValueError("topk should be smaller than or equal to 32 for now") 

251 if scoring_func not in (0, 1): 

252 raise ValueError("scoring_func must be 0 (none) or 1 (sigmoid)") 

253 

254 if bias.dtype != scores.dtype: 

255 bias = bias.to(scores.dtype) 

256 if bias.ndim != 1: 

257 bias = bias.flatten() 

258 if len(bias) != num_experts: 

259 raise ValueError( 

260 f"bias length ({len(bias)}) must match num_experts ({num_experts})" 

261 ) 

262 

263 num_experts_per_group = num_experts // n_group 

264 

265 if scores.dtype == torch.float32: 

266 INPUT_DTYPE = tl.float32 

267 elif scores.dtype == torch.float16: 

268 INPUT_DTYPE = tl.float16 

269 elif scores.dtype == torch.bfloat16: 

270 INPUT_DTYPE = tl.bfloat16 

271 else: 

272 raise ValueError(f"Unsupported dtype: {scores.dtype}") 

273 

274 if scoring_func == 1: 

275 from flag_gems.ops.tanh import tanh as gems_tanh 

276 

277 scores_processed = 0.5 * gems_tanh(0.5 * scores) + 0.5 

278 else: 

279 scores_processed = scores 

280 

281 group_scores = torch.empty( 

282 (num_tokens, n_group), 

283 device=scores.device, 

284 dtype=scores.dtype, 

285 ) 

286 

287 topk_values = torch.empty( 

288 (num_tokens, topk), 

289 device=scores.device, 

290 dtype=torch.float32, 

291 ) 

292 

293 topk_indices = torch.empty( 

294 (num_tokens, topk), 

295 device=scores.device, 

296 dtype=torch.int32, 

297 ) 

298 

299 BLOCK1 = triton.next_power_of_2(num_experts_per_group) 

300 grid1 = (num_tokens * n_group,) 

301 

302 topk_with_k2_triton[grid1]( 

303 scores_processed, 

304 bias, 

305 group_scores, 

306 num_experts_per_group, 

307 n_group, 

308 scores_processed.stride(0), 

309 group_scores.stride(0), 

310 BLOCK_SIZE=BLOCK1, 

311 INPUT_DTYPE=INPUT_DTYPE, 

312 ) 

313 

314 BLOCK_GROUP = triton.next_power_of_2(n_group) 

315 BLOCK_EXPERT = triton.next_power_of_2(num_experts) 

316 grid2 = (num_tokens,) 

317 

318 group_idx_and_topk_triton[grid2]( 

319 scores_processed, 

320 group_scores, 

321 topk_values, 

322 topk_indices, 

323 bias, 

324 num_tokens, 

325 n_group, 

326 topk_group, 

327 topk, 

328 num_experts, 

329 num_experts_per_group, 

330 routed_scaling_factor, 

331 scores_processed.stride(0), 

332 group_scores.stride(0), 

333 topk_values.stride(0), 

334 N_GROUP=n_group, 

335 TOPK_GROUP=topk_group, 

336 TOPK=topk, 

337 BLOCK_GROUP=BLOCK_GROUP, 

338 BLOCK_EXPERT=BLOCK_EXPERT, 

339 INPUT_DTYPE=INPUT_DTYPE, 

340 renormalize=int(renormalize), 

341 ) 

342 

343 return topk_values, topk_indices