Coverage for src/flag_gems/utils/codegen_config_utils.py: 62%

58 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-21 14:31 +0800

1from dataclasses import dataclass 

2from typing import Tuple 

3 

4import triton 

5 

6from flag_gems.runtime import device 

7from flag_gems.runtime.backend import vendor_module 

8from flag_gems.runtime.common import vendors 

9 

10 

11def default_heuristics_for_num_warps(tile_size): 

12 if tile_size < 2048: 

13 return 4 

14 elif tile_size < 4096: 

15 return 8 

16 else: 

17 return 16 

18 

19 

20def metax_heuristics_for_num_warps(tile_size): 

21 if tile_size <= 1024: 

22 return 4 

23 elif tile_size <= 2048: 

24 return 8 

25 else: 

26 return 16 

27 

28 

29def hygon_heuristics_for_num_warps(tile_size): 

30 if tile_size <= 1024: 

31 return 4 

32 elif tile_size <= 2048: 

33 return 8 

34 else: 

35 return 16 

36 

37 

38def cambricon_heuristics_for_num_warps(tile_size): 

39 return 1 

40 

41 

42def tsingmicro_heuristics_for_num_warps(tile_size): 

43 return 1 

44 

45 

46def sunrise_heuristics_for_num_warps(tile_size): 

47 if tile_size < 1024: 

48 return 4 

49 elif tile_size < 2048: 

50 return 8 

51 elif tile_size < 4096: 

52 return 16 

53 else: 

54 return 32 

55 

56 

57def enflame_heuristics_for_num_warps(tile_size): 

58 return 4 

59 

60 

61@dataclass 

62class CodeGenConfig: 

63 max_tile_size: int 

64 max_grid_size: Tuple[int, int, int] 

65 max_num_warps_per_cta: int 

66 

67 prefer_block_pointer: bool 

68 prefer_1d_tile: bool 

69 # gen_configs: -> configs 

70 # prune_config: (as jit function, ) cofigs -> configs 

71 

72 def __post_init__(self): 

73 if self.prefer_1d_tile: 

74 self.prefer_block_pointer = False 

75 

76 

77CODEGEN_COFIGS = { 

78 vendors.NVIDIA: CodeGenConfig( 

79 512, 

80 (65536, 65536, 65536), 

81 32, 

82 True, 

83 prefer_1d_tile=int(triton.__version__[0]) < 3, 

84 ), 

85 vendors.CAMBRICON: ( 

86 CodeGenConfig( 

87 8192, 

88 tuple([vendor_module.TOTAL_CORE_NUM, 1, 1]), 

89 32, 

90 True, 

91 prefer_1d_tile=int(triton.__version__[0]) < 3, 

92 ) 

93 if vendor_module.vendor_info.vendor_name == "cambricon" 

94 else None 

95 ), 

96 vendors.METAX: CodeGenConfig( 

97 2048, 

98 (65536, 65536, 65536), 

99 16, 

100 True, 

101 prefer_1d_tile=int(triton.__version__[0]) < 3, 

102 ), 

103 vendors.MTHREADS: CodeGenConfig( 

104 512, 

105 (2147483647, 2147483647, 2147483647), 

106 32, 

107 True, 

108 prefer_1d_tile=int(triton.__version__[0]) < 3, 

109 ), 

110 vendors.KUNLUNXIN: CodeGenConfig( 

111 512, 

112 (65536, 65536, 65536), 

113 32, 

114 True, 

115 prefer_1d_tile=True, 

116 ), 

117 vendors.ASCEND: CodeGenConfig( 

118 512, 

119 tuple([48, 1, 1]), 

120 32, 

121 False, 

122 prefer_1d_tile=int(triton.__version__[0]) < 3, 

123 ), 

124 vendors.HYGON: CodeGenConfig( 

125 2048, 

126 (65536, 65536, 65536), 

127 16, 

128 True, 

129 prefer_1d_tile=int(triton.__version__[0]) < 3, 

130 ), 

131 vendors.TSINGMICRO: CodeGenConfig( 

132 4096, 

133 (16, 16, 16), 

134 1, 

135 True, 

136 prefer_1d_tile=int(triton.__version__[0]) < 3, 

137 ), 

138 vendors.ENFLAME: CodeGenConfig( 

139 512 * 8, 

140 (12, 1, 1), 

141 4, 

142 True, 

143 prefer_1d_tile=int(triton.__version__[0]) < 3, 

144 ), 

145} 

146 

147HEURISTICS_CONFIG = { 

148 vendors.NVIDIA: default_heuristics_for_num_warps, 

149 vendors.METAX: metax_heuristics_for_num_warps, 

150 vendors.CAMBRICON: cambricon_heuristics_for_num_warps, 

151 vendors.HYGON: hygon_heuristics_for_num_warps, 

152 vendors.TSINGMICRO: tsingmicro_heuristics_for_num_warps, 

153 vendors.SUNRISE: sunrise_heuristics_for_num_warps, 

154 vendors.ENFLAME: enflame_heuristics_for_num_warps, 

155} 

156 

157 

158def get_codegen_config(): 

159 if device.vendor not in CODEGEN_COFIGS: 

160 return CODEGEN_COFIGS.get(vendors.NVIDIA) 

161 return CODEGEN_COFIGS.get(device.vendor) 

162 

163 

164def get_heuristics_for_num_warps(tile_size): 

165 if device.vendor not in HEURISTICS_CONFIG: 

166 return HEURISTICS_CONFIG.get(vendors.NVIDIA)(tile_size) 

167 return HEURISTICS_CONFIG.get(device.vendor)(tile_size)