Coverage for src/flag_gems/runtime/backend/_mthreads/ops/addmm.py: 0%

120 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-13 10:08 +0800

1import logging 

2import os 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import broadcastable_to, libentry, libtuner 

11from flag_gems.utils import triton_lang_extension as tle 

12 

13from .utils import create_tma_device_descriptor, should_enable_sqmma 

14 

15logger = logging.getLogger( 

16 f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}' 

17) 

18 

19 

20@libentry() 

21@libtuner( 

22 configs=runtime.get_tuned_config("addmm"), 

23 key=["M", "N", "K"], 

24) 

25@triton.jit(do_not_specialize=["alpha", "beta"]) 

26def addmm_kernel( 

27 a_ptr, 

28 b_ptr, 

29 i_ptr, 

30 c_ptr, 

31 alpha, 

32 beta, 

33 M, 

34 N, 

35 K, 

36 stride_am, 

37 stride_ak, 

38 stride_bk, 

39 stride_bn, 

40 stride_im, 

41 stride_in, 

42 stride_cm, 

43 stride_cn, 

44 BLOCK_SIZE_M: tl.constexpr, 

45 BLOCK_SIZE_N: tl.constexpr, 

46 BLOCK_SIZE_K: tl.constexpr, 

47): 

48 pid_m = tle.program_id(0) 

49 pid_n = tle.program_id(1) 

50 

51 offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 

52 offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 

53 offs_k = tl.arange(0, BLOCK_SIZE_K) 

54 a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) 

55 b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) 

56 

57 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 

58 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): 

59 a = tl.load( 

60 a_ptrs, 

61 mask=(offs_am[:, None] < M) & (offs_k[None, :] < K - k * BLOCK_SIZE_K), 

62 other=0.0, 

63 ) 

64 b = tl.load( 

65 b_ptrs, 

66 mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & (offs_bn[None, :] < N), 

67 other=0.0, 

68 ) 

69 accumulator += tl.dot(a, b, allow_tf32=False) 

70 a_ptrs += BLOCK_SIZE_K * stride_ak 

71 b_ptrs += BLOCK_SIZE_K * stride_bk 

72 

73 offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 

74 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 

75 c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] 

76 c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) 

77 i_ptrs = i_ptr + stride_im * offs_cm[:, None] + stride_in * offs_cn[None, :] 

78 bias = tl.load(i_ptrs, mask=c_mask, other=0.0) 

79 

80 accumulator = accumulator * alpha + bias * beta 

81 c = accumulator.to(bias.dtype) 

82 tl.store(c_ptrs, c, mask=c_mask) 

83 

84 

85def addmm_fma(bias, mat1, mat2, *, beta=1, alpha=1): 

86 logger.debug("GEMS_MTHREADS ADDMM(FMA)") 

87 assert mat1.shape[1] == mat2.shape[0], "Incompatible dimensions" 

88 assert broadcastable_to( 

89 bias.shape, (mat1.shape[0], mat2.shape[1]) 

90 ), "Incompatible input shape" 

91 M, K = mat1.shape 

92 _, N = mat2.shape 

93 

94 mat1 = mat1.contiguous() 

95 mat2 = mat2.contiguous() 

96 out = torch.empty((M, N), device=mat1.device, dtype=mat1.dtype) 

97 bias = bias.broadcast_to(out.shape).contiguous() 

98 

99 grid = lambda META: ( 

100 triton.cdiv(M, META["BLOCK_SIZE_M"]), 

101 triton.cdiv(N, META["BLOCK_SIZE_N"]), 

102 ) 

103 with torch_device_fn.device(mat1.device): 

104 addmm_kernel[grid]( 

105 mat1, 

106 mat2, 

107 bias, 

108 out, 

109 alpha, 

110 beta, 

111 M, 

112 N, 

113 K, 

114 mat1.stride(0), 

115 mat1.stride(1), 

116 mat2.stride(0), 

117 mat2.stride(1), 

118 bias.stride(0), 

119 bias.stride(1), 

120 out.stride(0), 

121 out.stride(1), 

122 ) 

123 return out 

124 

125 

126@triton.jit 

127def addmm_sqmma_kernel( 

128 a_desc_ptr, 

129 b_desc_ptr, 

130 bias_desc_ptr, 

131 c_desc_ptr, 

132 M, 

133 N, 

134 K, 

135 BLOCK_SIZE_M: tl.constexpr, 

136 BLOCK_SIZE_N: tl.constexpr, 

137 BLOCK_SIZE_K: tl.constexpr, 

138 alpha: tl.constexpr, 

139 beta: tl.constexpr, 

140 ab_type: tl.constexpr, 

141 c_type: tl.constexpr, 

142 is_transpose_a: tl.constexpr = False, 

143 is_transpose_b: tl.constexpr = False, 

144): 

145 pid = tl.program_id(axis=0) 

146 num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) 

147 pid_m = pid % num_pid_m 

148 pid_n = pid // num_pid_m 

149 offs_am = pid_m * BLOCK_SIZE_M 

150 offs_bn = pid_n * BLOCK_SIZE_N 

151 offs_k = 0 

152 input_type = ab_type 

153 output_type = c_type 

154 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 

155 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): 

156 a = tl._experimental_descriptor_load( 

157 a_desc_ptr, 

158 [offs_am, offs_k], 

159 [BLOCK_SIZE_M, BLOCK_SIZE_K], 

160 input_type, 

161 is_transpose_a, 

162 ) 

163 b = tl._experimental_descriptor_load( 

164 b_desc_ptr, 

165 [offs_k, offs_bn], 

166 [BLOCK_SIZE_K, BLOCK_SIZE_N], 

167 input_type, 

168 is_transpose_b, 

169 ) 

170 accumulator = tl.dot(a, b, acc=accumulator) 

171 offs_k += BLOCK_SIZE_K 

172 bias = tl._experimental_descriptor_load( 

173 bias_desc_ptr, [offs_am, offs_bn], [BLOCK_SIZE_M, BLOCK_SIZE_N], input_type 

174 ) 

175 result = (alpha * accumulator.to(output_type) + beta * bias.to(output_type)).to( 

176 output_type 

177 ) 

178 tl._experimental_descriptor_store(c_desc_ptr, result, [offs_am, offs_bn]) 

179 

180 

181def get_triton_type(elem_type): 

182 type_map = { 

183 torch.float16: tl.float16, 

184 torch.bfloat16: tl.bfloat16, 

185 torch.float8_e4m3fn: tl.float8e4nv, 

186 } 

187 return type_map.get(elem_type, None) 

188 

189 

190def addmm_sqmma( 

191 A, 

192 B, 

193 Bias, 

194 elem_type, 

195 alpha, 

196 beta, 

197 M, 

198 N, 

199 K, 

200 BLOCK_M, 

201 BLOCK_N, 

202 BLOCK_K, 

203 num_warps, 

204 num_stages, 

205): 

206 logger.debug("GEMS_MTHREADS ADDMM(SQMMA)") 

207 device = "musa" 

208 assert broadcastable_to( 

209 Bias.shape, (A.shape[0], B.shape[1]) 

210 ), "Incompatible input shape" 

211 # handle non-contiguous inputs if necessary 

212 is_transpose_a = False 

213 is_transpose_b = False 

214 if not A.is_contiguous(): 

215 if A.stride(0) == 1 and A.stride(1) == A.shape[0]: 

216 is_transpose_a = True 

217 else: 

218 A = A.contiguous() 

219 if not B.is_contiguous(): 

220 if B.stride(0) == 1 and B.stride(1) == B.shape[0]: 

221 is_transpose_b = True 

222 else: 

223 B = B.contiguous() 

224 ab_type = elem_type 

225 a_type = A.dtype 

226 b_type = B.dtype 

227 assert a_type == b_type, "Mat A and Mat B should have the same dtype" 

228 c_type = a_type 

229 C = torch.empty((M, N), dtype=c_type, device=device) 

230 Bias = Bias.broadcast_to(C.shape).contiguous() 

231 desc_a = create_tma_device_descriptor(A, BLOCK_M, BLOCK_K, device) 

232 desc_b = create_tma_device_descriptor(B, BLOCK_K, BLOCK_N, device) 

233 desc_bias = create_tma_device_descriptor(Bias, BLOCK_M, BLOCK_N, device) 

234 desc_c = create_tma_device_descriptor(C, BLOCK_M, BLOCK_N, device) 

235 addmm_sqmma_kernel[(triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1, 1)]( 

236 desc_a, 

237 desc_b, 

238 desc_bias, 

239 desc_c, 

240 M, 

241 N, 

242 K, 

243 BLOCK_M, 

244 BLOCK_N, 

245 BLOCK_K, 

246 alpha, 

247 beta, 

248 get_triton_type(ab_type), 

249 get_triton_type(c_type), 

250 is_transpose_a, 

251 is_transpose_b, 

252 num_warps=num_warps, 

253 num_stages=num_stages, 

254 ) 

255 return C 

256 

257 

258def addmm(bias, mat1, mat2, *, beta=1, alpha=1): 

259 a_dtype = mat1.dtype 

260 b_dtype = mat2.dtype 

261 M, K = mat1.shape 

262 _, N = mat2.shape 

263 use_sqmma = should_enable_sqmma(a_dtype, b_dtype, M, N, K) 

264 

265 if use_sqmma: 

266 BLOCK_M = 256 if M % 256 == 0 else 128 

267 BLOCK_N = BLOCK_M 

268 BLOCK_K = 64 

269 num_warps = 16 if BLOCK_M == 256 else 4 

270 num_stages = 1 

271 return addmm_sqmma( 

272 mat1, 

273 mat2, 

274 bias, 

275 a_dtype, 

276 alpha, 

277 beta, 

278 M, 

279 N, 

280 K, 

281 BLOCK_M, 

282 BLOCK_N, 

283 BLOCK_K, 

284 num_warps, 

285 num_stages, 

286 ) 

287 else: 

288 enable_sqmma = os.environ.pop("MUSA_ENABLE_SQMMA", None) 

289 result = addmm_fma(bias, mat1, mat2, alpha=alpha, beta=beta) 

290 if enable_sqmma: 

291 os.environ["MUSA_ENABLE_SQMMA"] = enable_sqmma 

292 return result