Coverage for src/flag_gems/runtime/backend/_ascend/heuristics_config_utils.py: 0%

143 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-13 10:08 +0800

1import triton 

2 

3 

4def argmax_heur_block_m(args): 

5 return 16 

6 

7 

8def argmax_heur_block_n(args): 

9 return 100 

10 

11 

12def argmax_heur_tile_k(args): 

13 tile_k = 64 

14 return tile_k 

15 

16 

17def argmax_heur_tile_n_non_inner(args): 

18 return 128 

19 

20 

21def argmax_heur_one_tile_per_cta(args): 

22 return args["TILE_N"] >= args["N"] 

23 

24 

25def argmin_heur_block_m(args): 

26 return 16 

27 

28 

29def argmin_heur_block_n(args): 

30 return 100 

31 

32 

33def bmm_heur_divisible_m(args): 

34 return args["M"] % args["TILE_M"] == 0 

35 

36 

37def bmm_heur_divisible_n(args): 

38 return args["N"] % args["TILE_N"] == 0 

39 

40 

41def bmm_heur_divisible_k(args): 

42 return args["K"] % args["TILE_K"] == 0 

43 

44 

45def dropout_heur_block(args): 

46 if args["N"] <= 512: 

47 return 512 

48 else: 

49 return 4096 

50 

51 

52def dropout_heur_num_warps(args): 

53 if args["N"] <= 512: 

54 return 4 

55 elif args["N"] <= 1024: 

56 return 8 

57 else: 

58 return 16 

59 

60 

61def exponential_heur_block(args): 

62 if args["N"] <= 512: 

63 return 512 

64 else: 

65 return 1024 

66 

67 

68def exponential_heur_num_warps(args): 

69 if args["N"] <= 512: 

70 return 4 

71 elif args["N"] <= 1024: 

72 return 8 

73 else: 

74 return 16 

75 

76 

77def gather_heur_block_m(args): 

78 return min(4, triton.next_power_of_2(triton.cdiv(args["N"], 2048))) 

79 

80 

81def gather_heur_block_n(args): 

82 return min(2048, triton.next_power_of_2(args["N"])) 

83 

84 

85def index_select_heur_block_m(args): 

86 return min(4, triton.next_power_of_2(triton.cdiv(256, args["N"]))) 

87 

88 

89def index_select_heur_block_n(args): 

90 m = min(triton.next_power_of_2(triton.cdiv(args["N"], 16)), 512) 

91 return max(m, 16) 

92 

93 

94def mm_heur_even_k(args): 

95 return args["K"] % (args["BLOCK_K"]) == 0 

96 

97 

98def rand_heur_block(args): 

99 if args["N"] <= 512: 

100 return 2048 

101 else: 

102 return 4097 

103 

104 

105def rand_heur_num_warps(args): 

106 if args["N"] <= 512: 

107 return 4 

108 elif args["N"] <= 1024: 

109 return 8 

110 else: 

111 return 16 

112 

113 

114def randn_heur_block(args): 

115 if args["N"] <= 512: 

116 return 2048 

117 else: 

118 return 4097 

119 

120 

121def randn_heur_num_warps(args): 

122 if args["N"] <= 512: 

123 return 4 

124 elif args["N"] <= 1024: 

125 return 8 

126 else: 

127 return 16 

128 

129 

130def softmax_heur_tile_k(args): 

131 MAX_TILE_K = 4096 

132 # FIXME: 

133 # NUM_SMS should be obtained by API. 

134 # It is actually the number of AIV cores which depends on the Ascend version. 

135 NUM_SMS = 40 

136 tile_k = 1 

137 upper_bound = min(args["K"], MAX_TILE_K) 

138 while tile_k <= upper_bound: 

139 num_blocks = args["M"] * triton.cdiv(args["K"], tile_k) 

140 num_waves = num_blocks / NUM_SMS 

141 if (num_waves > 1) and (tile_k * 2 <= upper_bound): 

142 tile_k *= 2 

143 else: 

144 break 

145 return tile_k 

146 

147 

148def softmax_heur_tile_n_non_inner(args): 

149 return triton.cdiv(768, args["TILE_K"]) 

150 

151 

152def softmax_heur_one_tile_per_cta(args): 

153 return args["TILE_N"] >= args["N"] 

154 

155 

156def softmax_heur_num_warps_non_inner(args): 

157 tile_size = args["TILE_N"] * args["TILE_K"] 

158 if tile_size < 2048: 

159 return 4 

160 elif tile_size < 4096: 

161 return 8 

162 else: 

163 return 16 

164 

165 

166def softmax_heur_tile_n_inner(args): 

167 if args["N"] <= (32 * 1024): 

168 return triton.next_power_of_2(args["N"]) 

169 else: 

170 return 4096 

171 

172 

173def softmax_heur_num_warps_inner(args): 

174 tile_size = args["TILE_N"] 

175 if tile_size < 2048: 

176 return 4 

177 elif tile_size < 4096: 

178 return 8 

179 else: 

180 return 16 

181 

182 

183def softmax_heur_tile_n_bwd_non_inner(args): 

184 return max(1, 1024 // args["TILE_K"]) 

185 

186 

187def softmax_heur_tile_m(args): 

188 return max(1, 1024 // args["TILE_N"]) 

189 

190 

191def uniform_heur_block(args): 

192 if args["N"] <= 512: 

193 return 512 

194 elif args["N"] >= 1073741824: 

195 return 4097 

196 else: 

197 return 1024 

198 

199 

200def uniform_heur_num_warps(args): 

201 if args["N"] <= 512: 

202 return 4 

203 elif args["N"] <= 1024: 

204 return 8 

205 else: 

206 return 16 

207 

208 

209def var_mean_heur_block_n(args): 

210 return triton.next_power_of_2(args["BLOCK_NUM"]) 

211 

212 

213def upsample_nearest2d_SAME_H(args): 

214 return args["OH"] == args["IH"] 

215 

216 

217def upsample_nearest2d_SAME_W(args): 

218 return args["OW"] == args["IW"] 

219 

220 

221def batch_norm_heur_block_m(args): 

222 return min(128, triton.next_power_of_2(args["batch_dim"])) 

223 

224 

225def batch_norm_heur_block_n(args): 

226 # A maximum of 4096 elements are loaded at once. 

227 BLOCK_M = batch_norm_heur_block_m(args) 

228 BLOCK_N = triton.next_power_of_2(args["spatial_dim"]) 

229 return min(BLOCK_N, max(1, 2**12 // BLOCK_M)) 

230 

231 

232def vdot_heur_block_size(args): 

233 n = args["n_elements"] 

234 if n < 1024: 

235 return 32 

236 elif n < 8192: 

237 return 256 

238 else: 

239 return 1024 

240 

241 

242HEURISTICS_CONFIGS = { 

243 "argmax_non_inner": { 

244 "TILE_K": argmax_heur_tile_k, 

245 "TILE_N": argmax_heur_tile_n_non_inner, 

246 "ONE_TILE_PER_CTA": argmax_heur_one_tile_per_cta, 

247 }, 

248 "argmax": { 

249 "BLOCK_M": argmax_heur_block_m, 

250 "BLOCK_N": argmax_heur_block_n, 

251 }, 

252 "argmin": { 

253 "BLOCK_M": argmin_heur_block_m, 

254 "BLOCK_N": argmin_heur_block_n, 

255 }, 

256 "bmm": { 

257 "DIVISIBLE_M": bmm_heur_divisible_m, 

258 "DIVISIBLE_N": bmm_heur_divisible_n, 

259 "DIVISIBLE_K": bmm_heur_divisible_k, 

260 }, 

261 "dropout": { 

262 "BLOCK": dropout_heur_block, 

263 "num_warps": dropout_heur_num_warps, 

264 }, 

265 "exponential_": { 

266 "BLOCK": exponential_heur_block, 

267 "num_warps": exponential_heur_num_warps, 

268 }, 

269 "gather": { 

270 "BLOCK_M": gather_heur_block_m, 

271 "BLOCK_N": gather_heur_block_n, 

272 }, 

273 "index_select": { 

274 "BLOCK_M": index_select_heur_block_m, 

275 "BLOCK_N": index_select_heur_block_n, 

276 }, 

277 "mm": { 

278 "EVEN_K": mm_heur_even_k, 

279 }, 

280 "rand": { 

281 "BLOCK": rand_heur_block, 

282 "num_warps": rand_heur_num_warps, 

283 }, 

284 "randn": { 

285 "BLOCK": randn_heur_block, 

286 "num_warps": randn_heur_num_warps, 

287 }, 

288 "softmax_non_inner": { 

289 "TILE_K": softmax_heur_tile_k, 

290 "TILE_N": softmax_heur_tile_n_non_inner, 

291 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, 

292 "num_warps": softmax_heur_num_warps_non_inner, 

293 }, 

294 "softmax_inner": { 

295 "TILE_N": softmax_heur_tile_n_inner, 

296 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, 

297 "num_warps": softmax_heur_num_warps_inner, 

298 }, 

299 "softmax_backward_non_inner": { 

300 "TILE_N": softmax_heur_tile_n_bwd_non_inner, 

301 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, 

302 }, 

303 "softmax_backward_inner": { 

304 "TILE_M": softmax_heur_tile_m, 

305 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, 

306 }, 

307 "uniform": { 

308 "BLOCK": uniform_heur_block, 

309 "num_warps": uniform_heur_num_warps, 

310 }, 

311 "upsample_nearest2d": { 

312 "SAME_H": upsample_nearest2d_SAME_H, 

313 "SAME_W": upsample_nearest2d_SAME_W, 

314 }, 

315 "var_mean": { 

316 "BLOCK_N": var_mean_heur_block_n, 

317 }, 

318 "batch_norm": { 

319 "BLOCK_M": batch_norm_heur_block_m, 

320 "BLOCK_N": batch_norm_heur_block_n, 

321 }, 

322 "vdot": { 

323 "BLOCK_SIZE": vdot_heur_block_size, 

324 }, 

325}