Coverage for src/flag_gems/runtime/backend/_tsingmicro/heuristics_config_utils.py: 0%
215 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
1import torch
2import triton
4_MIN_TILE_N = 64
5_MAX_TILE_N_PER_ROW = 4096
6_MAX_ONE_TILE_N = 2048
9def simple_elementwise_blocksize_heur(args):
10 return 1024
13def argmax_heur_tile_k(args):
14 MAX_TILE_K = 512
15 NUM_SMS = torch.txda.get_device_properties(
16 torch.txda.current_device()
17 ).multi_processor_count
19 K = args["K"]
20 M = args["M"]
21 dtype = "fp32" if args["inp"].dtype == torch.float32 else "fp16"
23 if M == 64 and K == 512:
24 return 64 if dtype == "fp32" else 128
26 if K <= 128:
27 return 1 << (K.bit_length() - 1) if K > 0 else 1
29 tile_k = 64
30 upper_bound = min(K, MAX_TILE_K)
32 while tile_k <= upper_bound:
33 num_blocks = M * triton.cdiv(K, tile_k)
34 num_waves = num_blocks / NUM_SMS
36 if num_waves > 1 and (tile_k * 2 <= upper_bound):
37 tile_k *= 2
38 else:
39 break
41 return tile_k
44def argmax_heur_tile_n_non_inner(args):
45 n = args["N"]
46 tile_k = args["TILE_K"]
48 if n <= 128:
49 return n
51 target_tile = min(8192, n)
52 tile_n = triton.next_power_of_2(target_tile)
53 tile_n = max(64, min(tile_n, 4096))
55 if tile_n * tile_k > 32768:
56 tile_n = max(64, 32768 // tile_k)
58 return tile_n
61def argmax_heur_one_tile_per_cta(args):
62 return args["TILE_N"] >= args["N"]
65def argmax_heur_num_warps_non_inner(args):
66 # tile_n = args["TILE_N"]
67 # dtype = "fp32" if args["inp"].dtype == torch.float32 else "fp16"
69 # if tile_n <= 32:
70 # num_warps = 2
71 # elif tile_n <= 64:
72 # num_warps = 4
73 # elif tile_n <= 128:
74 # num_warps = 4
75 # else:
76 # num_warps = 8
78 # if dtype == "fp32":
79 # num_warps = min(num_warps, 4)
81 # return num_warps
83 return 1
86def argmax_heur_tile_n_inner(args):
87 if args["N"] <= (32 * 1024):
88 return triton.next_power_of_2(args["N"])
89 else:
90 return 4096
93def argmax_heur_num_warps_inner(args):
94 # tile_size = args["TILE_N"]
95 # if tile_size < 2048:
96 # return 4
97 # elif tile_size < 4096:
98 # return 8
99 # else:
100 # return 16
102 return 1
105def argmin_heur_block_m(args):
106 return 16 if args["M"] < 4096 else 32
109def argmin_heur_block_n(args):
110 return min(16384, triton.next_power_of_2(args["N"]))
113def bmm_heur_divisible_m(args):
114 return args["M"] % args["TILE_M"] == 0
117def bmm_heur_divisible_n(args):
118 return args["N"] % args["TILE_N"] == 0
121def bmm_heur_divisible_k(args):
122 return args["K"] % args["TILE_K"] == 0
125def baddbmm_heur_divisible_m(args):
126 return args["M"] % args["TILE_M"] == 0
129def baddbmm_heur_divisible_n(args):
130 return args["N"] % args["TILE_N"] == 0
133def baddbmm_heur_divisible_k(args):
134 return args["K"] % args["TILE_K"] == 0
137def dropout_heur_block(args):
138 if args["N"] <= 512:
139 return 512
140 else:
141 return 1024
144def dropout_heur_num_warps(args):
145 if args["N"] <= 512:
146 return 4
147 elif args["N"] <= 1024:
148 return 8
149 else:
150 return 16
153def exponential_heur_block(args):
154 if args["N"] <= 512:
155 return 512
156 else:
157 return 1024
160def exponential_heur_num_warps(args):
161 if args["N"] <= 512:
162 return 4
163 elif args["N"] <= 1024:
164 return 8
165 else:
166 return 16
169def gather_heur_block_m(args):
170 return min(4, triton.next_power_of_2(triton.cdiv(args["N"], 2048)))
173def gather_heur_block_n(args):
174 return min(2048, triton.next_power_of_2(args["N"]))
177def index_select_heur_block_m(args):
178 return min(4, triton.next_power_of_2(triton.cdiv(256, args["N"])))
181def index_select_heur_block_n(args):
182 m = min(triton.next_power_of_2(triton.cdiv(args["N"], 16)), 512)
183 return max(m, 16)
186def mm_heur_even_k(args):
187 return args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0
190def rand_heur_block(args):
191 if args["N"] <= 512:
192 return 512
193 else:
194 return 1024
197def rand_heur_num_warps(args):
198 if args["N"] <= 512:
199 return 4
200 elif args["N"] <= 1024:
201 return 8
202 else:
203 return 16
206def randn_heur_block(args):
207 if args["N"] <= 512:
208 return 512
209 else:
210 return 1024
213def randn_heur_num_warps(args):
214 if args["N"] <= 512:
215 return 4
216 elif args["N"] <= 1024:
217 return 8
218 else:
219 return 16
222def softmax_heur_tile_k(args):
223 MAX_TILE_K = 8192
224 NUM_SMS = torch.txda.get_device_properties(
225 torch.txda.current_device()
226 ).multi_processor_count
227 tile_k = 1
228 upper_bound = min(args["K"], MAX_TILE_K)
229 while tile_k <= upper_bound:
230 num_blocks = args["M"] * triton.cdiv(args["K"], tile_k)
231 num_waves = num_blocks / NUM_SMS
232 if (num_waves > 1) and (tile_k * 2 <= upper_bound):
233 tile_k *= 2
234 else:
235 break
236 return tile_k
239def softmax_heur_tile_n_non_inner(args):
240 return triton.cdiv(8192, args["TILE_K"])
243def softmax_heur_one_tile_per_cta(args):
244 return args["TILE_N"] >= args["N"]
247def softmax_heur_num_warps_non_inner(args):
248 tile_size = args["TILE_N"] * args["TILE_K"]
249 if tile_size < 2048:
250 return 4
251 elif tile_size < 4096:
252 return 8
253 else:
254 return 16
257def softmax_heur_tile_n_inner(args):
258 if args["N"] <= (32 * 1024):
259 return triton.next_power_of_2(args["N"])
260 else:
261 return 4096
264def softmax_heur_num_warps_inner(args):
265 tile_size = args["TILE_N"]
266 if tile_size < 2048:
267 return 4
268 elif tile_size < 4096:
269 return 8
270 else:
271 return 16
274def softmax_heur_tile_n_bwd_non_inner(args):
275 return max(1, 1024 // args["TILE_K"])
278def softmax_heur_tile_m(args):
279 return max(1, 1024 // args["TILE_N"])
282def uniform_heur_block(args):
283 if args["N"] <= 512:
284 return 512
285 else:
286 return 1024
289def uniform_heur_num_warps(args):
290 if args["N"] <= 512:
291 return 4
292 elif args["N"] <= 1024:
293 return 8
294 else:
295 return 16
298def var_mean_heur_block_n(args):
299 return triton.next_power_of_2(args["BLOCK_NUM"])
302def upsample_nearest1d_SAME_L(args):
303 return args["OL"] == args["IL"]
306def upsample_nearest1d_USE_INT32_IDX(args):
307 return args["N"] * args["C"] * args["OL"] <= (2**31 - 1) # INT32 MAX
310def upsample_nearest2d_SAME_H(args):
311 return args["OH"] == args["IH"]
314def upsample_nearest2d_SAME_W(args):
315 return args["OW"] == args["IW"]
318def upsample_nearest2d_USE_INT32_IDX(args):
319 return args["N"] * args["C"] * args["OH"] * args["OW"] <= (2**31 - 1) # INT32 MAX
322def batch_norm_heur_block_m(args):
323 return min(2048, triton.next_power_of_2(args["batch_dim"]))
326def batch_norm_heur_block_n(args):
327 # A maximum of 16384 elements are loaded at once.
328 BLOCK_M = batch_norm_heur_block_m(args)
329 BLOCK_N = triton.next_power_of_2(args["spatial_dim"])
330 return min(BLOCK_N, max(1, 2**14 // BLOCK_M))
333def vdot_heur_block_size(args):
334 n = args["n_elements"]
335 if n < 1024:
336 return 32
337 elif n < 8192:
338 return 256
339 else:
340 return 1024
343def mean_heur_tile_k(args):
344 MAX_TILE_K = 512
345 NUM_SMS = torch.txda.get_device_properties(
346 torch.txda.current_device()
347 ).multi_processor_count
348 tile_k = 1
349 upper_bound = min(args["K"], MAX_TILE_K)
350 max_tile_k_allowed_by_tile_n = max(1, _MAX_TILE_N_PER_ROW // _MIN_TILE_N)
351 upper_bound = min(upper_bound, max_tile_k_allowed_by_tile_n)
352 while tile_k <= upper_bound:
353 num_blocks = args["M"] * triton.cdiv(args["K"], tile_k)
354 num_waves = num_blocks / NUM_SMS
355 if (num_waves > 1) and (tile_k * 2 <= upper_bound):
356 tile_k *= 2
357 else:
358 break
359 return tile_k
362def mean_heur_tile_n_non_inner(args):
363 tile_k = args.get("TILE_K", 1)
364 limit_by_k = max(1, _MAX_TILE_N_PER_ROW // tile_k)
365 N = args.get("N", 1)
366 desired = min(max(N, _MIN_TILE_N), limit_by_k)
367 desired = min(desired, _MAX_ONE_TILE_N, limit_by_k)
368 tile_n = triton.next_power_of_2(desired)
369 if tile_n > limit_by_k:
370 tile_n = limit_by_k
371 tile_n = max(tile_n, _MIN_TILE_N)
372 return tile_n
375def mean_heur_one_tile_per_cta(args):
376 return args["TILE_N"] >= args["N"]
379HEURISTICS_CONFIGS = {
380 "argmax_non_inner": {
381 "TILE_K": argmax_heur_tile_k,
382 "TILE_N": argmax_heur_tile_n_non_inner,
383 "ONE_TILE_PER_CTA": argmax_heur_one_tile_per_cta,
384 "num_warps": argmax_heur_num_warps_non_inner,
385 },
386 "argmax_inner": {
387 "TILE_N": argmax_heur_tile_n_inner,
388 "ONE_TILE_PER_CTA": argmax_heur_one_tile_per_cta,
389 "num_warps": argmax_heur_num_warps_inner,
390 },
391 "argmin": {
392 "BLOCK_M": argmin_heur_block_m,
393 "BLOCK_N": argmin_heur_block_n,
394 },
395 "bmm": {
396 "DIVISIBLE_M": bmm_heur_divisible_m,
397 "DIVISIBLE_N": bmm_heur_divisible_n,
398 "DIVISIBLE_K": bmm_heur_divisible_k,
399 },
400 "baddbmm": {
401 "DIVISIBLE_M": baddbmm_heur_divisible_m,
402 "DIVISIBLE_N": baddbmm_heur_divisible_n,
403 "DIVISIBLE_K": baddbmm_heur_divisible_k,
404 },
405 "dropout": {
406 "BLOCK": dropout_heur_block,
407 "num_warps": dropout_heur_num_warps,
408 },
409 "exponential_": {
410 "BLOCK": exponential_heur_block,
411 "num_warps": exponential_heur_num_warps,
412 },
413 "gather": {
414 "BLOCK_M": gather_heur_block_m,
415 "BLOCK_N": gather_heur_block_n,
416 },
417 "index_select": {
418 "BLOCK_M": index_select_heur_block_m,
419 "BLOCK_N": index_select_heur_block_n,
420 },
421 "mm": {
422 "EVEN_K": mm_heur_even_k,
423 },
424 "rand": {
425 "BLOCK": rand_heur_block,
426 "num_warps": rand_heur_num_warps,
427 },
428 "randn": {
429 "BLOCK": randn_heur_block,
430 "num_warps": randn_heur_num_warps,
431 },
432 "softmax_non_inner": {
433 "TILE_K": softmax_heur_tile_k,
434 "TILE_N": softmax_heur_tile_n_non_inner,
435 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
436 "num_warps": softmax_heur_num_warps_non_inner,
437 },
438 "mean_non_inner": {
439 "TILE_K": mean_heur_tile_k,
440 "TILE_N": mean_heur_tile_n_non_inner,
441 "ONE_TILE_PER_CTA": mean_heur_one_tile_per_cta,
442 "num_warps": softmax_heur_num_warps_non_inner,
443 },
444 "softmax_inner": {
445 "TILE_N": softmax_heur_tile_n_inner,
446 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
447 "num_warps": softmax_heur_num_warps_inner,
448 },
449 "softmax_backward_non_inner": {
450 "TILE_N": softmax_heur_tile_n_bwd_non_inner,
451 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
452 },
453 "softmax_backward_inner": {
454 "TILE_M": softmax_heur_tile_m,
455 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
456 },
457 "uniform": {
458 "BLOCK": uniform_heur_block,
459 "num_warps": uniform_heur_num_warps,
460 },
461 "upsample_nearest1d": {
462 "SAME_L": upsample_nearest1d_SAME_L,
463 "USE_INT32_IDX": upsample_nearest1d_USE_INT32_IDX,
464 },
465 "upsample_nearest2d": {
466 "SAME_H": upsample_nearest2d_SAME_H,
467 "SAME_W": upsample_nearest2d_SAME_W,
468 "USE_INT32_IDX": upsample_nearest2d_USE_INT32_IDX,
469 },
470 "var_mean": {
471 "BLOCK_N": var_mean_heur_block_n,
472 },
473 "batch_norm": {
474 "BLOCK_M": batch_norm_heur_block_m,
475 "BLOCK_N": batch_norm_heur_block_n,
476 },
477 "vdot": {
478 "BLOCK_SIZE": vdot_heur_block_size,
479 },
480 "mha_block_128": {
481 "BLOCK_M": lambda args: 128,
482 "BLOCK_N": lambda args: 32,
483 "num_warps": lambda args: 4,
484 "num_stages": lambda args: 3,
485 },
486 "mha_block_64": {
487 "BLOCK_M": lambda args: 64,
488 "BLOCK_N": lambda args: 64,
489 "num_warps": lambda args: 4,
490 "num_stages": lambda args: 3,
491 },
492 "mha_block_32": {
493 "BLOCK_M": lambda args: 32,
494 "BLOCK_N": lambda args: 64,
495 "num_warps": lambda args: 4,
496 "num_stages": lambda args: 3,
497 },
498 "mha_block_16": {
499 "BLOCK_M": lambda args: 16,
500 "BLOCK_N": lambda args: 64,
501 "num_warps": lambda args: 4,
502 "num_stages": lambda args: 3,
503 },
504 "elementwise_generic": {
505 "BLOCK_SIZE": simple_elementwise_blocksize_heur,
506 "num_warps": lambda args: 8,
507 },
508}