Coverage for src/flag_gems/runtime/backend/_arm/heuristics_config_utils.py: 0%

122 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-09 01:57 +0800

1import triton 

2 

3 

4def argmax_heur_block_m(args): 

5 return 4 if args["M"] < 4096 else 8 

6 

7 

8def argmax_heur_block_n(args): 

9 return min(4, triton.next_power_of_2(args["N"])) 

10 

11 

12def argmin_heur_block_m(args): 

13 return 4 if args["M"] < 4096 else 8 

14 

15 

16def argmin_heur_block_n(args): 

17 return min(4, triton.next_power_of_2(args["N"])) 

18 

19 

20def bmm_heur_divisible_m(args): 

21 return args["M"] % args["TILE_M"] == 0 

22 

23 

24def bmm_heur_divisible_n(args): 

25 return args["N"] % args["TILE_N"] == 0 

26 

27 

28def bmm_heur_divisible_k(args): 

29 return args["K"] % args["TILE_K"] == 0 

30 

31 

32def dropout_heur_block(args): 

33 if args["N"] <= 512: 

34 return 512 

35 else: 

36 return 1024 

37 

38 

39def dropout_heur_num_warps(args): 

40 if args["N"] <= 512: 

41 return 4 

42 elif args["N"] <= 1024: 

43 return 8 

44 else: 

45 return 16 

46 

47 

48def exponential_heur_block(args): 

49 if args["N"] <= 512: 

50 return 4 

51 else: 

52 return 8 

53 

54 

55def exponential_heur_num_warps(args): 

56 if args["N"] <= 512: 

57 return 4 

58 elif args["N"] <= 1024: 

59 return 8 

60 else: 

61 return 16 

62 

63 

64def gather_heur_block_m(args): 

65 return min(4, triton.next_power_of_2(triton.cdiv(args["N"], 2048))) 

66 

67 

68def gather_heur_block_n(args): 

69 return min(16, triton.next_power_of_2(args["N"])) 

70 

71 

72def index_select_heur_block_m(args): 

73 return min(4, triton.next_power_of_2(triton.cdiv(256, args["N"]))) 

74 

75 

76def index_select_heur_block_n(args): 

77 m = min(triton.next_power_of_2(triton.cdiv(args["N"], 16)), 512) 

78 return max(m, 16) 

79 

80 

81def mm_heur_even_k(args): 

82 return args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0 

83 

84 

85def rand_heur_block(args): 

86 if args["N"] <= 512: 

87 return 4 

88 else: 

89 return 16 

90 

91 

92def rand_heur_num_warps(args): 

93 if args["N"] <= 512: 

94 return 4 

95 elif args["N"] <= 1024: 

96 return 8 

97 else: 

98 return 16 

99 

100 

101def randn_heur_block(args): 

102 if args["N"] <= 512: 

103 return 512 

104 else: 

105 return 1024 

106 

107 

108def randn_heur_num_warps(args): 

109 if args["N"] <= 512: 

110 return 4 

111 elif args["N"] <= 1024: 

112 return 8 

113 else: 

114 return 16 

115 

116 

117def softmax_heur_tile_k(args): 

118 # MAX_TILE_K = 8192 

119 # NUM_SMS = torch.cuda.get_device_properties( 

120 # torch.cuda.current_device() 

121 # ).multi_processor_count 

122 # tile_k = 1 

123 # upper_bound = min(args["K"], MAX_TILE_K) 

124 # while tile_k <= upper_bound: 

125 # num_blocks = args["M"] * triton.cdiv(args["K"], tile_k) 

126 # num_waves = num_blocks / NUM_SMS 

127 # if (num_waves > 1) and (tile_k * 2 <= upper_bound): 

128 # tile_k *= 2 

129 # else: 

130 # break 

131 # return tile_k 

132 return 16 

133 

134 

135def softmax_heur_tile_n_non_inner(args): 

136 # return triton.cdiv(8192, args["TILE_K"]) 

137 return 16 

138 

139 

140def softmax_heur_one_tile_per_cta(args): 

141 return args["TILE_N"] >= args["N"] 

142 

143 

144def softmax_heur_num_warps_non_inner(args): 

145 tile_size = args["TILE_N"] * args["TILE_K"] 

146 if tile_size < 2048: 

147 return 4 

148 elif tile_size < 4096: 

149 return 8 

150 else: 

151 return 16 

152 

153 

154def softmax_heur_tile_n_inner(args): 

155 # if args["N"] <= (32 * 1024): 

156 # return triton.next_power_of_2(args["N"]) 

157 # else: 

158 # return 4096 

159 return 4 

160 

161 

162def softmax_heur_num_warps_inner(args): 

163 tile_size = args["TILE_N"] 

164 if tile_size < 2048: 

165 return 4 

166 elif tile_size < 4096: 

167 return 8 

168 else: 

169 return 16 

170 

171 

172def softmax_heur_tile_n_bwd_non_inner(args): 

173 return max(1, 1024 // args["TILE_K"]) 

174 

175 

176def softmax_heur_tile_m(args): 

177 return max(1, 1024 // args["TILE_N"]) 

178 

179 

180def uniform_heur_block(args): 

181 if args["N"] <= 512: 

182 return 512 

183 else: 

184 return 1024 

185 

186 

187def uniform_heur_num_warps(args): 

188 if args["N"] <= 512: 

189 return 4 

190 elif args["N"] <= 1024: 

191 return 8 

192 else: 

193 return 16 

194 

195 

196def var_mean_heur_block_n(args): 

197 return triton.next_power_of_2(args["BLOCK_NUM"]) 

198 

199 

200def upsample_nearest2d_SAME_H(args): 

201 return args["OH"] == args["IH"] 

202 

203 

204def upsample_nearest2d_SAME_W(args): 

205 return args["OW"] == args["IW"] 

206 

207 

208def batch_norm_heur_block_m(args): 

209 return min(4, triton.next_power_of_2(args["batch_dim"])) 

210 

211 

212def batch_norm_heur_block_n(args): 

213 # A maximum of 16384 elements are loaded at once. 

214 BLOCK_M = batch_norm_heur_block_m(args) 

215 BLOCK_N = triton.next_power_of_2(args["spatial_dim"]) 

216 return min(BLOCK_N, max(1, 2**14 // BLOCK_M)) 

217 

218 

219def vdot_heur_block_size(args): 

220 n = args["n_elements"] 

221 if n < 1024: 

222 return 32 

223 elif n < 8192: 

224 return 256 

225 else: 

226 return 1024 

227 

228 

229HEURISTICS_CONFIGS = { 

230 "argmax": { 

231 "BLOCK_M": argmax_heur_block_m, 

232 "BLOCK_N": argmax_heur_block_n, 

233 }, 

234 "argmin": { 

235 "BLOCK_M": argmin_heur_block_m, 

236 "BLOCK_N": argmin_heur_block_n, 

237 }, 

238 "bmm": { 

239 "DIVISIBLE_M": bmm_heur_divisible_m, 

240 "DIVISIBLE_N": bmm_heur_divisible_n, 

241 "DIVISIBLE_K": bmm_heur_divisible_k, 

242 }, 

243 "dropout": { 

244 "BLOCK": dropout_heur_block, 

245 "num_warps": dropout_heur_num_warps, 

246 }, 

247 "exponential_": { 

248 "BLOCK": exponential_heur_block, 

249 "num_warps": exponential_heur_num_warps, 

250 }, 

251 "gather": { 

252 "BLOCK_M": gather_heur_block_m, 

253 "BLOCK_N": gather_heur_block_n, 

254 }, 

255 "index_select": { 

256 "BLOCK_M": index_select_heur_block_m, 

257 "BLOCK_N": index_select_heur_block_n, 

258 }, 

259 "mm": { 

260 "EVEN_K": mm_heur_even_k, 

261 }, 

262 "rand": { 

263 "BLOCK": rand_heur_block, 

264 "num_warps": rand_heur_num_warps, 

265 }, 

266 "randn": { 

267 "BLOCK": randn_heur_block, 

268 "num_warps": randn_heur_num_warps, 

269 }, 

270 "softmax_non_inner": { 

271 "TILE_K": softmax_heur_tile_k, 

272 "TILE_N": softmax_heur_tile_n_non_inner, 

273 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, 

274 "num_warps": softmax_heur_num_warps_non_inner, 

275 }, 

276 "softmax_inner": { 

277 "TILE_N": softmax_heur_tile_n_inner, 

278 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, 

279 "num_warps": softmax_heur_num_warps_inner, 

280 }, 

281 "softmax_backward_non_inner": { 

282 "TILE_N": softmax_heur_tile_n_bwd_non_inner, 

283 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, 

284 }, 

285 "softmax_backward_inner": { 

286 "TILE_M": softmax_heur_tile_m, 

287 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, 

288 }, 

289 "uniform": { 

290 "BLOCK": uniform_heur_block, 

291 "num_warps": uniform_heur_num_warps, 

292 }, 

293 "upsample_nearest2d": { 

294 "SAME_H": upsample_nearest2d_SAME_H, 

295 "SAME_W": upsample_nearest2d_SAME_W, 

296 }, 

297 "var_mean": { 

298 "BLOCK_N": var_mean_heur_block_n, 

299 }, 

300 "batch_norm": { 

301 "BLOCK_M": batch_norm_heur_block_m, 

302 "BLOCK_N": batch_norm_heur_block_n, 

303 }, 

304 "vdot": { 

305 "BLOCK_SIZE": vdot_heur_block_size, 

306 }, 

307}