Coverage for src/flag_gems/runtime/backend/_mthreads/ops/mm.py: 0%

197 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-26 06:59 +0800

1import logging 

2import os 

3 

4import torch 

5import triton 

6import triton.language as tl 

7from triton.tools.tensor_descriptor import TensorDescriptor 

8 

9from flag_gems import runtime 

10from flag_gems.runtime import torch_device_fn 

11from flag_gems.utils import libentry, libtuner 

12from flag_gems.utils import triton_lang_extension as ext 

13 

14logger = logging.getLogger("flag_gems.runtime.backend._mthreads.ops.mm") 

15 

16EXPAND_CONFIG_FILENAME = os.path.normpath( 

17 os.path.join(os.path.dirname(__file__), "..", "mm_mthreads_expand.yaml") 

18) 

19 

20# Module-level capability flag: evaluated once at import time, then reused as 

21# a constant for the entire process lifetime with no repeated parsing overhead. 

22# False when Triton < 3.2 (e.g. 3.1), True when Triton >= 3.2. 

23SQMMA_ON = tuple(int(x) for x in triton.__version__.split(".")[:2]) >= (3, 2) 

24 

25 

26def is_supported_sqmma_layout(tensor): 

27 return tensor.is_contiguous() or ( 

28 tensor.stride(0) == 1 and tensor.stride(1) == tensor.shape[0] 

29 ) 

30 

31 

32def is_sqmma_compatible(a, b, N, K): 

33 return ( 

34 SQMMA_ON 

35 and a.dim() == 2 

36 and b.dim() == 2 

37 and a.dtype == b.dtype 

38 and a.dtype in (torch.float16, torch.bfloat16) 

39 and is_supported_sqmma_layout(a) 

40 and is_supported_sqmma_layout(b) 

41 and N % 8 == 0 

42 and K % 8 == 0 

43 ) 

44 

45 

46@triton.jit 

47def prev_multiple_of(a, b): 

48 # the largest x<a that x%b ==0 

49 return tl.cdiv(a, b) * b - b 

50 

51 

52@libentry() 

53@libtuner( 

54 configs=runtime.ops_get_configs("mm", yaml_path=EXPAND_CONFIG_FILENAME) 

55 if os.environ.get("USE_FLAGTUNE") == "1" 

56 else runtime.get_tuned_config("mm"), 

57 key=["M", "N", "K", "stride_am", "stride_bk"], 

58 strategy=runtime.get_expand_config("mm", yaml_path=EXPAND_CONFIG_FILENAME)[ 

59 "strategy" 

60 ] 

61 if os.environ.get("USE_FLAGTUNE") == "1" 

62 else ["align32", "align32", "align32", "align32", "align32"], 

63 warmup=5, 

64 rep=5, 

65) 

66@triton.jit 

67def mm_kernel( 

68 A, 

69 B, 

70 C, 

71 M, 

72 N, 

73 K, 

74 stride_am, 

75 stride_ak, 

76 stride_bk, 

77 stride_bn, 

78 stride_cm, 

79 stride_cn, 

80 dtype: tl.constexpr, 

81 BLOCK_M: tl.constexpr, 

82 BLOCK_N: tl.constexpr, 

83 BLOCK_K: tl.constexpr, 

84 GROUP_M: tl.constexpr, 

85 IS_FP64: tl.constexpr = False, 

86): 

87 # matrix multiplication 

88 pid = ext.program_id(0) 

89 grid_m = tl.cdiv(M, BLOCK_M) 

90 grid_n = tl.cdiv(N, BLOCK_N) 

91 # re-order program ID for better L2 performance 

92 width = GROUP_M * grid_n 

93 group_id = pid // width 

94 group_size = min(grid_m - group_id * GROUP_M, GROUP_M) 

95 pid_m = group_id * GROUP_M + (pid % group_size) 

96 pid_n = (pid % width) // (group_size) 

97 # do matrix multiplication 

98 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

99 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

100 ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M).to(tl.int64) 

101 rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N).to(tl.int64) 

102 rm = rm.to(tl.int64) 

103 rn = rn.to(tl.int64) 

104 prev_multiple = prev_multiple_of(K, BLOCK_K) 

105 

106 if IS_FP64: 

107 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float64) 

108 else: 

109 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

110 for start_k in range(0, prev_multiple, BLOCK_K): 

111 rk = (start_k + tl.arange(0, BLOCK_K)).to(tl.int64) 

112 a = tl.load(A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)) 

113 b = tl.load(B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)) 

114 if a.dtype != b.dtype: 

115 a = a.to(C.dtype.element_ty) 

116 b = b.to(C.dtype.element_ty) 

117 if IS_FP64: 

118 acc += tl.dot(a, b, allow_tf32=False) 

119 else: 

120 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False) 

121 

122 # loop peeling 

123 rk = (prev_multiple + tl.arange(0, BLOCK_K)).to(tl.int64) 

124 mask_k = rk < K 

125 a = tl.load( 

126 A + (ram[:, None] * stride_am + rk[None, :] * stride_ak), mask=mask_k[None, :] 

127 ) 

128 b = tl.load( 

129 B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn), mask=mask_k[:, None] 

130 ) 

131 if a.dtype != b.dtype: 

132 a = a.to(C.dtype.element_ty) 

133 b = b.to(C.dtype.element_ty) 

134 if IS_FP64: 

135 acc += tl.dot(a, b, allow_tf32=False) 

136 else: 

137 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False) 

138 

139 acc = acc.to(C.dtype.element_ty) 

140 # rematerialize rm and rn to save registers 

141 rm = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)).to(tl.int64) 

142 rn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)).to(tl.int64) 

143 C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) 

144 mask = (rm < M)[:, None] & (rn < N)[None, :] 

145 # handles write-back with reduction-splitting 

146 tl.store(C, acc, mask=mask) 

147 

148 

149@libentry() 

150@libtuner( 

151 configs=runtime.ops_get_configs("gemv", yaml_path=EXPAND_CONFIG_FILENAME) 

152 if os.environ.get("USE_FLAGTUNE") == "1" 

153 else [triton.Config({"BLOCK_M": 64, "BLOCK_K": 64})], 

154 key=["M", "K", "stride_am", "stride_bk"], 

155 strategy=runtime.get_expand_config("gemv", yaml_path=EXPAND_CONFIG_FILENAME)[ 

156 "strategy" 

157 ] 

158 if os.environ.get("USE_FLAGTUNE") == "1" 

159 else ["align32", "align32", "align32", "default"], 

160 warmup=5, 

161 rep=5, 

162) 

163@triton.jit 

164def gemv_kernel( 

165 A, 

166 B, 

167 C, 

168 M, 

169 K, 

170 stride_am, 

171 stride_ak, 

172 stride_bk, 

173 stride_cm, 

174 BLOCK_M: tl.constexpr, 

175 BLOCK_K: tl.constexpr, 

176): 

177 pid = ext.program_id(0) 

178 

179 row_start = pid * BLOCK_M 

180 row_offset = row_start + tl.arange(0, BLOCK_M) 

181 row_mask = row_offset < M 

182 

183 acc = tl.zeros((BLOCK_M,), dtype=tl.float32) 

184 

185 for k_start in range(0, K, BLOCK_K): 

186 k_offset = k_start + tl.arange(0, BLOCK_K) 

187 k_mask = k_offset < K 

188 

189 a_ptrs = A + row_offset[:, None] * stride_am + k_offset[None, :] * stride_ak 

190 a = tl.load(a_ptrs, mask=row_mask[:, None] & k_mask[None, :], other=0.0) 

191 

192 b_ptrs = B + k_offset * stride_bk 

193 b = tl.load(b_ptrs, mask=k_mask, other=0.0) 

194 

195 # Keep the reduction in fp32 so N=1 GEMV matches the mm path more closely. 

196 a = a.to(tl.float32) 

197 b = b.to(tl.float32) 

198 acc += tl.sum(a * b[None, :], axis=1) 

199 

200 c_ptrs = C + row_offset * stride_cm 

201 acc = acc.to(C.dtype.element_ty) 

202 tl.store(c_ptrs, acc, mask=row_mask) 

203 

204 

205_ordered_datatypes = [torch.float16, torch.bfloat16, torch.float32, torch.float64] 

206 

207 

208def get_higher_dtype(a, b): 

209 if a is b: 

210 return a 

211 

212 assert a in _ordered_datatypes 

213 assert b in _ordered_datatypes 

214 

215 for d in _ordered_datatypes: 

216 if a is d: 

217 return b 

218 if b is d: 

219 return a 

220 

221 

222def mm_fma(a, b): 

223 logger.debug("GEMS_MTHREADS MM(FMA)") 

224 device = a.device 

225 # handle non-contiguous inputs if necessary 

226 if a.stride(0) > 1 and a.stride(1) > 1: 

227 a = a.contiguous() 

228 if b.stride(0) > 1 and b.stride(1) > 1: 

229 b = b.contiguous() 

230 # checks constraints 

231 assert a.shape[1] == b.shape[0], "incompatible dimensions" 

232 M, K = a.shape 

233 _, N = b.shape 

234 # allocates output 

235 c_dtype = get_higher_dtype(a.dtype, b.dtype) 

236 c = torch.empty((M, N), device=device, dtype=c_dtype) 

237 # launch kernel 

238 grid = lambda META: ( 

239 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), 

240 ) 

241 with torch_device_fn.device(a.device): 

242 mm_kernel[grid]( 

243 a, 

244 b, 

245 c, 

246 M, 

247 N, 

248 K, 

249 a.stride(0), 

250 a.stride(1), 

251 b.stride(0), 

252 b.stride(1), 

253 c.stride(0), 

254 c.stride(1), 

255 dtype=str(a.dtype).split(".")[-1], 

256 GROUP_M=8, 

257 IS_FP64=a.dtype == torch.float64, 

258 ) 

259 return c 

260 

261 

262def gemv_mm(a, b, c, M, K): 

263 logger.debug( 

264 "GEMS_MTHREADS MM(GEMV), [shape info]: [%s, %s, 1](M, K, N)", 

265 M, 

266 K, 

267 ) 

268 grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]),) 

269 with torch_device_fn.device(a.device): 

270 gemv_kernel[grid]( 

271 a, 

272 b, 

273 c, 

274 M, 

275 K, 

276 a.stride(0), 

277 a.stride(1), 

278 b.stride(0), 

279 c.stride(0), 

280 ) 

281 return c 

282 

283 

284def mm_out(a, b, *, out): 

285 logger.debug("GEMS_MTHREADS MM_OUT") 

286 # handle non-contiguous inputs if necessary 

287 if a.stride(0) > 1 and a.stride(1) > 1: 

288 a = a.contiguous() 

289 if b.stride(0) > 1 and b.stride(1) > 1: 

290 b = b.contiguous() 

291 # checks constraints 

292 assert a.shape[1] == b.shape[0], "incompatible dimensions" 

293 M, K = a.shape 

294 _, N = b.shape 

295 # allocates output 

296 c = out 

297 if N == 1: 

298 return gemv_mm(a, b, c, M, K) 

299 # launch kernel 

300 grid = lambda META: ( 

301 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), 

302 ) 

303 with torch_device_fn.device(a.device): 

304 mm_kernel[grid]( 

305 a, 

306 b, 

307 c, 

308 M, 

309 N, 

310 K, 

311 a.stride(0), 

312 a.stride(1), 

313 b.stride(0), 

314 b.stride(1), 

315 c.stride(0), 

316 c.stride(1), 

317 dtype=str(a.dtype).split(".")[-1], 

318 GROUP_M=8, 

319 IS_FP64=a.dtype == torch.float64, 

320 ) 

321 return c 

322 

323 

324@triton.jit 

325def mm_sqmma_kernel( 

326 a_desc, 

327 b_desc, 

328 c_desc, 

329 M, 

330 N, 

331 K, 

332 dtype: tl.constexpr, 

333 GROUP_M: tl.constexpr, 

334 BLOCK_M: tl.constexpr, 

335 BLOCK_N: tl.constexpr, 

336 BLOCK_K: tl.constexpr, 

337): 

338 pid = ext.program_id(0) 

339 grid_m = tl.cdiv(M, BLOCK_M) 

340 grid_n = tl.cdiv(N, BLOCK_N) 

341 width = GROUP_M * grid_n 

342 group_id = pid // width 

343 group_size = min(grid_m - group_id * GROUP_M, GROUP_M) 

344 pid_m = group_id * GROUP_M + (pid % group_size) 

345 pid_n = (pid % width) // (group_size) 

346 offs_am = (pid_m * BLOCK_M).to(tl.int32) 

347 offs_bn = (pid_n * BLOCK_N).to(tl.int32) 

348 offs_k = 0 

349 offs_k = offs_k.to(tl.int32) 

350 accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

351 for k in range(0, tl.cdiv(K, BLOCK_K)): 

352 a = tl.load_tensor_descriptor(a_desc, [offs_am, offs_k]) 

353 b = tl.load_tensor_descriptor(b_desc, [offs_k, offs_bn]) 

354 accumulator = tl.dot(a, b, acc=accumulator) 

355 offs_k += BLOCK_K 

356 tl.store_tensor_descriptor(c_desc, [offs_am, offs_bn], accumulator.to(c_desc.dtype)) 

357 

358 

359def get_triton_type(elem_type): 

360 type_map = { 

361 torch.float16: tl.float16, 

362 torch.bfloat16: tl.bfloat16, 

363 torch.float8_e4m3fn: tl.float8e4nv, 

364 } 

365 return type_map.get(elem_type, None) 

366 

367 

368def mm_sqmma(A, B, M, N, K, GROUP_M): 

369 logger.debug("GEMS_MTHREADS MM(SQMMA)") 

370 device = A.device 

371 if not A.is_contiguous(): 

372 A = A.contiguous() 

373 if not B.is_contiguous(): 

374 B = B.contiguous() 

375 a_type = A.dtype 

376 b_type = B.dtype 

377 assert a_type == b_type, "Mat A and Mat B should have the same dtype" 

378 c_dtype = get_higher_dtype(a_type, b_type) 

379 C = torch.empty((M, N), dtype=c_dtype, device=device) 

380 BLOCK_M = 128 

381 BLOCK_N = 128 

382 BLOCK_K = 64 

383 desc_a = TensorDescriptor.from_tensor(A, [BLOCK_M, BLOCK_K]) 

384 desc_b = TensorDescriptor.from_tensor(B, [BLOCK_K, BLOCK_N]) 

385 desc_c = TensorDescriptor.from_tensor(C, [BLOCK_M, BLOCK_N]) 

386 grid = lambda META: ( 

387 triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 

388 1, 

389 1, 

390 ) 

391 mm_sqmma_kernel[grid]( 

392 desc_a, 

393 desc_b, 

394 desc_c, 

395 M, 

396 N, 

397 K, 

398 str(a_type).split(".")[-1], 

399 GROUP_M, 

400 BLOCK_M, 

401 BLOCK_N, 

402 BLOCK_K, 

403 num_warps=4, 

404 num_stages=1, 

405 ) 

406 return C 

407 

408 

409def mm(a, b): 

410 a_dtype = a.dtype 

411 b_dtype = b.dtype 

412 M, K = a.shape 

413 _, N = b.shape 

414 if N == 1: 

415 c_dtype = get_higher_dtype(a_dtype, b_dtype) 

416 c = torch.empty((M, N), device=a.device, dtype=c_dtype) 

417 return gemv_mm(a, b, c, M, K) 

418 

419 if is_sqmma_compatible(a, b, N, K): 

420 GROUP_M = 8 

421 return mm_sqmma( 

422 a, 

423 b, 

424 M, 

425 N, 

426 K, 

427 GROUP_M, 

428 ) 

429 else: 

430 return mm_fma(a, b)