Coverage for src/flag_gems/runtime/backend/_cambricon/heuristics_config_utils.py: 0%

110 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-13 10:08 +0800

1import torch 

2import triton 

3 

4from .utils import TOTAL_CORE_NUM 

5 

6 

7def argmax_heur_block_m(args): 

8 return 4 if args["M"] < 4096 else 8 

9 

10 

11def argmax_heur_block_n(args): 

12 return min(4096, triton.next_power_of_2(args["N"])) 

13 

14 

15def argmin_heur_block_m(args): 

16 return 4 if args["M"] < 4096 else 8 

17 

18 

19def argmin_heur_block_n(args): 

20 return min(4096, triton.next_power_of_2(args["N"])) 

21 

22 

23def bmm_heur_divisible_m(args): 

24 return args["M"] % args["TILE_M"] == 0 

25 

26 

27def bmm_heur_divisible_n(args): 

28 return args["N"] % args["TILE_N"] == 0 

29 

30 

31def bmm_heur_divisible_k(args): 

32 return args["K"] % args["TILE_K"] == 0 

33 

34 

35def dropout_heur_block(args): 

36 if args["N"] <= 512: 

37 return 512 

38 else: 

39 return 1024 

40 

41 

42def exponential_heur_block(args): 

43 if args["N"] <= 512: 

44 return 512 

45 else: 

46 return 1024 

47 

48 

49def exponential_heur_num_warps(args): 

50 if args["N"] <= 512: 

51 return 4 

52 elif args["N"] <= 1024: 

53 return 8 

54 else: 

55 return 16 

56 

57 

58def gather_heur_block_m(args): 

59 return min(4, triton.next_power_of_2(triton.cdiv(args["N"], 2048))) 

60 

61 

62def gather_heur_block_n(args): 

63 return min(2048, triton.next_power_of_2(args["N"])) 

64 

65 

66def index_select_heur_block_m(args): 

67 return min(4, triton.next_power_of_2(triton.cdiv(256, args["N"]))) 

68 

69 

70def index_select_heur_block_n(args): 

71 m = min(triton.next_power_of_2(triton.cdiv(args["N"], 16)), 512) 

72 return max(m, 16) 

73 

74 

75def mm_heur_even_k(args): 

76 return args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0 

77 

78 

79def rand_heur_block(args): 

80 if args["N"] <= 512: 

81 return 512 

82 else: 

83 return 1024 

84 

85 

86def randn_heur_block(args): 

87 if args["N"] <= 512: 

88 return 512 

89 else: 

90 return 1024 

91 

92 

93def softmax_heur_tile_k(args): 

94 MAX_TILE_K = 8192 

95 NUM_SMS = torch.cuda.get_device_properties( 

96 torch.cuda.current_device() 

97 ).multi_processor_count 

98 tile_k = 1 

99 upper_bound = min(args["K"], MAX_TILE_K) 

100 while tile_k <= upper_bound: 

101 num_blocks = args["M"] * triton.cdiv(args["K"], tile_k) 

102 num_waves = num_blocks / NUM_SMS 

103 if (num_waves > 1) and (tile_k * 2 <= upper_bound): 

104 tile_k *= 2 

105 else: 

106 break 

107 return tile_k 

108 

109 

110def softmax_heur_tile_mode_non_inner(args): 

111 M, N, K, TILE_N, TILE_K = ( 

112 args["M"], 

113 args["N"], 

114 args["K"], 

115 args["TILE_N"], 

116 args["TILE_K"], 

117 ) 

118 one_tile_k = TILE_K * max(TOTAL_CORE_NUM // M, 1) >= K 

119 one_tile_n = TILE_N >= N 

120 if one_tile_n and one_tile_k: 

121 return 0 

122 elif one_tile_n and not one_tile_k: 

123 return 1 

124 else: 

125 return 2 

126 

127 

128def softmax_heur_tile_mode_inner(args): 

129 one_tile_m = args["BLOCK_M"] * TOTAL_CORE_NUM >= args["M"] 

130 one_tile_n = args["BLOCK_N"] >= args["N"] 

131 if one_tile_n and one_tile_m: 

132 return 0 

133 elif one_tile_n and not one_tile_m: 

134 return 1 

135 else: 

136 return 2 

137 

138 

139def uniform_heur_block(args): 

140 if args["N"] <= 512: 

141 return 512 

142 else: 

143 return 1024 

144 

145 

146def uniform_heur_num_warps(args): 

147 if args["N"] <= 512: 

148 return 4 

149 elif args["N"] <= 1024: 

150 return 8 

151 else: 

152 return 16 

153 

154 

155def upsample_nearest2d_SAME_H(args): 

156 return args["OH"] == args["IH"] 

157 

158 

159def upsample_nearest2d_SAME_W(args): 

160 return args["OW"] == args["IW"] 

161 

162 

163def vdot_heur_block_size(args): 

164 n = args["n_elements"] 

165 if n < 1024: 

166 return 32 

167 elif n < 8192: 

168 return 256 

169 else: 

170 return 1024 

171 

172 

173def linspace_heur_inner_block_size(args): 

174 n = args["BLOCK_SIZE"] 

175 if n < 1024: 

176 return 64 

177 elif n < 8192: 

178 return 1024 

179 else: 

180 return 8192 

181 

182 

183def simple_elementwise_blocksize_heur(args): 

184 return 1024 

185 

186 

187HEURISTICS_CONFIGS = { 

188 "argmax": { 

189 "BLOCK_M": argmax_heur_block_m, 

190 "BLOCK_N": argmax_heur_block_n, 

191 }, 

192 "argmin": { 

193 "BLOCK_M": argmin_heur_block_m, 

194 "BLOCK_N": argmin_heur_block_n, 

195 }, 

196 "bmm": { 

197 "DIVISIBLE_M": bmm_heur_divisible_m, 

198 "DIVISIBLE_N": bmm_heur_divisible_n, 

199 "DIVISIBLE_K": bmm_heur_divisible_k, 

200 }, 

201 "dropout": { 

202 "BLOCK": dropout_heur_block, 

203 }, 

204 "exponential_": { 

205 "BLOCK": exponential_heur_block, 

206 }, 

207 "gather": { 

208 "BLOCK_M": gather_heur_block_m, 

209 "BLOCK_N": gather_heur_block_n, 

210 }, 

211 "index_select": { 

212 "BLOCK_M": index_select_heur_block_m, 

213 "BLOCK_N": index_select_heur_block_n, 

214 }, 

215 "mm": { 

216 "EVEN_K": mm_heur_even_k, 

217 }, 

218 "rand": { 

219 "BLOCK": rand_heur_block, 

220 }, 

221 "randn": { 

222 "BLOCK": randn_heur_block, 

223 }, 

224 "softmax_non_inner": { 

225 "TILE_MODE": softmax_heur_tile_mode_non_inner, 

226 }, 

227 "softmax_inner": { 

228 "TILE_MODE": softmax_heur_tile_mode_inner, 

229 }, 

230 "softmax_backward_non_inner": { 

231 "TILE_MODE": softmax_heur_tile_mode_non_inner, 

232 }, 

233 "softmax_backward_inner": { 

234 "TILE_MODE": softmax_heur_tile_mode_inner, 

235 }, 

236 "uniform": { 

237 "BLOCK": uniform_heur_block, 

238 }, 

239 "upsample_nearest2d": { 

240 "SAME_H": upsample_nearest2d_SAME_H, 

241 "SAME_W": upsample_nearest2d_SAME_W, 

242 }, 

243 "var_mean": {}, 

244 "batch_norm": {}, 

245 "vdot": { 

246 "BLOCK_SIZE": vdot_heur_block_size, 

247 }, 

248 "linspace": { 

249 "INNER_BLOCK_SIZE": linspace_heur_inner_block_size, 

250 }, 

251 "elementwise_generic": { 

252 "BLOCK_SIZE": simple_elementwise_blocksize_heur, 

253 "num_warps": lambda args: 8, 

254 }, 

255 "mha_varlen_prefill": { 

256 "BLOCK_M": lambda args: 128, 

257 "BLOCK_N": lambda args: 32, 

258 "num_warps": lambda args: 4, 

259 "num_stages": lambda args: 3, 

260 }, 

261 "mha_varlen_decode": { 

262 "BLOCK_M": lambda args: 16, 

263 "BLOCK_N": lambda args: 64, 

264 "num_warps": lambda args: 4, 

265 "num_stages": lambda args: 3, 

266 }, 

267}