Coverage for src/flag_gems/runtime/backend/_mthreads/ops/mm.py: 0%
197 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
1import logging
2import os
4import torch
5import triton
6import triton.language as tl
7from triton.tools.tensor_descriptor import TensorDescriptor
9from flag_gems import runtime
10from flag_gems.runtime import torch_device_fn
11from flag_gems.utils import libentry, libtuner
12from flag_gems.utils import triton_lang_extension as ext
14logger = logging.getLogger("flag_gems.runtime.backend._mthreads.ops.mm")
16EXPAND_CONFIG_FILENAME = os.path.normpath(
17 os.path.join(os.path.dirname(__file__), "..", "mm_mthreads_expand.yaml")
18)
20# Module-level capability flag: evaluated once at import time, then reused as
21# a constant for the entire process lifetime with no repeated parsing overhead.
22# False when Triton < 3.2 (e.g. 3.1), True when Triton >= 3.2.
23SQMMA_ON = tuple(int(x) for x in triton.__version__.split(".")[:2]) >= (3, 2)
26def is_supported_sqmma_layout(tensor):
27 return tensor.is_contiguous() or (
28 tensor.stride(0) == 1 and tensor.stride(1) == tensor.shape[0]
29 )
32def is_sqmma_compatible(a, b, N, K):
33 return (
34 SQMMA_ON
35 and a.dim() == 2
36 and b.dim() == 2
37 and a.dtype == b.dtype
38 and a.dtype in (torch.float16, torch.bfloat16)
39 and is_supported_sqmma_layout(a)
40 and is_supported_sqmma_layout(b)
41 and N % 8 == 0
42 and K % 8 == 0
43 )
46@triton.jit
47def prev_multiple_of(a, b):
48 # the largest x<a that x%b ==0
49 return tl.cdiv(a, b) * b - b
52@libentry()
53@libtuner(
54 configs=runtime.ops_get_configs("mm", yaml_path=EXPAND_CONFIG_FILENAME)
55 if os.environ.get("USE_FLAGTUNE") == "1"
56 else runtime.get_tuned_config("mm"),
57 key=["M", "N", "K", "stride_am", "stride_bk"],
58 strategy=runtime.get_expand_config("mm", yaml_path=EXPAND_CONFIG_FILENAME)[
59 "strategy"
60 ]
61 if os.environ.get("USE_FLAGTUNE") == "1"
62 else ["align32", "align32", "align32", "align32", "align32"],
63 warmup=5,
64 rep=5,
65)
66@triton.jit
67def mm_kernel(
68 A,
69 B,
70 C,
71 M,
72 N,
73 K,
74 stride_am,
75 stride_ak,
76 stride_bk,
77 stride_bn,
78 stride_cm,
79 stride_cn,
80 dtype: tl.constexpr,
81 BLOCK_M: tl.constexpr,
82 BLOCK_N: tl.constexpr,
83 BLOCK_K: tl.constexpr,
84 GROUP_M: tl.constexpr,
85 IS_FP64: tl.constexpr = False,
86):
87 # matrix multiplication
88 pid = ext.program_id(0)
89 grid_m = tl.cdiv(M, BLOCK_M)
90 grid_n = tl.cdiv(N, BLOCK_N)
91 # re-order program ID for better L2 performance
92 width = GROUP_M * grid_n
93 group_id = pid // width
94 group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
95 pid_m = group_id * GROUP_M + (pid % group_size)
96 pid_n = (pid % width) // (group_size)
97 # do matrix multiplication
98 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
99 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
100 ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M).to(tl.int64)
101 rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N).to(tl.int64)
102 rm = rm.to(tl.int64)
103 rn = rn.to(tl.int64)
104 prev_multiple = prev_multiple_of(K, BLOCK_K)
106 if IS_FP64:
107 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float64)
108 else:
109 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
110 for start_k in range(0, prev_multiple, BLOCK_K):
111 rk = (start_k + tl.arange(0, BLOCK_K)).to(tl.int64)
112 a = tl.load(A + (ram[:, None] * stride_am + rk[None, :] * stride_ak))
113 b = tl.load(B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn))
114 if a.dtype != b.dtype:
115 a = a.to(C.dtype.element_ty)
116 b = b.to(C.dtype.element_ty)
117 if IS_FP64:
118 acc += tl.dot(a, b, allow_tf32=False)
119 else:
120 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)
122 # loop peeling
123 rk = (prev_multiple + tl.arange(0, BLOCK_K)).to(tl.int64)
124 mask_k = rk < K
125 a = tl.load(
126 A + (ram[:, None] * stride_am + rk[None, :] * stride_ak), mask=mask_k[None, :]
127 )
128 b = tl.load(
129 B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn), mask=mask_k[:, None]
130 )
131 if a.dtype != b.dtype:
132 a = a.to(C.dtype.element_ty)
133 b = b.to(C.dtype.element_ty)
134 if IS_FP64:
135 acc += tl.dot(a, b, allow_tf32=False)
136 else:
137 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)
139 acc = acc.to(C.dtype.element_ty)
140 # rematerialize rm and rn to save registers
141 rm = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)).to(tl.int64)
142 rn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)).to(tl.int64)
143 C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
144 mask = (rm < M)[:, None] & (rn < N)[None, :]
145 # handles write-back with reduction-splitting
146 tl.store(C, acc, mask=mask)
149@libentry()
150@libtuner(
151 configs=runtime.ops_get_configs("gemv", yaml_path=EXPAND_CONFIG_FILENAME)
152 if os.environ.get("USE_FLAGTUNE") == "1"
153 else [triton.Config({"BLOCK_M": 64, "BLOCK_K": 64})],
154 key=["M", "K", "stride_am", "stride_bk"],
155 strategy=runtime.get_expand_config("gemv", yaml_path=EXPAND_CONFIG_FILENAME)[
156 "strategy"
157 ]
158 if os.environ.get("USE_FLAGTUNE") == "1"
159 else ["align32", "align32", "align32", "default"],
160 warmup=5,
161 rep=5,
162)
163@triton.jit
164def gemv_kernel(
165 A,
166 B,
167 C,
168 M,
169 K,
170 stride_am,
171 stride_ak,
172 stride_bk,
173 stride_cm,
174 BLOCK_M: tl.constexpr,
175 BLOCK_K: tl.constexpr,
176):
177 pid = ext.program_id(0)
179 row_start = pid * BLOCK_M
180 row_offset = row_start + tl.arange(0, BLOCK_M)
181 row_mask = row_offset < M
183 acc = tl.zeros((BLOCK_M,), dtype=tl.float32)
185 for k_start in range(0, K, BLOCK_K):
186 k_offset = k_start + tl.arange(0, BLOCK_K)
187 k_mask = k_offset < K
189 a_ptrs = A + row_offset[:, None] * stride_am + k_offset[None, :] * stride_ak
190 a = tl.load(a_ptrs, mask=row_mask[:, None] & k_mask[None, :], other=0.0)
192 b_ptrs = B + k_offset * stride_bk
193 b = tl.load(b_ptrs, mask=k_mask, other=0.0)
195 # Keep the reduction in fp32 so N=1 GEMV matches the mm path more closely.
196 a = a.to(tl.float32)
197 b = b.to(tl.float32)
198 acc += tl.sum(a * b[None, :], axis=1)
200 c_ptrs = C + row_offset * stride_cm
201 acc = acc.to(C.dtype.element_ty)
202 tl.store(c_ptrs, acc, mask=row_mask)
205_ordered_datatypes = [torch.float16, torch.bfloat16, torch.float32, torch.float64]
208def get_higher_dtype(a, b):
209 if a is b:
210 return a
212 assert a in _ordered_datatypes
213 assert b in _ordered_datatypes
215 for d in _ordered_datatypes:
216 if a is d:
217 return b
218 if b is d:
219 return a
222def mm_fma(a, b):
223 logger.debug("GEMS_MTHREADS MM(FMA)")
224 device = a.device
225 # handle non-contiguous inputs if necessary
226 if a.stride(0) > 1 and a.stride(1) > 1:
227 a = a.contiguous()
228 if b.stride(0) > 1 and b.stride(1) > 1:
229 b = b.contiguous()
230 # checks constraints
231 assert a.shape[1] == b.shape[0], "incompatible dimensions"
232 M, K = a.shape
233 _, N = b.shape
234 # allocates output
235 c_dtype = get_higher_dtype(a.dtype, b.dtype)
236 c = torch.empty((M, N), device=device, dtype=c_dtype)
237 # launch kernel
238 grid = lambda META: (
239 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
240 )
241 with torch_device_fn.device(a.device):
242 mm_kernel[grid](
243 a,
244 b,
245 c,
246 M,
247 N,
248 K,
249 a.stride(0),
250 a.stride(1),
251 b.stride(0),
252 b.stride(1),
253 c.stride(0),
254 c.stride(1),
255 dtype=str(a.dtype).split(".")[-1],
256 GROUP_M=8,
257 IS_FP64=a.dtype == torch.float64,
258 )
259 return c
262def gemv_mm(a, b, c, M, K):
263 logger.debug(
264 "GEMS_MTHREADS MM(GEMV), [shape info]: [%s, %s, 1](M, K, N)",
265 M,
266 K,
267 )
268 grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]),)
269 with torch_device_fn.device(a.device):
270 gemv_kernel[grid](
271 a,
272 b,
273 c,
274 M,
275 K,
276 a.stride(0),
277 a.stride(1),
278 b.stride(0),
279 c.stride(0),
280 )
281 return c
284def mm_out(a, b, *, out):
285 logger.debug("GEMS_MTHREADS MM_OUT")
286 # handle non-contiguous inputs if necessary
287 if a.stride(0) > 1 and a.stride(1) > 1:
288 a = a.contiguous()
289 if b.stride(0) > 1 and b.stride(1) > 1:
290 b = b.contiguous()
291 # checks constraints
292 assert a.shape[1] == b.shape[0], "incompatible dimensions"
293 M, K = a.shape
294 _, N = b.shape
295 # allocates output
296 c = out
297 if N == 1:
298 return gemv_mm(a, b, c, M, K)
299 # launch kernel
300 grid = lambda META: (
301 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
302 )
303 with torch_device_fn.device(a.device):
304 mm_kernel[grid](
305 a,
306 b,
307 c,
308 M,
309 N,
310 K,
311 a.stride(0),
312 a.stride(1),
313 b.stride(0),
314 b.stride(1),
315 c.stride(0),
316 c.stride(1),
317 dtype=str(a.dtype).split(".")[-1],
318 GROUP_M=8,
319 IS_FP64=a.dtype == torch.float64,
320 )
321 return c
324@triton.jit
325def mm_sqmma_kernel(
326 a_desc,
327 b_desc,
328 c_desc,
329 M,
330 N,
331 K,
332 dtype: tl.constexpr,
333 GROUP_M: tl.constexpr,
334 BLOCK_M: tl.constexpr,
335 BLOCK_N: tl.constexpr,
336 BLOCK_K: tl.constexpr,
337):
338 pid = ext.program_id(0)
339 grid_m = tl.cdiv(M, BLOCK_M)
340 grid_n = tl.cdiv(N, BLOCK_N)
341 width = GROUP_M * grid_n
342 group_id = pid // width
343 group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
344 pid_m = group_id * GROUP_M + (pid % group_size)
345 pid_n = (pid % width) // (group_size)
346 offs_am = (pid_m * BLOCK_M).to(tl.int32)
347 offs_bn = (pid_n * BLOCK_N).to(tl.int32)
348 offs_k = 0
349 offs_k = offs_k.to(tl.int32)
350 accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
351 for k in range(0, tl.cdiv(K, BLOCK_K)):
352 a = tl.load_tensor_descriptor(a_desc, [offs_am, offs_k])
353 b = tl.load_tensor_descriptor(b_desc, [offs_k, offs_bn])
354 accumulator = tl.dot(a, b, acc=accumulator)
355 offs_k += BLOCK_K
356 tl.store_tensor_descriptor(c_desc, [offs_am, offs_bn], accumulator.to(c_desc.dtype))
359def get_triton_type(elem_type):
360 type_map = {
361 torch.float16: tl.float16,
362 torch.bfloat16: tl.bfloat16,
363 torch.float8_e4m3fn: tl.float8e4nv,
364 }
365 return type_map.get(elem_type, None)
368def mm_sqmma(A, B, M, N, K, GROUP_M):
369 logger.debug("GEMS_MTHREADS MM(SQMMA)")
370 device = A.device
371 if not A.is_contiguous():
372 A = A.contiguous()
373 if not B.is_contiguous():
374 B = B.contiguous()
375 a_type = A.dtype
376 b_type = B.dtype
377 assert a_type == b_type, "Mat A and Mat B should have the same dtype"
378 c_dtype = get_higher_dtype(a_type, b_type)
379 C = torch.empty((M, N), dtype=c_dtype, device=device)
380 BLOCK_M = 128
381 BLOCK_N = 128
382 BLOCK_K = 64
383 desc_a = TensorDescriptor.from_tensor(A, [BLOCK_M, BLOCK_K])
384 desc_b = TensorDescriptor.from_tensor(B, [BLOCK_K, BLOCK_N])
385 desc_c = TensorDescriptor.from_tensor(C, [BLOCK_M, BLOCK_N])
386 grid = lambda META: (
387 triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),
388 1,
389 1,
390 )
391 mm_sqmma_kernel[grid](
392 desc_a,
393 desc_b,
394 desc_c,
395 M,
396 N,
397 K,
398 str(a_type).split(".")[-1],
399 GROUP_M,
400 BLOCK_M,
401 BLOCK_N,
402 BLOCK_K,
403 num_warps=4,
404 num_stages=1,
405 )
406 return C
409def mm(a, b):
410 a_dtype = a.dtype
411 b_dtype = b.dtype
412 M, K = a.shape
413 _, N = b.shape
414 if N == 1:
415 c_dtype = get_higher_dtype(a_dtype, b_dtype)
416 c = torch.empty((M, N), device=a.device, dtype=c_dtype)
417 return gemv_mm(a, b, c, M, K)
419 if is_sqmma_compatible(a, b, N, K):
420 GROUP_M = 8
421 return mm_sqmma(
422 a,
423 b,
424 M,
425 N,
426 K,
427 GROUP_M,
428 )
429 else:
430 return mm_fma(a, b)