Coverage for src/flag_gems/runtime/backend/_mthreads/ops/addmm.py: 0%

140 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-06 06:51 +0800

1import logging 

2import os 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import broadcastable_to, libentry, libtuner 

11from flag_gems.utils import triton_lang_extension as tle 

12 

13from .utils import create_tma_device_descriptor, get_cached_tma_device_descriptor 

14 

15logger = logging.getLogger( 

16 f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}' 

17) 

18 

19 

20EXPAND_CONFIG_FILENAME = os.path.normpath( 

21 os.path.join(os.path.dirname(__file__), "..", "addmm_mthreads_expand.yaml") 

22) 

23 

24 

25def is_supported_sqmma_layout(tensor): 

26 return tensor.is_contiguous() or ( 

27 tensor.stride(0) == 1 and tensor.stride(1) == tensor.shape[0] 

28 ) 

29 

30 

31def is_sqmma_compatible(a, b, N, K): 

32 return ( 

33 a.dim() == 2 

34 and b.dim() == 2 

35 and a.dtype == b.dtype 

36 and a.dtype in (torch.float16, torch.bfloat16) 

37 and is_supported_sqmma_layout(a) 

38 and is_supported_sqmma_layout(b) 

39 and N % 8 == 0 

40 and K % 8 == 0 

41 ) 

42 

43 

44@libentry() 

45@libtuner( 

46 configs=runtime.get_tuned_config("addmm"), 

47 key=["M", "N", "K"], 

48) 

49@triton.jit(do_not_specialize=["alpha", "beta"]) 

50def addmm_kernel( 

51 a_ptr, 

52 b_ptr, 

53 i_ptr, 

54 c_ptr, 

55 alpha, 

56 beta, 

57 M, 

58 N, 

59 K, 

60 stride_am, 

61 stride_ak, 

62 stride_bk, 

63 stride_bn, 

64 stride_im, 

65 stride_in, 

66 stride_cm, 

67 stride_cn, 

68 BLOCK_SIZE_M: tl.constexpr, 

69 BLOCK_SIZE_N: tl.constexpr, 

70 BLOCK_SIZE_K: tl.constexpr, 

71): 

72 pid_m = tle.program_id(0) 

73 pid_n = tle.program_id(1) 

74 

75 offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 

76 offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 

77 offs_k = tl.arange(0, BLOCK_SIZE_K) 

78 a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) 

79 b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) 

80 

81 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 

82 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): 

83 a = tl.load( 

84 a_ptrs, 

85 mask=(offs_am[:, None] < M) & (offs_k[None, :] < K - k * BLOCK_SIZE_K), 

86 other=0.0, 

87 ) 

88 b = tl.load( 

89 b_ptrs, 

90 mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & (offs_bn[None, :] < N), 

91 other=0.0, 

92 ) 

93 accumulator += tl.dot(a, b, allow_tf32=False) 

94 a_ptrs += BLOCK_SIZE_K * stride_ak 

95 b_ptrs += BLOCK_SIZE_K * stride_bk 

96 

97 offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 

98 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 

99 c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] 

100 c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) 

101 i_ptrs = i_ptr + stride_im * offs_cm[:, None] + stride_in * offs_cn[None, :] 

102 bias = tl.load(i_ptrs, mask=c_mask, other=0.0) 

103 

104 accumulator = accumulator * alpha + bias * beta 

105 c = accumulator.to(bias.dtype) 

106 tl.store(c_ptrs, c, mask=c_mask) 

107 

108 

109def addmm_fma(bias, mat1, mat2, *, beta=1, alpha=1): 

110 logger.debug("GEMS_MTHREADS ADDMM(FMA)") 

111 assert mat1.shape[1] == mat2.shape[0], "Incompatible dimensions" 

112 assert broadcastable_to( 

113 bias.shape, (mat1.shape[0], mat2.shape[1]) 

114 ), "Incompatible input shape" 

115 M, K = mat1.shape 

116 _, N = mat2.shape 

117 

118 mat1 = mat1.contiguous() 

119 mat2 = mat2.contiguous() 

120 out = torch.empty((M, N), device=mat1.device, dtype=mat1.dtype) 

121 bias = bias.broadcast_to(out.shape).contiguous() 

122 

123 grid = lambda META: ( 

124 triton.cdiv(M, META["BLOCK_SIZE_M"]), 

125 triton.cdiv(N, META["BLOCK_SIZE_N"]), 

126 ) 

127 with torch_device_fn.device(mat1.device): 

128 addmm_kernel[grid]( 

129 mat1, 

130 mat2, 

131 bias, 

132 out, 

133 alpha, 

134 beta, 

135 M, 

136 N, 

137 K, 

138 mat1.stride(0), 

139 mat1.stride(1), 

140 mat2.stride(0), 

141 mat2.stride(1), 

142 bias.stride(0), 

143 bias.stride(1), 

144 out.stride(0), 

145 out.stride(1), 

146 ) 

147 return out 

148 

149 

150def addmm_sqmma_descriptor_pre_hook(nargs): 

151 a = nargs["A"] 

152 b = nargs["B"] 

153 bias = nargs["Bias"] 

154 c = nargs["C"] 

155 block_m = nargs["BLOCK_SIZE_M"] 

156 block_n = nargs["BLOCK_SIZE_N"] 

157 block_k = nargs["BLOCK_SIZE_K"] 

158 device = c.device 

159 

160 nargs["a_desc_ptr"].copy_( 

161 get_cached_tma_device_descriptor(a, block_m, block_k, device) 

162 ) 

163 nargs["b_desc_ptr"].copy_( 

164 get_cached_tma_device_descriptor(b, block_k, block_n, device) 

165 ) 

166 nargs["bias_desc_ptr"].copy_( 

167 get_cached_tma_device_descriptor(bias, block_m, block_n, device) 

168 ) 

169 nargs["c_desc_ptr"].copy_(create_tma_device_descriptor(c, block_m, block_n, device)) 

170 

171 

172@libentry() 

173@libtuner( 

174 configs=runtime.ops_get_configs( 

175 "addmm_sqmma", 

176 pre_hook=addmm_sqmma_descriptor_pre_hook, 

177 yaml_path=EXPAND_CONFIG_FILENAME, 

178 ) 

179 if os.environ.get("USE_FLAGTUNE") == "1" 

180 else [ 

181 triton.Config( 

182 {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64}, 

183 num_stages=1, 

184 num_warps=4, 

185 pre_hook=addmm_sqmma_descriptor_pre_hook, 

186 ) 

187 ], 

188 key=["M", "N", "K"], 

189 strategy=runtime.get_expand_config("addmm_sqmma", yaml_path=EXPAND_CONFIG_FILENAME)[ 

190 "strategy" 

191 ] 

192 if os.environ.get("USE_FLAGTUNE") == "1" 

193 else ["default", "default", "default"], 

194 warmup=5, 

195 rep=5, 

196) 

197@triton.jit(do_not_specialize=["alpha", "beta"]) 

198def addmm_sqmma_kernel( 

199 A, 

200 B, 

201 Bias, 

202 C, 

203 a_desc_ptr, 

204 b_desc_ptr, 

205 bias_desc_ptr, 

206 c_desc_ptr, 

207 M, 

208 N, 

209 K, 

210 alpha, 

211 beta, 

212 BLOCK_SIZE_M: tl.constexpr, 

213 BLOCK_SIZE_N: tl.constexpr, 

214 BLOCK_SIZE_K: tl.constexpr, 

215 ab_type: tl.constexpr, 

216 c_type: tl.constexpr, 

217 is_transpose_a: tl.constexpr = False, 

218 is_transpose_b: tl.constexpr = False, 

219): 

220 pid = tl.program_id(axis=0) 

221 num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) 

222 pid_m = pid % num_pid_m 

223 pid_n = pid // num_pid_m 

224 offs_am = pid_m * BLOCK_SIZE_M 

225 offs_bn = pid_n * BLOCK_SIZE_N 

226 offs_k = 0 

227 input_type = ab_type 

228 output_type = c_type 

229 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 

230 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): 

231 a = tl._experimental_descriptor_load( 

232 a_desc_ptr, 

233 [offs_am, offs_k], 

234 [BLOCK_SIZE_M, BLOCK_SIZE_K], 

235 input_type, 

236 is_transpose_a, 

237 ) 

238 b = tl._experimental_descriptor_load( 

239 b_desc_ptr, 

240 [offs_k, offs_bn], 

241 [BLOCK_SIZE_K, BLOCK_SIZE_N], 

242 input_type, 

243 is_transpose_b, 

244 ) 

245 accumulator = tl.dot(a, b, acc=accumulator) 

246 offs_k += BLOCK_SIZE_K 

247 bias = tl._experimental_descriptor_load( 

248 bias_desc_ptr, [offs_am, offs_bn], [BLOCK_SIZE_M, BLOCK_SIZE_N], input_type 

249 ) 

250 result = (alpha * accumulator.to(output_type) + beta * bias.to(output_type)).to( 

251 output_type 

252 ) 

253 tl._experimental_descriptor_store(c_desc_ptr, result, [offs_am, offs_bn]) 

254 

255 

256def get_triton_type(elem_type): 

257 type_map = { 

258 torch.float16: tl.float16, 

259 torch.bfloat16: tl.bfloat16, 

260 torch.float8_e4m3fn: tl.float8e4nv, 

261 } 

262 return type_map.get(elem_type, None) 

263 

264 

265def addmm_sqmma(mat1, mat2, bias, elem_type, alpha, beta, M, N, K): 

266 logger.debug("GEMS_MTHREADS ADDMM(SQMMA)") 

267 device = mat1.device 

268 assert broadcastable_to( 

269 bias.shape, (mat1.shape[0], mat2.shape[1]) 

270 ), "Incompatible input shape" 

271 # handle non-contiguous inputs if necessary 

272 is_transpose_a = False 

273 is_transpose_b = False 

274 if not mat1.is_contiguous(): 

275 if mat1.stride(0) == 1 and mat1.stride(1) == mat1.shape[0]: 

276 is_transpose_a = True 

277 else: 

278 mat1 = mat1.contiguous() 

279 if not mat2.is_contiguous(): 

280 if mat2.stride(0) == 1 and mat2.stride(1) == mat2.shape[0]: 

281 is_transpose_b = True 

282 else: 

283 mat2 = mat2.contiguous() 

284 ab_type = elem_type 

285 a_type = mat1.dtype 

286 b_type = mat2.dtype 

287 assert a_type == b_type, "Mat A and Mat B should have the same dtype" 

288 c_type = a_type 

289 C = torch.empty((M, N), dtype=c_type, device=device) 

290 bias = bias.broadcast_to(C.shape).contiguous() 

291 desc_a = torch.empty((64,), dtype=torch.int8, device=device) 

292 desc_b = torch.empty((64,), dtype=torch.int8, device=device) 

293 desc_bias = torch.empty((64,), dtype=torch.int8, device=device) 

294 desc_c = torch.empty((64,), dtype=torch.int8, device=device) 

295 grid = lambda META: ( 

296 triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), 

297 1, 

298 1, 

299 ) 

300 addmm_sqmma_kernel[grid]( 

301 mat1, 

302 mat2, 

303 bias, 

304 C, 

305 desc_a, 

306 desc_b, 

307 desc_bias, 

308 desc_c, 

309 M, 

310 N, 

311 K, 

312 alpha, 

313 beta, 

314 ab_type=get_triton_type(ab_type), 

315 c_type=get_triton_type(c_type), 

316 is_transpose_a=is_transpose_a, 

317 is_transpose_b=is_transpose_b, 

318 ) 

319 return C 

320 

321 

322def addmm(bias, mat1, mat2, *, beta=1, alpha=1): 

323 a_dtype = mat1.dtype 

324 b_dtype = mat2.dtype 

325 M, K = mat1.shape 

326 _, N = mat2.shape 

327 

328 need_sqmma = a_dtype != torch.float32 and b_dtype != torch.float32 

329 prev_sqmma = os.environ.get("MUSA_ENABLE_SQMMA") 

330 if need_sqmma: 

331 os.environ["MUSA_ENABLE_SQMMA"] = "1" 

332 else: 

333 os.environ.pop("MUSA_ENABLE_SQMMA", None) 

334 try: 

335 if is_sqmma_compatible(mat1, mat2, N, K): 

336 return addmm_sqmma( 

337 mat1, 

338 mat2, 

339 bias, 

340 a_dtype, 

341 alpha, 

342 beta, 

343 M, 

344 N, 

345 K, 

346 ) 

347 else: 

348 return addmm_fma(bias, mat1, mat2, alpha=alpha, beta=beta) 

349 finally: 

350 if prev_sqmma is None: 

351 os.environ.pop("MUSA_ENABLE_SQMMA", None) 

352 else: 

353 os.environ["MUSA_ENABLE_SQMMA"] = prev_sqmma