Coverage for src/flag_gems/runtime/backend/_mthreads/ops/mm.py: 0%

168 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-15 02:11 +0800

1import logging 

2import os 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry, libtuner 

11from flag_gems.utils import triton_lang_extension as tle 

12 

13from .utils import create_tma_device_descriptor, should_enable_sqmma 

14 

15logger = logging.getLogger( 

16 f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}' 

17) 

18 

19 

20@triton.jit 

21def prev_multiple_of(a, b): 

22 # the largest x<a that x%b ==0 

23 return tl.cdiv(a, b) * b - b 

24 

25 

26@libentry() 

27@libtuner( 

28 configs=runtime.get_tuned_config("mm"), 

29 key=["M", "N", "K"], 

30 strategy=["align32", "align32", "align32"], 

31) 

32@triton.jit 

33def mm_kernel( 

34 A, 

35 B, 

36 C, 

37 M, 

38 N, 

39 K, 

40 stride_am, 

41 stride_ak, 

42 stride_bk, 

43 stride_bn, 

44 stride_cm, 

45 stride_cn, 

46 BLOCK_M: tl.constexpr, 

47 BLOCK_N: tl.constexpr, 

48 BLOCK_K: tl.constexpr, 

49 GROUP_M: tl.constexpr, 

50): 

51 # matrix multiplication 

52 pid = tle.program_id(0) 

53 grid_m = tl.cdiv(M, BLOCK_M) 

54 grid_n = tl.cdiv(N, BLOCK_N) 

55 # re-order program ID for better L2 performance 

56 width = GROUP_M * grid_n 

57 group_id = pid // width 

58 group_size = min(grid_m - group_id * GROUP_M, GROUP_M) 

59 pid_m = group_id * GROUP_M + (pid % group_size) 

60 pid_n = (pid % width) // (group_size) 

61 # do matrix multiplication 

62 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

63 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

64 ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M).to(tl.int64) 

65 rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N).to(tl.int64) 

66 rm = rm.to(tl.int64) 

67 rn = rn.to(tl.int64) 

68 prev_multiple = prev_multiple_of(K, BLOCK_K) 

69 

70 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

71 for start_k in range(0, prev_multiple, BLOCK_K): 

72 rk = (start_k + tl.arange(0, BLOCK_K)).to(tl.int64) 

73 a = tl.load(A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)) 

74 b = tl.load(B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)) 

75 if a.dtype != b.dtype: 

76 a = a.to(C.dtype.element_ty) 

77 b = b.to(C.dtype.element_ty) 

78 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False) 

79 

80 # loop peeling 

81 rk = (prev_multiple + tl.arange(0, BLOCK_K)).to(tl.int64) 

82 mask_k = rk < K 

83 a = tl.load( 

84 A + (ram[:, None] * stride_am + rk[None, :] * stride_ak), mask=mask_k[None, :] 

85 ) 

86 b = tl.load( 

87 B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn), mask=mask_k[:, None] 

88 ) 

89 if a.dtype != b.dtype: 

90 a = a.to(C.dtype.element_ty) 

91 b = b.to(C.dtype.element_ty) 

92 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False) 

93 

94 acc = acc.to(C.dtype.element_ty) 

95 # rematerialize rm and rn to save registers 

96 rm = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)).to(tl.int64) 

97 rn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)).to(tl.int64) 

98 C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) 

99 mask = (rm < M)[:, None] & (rn < N)[None, :] 

100 # handles write-back with reduction-splitting 

101 tl.store(C, acc, mask=mask) 

102 

103 

104_ordered_datatypes = [torch.float16, torch.bfloat16, torch.float32] 

105 

106 

107def get_higher_dtype(a, b): 

108 if a is b: 

109 return a 

110 

111 assert a in _ordered_datatypes 

112 assert b in _ordered_datatypes 

113 

114 for d in _ordered_datatypes: 

115 if a is d: 

116 return b 

117 if b is d: 

118 return a 

119 

120 

121def mm_fma(a, b): 

122 logger.debug("GEMS_MTHREADS MM(FMA)") 

123 device = a.device 

124 # handle non-contiguous inputs if necessary 

125 if a.stride(0) > 1 and a.stride(1) > 1: 

126 a = a.contiguous() 

127 if b.stride(0) > 1 and b.stride(1) > 1: 

128 b = b.contiguous() 

129 # checks constraints 

130 assert a.shape[1] == b.shape[0], "incompatible dimensions" 

131 M, K = a.shape 

132 _, N = b.shape 

133 # allocates output 

134 c_dtype = get_higher_dtype(a.dtype, b.dtype) 

135 c = torch.empty((M, N), device=device, dtype=c_dtype) 

136 # launch kernel 

137 grid = lambda META: ( 

138 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), 

139 ) 

140 with torch_device_fn.device(a.device): 

141 mm_kernel[grid]( 

142 a, 

143 b, 

144 c, 

145 M, 

146 N, 

147 K, 

148 a.stride(0), 

149 a.stride(1), 

150 b.stride(0), 

151 b.stride(1), 

152 c.stride(0), 

153 c.stride(1), 

154 GROUP_M=8, 

155 ) 

156 return c 

157 

158 

159def mm_out(a, b, *, out): 

160 logger.debug("GEMS_MTHREADS MM_OUT") 

161 # handle non-contiguous inputs if necessary 

162 if a.stride(0) > 1 and a.stride(1) > 1: 

163 a = a.contiguous() 

164 if b.stride(0) > 1 and b.stride(1) > 1: 

165 b = b.contiguous() 

166 # checks constraints 

167 assert a.shape[1] == b.shape[0], "incompatible dimensions" 

168 M, K = a.shape 

169 _, N = b.shape 

170 # allocates output 

171 c = out 

172 # launch kernel 

173 grid = lambda META: ( 

174 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), 

175 ) 

176 with torch_device_fn.device(a.device): 

177 mm_kernel[grid]( 

178 a, 

179 b, 

180 c, 

181 M, 

182 N, 

183 K, 

184 a.stride(0), 

185 a.stride(1), 

186 b.stride(0), 

187 b.stride(1), 

188 c.stride(0), 

189 c.stride(1), 

190 GROUP_M=8, 

191 ) 

192 return c 

193 

194 

195@triton.jit 

196def mm_sqmma_kernel( 

197 a_desc_ptr, 

198 b_desc_ptr, 

199 c_desc_ptr, 

200 M, 

201 N, 

202 K, 

203 GROUP_M: tl.constexpr, 

204 BLOCK_SIZE_M: tl.constexpr, 

205 BLOCK_SIZE_N: tl.constexpr, 

206 BLOCK_SIZE_K: tl.constexpr, 

207 ab_dtype: tl.constexpr, 

208 c_dtype: tl.constexpr, 

209 is_transpose_a: tl.constexpr = False, 

210 is_transpose_b: tl.constexpr = False, 

211): 

212 pid = tle.program_id(0) 

213 grid_m = tl.cdiv(M, BLOCK_SIZE_M) 

214 grid_n = tl.cdiv(N, BLOCK_SIZE_N) 

215 width = GROUP_M * grid_n 

216 group_id = pid // width 

217 group_size = min(grid_m - group_id * GROUP_M, GROUP_M) 

218 pid_m = group_id * GROUP_M + (pid % group_size) 

219 pid_n = (pid % width) // (group_size) 

220 offs_am = pid_m * BLOCK_SIZE_M 

221 offs_bn = pid_n * BLOCK_SIZE_N 

222 offs_k = 0 

223 offs_am = offs_am.to(tl.int32) 

224 offs_bn = offs_bn.to(tl.int32) 

225 offs_k = offs_k.to(tl.int32) 

226 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 

227 tme_load_ab_dtype = ab_dtype 

228 c_store_dtype = c_dtype 

229 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): 

230 a = tl._experimental_descriptor_load( 

231 a_desc_ptr, 

232 [offs_am, offs_k], 

233 [BLOCK_SIZE_M, BLOCK_SIZE_K], 

234 tme_load_ab_dtype, 

235 is_transpose_a, 

236 ) 

237 b = tl._experimental_descriptor_load( 

238 b_desc_ptr, 

239 [offs_k, offs_bn], 

240 [BLOCK_SIZE_K, BLOCK_SIZE_N], 

241 tme_load_ab_dtype, 

242 is_transpose_b, 

243 ) 

244 accumulator += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False) 

245 offs_k += BLOCK_SIZE_K 

246 accumulator = accumulator.to(c_store_dtype) 

247 tl._experimental_descriptor_store(c_desc_ptr, accumulator, [offs_am, offs_bn]) 

248 

249 

250def get_triton_type(elem_type): 

251 type_map = { 

252 torch.float16: tl.float16, 

253 torch.bfloat16: tl.bfloat16, 

254 torch.float8_e4m3fn: tl.float8e4nv, 

255 } 

256 return type_map.get(elem_type, None) 

257 

258 

259def mm_sqmma(A, B, M, N, K, GROUP_M, BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_stages): 

260 logger.debug("GEMS_MTHREADS MM(SQMMA)") 

261 device = "musa" 

262 # handle non-contiguous inputs if necessary 

263 is_transpose_a = False 

264 is_transpose_b = False 

265 if not A.is_contiguous(): 

266 if A.stride(0) == 1 and A.stride(1) == A.shape[0]: 

267 is_transpose_a = True 

268 else: 

269 A = A.contiguous() 

270 if not B.is_contiguous(): 

271 if B.stride(0) == 1 and B.stride(1) == B.shape[0]: 

272 is_transpose_b = True 

273 else: 

274 B = B.contiguous() 

275 a_type = A.dtype 

276 b_type = B.dtype 

277 assert a_type == b_type, "Mat A and Mat B should have the same dtype" 

278 c_dtype = get_higher_dtype(a_type, b_type) 

279 C = torch.empty((M, N), dtype=c_dtype, device=device) 

280 desc_a = create_tma_device_descriptor(A, BLOCK_M, BLOCK_K, device) 

281 desc_b = create_tma_device_descriptor(B, BLOCK_K, BLOCK_N, device) 

282 desc_c = create_tma_device_descriptor(C, BLOCK_M, BLOCK_N, device) 

283 mm_sqmma_kernel[(triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1, 1)]( 

284 desc_a, 

285 desc_b, 

286 desc_c, 

287 M, 

288 N, 

289 K, 

290 GROUP_M, 

291 BLOCK_M, 

292 BLOCK_N, 

293 BLOCK_K, 

294 get_triton_type(a_type), 

295 get_triton_type(c_dtype), 

296 num_warps=num_warps, 

297 num_stages=num_stages, 

298 is_transpose_a=is_transpose_a, 

299 is_transpose_b=is_transpose_b, 

300 ) 

301 return C 

302 

303 

304def mm(a, b): 

305 a_dtype = a.dtype 

306 b_dtype = b.dtype 

307 M, K = a.shape 

308 _, N = b.shape 

309 use_sqmma = should_enable_sqmma(a_dtype, b_dtype, M, N, K) 

310 if use_sqmma: 

311 GROUP_M = 8 

312 BLOCK_M = 128 

313 BLOCK_N = BLOCK_M 

314 BLOCK_K = 64 

315 num_warps = 16 if BLOCK_M == 256 else 4 

316 num_stages = 1 

317 return mm_sqmma( 

318 a, 

319 b, 

320 M, 

321 N, 

322 K, 

323 GROUP_M, 

324 BLOCK_M, 

325 BLOCK_N, 

326 BLOCK_K, 

327 num_warps, 

328 num_stages, 

329 ) 

330 else: 

331 enable_sqmma = os.environ.pop("MUSA_ENABLE_SQMMA", None) 

332 result = mm_fma(a, b) 

333 if enable_sqmma: 

334 os.environ["MUSA_ENABLE_SQMMA"] = enable_sqmma 

335 return result