Coverage for src/flag_gems/runtime/backend/_aipu/heuristics_config_utils.py: 0%

136 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-10 02:30 +0800

1import triton 

2 

3 

4def simple_elementwise_blocksize_heur(args): 

5 return 1024 

6 

7 

8def argmax_heur_block_m(args): 

9 return 4 if args["M"] < 4096 else 8 

10 

11 

12def argmax_heur_block_n(args): 

13 return min(4096, triton.next_power_of_2(args["N"])) 

14 

15 

16def argmin_heur_block_m(args): 

17 return 4 if args["M"] < 4096 else 8 

18 

19 

20def argmin_heur_block_n(args): 

21 return min(4096, triton.next_power_of_2(args["N"])) 

22 

23 

24def bmm_heur_divisible_m(args): 

25 return args["M"] % args["TILE_M"] == 0 

26 

27 

28def bmm_heur_divisible_n(args): 

29 return args["N"] % args["TILE_N"] == 0 

30 

31 

32def bmm_heur_divisible_k(args): 

33 return args["K"] % args["TILE_K"] == 0 

34 

35 

36def dropout_heur_block(args): 

37 if args["N"] <= 512: 

38 return 512 

39 else: 

40 return 1024 

41 

42 

43def dropout_heur_num_warps(args): 

44 if args["N"] <= 512: 

45 return 4 

46 elif args["N"] <= 1024: 

47 return 8 

48 else: 

49 return 16 

50 

51 

52def exponential_heur_block(args): 

53 if args["N"] <= 512: 

54 return 512 

55 else: 

56 return 1024 

57 

58 

59def exponential_heur_num_warps(args): 

60 if args["N"] <= 512: 

61 return 4 

62 elif args["N"] <= 1024: 

63 return 8 

64 else: 

65 return 16 

66 

67 

68def gather_heur_block_m(args): 

69 return min(4, triton.next_power_of_2(triton.cdiv(args["N"], 2048))) 

70 

71 

72def gather_heur_block_n(args): 

73 return min(2048, triton.next_power_of_2(args["N"])) 

74 

75 

76def index_select_heur_block_m(args): 

77 return min(4, triton.next_power_of_2(triton.cdiv(256, args["N"]))) 

78 

79 

80def index_select_heur_block_n(args): 

81 m = min(triton.next_power_of_2(triton.cdiv(args["N"], 16)), 512) 

82 return max(m, 16) 

83 

84 

85def mm_heur_even_k(args): 

86 return args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0 

87 

88 

89def rand_heur_block(args): 

90 if args["N"] <= 512: 

91 return 512 

92 else: 

93 return 1024 

94 

95 

96def rand_heur_num_warps(args): 

97 if args["N"] <= 512: 

98 return 4 

99 elif args["N"] <= 1024: 

100 return 8 

101 else: 

102 return 16 

103 

104 

105def randn_heur_block(args): 

106 if args["N"] <= 512: 

107 return 512 

108 else: 

109 return 1024 

110 

111 

112def randn_heur_num_warps(args): 

113 if args["N"] <= 512: 

114 return 4 

115 elif args["N"] <= 1024: 

116 return 8 

117 else: 

118 return 16 

119 

120 

121def softmax_heur_tile_k(args): 

122 MAX_TILE_K = 8192 

123 NUM_SMS = 4 

124 tile_k = 1 

125 upper_bound = min(args["K"], MAX_TILE_K) 

126 while tile_k <= upper_bound: 

127 num_blocks = args["M"] * triton.cdiv(args["K"], tile_k) 

128 num_waves = num_blocks / NUM_SMS 

129 if (num_waves > 1) and (tile_k * 2 <= upper_bound): 

130 tile_k *= 2 

131 else: 

132 break 

133 return tile_k 

134 

135 

136def softmax_heur_tile_n_non_inner(args): 

137 return triton.cdiv(8192, args["TILE_K"]) 

138 

139 

140def softmax_heur_one_tile_per_cta(args): 

141 return args["TILE_N"] >= args["N"] 

142 

143 

144def softmax_heur_num_warps_non_inner(args): 

145 tile_size = args["TILE_N"] * args["TILE_K"] 

146 if tile_size < 2048: 

147 return 4 

148 elif tile_size < 4096: 

149 return 8 

150 else: 

151 return 16 

152 

153 

154def softmax_heur_tile_n_inner(args): 

155 if args["N"] <= (32 * 1024): 

156 return triton.next_power_of_2(args["N"]) 

157 else: 

158 return 4096 

159 

160 

161def softmax_heur_num_warps_inner(args): 

162 tile_size = args["TILE_N"] 

163 if tile_size < 2048: 

164 return 4 

165 elif tile_size < 4096: 

166 return 8 

167 else: 

168 return 16 

169 

170 

171def softmax_heur_tile_n_bwd_non_inner(args): 

172 return max(1, 1024 // args["TILE_K"]) 

173 

174 

175def softmax_heur_tile_m(args): 

176 return max(1, 1024 // args["TILE_N"]) 

177 

178 

179def uniform_heur_block(args): 

180 if args["N"] <= 512: 

181 return 512 

182 else: 

183 return 1024 

184 

185 

186def uniform_heur_num_warps(args): 

187 if args["N"] <= 512: 

188 return 4 

189 elif args["N"] <= 1024: 

190 return 8 

191 else: 

192 return 16 

193 

194 

195def var_mean_heur_block_n(args): 

196 return triton.next_power_of_2(args["BLOCK_NUM"]) 

197 

198 

199def upsample_nearest2d_SAME_H(args): 

200 return args["OH"] == args["IH"] 

201 

202 

203def upsample_nearest2d_SAME_W(args): 

204 return args["OW"] == args["IW"] 

205 

206 

207def batch_norm_heur_block_m(args): 

208 return min(2048, triton.next_power_of_2(args["batch_dim"])) 

209 

210 

211def batch_norm_heur_block_n(args): 

212 # A maximum of 16384 elements are loaded at once. 

213 BLOCK_M = batch_norm_heur_block_m(args) 

214 BLOCK_N = triton.next_power_of_2(args["spatial_dim"]) 

215 return min(BLOCK_N, max(1, 2**14 // BLOCK_M)) 

216 

217 

218def vdot_heur_block_size(args): 

219 n = args["n_elements"] 

220 if n < 1024: 

221 return 32 

222 elif n < 8192: 

223 return 256 

224 else: 

225 return 1024 

226 

227 

228HEURISTICS_CONFIGS = { 

229 "argmax": { 

230 "BLOCK_M": argmax_heur_block_m, 

231 "BLOCK_N": argmax_heur_block_n, 

232 }, 

233 "argmin": { 

234 "BLOCK_M": argmin_heur_block_m, 

235 "BLOCK_N": argmin_heur_block_n, 

236 }, 

237 "bmm": { 

238 "DIVISIBLE_M": bmm_heur_divisible_m, 

239 "DIVISIBLE_N": bmm_heur_divisible_n, 

240 "DIVISIBLE_K": bmm_heur_divisible_k, 

241 }, 

242 "dropout": { 

243 "BLOCK": dropout_heur_block, 

244 "num_warps": dropout_heur_num_warps, 

245 }, 

246 "exponential_": { 

247 "BLOCK": exponential_heur_block, 

248 "num_warps": exponential_heur_num_warps, 

249 }, 

250 "gather": { 

251 "BLOCK_M": gather_heur_block_m, 

252 "BLOCK_N": gather_heur_block_n, 

253 }, 

254 "index_select": { 

255 "BLOCK_M": index_select_heur_block_m, 

256 "BLOCK_N": index_select_heur_block_n, 

257 }, 

258 "mm": { 

259 "EVEN_K": mm_heur_even_k, 

260 }, 

261 "rand": { 

262 "BLOCK": rand_heur_block, 

263 "num_warps": rand_heur_num_warps, 

264 }, 

265 "randn": { 

266 "BLOCK": randn_heur_block, 

267 "num_warps": randn_heur_num_warps, 

268 }, 

269 "softmax_non_inner": { 

270 "TILE_K": softmax_heur_tile_k, 

271 "TILE_N": softmax_heur_tile_n_non_inner, 

272 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, 

273 "num_warps": softmax_heur_num_warps_non_inner, 

274 }, 

275 "softmax_inner": { 

276 "TILE_N": softmax_heur_tile_n_inner, 

277 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, 

278 "num_warps": softmax_heur_num_warps_inner, 

279 }, 

280 "softmax_backward_non_inner": { 

281 "TILE_N": softmax_heur_tile_n_bwd_non_inner, 

282 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, 

283 }, 

284 "softmax_backward_inner": { 

285 "TILE_M": softmax_heur_tile_m, 

286 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, 

287 }, 

288 "uniform": { 

289 "BLOCK": uniform_heur_block, 

290 "num_warps": uniform_heur_num_warps, 

291 }, 

292 "upsample_nearest2d": { 

293 "SAME_H": upsample_nearest2d_SAME_H, 

294 "SAME_W": upsample_nearest2d_SAME_W, 

295 }, 

296 "var_mean": { 

297 "BLOCK_N": var_mean_heur_block_n, 

298 }, 

299 "batch_norm": { 

300 "BLOCK_M": batch_norm_heur_block_m, 

301 "BLOCK_N": batch_norm_heur_block_n, 

302 }, 

303 "vdot": { 

304 "BLOCK_SIZE": vdot_heur_block_size, 

305 }, 

306 "elementwise_generic": { 

307 "BLOCK_SIZE": simple_elementwise_blocksize_heur, 

308 "num_warps": lambda args: 8, 

309 }, 

310}