Coverage for src/flag_gems/runtime/backend/_metax/heuristics_config_utils.py: 0%
168 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-26 15:32 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-26 15:32 +0800
1import torch
2import triton
5def simple_elementwise_blocksize_heur(args):
6 return 512
9def argmax_heur_block_m(args):
10 return 4 if args["M"] < 4096 else 8
13def argmax_heur_block_n(args):
14 return min(4096, triton.next_power_of_2(args["N"]))
17def bmm_heur_divisible_m(args):
18 return args["M"] % args["TILE_M"] == 0
21def bmm_heur_divisible_n(args):
22 return args["N"] % args["TILE_N"] == 0
25def bmm_heur_divisible_k(args):
26 return args["K"] % args["TILE_K"] == 0
29def argmin_heur_block_m(args):
30 return 4 if args["M"] < 4096 else 8
33def argmin_heur_block_n(args):
34 return min(4096, triton.next_power_of_2(args["N"]))
37def dropout_heur_block(args):
38 if args["N"] <= 512:
39 return 512
40 else:
41 return 1024
44def dropout_heur_num_warps(args):
45 if args["N"] <= 512:
46 return 4
47 elif args["N"] <= 1024:
48 return 8
49 else:
50 return 16
53def exponential_heur_block(args):
54 if args["N"] <= 512:
55 return 512
56 else:
57 return 1024
60def exponential_heur_num_warps(args):
61 if args["N"] <= 512:
62 return 4
63 elif args["N"] <= 1024:
64 return 8
65 else:
66 return 16
69def gather_heur_block_m(args):
70 return min(4, triton.next_power_of_2(triton.cdiv(args["N"], 2048)))
73def gather_heur_block_n(args):
74 return min(2048, triton.next_power_of_2(args["N"]))
77def index_heur_block_0(args):
78 return 2
81def index_heur_block_1(args):
82 return 1024
85def index_select_heur_block_m(args):
86 return min(4, triton.next_power_of_2(triton.cdiv(256, args["N"])))
89def index_select_heur_block_n(args):
90 m = min(triton.next_power_of_2(triton.cdiv(args["N"], 16)), 512)
91 return max(m, 16)
94def mm_heur_even_k(args):
95 return args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0
98def ones_heur_block_size(args):
99 if args["N"] <= 1024:
100 return 1024
101 elif args["N"] <= 2048:
102 return 2048
103 else:
104 return 4096
107def ones_heur_num_warps(args):
108 if (
109 args["output_ptr"].dtype == torch.float16
110 or args["output_ptr"].dtype == torch.bfloat16
111 ):
112 return 2
113 else:
114 return 4
117def rand_heur_block(args):
118 if args["N"] <= 512:
119 return 512
120 else:
121 return 1024
124def rand_heur_num_warps(args):
125 if args["N"] <= 512:
126 return 4
127 elif args["N"] <= 1024:
128 return 8
129 else:
130 return 16
133def randn_heur_block(args):
134 if args["N"] <= 512:
135 return 512
136 else:
137 return 1024
140def randn_heur_num_warps(args):
141 if args["N"] <= 512:
142 return 4
143 elif args["N"] <= 1024:
144 return 8
145 else:
146 return 16
149def softmax_heur_tile_k(args):
150 MAX_TILE_K = 8192
151 NUM_SMS = torch.cuda.get_device_properties(
152 torch.cuda.current_device()
153 ).multi_processor_count
154 tile_k = 1
155 upper_bound = min(args["K"], MAX_TILE_K)
156 while tile_k <= upper_bound:
157 num_blocks = args["M"] * triton.cdiv(args["K"], tile_k)
158 num_waves = num_blocks / NUM_SMS
159 if (num_waves > 1) and (tile_k * 2 <= upper_bound):
160 tile_k *= 2
161 else:
162 break
163 return tile_k
166def softmax_heur_tile_n_non_inner(args):
167 upper_bound = triton.next_power_of_2(args["N"])
168 return min(upper_bound, triton.cdiv(8192, args["TILE_K"]))
171def softmax_heur_one_tile_per_cta(args):
172 return args["TILE_N"] >= args["N"]
175def softmax_heur_num_warps_non_inner(args):
176 tile_size = args["TILE_N"] * args["TILE_K"]
177 if tile_size < 512:
178 return 1
179 elif tile_size < 256:
180 return 2
181 elif tile_size < 2048:
182 return 4
183 elif tile_size < 4096:
184 return 8
185 else:
186 return 16
189def softmax_heur_tile_n_inner(args):
190 if args["N"] <= (32 * 1024):
191 return triton.next_power_of_2(args["N"])
192 else:
193 return 4096
196def softmax_heur_num_warps_inner(args):
197 tile_size = args["TILE_N"]
198 if tile_size < 2048:
199 return 4
200 elif tile_size < 4096:
201 return 8
202 else:
203 return 16
206def softmax_heur_tile_n_bwd_non_inner(args):
207 return max(1, 1024 // args["TILE_K"])
210def softmax_heur_tile_m(args):
211 return max(1, 1024 // args["TILE_N"])
214def uniform_heur_block(args):
215 if args["N"] <= 512:
216 return 512
217 else:
218 return 1024
221def uniform_heur_num_warps(args):
222 if args["N"] <= 512:
223 return 4
224 elif args["N"] <= 1024:
225 return 8
226 else:
227 return 16
230def var_mean_heur_block_n(args):
231 return triton.next_power_of_2(args["BLOCK_NUM"])
234def upsample_nearest2d_SAME_H(args):
235 return args["OH"] == args["IH"]
238def upsample_nearest2d_SAME_W(args):
239 return args["OW"] == args["IW"]
242def upsample_nearest2d_USE_INT32_IDX(args):
243 return args["N"] * args["C"] * args["OH"] * args["OW"] <= (2**31 - 1) # INT32 MAX
246def batch_norm_heur_block_m(args):
247 return min(2048, triton.next_power_of_2(args["batch_dim"]))
250def batch_norm_heur_block_n(args):
251 # A maximum of 16384 elements are loaded at once.
252 BLOCK_M = batch_norm_heur_block_m(args)
253 BLOCK_N = triton.next_power_of_2(args["spatial_dim"])
254 return min(BLOCK_N, max(1, 2**14 // BLOCK_M))
257def vdot_heur_block_size(args):
258 n = args["n_elements"]
259 if n < 1024:
260 return 32
261 elif n < 8192:
262 return 256
263 else:
264 return 1024
267def zeros_heur_block_size(args):
268 if args["N"] <= 1024:
269 return 1024
270 elif args["N"] <= 2048:
271 return 2048
272 else:
273 return 4096
276def zeros_heur_num_warps(args):
277 if (
278 args["output_ptr"].dtype == torch.float16
279 or args["output_ptr"].dtype == torch.bfloat16
280 ):
281 return 2
282 else:
283 return 4
286HEURISTICS_CONFIGS = {
287 "amax": {
288 "BLOCK_M": lambda args: 4,
289 "BLOCK_N": lambda args: 1024,
290 },
291 "argmax": {
292 "BLOCK_M": argmax_heur_block_m,
293 "BLOCK_N": argmax_heur_block_n,
294 },
295 "argmin": {
296 "BLOCK_M": argmin_heur_block_m,
297 "BLOCK_N": argmin_heur_block_n,
298 },
299 "bmm": {
300 "DIVISIBLE_M": bmm_heur_divisible_m,
301 "DIVISIBLE_N": bmm_heur_divisible_n,
302 "DIVISIBLE_K": bmm_heur_divisible_k,
303 },
304 "dropout": {
305 "BLOCK": dropout_heur_block,
306 "num_warps": dropout_heur_num_warps,
307 },
308 "exponential_": {
309 "BLOCK": exponential_heur_block,
310 "num_warps": exponential_heur_num_warps,
311 },
312 "gather": {
313 "BLOCK_M": gather_heur_block_m,
314 "BLOCK_N": gather_heur_block_n,
315 },
316 "index": {
317 "BLOCK_SIZE0": index_heur_block_0,
318 "BLOCK_SIZE1": index_heur_block_1,
319 },
320 "index_select": {
321 "BLOCK_M": index_select_heur_block_m,
322 "BLOCK_N": index_select_heur_block_n,
323 },
324 "mm": {
325 "EVEN_K": mm_heur_even_k,
326 },
327 "nonzero": {
328 "BLOCK_SIZE": lambda args: 2048,
329 },
330 "ones": {
331 "BLOCK_SIZE": ones_heur_block_size,
332 "num_warps": ones_heur_num_warps,
333 },
334 "rand": {
335 "BLOCK": rand_heur_block,
336 "num_warps": rand_heur_num_warps,
337 },
338 "randn": {
339 "BLOCK": randn_heur_block,
340 "num_warps": randn_heur_num_warps,
341 },
342 "softmax_non_inner": {
343 "TILE_K": softmax_heur_tile_k,
344 "TILE_N": softmax_heur_tile_n_non_inner,
345 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
346 "num_warps": softmax_heur_num_warps_non_inner,
347 },
348 "softmax_inner": {
349 "TILE_N": softmax_heur_tile_n_inner,
350 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
351 "num_warps": softmax_heur_num_warps_inner,
352 },
353 "softmax_backward_non_inner": {
354 "TILE_N": softmax_heur_tile_n_bwd_non_inner,
355 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
356 },
357 "softmax_backward_inner": {
358 "TILE_M": softmax_heur_tile_m,
359 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
360 },
361 "uniform": {
362 "BLOCK": uniform_heur_block,
363 "num_warps": uniform_heur_num_warps,
364 },
365 "upsample_nearest2d": {
366 "SAME_H": upsample_nearest2d_SAME_H,
367 "SAME_W": upsample_nearest2d_SAME_W,
368 "USE_INT32_IDX": upsample_nearest2d_USE_INT32_IDX,
369 },
370 "var_mean": {
371 "BLOCK_N": var_mean_heur_block_n,
372 },
373 "batch_norm": {
374 "BLOCK_M": batch_norm_heur_block_m,
375 "BLOCK_N": batch_norm_heur_block_n,
376 },
377 "vdot": {
378 "BLOCK_SIZE": vdot_heur_block_size,
379 },
380 "zeros": {
381 "BLOCK_SIZE": zeros_heur_block_size,
382 "num_warps": zeros_heur_num_warps,
383 },
384 "mha_block_128": {
385 "BLOCK_M": lambda args: 128,
386 "BLOCK_N": lambda args: 32,
387 "num_warps": lambda args: 4,
388 "num_stages": lambda args: 3,
389 },
390 "mha_block_64": {
391 "BLOCK_M": lambda args: 64,
392 "BLOCK_N": lambda args: 32,
393 "num_warps": lambda args: 4,
394 "num_stages": lambda args: 3,
395 },
396 "mha_block_32": {
397 "BLOCK_M": lambda args: 32,
398 "BLOCK_N": lambda args: 16,
399 "num_warps": lambda args: 4,
400 "num_stages": lambda args: 3,
401 },
402 "mha_block_16": {
403 "BLOCK_M": lambda args: 16,
404 "BLOCK_N": lambda args: 16,
405 "num_warps": lambda args: 4,
406 "num_stages": lambda args: 3,
407 },
408 "elementwise_generic": {
409 "BLOCK_SIZE": simple_elementwise_blocksize_heur,
410 "num_warps": lambda args: 8,
411 },
412}