Coverage for src/flag_gems/runtime/backend/_kunlunxin/heuristics_config_utils.py: 0%

139 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-19 02:32 +0800

1import torch 

2import triton 

3 

4 

5def simple_elementwise_blocksize_heur(args): 

6 return 1024 

7 

8 

9def argmax_heur_block_m(args): 

10 return 4 if args["M"] < 4096 else 8 

11 

12 

13def argmax_heur_block_n(args): 

14 return min(4096, triton.next_power_of_2(args["N"])) 

15 

16 

17def argmin_heur_block_m(args): 

18 return triton.next_power_of_2(triton.cdiv(args["M"], 12)) # cluster_num 

19 

20 

21def argmin_heur_block_n(args): 

22 import builtins 

23 

24 return builtins.min(triton.next_power_of_2(args["N"]), 8192) 

25 

26 

27def bmm_heur_divisible_m(args): 

28 return args["M"] % args["TILE_M"] == 0 

29 

30 

31def bmm_heur_divisible_n(args): 

32 return args["N"] % args["TILE_N"] == 0 

33 

34 

35def bmm_heur_divisible_k(args): 

36 return args["K"] % args["TILE_K"] == 0 

37 

38 

39def dropout_heur_block(args): 

40 if args["N"] <= 512: 

41 return 512 

42 else: 

43 return 1024 

44 

45 

46def dropout_heur_num_warps(args): 

47 if args["N"] <= 512: 

48 return 4 

49 elif args["N"] <= 1024: 

50 return 8 

51 else: 

52 return 16 

53 

54 

55def exponential_heur_block(args): 

56 if args["N"] <= 512: 

57 return 512 

58 else: 

59 return 1024 

60 

61 

62def exponential_heur_num_warps(args): 

63 if args["N"] <= 512: 

64 return 4 

65 elif args["N"] <= 1024: 

66 return 8 

67 else: 

68 return 16 

69 

70 

71def gather_heur_block_m(args): 

72 return min(4, triton.next_power_of_2(triton.cdiv(args["N"], 2048))) 

73 

74 

75def gather_heur_block_n(args): 

76 return min(2048, triton.next_power_of_2(args["N"])) 

77 

78 

79def index_add_heur_block_m(args): 

80 return triton.next_power_of_2(triton.cdiv(args["M"], 12)) # cluster_num 

81 

82 

83def index_add_heur_block_n(args): 

84 # if args["N"] > 8192: 

85 # return 64 

86 # if args["N"] > 256: 

87 # return 256 

88 

89 # return args["N"] 

90 return min(8192, triton.next_power_of_2(args["N"])) 

91 

92 

93def index_select_heur_block_m(args): 

94 return triton.next_power_of_2(triton.cdiv(args["M"], 12)) # cluster_num 

95 

96 

97def index_select_heur_block_n(args): 

98 return 64 

99 

100 

101def mm_heur_even_k(args): 

102 return args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0 

103 

104 

105def rand_heur_block(args): 

106 return triton.next_power_of_2(triton.cdiv(args["N"], 12 * 4)) # CLUSTER_NUM = 12 

107 if args["N"] <= 512: 

108 return 512 

109 else: 

110 return 1024 

111 

112 

113def rand_heur_num_warps(args): 

114 if args["N"] <= 512: 

115 return 4 

116 elif args["N"] <= 1024: 

117 return 8 

118 else: 

119 return 16 

120 

121 

122def randn_heur_block(args): 

123 if args["N"] <= 512: 

124 return 512 

125 else: 

126 return 1024 

127 

128 

129def randn_heur_num_warps(args): 

130 if args["N"] <= 512: 

131 return 4 

132 elif args["N"] <= 1024: 

133 return 8 

134 else: 

135 return 16 

136 

137 

138def softmax_heur_tile_k(args): 

139 MAX_TILE_K = 8192 

140 NUM_SMS = torch.cuda.get_device_properties( 

141 torch.cuda.current_device() 

142 ).multi_processor_count 

143 tile_k = 1 

144 upper_bound = min(args["K"], MAX_TILE_K) 

145 while tile_k <= upper_bound: 

146 num_blocks = args["M"] * triton.cdiv(args["K"], tile_k) 

147 num_waves = num_blocks / NUM_SMS 

148 if (num_waves > 1) and (tile_k * 2 <= upper_bound): 

149 tile_k *= 2 

150 else: 

151 break 

152 return tile_k 

153 

154 

155def softmax_heur_tile_n_non_inner(args): 

156 return triton.cdiv(8192, args["TILE_K"]) 

157 

158 

159def softmax_heur_one_tile_per_cta(args): 

160 return args["TILE_N"] >= args["N"] 

161 

162 

163def softmax_heur_num_warps_non_inner(args): 

164 tile_size = args["TILE_N"] * args["TILE_K"] 

165 if tile_size < 2048: 

166 return 4 

167 elif tile_size < 4096: 

168 return 8 

169 else: 

170 return 16 

171 

172 

173def softmax_heur_tile_n_inner(args): 

174 if args["N"] <= (32 * 1024): 

175 return triton.next_power_of_2(args["N"]) 

176 else: 

177 return 4096 

178 

179 

180def softmax_heur_num_warps_inner(args): 

181 tile_size = args["TILE_N"] 

182 if tile_size < 2048: 

183 return 4 

184 elif tile_size < 4096: 

185 return 8 

186 else: 

187 return 16 

188 

189 

190def softmax_heur_tile_n_bwd_non_inner(args): 

191 return max(1, 1024 // args["TILE_K"]) 

192 

193 

194def softmax_heur_tile_m(args): 

195 return max(1, 1024 // args["TILE_N"]) 

196 

197 

198def uniform_heur_block(args): 

199 if args["N"] <= 512: 

200 return 512 

201 else: 

202 return 1024 

203 

204 

205def uniform_heur_num_warps(args): 

206 if args["N"] <= 512: 

207 return 4 

208 elif args["N"] <= 1024: 

209 return 8 

210 else: 

211 return 16 

212 

213 

214def var_mean_heur_block_n(args): 

215 return triton.next_power_of_2(args["BLOCK_NUM"]) 

216 

217 

218def upsample_nearest2d_SAME_H(args): 

219 return args["OH"] == args["IH"] 

220 

221 

222def upsample_nearest2d_SAME_W(args): 

223 return args["OW"] == args["IW"] 

224 

225 

226def batch_norm_heur_block_m(args): 

227 return min(2048, triton.next_power_of_2(args["batch_dim"])) 

228 

229 

230def batch_norm_heur_block_n(args): 

231 # A maximum of 16384 elements are loaded at once. 

232 BLOCK_M = batch_norm_heur_block_m(args) 

233 BLOCK_N = triton.next_power_of_2(args["spatial_dim"]) 

234 return min(BLOCK_N, max(1, 2**14 // BLOCK_M)) 

235 

236 

237def vdot_heur_block_size(args): 

238 n = args["n_elements"] 

239 if n < 1024: 

240 return 32 

241 elif n < 8192: 

242 return 256 

243 else: 

244 return 1024 

245 

246 

247HEURISTICS_CONFIGS = { 

248 "argmax": { 

249 "BLOCK_M": argmax_heur_block_m, 

250 "BLOCK_N": argmax_heur_block_n, 

251 }, 

252 "argmin": { 

253 "BLOCK_M": argmin_heur_block_m, 

254 "BLOCK_N": argmin_heur_block_n, 

255 }, 

256 "bmm": { 

257 "DIVISIBLE_M": bmm_heur_divisible_m, 

258 "DIVISIBLE_N": bmm_heur_divisible_n, 

259 "DIVISIBLE_K": bmm_heur_divisible_k, 

260 }, 

261 "dropout": { 

262 "BLOCK": dropout_heur_block, 

263 "num_warps": dropout_heur_num_warps, 

264 }, 

265 "exponential_": { 

266 "BLOCK": exponential_heur_block, 

267 "num_warps": exponential_heur_num_warps, 

268 }, 

269 "gather": { 

270 "BLOCK_M": gather_heur_block_m, 

271 "BLOCK_N": gather_heur_block_n, 

272 }, 

273 "index_select": { 

274 "BLOCK_M": index_select_heur_block_m, 

275 "BLOCK_N": index_select_heur_block_n, 

276 }, 

277 "index_add": { 

278 "BLOCK_M": index_add_heur_block_m, 

279 "BLOCK_N": index_add_heur_block_n, 

280 }, 

281 "mm": { 

282 "EVEN_K": mm_heur_even_k, 

283 }, 

284 "rand": { 

285 "BLOCK": rand_heur_block, 

286 "num_warps": rand_heur_num_warps, 

287 }, 

288 "randn": { 

289 "BLOCK": randn_heur_block, 

290 "num_warps": randn_heur_num_warps, 

291 }, 

292 "softmax_non_inner": { 

293 "TILE_K": softmax_heur_tile_k, 

294 "TILE_N": softmax_heur_tile_n_non_inner, 

295 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, 

296 "num_warps": softmax_heur_num_warps_non_inner, 

297 }, 

298 "softmax_inner": { 

299 "TILE_N": softmax_heur_tile_n_inner, 

300 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, 

301 "num_warps": softmax_heur_num_warps_inner, 

302 }, 

303 "softmax_backward_non_inner": { 

304 "TILE_N": softmax_heur_tile_n_bwd_non_inner, 

305 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, 

306 }, 

307 "softmax_backward_inner": { 

308 "TILE_M": softmax_heur_tile_m, 

309 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, 

310 }, 

311 "uniform": { 

312 "BLOCK": uniform_heur_block, 

313 "num_warps": uniform_heur_num_warps, 

314 }, 

315 "upsample_nearest2d": { 

316 "SAME_H": upsample_nearest2d_SAME_H, 

317 "SAME_W": upsample_nearest2d_SAME_W, 

318 }, 

319 "var_mean": { 

320 "BLOCK_N": var_mean_heur_block_n, 

321 }, 

322 "batch_norm": { 

323 "BLOCK_M": batch_norm_heur_block_m, 

324 "BLOCK_N": batch_norm_heur_block_n, 

325 }, 

326 "vdot": { 

327 "BLOCK_SIZE": vdot_heur_block_size, 

328 }, 

329 "elementwise_generic": { 

330 "BLOCK_SIZE": simple_elementwise_blocksize_heur, 

331 "num_warps": lambda args: 8, 

332 }, 

333}