Coverage for src/flag_gems/runtime/backend/_enflame/heuristics_config_utils.py: 0%
114 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
1import torch
2import triton
5def argmax_heur_block_m(args):
6 return 4 if args["M"] < 4096 else 8
9def argmax_heur_block_n(args):
10 return min(4096, triton.next_power_of_2(args["N"]))
13def argmin_heur_block_m(args):
14 return 4 if args["M"] < 4096 else 8
17def argmin_heur_block_n(args):
18 return min(4096, triton.next_power_of_2(args["N"]))
21# def bmm_heur_divisible_m(args):
22# return args["M"] % args["BLOCK_M"] == 0
25# def bmm_heur_divisible_n(args):
26# return args["N"] % args["BLOCK_N"] == 0
29# def bmm_heur_divisible_k(args):
30# return args["K"] % args["BLOCK_K"] == 0
33def dropout_heur_block(args):
34 if args["N"] <= 512:
35 return 512
36 else:
37 return 4096
40def dropout_heur_num_warps(args):
41 return 4
44def exponential_heur_block(args):
45 if args["N"] <= 512:
46 return 512
47 else:
48 return 16384
51def exponential_heur_num_warps(args):
52 return 4
55def gather_heur_block_m(args):
56 return min(4, triton.next_power_of_2(triton.cdiv(args["N"], 2048)))
59def gather_heur_block_n(args):
60 return min(2048, triton.next_power_of_2(args["N"]))
63def index_select_heur_block_m(args):
64 return min(16, triton.next_power_of_2(triton.cdiv(32768, args["N"])))
67def index_select_heur_block_n(args):
68 m = min(triton.next_power_of_2(triton.cdiv(args["N"], 16)), 512)
69 return max(m, 16)
72def mm_heur_even_k(args):
73 return args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0
76def rand_heur_block(args):
77 if args["N"] <= 512:
78 return 512
79 else:
80 return 16384
83def rand_heur_num_warps(args):
84 return 4
87def randn_heur_block(args):
88 if args["N"] <= 512:
89 return 512
90 else:
91 return 16384
94def randn_heur_num_warps(args):
95 return 4
98def softmax_heur_tile_k(args):
99 MAX_TILE_K = 8192
100 NUM_SMS = torch.cuda.get_device_properties(
101 torch.cuda.current_device()
102 ).multi_processor_count
103 tile_k = 1
104 upper_bound = min(args["K"], MAX_TILE_K)
105 while tile_k <= upper_bound:
106 num_blocks = args["M"] * triton.cdiv(args["K"], tile_k)
107 num_waves = num_blocks / NUM_SMS
108 if (num_waves > 1) and (tile_k * 2 <= upper_bound):
109 tile_k *= 2
110 else:
111 break
112 return tile_k
115def softmax_heur_tile_n_non_inner(args):
116 return triton.cdiv(8192, args["TILE_K"])
119def softmax_heur_one_tile_per_cta(args):
120 return args["TILE_N"] >= args["N"]
123def softmax_heur_num_warps_non_inner(args):
124 return 4
127def softmax_heur_tile_n_inner(args):
128 if args["N"] <= (32 * 1024):
129 return triton.next_power_of_2(args["N"])
130 else:
131 return 4096
134def softmax_heur_num_warps_inner(args):
135 return 4
138def softmax_heur_tile_n_bwd_non_inner(args):
139 return max(1, 1024 // args["TILE_K"])
142def softmax_heru_tile_m(args):
143 return max(1, 1024 // args["TILE_N"])
146def uniform_heur_block(args):
147 if args["N"] <= 512:
148 return 512
149 else:
150 return 16384
153def uniform_heur_num_warps(args):
154 return 4
157def var_mean_heur_block_n(args):
158 return triton.next_power_of_2(args["BLOCK_NUM"])
161def upsample_nearest2d_NUM_TILE(args):
162 grid_y = triton.cdiv(args["N"] * args["C"], 4)
163 if grid_y <= 128:
164 num_tile = 1
165 else:
166 num_tile = triton.cdiv(grid_y, 128)
167 return num_tile
170def upsample_nearest2d_TOTAL_TILE(args):
171 return triton.cdiv(args["N"] * args["C"], 4)
174def upsample_nearest2d_SAME_H(args):
175 return args["OH"] == args["IH"]
178def upsample_nearest2d_SAME_W(args):
179 return args["OW"] == args["IW"]
182def upsample_nearest2d_USE_INT32_IDX(args):
183 return args["N"] * args["C"] * args["OH"] * args["OW"] <= (2**31 - 1) # INT32 MAX
186def batch_norm_heur_block_m(args):
187 return min(2048, triton.next_power_of_2(args["batch_dim"]))
190def batch_norm_heur_block_n(args):
191 # A maximum of 16384 elements are loaded at once.
192 BLOCK_M = batch_norm_heur_block_m(args)
193 BLOCK_N = triton.next_power_of_2(args["spatial_dim"])
194 return min(BLOCK_N, max(1, 2**14 // BLOCK_M))
197def vdot_heur_block_size(args):
198 n = args["n_elements"]
199 if n < 1024:
200 return 32
201 elif n < 8192:
202 return 256
203 else:
204 return 1024
207def simple_elementwise_blocksize_heur(args):
208 n = args["n_elements"]
209 if n < 65535:
210 return 1024
211 else:
212 return 16384
215HEURISTICS_CONFIGS = {
216 "argmax": {
217 "BLOCK_M": argmax_heur_block_m,
218 "BLOCK_N": argmax_heur_block_n,
219 },
220 "argmin": {
221 "BLOCK_M": argmin_heur_block_m,
222 "BLOCK_N": argmin_heur_block_n,
223 },
224 "bmm": {
225 # "DIVISIBLE_M": bmm_heur_divisible_m,
226 # "DIVISIBLE_N": bmm_heur_divisible_n,
227 # "DIVISIBLE_K": bmm_heur_divisible_k,
228 },
229 "dropout": {
230 "BLOCK": dropout_heur_block,
231 "num_warps": dropout_heur_num_warps,
232 },
233 "exponential_": {
234 "BLOCK": exponential_heur_block,
235 "num_warps": exponential_heur_num_warps,
236 },
237 "gather": {
238 "BLOCK_M": gather_heur_block_m,
239 "BLOCK_N": gather_heur_block_n,
240 },
241 "index_select": {
242 "BLOCK_M": index_select_heur_block_m,
243 "BLOCK_N": index_select_heur_block_n,
244 },
245 "mm": {
246 "EVEN_K": mm_heur_even_k,
247 },
248 "rand": {
249 "BLOCK": rand_heur_block,
250 "num_warps": rand_heur_num_warps,
251 },
252 "randn": {
253 "BLOCK": randn_heur_block,
254 "num_warps": randn_heur_num_warps,
255 },
256 "softmax_non_inner": {
257 "TILE_K": softmax_heur_tile_k,
258 "TILE_N": softmax_heur_tile_n_non_inner,
259 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
260 "num_warps": softmax_heur_num_warps_non_inner,
261 },
262 "softmax_inner": {
263 "TILE_N": softmax_heur_tile_n_inner,
264 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
265 "num_warps": softmax_heur_num_warps_inner,
266 },
267 "softmax_backward_non_inner": {
268 "TILE_N": softmax_heur_tile_n_bwd_non_inner,
269 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
270 },
271 "softmax_backward_inner": {
272 "TILE_M": softmax_heru_tile_m,
273 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
274 },
275 "uniform": {
276 "BLOCK": uniform_heur_block,
277 "num_warps": uniform_heur_num_warps,
278 },
279 "upsample_nearest2d": {
280 "NUM_TILE": upsample_nearest2d_NUM_TILE,
281 "TOTAL_TILE": upsample_nearest2d_TOTAL_TILE,
282 "SAME_H": upsample_nearest2d_SAME_H,
283 "SAME_W": upsample_nearest2d_SAME_W,
284 "USE_INT32_IDX": upsample_nearest2d_USE_INT32_IDX,
285 },
286 "var_mean": {
287 "BLOCK_N": var_mean_heur_block_n,
288 },
289 "batch_norm": {
290 "BLOCK_M": batch_norm_heur_block_m,
291 "BLOCK_N": batch_norm_heur_block_n,
292 },
293 "vdot": {
294 "BLOCK_SIZE": vdot_heur_block_size,
295 },
296 "elementwise_generic": {
297 "BLOCK_SIZE": simple_elementwise_blocksize_heur,
298 "num_warps": lambda args: 4,
299 },
300 "mha_varlen_fwd": {
301 "BLOCK_M": lambda args: 128,
302 "BLOCK_N": lambda args: 32,
303 "num_warps": lambda args: 4,
304 "num_stages": lambda args: 3,
305 },
306}