Coverage for src/flag_gems/runtime/backend/_tsingmicro/ops/mm.py: 0%
94 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-18 02:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-18 02:36 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry, libtuner
10from flag_gems.utils import triton_lang_extension as tle
12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
15@triton.jit
16def prev_multiple_of(a, b):
17 # the largest x<a that x%b ==0
18 return tl.cdiv(a, b) * b - b
21@libentry()
22@libtuner(
23 configs=runtime.get_tuned_config("mm"),
24 # Add 'stride_am' and 'stride_bk' to trigger autotune for tensors with the same shape but different strides.
25 key=["M", "N", "K", "stride_am", "stride_bk"],
26 strategy=["align32", "align32", "align32", "align32", "align32"],
27 warmup=5,
28 rep=5,
29)
30@triton.jit
31def mm_kernel_general(
32 A,
33 B,
34 C,
35 M,
36 N,
37 K,
38 stride_am,
39 stride_ak,
40 stride_bk,
41 stride_bn,
42 stride_cm,
43 stride_cn,
44 BLOCK_M: tl.constexpr,
45 BLOCK_N: tl.constexpr,
46 BLOCK_K: tl.constexpr,
47 GROUP_M: tl.constexpr,
48):
49 # matrix multiplication
50 pid = tle.program_id(0)
51 grid_m = tl.cdiv(M, BLOCK_M)
52 grid_n = tl.cdiv(N, BLOCK_N)
53 # re-order program ID for better L2 performance
54 width = GROUP_M * grid_n
55 group_id = pid // width
56 group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
57 pid_m = group_id * GROUP_M + (pid % group_size)
58 pid_n = (pid % width) // (group_size)
59 # do matrix multiplication
60 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
61 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
62 ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M).to(tl.int64)
63 rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N).to(tl.int64)
64 rm = rm.to(tl.int64)
65 rn = rn.to(tl.int64)
66 prev_multiple = prev_multiple_of(K, BLOCK_K)
68 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
69 for start_k in range(0, prev_multiple, BLOCK_K):
70 rk = (start_k + tl.arange(0, BLOCK_K)).to(tl.int64)
71 a = tl.load(A + (ram[:, None] * stride_am + rk[None, :] * stride_ak))
72 b = tl.load(B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn))
73 if a.dtype != b.dtype:
74 a = a.to(C.dtype.element_ty)
75 b = b.to(C.dtype.element_ty)
76 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)
78 # loop peeling
79 rk = (prev_multiple + tl.arange(0, BLOCK_K)).to(tl.int64)
80 mask_k = rk < K
81 a = tl.load(
82 A + (ram[:, None] * stride_am + rk[None, :] * stride_ak),
83 mask=mask_k[None, :],
84 other=0.0,
85 )
86 b = tl.load(
87 B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn),
88 mask=mask_k[:, None],
89 other=0.0,
90 )
91 if a.dtype != b.dtype:
92 a = a.to(C.dtype.element_ty)
93 b = b.to(C.dtype.element_ty)
94 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)
96 acc = acc.to(C.dtype.element_ty)
97 # rematerialize rm and rn to save registers
98 rm = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)).to(tl.int64)
99 rn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)).to(tl.int64)
100 C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
101 mask = (rm < M)[:, None] & (rn < N)[None, :]
102 # handles write-back with reduction-splitting
103 tl.store(C, acc, mask=mask)
106_ordered_datatypes = [torch.float16, torch.bfloat16, torch.float32]
109def get_higher_dtype(a, b):
110 if a is b:
111 return a
113 assert a in _ordered_datatypes
114 assert b in _ordered_datatypes
116 for d in _ordered_datatypes:
117 if a is d:
118 return b
119 if b is d:
120 return a
123def general_mm(a, b, c, M, N, K):
124 logger.debug(
125 "GEMS MM, [mm scenario]: general, [shape info]: [-, %s, %s, %s](batch, M, N, K), "
126 "[A column-major]: %s, [B column-major]: %s",
127 M,
128 N,
129 K,
130 a.stride(0) == 1,
131 b.stride(0) == 1,
132 )
133 grid = lambda META: (
134 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
135 )
136 with torch_device_fn.device(a.device):
137 mm_kernel_general[grid](
138 a,
139 b,
140 c,
141 M,
142 N,
143 K,
144 a.stride(0),
145 a.stride(1),
146 b.stride(0),
147 b.stride(1),
148 c.stride(0),
149 c.stride(1),
150 GROUP_M=8,
151 )
152 return c
155def mm(a, b):
156 logger.debug("GEMS_TSINGMICRO mm")
157 device = a.device
158 # handle non-contiguous inputs if necessary
159 if a.stride(0) > 1 and a.stride(1) > 1:
160 a = a.contiguous()
161 if b.stride(0) > 1 and b.stride(1) > 1:
162 b = b.contiguous()
163 # checks constraints
164 assert a.shape[1] == b.shape[0], "incompatible dimensions"
165 M, K = a.shape
166 _, N = b.shape
167 # allocates output
168 c_dtype = get_higher_dtype(a.dtype, b.dtype)
169 c = torch.empty((M, N), device=device, dtype=c_dtype)
170 return general_mm(a, b, c, M, N, K)
173def mm_out(a, b, *, out):
174 logger.debug("GEMS_TSINGMICRO mm_out")
175 # handle non-contiguous inputs if necessary
176 if a.stride(0) > 1 and a.stride(1) > 1:
177 a = a.contiguous()
178 if b.stride(0) > 1 and b.stride(1) > 1:
179 b = b.contiguous()
180 # checks constraints
181 assert a.shape[1] == b.shape[0], "incompatible dimensions"
182 M, K = a.shape
183 _, N = b.shape
184 return general_mm(a, b, out, M, N, K)