Coverage for src/flag_gems/runtime/backend/_mthreads/ops/addmm.py: 0%
140 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
1import logging
2import os
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import broadcastable_to, libentry, libtuner
11from flag_gems.utils import triton_lang_extension as tle
13from .utils import create_tma_device_descriptor, get_cached_tma_device_descriptor
15logger = logging.getLogger(
16 f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}'
17)
20EXPAND_CONFIG_FILENAME = os.path.normpath(
21 os.path.join(os.path.dirname(__file__), "..", "addmm_mthreads_expand.yaml")
22)
25def is_supported_sqmma_layout(tensor):
26 return tensor.is_contiguous() or (
27 tensor.stride(0) == 1 and tensor.stride(1) == tensor.shape[0]
28 )
31def is_sqmma_compatible(a, b, N, K):
32 return (
33 a.dim() == 2
34 and b.dim() == 2
35 and a.dtype == b.dtype
36 and a.dtype in (torch.float16, torch.bfloat16)
37 and is_supported_sqmma_layout(a)
38 and is_supported_sqmma_layout(b)
39 and N % 8 == 0
40 and K % 8 == 0
41 )
44@libentry()
45@libtuner(
46 configs=runtime.get_tuned_config("addmm"),
47 key=["M", "N", "K"],
48)
49@triton.jit(do_not_specialize=["alpha", "beta"])
50def addmm_kernel(
51 a_ptr,
52 b_ptr,
53 i_ptr,
54 c_ptr,
55 alpha,
56 beta,
57 M,
58 N,
59 K,
60 stride_am,
61 stride_ak,
62 stride_bk,
63 stride_bn,
64 stride_im,
65 stride_in,
66 stride_cm,
67 stride_cn,
68 BLOCK_SIZE_M: tl.constexpr,
69 BLOCK_SIZE_N: tl.constexpr,
70 BLOCK_SIZE_K: tl.constexpr,
71):
72 pid_m = tle.program_id(0)
73 pid_n = tle.program_id(1)
75 offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
76 offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
77 offs_k = tl.arange(0, BLOCK_SIZE_K)
78 a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
79 b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
81 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
82 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
83 a = tl.load(
84 a_ptrs,
85 mask=(offs_am[:, None] < M) & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
86 other=0.0,
87 )
88 b = tl.load(
89 b_ptrs,
90 mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & (offs_bn[None, :] < N),
91 other=0.0,
92 )
93 accumulator += tl.dot(a, b, allow_tf32=False)
94 a_ptrs += BLOCK_SIZE_K * stride_ak
95 b_ptrs += BLOCK_SIZE_K * stride_bk
97 offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
98 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
99 c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
100 c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
101 i_ptrs = i_ptr + stride_im * offs_cm[:, None] + stride_in * offs_cn[None, :]
102 bias = tl.load(i_ptrs, mask=c_mask, other=0.0)
104 accumulator = accumulator * alpha + bias * beta
105 c = accumulator.to(bias.dtype)
106 tl.store(c_ptrs, c, mask=c_mask)
109def addmm_fma(bias, mat1, mat2, *, beta=1, alpha=1):
110 logger.debug("GEMS_MTHREADS ADDMM(FMA)")
111 assert mat1.shape[1] == mat2.shape[0], "Incompatible dimensions"
112 assert broadcastable_to(
113 bias.shape, (mat1.shape[0], mat2.shape[1])
114 ), "Incompatible input shape"
115 M, K = mat1.shape
116 _, N = mat2.shape
118 mat1 = mat1.contiguous()
119 mat2 = mat2.contiguous()
120 out = torch.empty((M, N), device=mat1.device, dtype=mat1.dtype)
121 bias = bias.broadcast_to(out.shape).contiguous()
123 grid = lambda META: (
124 triton.cdiv(M, META["BLOCK_SIZE_M"]),
125 triton.cdiv(N, META["BLOCK_SIZE_N"]),
126 )
127 with torch_device_fn.device(mat1.device):
128 addmm_kernel[grid](
129 mat1,
130 mat2,
131 bias,
132 out,
133 alpha,
134 beta,
135 M,
136 N,
137 K,
138 mat1.stride(0),
139 mat1.stride(1),
140 mat2.stride(0),
141 mat2.stride(1),
142 bias.stride(0),
143 bias.stride(1),
144 out.stride(0),
145 out.stride(1),
146 )
147 return out
150def addmm_sqmma_descriptor_pre_hook(nargs):
151 a = nargs["A"]
152 b = nargs["B"]
153 bias = nargs["Bias"]
154 c = nargs["C"]
155 block_m = nargs["BLOCK_SIZE_M"]
156 block_n = nargs["BLOCK_SIZE_N"]
157 block_k = nargs["BLOCK_SIZE_K"]
158 device = c.device
160 nargs["a_desc_ptr"].copy_(
161 get_cached_tma_device_descriptor(a, block_m, block_k, device)
162 )
163 nargs["b_desc_ptr"].copy_(
164 get_cached_tma_device_descriptor(b, block_k, block_n, device)
165 )
166 nargs["bias_desc_ptr"].copy_(
167 get_cached_tma_device_descriptor(bias, block_m, block_n, device)
168 )
169 nargs["c_desc_ptr"].copy_(create_tma_device_descriptor(c, block_m, block_n, device))
172@libentry()
173@libtuner(
174 configs=runtime.ops_get_configs(
175 "addmm_sqmma",
176 pre_hook=addmm_sqmma_descriptor_pre_hook,
177 yaml_path=EXPAND_CONFIG_FILENAME,
178 )
179 if os.environ.get("USE_FLAGTUNE") == "1"
180 else [
181 triton.Config(
182 {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
183 num_stages=1,
184 num_warps=4,
185 pre_hook=addmm_sqmma_descriptor_pre_hook,
186 )
187 ],
188 key=["M", "N", "K"],
189 strategy=runtime.get_expand_config("addmm_sqmma", yaml_path=EXPAND_CONFIG_FILENAME)[
190 "strategy"
191 ]
192 if os.environ.get("USE_FLAGTUNE") == "1"
193 else ["default", "default", "default"],
194 warmup=5,
195 rep=5,
196)
197@triton.jit(do_not_specialize=["alpha", "beta"])
198def addmm_sqmma_kernel(
199 A,
200 B,
201 Bias,
202 C,
203 a_desc_ptr,
204 b_desc_ptr,
205 bias_desc_ptr,
206 c_desc_ptr,
207 M,
208 N,
209 K,
210 alpha,
211 beta,
212 BLOCK_SIZE_M: tl.constexpr,
213 BLOCK_SIZE_N: tl.constexpr,
214 BLOCK_SIZE_K: tl.constexpr,
215 ab_type: tl.constexpr,
216 c_type: tl.constexpr,
217 is_transpose_a: tl.constexpr = False,
218 is_transpose_b: tl.constexpr = False,
219):
220 pid = tl.program_id(axis=0)
221 num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
222 pid_m = pid % num_pid_m
223 pid_n = pid // num_pid_m
224 offs_am = pid_m * BLOCK_SIZE_M
225 offs_bn = pid_n * BLOCK_SIZE_N
226 offs_k = 0
227 input_type = ab_type
228 output_type = c_type
229 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
230 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
231 a = tl._experimental_descriptor_load(
232 a_desc_ptr,
233 [offs_am, offs_k],
234 [BLOCK_SIZE_M, BLOCK_SIZE_K],
235 input_type,
236 is_transpose_a,
237 )
238 b = tl._experimental_descriptor_load(
239 b_desc_ptr,
240 [offs_k, offs_bn],
241 [BLOCK_SIZE_K, BLOCK_SIZE_N],
242 input_type,
243 is_transpose_b,
244 )
245 accumulator = tl.dot(a, b, acc=accumulator)
246 offs_k += BLOCK_SIZE_K
247 bias = tl._experimental_descriptor_load(
248 bias_desc_ptr, [offs_am, offs_bn], [BLOCK_SIZE_M, BLOCK_SIZE_N], input_type
249 )
250 result = (alpha * accumulator.to(output_type) + beta * bias.to(output_type)).to(
251 output_type
252 )
253 tl._experimental_descriptor_store(c_desc_ptr, result, [offs_am, offs_bn])
256def get_triton_type(elem_type):
257 type_map = {
258 torch.float16: tl.float16,
259 torch.bfloat16: tl.bfloat16,
260 torch.float8_e4m3fn: tl.float8e4nv,
261 }
262 return type_map.get(elem_type, None)
265def addmm_sqmma(mat1, mat2, bias, elem_type, alpha, beta, M, N, K):
266 logger.debug("GEMS_MTHREADS ADDMM(SQMMA)")
267 device = mat1.device
268 assert broadcastable_to(
269 bias.shape, (mat1.shape[0], mat2.shape[1])
270 ), "Incompatible input shape"
271 # handle non-contiguous inputs if necessary
272 is_transpose_a = False
273 is_transpose_b = False
274 if not mat1.is_contiguous():
275 if mat1.stride(0) == 1 and mat1.stride(1) == mat1.shape[0]:
276 is_transpose_a = True
277 else:
278 mat1 = mat1.contiguous()
279 if not mat2.is_contiguous():
280 if mat2.stride(0) == 1 and mat2.stride(1) == mat2.shape[0]:
281 is_transpose_b = True
282 else:
283 mat2 = mat2.contiguous()
284 ab_type = elem_type
285 a_type = mat1.dtype
286 b_type = mat2.dtype
287 assert a_type == b_type, "Mat A and Mat B should have the same dtype"
288 c_type = a_type
289 C = torch.empty((M, N), dtype=c_type, device=device)
290 bias = bias.broadcast_to(C.shape).contiguous()
291 desc_a = torch.empty((64,), dtype=torch.int8, device=device)
292 desc_b = torch.empty((64,), dtype=torch.int8, device=device)
293 desc_bias = torch.empty((64,), dtype=torch.int8, device=device)
294 desc_c = torch.empty((64,), dtype=torch.int8, device=device)
295 grid = lambda META: (
296 triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
297 1,
298 1,
299 )
300 addmm_sqmma_kernel[grid](
301 mat1,
302 mat2,
303 bias,
304 C,
305 desc_a,
306 desc_b,
307 desc_bias,
308 desc_c,
309 M,
310 N,
311 K,
312 alpha,
313 beta,
314 ab_type=get_triton_type(ab_type),
315 c_type=get_triton_type(c_type),
316 is_transpose_a=is_transpose_a,
317 is_transpose_b=is_transpose_b,
318 )
319 return C
322def addmm(bias, mat1, mat2, *, beta=1, alpha=1):
323 a_dtype = mat1.dtype
324 b_dtype = mat2.dtype
325 M, K = mat1.shape
326 _, N = mat2.shape
328 need_sqmma = a_dtype != torch.float32 and b_dtype != torch.float32
329 prev_sqmma = os.environ.get("MUSA_ENABLE_SQMMA")
330 if need_sqmma:
331 os.environ["MUSA_ENABLE_SQMMA"] = "1"
332 else:
333 os.environ.pop("MUSA_ENABLE_SQMMA", None)
334 try:
335 if is_sqmma_compatible(mat1, mat2, N, K):
336 return addmm_sqmma(
337 mat1,
338 mat2,
339 bias,
340 a_dtype,
341 alpha,
342 beta,
343 M,
344 N,
345 K,
346 )
347 else:
348 return addmm_fma(bias, mat1, mat2, alpha=alpha, beta=beta)
349 finally:
350 if prev_sqmma is None:
351 os.environ.pop("MUSA_ENABLE_SQMMA", None)
352 else:
353 os.environ["MUSA_ENABLE_SQMMA"] = prev_sqmma