Coverage for src/flag_gems/runtime/backend/_sunrise/heuristics_config_utils.py: 0%
174 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
1import torch # noqa: F401
2import triton
5def simple_elementwise_blocksize_heur(args):
6 return 1024
9def argmax_heur_block_m(args):
10 return 1 if args["M"] < 4096 else 8
13def argmax_heur_block_n(args):
14 return min(4096, triton.next_power_of_2(args["N"]))
17def argmin_heur_block_m(args):
18 return 4 if args["M"] < 4096 else 8
21def argmin_heur_block_n(args):
22 return min(4096, triton.next_power_of_2(args["N"]))
25def bmm_heur_divisible_m(args):
26 return args["M"] % args["TILE_M"] == 0
29def bmm_heur_divisible_n(args):
30 return args["N"] % args["TILE_N"] == 0
33def bmm_heur_divisible_k(args):
34 return args["K"] % args["TILE_K"] == 0
37def dropout_heur_block(args):
38 if args["N"] <= 512:
39 return 256
40 else:
41 return 512
44def dropout_heur_num_warps(args):
45 if args["N"] <= 512:
46 return 2
47 elif args["N"] <= 2048:
48 return 4
49 else:
50 return 8
53def exponential_heur_block(args):
54 if args["N"] <= 512:
55 return 512
56 else:
57 return 1024
60def exponential_heur_num_warps(args):
61 if args["N"] <= 512:
62 return 4
63 elif args["N"] <= 1024:
64 return 8
65 else:
66 return 16
69def gather_heur_block_m(args):
70 return min(1, triton.next_power_of_2(triton.cdiv(args["N"], 2048)))
73def gather_heur_block_n(args):
74 return min(2048, triton.next_power_of_2(args["N"]))
77def index_select_heur_block_m(args):
78 return min(4, triton.next_power_of_2(triton.cdiv(256, args["N"])))
81def index_select_heur_block_n(args):
82 m = min(triton.next_power_of_2(triton.cdiv(args["N"], 16)), 512)
83 return max(m, 16)
86def mm_heur_even_k(args):
87 return args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0
90def rand_heur_block(args):
91 if args["N"] <= 512:
92 return 512
93 else:
94 return 1024
97def rand_heur_num_warps(args):
98 if args["N"] <= 512:
99 return 4
100 elif args["N"] <= 1024:
101 return 8
102 else:
103 return 16
106def randn_heur_block(args):
107 if args["N"] <= 512:
108 return 512
109 else:
110 return 1024
113def randn_heur_num_warps(args):
114 if args["N"] <= 512:
115 return 4
116 elif args["N"] <= 1024:
117 return 8
118 else:
119 return 16
122def softmax_heur_tile_k(args):
123 MAX_TILE_K = 512
124 # NUM_SMS = torch.cuda.get_device_properties(
125 # torch.cuda.current_device()
126 # ).multi_processor_count
127 NUM_SMS = 32 # Not support now.
129 tile_k = 1
130 upper_bound = min(args["K"], MAX_TILE_K)
131 while tile_k <= upper_bound:
132 num_blocks = args["M"] * triton.cdiv(args["K"], tile_k)
133 num_waves = num_blocks / NUM_SMS
134 if (num_waves > 1) and (tile_k * 2 <= upper_bound):
135 tile_k *= 2
136 else:
137 break
138 return tile_k
141def softmax_heur_tile_n_non_inner(args):
142 return triton.cdiv(8192, args["TILE_K"])
145def softmax_heur_one_tile_per_cta(args):
146 return args["TILE_N"] >= args["N"]
149def softmax_heur_num_warps_non_inner(args):
150 tile_size = args["TILE_N"] * args["TILE_K"]
151 if tile_size < 2048:
152 return 4
153 elif tile_size < 4096:
154 return 8
155 else:
156 return 16
159def softmax_heur_tile_n_inner(args):
160 if args["N"] <= 32:
161 return triton.next_power_of_2(args["N"])
162 if args["N"] <= 1024:
163 return 256
164 else:
165 return 512
168def softmax_heur_num_warps_inner(args):
169 tile_size = args["TILE_N"]
170 if tile_size < 64:
171 return 2
172 if tile_size < 2048:
173 return 4
174 elif tile_size < 4096:
175 return 8
176 else:
177 return 16
180def softmax_heur_tile_n_bwd_non_inner(args):
181 return max(1, 1024 // args["TILE_K"])
184def softmax_heur_tile_m(args):
185 return max(1, 1024 // args["TILE_N"])
188def uniform_heur_block(args):
189 if args["N"] <= 512:
190 return 512
191 else:
192 return 1024
195def uniform_heur_num_warps(args):
196 if args["N"] <= 512:
197 return 4
198 elif args["N"] <= 1024:
199 return 8
200 else:
201 return 16
204def var_mean_heur_block_n(args):
205 return triton.next_power_of_2(args["BLOCK_NUM"])
208def upsample_nearest2d_SAME_H(args):
209 return args["OH"] == args["IH"]
212def upsample_nearest2d_SAME_W(args):
213 return args["OW"] == args["IW"]
216def upsample_nearest2d_USE_INT32_IDX(args):
217 return args["N"] * args["C"] * args["OH"] * args["OW"] <= (2**31 - 1) # INT32 MAX
220def batch_norm_heur_block_m(args):
221 return min(256, triton.next_power_of_2(args["batch_dim"]))
224def batch_norm_heur_block_n(args):
225 # A maximum of 16384 elements are loaded at once.
226 BLOCK_M = batch_norm_heur_block_m(args)
227 BLOCK_N = triton.next_power_of_2(args["spatial_dim"])
228 return min(BLOCK_N, max(1, 2**10 // BLOCK_M))
231def vdot_heur_block_size(args):
232 n = args["n_elements"]
233 if n < 1024:
234 return 32
235 elif n < 8192:
236 return 256
237 else:
238 return 1024
241def sum_heur_num_warps_inner(args):
242 tile_size = args["TILE_N"]
243 if tile_size < 64:
244 return 2
245 if tile_size < 2048:
246 return 4
247 elif tile_size < 4096:
248 return 8
249 else:
250 return 16
253def sum_heur_tile_n_inner(args):
254 if args["N"] <= 32:
255 return triton.next_power_of_2(args["N"])
256 if args["N"] <= 1024:
257 return 128
258 else:
259 return 256
262def sum_heur_one_tile_per_cta(args):
263 return args["TILE_N"] >= args["N"]
266def sum_heur_tile_k(args):
267 MAX_TILE_K = 128
268 # NUM_SMS = torch.cuda.get_device_properties(
269 # torch.cuda.current_device()
270 # ).multi_processor_count
271 NUM_SMS = 32 # Not support now.
273 tile_k = 1
274 upper_bound = min(args["K"], MAX_TILE_K)
275 while tile_k <= upper_bound:
276 num_blocks = args["M"] * triton.cdiv(args["K"], tile_k)
277 num_waves = num_blocks / NUM_SMS
278 if (num_waves > 1) and (tile_k * 2 <= upper_bound):
279 tile_k *= 2
280 else:
281 break
282 return tile_k
285def sum_heur_tile_n_non_inner(args):
286 return triton.cdiv(256, args["TILE_K"])
289HEURISTICS_CONFIGS = {
290 "argmax": {
291 "BLOCK_M": argmax_heur_block_m,
292 "BLOCK_N": argmax_heur_block_n,
293 },
294 "argmin": {
295 "BLOCK_M": argmin_heur_block_m,
296 "BLOCK_N": argmin_heur_block_n,
297 },
298 "bmm": {
299 "DIVISIBLE_M": bmm_heur_divisible_m,
300 "DIVISIBLE_N": bmm_heur_divisible_n,
301 "DIVISIBLE_K": bmm_heur_divisible_k,
302 },
303 "dropout": {
304 "BLOCK": dropout_heur_block,
305 "num_warps": dropout_heur_num_warps,
306 },
307 "exponential_": {
308 "BLOCK": exponential_heur_block,
309 "num_warps": exponential_heur_num_warps,
310 },
311 "gather": {
312 "BLOCK_M": gather_heur_block_m,
313 "BLOCK_N": gather_heur_block_n,
314 },
315 "index_select": {
316 "BLOCK_M": index_select_heur_block_m,
317 "BLOCK_N": index_select_heur_block_n,
318 },
319 "mm": {
320 "EVEN_K": mm_heur_even_k,
321 },
322 "rand": {
323 "BLOCK": rand_heur_block,
324 "num_warps": rand_heur_num_warps,
325 },
326 "randn": {
327 "BLOCK": randn_heur_block,
328 "num_warps": randn_heur_num_warps,
329 },
330 "softmax_non_inner": {
331 "TILE_K": softmax_heur_tile_k,
332 "TILE_N": softmax_heur_tile_n_non_inner,
333 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
334 "num_warps": softmax_heur_num_warps_non_inner,
335 },
336 "softmax_inner": {
337 "TILE_N": softmax_heur_tile_n_inner,
338 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
339 "num_warps": softmax_heur_num_warps_inner,
340 },
341 "softmax_backward_non_inner": {
342 "TILE_N": softmax_heur_tile_n_bwd_non_inner,
343 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
344 },
345 "softmax_backward_inner": {
346 "TILE_M": softmax_heur_tile_m,
347 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
348 },
349 "uniform": {
350 "BLOCK": uniform_heur_block,
351 "num_warps": uniform_heur_num_warps,
352 },
353 "upsample_nearest2d": {
354 "SAME_H": upsample_nearest2d_SAME_H,
355 "SAME_W": upsample_nearest2d_SAME_W,
356 "USE_INT32_IDX": upsample_nearest2d_USE_INT32_IDX,
357 },
358 "var_mean": {
359 "BLOCK_N": var_mean_heur_block_n,
360 },
361 "batch_norm": {
362 "BLOCK_M": batch_norm_heur_block_m,
363 "BLOCK_N": batch_norm_heur_block_n,
364 },
365 "vdot": {
366 "BLOCK_SIZE": vdot_heur_block_size,
367 },
368 "mha_varlen_prefill": {
369 "BLOCK_M": lambda args: 128,
370 "BLOCK_N": lambda args: 32,
371 "num_warps": lambda args: 4,
372 "num_stages": lambda args: 3,
373 },
374 "mha_varlen_decode": {
375 "BLOCK_M": lambda args: 16,
376 "BLOCK_N": lambda args: 64,
377 "num_warps": lambda args: 4,
378 "num_stages": lambda args: 3,
379 },
380 "elementwise_generic": {
381 "BLOCK_SIZE": simple_elementwise_blocksize_heur,
382 "num_warps": lambda args: 8,
383 },
384 "sum_inner": {
385 "TILE_N": sum_heur_tile_n_inner,
386 "ONE_TILE_PER_CTA": sum_heur_one_tile_per_cta,
387 "num_warps": sum_heur_num_warps_inner,
388 },
389 "sum_non_inner": {
390 "TILE_K": sum_heur_tile_k,
391 "TILE_N": sum_heur_tile_n_non_inner,
392 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
393 "num_warps": softmax_heur_num_warps_non_inner,
394 },
395}