Coverage for src/flag_gems/runtime/backend/_sunrise/heuristics_config_utils.py: 0%

174 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-11 02:28 +0800

1import torch # noqa: F401 

2import triton 

3 

4 

5def simple_elementwise_blocksize_heur(args): 

6 return 1024 

7 

8 

9def argmax_heur_block_m(args): 

10 return 1 if args["M"] < 4096 else 8 

11 

12 

13def argmax_heur_block_n(args): 

14 return min(4096, triton.next_power_of_2(args["N"])) 

15 

16 

17def argmin_heur_block_m(args): 

18 return 4 if args["M"] < 4096 else 8 

19 

20 

21def argmin_heur_block_n(args): 

22 return min(4096, triton.next_power_of_2(args["N"])) 

23 

24 

25def bmm_heur_divisible_m(args): 

26 return args["M"] % args["TILE_M"] == 0 

27 

28 

29def bmm_heur_divisible_n(args): 

30 return args["N"] % args["TILE_N"] == 0 

31 

32 

33def bmm_heur_divisible_k(args): 

34 return args["K"] % args["TILE_K"] == 0 

35 

36 

37def dropout_heur_block(args): 

38 if args["N"] <= 512: 

39 return 256 

40 else: 

41 return 512 

42 

43 

44def dropout_heur_num_warps(args): 

45 if args["N"] <= 512: 

46 return 2 

47 elif args["N"] <= 2048: 

48 return 4 

49 else: 

50 return 8 

51 

52 

53def exponential_heur_block(args): 

54 if args["N"] <= 512: 

55 return 512 

56 else: 

57 return 1024 

58 

59 

60def exponential_heur_num_warps(args): 

61 if args["N"] <= 512: 

62 return 4 

63 elif args["N"] <= 1024: 

64 return 8 

65 else: 

66 return 16 

67 

68 

69def gather_heur_block_m(args): 

70 return min(1, triton.next_power_of_2(triton.cdiv(args["N"], 2048))) 

71 

72 

73def gather_heur_block_n(args): 

74 return min(2048, triton.next_power_of_2(args["N"])) 

75 

76 

77def index_select_heur_block_m(args): 

78 return min(4, triton.next_power_of_2(triton.cdiv(256, args["N"]))) 

79 

80 

81def index_select_heur_block_n(args): 

82 m = min(triton.next_power_of_2(triton.cdiv(args["N"], 16)), 512) 

83 return max(m, 16) 

84 

85 

86def mm_heur_even_k(args): 

87 return args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0 

88 

89 

90def rand_heur_block(args): 

91 if args["N"] <= 512: 

92 return 512 

93 else: 

94 return 1024 

95 

96 

97def rand_heur_num_warps(args): 

98 if args["N"] <= 512: 

99 return 4 

100 elif args["N"] <= 1024: 

101 return 8 

102 else: 

103 return 16 

104 

105 

106def randn_heur_block(args): 

107 if args["N"] <= 512: 

108 return 512 

109 else: 

110 return 1024 

111 

112 

113def randn_heur_num_warps(args): 

114 if args["N"] <= 512: 

115 return 4 

116 elif args["N"] <= 1024: 

117 return 8 

118 else: 

119 return 16 

120 

121 

122def softmax_heur_tile_k(args): 

123 MAX_TILE_K = 512 

124 # NUM_SMS = torch.cuda.get_device_properties( 

125 # torch.cuda.current_device() 

126 # ).multi_processor_count 

127 NUM_SMS = 32 # Not support now. 

128 

129 tile_k = 1 

130 upper_bound = min(args["K"], MAX_TILE_K) 

131 while tile_k <= upper_bound: 

132 num_blocks = args["M"] * triton.cdiv(args["K"], tile_k) 

133 num_waves = num_blocks / NUM_SMS 

134 if (num_waves > 1) and (tile_k * 2 <= upper_bound): 

135 tile_k *= 2 

136 else: 

137 break 

138 return tile_k 

139 

140 

141def softmax_heur_tile_n_non_inner(args): 

142 return triton.cdiv(8192, args["TILE_K"]) 

143 

144 

145def softmax_heur_one_tile_per_cta(args): 

146 return args["TILE_N"] >= args["N"] 

147 

148 

149def softmax_heur_num_warps_non_inner(args): 

150 tile_size = args["TILE_N"] * args["TILE_K"] 

151 if tile_size < 2048: 

152 return 4 

153 elif tile_size < 4096: 

154 return 8 

155 else: 

156 return 16 

157 

158 

159def softmax_heur_tile_n_inner(args): 

160 if args["N"] <= 32: 

161 return triton.next_power_of_2(args["N"]) 

162 if args["N"] <= 1024: 

163 return 256 

164 else: 

165 return 512 

166 

167 

168def softmax_heur_num_warps_inner(args): 

169 tile_size = args["TILE_N"] 

170 if tile_size < 64: 

171 return 2 

172 if tile_size < 2048: 

173 return 4 

174 elif tile_size < 4096: 

175 return 8 

176 else: 

177 return 16 

178 

179 

180def softmax_heur_tile_n_bwd_non_inner(args): 

181 return max(1, 1024 // args["TILE_K"]) 

182 

183 

184def softmax_heur_tile_m(args): 

185 return max(1, 1024 // args["TILE_N"]) 

186 

187 

188def uniform_heur_block(args): 

189 if args["N"] <= 512: 

190 return 512 

191 else: 

192 return 1024 

193 

194 

195def uniform_heur_num_warps(args): 

196 if args["N"] <= 512: 

197 return 4 

198 elif args["N"] <= 1024: 

199 return 8 

200 else: 

201 return 16 

202 

203 

204def var_mean_heur_block_n(args): 

205 return triton.next_power_of_2(args["BLOCK_NUM"]) 

206 

207 

208def upsample_nearest2d_SAME_H(args): 

209 return args["OH"] == args["IH"] 

210 

211 

212def upsample_nearest2d_SAME_W(args): 

213 return args["OW"] == args["IW"] 

214 

215 

216def upsample_nearest2d_USE_INT32_IDX(args): 

217 return args["N"] * args["C"] * args["OH"] * args["OW"] <= (2**31 - 1) # INT32 MAX 

218 

219 

220def batch_norm_heur_block_m(args): 

221 return min(256, triton.next_power_of_2(args["batch_dim"])) 

222 

223 

224def batch_norm_heur_block_n(args): 

225 # A maximum of 16384 elements are loaded at once. 

226 BLOCK_M = batch_norm_heur_block_m(args) 

227 BLOCK_N = triton.next_power_of_2(args["spatial_dim"]) 

228 return min(BLOCK_N, max(1, 2**10 // BLOCK_M)) 

229 

230 

231def vdot_heur_block_size(args): 

232 n = args["n_elements"] 

233 if n < 1024: 

234 return 32 

235 elif n < 8192: 

236 return 256 

237 else: 

238 return 1024 

239 

240 

241def sum_heur_num_warps_inner(args): 

242 tile_size = args["TILE_N"] 

243 if tile_size < 64: 

244 return 2 

245 if tile_size < 2048: 

246 return 4 

247 elif tile_size < 4096: 

248 return 8 

249 else: 

250 return 16 

251 

252 

253def sum_heur_tile_n_inner(args): 

254 if args["N"] <= 32: 

255 return triton.next_power_of_2(args["N"]) 

256 if args["N"] <= 1024: 

257 return 128 

258 else: 

259 return 256 

260 

261 

262def sum_heur_one_tile_per_cta(args): 

263 return args["TILE_N"] >= args["N"] 

264 

265 

266def sum_heur_tile_k(args): 

267 MAX_TILE_K = 128 

268 # NUM_SMS = torch.cuda.get_device_properties( 

269 # torch.cuda.current_device() 

270 # ).multi_processor_count 

271 NUM_SMS = 32 # Not support now. 

272 

273 tile_k = 1 

274 upper_bound = min(args["K"], MAX_TILE_K) 

275 while tile_k <= upper_bound: 

276 num_blocks = args["M"] * triton.cdiv(args["K"], tile_k) 

277 num_waves = num_blocks / NUM_SMS 

278 if (num_waves > 1) and (tile_k * 2 <= upper_bound): 

279 tile_k *= 2 

280 else: 

281 break 

282 return tile_k 

283 

284 

285def sum_heur_tile_n_non_inner(args): 

286 return triton.cdiv(256, args["TILE_K"]) 

287 

288 

289HEURISTICS_CONFIGS = { 

290 "argmax": { 

291 "BLOCK_M": argmax_heur_block_m, 

292 "BLOCK_N": argmax_heur_block_n, 

293 }, 

294 "argmin": { 

295 "BLOCK_M": argmin_heur_block_m, 

296 "BLOCK_N": argmin_heur_block_n, 

297 }, 

298 "bmm": { 

299 "DIVISIBLE_M": bmm_heur_divisible_m, 

300 "DIVISIBLE_N": bmm_heur_divisible_n, 

301 "DIVISIBLE_K": bmm_heur_divisible_k, 

302 }, 

303 "dropout": { 

304 "BLOCK": dropout_heur_block, 

305 "num_warps": dropout_heur_num_warps, 

306 }, 

307 "exponential_": { 

308 "BLOCK": exponential_heur_block, 

309 "num_warps": exponential_heur_num_warps, 

310 }, 

311 "gather": { 

312 "BLOCK_M": gather_heur_block_m, 

313 "BLOCK_N": gather_heur_block_n, 

314 }, 

315 "index_select": { 

316 "BLOCK_M": index_select_heur_block_m, 

317 "BLOCK_N": index_select_heur_block_n, 

318 }, 

319 "mm": { 

320 "EVEN_K": mm_heur_even_k, 

321 }, 

322 "rand": { 

323 "BLOCK": rand_heur_block, 

324 "num_warps": rand_heur_num_warps, 

325 }, 

326 "randn": { 

327 "BLOCK": randn_heur_block, 

328 "num_warps": randn_heur_num_warps, 

329 }, 

330 "softmax_non_inner": { 

331 "TILE_K": softmax_heur_tile_k, 

332 "TILE_N": softmax_heur_tile_n_non_inner, 

333 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, 

334 "num_warps": softmax_heur_num_warps_non_inner, 

335 }, 

336 "softmax_inner": { 

337 "TILE_N": softmax_heur_tile_n_inner, 

338 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, 

339 "num_warps": softmax_heur_num_warps_inner, 

340 }, 

341 "softmax_backward_non_inner": { 

342 "TILE_N": softmax_heur_tile_n_bwd_non_inner, 

343 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, 

344 }, 

345 "softmax_backward_inner": { 

346 "TILE_M": softmax_heur_tile_m, 

347 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, 

348 }, 

349 "uniform": { 

350 "BLOCK": uniform_heur_block, 

351 "num_warps": uniform_heur_num_warps, 

352 }, 

353 "upsample_nearest2d": { 

354 "SAME_H": upsample_nearest2d_SAME_H, 

355 "SAME_W": upsample_nearest2d_SAME_W, 

356 "USE_INT32_IDX": upsample_nearest2d_USE_INT32_IDX, 

357 }, 

358 "var_mean": { 

359 "BLOCK_N": var_mean_heur_block_n, 

360 }, 

361 "batch_norm": { 

362 "BLOCK_M": batch_norm_heur_block_m, 

363 "BLOCK_N": batch_norm_heur_block_n, 

364 }, 

365 "vdot": { 

366 "BLOCK_SIZE": vdot_heur_block_size, 

367 }, 

368 "mha_varlen_prefill": { 

369 "BLOCK_M": lambda args: 128, 

370 "BLOCK_N": lambda args: 32, 

371 "num_warps": lambda args: 4, 

372 "num_stages": lambda args: 3, 

373 }, 

374 "mha_varlen_decode": { 

375 "BLOCK_M": lambda args: 16, 

376 "BLOCK_N": lambda args: 64, 

377 "num_warps": lambda args: 4, 

378 "num_stages": lambda args: 3, 

379 }, 

380 "elementwise_generic": { 

381 "BLOCK_SIZE": simple_elementwise_blocksize_heur, 

382 "num_warps": lambda args: 8, 

383 }, 

384 "sum_inner": { 

385 "TILE_N": sum_heur_tile_n_inner, 

386 "ONE_TILE_PER_CTA": sum_heur_one_tile_per_cta, 

387 "num_warps": sum_heur_num_warps_inner, 

388 }, 

389 "sum_non_inner": { 

390 "TILE_K": sum_heur_tile_k, 

391 "TILE_N": sum_heur_tile_n_non_inner, 

392 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, 

393 "num_warps": softmax_heur_num_warps_non_inner, 

394 }, 

395}