Coverage for src/flag_gems/runtime/backend/_nvidia/hopper/grouped_topk.py: 0%

139 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-23 02:03 +0800

1import torch 

2import triton 

3import triton.language as tl 

4from triton.language.extra.cuda import libdevice 

5 

6 

7@triton.jit 

8def topk_with_k2_triton( 

9 scores_ptr, 

10 bias_ptr, 

11 group_scores_ptr, 

12 num_experts_per_group, 

13 n_group, 

14 stride_scores_token, 

15 stride_group_scores_token, 

16 scoring_func: tl.constexpr, 

17 BLOCK_SIZE: tl.constexpr, 

18 INPUT_DTYPE: tl.constexpr, 

19): 

20 pid = tl.program_id(0) 

21 

22 token_id = pid // n_group 

23 group_id = pid % n_group 

24 

25 lane = tl.arange(0, BLOCK_SIZE) 

26 mask = lane < num_experts_per_group 

27 

28 scores_offset = token_id * stride_scores_token + group_id * num_experts_per_group 

29 bias_offset = group_id * num_experts_per_group 

30 

31 x = tl.load( 

32 scores_ptr + scores_offset + lane, 

33 mask=mask, 

34 other=-float("inf"), 

35 ) 

36 

37 b = tl.load( 

38 bias_ptr + bias_offset + lane, 

39 mask=mask, 

40 other=0.0, 

41 ) 

42 

43 if scoring_func == 1: 

44 x_f32 = x.to(tl.float32) 

45 x_f32 = 0.5 * libdevice.tanh(0.5 * x_f32) + 0.5 

46 x = x_f32.to(INPUT_DTYPE) 

47 

48 x = x + b 

49 

50 x_f32 = x.to(tl.float32) 

51 

52 max1 = tl.max(x_f32, axis=0) 

53 is_max1 = (x_f32 == max1) & mask 

54 count_max1 = tl.sum(is_max1.to(tl.int32), axis=0) 

55 

56 x2 = tl.where( 

57 is_max1 & (count_max1 == 1), 

58 -float("inf"), 

59 x_f32, 

60 ) 

61 max2 = tl.max(x2, axis=0) 

62 

63 group_scores_offset = token_id * stride_group_scores_token + group_id 

64 tl.store( 

65 group_scores_ptr + group_scores_offset, 

66 (max1 + max2).to(INPUT_DTYPE), 

67 ) 

68 

69 

70@triton.jit 

71def group_idx_and_topk_triton( 

72 scores_ptr, 

73 group_scores_ptr, 

74 topk_values_ptr, 

75 topk_indices_ptr, 

76 bias_ptr, 

77 num_tokens, 

78 n_group, 

79 topk_group, 

80 topk, 

81 num_experts, 

82 num_experts_per_group, 

83 routed_scaling_factor, 

84 scoring_func: tl.constexpr, 

85 stride_scores_token, 

86 stride_group_scores_token, 

87 stride_out_token, 

88 N_GROUP: tl.constexpr, 

89 TOPK_GROUP: tl.constexpr, 

90 TOPK: tl.constexpr, 

91 BLOCK_GROUP: tl.constexpr, 

92 BLOCK_EXPERT: tl.constexpr, 

93 INPUT_DTYPE: tl.constexpr, 

94 renormalize: tl.constexpr, 

95): 

96 pid = tl.program_id(0) 

97 if pid >= num_tokens: 

98 return 

99 

100 neg_inf = -float("inf") 

101 

102 group_offsets = tl.arange(0, BLOCK_GROUP) 

103 valid_group = group_offsets < n_group 

104 

105 group_scores = tl.load( 

106 group_scores_ptr + pid * stride_group_scores_token + group_offsets, 

107 mask=valid_group, 

108 other=neg_inf, 

109 ) 

110 

111 group_scores_f32 = group_scores.to(tl.float32) 

112 is_finite = (group_scores_f32 == group_scores_f32) & ( 

113 group_scores_f32 != float("inf") 

114 ) 

115 group_scores_f32 = tl.where(is_finite & valid_group, group_scores_f32, neg_inf) 

116 

117 max_group_score = tl.max(group_scores_f32, axis=0) 

118 if_proceed = max_group_score != neg_inf 

119 

120 value = group_scores_f32 

121 target_num_min = BLOCK_GROUP - n_group + topk_group 

122 count_equal_to_top_value = BLOCK_GROUP - n_group 

123 pre_count_equal_to_top_value = 0 

124 topk_group_value = neg_inf 

125 

126 for _ in range(TOPK_GROUP): 

127 need = count_equal_to_top_value < target_num_min 

128 max_val = tl.max(value, axis=0) 

129 

130 is_max = need & (value == max_val) 

131 value = tl.where(is_max, neg_inf, value) 

132 

133 newly = tl.sum(is_max.to(tl.int32), axis=0) 

134 

135 pre_count_equal_to_top_value = tl.where( 

136 need, count_equal_to_top_value, pre_count_equal_to_top_value 

137 ) 

138 count_equal_to_top_value = tl.where( 

139 need, count_equal_to_top_value + newly, count_equal_to_top_value 

140 ) 

141 topk_group_value = tl.where(need, max_val, topk_group_value) 

142 

143 num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value 

144 

145 group_gt = group_scores_f32 > topk_group_value 

146 group_eq = group_scores_f32 == topk_group_value 

147 

148 eq_i = group_eq.to(tl.int32) 

149 prefix_eq = tl.cumsum(eq_i, axis=0) - eq_i 

150 

151 group_selected = ( 

152 group_gt | (group_eq & (prefix_eq < num_equalto_topkth_group)) 

153 ) & valid_group 

154 

155 expert_offsets = tl.arange(0, BLOCK_EXPERT) 

156 valid_expert = expert_offsets < num_experts 

157 expert_group = expert_offsets // num_experts_per_group 

158 

159 expert_in_group = expert_group[:, None] == group_offsets[None, :] 

160 expert_selected = ( 

161 tl.sum((expert_in_group & group_selected[None, :]).to(tl.int32), axis=1) > 0 

162 ) & valid_expert 

163 

164 raw_scores = tl.load( 

165 scores_ptr + pid * stride_scores_token + expert_offsets, 

166 mask=expert_selected, 

167 other=neg_inf, 

168 ) 

169 

170 expert_bias = tl.load( 

171 bias_ptr + expert_offsets, 

172 mask=valid_expert, 

173 other=0.0, 

174 ) 

175 

176 if scoring_func == 1: 

177 scored_f32 = raw_scores.to(tl.float32) 

178 scored_f32 = 0.5 * libdevice.tanh(0.5 * scored_f32) + 0.5 

179 scored = scored_f32.to(INPUT_DTYPE) 

180 else: 

181 scored = raw_scores 

182 

183 selection_scores_native = scored + expert_bias 

184 

185 selection_scores = tl.where( 

186 expert_selected, 

187 selection_scores_native.to(tl.float32), 

188 neg_inf, 

189 ) 

190 

191 topk_vals = tl.full([TOPK], 0.0, tl.float32) 

192 topk_idx = tl.full([TOPK], 0, tl.int32) 

193 pos_range = tl.arange(0, TOPK) 

194 

195 for i in range(TOPK): 

196 max_val = tl.max(selection_scores, axis=0) 

197 is_max = selection_scores == max_val 

198 

199 candidate_idx = tl.where(is_max, expert_offsets, num_experts + 1) 

200 selected_idx = tl.min(candidate_idx, axis=0) 

201 

202 selected_raw = tl.load( 

203 scores_ptr + pid * stride_scores_token + selected_idx, 

204 mask=selected_idx < num_experts, 

205 other=neg_inf, 

206 ).to(tl.float32) 

207 

208 if scoring_func == 1: 

209 selected_score = 0.5 * libdevice.tanh(0.5 * selected_raw) + 0.5 

210 else: 

211 selected_score = selected_raw 

212 

213 topk_vals = tl.where(pos_range == i, selected_score, topk_vals) 

214 topk_idx = tl.where(pos_range == i, selected_idx.to(tl.int32), topk_idx) 

215 

216 selection_scores = tl.where( 

217 expert_offsets == selected_idx, neg_inf, selection_scores 

218 ) 

219 

220 if renormalize == 1: 

221 topk_sum = tl.sum(topk_vals, axis=0) + 1e-20 

222 scale = routed_scaling_factor / topk_sum 

223 else: 

224 scale = routed_scaling_factor 

225 

226 topk_vals = topk_vals * scale 

227 

228 default_idx = pos_range.to(tl.int32) 

229 default_vals = tl.full([TOPK], 1.0 / topk, tl.float32) 

230 

231 final_vals = tl.where(if_proceed, topk_vals, default_vals) 

232 final_idx = tl.where(if_proceed, topk_idx, default_idx) 

233 

234 tl.store( 

235 topk_values_ptr + pid * stride_out_token + pos_range, 

236 final_vals, 

237 mask=pos_range < topk, 

238 ) 

239 

240 tl.store( 

241 topk_indices_ptr + pid * stride_out_token + pos_range, 

242 final_idx, 

243 mask=pos_range < topk, 

244 ) 

245 

246 

247def grouped_topk( 

248 scores: torch.Tensor, 

249 n_group: int, 

250 topk_group: int, 

251 topk: int, 

252 renormalize: bool, 

253 routed_scaling_factor: float, 

254 bias: torch.Tensor, 

255 scoring_func: int = 0, 

256): 

257 if scores.ndim != 2: 

258 raise ValueError("scores must be a 2D Tensor") 

259 num_tokens, num_experts = scores.shape 

260 if num_experts % n_group != 0: 

261 raise ValueError("num_experts must be divisible by n_group") 

262 if n_group > 32: 

263 raise ValueError("n_group should be smaller than or equal to 32") 

264 if topk > 32: 

265 raise ValueError("topk should be smaller than or equal to 32 for now") 

266 if scoring_func not in (0, 1): 

267 raise ValueError("scoring_func must be 0 (none) or 1 (sigmoid)") 

268 

269 if bias.dtype != scores.dtype: 

270 bias = bias.to(scores.dtype) 

271 if bias.ndim != 1: 

272 bias = bias.flatten() 

273 if len(bias) != num_experts: 

274 raise ValueError( 

275 f"bias length ({len(bias)}) must match num_experts ({num_experts})" 

276 ) 

277 

278 num_experts_per_group = num_experts // n_group 

279 

280 if scores.dtype == torch.float32: 

281 INPUT_DTYPE = tl.float32 

282 elif scores.dtype == torch.float16: 

283 INPUT_DTYPE = tl.float16 

284 elif scores.dtype == torch.bfloat16: 

285 INPUT_DTYPE = tl.bfloat16 

286 else: 

287 raise ValueError(f"Unsupported dtype: {scores.dtype}") 

288 

289 group_scores = torch.empty( 

290 (num_tokens, n_group), 

291 device=scores.device, 

292 dtype=scores.dtype, 

293 ) 

294 

295 topk_values = torch.empty( 

296 (num_tokens, topk), 

297 device=scores.device, 

298 dtype=torch.float32, 

299 ) 

300 

301 topk_indices = torch.empty( 

302 (num_tokens, topk), 

303 device=scores.device, 

304 dtype=torch.int32, 

305 ) 

306 

307 BLOCK1 = triton.next_power_of_2(num_experts_per_group) 

308 grid1 = (num_tokens * n_group,) 

309 

310 topk_with_k2_triton[grid1]( 

311 scores, 

312 bias, 

313 group_scores, 

314 num_experts_per_group, 

315 n_group, 

316 scores.stride(0), 

317 group_scores.stride(0), 

318 scoring_func, 

319 BLOCK_SIZE=BLOCK1, 

320 INPUT_DTYPE=INPUT_DTYPE, 

321 ) 

322 

323 BLOCK_GROUP = triton.next_power_of_2(n_group) 

324 BLOCK_EXPERT = triton.next_power_of_2(num_experts) 

325 grid2 = (num_tokens,) 

326 

327 group_idx_and_topk_triton[grid2]( 

328 scores, 

329 group_scores, 

330 topk_values, 

331 topk_indices, 

332 bias, 

333 num_tokens, 

334 n_group, 

335 topk_group, 

336 topk, 

337 num_experts, 

338 num_experts_per_group, 

339 routed_scaling_factor, 

340 scoring_func, 

341 scores.stride(0), 

342 group_scores.stride(0), 

343 topk_values.stride(0), 

344 N_GROUP=n_group, 

345 TOPK_GROUP=topk_group, 

346 TOPK=topk, 

347 BLOCK_GROUP=BLOCK_GROUP, 

348 BLOCK_EXPERT=BLOCK_EXPERT, 

349 INPUT_DTYPE=INPUT_DTYPE, 

350 renormalize=int(renormalize), 

351 ) 

352 

353 return topk_values, topk_indices