Coverage for src/flag_gems/runtime/backend/_enflame/heuristics_config_utils.py: 0%

114 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-27 02:51 +0800

1import torch 

2import triton 

3 

4 

5def argmax_heur_block_m(args): 

6 return 4 if args["M"] < 4096 else 8 

7 

8 

9def argmax_heur_block_n(args): 

10 return min(4096, triton.next_power_of_2(args["N"])) 

11 

12 

13def argmin_heur_block_m(args): 

14 return 4 if args["M"] < 4096 else 8 

15 

16 

17def argmin_heur_block_n(args): 

18 return min(4096, triton.next_power_of_2(args["N"])) 

19 

20 

21# def bmm_heur_divisible_m(args): 

22# return args["M"] % args["BLOCK_M"] == 0 

23 

24 

25# def bmm_heur_divisible_n(args): 

26# return args["N"] % args["BLOCK_N"] == 0 

27 

28 

29# def bmm_heur_divisible_k(args): 

30# return args["K"] % args["BLOCK_K"] == 0 

31 

32 

33def dropout_heur_block(args): 

34 if args["N"] <= 512: 

35 return 512 

36 else: 

37 return 4096 

38 

39 

40def dropout_heur_num_warps(args): 

41 return 4 

42 

43 

44def exponential_heur_block(args): 

45 if args["N"] <= 512: 

46 return 512 

47 else: 

48 return 16384 

49 

50 

51def exponential_heur_num_warps(args): 

52 return 4 

53 

54 

55def gather_heur_block_m(args): 

56 return min(4, triton.next_power_of_2(triton.cdiv(args["N"], 2048))) 

57 

58 

59def gather_heur_block_n(args): 

60 return min(2048, triton.next_power_of_2(args["N"])) 

61 

62 

63def index_select_heur_block_m(args): 

64 return min(16, triton.next_power_of_2(triton.cdiv(32768, args["N"]))) 

65 

66 

67def index_select_heur_block_n(args): 

68 m = min(triton.next_power_of_2(triton.cdiv(args["N"], 16)), 512) 

69 return max(m, 16) 

70 

71 

72def mm_heur_even_k(args): 

73 return args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0 

74 

75 

76def rand_heur_block(args): 

77 if args["N"] <= 512: 

78 return 512 

79 else: 

80 return 16384 

81 

82 

83def rand_heur_num_warps(args): 

84 return 4 

85 

86 

87def randn_heur_block(args): 

88 if args["N"] <= 512: 

89 return 512 

90 else: 

91 return 16384 

92 

93 

94def randn_heur_num_warps(args): 

95 return 4 

96 

97 

98def softmax_heur_tile_k(args): 

99 MAX_TILE_K = 8192 

100 NUM_SMS = torch.cuda.get_device_properties( 

101 torch.cuda.current_device() 

102 ).multi_processor_count 

103 tile_k = 1 

104 upper_bound = min(args["K"], MAX_TILE_K) 

105 while tile_k <= upper_bound: 

106 num_blocks = args["M"] * triton.cdiv(args["K"], tile_k) 

107 num_waves = num_blocks / NUM_SMS 

108 if (num_waves > 1) and (tile_k * 2 <= upper_bound): 

109 tile_k *= 2 

110 else: 

111 break 

112 return tile_k 

113 

114 

115def softmax_heur_tile_n_non_inner(args): 

116 return triton.cdiv(8192, args["TILE_K"]) 

117 

118 

119def softmax_heur_one_tile_per_cta(args): 

120 return args["TILE_N"] >= args["N"] 

121 

122 

123def softmax_heur_num_warps_non_inner(args): 

124 return 4 

125 

126 

127def softmax_heur_tile_n_inner(args): 

128 if args["N"] <= (32 * 1024): 

129 return triton.next_power_of_2(args["N"]) 

130 else: 

131 return 4096 

132 

133 

134def softmax_heur_num_warps_inner(args): 

135 return 4 

136 

137 

138def softmax_heur_tile_n_bwd_non_inner(args): 

139 return max(1, 1024 // args["TILE_K"]) 

140 

141 

142def softmax_heru_tile_m(args): 

143 return max(1, 1024 // args["TILE_N"]) 

144 

145 

146def uniform_heur_block(args): 

147 if args["N"] <= 512: 

148 return 512 

149 else: 

150 return 16384 

151 

152 

153def uniform_heur_num_warps(args): 

154 return 4 

155 

156 

157def var_mean_heur_block_n(args): 

158 return triton.next_power_of_2(args["BLOCK_NUM"]) 

159 

160 

161def upsample_nearest2d_NUM_TILE(args): 

162 grid_y = triton.cdiv(args["N"] * args["C"], 4) 

163 if grid_y <= 128: 

164 num_tile = 1 

165 else: 

166 num_tile = triton.cdiv(grid_y, 128) 

167 return num_tile 

168 

169 

170def upsample_nearest2d_TOTAL_TILE(args): 

171 return triton.cdiv(args["N"] * args["C"], 4) 

172 

173 

174def upsample_nearest2d_SAME_H(args): 

175 return args["OH"] == args["IH"] 

176 

177 

178def upsample_nearest2d_SAME_W(args): 

179 return args["OW"] == args["IW"] 

180 

181 

182def upsample_nearest2d_USE_INT32_IDX(args): 

183 return args["N"] * args["C"] * args["OH"] * args["OW"] <= (2**31 - 1) # INT32 MAX 

184 

185 

186def batch_norm_heur_block_m(args): 

187 return min(2048, triton.next_power_of_2(args["batch_dim"])) 

188 

189 

190def batch_norm_heur_block_n(args): 

191 # A maximum of 16384 elements are loaded at once. 

192 BLOCK_M = batch_norm_heur_block_m(args) 

193 BLOCK_N = triton.next_power_of_2(args["spatial_dim"]) 

194 return min(BLOCK_N, max(1, 2**14 // BLOCK_M)) 

195 

196 

197def vdot_heur_block_size(args): 

198 n = args["n_elements"] 

199 if n < 1024: 

200 return 32 

201 elif n < 8192: 

202 return 256 

203 else: 

204 return 1024 

205 

206 

207def simple_elementwise_blocksize_heur(args): 

208 n = args["n_elements"] 

209 if n < 65535: 

210 return 1024 

211 else: 

212 return 16384 

213 

214 

215HEURISTICS_CONFIGS = { 

216 "argmax": { 

217 "BLOCK_M": argmax_heur_block_m, 

218 "BLOCK_N": argmax_heur_block_n, 

219 }, 

220 "argmin": { 

221 "BLOCK_M": argmin_heur_block_m, 

222 "BLOCK_N": argmin_heur_block_n, 

223 }, 

224 "bmm": { 

225 # "DIVISIBLE_M": bmm_heur_divisible_m, 

226 # "DIVISIBLE_N": bmm_heur_divisible_n, 

227 # "DIVISIBLE_K": bmm_heur_divisible_k, 

228 }, 

229 "dropout": { 

230 "BLOCK": dropout_heur_block, 

231 "num_warps": dropout_heur_num_warps, 

232 }, 

233 "exponential_": { 

234 "BLOCK": exponential_heur_block, 

235 "num_warps": exponential_heur_num_warps, 

236 }, 

237 "gather": { 

238 "BLOCK_M": gather_heur_block_m, 

239 "BLOCK_N": gather_heur_block_n, 

240 }, 

241 "index_select": { 

242 "BLOCK_M": index_select_heur_block_m, 

243 "BLOCK_N": index_select_heur_block_n, 

244 }, 

245 "mm": { 

246 "EVEN_K": mm_heur_even_k, 

247 }, 

248 "rand": { 

249 "BLOCK": rand_heur_block, 

250 "num_warps": rand_heur_num_warps, 

251 }, 

252 "randn": { 

253 "BLOCK": randn_heur_block, 

254 "num_warps": randn_heur_num_warps, 

255 }, 

256 "softmax_non_inner": { 

257 "TILE_K": softmax_heur_tile_k, 

258 "TILE_N": softmax_heur_tile_n_non_inner, 

259 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, 

260 "num_warps": softmax_heur_num_warps_non_inner, 

261 }, 

262 "softmax_inner": { 

263 "TILE_N": softmax_heur_tile_n_inner, 

264 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, 

265 "num_warps": softmax_heur_num_warps_inner, 

266 }, 

267 "softmax_backward_non_inner": { 

268 "TILE_N": softmax_heur_tile_n_bwd_non_inner, 

269 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, 

270 }, 

271 "softmax_backward_inner": { 

272 "TILE_M": softmax_heru_tile_m, 

273 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, 

274 }, 

275 "uniform": { 

276 "BLOCK": uniform_heur_block, 

277 "num_warps": uniform_heur_num_warps, 

278 }, 

279 "upsample_nearest2d": { 

280 "NUM_TILE": upsample_nearest2d_NUM_TILE, 

281 "TOTAL_TILE": upsample_nearest2d_TOTAL_TILE, 

282 "SAME_H": upsample_nearest2d_SAME_H, 

283 "SAME_W": upsample_nearest2d_SAME_W, 

284 "USE_INT32_IDX": upsample_nearest2d_USE_INT32_IDX, 

285 }, 

286 "var_mean": { 

287 "BLOCK_N": var_mean_heur_block_n, 

288 }, 

289 "batch_norm": { 

290 "BLOCK_M": batch_norm_heur_block_m, 

291 "BLOCK_N": batch_norm_heur_block_n, 

292 }, 

293 "vdot": { 

294 "BLOCK_SIZE": vdot_heur_block_size, 

295 }, 

296 "elementwise_generic": { 

297 "BLOCK_SIZE": simple_elementwise_blocksize_heur, 

298 "num_warps": lambda args: 4, 

299 }, 

300 "mha_varlen_fwd": { 

301 "BLOCK_M": lambda args: 128, 

302 "BLOCK_N": lambda args: 32, 

303 "num_warps": lambda args: 4, 

304 "num_stages": lambda args: 3, 

305 }, 

306}