Coverage for src/flag_gems/utils/codegen_config_utils.py: 62%
58 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
1from dataclasses import dataclass
2from typing import Tuple
4import triton
6from flag_gems.runtime import device
7from flag_gems.runtime.backend import vendor_module
8from flag_gems.runtime.common import vendors
11def default_heuristics_for_num_warps(tile_size):
12 if tile_size < 2048:
13 return 4
14 elif tile_size < 4096:
15 return 8
16 else:
17 return 16
20def metax_heuristics_for_num_warps(tile_size):
21 if tile_size <= 1024:
22 return 4
23 elif tile_size <= 2048:
24 return 8
25 else:
26 return 16
29def hygon_heuristics_for_num_warps(tile_size):
30 if tile_size <= 1024:
31 return 4
32 elif tile_size <= 2048:
33 return 8
34 else:
35 return 16
38def cambricon_heuristics_for_num_warps(tile_size):
39 return 1
42def tsingmicro_heuristics_for_num_warps(tile_size):
43 return 1
46def sunrise_heuristics_for_num_warps(tile_size):
47 if tile_size < 1024:
48 return 4
49 elif tile_size < 2048:
50 return 8
51 elif tile_size < 4096:
52 return 16
53 else:
54 return 32
57def enflame_heuristics_for_num_warps(tile_size):
58 return 4
61@dataclass
62class CodeGenConfig:
63 max_tile_size: int
64 max_grid_size: Tuple[int, int, int]
65 max_num_warps_per_cta: int
67 prefer_block_pointer: bool
68 prefer_1d_tile: bool
69 # gen_configs: -> configs
70 # prune_config: (as jit function, ) cofigs -> configs
72 def __post_init__(self):
73 if self.prefer_1d_tile:
74 self.prefer_block_pointer = False
77CODEGEN_COFIGS = {
78 vendors.NVIDIA: CodeGenConfig(
79 512,
80 (65536, 65536, 65536),
81 32,
82 True,
83 prefer_1d_tile=int(triton.__version__[0]) < 3,
84 ),
85 vendors.CAMBRICON: (
86 CodeGenConfig(
87 8192,
88 tuple([vendor_module.TOTAL_CORE_NUM, 1, 1]),
89 32,
90 True,
91 prefer_1d_tile=int(triton.__version__[0]) < 3,
92 )
93 if vendor_module.vendor_info.vendor_name == "cambricon"
94 else None
95 ),
96 vendors.METAX: CodeGenConfig(
97 2048,
98 (65536, 65536, 65536),
99 16,
100 True,
101 prefer_1d_tile=int(triton.__version__[0]) < 3,
102 ),
103 vendors.MTHREADS: CodeGenConfig(
104 512,
105 (2147483647, 2147483647, 2147483647),
106 32,
107 True,
108 prefer_1d_tile=int(triton.__version__[0]) < 3,
109 ),
110 vendors.KUNLUNXIN: CodeGenConfig(
111 512,
112 (65536, 65536, 65536),
113 32,
114 True,
115 prefer_1d_tile=True,
116 ),
117 vendors.ASCEND: CodeGenConfig(
118 512,
119 tuple([48, 1, 1]),
120 32,
121 False,
122 prefer_1d_tile=int(triton.__version__[0]) < 3,
123 ),
124 vendors.HYGON: CodeGenConfig(
125 2048,
126 (65536, 65536, 65536),
127 16,
128 True,
129 prefer_1d_tile=int(triton.__version__[0]) < 3,
130 ),
131 vendors.TSINGMICRO: CodeGenConfig(
132 4096,
133 (16, 16, 16),
134 1,
135 True,
136 prefer_1d_tile=int(triton.__version__[0]) < 3,
137 ),
138 vendors.ENFLAME: CodeGenConfig(
139 512 * 8,
140 (12, 1, 1),
141 4,
142 True,
143 prefer_1d_tile=int(triton.__version__[0]) < 3,
144 ),
145}
147HEURISTICS_CONFIG = {
148 vendors.NVIDIA: default_heuristics_for_num_warps,
149 vendors.METAX: metax_heuristics_for_num_warps,
150 vendors.CAMBRICON: cambricon_heuristics_for_num_warps,
151 vendors.HYGON: hygon_heuristics_for_num_warps,
152 vendors.TSINGMICRO: tsingmicro_heuristics_for_num_warps,
153 vendors.SUNRISE: sunrise_heuristics_for_num_warps,
154 vendors.ENFLAME: enflame_heuristics_for_num_warps,
155}
158def get_codegen_config():
159 if device.vendor not in CODEGEN_COFIGS:
160 return CODEGEN_COFIGS.get(vendors.NVIDIA)
161 return CODEGEN_COFIGS.get(device.vendor)
164def get_heuristics_for_num_warps(tile_size):
165 if device.vendor not in HEURISTICS_CONFIG:
166 return HEURISTICS_CONFIG.get(vendors.NVIDIA)(tile_size)
167 return HEURISTICS_CONFIG.get(device.vendor)(tile_size)