Coverage for src/flag_gems/runtime/backend/_iluvatar/heuristics_config_utils.py: 0%
131 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-23 02:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-23 02:03 +0800
1import torch
2import triton
5def simple_elementwise_blocksize_heur(args):
6 return 1024
9def argmax_heur_block_m(args):
10 return 4 if args["M"] < 4096 else 8
13def argmax_heur_block_n(args):
14 return min(4096, triton.next_power_of_2(args["N"]))
17def argmin_heur_block_m(args):
18 return 4 if args["M"] < 4096 else 8
21def argmin_heur_block_n(args):
22 return min(4096, triton.next_power_of_2(args["N"]))
25def dropout_heur_block(args):
26 if args["N"] <= 512:
27 return 512
28 else:
29 return 1024
32def dropout_heur_num_warps(args):
33 if args["N"] <= 512:
34 return 4
35 elif args["N"] <= 1024:
36 return 8
37 else:
38 return 16
41def exponential_heur_block(args):
42 if args["N"] <= 512:
43 return 512
44 else:
45 return 1024
48def exponential_heur_num_warps(args):
49 if args["N"] <= 512:
50 return 4
51 elif args["N"] <= 1024:
52 return 8
53 else:
54 return 16
57def gather_heur_block_m(args):
58 return min(4, triton.next_power_of_2(triton.cdiv(args["N"], 2048)))
61def gather_heur_block_n(args):
62 return min(2048, triton.next_power_of_2(args["N"]))
65def index_select_heur_block_m(args):
66 return min(4, triton.next_power_of_2(triton.cdiv(256, args["N"])))
69def index_select_heur_block_n(args):
70 m = min(triton.next_power_of_2(triton.cdiv(args["N"], 16)), 512)
71 return max(m, 16)
74def mm_heur_even_k(args):
75 return args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0
78def rand_heur_block(args):
79 if args["N"] <= 512:
80 return 512
81 else:
82 return 1024
85def rand_heur_num_warps(args):
86 if args["N"] <= 512:
87 return 4
88 elif args["N"] <= 1024:
89 return 8
90 else:
91 return 16
94def randn_heur_block(args):
95 if args["N"] <= 512:
96 return 512
97 else:
98 return 1024
101def randn_heur_num_warps(args):
102 if args["N"] <= 512:
103 return 4
104 else:
105 return 8
108def softmax_heur_tile_k(args):
109 MAX_TILE_K = 8192
110 NUM_SMS = torch.cuda.get_device_properties(
111 torch.cuda.current_device()
112 ).multi_processor_count
113 tile_k = 1
114 upper_bound = min(args["K"], MAX_TILE_K)
115 while tile_k <= upper_bound:
116 num_blocks = args["M"] * triton.cdiv(args["K"], tile_k)
117 num_waves = num_blocks / NUM_SMS
118 if (num_waves > 1) and (tile_k * 2 <= upper_bound):
119 tile_k *= 2
120 else:
121 break
122 return tile_k
125def softmax_heur_tile_n_non_inner(args):
126 return triton.cdiv(8192, args["TILE_K"])
129def softmax_heur_one_tile_per_cta(args):
130 return args["TILE_N"] >= args["N"]
133def softmax_heur_num_warps_non_inner(args):
134 tile_size = args["TILE_N"] * args["TILE_K"]
135 if tile_size < 2048:
136 return 4
137 elif tile_size < 4096:
138 return 8
139 else:
140 return 16
143def softmax_heur_tile_n_inner(args):
144 if args["N"] <= (32 * 1024):
145 return triton.next_power_of_2(args["N"])
146 else:
147 return 4096
150def softmax_heur_num_warps_inner(args):
151 tile_size = args["TILE_N"]
152 if tile_size < 2048:
153 return 4
154 elif tile_size < 4096:
155 return 8
156 else:
157 return 16
160def softmax_heur_tile_n_bwd_non_inner(args):
161 return max(1, 1024 // args["TILE_K"])
164def softmax_heur_tile_m(args):
165 return max(1, 1024 // args["TILE_N"])
168def uniform_heur_block(args):
169 if args["N"] <= 512:
170 return 512
171 else:
172 return 1024
175def uniform_heur_num_warps(args):
176 if args["N"] <= 512:
177 return 4
178 elif args["N"] <= 1024:
179 return 8
180 else:
181 return 16
184def var_mean_heur_block_n(args):
185 return triton.next_power_of_2(args["BLOCK_NUM"])
188def upsample_nearest2d_SAME_H(args):
189 return args["OH"] == args["IH"]
192def upsample_nearest2d_SAME_W(args):
193 return args["OW"] == args["IW"]
196def upsample_nearest2d_USE_INT32_IDX(args):
197 return args["N"] * args["C"] * args["OH"] * args["OW"] <= (2**31 - 1) # INT32 MAX
200def batch_norm_heur_block_m(args):
201 return min(2048, triton.next_power_of_2(args["batch_dim"]))
204def batch_norm_heur_block_n(args):
205 # A maximum of 16384 elements are loaded at once.
206 BLOCK_M = batch_norm_heur_block_m(args)
207 BLOCK_N = triton.next_power_of_2(args["spatial_dim"])
208 return min(BLOCK_N, max(1, 2**14 // BLOCK_M))
211def vdot_heur_block_size(args):
212 n = args["n_elements"]
213 if n < 1024:
214 return 32
215 elif n < 8192:
216 return 256
217 else:
218 return 1024
221HEURISTICS_CONFIGS = {
222 "argmax": {
223 "BLOCK_M": argmax_heur_block_m,
224 "BLOCK_N": argmax_heur_block_n,
225 },
226 "argmin": {
227 "BLOCK_M": argmin_heur_block_m,
228 "BLOCK_N": argmin_heur_block_n,
229 },
230 "dropout": {
231 "BLOCK": dropout_heur_block,
232 "num_warps": dropout_heur_num_warps,
233 },
234 "exponential_": {
235 "BLOCK": exponential_heur_block,
236 "num_warps": exponential_heur_num_warps,
237 },
238 "gather": {
239 "BLOCK_M": gather_heur_block_m,
240 "BLOCK_N": gather_heur_block_n,
241 },
242 "index_select": {
243 "BLOCK_M": index_select_heur_block_m,
244 "BLOCK_N": index_select_heur_block_n,
245 },
246 "mm": {
247 "EVEN_K": mm_heur_even_k,
248 },
249 "rand": {
250 "BLOCK": rand_heur_block,
251 "num_warps": rand_heur_num_warps,
252 },
253 "randn": {
254 "BLOCK": randn_heur_block,
255 "num_warps": randn_heur_num_warps,
256 },
257 "softmax_non_inner": {
258 "TILE_K": softmax_heur_tile_k,
259 "TILE_N": softmax_heur_tile_n_non_inner,
260 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
261 "num_warps": softmax_heur_num_warps_non_inner,
262 },
263 "softmax_inner": {
264 "TILE_N": softmax_heur_tile_n_inner,
265 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
266 "num_warps": softmax_heur_num_warps_inner,
267 },
268 "softmax_backward_non_inner": {
269 "TILE_N": softmax_heur_tile_n_bwd_non_inner,
270 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
271 },
272 "softmax_backward_inner": {
273 "TILE_M": softmax_heur_tile_m,
274 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
275 },
276 "uniform": {
277 "BLOCK": uniform_heur_block,
278 "num_warps": uniform_heur_num_warps,
279 },
280 "upsample_nearest2d": {
281 "SAME_H": upsample_nearest2d_SAME_H,
282 "SAME_W": upsample_nearest2d_SAME_W,
283 "USE_INT32_IDX": upsample_nearest2d_USE_INT32_IDX,
284 },
285 "var_mean": {
286 "BLOCK_N": var_mean_heur_block_n,
287 },
288 "batch_norm": {
289 "BLOCK_M": batch_norm_heur_block_m,
290 "BLOCK_N": batch_norm_heur_block_n,
291 },
292 "vdot": {
293 "BLOCK_SIZE": vdot_heur_block_size,
294 },
295 "mha_block_128": {
296 "BLOCK_M": lambda args: 128,
297 "BLOCK_N": lambda args: 32,
298 "num_warps": lambda args: 4,
299 "num_stages": lambda args: 3,
300 },
301 "mha_block_64": {
302 "BLOCK_M": lambda args: 64,
303 "BLOCK_N": lambda args: 32,
304 "num_warps": lambda args: 4,
305 "num_stages": lambda args: 3,
306 },
307 "mha_block_32": {
308 "BLOCK_M": lambda args: 32,
309 "BLOCK_N": lambda args: 32,
310 "num_warps": lambda args: 4,
311 "num_stages": lambda args: 3,
312 },
313 "mha_block_16": {
314 "BLOCK_M": lambda args: 16,
315 "BLOCK_N": lambda args: 32,
316 "num_warps": lambda args: 4,
317 "num_stages": lambda args: 3,
318 },
319 "elementwise_generic": {
320 "BLOCK_SIZE": simple_elementwise_blocksize_heur,
321 "num_warps": lambda args: 16,
322 },
323}