Coverage for src/flag_gems/runtime/backend/_cambricon/heuristics_config_utils.py: 0%
110 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-16 02:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-16 02:02 +0800
1import torch
2import triton
4from .utils import TOTAL_CORE_NUM
7def argmax_heur_block_m(args):
8 return 4 if args["M"] < 4096 else 8
11def argmax_heur_block_n(args):
12 return min(4096, triton.next_power_of_2(args["N"]))
15def argmin_heur_block_m(args):
16 return 4 if args["M"] < 4096 else 8
19def argmin_heur_block_n(args):
20 return min(4096, triton.next_power_of_2(args["N"]))
23def bmm_heur_divisible_m(args):
24 return args["M"] % args["TILE_M"] == 0
27def bmm_heur_divisible_n(args):
28 return args["N"] % args["TILE_N"] == 0
31def bmm_heur_divisible_k(args):
32 return args["K"] % args["TILE_K"] == 0
35def dropout_heur_block(args):
36 if args["N"] <= 512:
37 return 512
38 else:
39 return 1024
42def exponential_heur_block(args):
43 if args["N"] <= 512:
44 return 512
45 else:
46 return 1024
49def exponential_heur_num_warps(args):
50 if args["N"] <= 512:
51 return 4
52 elif args["N"] <= 1024:
53 return 8
54 else:
55 return 16
58def gather_heur_block_m(args):
59 return min(4, triton.next_power_of_2(triton.cdiv(args["N"], 2048)))
62def gather_heur_block_n(args):
63 return min(2048, triton.next_power_of_2(args["N"]))
66def index_select_heur_block_m(args):
67 return min(4, triton.next_power_of_2(triton.cdiv(256, args["N"])))
70def index_select_heur_block_n(args):
71 m = min(triton.next_power_of_2(triton.cdiv(args["N"], 16)), 512)
72 return max(m, 16)
75def mm_heur_even_k(args):
76 return args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0
79def rand_heur_block(args):
80 if args["N"] <= 512:
81 return 512
82 else:
83 return 1024
86def randn_heur_block(args):
87 if args["N"] <= 512:
88 return 512
89 else:
90 return 1024
93def softmax_heur_tile_k(args):
94 MAX_TILE_K = 8192
95 NUM_SMS = torch.cuda.get_device_properties(
96 torch.cuda.current_device()
97 ).multi_processor_count
98 tile_k = 1
99 upper_bound = min(args["K"], MAX_TILE_K)
100 while tile_k <= upper_bound:
101 num_blocks = args["M"] * triton.cdiv(args["K"], tile_k)
102 num_waves = num_blocks / NUM_SMS
103 if (num_waves > 1) and (tile_k * 2 <= upper_bound):
104 tile_k *= 2
105 else:
106 break
107 return tile_k
110def softmax_heur_tile_mode_non_inner(args):
111 M, N, K, TILE_N, TILE_K = (
112 args["M"],
113 args["N"],
114 args["K"],
115 args["TILE_N"],
116 args["TILE_K"],
117 )
118 one_tile_k = TILE_K * max(TOTAL_CORE_NUM // M, 1) >= K
119 one_tile_n = TILE_N >= N
120 if one_tile_n and one_tile_k:
121 return 0
122 elif one_tile_n and not one_tile_k:
123 return 1
124 else:
125 return 2
128def softmax_heur_tile_mode_inner(args):
129 one_tile_m = args["BLOCK_M"] * TOTAL_CORE_NUM >= args["M"]
130 one_tile_n = args["BLOCK_N"] >= args["N"]
131 if one_tile_n and one_tile_m:
132 return 0
133 elif one_tile_n and not one_tile_m:
134 return 1
135 else:
136 return 2
139def uniform_heur_block(args):
140 if args["N"] <= 512:
141 return 512
142 else:
143 return 1024
146def uniform_heur_num_warps(args):
147 if args["N"] <= 512:
148 return 4
149 elif args["N"] <= 1024:
150 return 8
151 else:
152 return 16
155def upsample_nearest2d_SAME_H(args):
156 return args["OH"] == args["IH"]
159def upsample_nearest2d_SAME_W(args):
160 return args["OW"] == args["IW"]
163def vdot_heur_block_size(args):
164 n = args["n_elements"]
165 if n < 1024:
166 return 32
167 elif n < 8192:
168 return 256
169 else:
170 return 1024
173def linspace_heur_inner_block_size(args):
174 n = args["BLOCK_SIZE"]
175 if n < 1024:
176 return 64
177 elif n < 8192:
178 return 1024
179 else:
180 return 8192
183def simple_elementwise_blocksize_heur(args):
184 return 1024
187HEURISTICS_CONFIGS = {
188 "argmax": {
189 "BLOCK_M": argmax_heur_block_m,
190 "BLOCK_N": argmax_heur_block_n,
191 },
192 "argmin": {
193 "BLOCK_M": argmin_heur_block_m,
194 "BLOCK_N": argmin_heur_block_n,
195 },
196 "bmm": {
197 "DIVISIBLE_M": bmm_heur_divisible_m,
198 "DIVISIBLE_N": bmm_heur_divisible_n,
199 "DIVISIBLE_K": bmm_heur_divisible_k,
200 },
201 "dropout": {
202 "BLOCK": dropout_heur_block,
203 },
204 "exponential_": {
205 "BLOCK": exponential_heur_block,
206 },
207 "gather": {
208 "BLOCK_M": gather_heur_block_m,
209 "BLOCK_N": gather_heur_block_n,
210 },
211 "index_select": {
212 "BLOCK_M": index_select_heur_block_m,
213 "BLOCK_N": index_select_heur_block_n,
214 },
215 "mm": {
216 "EVEN_K": mm_heur_even_k,
217 },
218 "rand": {
219 "BLOCK": rand_heur_block,
220 },
221 "randn": {
222 "BLOCK": randn_heur_block,
223 },
224 "softmax_non_inner": {
225 "TILE_MODE": softmax_heur_tile_mode_non_inner,
226 },
227 "softmax_inner": {
228 "TILE_MODE": softmax_heur_tile_mode_inner,
229 },
230 "softmax_backward_non_inner": {
231 "TILE_MODE": softmax_heur_tile_mode_non_inner,
232 },
233 "softmax_backward_inner": {
234 "TILE_MODE": softmax_heur_tile_mode_inner,
235 },
236 "uniform": {
237 "BLOCK": uniform_heur_block,
238 },
239 "upsample_nearest2d": {
240 "SAME_H": upsample_nearest2d_SAME_H,
241 "SAME_W": upsample_nearest2d_SAME_W,
242 },
243 "var_mean": {},
244 "batch_norm": {},
245 "vdot": {
246 "BLOCK_SIZE": vdot_heur_block_size,
247 },
248 "linspace": {
249 "INNER_BLOCK_SIZE": linspace_heur_inner_block_size,
250 },
251 "elementwise_generic": {
252 "BLOCK_SIZE": simple_elementwise_blocksize_heur,
253 "num_warps": lambda args: 8,
254 },
255 "mha_varlen_prefill": {
256 "BLOCK_M": lambda args: 128,
257 "BLOCK_N": lambda args: 32,
258 "num_warps": lambda args: 4,
259 "num_stages": lambda args: 3,
260 },
261 "mha_varlen_decode": {
262 "BLOCK_M": lambda args: 16,
263 "BLOCK_N": lambda args: 64,
264 "num_warps": lambda args: 4,
265 "num_stages": lambda args: 3,
266 },
267}