Coverage for src/flag_gems/runtime/backend/_amd/heuristics_config_utils.py: 0%
139 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-16 02:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-16 02:02 +0800
1import torch
2import triton
5def simple_elementwise_blocksize_heur(args):
6 return 1024
9def argmax_heur_block_m(args):
10 return 4 if args["M"] < 4096 else 8
13def argmax_heur_block_n(args):
14 return min(4096, triton.next_power_of_2(args["N"]))
17def argmin_heur_block_m(args):
18 return 4 if args["M"] < 4096 else 8
21def argmin_heur_block_n(args):
22 return min(4096, triton.next_power_of_2(args["N"]))
25def bmm_heur_divisible_m(args):
26 return args["M"] % args["TILE_M"] == 0
29def bmm_heur_divisible_n(args):
30 return args["N"] % args["TILE_N"] == 0
33def bmm_heur_divisible_k(args):
34 return args["K"] % args["TILE_K"] == 0
37def dropout_heur_block(args):
38 if args["N"] <= 512:
39 return 512
40 else:
41 return 1024
44def dropout_heur_num_warps(args):
45 if args["N"] <= 512:
46 return 4
47 elif args["N"] <= 1024:
48 return 8
49 else:
50 return 16
53def exponential_heur_block(args):
54 if args["N"] <= 512:
55 return 512
56 else:
57 return 1024
60def exponential_heur_num_warps(args):
61 if args["N"] <= 512:
62 return 4
63 elif args["N"] <= 1024:
64 return 8
65 else:
66 return 16
69def gather_heur_block_m(args):
70 return min(4, triton.next_power_of_2(triton.cdiv(args["N"], 2048)))
73def gather_heur_block_n(args):
74 return min(2048, triton.next_power_of_2(args["N"]))
77def index_select_heur_block_m(args):
78 return min(4, triton.next_power_of_2(triton.cdiv(256, args["N"])))
81def index_select_heur_block_n(args):
82 m = min(triton.next_power_of_2(triton.cdiv(args["N"], 16)), 512)
83 return max(m, 16)
86def mm_heur_even_k(args):
87 return args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0
90def rand_heur_block(args):
91 if args["N"] <= 512:
92 return 512
93 else:
94 return 1024
97def rand_heur_num_warps(args):
98 if args["N"] <= 512:
99 return 4
100 elif args["N"] <= 1024:
101 return 8
102 else:
103 return 16
106def randn_heur_block(args):
107 if args["N"] <= 512:
108 return 512
109 else:
110 return 1024
113def randn_heur_num_warps(args):
114 if args["N"] <= 512:
115 return 4
116 elif args["N"] <= 1024:
117 return 8
118 else:
119 return 16
122def softmax_heur_tile_k(args):
123 MAX_TILE_K = 8192
124 NUM_SMS = torch.cuda.get_device_properties(
125 torch.cuda.current_device()
126 ).multi_processor_count
127 tile_k = 1
128 upper_bound = min(args["K"], MAX_TILE_K)
129 while tile_k <= upper_bound:
130 num_blocks = args["M"] * triton.cdiv(args["K"], tile_k)
131 num_waves = num_blocks / NUM_SMS
132 if (num_waves > 1) and (tile_k * 2 <= upper_bound):
133 tile_k *= 2
134 else:
135 break
136 return tile_k
139def softmax_heur_tile_n_non_inner(args):
140 return triton.cdiv(8192, args["TILE_K"])
143def softmax_heur_one_tile_per_cta(args):
144 return args["TILE_N"] >= args["N"]
147def softmax_heur_num_warps_non_inner(args):
148 tile_size = args["TILE_N"] * args["TILE_K"]
149 if tile_size < 2048:
150 return 4
151 elif tile_size < 4096:
152 return 8
153 else:
154 return 16
157def softmax_heur_tile_n_inner(args):
158 if args["N"] <= (32 * 1024):
159 return triton.next_power_of_2(args["N"])
160 else:
161 return 4096
164def softmax_heur_num_warps_inner(args):
165 tile_size = args["TILE_N"]
166 if tile_size < 2048:
167 return 4
168 elif tile_size < 4096:
169 return 8
170 else:
171 return 16
174def softmax_heur_tile_n_bwd_non_inner(args):
175 return max(1, 1024 // args["TILE_K"])
178def softmax_heur_tile_m(args):
179 return max(1, 1024 // args["TILE_N"])
182def uniform_heur_block(args):
183 if args["N"] <= 512:
184 return 512
185 else:
186 return 1024
189def uniform_heur_num_warps(args):
190 if args["N"] <= 512:
191 return 4
192 elif args["N"] <= 1024:
193 return 8
194 else:
195 return 16
198def var_mean_heur_block_n(args):
199 return triton.next_power_of_2(args["BLOCK_NUM"])
202def upsample_nearest2d_SAME_H(args):
203 return args["OH"] == args["IH"]
206def upsample_nearest2d_SAME_W(args):
207 return args["OW"] == args["IW"]
210def upsample_nearest2d_USE_INT32_IDX(args):
211 return args["N"] * args["C"] * args["OH"] * args["OW"] <= (2**31 - 1) # INT32 MAX
214def batch_norm_heur_block_m(args):
215 return min(2048, triton.next_power_of_2(args["batch_dim"]))
218def batch_norm_heur_block_n(args):
219 # A maximum of 16384 elements are loaded at once.
220 BLOCK_M = batch_norm_heur_block_m(args)
221 BLOCK_N = triton.next_power_of_2(args["spatial_dim"])
222 return min(BLOCK_N, max(1, 2**14 // BLOCK_M))
225def vdot_heur_block_size(args):
226 n = args["n_elements"]
227 if n < 1024:
228 return 32
229 elif n < 8192:
230 return 256
231 else:
232 return 1024
235HEURISTICS_CONFIGS = {
236 "argmax": {
237 "BLOCK_M": argmax_heur_block_m,
238 "BLOCK_N": argmax_heur_block_n,
239 },
240 "argmin": {
241 "BLOCK_M": argmin_heur_block_m,
242 "BLOCK_N": argmin_heur_block_n,
243 },
244 "bmm": {
245 "DIVISIBLE_M": bmm_heur_divisible_m,
246 "DIVISIBLE_N": bmm_heur_divisible_n,
247 "DIVISIBLE_K": bmm_heur_divisible_k,
248 },
249 "dropout": {
250 "BLOCK": dropout_heur_block,
251 "num_warps": dropout_heur_num_warps,
252 },
253 "exponential_": {
254 "BLOCK": exponential_heur_block,
255 "num_warps": exponential_heur_num_warps,
256 },
257 "gather": {
258 "BLOCK_M": gather_heur_block_m,
259 "BLOCK_N": gather_heur_block_n,
260 },
261 "index_select": {
262 "BLOCK_M": index_select_heur_block_m,
263 "BLOCK_N": index_select_heur_block_n,
264 },
265 "mm": {
266 "EVEN_K": mm_heur_even_k,
267 },
268 "rand": {
269 "BLOCK": rand_heur_block,
270 "num_warps": rand_heur_num_warps,
271 },
272 "randn": {
273 "BLOCK": randn_heur_block,
274 "num_warps": randn_heur_num_warps,
275 },
276 "softmax_non_inner": {
277 "TILE_K": softmax_heur_tile_k,
278 "TILE_N": softmax_heur_tile_n_non_inner,
279 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
280 "num_warps": softmax_heur_num_warps_non_inner,
281 },
282 "softmax_inner": {
283 "TILE_N": softmax_heur_tile_n_inner,
284 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
285 "num_warps": softmax_heur_num_warps_inner,
286 },
287 "softmax_backward_non_inner": {
288 "TILE_N": softmax_heur_tile_n_bwd_non_inner,
289 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
290 },
291 "softmax_backward_inner": {
292 "TILE_M": softmax_heur_tile_m,
293 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
294 },
295 "uniform": {
296 "BLOCK": uniform_heur_block,
297 "num_warps": uniform_heur_num_warps,
298 },
299 "upsample_nearest2d": {
300 "SAME_H": upsample_nearest2d_SAME_H,
301 "SAME_W": upsample_nearest2d_SAME_W,
302 "USE_INT32_IDX": upsample_nearest2d_USE_INT32_IDX,
303 },
304 "var_mean": {
305 "BLOCK_N": var_mean_heur_block_n,
306 },
307 "batch_norm": {
308 "BLOCK_M": batch_norm_heur_block_m,
309 "BLOCK_N": batch_norm_heur_block_n,
310 },
311 "vdot": {
312 "BLOCK_SIZE": vdot_heur_block_size,
313 },
314 "mha_block_128": {
315 "BLOCK_M": lambda args: 128,
316 "BLOCK_N": lambda args: 32,
317 "num_warps": lambda args: 4,
318 "num_stages": lambda args: 3,
319 },
320 "mha_block_64": {
321 "BLOCK_M": lambda args: 64,
322 "BLOCK_N": lambda args: 64,
323 "num_warps": lambda args: 4,
324 "num_stages": lambda args: 3,
325 },
326 "mha_block_32": {
327 "BLOCK_M": lambda args: 32,
328 "BLOCK_N": lambda args: 64,
329 "num_warps": lambda args: 4,
330 "num_stages": lambda args: 3,
331 },
332 "mha_block_16": {
333 "BLOCK_M": lambda args: 16,
334 "BLOCK_N": lambda args: 64,
335 "num_warps": lambda args: 4,
336 "num_stages": lambda args: 3,
337 },
338 "elementwise_generic": {
339 "BLOCK_SIZE": simple_elementwise_blocksize_heur,
340 "num_warps": lambda args: 8,
341 },
342}