Coverage for src/flag_gems/runtime/backend/_mthreads/ops/mm.py: 0%
168 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
1import logging
2import os
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry, libtuner
11from flag_gems.utils import triton_lang_extension as tle
13from .utils import create_tma_device_descriptor, should_enable_sqmma
15logger = logging.getLogger(
16 f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}'
17)
20@triton.jit
21def prev_multiple_of(a, b):
22 # the largest x<a that x%b ==0
23 return tl.cdiv(a, b) * b - b
26@libentry()
27@libtuner(
28 configs=runtime.get_tuned_config("mm"),
29 key=["M", "N", "K"],
30 strategy=["align32", "align32", "align32"],
31)
32@triton.jit
33def mm_kernel(
34 A,
35 B,
36 C,
37 M,
38 N,
39 K,
40 stride_am,
41 stride_ak,
42 stride_bk,
43 stride_bn,
44 stride_cm,
45 stride_cn,
46 BLOCK_M: tl.constexpr,
47 BLOCK_N: tl.constexpr,
48 BLOCK_K: tl.constexpr,
49 GROUP_M: tl.constexpr,
50):
51 # matrix multiplication
52 pid = tle.program_id(0)
53 grid_m = tl.cdiv(M, BLOCK_M)
54 grid_n = tl.cdiv(N, BLOCK_N)
55 # re-order program ID for better L2 performance
56 width = GROUP_M * grid_n
57 group_id = pid // width
58 group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
59 pid_m = group_id * GROUP_M + (pid % group_size)
60 pid_n = (pid % width) // (group_size)
61 # do matrix multiplication
62 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
63 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
64 ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M).to(tl.int64)
65 rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N).to(tl.int64)
66 rm = rm.to(tl.int64)
67 rn = rn.to(tl.int64)
68 prev_multiple = prev_multiple_of(K, BLOCK_K)
70 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
71 for start_k in range(0, prev_multiple, BLOCK_K):
72 rk = (start_k + tl.arange(0, BLOCK_K)).to(tl.int64)
73 a = tl.load(A + (ram[:, None] * stride_am + rk[None, :] * stride_ak))
74 b = tl.load(B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn))
75 if a.dtype != b.dtype:
76 a = a.to(C.dtype.element_ty)
77 b = b.to(C.dtype.element_ty)
78 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)
80 # loop peeling
81 rk = (prev_multiple + tl.arange(0, BLOCK_K)).to(tl.int64)
82 mask_k = rk < K
83 a = tl.load(
84 A + (ram[:, None] * stride_am + rk[None, :] * stride_ak), mask=mask_k[None, :]
85 )
86 b = tl.load(
87 B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn), mask=mask_k[:, None]
88 )
89 if a.dtype != b.dtype:
90 a = a.to(C.dtype.element_ty)
91 b = b.to(C.dtype.element_ty)
92 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)
94 acc = acc.to(C.dtype.element_ty)
95 # rematerialize rm and rn to save registers
96 rm = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)).to(tl.int64)
97 rn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)).to(tl.int64)
98 C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
99 mask = (rm < M)[:, None] & (rn < N)[None, :]
100 # handles write-back with reduction-splitting
101 tl.store(C, acc, mask=mask)
104_ordered_datatypes = [torch.float16, torch.bfloat16, torch.float32]
107def get_higher_dtype(a, b):
108 if a is b:
109 return a
111 assert a in _ordered_datatypes
112 assert b in _ordered_datatypes
114 for d in _ordered_datatypes:
115 if a is d:
116 return b
117 if b is d:
118 return a
121def mm_fma(a, b):
122 logger.debug("GEMS_MTHREADS MM(FMA)")
123 device = a.device
124 # handle non-contiguous inputs if necessary
125 if a.stride(0) > 1 and a.stride(1) > 1:
126 a = a.contiguous()
127 if b.stride(0) > 1 and b.stride(1) > 1:
128 b = b.contiguous()
129 # checks constraints
130 assert a.shape[1] == b.shape[0], "incompatible dimensions"
131 M, K = a.shape
132 _, N = b.shape
133 # allocates output
134 c_dtype = get_higher_dtype(a.dtype, b.dtype)
135 c = torch.empty((M, N), device=device, dtype=c_dtype)
136 # launch kernel
137 grid = lambda META: (
138 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
139 )
140 with torch_device_fn.device(a.device):
141 mm_kernel[grid](
142 a,
143 b,
144 c,
145 M,
146 N,
147 K,
148 a.stride(0),
149 a.stride(1),
150 b.stride(0),
151 b.stride(1),
152 c.stride(0),
153 c.stride(1),
154 GROUP_M=8,
155 )
156 return c
159def mm_out(a, b, *, out):
160 logger.debug("GEMS_MTHREADS MM_OUT")
161 # handle non-contiguous inputs if necessary
162 if a.stride(0) > 1 and a.stride(1) > 1:
163 a = a.contiguous()
164 if b.stride(0) > 1 and b.stride(1) > 1:
165 b = b.contiguous()
166 # checks constraints
167 assert a.shape[1] == b.shape[0], "incompatible dimensions"
168 M, K = a.shape
169 _, N = b.shape
170 # allocates output
171 c = out
172 # launch kernel
173 grid = lambda META: (
174 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
175 )
176 with torch_device_fn.device(a.device):
177 mm_kernel[grid](
178 a,
179 b,
180 c,
181 M,
182 N,
183 K,
184 a.stride(0),
185 a.stride(1),
186 b.stride(0),
187 b.stride(1),
188 c.stride(0),
189 c.stride(1),
190 GROUP_M=8,
191 )
192 return c
195@triton.jit
196def mm_sqmma_kernel(
197 a_desc_ptr,
198 b_desc_ptr,
199 c_desc_ptr,
200 M,
201 N,
202 K,
203 GROUP_M: tl.constexpr,
204 BLOCK_SIZE_M: tl.constexpr,
205 BLOCK_SIZE_N: tl.constexpr,
206 BLOCK_SIZE_K: tl.constexpr,
207 ab_dtype: tl.constexpr,
208 c_dtype: tl.constexpr,
209 is_transpose_a: tl.constexpr = False,
210 is_transpose_b: tl.constexpr = False,
211):
212 pid = tle.program_id(0)
213 grid_m = tl.cdiv(M, BLOCK_SIZE_M)
214 grid_n = tl.cdiv(N, BLOCK_SIZE_N)
215 width = GROUP_M * grid_n
216 group_id = pid // width
217 group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
218 pid_m = group_id * GROUP_M + (pid % group_size)
219 pid_n = (pid % width) // (group_size)
220 offs_am = pid_m * BLOCK_SIZE_M
221 offs_bn = pid_n * BLOCK_SIZE_N
222 offs_k = 0
223 offs_am = offs_am.to(tl.int32)
224 offs_bn = offs_bn.to(tl.int32)
225 offs_k = offs_k.to(tl.int32)
226 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
227 tme_load_ab_dtype = ab_dtype
228 c_store_dtype = c_dtype
229 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
230 a = tl._experimental_descriptor_load(
231 a_desc_ptr,
232 [offs_am, offs_k],
233 [BLOCK_SIZE_M, BLOCK_SIZE_K],
234 tme_load_ab_dtype,
235 is_transpose_a,
236 )
237 b = tl._experimental_descriptor_load(
238 b_desc_ptr,
239 [offs_k, offs_bn],
240 [BLOCK_SIZE_K, BLOCK_SIZE_N],
241 tme_load_ab_dtype,
242 is_transpose_b,
243 )
244 accumulator += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)
245 offs_k += BLOCK_SIZE_K
246 accumulator = accumulator.to(c_store_dtype)
247 tl._experimental_descriptor_store(c_desc_ptr, accumulator, [offs_am, offs_bn])
250def get_triton_type(elem_type):
251 type_map = {
252 torch.float16: tl.float16,
253 torch.bfloat16: tl.bfloat16,
254 torch.float8_e4m3fn: tl.float8e4nv,
255 }
256 return type_map.get(elem_type, None)
259def mm_sqmma(A, B, M, N, K, GROUP_M, BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_stages):
260 logger.debug("GEMS_MTHREADS MM(SQMMA)")
261 device = "musa"
262 # handle non-contiguous inputs if necessary
263 is_transpose_a = False
264 is_transpose_b = False
265 if not A.is_contiguous():
266 if A.stride(0) == 1 and A.stride(1) == A.shape[0]:
267 is_transpose_a = True
268 else:
269 A = A.contiguous()
270 if not B.is_contiguous():
271 if B.stride(0) == 1 and B.stride(1) == B.shape[0]:
272 is_transpose_b = True
273 else:
274 B = B.contiguous()
275 a_type = A.dtype
276 b_type = B.dtype
277 assert a_type == b_type, "Mat A and Mat B should have the same dtype"
278 c_dtype = get_higher_dtype(a_type, b_type)
279 C = torch.empty((M, N), dtype=c_dtype, device=device)
280 desc_a = create_tma_device_descriptor(A, BLOCK_M, BLOCK_K, device)
281 desc_b = create_tma_device_descriptor(B, BLOCK_K, BLOCK_N, device)
282 desc_c = create_tma_device_descriptor(C, BLOCK_M, BLOCK_N, device)
283 mm_sqmma_kernel[(triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1, 1)](
284 desc_a,
285 desc_b,
286 desc_c,
287 M,
288 N,
289 K,
290 GROUP_M,
291 BLOCK_M,
292 BLOCK_N,
293 BLOCK_K,
294 get_triton_type(a_type),
295 get_triton_type(c_dtype),
296 num_warps=num_warps,
297 num_stages=num_stages,
298 is_transpose_a=is_transpose_a,
299 is_transpose_b=is_transpose_b,
300 )
301 return C
304def mm(a, b):
305 a_dtype = a.dtype
306 b_dtype = b.dtype
307 M, K = a.shape
308 _, N = b.shape
309 use_sqmma = should_enable_sqmma(a_dtype, b_dtype, M, N, K)
310 if use_sqmma:
311 GROUP_M = 8
312 BLOCK_M = 128
313 BLOCK_N = BLOCK_M
314 BLOCK_K = 64
315 num_warps = 16 if BLOCK_M == 256 else 4
316 num_stages = 1
317 return mm_sqmma(
318 a,
319 b,
320 M,
321 N,
322 K,
323 GROUP_M,
324 BLOCK_M,
325 BLOCK_N,
326 BLOCK_K,
327 num_warps,
328 num_stages,
329 )
330 else:
331 enable_sqmma = os.environ.pop("MUSA_ENABLE_SQMMA", None)
332 result = mm_fma(a, b)
333 if enable_sqmma:
334 os.environ["MUSA_ENABLE_SQMMA"] = enable_sqmma
335 return result