Coverage for src/flag_gems/runtime/backend/_kunlunxin/heuristics_config_utils.py: 0%
139 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
1import torch
2import triton
5def simple_elementwise_blocksize_heur(args):
6 return 1024
9def argmax_heur_block_m(args):
10 return 4 if args["M"] < 4096 else 8
13def argmax_heur_block_n(args):
14 return min(4096, triton.next_power_of_2(args["N"]))
17def argmin_heur_block_m(args):
18 return triton.next_power_of_2(triton.cdiv(args["M"], 12)) # cluster_num
21def argmin_heur_block_n(args):
22 import builtins
24 return builtins.min(triton.next_power_of_2(args["N"]), 8192)
27def bmm_heur_divisible_m(args):
28 return args["M"] % args["TILE_M"] == 0
31def bmm_heur_divisible_n(args):
32 return args["N"] % args["TILE_N"] == 0
35def bmm_heur_divisible_k(args):
36 return args["K"] % args["TILE_K"] == 0
39def dropout_heur_block(args):
40 if args["N"] <= 512:
41 return 512
42 else:
43 return 1024
46def dropout_heur_num_warps(args):
47 if args["N"] <= 512:
48 return 4
49 elif args["N"] <= 1024:
50 return 8
51 else:
52 return 16
55def exponential_heur_block(args):
56 if args["N"] <= 512:
57 return 512
58 else:
59 return 1024
62def exponential_heur_num_warps(args):
63 if args["N"] <= 512:
64 return 4
65 elif args["N"] <= 1024:
66 return 8
67 else:
68 return 16
71def gather_heur_block_m(args):
72 return min(4, triton.next_power_of_2(triton.cdiv(args["N"], 2048)))
75def gather_heur_block_n(args):
76 return min(2048, triton.next_power_of_2(args["N"]))
79def index_add_heur_block_m(args):
80 return triton.next_power_of_2(triton.cdiv(args["M"], 12)) # cluster_num
83def index_add_heur_block_n(args):
84 # if args["N"] > 8192:
85 # return 64
86 # if args["N"] > 256:
87 # return 256
89 # return args["N"]
90 return min(8192, triton.next_power_of_2(args["N"]))
93def index_select_heur_block_m(args):
94 return triton.next_power_of_2(triton.cdiv(args["M"], 12)) # cluster_num
97def index_select_heur_block_n(args):
98 return 64
101def mm_heur_even_k(args):
102 return args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0
105def rand_heur_block(args):
106 return triton.next_power_of_2(triton.cdiv(args["N"], 12 * 4)) # CLUSTER_NUM = 12
107 if args["N"] <= 512:
108 return 512
109 else:
110 return 1024
113def rand_heur_num_warps(args):
114 if args["N"] <= 512:
115 return 4
116 elif args["N"] <= 1024:
117 return 8
118 else:
119 return 16
122def randn_heur_block(args):
123 if args["N"] <= 512:
124 return 512
125 else:
126 return 1024
129def randn_heur_num_warps(args):
130 if args["N"] <= 512:
131 return 4
132 elif args["N"] <= 1024:
133 return 8
134 else:
135 return 16
138def softmax_heur_tile_k(args):
139 MAX_TILE_K = 8192
140 NUM_SMS = torch.cuda.get_device_properties(
141 torch.cuda.current_device()
142 ).multi_processor_count
143 tile_k = 1
144 upper_bound = min(args["K"], MAX_TILE_K)
145 while tile_k <= upper_bound:
146 num_blocks = args["M"] * triton.cdiv(args["K"], tile_k)
147 num_waves = num_blocks / NUM_SMS
148 if (num_waves > 1) and (tile_k * 2 <= upper_bound):
149 tile_k *= 2
150 else:
151 break
152 return tile_k
155def softmax_heur_tile_n_non_inner(args):
156 return triton.cdiv(8192, args["TILE_K"])
159def softmax_heur_one_tile_per_cta(args):
160 return args["TILE_N"] >= args["N"]
163def softmax_heur_num_warps_non_inner(args):
164 tile_size = args["TILE_N"] * args["TILE_K"]
165 if tile_size < 2048:
166 return 4
167 elif tile_size < 4096:
168 return 8
169 else:
170 return 16
173def softmax_heur_tile_n_inner(args):
174 if args["N"] <= (32 * 1024):
175 return triton.next_power_of_2(args["N"])
176 else:
177 return 4096
180def softmax_heur_num_warps_inner(args):
181 tile_size = args["TILE_N"]
182 if tile_size < 2048:
183 return 4
184 elif tile_size < 4096:
185 return 8
186 else:
187 return 16
190def softmax_heur_tile_n_bwd_non_inner(args):
191 return max(1, 1024 // args["TILE_K"])
194def softmax_heur_tile_m(args):
195 return max(1, 1024 // args["TILE_N"])
198def uniform_heur_block(args):
199 if args["N"] <= 512:
200 return 512
201 else:
202 return 1024
205def uniform_heur_num_warps(args):
206 if args["N"] <= 512:
207 return 4
208 elif args["N"] <= 1024:
209 return 8
210 else:
211 return 16
214def var_mean_heur_block_n(args):
215 return triton.next_power_of_2(args["BLOCK_NUM"])
218def upsample_nearest2d_SAME_H(args):
219 return args["OH"] == args["IH"]
222def upsample_nearest2d_SAME_W(args):
223 return args["OW"] == args["IW"]
226def batch_norm_heur_block_m(args):
227 return min(2048, triton.next_power_of_2(args["batch_dim"]))
230def batch_norm_heur_block_n(args):
231 # A maximum of 16384 elements are loaded at once.
232 BLOCK_M = batch_norm_heur_block_m(args)
233 BLOCK_N = triton.next_power_of_2(args["spatial_dim"])
234 return min(BLOCK_N, max(1, 2**14 // BLOCK_M))
237def vdot_heur_block_size(args):
238 n = args["n_elements"]
239 if n < 1024:
240 return 32
241 elif n < 8192:
242 return 256
243 else:
244 return 1024
247HEURISTICS_CONFIGS = {
248 "argmax": {
249 "BLOCK_M": argmax_heur_block_m,
250 "BLOCK_N": argmax_heur_block_n,
251 },
252 "argmin": {
253 "BLOCK_M": argmin_heur_block_m,
254 "BLOCK_N": argmin_heur_block_n,
255 },
256 "bmm": {
257 "DIVISIBLE_M": bmm_heur_divisible_m,
258 "DIVISIBLE_N": bmm_heur_divisible_n,
259 "DIVISIBLE_K": bmm_heur_divisible_k,
260 },
261 "dropout": {
262 "BLOCK": dropout_heur_block,
263 "num_warps": dropout_heur_num_warps,
264 },
265 "exponential_": {
266 "BLOCK": exponential_heur_block,
267 "num_warps": exponential_heur_num_warps,
268 },
269 "gather": {
270 "BLOCK_M": gather_heur_block_m,
271 "BLOCK_N": gather_heur_block_n,
272 },
273 "index_select": {
274 "BLOCK_M": index_select_heur_block_m,
275 "BLOCK_N": index_select_heur_block_n,
276 },
277 "index_add": {
278 "BLOCK_M": index_add_heur_block_m,
279 "BLOCK_N": index_add_heur_block_n,
280 },
281 "mm": {
282 "EVEN_K": mm_heur_even_k,
283 },
284 "rand": {
285 "BLOCK": rand_heur_block,
286 "num_warps": rand_heur_num_warps,
287 },
288 "randn": {
289 "BLOCK": randn_heur_block,
290 "num_warps": randn_heur_num_warps,
291 },
292 "softmax_non_inner": {
293 "TILE_K": softmax_heur_tile_k,
294 "TILE_N": softmax_heur_tile_n_non_inner,
295 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
296 "num_warps": softmax_heur_num_warps_non_inner,
297 },
298 "softmax_inner": {
299 "TILE_N": softmax_heur_tile_n_inner,
300 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
301 "num_warps": softmax_heur_num_warps_inner,
302 },
303 "softmax_backward_non_inner": {
304 "TILE_N": softmax_heur_tile_n_bwd_non_inner,
305 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
306 },
307 "softmax_backward_inner": {
308 "TILE_M": softmax_heur_tile_m,
309 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
310 },
311 "uniform": {
312 "BLOCK": uniform_heur_block,
313 "num_warps": uniform_heur_num_warps,
314 },
315 "upsample_nearest2d": {
316 "SAME_H": upsample_nearest2d_SAME_H,
317 "SAME_W": upsample_nearest2d_SAME_W,
318 },
319 "var_mean": {
320 "BLOCK_N": var_mean_heur_block_n,
321 },
322 "batch_norm": {
323 "BLOCK_M": batch_norm_heur_block_m,
324 "BLOCK_N": batch_norm_heur_block_n,
325 },
326 "vdot": {
327 "BLOCK_SIZE": vdot_heur_block_size,
328 },
329 "elementwise_generic": {
330 "BLOCK_SIZE": simple_elementwise_blocksize_heur,
331 "num_warps": lambda args: 8,
332 },
333}