Coverage for src/flag_gems/runtime/backend/_mthreads/ops/addmm.py: 0%
120 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-07 22:33 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-07 22:33 +0800
1import logging
2import os
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import broadcastable_to, libentry, libtuner
11from flag_gems.utils import triton_lang_extension as tle
13from .utils import create_tma_device_descriptor, should_enable_sqmma
15logger = logging.getLogger(
16 f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}'
17)
20@libentry()
21@libtuner(
22 configs=runtime.get_tuned_config("addmm"),
23 key=["M", "N", "K"],
24)
25@triton.jit(do_not_specialize=["alpha", "beta"])
26def addmm_kernel(
27 a_ptr,
28 b_ptr,
29 i_ptr,
30 c_ptr,
31 alpha,
32 beta,
33 M,
34 N,
35 K,
36 stride_am,
37 stride_ak,
38 stride_bk,
39 stride_bn,
40 stride_im,
41 stride_in,
42 stride_cm,
43 stride_cn,
44 BLOCK_SIZE_M: tl.constexpr,
45 BLOCK_SIZE_N: tl.constexpr,
46 BLOCK_SIZE_K: tl.constexpr,
47):
48 pid_m = tle.program_id(0)
49 pid_n = tle.program_id(1)
51 offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
52 offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
53 offs_k = tl.arange(0, BLOCK_SIZE_K)
54 a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
55 b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
57 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
58 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
59 a = tl.load(
60 a_ptrs,
61 mask=(offs_am[:, None] < M) & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
62 other=0.0,
63 )
64 b = tl.load(
65 b_ptrs,
66 mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & (offs_bn[None, :] < N),
67 other=0.0,
68 )
69 accumulator += tl.dot(a, b, allow_tf32=False)
70 a_ptrs += BLOCK_SIZE_K * stride_ak
71 b_ptrs += BLOCK_SIZE_K * stride_bk
73 offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
74 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
75 c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
76 c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
77 i_ptrs = i_ptr + stride_im * offs_cm[:, None] + stride_in * offs_cn[None, :]
78 bias = tl.load(i_ptrs, mask=c_mask, other=0.0)
80 accumulator = accumulator * alpha + bias * beta
81 c = accumulator.to(bias.dtype)
82 tl.store(c_ptrs, c, mask=c_mask)
85def addmm_fma(bias, mat1, mat2, *, beta=1, alpha=1):
86 logger.debug("GEMS_MTHREADS ADDMM(FMA)")
87 assert mat1.shape[1] == mat2.shape[0], "Incompatible dimensions"
88 assert broadcastable_to(
89 bias.shape, (mat1.shape[0], mat2.shape[1])
90 ), "Incompatible input shape"
91 M, K = mat1.shape
92 _, N = mat2.shape
94 mat1 = mat1.contiguous()
95 mat2 = mat2.contiguous()
96 out = torch.empty((M, N), device=mat1.device, dtype=mat1.dtype)
97 bias = bias.broadcast_to(out.shape).contiguous()
99 grid = lambda META: (
100 triton.cdiv(M, META["BLOCK_SIZE_M"]),
101 triton.cdiv(N, META["BLOCK_SIZE_N"]),
102 )
103 with torch_device_fn.device(mat1.device):
104 addmm_kernel[grid](
105 mat1,
106 mat2,
107 bias,
108 out,
109 alpha,
110 beta,
111 M,
112 N,
113 K,
114 mat1.stride(0),
115 mat1.stride(1),
116 mat2.stride(0),
117 mat2.stride(1),
118 bias.stride(0),
119 bias.stride(1),
120 out.stride(0),
121 out.stride(1),
122 )
123 return out
126@triton.jit
127def addmm_sqmma_kernel(
128 a_desc_ptr,
129 b_desc_ptr,
130 bias_desc_ptr,
131 c_desc_ptr,
132 M,
133 N,
134 K,
135 BLOCK_SIZE_M: tl.constexpr,
136 BLOCK_SIZE_N: tl.constexpr,
137 BLOCK_SIZE_K: tl.constexpr,
138 alpha: tl.constexpr,
139 beta: tl.constexpr,
140 ab_type: tl.constexpr,
141 c_type: tl.constexpr,
142 is_transpose_a: tl.constexpr = False,
143 is_transpose_b: tl.constexpr = False,
144):
145 pid = tl.program_id(axis=0)
146 num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
147 pid_m = pid % num_pid_m
148 pid_n = pid // num_pid_m
149 offs_am = pid_m * BLOCK_SIZE_M
150 offs_bn = pid_n * BLOCK_SIZE_N
151 offs_k = 0
152 input_type = ab_type
153 output_type = c_type
154 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
155 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
156 a = tl._experimental_descriptor_load(
157 a_desc_ptr,
158 [offs_am, offs_k],
159 [BLOCK_SIZE_M, BLOCK_SIZE_K],
160 input_type,
161 is_transpose_a,
162 )
163 b = tl._experimental_descriptor_load(
164 b_desc_ptr,
165 [offs_k, offs_bn],
166 [BLOCK_SIZE_K, BLOCK_SIZE_N],
167 input_type,
168 is_transpose_b,
169 )
170 accumulator = tl.dot(a, b, acc=accumulator)
171 offs_k += BLOCK_SIZE_K
172 bias = tl._experimental_descriptor_load(
173 bias_desc_ptr, [offs_am, offs_bn], [BLOCK_SIZE_M, BLOCK_SIZE_N], input_type
174 )
175 result = (alpha * accumulator.to(output_type) + beta * bias.to(output_type)).to(
176 output_type
177 )
178 tl._experimental_descriptor_store(c_desc_ptr, result, [offs_am, offs_bn])
181def get_triton_type(elem_type):
182 type_map = {
183 torch.float16: tl.float16,
184 torch.bfloat16: tl.bfloat16,
185 torch.float8_e4m3fn: tl.float8e4nv,
186 }
187 return type_map.get(elem_type, None)
190def addmm_sqmma(
191 A,
192 B,
193 Bias,
194 elem_type,
195 alpha,
196 beta,
197 M,
198 N,
199 K,
200 BLOCK_M,
201 BLOCK_N,
202 BLOCK_K,
203 num_warps,
204 num_stages,
205):
206 logger.debug("GEMS_MTHREADS ADDMM(SQMMA)")
207 device = "musa"
208 assert broadcastable_to(
209 Bias.shape, (A.shape[0], B.shape[1])
210 ), "Incompatible input shape"
211 # handle non-contiguous inputs if necessary
212 is_transpose_a = False
213 is_transpose_b = False
214 if not A.is_contiguous():
215 if A.stride(0) == 1 and A.stride(1) == A.shape[0]:
216 is_transpose_a = True
217 else:
218 A = A.contiguous()
219 if not B.is_contiguous():
220 if B.stride(0) == 1 and B.stride(1) == B.shape[0]:
221 is_transpose_b = True
222 else:
223 B = B.contiguous()
224 ab_type = elem_type
225 a_type = A.dtype
226 b_type = B.dtype
227 assert a_type == b_type, "Mat A and Mat B should have the same dtype"
228 c_type = a_type
229 C = torch.empty((M, N), dtype=c_type, device=device)
230 Bias = Bias.broadcast_to(C.shape).contiguous()
231 desc_a = create_tma_device_descriptor(A, BLOCK_M, BLOCK_K, device)
232 desc_b = create_tma_device_descriptor(B, BLOCK_K, BLOCK_N, device)
233 desc_bias = create_tma_device_descriptor(Bias, BLOCK_M, BLOCK_N, device)
234 desc_c = create_tma_device_descriptor(C, BLOCK_M, BLOCK_N, device)
235 addmm_sqmma_kernel[(triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1, 1)](
236 desc_a,
237 desc_b,
238 desc_bias,
239 desc_c,
240 M,
241 N,
242 K,
243 BLOCK_M,
244 BLOCK_N,
245 BLOCK_K,
246 alpha,
247 beta,
248 get_triton_type(ab_type),
249 get_triton_type(c_type),
250 is_transpose_a,
251 is_transpose_b,
252 num_warps=num_warps,
253 num_stages=num_stages,
254 )
255 return C
258def addmm(bias, mat1, mat2, *, beta=1, alpha=1):
259 a_dtype = mat1.dtype
260 b_dtype = mat2.dtype
261 M, K = mat1.shape
262 _, N = mat2.shape
263 use_sqmma = should_enable_sqmma(a_dtype, b_dtype, M, N, K)
265 if use_sqmma:
266 BLOCK_M = 256 if M % 256 == 0 else 128
267 BLOCK_N = BLOCK_M
268 BLOCK_K = 64
269 num_warps = 16 if BLOCK_M == 256 else 4
270 num_stages = 1
271 return addmm_sqmma(
272 mat1,
273 mat2,
274 bias,
275 a_dtype,
276 alpha,
277 beta,
278 M,
279 N,
280 K,
281 BLOCK_M,
282 BLOCK_N,
283 BLOCK_K,
284 num_warps,
285 num_stages,
286 )
287 else:
288 enable_sqmma = os.environ.pop("MUSA_ENABLE_SQMMA", None)
289 result = addmm_fma(bias, mat1, mat2, alpha=alpha, beta=beta)
290 if enable_sqmma:
291 os.environ["MUSA_ENABLE_SQMMA"] = enable_sqmma
292 return result