Coverage for src/flag_gems/runtime/backend/_aipu/heuristics_config_utils.py: 0%
136 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-13 10:08 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-13 10:08 +0800
1import triton
4def simple_elementwise_blocksize_heur(args):
5 return 1024
8def argmax_heur_block_m(args):
9 return 4 if args["M"] < 4096 else 8
12def argmax_heur_block_n(args):
13 return min(4096, triton.next_power_of_2(args["N"]))
16def argmin_heur_block_m(args):
17 return 4 if args["M"] < 4096 else 8
20def argmin_heur_block_n(args):
21 return min(4096, triton.next_power_of_2(args["N"]))
24def bmm_heur_divisible_m(args):
25 return args["M"] % args["TILE_M"] == 0
28def bmm_heur_divisible_n(args):
29 return args["N"] % args["TILE_N"] == 0
32def bmm_heur_divisible_k(args):
33 return args["K"] % args["TILE_K"] == 0
36def dropout_heur_block(args):
37 if args["N"] <= 512:
38 return 512
39 else:
40 return 1024
43def dropout_heur_num_warps(args):
44 if args["N"] <= 512:
45 return 4
46 elif args["N"] <= 1024:
47 return 8
48 else:
49 return 16
52def exponential_heur_block(args):
53 if args["N"] <= 512:
54 return 512
55 else:
56 return 1024
59def exponential_heur_num_warps(args):
60 if args["N"] <= 512:
61 return 4
62 elif args["N"] <= 1024:
63 return 8
64 else:
65 return 16
68def gather_heur_block_m(args):
69 return min(4, triton.next_power_of_2(triton.cdiv(args["N"], 2048)))
72def gather_heur_block_n(args):
73 return min(2048, triton.next_power_of_2(args["N"]))
76def index_select_heur_block_m(args):
77 return min(4, triton.next_power_of_2(triton.cdiv(256, args["N"])))
80def index_select_heur_block_n(args):
81 m = min(triton.next_power_of_2(triton.cdiv(args["N"], 16)), 512)
82 return max(m, 16)
85def mm_heur_even_k(args):
86 return args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0
89def rand_heur_block(args):
90 if args["N"] <= 512:
91 return 512
92 else:
93 return 1024
96def rand_heur_num_warps(args):
97 if args["N"] <= 512:
98 return 4
99 elif args["N"] <= 1024:
100 return 8
101 else:
102 return 16
105def randn_heur_block(args):
106 if args["N"] <= 512:
107 return 512
108 else:
109 return 1024
112def randn_heur_num_warps(args):
113 if args["N"] <= 512:
114 return 4
115 elif args["N"] <= 1024:
116 return 8
117 else:
118 return 16
121def softmax_heur_tile_k(args):
122 MAX_TILE_K = 8192
123 NUM_SMS = 4
124 tile_k = 1
125 upper_bound = min(args["K"], MAX_TILE_K)
126 while tile_k <= upper_bound:
127 num_blocks = args["M"] * triton.cdiv(args["K"], tile_k)
128 num_waves = num_blocks / NUM_SMS
129 if (num_waves > 1) and (tile_k * 2 <= upper_bound):
130 tile_k *= 2
131 else:
132 break
133 return tile_k
136def softmax_heur_tile_n_non_inner(args):
137 return triton.cdiv(8192, args["TILE_K"])
140def softmax_heur_one_tile_per_cta(args):
141 return args["TILE_N"] >= args["N"]
144def softmax_heur_num_warps_non_inner(args):
145 tile_size = args["TILE_N"] * args["TILE_K"]
146 if tile_size < 2048:
147 return 4
148 elif tile_size < 4096:
149 return 8
150 else:
151 return 16
154def softmax_heur_tile_n_inner(args):
155 if args["N"] <= (32 * 1024):
156 return triton.next_power_of_2(args["N"])
157 else:
158 return 4096
161def softmax_heur_num_warps_inner(args):
162 tile_size = args["TILE_N"]
163 if tile_size < 2048:
164 return 4
165 elif tile_size < 4096:
166 return 8
167 else:
168 return 16
171def softmax_heur_tile_n_bwd_non_inner(args):
172 return max(1, 1024 // args["TILE_K"])
175def softmax_heur_tile_m(args):
176 return max(1, 1024 // args["TILE_N"])
179def uniform_heur_block(args):
180 if args["N"] <= 512:
181 return 512
182 else:
183 return 1024
186def uniform_heur_num_warps(args):
187 if args["N"] <= 512:
188 return 4
189 elif args["N"] <= 1024:
190 return 8
191 else:
192 return 16
195def var_mean_heur_block_n(args):
196 return triton.next_power_of_2(args["BLOCK_NUM"])
199def upsample_nearest2d_SAME_H(args):
200 return args["OH"] == args["IH"]
203def upsample_nearest2d_SAME_W(args):
204 return args["OW"] == args["IW"]
207def batch_norm_heur_block_m(args):
208 return min(2048, triton.next_power_of_2(args["batch_dim"]))
211def batch_norm_heur_block_n(args):
212 # A maximum of 16384 elements are loaded at once.
213 BLOCK_M = batch_norm_heur_block_m(args)
214 BLOCK_N = triton.next_power_of_2(args["spatial_dim"])
215 return min(BLOCK_N, max(1, 2**14 // BLOCK_M))
218def vdot_heur_block_size(args):
219 n = args["n_elements"]
220 if n < 1024:
221 return 32
222 elif n < 8192:
223 return 256
224 else:
225 return 1024
228HEURISTICS_CONFIGS = {
229 "argmax": {
230 "BLOCK_M": argmax_heur_block_m,
231 "BLOCK_N": argmax_heur_block_n,
232 },
233 "argmin": {
234 "BLOCK_M": argmin_heur_block_m,
235 "BLOCK_N": argmin_heur_block_n,
236 },
237 "bmm": {
238 "DIVISIBLE_M": bmm_heur_divisible_m,
239 "DIVISIBLE_N": bmm_heur_divisible_n,
240 "DIVISIBLE_K": bmm_heur_divisible_k,
241 },
242 "dropout": {
243 "BLOCK": dropout_heur_block,
244 "num_warps": dropout_heur_num_warps,
245 },
246 "exponential_": {
247 "BLOCK": exponential_heur_block,
248 "num_warps": exponential_heur_num_warps,
249 },
250 "gather": {
251 "BLOCK_M": gather_heur_block_m,
252 "BLOCK_N": gather_heur_block_n,
253 },
254 "index_select": {
255 "BLOCK_M": index_select_heur_block_m,
256 "BLOCK_N": index_select_heur_block_n,
257 },
258 "mm": {
259 "EVEN_K": mm_heur_even_k,
260 },
261 "rand": {
262 "BLOCK": rand_heur_block,
263 "num_warps": rand_heur_num_warps,
264 },
265 "randn": {
266 "BLOCK": randn_heur_block,
267 "num_warps": randn_heur_num_warps,
268 },
269 "softmax_non_inner": {
270 "TILE_K": softmax_heur_tile_k,
271 "TILE_N": softmax_heur_tile_n_non_inner,
272 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
273 "num_warps": softmax_heur_num_warps_non_inner,
274 },
275 "softmax_inner": {
276 "TILE_N": softmax_heur_tile_n_inner,
277 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
278 "num_warps": softmax_heur_num_warps_inner,
279 },
280 "softmax_backward_non_inner": {
281 "TILE_N": softmax_heur_tile_n_bwd_non_inner,
282 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
283 },
284 "softmax_backward_inner": {
285 "TILE_M": softmax_heur_tile_m,
286 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
287 },
288 "uniform": {
289 "BLOCK": uniform_heur_block,
290 "num_warps": uniform_heur_num_warps,
291 },
292 "upsample_nearest2d": {
293 "SAME_H": upsample_nearest2d_SAME_H,
294 "SAME_W": upsample_nearest2d_SAME_W,
295 },
296 "var_mean": {
297 "BLOCK_N": var_mean_heur_block_n,
298 },
299 "batch_norm": {
300 "BLOCK_M": batch_norm_heur_block_m,
301 "BLOCK_N": batch_norm_heur_block_n,
302 },
303 "vdot": {
304 "BLOCK_SIZE": vdot_heur_block_size,
305 },
306 "elementwise_generic": {
307 "BLOCK_SIZE": simple_elementwise_blocksize_heur,
308 "num_warps": lambda args: 8,
309 },
310}