Coverage for src/flag_gems/runtime/backend/_ascend/heuristics_config_utils.py: 0%
143 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
1import triton
4def argmax_heur_block_m(args):
5 return 16
8def argmax_heur_block_n(args):
9 return 100
12def argmax_heur_tile_k(args):
13 tile_k = 64
14 return tile_k
17def argmax_heur_tile_n_non_inner(args):
18 return 128
21def argmax_heur_one_tile_per_cta(args):
22 return args["TILE_N"] >= args["N"]
25def argmin_heur_block_m(args):
26 return 16
29def argmin_heur_block_n(args):
30 return 100
33def bmm_heur_divisible_m(args):
34 return args["M"] % args["TILE_M"] == 0
37def bmm_heur_divisible_n(args):
38 return args["N"] % args["TILE_N"] == 0
41def bmm_heur_divisible_k(args):
42 return args["K"] % args["TILE_K"] == 0
45def dropout_heur_block(args):
46 if args["N"] <= 512:
47 return 512
48 else:
49 return 4096
52def dropout_heur_num_warps(args):
53 if args["N"] <= 512:
54 return 4
55 elif args["N"] <= 1024:
56 return 8
57 else:
58 return 16
61def exponential_heur_block(args):
62 if args["N"] <= 512:
63 return 512
64 else:
65 return 1024
68def exponential_heur_num_warps(args):
69 if args["N"] <= 512:
70 return 4
71 elif args["N"] <= 1024:
72 return 8
73 else:
74 return 16
77def gather_heur_block_m(args):
78 return min(4, triton.next_power_of_2(triton.cdiv(args["N"], 2048)))
81def gather_heur_block_n(args):
82 return min(2048, triton.next_power_of_2(args["N"]))
85def index_select_heur_block_m(args):
86 return min(4, triton.next_power_of_2(triton.cdiv(256, args["N"])))
89def index_select_heur_block_n(args):
90 m = min(triton.next_power_of_2(triton.cdiv(args["N"], 16)), 512)
91 return max(m, 16)
94def mm_heur_even_k(args):
95 return args["K"] % (args["BLOCK_K"]) == 0
98def rand_heur_block(args):
99 if args["N"] <= 512:
100 return 2048
101 else:
102 return 4097
105def rand_heur_num_warps(args):
106 if args["N"] <= 512:
107 return 4
108 elif args["N"] <= 1024:
109 return 8
110 else:
111 return 16
114def randn_heur_block(args):
115 if args["N"] <= 512:
116 return 2048
117 else:
118 return 4097
121def randn_heur_num_warps(args):
122 if args["N"] <= 512:
123 return 4
124 elif args["N"] <= 1024:
125 return 8
126 else:
127 return 16
130def softmax_heur_tile_k(args):
131 MAX_TILE_K = 4096
132 # FIXME:
133 # NUM_SMS should be obtained by API.
134 # It is actually the number of AIV cores which depends on the Ascend version.
135 NUM_SMS = 40
136 tile_k = 1
137 upper_bound = min(args["K"], MAX_TILE_K)
138 while tile_k <= upper_bound:
139 num_blocks = args["M"] * triton.cdiv(args["K"], tile_k)
140 num_waves = num_blocks / NUM_SMS
141 if (num_waves > 1) and (tile_k * 2 <= upper_bound):
142 tile_k *= 2
143 else:
144 break
145 return tile_k
148def softmax_heur_tile_n_non_inner(args):
149 return triton.cdiv(768, args["TILE_K"])
152def softmax_heur_one_tile_per_cta(args):
153 return args["TILE_N"] >= args["N"]
156def softmax_heur_num_warps_non_inner(args):
157 tile_size = args["TILE_N"] * args["TILE_K"]
158 if tile_size < 2048:
159 return 4
160 elif tile_size < 4096:
161 return 8
162 else:
163 return 16
166def softmax_heur_tile_n_inner(args):
167 if args["N"] <= (32 * 1024):
168 return triton.next_power_of_2(args["N"])
169 else:
170 return 4096
173def softmax_heur_num_warps_inner(args):
174 tile_size = args["TILE_N"]
175 if tile_size < 2048:
176 return 4
177 elif tile_size < 4096:
178 return 8
179 else:
180 return 16
183def softmax_heur_tile_n_bwd_non_inner(args):
184 return max(1, 1024 // args["TILE_K"])
187def softmax_heur_tile_m(args):
188 return max(1, 1024 // args["TILE_N"])
191def uniform_heur_block(args):
192 if args["N"] <= 512:
193 return 512
194 elif args["N"] >= 1073741824:
195 return 4097
196 else:
197 return 1024
200def uniform_heur_num_warps(args):
201 if args["N"] <= 512:
202 return 4
203 elif args["N"] <= 1024:
204 return 8
205 else:
206 return 16
209def var_mean_heur_block_n(args):
210 return triton.next_power_of_2(args["BLOCK_NUM"])
213def upsample_nearest2d_SAME_H(args):
214 return args["OH"] == args["IH"]
217def upsample_nearest2d_SAME_W(args):
218 return args["OW"] == args["IW"]
221def batch_norm_heur_block_m(args):
222 return min(128, triton.next_power_of_2(args["batch_dim"]))
225def batch_norm_heur_block_n(args):
226 # A maximum of 4096 elements are loaded at once.
227 BLOCK_M = batch_norm_heur_block_m(args)
228 BLOCK_N = triton.next_power_of_2(args["spatial_dim"])
229 return min(BLOCK_N, max(1, 2**12 // BLOCK_M))
232def vdot_heur_block_size(args):
233 n = args["n_elements"]
234 if n < 1024:
235 return 32
236 elif n < 8192:
237 return 256
238 else:
239 return 1024
242HEURISTICS_CONFIGS = {
243 "argmax_non_inner": {
244 "TILE_K": argmax_heur_tile_k,
245 "TILE_N": argmax_heur_tile_n_non_inner,
246 "ONE_TILE_PER_CTA": argmax_heur_one_tile_per_cta,
247 },
248 "argmax": {
249 "BLOCK_M": argmax_heur_block_m,
250 "BLOCK_N": argmax_heur_block_n,
251 },
252 "argmin": {
253 "BLOCK_M": argmin_heur_block_m,
254 "BLOCK_N": argmin_heur_block_n,
255 },
256 "bmm": {
257 "DIVISIBLE_M": bmm_heur_divisible_m,
258 "DIVISIBLE_N": bmm_heur_divisible_n,
259 "DIVISIBLE_K": bmm_heur_divisible_k,
260 },
261 "dropout": {
262 "BLOCK": dropout_heur_block,
263 "num_warps": dropout_heur_num_warps,
264 },
265 "exponential_": {
266 "BLOCK": exponential_heur_block,
267 "num_warps": exponential_heur_num_warps,
268 },
269 "gather": {
270 "BLOCK_M": gather_heur_block_m,
271 "BLOCK_N": gather_heur_block_n,
272 },
273 "index_select": {
274 "BLOCK_M": index_select_heur_block_m,
275 "BLOCK_N": index_select_heur_block_n,
276 },
277 "mm": {
278 "EVEN_K": mm_heur_even_k,
279 },
280 "rand": {
281 "BLOCK": rand_heur_block,
282 "num_warps": rand_heur_num_warps,
283 },
284 "randn": {
285 "BLOCK": randn_heur_block,
286 "num_warps": randn_heur_num_warps,
287 },
288 "softmax_non_inner": {
289 "TILE_K": softmax_heur_tile_k,
290 "TILE_N": softmax_heur_tile_n_non_inner,
291 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
292 "num_warps": softmax_heur_num_warps_non_inner,
293 },
294 "softmax_inner": {
295 "TILE_N": softmax_heur_tile_n_inner,
296 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
297 "num_warps": softmax_heur_num_warps_inner,
298 },
299 "softmax_backward_non_inner": {
300 "TILE_N": softmax_heur_tile_n_bwd_non_inner,
301 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
302 },
303 "softmax_backward_inner": {
304 "TILE_M": softmax_heur_tile_m,
305 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
306 },
307 "uniform": {
308 "BLOCK": uniform_heur_block,
309 "num_warps": uniform_heur_num_warps,
310 },
311 "upsample_nearest2d": {
312 "SAME_H": upsample_nearest2d_SAME_H,
313 "SAME_W": upsample_nearest2d_SAME_W,
314 },
315 "var_mean": {
316 "BLOCK_N": var_mean_heur_block_n,
317 },
318 "batch_norm": {
319 "BLOCK_M": batch_norm_heur_block_m,
320 "BLOCK_N": batch_norm_heur_block_n,
321 },
322 "vdot": {
323 "BLOCK_SIZE": vdot_heur_block_size,
324 },
325}