Coverage for src/flag_gems/runtime/backend/_hygon/heuristics_config_utils.py: 0%

139 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-29 04:01 +0800

1import torch 

2import triton 

3 

4 

5def simple_elementwise_blocksize_heur(args): 

6 return 1024 

7 

8 

9def argmax_heur_block_m(args): 

10 return 4 if args["M"] < 4096 else 8 

11 

12 

13def argmax_heur_block_n(args): 

14 return min(4096, triton.next_power_of_2(args["N"])) 

15 

16 

17def argmin_heur_block_m(args): 

18 return 4 if args["M"] < 4096 else 8 

19 

20 

21def argmin_heur_block_n(args): 

22 return min(4096, triton.next_power_of_2(args["N"])) 

23 

24 

25def bmm_heur_divisible_m(args): 

26 return args["M"] % args["TILE_M"] == 0 

27 

28 

29def bmm_heur_divisible_n(args): 

30 return args["N"] % args["TILE_N"] == 0 

31 

32 

33def bmm_heur_divisible_k(args): 

34 return args["K"] % args["TILE_K"] == 0 

35 

36 

37def dropout_heur_block(args): 

38 if args["N"] <= 512: 

39 return 512 

40 elif args["N"] <= 1024: 

41 return 1024 

42 else: 

43 return 4096 

44 

45 

46def dropout_heur_num_warps(args): 

47 if args["N"] <= 512: 

48 return 4 

49 elif args["N"] <= 1024: 

50 return 8 

51 else: 

52 return 16 

53 

54 

55def exponential_heur_block(args): 

56 if args["N"] <= 512: 

57 return 512 

58 else: 

59 return 1024 

60 

61 

62def exponential_heur_num_warps(args): 

63 if args["N"] <= 512: 

64 return 4 

65 elif args["N"] <= 1024: 

66 return 8 

67 else: 

68 return 16 

69 

70 

71def gather_heur_block_m(args): 

72 return min(4, triton.next_power_of_2(triton.cdiv(args["N"], 2048))) 

73 

74 

75def gather_heur_block_n(args): 

76 return min(2048, triton.next_power_of_2(args["N"])) 

77 

78 

79def index_select_heur_block_m(args): 

80 return min(4, triton.next_power_of_2(triton.cdiv(256, args["N"]))) 

81 

82 

83def index_select_heur_block_n(args): 

84 m = min(triton.next_power_of_2(triton.cdiv(args["N"], 16)), 512) 

85 return max(m, 16) 

86 

87 

88def mm_heur_even_k(args): 

89 return args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0 

90 

91 

92def rand_heur_block(args): 

93 if args["N"] <= 512: 

94 return 512 

95 else: 

96 return 1024 

97 

98 

99def rand_heur_num_warps(args): 

100 if args["N"] <= 512: 

101 return 4 

102 elif args["N"] <= 1024: 

103 return 8 

104 else: 

105 return 16 

106 

107 

108def randn_heur_block(args): 

109 if args["N"] <= 512: 

110 return 512 

111 else: 

112 return 1024 

113 

114 

115def randn_heur_num_warps(args): 

116 if args["N"] <= 512: 

117 return 4 

118 elif args["N"] <= 1024: 

119 return 8 

120 else: 

121 return 16 

122 

123 

124def softmax_heur_tile_k(args): 

125 MAX_TILE_K = 8192 

126 NUM_SMS = torch.cuda.get_device_properties( 

127 torch.cuda.current_device() 

128 ).multi_processor_count 

129 tile_k = 1 

130 upper_bound = min(args["K"], MAX_TILE_K) 

131 while tile_k <= upper_bound: 

132 num_blocks = args["M"] * triton.cdiv(args["K"], tile_k) 

133 num_waves = num_blocks / NUM_SMS 

134 if (num_waves > 1) and (tile_k * 2 <= upper_bound): 

135 tile_k *= 2 

136 else: 

137 break 

138 return tile_k 

139 

140 

141def softmax_heur_tile_n_non_inner(args): 

142 return triton.cdiv(8192, args["TILE_K"]) 

143 

144 

145def softmax_heur_one_tile_per_cta(args): 

146 return args["TILE_N"] >= args["N"] 

147 

148 

149def softmax_heur_num_warps_non_inner(args): 

150 tile_size = args["TILE_N"] * args["TILE_K"] 

151 if tile_size < 2048: 

152 return 4 

153 elif tile_size < 4096: 

154 return 8 

155 else: 

156 return 16 

157 

158 

159def softmax_heur_tile_n_inner(args): 

160 if args["N"] <= (32 * 1024): 

161 return triton.next_power_of_2(args["N"]) 

162 else: 

163 return 4096 

164 

165 

166def softmax_heur_num_warps_inner(args): 

167 tile_size = args["TILE_N"] 

168 if tile_size < 2048: 

169 return 4 

170 elif tile_size < 4096: 

171 return 8 

172 else: 

173 return 16 

174 

175 

176def softmax_heur_tile_n_bwd_non_inner(args): 

177 return max(1, 1024 // args["TILE_K"]) 

178 

179 

180def softmax_heur_tile_m(args): 

181 return max(1, 1024 // args["TILE_N"]) 

182 

183 

184def uniform_heur_block(args): 

185 if args["N"] <= 512: 

186 return 512 

187 else: 

188 return 1024 

189 

190 

191def uniform_heur_num_warps(args): 

192 if args["N"] <= 512: 

193 return 4 

194 elif args["N"] <= 1024: 

195 return 8 

196 else: 

197 return 16 

198 

199 

200def var_mean_heur_block_n(args): 

201 return triton.next_power_of_2(args["BLOCK_NUM"]) 

202 

203 

204def upsample_nearest2d_SAME_H(args): 

205 return args["OH"] == args["IH"] 

206 

207 

208def upsample_nearest2d_SAME_W(args): 

209 return args["OW"] == args["IW"] 

210 

211 

212def batch_norm_heur_block_m(args): 

213 return min(2048, triton.next_power_of_2(args["batch_dim"])) 

214 

215 

216def batch_norm_heur_block_n(args): 

217 # A maximum of 16384 elements are loaded at once. 

218 BLOCK_M = batch_norm_heur_block_m(args) 

219 BLOCK_N = triton.next_power_of_2(args["spatial_dim"]) 

220 return min(BLOCK_N, max(1, 2**14 // BLOCK_M)) 

221 

222 

223def vdot_heur_block_size(args): 

224 n = args["n_elements"] 

225 if n < 1024: 

226 return 32 

227 elif n < 8192: 

228 return 256 

229 else: 

230 return 1024 

231 

232 

233HEURISTICS_CONFIGS = { 

234 "argmax": { 

235 "BLOCK_M": argmax_heur_block_m, 

236 "BLOCK_N": argmax_heur_block_n, 

237 }, 

238 "argmin": { 

239 "BLOCK_M": argmin_heur_block_m, 

240 "BLOCK_N": argmin_heur_block_n, 

241 }, 

242 "bmm": { 

243 "DIVISIBLE_M": bmm_heur_divisible_m, 

244 "DIVISIBLE_N": bmm_heur_divisible_n, 

245 "DIVISIBLE_K": bmm_heur_divisible_k, 

246 }, 

247 "dropout": { 

248 "BLOCK": dropout_heur_block, 

249 "num_warps": dropout_heur_num_warps, 

250 }, 

251 "exponential_": { 

252 "BLOCK": exponential_heur_block, 

253 "num_warps": exponential_heur_num_warps, 

254 }, 

255 "gather": { 

256 "BLOCK_M": gather_heur_block_m, 

257 "BLOCK_N": gather_heur_block_n, 

258 }, 

259 "index_select": { 

260 "BLOCK_M": index_select_heur_block_m, 

261 "BLOCK_N": index_select_heur_block_n, 

262 }, 

263 "mm": { 

264 "EVEN_K": mm_heur_even_k, 

265 }, 

266 "rand": { 

267 "BLOCK": rand_heur_block, 

268 "num_warps": rand_heur_num_warps, 

269 }, 

270 "randn": { 

271 "BLOCK": randn_heur_block, 

272 "num_warps": randn_heur_num_warps, 

273 }, 

274 "softmax_non_inner": { 

275 "TILE_K": softmax_heur_tile_k, 

276 "TILE_N": softmax_heur_tile_n_non_inner, 

277 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, 

278 "num_warps": softmax_heur_num_warps_non_inner, 

279 }, 

280 "softmax_inner": { 

281 "TILE_N": softmax_heur_tile_n_inner, 

282 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, 

283 "num_warps": softmax_heur_num_warps_inner, 

284 }, 

285 "softmax_backward_non_inner": { 

286 "TILE_N": softmax_heur_tile_n_bwd_non_inner, 

287 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, 

288 }, 

289 "softmax_backward_inner": { 

290 "TILE_M": softmax_heur_tile_m, 

291 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, 

292 }, 

293 "uniform": { 

294 "BLOCK": uniform_heur_block, 

295 "num_warps": uniform_heur_num_warps, 

296 }, 

297 "upsample_nearest2d": { 

298 "SAME_H": upsample_nearest2d_SAME_H, 

299 "SAME_W": upsample_nearest2d_SAME_W, 

300 }, 

301 "var_mean": { 

302 "BLOCK_N": var_mean_heur_block_n, 

303 }, 

304 "batch_norm": { 

305 "BLOCK_M": batch_norm_heur_block_m, 

306 "BLOCK_N": batch_norm_heur_block_n, 

307 }, 

308 "vdot": { 

309 "BLOCK_SIZE": vdot_heur_block_size, 

310 }, 

311 "elementwise_generic": { 

312 "BLOCK_SIZE": simple_elementwise_blocksize_heur, 

313 "num_warps": lambda args: 8, 

314 }, 

315}