Coverage for src/flag_gems/runtime/backend/_hygon/heuristics_config_utils.py: 0%
139 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
1import torch
2import triton
5def simple_elementwise_blocksize_heur(args):
6 return 1024
9def argmax_heur_block_m(args):
10 return 4 if args["M"] < 4096 else 8
13def argmax_heur_block_n(args):
14 return min(4096, triton.next_power_of_2(args["N"]))
17def argmin_heur_block_m(args):
18 return 4 if args["M"] < 4096 else 8
21def argmin_heur_block_n(args):
22 return min(4096, triton.next_power_of_2(args["N"]))
25def bmm_heur_divisible_m(args):
26 return args["M"] % args["TILE_M"] == 0
29def bmm_heur_divisible_n(args):
30 return args["N"] % args["TILE_N"] == 0
33def bmm_heur_divisible_k(args):
34 return args["K"] % args["TILE_K"] == 0
37def dropout_heur_block(args):
38 if args["N"] <= 512:
39 return 512
40 elif args["N"] <= 1024:
41 return 1024
42 else:
43 return 4096
46def dropout_heur_num_warps(args):
47 if args["N"] <= 512:
48 return 4
49 elif args["N"] <= 1024:
50 return 8
51 else:
52 return 16
55def exponential_heur_block(args):
56 if args["N"] <= 512:
57 return 512
58 else:
59 return 1024
62def exponential_heur_num_warps(args):
63 if args["N"] <= 512:
64 return 4
65 elif args["N"] <= 1024:
66 return 8
67 else:
68 return 16
71def gather_heur_block_m(args):
72 return min(4, triton.next_power_of_2(triton.cdiv(args["N"], 2048)))
75def gather_heur_block_n(args):
76 return min(2048, triton.next_power_of_2(args["N"]))
79def index_select_heur_block_m(args):
80 return min(4, triton.next_power_of_2(triton.cdiv(256, args["N"])))
83def index_select_heur_block_n(args):
84 m = min(triton.next_power_of_2(triton.cdiv(args["N"], 16)), 512)
85 return max(m, 16)
88def mm_heur_even_k(args):
89 return args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0
92def rand_heur_block(args):
93 if args["N"] <= 512:
94 return 512
95 else:
96 return 1024
99def rand_heur_num_warps(args):
100 if args["N"] <= 512:
101 return 4
102 elif args["N"] <= 1024:
103 return 8
104 else:
105 return 16
108def randn_heur_block(args):
109 if args["N"] <= 512:
110 return 512
111 else:
112 return 1024
115def randn_heur_num_warps(args):
116 if args["N"] <= 512:
117 return 4
118 elif args["N"] <= 1024:
119 return 8
120 else:
121 return 16
124def softmax_heur_tile_k(args):
125 MAX_TILE_K = 8192
126 NUM_SMS = torch.cuda.get_device_properties(
127 torch.cuda.current_device()
128 ).multi_processor_count
129 tile_k = 1
130 upper_bound = min(args["K"], MAX_TILE_K)
131 while tile_k <= upper_bound:
132 num_blocks = args["M"] * triton.cdiv(args["K"], tile_k)
133 num_waves = num_blocks / NUM_SMS
134 if (num_waves > 1) and (tile_k * 2 <= upper_bound):
135 tile_k *= 2
136 else:
137 break
138 return tile_k
141def softmax_heur_tile_n_non_inner(args):
142 return triton.cdiv(8192, args["TILE_K"])
145def softmax_heur_one_tile_per_cta(args):
146 return args["TILE_N"] >= args["N"]
149def softmax_heur_num_warps_non_inner(args):
150 tile_size = args["TILE_N"] * args["TILE_K"]
151 if tile_size < 2048:
152 return 4
153 elif tile_size < 4096:
154 return 8
155 else:
156 return 16
159def softmax_heur_tile_n_inner(args):
160 if args["N"] <= (32 * 1024):
161 return triton.next_power_of_2(args["N"])
162 else:
163 return 4096
166def softmax_heur_num_warps_inner(args):
167 tile_size = args["TILE_N"]
168 if tile_size < 2048:
169 return 4
170 elif tile_size < 4096:
171 return 8
172 else:
173 return 16
176def softmax_heur_tile_n_bwd_non_inner(args):
177 return max(1, 1024 // args["TILE_K"])
180def softmax_heur_tile_m(args):
181 return max(1, 1024 // args["TILE_N"])
184def uniform_heur_block(args):
185 if args["N"] <= 512:
186 return 512
187 else:
188 return 1024
191def uniform_heur_num_warps(args):
192 if args["N"] <= 512:
193 return 4
194 elif args["N"] <= 1024:
195 return 8
196 else:
197 return 16
200def var_mean_heur_block_n(args):
201 return triton.next_power_of_2(args["BLOCK_NUM"])
204def upsample_nearest2d_SAME_H(args):
205 return args["OH"] == args["IH"]
208def upsample_nearest2d_SAME_W(args):
209 return args["OW"] == args["IW"]
212def batch_norm_heur_block_m(args):
213 return min(2048, triton.next_power_of_2(args["batch_dim"]))
216def batch_norm_heur_block_n(args):
217 # A maximum of 16384 elements are loaded at once.
218 BLOCK_M = batch_norm_heur_block_m(args)
219 BLOCK_N = triton.next_power_of_2(args["spatial_dim"])
220 return min(BLOCK_N, max(1, 2**14 // BLOCK_M))
223def vdot_heur_block_size(args):
224 n = args["n_elements"]
225 if n < 1024:
226 return 32
227 elif n < 8192:
228 return 256
229 else:
230 return 1024
233HEURISTICS_CONFIGS = {
234 "argmax": {
235 "BLOCK_M": argmax_heur_block_m,
236 "BLOCK_N": argmax_heur_block_n,
237 },
238 "argmin": {
239 "BLOCK_M": argmin_heur_block_m,
240 "BLOCK_N": argmin_heur_block_n,
241 },
242 "bmm": {
243 "DIVISIBLE_M": bmm_heur_divisible_m,
244 "DIVISIBLE_N": bmm_heur_divisible_n,
245 "DIVISIBLE_K": bmm_heur_divisible_k,
246 },
247 "dropout": {
248 "BLOCK": dropout_heur_block,
249 "num_warps": dropout_heur_num_warps,
250 },
251 "exponential_": {
252 "BLOCK": exponential_heur_block,
253 "num_warps": exponential_heur_num_warps,
254 },
255 "gather": {
256 "BLOCK_M": gather_heur_block_m,
257 "BLOCK_N": gather_heur_block_n,
258 },
259 "index_select": {
260 "BLOCK_M": index_select_heur_block_m,
261 "BLOCK_N": index_select_heur_block_n,
262 },
263 "mm": {
264 "EVEN_K": mm_heur_even_k,
265 },
266 "rand": {
267 "BLOCK": rand_heur_block,
268 "num_warps": rand_heur_num_warps,
269 },
270 "randn": {
271 "BLOCK": randn_heur_block,
272 "num_warps": randn_heur_num_warps,
273 },
274 "softmax_non_inner": {
275 "TILE_K": softmax_heur_tile_k,
276 "TILE_N": softmax_heur_tile_n_non_inner,
277 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
278 "num_warps": softmax_heur_num_warps_non_inner,
279 },
280 "softmax_inner": {
281 "TILE_N": softmax_heur_tile_n_inner,
282 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
283 "num_warps": softmax_heur_num_warps_inner,
284 },
285 "softmax_backward_non_inner": {
286 "TILE_N": softmax_heur_tile_n_bwd_non_inner,
287 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
288 },
289 "softmax_backward_inner": {
290 "TILE_M": softmax_heur_tile_m,
291 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
292 },
293 "uniform": {
294 "BLOCK": uniform_heur_block,
295 "num_warps": uniform_heur_num_warps,
296 },
297 "upsample_nearest2d": {
298 "SAME_H": upsample_nearest2d_SAME_H,
299 "SAME_W": upsample_nearest2d_SAME_W,
300 },
301 "var_mean": {
302 "BLOCK_N": var_mean_heur_block_n,
303 },
304 "batch_norm": {
305 "BLOCK_M": batch_norm_heur_block_m,
306 "BLOCK_N": batch_norm_heur_block_n,
307 },
308 "vdot": {
309 "BLOCK_SIZE": vdot_heur_block_size,
310 },
311 "elementwise_generic": {
312 "BLOCK_SIZE": simple_elementwise_blocksize_heur,
313 "num_warps": lambda args: 8,
314 },
315}