Coverage for src/flag_gems/runtime/backend/_mthreads/ops/bmm.py: 0%

159 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-06 06:51 +0800

1import logging 

2import os 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry, libtuner 

11from flag_gems.utils import triton_lang_extension as tle 

12 

13from .utils import create_tma_device_descriptor, get_cached_tma_device_descriptor 

14 

15logger = logging.getLogger( 

16 f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}' 

17) 

18 

19EXPAND_CONFIG_FILENAME = os.path.normpath( 

20 os.path.join(os.path.dirname(__file__), "..", "bmm_mthreads_expand.yaml") 

21) 

22 

23 

24def is_supported_sqmma_layout(tensor): 

25 return tensor.is_contiguous() or ( 

26 tensor.stride(0) == 1 and tensor.stride(1) == tensor.shape[0] 

27 ) 

28 

29 

30def is_sqmma_compatible(a, b, N, K): 

31 return ( 

32 a.dtype == b.dtype 

33 and a.dtype in (torch.float16, torch.bfloat16) 

34 and is_supported_sqmma_layout(a) 

35 and is_supported_sqmma_layout(b) 

36 and N % 8 == 0 

37 and K % 8 == 0 

38 ) 

39 

40 

41@libentry() 

42@libtuner( 

43 configs=runtime.get_tuned_config("bmm"), 

44 key=["M", "N", "K"], 

45 strategy=["align32", "align32", "align32"], 

46) 

47@triton.heuristics(runtime.get_heuristic_config("bmm")) 

48@triton.jit 

49def bmm_kernel( 

50 A, 

51 B, 

52 O, 

53 M, 

54 N, 

55 K, 

56 TILE_M: tl.constexpr, 

57 TILE_N: tl.constexpr, 

58 TILE_K: tl.constexpr, 

59 GROUP_M: tl.constexpr, 

60 DIVISIBLE_M: tl.constexpr, 

61 DIVISIBLE_N: tl.constexpr, 

62 DIVISIBLE_K: tl.constexpr, 

63): 

64 # batch offsets 

65 pid_b = tle.program_id(2) 

66 A += pid_b * M * K 

67 B += pid_b * K * N 

68 O += pid_b * M * N 

69 

70 pidx = tle.program_id(0) 

71 pidy = tle.program_id(1) 

72 

73 if GROUP_M == 1: 

74 pid_m, pid_n = pidx, pidy 

75 else: 

76 # reorder CTAs 

77 gridx = tle.num_programs(0) 

78 gridy = tle.num_programs(1) 

79 pid = pidx + pidy * gridx 

80 

81 num_CTA_per_group = gridy * GROUP_M 

82 

83 group_id = pid // num_CTA_per_group 

84 inner_group_id = pid % num_CTA_per_group 

85 GROUP_SIZE = tl.where( 

86 (group_id * GROUP_M + GROUP_M) > gridx, gridx % GROUP_M, GROUP_M 

87 ) 

88 pid_m = group_id * GROUP_M + inner_group_id % GROUP_SIZE 

89 pid_n = inner_group_id // GROUP_SIZE 

90 

91 offs_m = pid_m * TILE_M + tl.arange(0, TILE_M) 

92 offs_n = pid_n * TILE_N + tl.arange(0, TILE_N) 

93 offs_k = tl.arange(0, TILE_K) 

94 

95 if not DIVISIBLE_M: 

96 mask_m = offs_m < M 

97 if not DIVISIBLE_N: 

98 mask_n = offs_n < N 

99 

100 a_ptrs = A + offs_m[:, None] * K + offs_k[None, :] 

101 b_ptrs = B + offs_k[:, None] * N + offs_n[None, :] 

102 o_ptrs = O + offs_m[:, None] * N + offs_n[None, :] 

103 

104 num_iters = tl.cdiv(K, TILE_K) 

105 o = tl.zeros((TILE_M, TILE_N), dtype=tl.float32) 

106 for _ in range(num_iters): 

107 if DIVISIBLE_K: 

108 if DIVISIBLE_M: 

109 mask_a = None 

110 else: 

111 mask_a = mask_m[:, None] 

112 if DIVISIBLE_N: 

113 mask_b = None 

114 else: 

115 mask_b = mask_n[None, :] 

116 else: 

117 mask_k = offs_k < K 

118 if DIVISIBLE_M: 

119 mask_a = mask_k[None, :] 

120 else: 

121 mask_a = mask_m[:, None] & mask_k[None, :] 

122 if DIVISIBLE_N: 

123 mask_b = mask_k[:, None] 

124 else: 

125 mask_b = mask_k[:, None] & mask_n[None, :] 

126 

127 a = tl.load(a_ptrs, mask_a) 

128 b = tl.load(b_ptrs, mask_b) 

129 

130 offs_k += TILE_K 

131 a_ptrs += TILE_K 

132 b_ptrs += TILE_K * N 

133 

134 o += tl.dot(a, b, allow_tf32=False) 

135 

136 if DIVISIBLE_M and DIVISIBLE_N: 

137 mask_c = None 

138 elif DIVISIBLE_M and not DIVISIBLE_N: 

139 mask_c = mask_n[None, :] 

140 elif not DIVISIBLE_M and DIVISIBLE_N: 

141 mask_c = mask_m[:, None] 

142 else: 

143 mask_c = mask_m[:, None] & mask_n[None, :] 

144 tl.store(o_ptrs, o, mask_c) 

145 

146 

147def bmm_fma(A, B): 

148 logger.debug("GEMS_MTHREADS BMM(FMA)") 

149 batch, M, K = A.shape 

150 _, _, N = B.shape 

151 A = A.contiguous() 

152 B = B.contiguous() 

153 out = torch.empty((batch, M, N), dtype=A.dtype, device=A.device) 

154 

155 grid_fn = lambda meta: ( 

156 triton.cdiv(meta["M"], meta["TILE_M"]), 

157 triton.cdiv(meta["N"], meta["TILE_N"]), 

158 batch, 

159 ) 

160 with torch_device_fn.device(A.device): 

161 bmm_kernel[grid_fn](A, B, out, M, N, K) 

162 return out 

163 

164 

165def bmm_sqmma_descriptor_pre_hook(nargs): 

166 a = nargs["A"] 

167 b = nargs["B"] 

168 c = nargs["C"] 

169 batch = nargs["batch"] 

170 M = nargs["M"] 

171 N = nargs["N"] 

172 K = nargs["K"] 

173 block_m = nargs["BLOCK_SIZE_M"] 

174 block_n = nargs["BLOCK_SIZE_N"] 

175 block_k = nargs["BLOCK_SIZE_K"] 

176 device = c.device 

177 

178 nargs["a_desc_ptr"].copy_( 

179 get_cached_tma_device_descriptor( 

180 a.reshape(batch * M, K), block_m, block_k, device 

181 ) 

182 ) 

183 nargs["b_desc_ptr"].copy_( 

184 get_cached_tma_device_descriptor( 

185 b.reshape(batch * K, N), block_k, block_n, device 

186 ) 

187 ) 

188 nargs["c_desc_ptr"].copy_( 

189 create_tma_device_descriptor(c.reshape(batch * M, N), block_m, block_n, device) 

190 ) 

191 

192 

193@libentry() 

194@libtuner( 

195 configs=runtime.ops_get_configs( 

196 "bmm_sqmma", 

197 pre_hook=bmm_sqmma_descriptor_pre_hook, 

198 yaml_path=EXPAND_CONFIG_FILENAME, 

199 ) 

200 if os.environ.get("USE_FLAGTUNE") == "1" 

201 else [ 

202 triton.Config( 

203 {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64}, 

204 num_stages=1, 

205 num_warps=4, 

206 pre_hook=bmm_sqmma_descriptor_pre_hook, 

207 ) 

208 ], 

209 key=["M", "N", "K"], 

210 strategy=runtime.get_expand_config("bmm_sqmma", yaml_path=EXPAND_CONFIG_FILENAME)[ 

211 "strategy" 

212 ][:3] 

213 if os.environ.get("USE_FLAGTUNE") == "1" 

214 else ["align32", "align32", "align32"], 

215 warmup=5, 

216 rep=5, 

217) 

218@triton.jit 

219def bmm_sqmma_kernel( 

220 A, 

221 B, 

222 C, 

223 a_desc_ptr, 

224 b_desc_ptr, 

225 c_desc_ptr, 

226 batch, 

227 M, 

228 N, 

229 K, 

230 BLOCK_SIZE_M: tl.constexpr, 

231 BLOCK_SIZE_N: tl.constexpr, 

232 BLOCK_SIZE_K: tl.constexpr, 

233 ab_type: tl.constexpr, 

234 d_type: tl.constexpr, 

235): 

236 pid = tl.program_id(axis=0) 

237 batch_index = tl.program_id(axis=1) 

238 num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) 

239 pid_m = pid % num_pid_m 

240 pid_n = pid // num_pid_m 

241 offs_am = pid_m * BLOCK_SIZE_M + batch_index * M 

242 offs_bn = pid_n * BLOCK_SIZE_N 

243 offs_ak = 0 

244 offs_bk = batch_index * K 

245 tme_load_type = ab_type 

246 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 

247 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): 

248 a = tl._experimental_descriptor_load( 

249 a_desc_ptr, [offs_am, offs_ak], [BLOCK_SIZE_M, BLOCK_SIZE_K], tme_load_type 

250 ) 

251 b = tl._experimental_descriptor_load( 

252 b_desc_ptr, [offs_bk, offs_bn], [BLOCK_SIZE_K, BLOCK_SIZE_N], tme_load_type 

253 ) 

254 accumulator = tl.dot(a, b, acc=accumulator) 

255 offs_ak += BLOCK_SIZE_K 

256 offs_bk += BLOCK_SIZE_K 

257 accumulator = accumulator.to(d_type) 

258 tl._experimental_descriptor_store(c_desc_ptr, accumulator, [offs_am, offs_bn]) 

259 

260 

261def get_triton_type(elem_type): 

262 type_map = { 

263 torch.float16: tl.float16, 

264 torch.bfloat16: tl.bfloat16, 

265 torch.float8_e4m3fn: tl.float8e4nv, 

266 } 

267 return type_map.get(elem_type, None) 

268 

269 

270def bmm_sqmma(A, B, elem_type, batch, M, N, K): 

271 device = "musa" 

272 ab_type = elem_type 

273 c_type = elem_type if (elem_type != torch.bfloat16) else torch.float16 

274 C = torch.empty((batch, M, N), dtype=torch.float16, device=device).to(c_type) 

275 desc_a = torch.empty((64,), dtype=torch.int8, device=device) 

276 desc_b = torch.empty((64,), dtype=torch.int8, device=device) 

277 desc_c = torch.empty((64,), dtype=torch.int8, device=device) 

278 grid = lambda META: ( 

279 triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), 

280 batch, 

281 1, 

282 ) 

283 bmm_sqmma_kernel[grid]( 

284 A, 

285 B, 

286 C, 

287 desc_a, 

288 desc_b, 

289 desc_c, 

290 batch, 

291 M, 

292 N, 

293 K, 

294 ab_type=get_triton_type(ab_type), 

295 d_type=get_triton_type(c_type), 

296 ) 

297 return C 

298 

299 

300def bmm(a, b): 

301 a_dtype = a.dtype 

302 b_dtype = b.dtype 

303 batch, M, K = a.shape 

304 _, _, N = b.shape 

305 need_sqmma = a_dtype != torch.float32 and b_dtype != torch.float32 

306 prev_sqmma = os.environ.get("MUSA_ENABLE_SQMMA") 

307 if need_sqmma: 

308 os.environ["MUSA_ENABLE_SQMMA"] = "1" 

309 else: 

310 os.environ.pop("MUSA_ENABLE_SQMMA", None) 

311 try: 

312 if is_sqmma_compatible(a, b, N, K): 

313 return bmm_sqmma(a, b, a_dtype, batch, M, N, K) 

314 else: 

315 return bmm_fma(a, b) 

316 finally: 

317 if prev_sqmma is None: 

318 os.environ.pop("MUSA_ENABLE_SQMMA", None) 

319 else: 

320 os.environ["MUSA_ENABLE_SQMMA"] = prev_sqmma