Coverage for src/flag_gems/runtime/backend/_arm/heuristics_config_utils.py: 0%
122 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
1import triton
4def argmax_heur_block_m(args):
5 return 4 if args["M"] < 4096 else 8
8def argmax_heur_block_n(args):
9 return min(4, triton.next_power_of_2(args["N"]))
12def argmin_heur_block_m(args):
13 return 4 if args["M"] < 4096 else 8
16def argmin_heur_block_n(args):
17 return min(4, triton.next_power_of_2(args["N"]))
20def bmm_heur_divisible_m(args):
21 return args["M"] % args["TILE_M"] == 0
24def bmm_heur_divisible_n(args):
25 return args["N"] % args["TILE_N"] == 0
28def bmm_heur_divisible_k(args):
29 return args["K"] % args["TILE_K"] == 0
32def dropout_heur_block(args):
33 if args["N"] <= 512:
34 return 512
35 else:
36 return 1024
39def dropout_heur_num_warps(args):
40 if args["N"] <= 512:
41 return 4
42 elif args["N"] <= 1024:
43 return 8
44 else:
45 return 16
48def exponential_heur_block(args):
49 if args["N"] <= 512:
50 return 4
51 else:
52 return 8
55def exponential_heur_num_warps(args):
56 if args["N"] <= 512:
57 return 4
58 elif args["N"] <= 1024:
59 return 8
60 else:
61 return 16
64def gather_heur_block_m(args):
65 return min(4, triton.next_power_of_2(triton.cdiv(args["N"], 2048)))
68def gather_heur_block_n(args):
69 return min(16, triton.next_power_of_2(args["N"]))
72def index_select_heur_block_m(args):
73 return min(4, triton.next_power_of_2(triton.cdiv(256, args["N"])))
76def index_select_heur_block_n(args):
77 m = min(triton.next_power_of_2(triton.cdiv(args["N"], 16)), 512)
78 return max(m, 16)
81def mm_heur_even_k(args):
82 return args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0
85def rand_heur_block(args):
86 if args["N"] <= 512:
87 return 4
88 else:
89 return 16
92def rand_heur_num_warps(args):
93 if args["N"] <= 512:
94 return 4
95 elif args["N"] <= 1024:
96 return 8
97 else:
98 return 16
101def randn_heur_block(args):
102 if args["N"] <= 512:
103 return 512
104 else:
105 return 1024
108def randn_heur_num_warps(args):
109 if args["N"] <= 512:
110 return 4
111 elif args["N"] <= 1024:
112 return 8
113 else:
114 return 16
117def softmax_heur_tile_k(args):
118 # MAX_TILE_K = 8192
119 # NUM_SMS = torch.cuda.get_device_properties(
120 # torch.cuda.current_device()
121 # ).multi_processor_count
122 # tile_k = 1
123 # upper_bound = min(args["K"], MAX_TILE_K)
124 # while tile_k <= upper_bound:
125 # num_blocks = args["M"] * triton.cdiv(args["K"], tile_k)
126 # num_waves = num_blocks / NUM_SMS
127 # if (num_waves > 1) and (tile_k * 2 <= upper_bound):
128 # tile_k *= 2
129 # else:
130 # break
131 # return tile_k
132 return 16
135def softmax_heur_tile_n_non_inner(args):
136 # return triton.cdiv(8192, args["TILE_K"])
137 return 16
140def softmax_heur_one_tile_per_cta(args):
141 return args["TILE_N"] >= args["N"]
144def softmax_heur_num_warps_non_inner(args):
145 tile_size = args["TILE_N"] * args["TILE_K"]
146 if tile_size < 2048:
147 return 4
148 elif tile_size < 4096:
149 return 8
150 else:
151 return 16
154def softmax_heur_tile_n_inner(args):
155 # if args["N"] <= (32 * 1024):
156 # return triton.next_power_of_2(args["N"])
157 # else:
158 # return 4096
159 return 4
162def softmax_heur_num_warps_inner(args):
163 tile_size = args["TILE_N"]
164 if tile_size < 2048:
165 return 4
166 elif tile_size < 4096:
167 return 8
168 else:
169 return 16
172def softmax_heur_tile_n_bwd_non_inner(args):
173 return max(1, 1024 // args["TILE_K"])
176def softmax_heur_tile_m(args):
177 return max(1, 1024 // args["TILE_N"])
180def uniform_heur_block(args):
181 if args["N"] <= 512:
182 return 512
183 else:
184 return 1024
187def uniform_heur_num_warps(args):
188 if args["N"] <= 512:
189 return 4
190 elif args["N"] <= 1024:
191 return 8
192 else:
193 return 16
196def var_mean_heur_block_n(args):
197 return triton.next_power_of_2(args["BLOCK_NUM"])
200def upsample_nearest2d_SAME_H(args):
201 return args["OH"] == args["IH"]
204def upsample_nearest2d_SAME_W(args):
205 return args["OW"] == args["IW"]
208def batch_norm_heur_block_m(args):
209 return min(4, triton.next_power_of_2(args["batch_dim"]))
212def batch_norm_heur_block_n(args):
213 # A maximum of 16384 elements are loaded at once.
214 BLOCK_M = batch_norm_heur_block_m(args)
215 BLOCK_N = triton.next_power_of_2(args["spatial_dim"])
216 return min(BLOCK_N, max(1, 2**14 // BLOCK_M))
219def vdot_heur_block_size(args):
220 n = args["n_elements"]
221 if n < 1024:
222 return 32
223 elif n < 8192:
224 return 256
225 else:
226 return 1024
229HEURISTICS_CONFIGS = {
230 "argmax": {
231 "BLOCK_M": argmax_heur_block_m,
232 "BLOCK_N": argmax_heur_block_n,
233 },
234 "argmin": {
235 "BLOCK_M": argmin_heur_block_m,
236 "BLOCK_N": argmin_heur_block_n,
237 },
238 "bmm": {
239 "DIVISIBLE_M": bmm_heur_divisible_m,
240 "DIVISIBLE_N": bmm_heur_divisible_n,
241 "DIVISIBLE_K": bmm_heur_divisible_k,
242 },
243 "dropout": {
244 "BLOCK": dropout_heur_block,
245 "num_warps": dropout_heur_num_warps,
246 },
247 "exponential_": {
248 "BLOCK": exponential_heur_block,
249 "num_warps": exponential_heur_num_warps,
250 },
251 "gather": {
252 "BLOCK_M": gather_heur_block_m,
253 "BLOCK_N": gather_heur_block_n,
254 },
255 "index_select": {
256 "BLOCK_M": index_select_heur_block_m,
257 "BLOCK_N": index_select_heur_block_n,
258 },
259 "mm": {
260 "EVEN_K": mm_heur_even_k,
261 },
262 "rand": {
263 "BLOCK": rand_heur_block,
264 "num_warps": rand_heur_num_warps,
265 },
266 "randn": {
267 "BLOCK": randn_heur_block,
268 "num_warps": randn_heur_num_warps,
269 },
270 "softmax_non_inner": {
271 "TILE_K": softmax_heur_tile_k,
272 "TILE_N": softmax_heur_tile_n_non_inner,
273 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
274 "num_warps": softmax_heur_num_warps_non_inner,
275 },
276 "softmax_inner": {
277 "TILE_N": softmax_heur_tile_n_inner,
278 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
279 "num_warps": softmax_heur_num_warps_inner,
280 },
281 "softmax_backward_non_inner": {
282 "TILE_N": softmax_heur_tile_n_bwd_non_inner,
283 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
284 },
285 "softmax_backward_inner": {
286 "TILE_M": softmax_heur_tile_m,
287 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
288 },
289 "uniform": {
290 "BLOCK": uniform_heur_block,
291 "num_warps": uniform_heur_num_warps,
292 },
293 "upsample_nearest2d": {
294 "SAME_H": upsample_nearest2d_SAME_H,
295 "SAME_W": upsample_nearest2d_SAME_W,
296 },
297 "var_mean": {
298 "BLOCK_N": var_mean_heur_block_n,
299 },
300 "batch_norm": {
301 "BLOCK_M": batch_norm_heur_block_m,
302 "BLOCK_N": batch_norm_heur_block_n,
303 },
304 "vdot": {
305 "BLOCK_SIZE": vdot_heur_block_size,
306 },
307}