Coverage for src/flag_gems/runtime/backend/_iluvatar/heuristics_config_utils.py: 0%

131 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-19 02:32 +0800

1import torch 

2import triton 

3 

4 

5def simple_elementwise_blocksize_heur(args): 

6 return 1024 

7 

8 

9def argmax_heur_block_m(args): 

10 return 4 if args["M"] < 4096 else 8 

11 

12 

13def argmax_heur_block_n(args): 

14 return min(4096, triton.next_power_of_2(args["N"])) 

15 

16 

17def argmin_heur_block_m(args): 

18 return 4 if args["M"] < 4096 else 8 

19 

20 

21def argmin_heur_block_n(args): 

22 return min(4096, triton.next_power_of_2(args["N"])) 

23 

24 

25def dropout_heur_block(args): 

26 if args["N"] <= 512: 

27 return 512 

28 else: 

29 return 1024 

30 

31 

32def dropout_heur_num_warps(args): 

33 if args["N"] <= 512: 

34 return 4 

35 elif args["N"] <= 1024: 

36 return 8 

37 else: 

38 return 16 

39 

40 

41def exponential_heur_block(args): 

42 if args["N"] <= 512: 

43 return 512 

44 else: 

45 return 1024 

46 

47 

48def exponential_heur_num_warps(args): 

49 if args["N"] <= 512: 

50 return 4 

51 elif args["N"] <= 1024: 

52 return 8 

53 else: 

54 return 16 

55 

56 

57def gather_heur_block_m(args): 

58 return min(4, triton.next_power_of_2(triton.cdiv(args["N"], 2048))) 

59 

60 

61def gather_heur_block_n(args): 

62 return min(2048, triton.next_power_of_2(args["N"])) 

63 

64 

65def index_select_heur_block_m(args): 

66 return min(4, triton.next_power_of_2(triton.cdiv(256, args["N"]))) 

67 

68 

69def index_select_heur_block_n(args): 

70 m = min(triton.next_power_of_2(triton.cdiv(args["N"], 16)), 512) 

71 return max(m, 16) 

72 

73 

74def mm_heur_even_k(args): 

75 return args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0 

76 

77 

78def rand_heur_block(args): 

79 if args["N"] <= 512: 

80 return 512 

81 else: 

82 return 1024 

83 

84 

85def rand_heur_num_warps(args): 

86 if args["N"] <= 512: 

87 return 4 

88 elif args["N"] <= 1024: 

89 return 8 

90 else: 

91 return 16 

92 

93 

94def randn_heur_block(args): 

95 if args["N"] <= 512: 

96 return 512 

97 else: 

98 return 1024 

99 

100 

101def randn_heur_num_warps(args): 

102 if args["N"] <= 512: 

103 return 4 

104 else: 

105 return 8 

106 

107 

108def softmax_heur_tile_k(args): 

109 MAX_TILE_K = 8192 

110 NUM_SMS = torch.cuda.get_device_properties( 

111 torch.cuda.current_device() 

112 ).multi_processor_count 

113 tile_k = 1 

114 upper_bound = min(args["K"], MAX_TILE_K) 

115 while tile_k <= upper_bound: 

116 num_blocks = args["M"] * triton.cdiv(args["K"], tile_k) 

117 num_waves = num_blocks / NUM_SMS 

118 if (num_waves > 1) and (tile_k * 2 <= upper_bound): 

119 tile_k *= 2 

120 else: 

121 break 

122 return tile_k 

123 

124 

125def softmax_heur_tile_n_non_inner(args): 

126 return triton.cdiv(8192, args["TILE_K"]) 

127 

128 

129def softmax_heur_one_tile_per_cta(args): 

130 return args["TILE_N"] >= args["N"] 

131 

132 

133def softmax_heur_num_warps_non_inner(args): 

134 tile_size = args["TILE_N"] * args["TILE_K"] 

135 if tile_size < 2048: 

136 return 4 

137 elif tile_size < 4096: 

138 return 8 

139 else: 

140 return 16 

141 

142 

143def softmax_heur_tile_n_inner(args): 

144 if args["N"] <= (32 * 1024): 

145 return triton.next_power_of_2(args["N"]) 

146 else: 

147 return 4096 

148 

149 

150def softmax_heur_num_warps_inner(args): 

151 tile_size = args["TILE_N"] 

152 if tile_size < 2048: 

153 return 4 

154 elif tile_size < 4096: 

155 return 8 

156 else: 

157 return 16 

158 

159 

160def softmax_heur_tile_n_bwd_non_inner(args): 

161 return max(1, 1024 // args["TILE_K"]) 

162 

163 

164def softmax_heur_tile_m(args): 

165 return max(1, 1024 // args["TILE_N"]) 

166 

167 

168def uniform_heur_block(args): 

169 if args["N"] <= 512: 

170 return 512 

171 else: 

172 return 1024 

173 

174 

175def uniform_heur_num_warps(args): 

176 if args["N"] <= 512: 

177 return 4 

178 elif args["N"] <= 1024: 

179 return 8 

180 else: 

181 return 16 

182 

183 

184def var_mean_heur_block_n(args): 

185 return triton.next_power_of_2(args["BLOCK_NUM"]) 

186 

187 

188def upsample_nearest2d_SAME_H(args): 

189 return args["OH"] == args["IH"] 

190 

191 

192def upsample_nearest2d_SAME_W(args): 

193 return args["OW"] == args["IW"] 

194 

195 

196def upsample_nearest2d_USE_INT32_IDX(args): 

197 return args["N"] * args["C"] * args["OH"] * args["OW"] <= (2**31 - 1) # INT32 MAX 

198 

199 

200def batch_norm_heur_block_m(args): 

201 return min(2048, triton.next_power_of_2(args["batch_dim"])) 

202 

203 

204def batch_norm_heur_block_n(args): 

205 # A maximum of 16384 elements are loaded at once. 

206 BLOCK_M = batch_norm_heur_block_m(args) 

207 BLOCK_N = triton.next_power_of_2(args["spatial_dim"]) 

208 return min(BLOCK_N, max(1, 2**14 // BLOCK_M)) 

209 

210 

211def vdot_heur_block_size(args): 

212 n = args["n_elements"] 

213 if n < 1024: 

214 return 32 

215 elif n < 8192: 

216 return 256 

217 else: 

218 return 1024 

219 

220 

221HEURISTICS_CONFIGS = { 

222 "argmax": { 

223 "BLOCK_M": argmax_heur_block_m, 

224 "BLOCK_N": argmax_heur_block_n, 

225 }, 

226 "argmin": { 

227 "BLOCK_M": argmin_heur_block_m, 

228 "BLOCK_N": argmin_heur_block_n, 

229 }, 

230 "dropout": { 

231 "BLOCK": dropout_heur_block, 

232 "num_warps": dropout_heur_num_warps, 

233 }, 

234 "exponential_": { 

235 "BLOCK": exponential_heur_block, 

236 "num_warps": exponential_heur_num_warps, 

237 }, 

238 "gather": { 

239 "BLOCK_M": gather_heur_block_m, 

240 "BLOCK_N": gather_heur_block_n, 

241 }, 

242 "index_select": { 

243 "BLOCK_M": index_select_heur_block_m, 

244 "BLOCK_N": index_select_heur_block_n, 

245 }, 

246 "mm": { 

247 "EVEN_K": mm_heur_even_k, 

248 }, 

249 "rand": { 

250 "BLOCK": rand_heur_block, 

251 "num_warps": rand_heur_num_warps, 

252 }, 

253 "randn": { 

254 "BLOCK": randn_heur_block, 

255 "num_warps": randn_heur_num_warps, 

256 }, 

257 "softmax_non_inner": { 

258 "TILE_K": softmax_heur_tile_k, 

259 "TILE_N": softmax_heur_tile_n_non_inner, 

260 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, 

261 "num_warps": softmax_heur_num_warps_non_inner, 

262 }, 

263 "softmax_inner": { 

264 "TILE_N": softmax_heur_tile_n_inner, 

265 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, 

266 "num_warps": softmax_heur_num_warps_inner, 

267 }, 

268 "softmax_backward_non_inner": { 

269 "TILE_N": softmax_heur_tile_n_bwd_non_inner, 

270 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, 

271 }, 

272 "softmax_backward_inner": { 

273 "TILE_M": softmax_heur_tile_m, 

274 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, 

275 }, 

276 "uniform": { 

277 "BLOCK": uniform_heur_block, 

278 "num_warps": uniform_heur_num_warps, 

279 }, 

280 "upsample_nearest2d": { 

281 "SAME_H": upsample_nearest2d_SAME_H, 

282 "SAME_W": upsample_nearest2d_SAME_W, 

283 "USE_INT32_IDX": upsample_nearest2d_USE_INT32_IDX, 

284 }, 

285 "var_mean": { 

286 "BLOCK_N": var_mean_heur_block_n, 

287 }, 

288 "batch_norm": { 

289 "BLOCK_M": batch_norm_heur_block_m, 

290 "BLOCK_N": batch_norm_heur_block_n, 

291 }, 

292 "vdot": { 

293 "BLOCK_SIZE": vdot_heur_block_size, 

294 }, 

295 "mha_block_128": { 

296 "BLOCK_M": lambda args: 128, 

297 "BLOCK_N": lambda args: 32, 

298 "num_warps": lambda args: 4, 

299 "num_stages": lambda args: 3, 

300 }, 

301 "mha_block_64": { 

302 "BLOCK_M": lambda args: 64, 

303 "BLOCK_N": lambda args: 32, 

304 "num_warps": lambda args: 4, 

305 "num_stages": lambda args: 3, 

306 }, 

307 "mha_block_32": { 

308 "BLOCK_M": lambda args: 32, 

309 "BLOCK_N": lambda args: 32, 

310 "num_warps": lambda args: 4, 

311 "num_stages": lambda args: 3, 

312 }, 

313 "mha_block_16": { 

314 "BLOCK_M": lambda args: 16, 

315 "BLOCK_N": lambda args: 32, 

316 "num_warps": lambda args: 4, 

317 "num_stages": lambda args: 3, 

318 }, 

319 "elementwise_generic": { 

320 "BLOCK_SIZE": simple_elementwise_blocksize_heur, 

321 "num_warps": lambda args: 16, 

322 }, 

323}