Coverage for src/flag_gems/runtime/backend/_nvidia/heuristics_config_utils.py: 83%
243 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-19 02:32 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-19 02:32 +0800
1import torch
2import triton
4_MIN_TILE_N = 64
5_MAX_TILE_N_PER_ROW = 4096
6_MAX_ONE_TILE_N = 2048
9def simple_elementwise_blocksize_heur(args):
10 return 1024
13def argmax_heur_tile_k(args):
14 MAX_TILE_K = 512
15 NUM_SMS = torch.cuda.get_device_properties(
16 torch.cuda.current_device()
17 ).multi_processor_count
19 K = args["K"]
20 M = args["M"]
21 dtype = "fp32" if args["inp"].dtype == torch.float32 else "fp16"
23 if M == 64 and K == 512:
24 return 64 if dtype == "fp32" else 128
26 if K <= 128:
27 return 1 << (K.bit_length() - 1) if K > 0 else 1
29 tile_k = 64
30 upper_bound = min(K, MAX_TILE_K)
32 while tile_k <= upper_bound:
33 num_blocks = M * triton.cdiv(K, tile_k)
34 num_waves = num_blocks / NUM_SMS
36 if num_waves > 1 and (tile_k * 2 <= upper_bound):
37 tile_k *= 2
38 else:
39 break
41 return tile_k
44def argmax_heur_tile_n_non_inner(args):
45 n = args["N"]
46 tile_k = args["TILE_K"]
48 if n <= 128:
49 return n
51 target_tile = min(8192, n)
52 tile_n = triton.next_power_of_2(target_tile)
53 tile_n = max(64, min(tile_n, 4096))
55 if tile_n * tile_k > 32768:
56 tile_n = max(64, 32768 // tile_k)
58 return tile_n
61def argmax_heur_one_tile_per_cta(args):
62 return args["TILE_N"] >= args["N"]
65def argmax_heur_num_warps_non_inner(args):
66 tile_n = args["TILE_N"]
67 dtype = "fp32" if args["inp"].dtype == torch.float32 else "fp16"
69 if tile_n <= 32:
70 num_warps = 2
71 elif tile_n <= 64:
72 num_warps = 4
73 elif tile_n <= 128:
74 num_warps = 4
75 else:
76 num_warps = 8
78 if dtype == "fp32":
79 num_warps = min(num_warps, 4)
81 return num_warps
84def argmax_heur_tile_n_inner(args):
85 if args["N"] <= (32 * 1024):
86 return triton.next_power_of_2(args["N"])
87 else:
88 return 4096
91def argmax_heur_num_warps_inner(args):
92 tile_size = args["TILE_N"]
93 if tile_size < 2048:
94 return 4
95 elif tile_size < 4096:
96 return 8
97 else:
98 return 16
101def argmin_heur_block_m(args):
102 return 4 if args["M"] < 4096 else 8
105def argmin_heur_block_n(args):
106 return min(4096, triton.next_power_of_2(args["N"]))
109def bmm_heur_divisible_m(args):
110 return args["M"] % args["TILE_M"] == 0
113def bmm_heur_divisible_n(args):
114 return args["N"] % args["TILE_N"] == 0
117def bmm_heur_divisible_k(args):
118 return args["K"] % args["TILE_K"] == 0
121def baddbmm_heur_divisible_m(args):
122 return args["M"] % args["TILE_M"] == 0
125def baddbmm_heur_divisible_n(args):
126 return args["N"] % args["TILE_N"] == 0
129def baddbmm_heur_divisible_k(args):
130 return args["K"] % args["TILE_K"] == 0
133def dropout_heur_block(args):
134 if args["N"] <= 512:
135 return 512
136 else:
137 return 1024
140def dropout_heur_num_warps(args):
141 if args["N"] <= 512:
142 return 4
143 elif args["N"] <= 1024:
144 return 8
145 else:
146 return 16
149def exponential_heur_block(args):
150 if args["N"] <= 512:
151 return 512
152 else:
153 return 1024
156def exponential_heur_num_warps(args):
157 if args["N"] <= 512:
158 return 4
159 elif args["N"] <= 1024:
160 return 8
161 else:
162 return 16
165def gather_heur_block_m(args):
166 return min(4, triton.next_power_of_2(triton.cdiv(args["N"], 2048)))
169def gather_heur_block_n(args):
170 return min(2048, triton.next_power_of_2(args["N"]))
173def index_select_heur_block_m(args):
174 return min(4, triton.next_power_of_2(triton.cdiv(256, args["N"])))
177def index_select_heur_block_n(args):
178 m = min(triton.next_power_of_2(triton.cdiv(args["N"], 16)), 512)
179 return max(m, 16)
182def mm_heur_even_k(args):
183 return args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0
186def rand_heur_block(args):
187 if args["N"] <= 512:
188 return 512
189 else:
190 return 1024
193def rand_heur_num_warps(args):
194 if args["N"] <= 512:
195 return 4
196 elif args["N"] <= 1024:
197 return 8
198 else:
199 return 16
202def randn_heur_block(args):
203 if args["N"] <= 512:
204 return 512
205 else:
206 return 1024
209def randn_heur_num_warps(args):
210 if args["N"] <= 512:
211 return 4
212 elif args["N"] <= 1024:
213 return 8
214 else:
215 return 16
218def softmax_heur_tile_k(args):
219 MAX_TILE_K = 8192
220 NUM_SMS = torch.cuda.get_device_properties(
221 torch.cuda.current_device()
222 ).multi_processor_count
223 tile_k = 1
224 upper_bound = min(args["K"], MAX_TILE_K)
225 while tile_k <= upper_bound:
226 num_blocks = args["M"] * triton.cdiv(args["K"], tile_k)
227 num_waves = num_blocks / NUM_SMS
228 if (num_waves > 1) and (tile_k * 2 <= upper_bound):
229 tile_k *= 2
230 else:
231 break
232 return tile_k
235def softmax_heur_tile_n_non_inner(args):
236 return triton.cdiv(8192, args["TILE_K"])
239def softmax_heur_one_tile_per_cta(args):
240 return args["TILE_N"] >= args["N"]
243def softmax_heur_num_warps_non_inner(args):
244 tile_size = args["TILE_N"] * args["TILE_K"]
245 if tile_size < 2048:
246 return 4
247 elif tile_size < 4096:
248 return 8
249 else:
250 return 16
253def softmax_heur_tile_n_inner(args):
254 if args["N"] <= (32 * 1024):
255 return triton.next_power_of_2(args["N"])
256 else:
257 return 4096
260def softmax_heur_num_warps_inner(args):
261 tile_size = args["TILE_N"]
262 if tile_size < 2048:
263 return 4
264 elif tile_size < 4096:
265 return 8
266 else:
267 return 16
270def softmax_heur_tile_n_bwd_non_inner(args):
271 return max(1, 1024 // args["TILE_K"])
274def softmax_heur_tile_m(args):
275 return max(1, 1024 // args["TILE_N"])
278def uniform_heur_block(args):
279 if args["N"] <= 512:
280 return 512
281 else:
282 return 1024
285def uniform_heur_num_warps(args):
286 if args["N"] <= 512:
287 return 4
288 elif args["N"] <= 1024:
289 return 8
290 else:
291 return 16
294def var_mean_heur_block_n(args):
295 return triton.next_power_of_2(args["BLOCK_NUM"])
298def upsample_nearest1d_SAME_L(args):
299 return args["OL"] == args["IL"]
302def upsample_nearest1d_USE_INT32_IDX(args):
303 return args["N"] * args["C"] * args["OL"] <= (2**31 - 1) # INT32 MAX
306def upsample_nearest2d_SAME_H(args):
307 return args["OH"] == args["IH"]
310def upsample_nearest2d_SAME_W(args):
311 return args["OW"] == args["IW"]
314def upsample_nearest2d_USE_INT32_IDX(args):
315 return args["N"] * args["C"] * args["OH"] * args["OW"] <= (2**31 - 1) # INT32 MAX
318def upsample_nearest3d_SAME_D(args):
319 return args["OD"] == args["ID"]
322def upsample_nearest3d_SAME_H(args):
323 return args["OH"] == args["IH"]
326def upsample_nearest3d_SAME_W(args):
327 return args["OW"] == args["IW"]
330def upsample_nearest3d_USE_INT32_IDX(args):
331 return args["N"] * args["C"] * args["OD"] * args["OH"] * args["OW"] <= (2**31 - 1)
334def batch_norm_heur_block_m(args):
335 return min(2048, triton.next_power_of_2(args["batch_dim"]))
338def batch_norm_heur_block_n(args):
339 # A maximum of 16384 elements are loaded at once.
340 BLOCK_M = batch_norm_heur_block_m(args)
341 BLOCK_N = triton.next_power_of_2(args["spatial_dim"])
342 return min(BLOCK_N, max(1, 2**14 // BLOCK_M))
345def vdot_heur_block_size(args):
346 n = args["n_elements"]
347 if n < 1024:
348 return 32
349 elif n < 8192:
350 return 256
351 else:
352 return 1024
355def mean_heur_tile_k(args):
356 MAX_TILE_K = 512
357 MAX_GRID_Y = 65535
358 NUM_SMS = torch.cuda.get_device_properties(
359 torch.cuda.current_device()
360 ).multi_processor_count
361 tile_k = 1
362 upper_bound = min(args["K"], MAX_TILE_K)
363 max_tile_k_allowed_by_tile_n = max(1, _MAX_TILE_N_PER_ROW // _MIN_TILE_N)
364 upper_bound = min(upper_bound, max_tile_k_allowed_by_tile_n)
365 while tile_k <= upper_bound:
366 num_blocks = args["M"] * triton.cdiv(args["K"], tile_k)
367 num_waves = num_blocks / NUM_SMS
368 if (num_waves > 1) and (tile_k * 2 <= upper_bound):
369 tile_k *= 2
370 else:
371 break
372 # Ensure grid Y dimension does not exceed CUDA limit
373 min_tile_k = triton.cdiv(args["K"], MAX_GRID_Y)
374 if min_tile_k > tile_k:
375 tile_k = triton.next_power_of_2(min_tile_k)
376 return tile_k
379def mean_heur_tile_n_non_inner(args):
380 tile_k = args.get("TILE_K", 1)
381 limit_by_k = max(1, _MAX_TILE_N_PER_ROW // tile_k)
382 N = args.get("N", 1)
383 desired = min(max(N, _MIN_TILE_N), limit_by_k)
384 desired = min(desired, _MAX_ONE_TILE_N, limit_by_k)
385 tile_n = triton.next_power_of_2(desired)
386 if tile_n > limit_by_k:
387 tile_n = limit_by_k
388 tile_n = max(tile_n, _MIN_TILE_N)
389 return tile_n
392def mean_heur_one_tile_per_cta(args):
393 return args["TILE_N"] >= args["N"]
396HEURISTICS_CONFIGS = {
397 "argmax_non_inner": {
398 "TILE_K": argmax_heur_tile_k,
399 "TILE_N": argmax_heur_tile_n_non_inner,
400 "ONE_TILE_PER_CTA": argmax_heur_one_tile_per_cta,
401 "num_warps": argmax_heur_num_warps_non_inner,
402 },
403 "argmax_inner": {
404 "TILE_N": argmax_heur_tile_n_inner,
405 "ONE_TILE_PER_CTA": argmax_heur_one_tile_per_cta,
406 "num_warps": argmax_heur_num_warps_inner,
407 },
408 "argmin": {
409 "BLOCK_M": argmin_heur_block_m,
410 "BLOCK_N": argmin_heur_block_n,
411 },
412 "bmm": {
413 "DIVISIBLE_M": bmm_heur_divisible_m,
414 "DIVISIBLE_N": bmm_heur_divisible_n,
415 "DIVISIBLE_K": bmm_heur_divisible_k,
416 },
417 "baddbmm": {
418 "DIVISIBLE_M": baddbmm_heur_divisible_m,
419 "DIVISIBLE_N": baddbmm_heur_divisible_n,
420 "DIVISIBLE_K": baddbmm_heur_divisible_k,
421 },
422 "dropout": {
423 "BLOCK": dropout_heur_block,
424 "num_warps": dropout_heur_num_warps,
425 },
426 "exponential_": {
427 "BLOCK": exponential_heur_block,
428 "num_warps": exponential_heur_num_warps,
429 },
430 "gather": {
431 "BLOCK_M": gather_heur_block_m,
432 "BLOCK_N": gather_heur_block_n,
433 },
434 "index_select": {
435 "BLOCK_M": index_select_heur_block_m,
436 "BLOCK_N": index_select_heur_block_n,
437 },
438 "mm": {
439 "EVEN_K": mm_heur_even_k,
440 },
441 "rand": {
442 "BLOCK": rand_heur_block,
443 "num_warps": rand_heur_num_warps,
444 },
445 "randn": {
446 "BLOCK": randn_heur_block,
447 "num_warps": randn_heur_num_warps,
448 },
449 "softmax_non_inner": {
450 "TILE_K": softmax_heur_tile_k,
451 "TILE_N": softmax_heur_tile_n_non_inner,
452 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
453 "num_warps": softmax_heur_num_warps_non_inner,
454 },
455 "mean_non_inner": {
456 "TILE_K": mean_heur_tile_k,
457 "TILE_N": mean_heur_tile_n_non_inner,
458 "ONE_TILE_PER_CTA": mean_heur_one_tile_per_cta,
459 "num_warps": softmax_heur_num_warps_non_inner,
460 },
461 "softmax_inner": {
462 "TILE_N": softmax_heur_tile_n_inner,
463 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
464 "num_warps": softmax_heur_num_warps_inner,
465 },
466 "softmax_backward_non_inner": {
467 "TILE_N": softmax_heur_tile_n_bwd_non_inner,
468 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
469 },
470 "softmax_backward_inner": {
471 "TILE_M": softmax_heur_tile_m,
472 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
473 },
474 "uniform": {
475 "BLOCK": uniform_heur_block,
476 "num_warps": uniform_heur_num_warps,
477 },
478 "upsample_nearest1d": {
479 "SAME_L": upsample_nearest1d_SAME_L,
480 "USE_INT32_IDX": upsample_nearest1d_USE_INT32_IDX,
481 },
482 "upsample_nearest2d": {
483 "SAME_H": upsample_nearest2d_SAME_H,
484 "SAME_W": upsample_nearest2d_SAME_W,
485 "USE_INT32_IDX": upsample_nearest2d_USE_INT32_IDX,
486 },
487 "upsample_nearest3d": {
488 "SAME_D": upsample_nearest3d_SAME_D,
489 "SAME_H": upsample_nearest3d_SAME_H,
490 "SAME_W": upsample_nearest3d_SAME_W,
491 "USE_INT32_IDX": upsample_nearest3d_USE_INT32_IDX,
492 },
493 "var_mean": {
494 "BLOCK_N": var_mean_heur_block_n,
495 },
496 "batch_norm": {
497 "BLOCK_M": batch_norm_heur_block_m,
498 "BLOCK_N": batch_norm_heur_block_n,
499 },
500 "vdot": {
501 "BLOCK_SIZE": vdot_heur_block_size,
502 },
503 "mha_block_128": {
504 "BLOCK_M": lambda args: 128,
505 "BLOCK_N": lambda args: 32,
506 "num_warps": lambda args: 4,
507 "num_stages": lambda args: 3,
508 },
509 "mha_block_64": {
510 "BLOCK_M": lambda args: 64,
511 "BLOCK_N": lambda args: 64,
512 "num_warps": lambda args: 4,
513 "num_stages": lambda args: 3,
514 },
515 "mha_block_32": {
516 "BLOCK_M": lambda args: 32,
517 "BLOCK_N": lambda args: 64,
518 "num_warps": lambda args: 4,
519 "num_stages": lambda args: 3,
520 },
521 "mha_block_16": {
522 "BLOCK_M": lambda args: 16,
523 "BLOCK_N": lambda args: 64,
524 "num_warps": lambda args: 4,
525 "num_stages": lambda args: 3,
526 },
527 "elementwise_generic": {
528 "BLOCK_SIZE": simple_elementwise_blocksize_heur,
529 "num_warps": lambda args: 8,
530 },
531}