Coverage for src/flag_gems/runtime/backend/_metax/heuristics_config_utils.py: 0%

168 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-26 15:32 +0800

1import torch 

2import triton 

3 

4 

5def simple_elementwise_blocksize_heur(args): 

6 return 512 

7 

8 

9def argmax_heur_block_m(args): 

10 return 4 if args["M"] < 4096 else 8 

11 

12 

13def argmax_heur_block_n(args): 

14 return min(4096, triton.next_power_of_2(args["N"])) 

15 

16 

17def bmm_heur_divisible_m(args): 

18 return args["M"] % args["TILE_M"] == 0 

19 

20 

21def bmm_heur_divisible_n(args): 

22 return args["N"] % args["TILE_N"] == 0 

23 

24 

25def bmm_heur_divisible_k(args): 

26 return args["K"] % args["TILE_K"] == 0 

27 

28 

29def argmin_heur_block_m(args): 

30 return 4 if args["M"] < 4096 else 8 

31 

32 

33def argmin_heur_block_n(args): 

34 return min(4096, triton.next_power_of_2(args["N"])) 

35 

36 

37def dropout_heur_block(args): 

38 if args["N"] <= 512: 

39 return 512 

40 else: 

41 return 1024 

42 

43 

44def dropout_heur_num_warps(args): 

45 if args["N"] <= 512: 

46 return 4 

47 elif args["N"] <= 1024: 

48 return 8 

49 else: 

50 return 16 

51 

52 

53def exponential_heur_block(args): 

54 if args["N"] <= 512: 

55 return 512 

56 else: 

57 return 1024 

58 

59 

60def exponential_heur_num_warps(args): 

61 if args["N"] <= 512: 

62 return 4 

63 elif args["N"] <= 1024: 

64 return 8 

65 else: 

66 return 16 

67 

68 

69def gather_heur_block_m(args): 

70 return min(4, triton.next_power_of_2(triton.cdiv(args["N"], 2048))) 

71 

72 

73def gather_heur_block_n(args): 

74 return min(2048, triton.next_power_of_2(args["N"])) 

75 

76 

77def index_heur_block_0(args): 

78 return 2 

79 

80 

81def index_heur_block_1(args): 

82 return 1024 

83 

84 

85def index_select_heur_block_m(args): 

86 return min(4, triton.next_power_of_2(triton.cdiv(256, args["N"]))) 

87 

88 

89def index_select_heur_block_n(args): 

90 m = min(triton.next_power_of_2(triton.cdiv(args["N"], 16)), 512) 

91 return max(m, 16) 

92 

93 

94def mm_heur_even_k(args): 

95 return args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0 

96 

97 

98def ones_heur_block_size(args): 

99 if args["N"] <= 1024: 

100 return 1024 

101 elif args["N"] <= 2048: 

102 return 2048 

103 else: 

104 return 4096 

105 

106 

107def ones_heur_num_warps(args): 

108 if ( 

109 args["output_ptr"].dtype == torch.float16 

110 or args["output_ptr"].dtype == torch.bfloat16 

111 ): 

112 return 2 

113 else: 

114 return 4 

115 

116 

117def rand_heur_block(args): 

118 if args["N"] <= 512: 

119 return 512 

120 else: 

121 return 1024 

122 

123 

124def rand_heur_num_warps(args): 

125 if args["N"] <= 512: 

126 return 4 

127 elif args["N"] <= 1024: 

128 return 8 

129 else: 

130 return 16 

131 

132 

133def randn_heur_block(args): 

134 if args["N"] <= 512: 

135 return 512 

136 else: 

137 return 1024 

138 

139 

140def randn_heur_num_warps(args): 

141 if args["N"] <= 512: 

142 return 4 

143 elif args["N"] <= 1024: 

144 return 8 

145 else: 

146 return 16 

147 

148 

149def softmax_heur_tile_k(args): 

150 MAX_TILE_K = 8192 

151 NUM_SMS = torch.cuda.get_device_properties( 

152 torch.cuda.current_device() 

153 ).multi_processor_count 

154 tile_k = 1 

155 upper_bound = min(args["K"], MAX_TILE_K) 

156 while tile_k <= upper_bound: 

157 num_blocks = args["M"] * triton.cdiv(args["K"], tile_k) 

158 num_waves = num_blocks / NUM_SMS 

159 if (num_waves > 1) and (tile_k * 2 <= upper_bound): 

160 tile_k *= 2 

161 else: 

162 break 

163 return tile_k 

164 

165 

166def softmax_heur_tile_n_non_inner(args): 

167 upper_bound = triton.next_power_of_2(args["N"]) 

168 return min(upper_bound, triton.cdiv(8192, args["TILE_K"])) 

169 

170 

171def softmax_heur_one_tile_per_cta(args): 

172 return args["TILE_N"] >= args["N"] 

173 

174 

175def softmax_heur_num_warps_non_inner(args): 

176 tile_size = args["TILE_N"] * args["TILE_K"] 

177 if tile_size < 512: 

178 return 1 

179 elif tile_size < 256: 

180 return 2 

181 elif tile_size < 2048: 

182 return 4 

183 elif tile_size < 4096: 

184 return 8 

185 else: 

186 return 16 

187 

188 

189def softmax_heur_tile_n_inner(args): 

190 if args["N"] <= (32 * 1024): 

191 return triton.next_power_of_2(args["N"]) 

192 else: 

193 return 4096 

194 

195 

196def softmax_heur_num_warps_inner(args): 

197 tile_size = args["TILE_N"] 

198 if tile_size < 2048: 

199 return 4 

200 elif tile_size < 4096: 

201 return 8 

202 else: 

203 return 16 

204 

205 

206def softmax_heur_tile_n_bwd_non_inner(args): 

207 return max(1, 1024 // args["TILE_K"]) 

208 

209 

210def softmax_heur_tile_m(args): 

211 return max(1, 1024 // args["TILE_N"]) 

212 

213 

214def uniform_heur_block(args): 

215 if args["N"] <= 512: 

216 return 512 

217 else: 

218 return 1024 

219 

220 

221def uniform_heur_num_warps(args): 

222 if args["N"] <= 512: 

223 return 4 

224 elif args["N"] <= 1024: 

225 return 8 

226 else: 

227 return 16 

228 

229 

230def var_mean_heur_block_n(args): 

231 return triton.next_power_of_2(args["BLOCK_NUM"]) 

232 

233 

234def upsample_nearest2d_SAME_H(args): 

235 return args["OH"] == args["IH"] 

236 

237 

238def upsample_nearest2d_SAME_W(args): 

239 return args["OW"] == args["IW"] 

240 

241 

242def upsample_nearest2d_USE_INT32_IDX(args): 

243 return args["N"] * args["C"] * args["OH"] * args["OW"] <= (2**31 - 1) # INT32 MAX 

244 

245 

246def batch_norm_heur_block_m(args): 

247 return min(2048, triton.next_power_of_2(args["batch_dim"])) 

248 

249 

250def batch_norm_heur_block_n(args): 

251 # A maximum of 16384 elements are loaded at once. 

252 BLOCK_M = batch_norm_heur_block_m(args) 

253 BLOCK_N = triton.next_power_of_2(args["spatial_dim"]) 

254 return min(BLOCK_N, max(1, 2**14 // BLOCK_M)) 

255 

256 

257def vdot_heur_block_size(args): 

258 n = args["n_elements"] 

259 if n < 1024: 

260 return 32 

261 elif n < 8192: 

262 return 256 

263 else: 

264 return 1024 

265 

266 

267def zeros_heur_block_size(args): 

268 if args["N"] <= 1024: 

269 return 1024 

270 elif args["N"] <= 2048: 

271 return 2048 

272 else: 

273 return 4096 

274 

275 

276def zeros_heur_num_warps(args): 

277 if ( 

278 args["output_ptr"].dtype == torch.float16 

279 or args["output_ptr"].dtype == torch.bfloat16 

280 ): 

281 return 2 

282 else: 

283 return 4 

284 

285 

286HEURISTICS_CONFIGS = { 

287 "amax": { 

288 "BLOCK_M": lambda args: 4, 

289 "BLOCK_N": lambda args: 1024, 

290 }, 

291 "argmax": { 

292 "BLOCK_M": argmax_heur_block_m, 

293 "BLOCK_N": argmax_heur_block_n, 

294 }, 

295 "argmin": { 

296 "BLOCK_M": argmin_heur_block_m, 

297 "BLOCK_N": argmin_heur_block_n, 

298 }, 

299 "bmm": { 

300 "DIVISIBLE_M": bmm_heur_divisible_m, 

301 "DIVISIBLE_N": bmm_heur_divisible_n, 

302 "DIVISIBLE_K": bmm_heur_divisible_k, 

303 }, 

304 "dropout": { 

305 "BLOCK": dropout_heur_block, 

306 "num_warps": dropout_heur_num_warps, 

307 }, 

308 "exponential_": { 

309 "BLOCK": exponential_heur_block, 

310 "num_warps": exponential_heur_num_warps, 

311 }, 

312 "gather": { 

313 "BLOCK_M": gather_heur_block_m, 

314 "BLOCK_N": gather_heur_block_n, 

315 }, 

316 "index": { 

317 "BLOCK_SIZE0": index_heur_block_0, 

318 "BLOCK_SIZE1": index_heur_block_1, 

319 }, 

320 "index_select": { 

321 "BLOCK_M": index_select_heur_block_m, 

322 "BLOCK_N": index_select_heur_block_n, 

323 }, 

324 "mm": { 

325 "EVEN_K": mm_heur_even_k, 

326 }, 

327 "nonzero": { 

328 "BLOCK_SIZE": lambda args: 2048, 

329 }, 

330 "ones": { 

331 "BLOCK_SIZE": ones_heur_block_size, 

332 "num_warps": ones_heur_num_warps, 

333 }, 

334 "rand": { 

335 "BLOCK": rand_heur_block, 

336 "num_warps": rand_heur_num_warps, 

337 }, 

338 "randn": { 

339 "BLOCK": randn_heur_block, 

340 "num_warps": randn_heur_num_warps, 

341 }, 

342 "softmax_non_inner": { 

343 "TILE_K": softmax_heur_tile_k, 

344 "TILE_N": softmax_heur_tile_n_non_inner, 

345 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, 

346 "num_warps": softmax_heur_num_warps_non_inner, 

347 }, 

348 "softmax_inner": { 

349 "TILE_N": softmax_heur_tile_n_inner, 

350 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, 

351 "num_warps": softmax_heur_num_warps_inner, 

352 }, 

353 "softmax_backward_non_inner": { 

354 "TILE_N": softmax_heur_tile_n_bwd_non_inner, 

355 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, 

356 }, 

357 "softmax_backward_inner": { 

358 "TILE_M": softmax_heur_tile_m, 

359 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, 

360 }, 

361 "uniform": { 

362 "BLOCK": uniform_heur_block, 

363 "num_warps": uniform_heur_num_warps, 

364 }, 

365 "upsample_nearest2d": { 

366 "SAME_H": upsample_nearest2d_SAME_H, 

367 "SAME_W": upsample_nearest2d_SAME_W, 

368 "USE_INT32_IDX": upsample_nearest2d_USE_INT32_IDX, 

369 }, 

370 "var_mean": { 

371 "BLOCK_N": var_mean_heur_block_n, 

372 }, 

373 "batch_norm": { 

374 "BLOCK_M": batch_norm_heur_block_m, 

375 "BLOCK_N": batch_norm_heur_block_n, 

376 }, 

377 "vdot": { 

378 "BLOCK_SIZE": vdot_heur_block_size, 

379 }, 

380 "zeros": { 

381 "BLOCK_SIZE": zeros_heur_block_size, 

382 "num_warps": zeros_heur_num_warps, 

383 }, 

384 "mha_block_128": { 

385 "BLOCK_M": lambda args: 128, 

386 "BLOCK_N": lambda args: 32, 

387 "num_warps": lambda args: 4, 

388 "num_stages": lambda args: 3, 

389 }, 

390 "mha_block_64": { 

391 "BLOCK_M": lambda args: 64, 

392 "BLOCK_N": lambda args: 32, 

393 "num_warps": lambda args: 4, 

394 "num_stages": lambda args: 3, 

395 }, 

396 "mha_block_32": { 

397 "BLOCK_M": lambda args: 32, 

398 "BLOCK_N": lambda args: 16, 

399 "num_warps": lambda args: 4, 

400 "num_stages": lambda args: 3, 

401 }, 

402 "mha_block_16": { 

403 "BLOCK_M": lambda args: 16, 

404 "BLOCK_N": lambda args: 16, 

405 "num_warps": lambda args: 4, 

406 "num_stages": lambda args: 3, 

407 }, 

408 "elementwise_generic": { 

409 "BLOCK_SIZE": simple_elementwise_blocksize_heur, 

410 "num_warps": lambda args: 8, 

411 }, 

412}