Coverage for src/flag_gems/runtime/backend/_metax/ops/mm.py: 0%
116 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry, libtuner
11from flag_gems.utils import triton_lang_extension as tle
13logger = logging.getLogger("flag_gems." + __name__)
16@libentry()
17@libtuner(
18 configs=runtime.get_tuned_config("mm"),
19 key=["M", "N", "K"],
20)
21@triton.heuristics(runtime.get_heuristic_config("mm"))
22@triton.heuristics(
23 {
24 "UPGRADE": lambda args: math.ceil(
25 (args["M"] * args["N"]) / (args["BLOCK_M"] * args["BLOCK_N"])
26 ).bit_length()
27 > 31,
28 }
29)
30@triton.heuristics(
31 {
32 "UPGRADE_A_OFFS": lambda args: math.ceil(args["M"] * args["K"]).bit_length()
33 > 31,
34 }
35)
36@triton.heuristics(
37 {
38 "UPGRADE_B_OFFS": lambda args: math.ceil(args["K"] * args["N"]).bit_length()
39 > 31,
40 }
41)
42@triton.heuristics(
43 {
44 "UPGRADE_C_OFFS": lambda args: math.ceil(args["M"] * args["N"]).bit_length()
45 > 31,
46 }
47)
48@triton.jit
49def mm_kernel(
50 A,
51 B,
52 C,
53 M,
54 N,
55 K,
56 stride_am,
57 stride_ak,
58 stride_bk,
59 stride_bn,
60 stride_cm,
61 stride_cn,
62 dot_out_dtype: tl.constexpr,
63 BLOCK_M: tl.constexpr,
64 BLOCK_N: tl.constexpr,
65 BLOCK_K: tl.constexpr,
66 GROUP_M: tl.constexpr,
67 SPLIT_K: tl.constexpr,
68 EVEN_K: tl.constexpr,
69 UPGRADE: tl.constexpr,
70 UPGRADE_A_OFFS: tl.constexpr,
71 UPGRADE_B_OFFS: tl.constexpr,
72 UPGRADE_C_OFFS: tl.constexpr,
73):
74 # matrix multiplication
75 if UPGRADE:
76 pid = tle.program_id(0)
77 pid_z = tle.program_id(1)
78 else:
79 pid = tl.program_id(0)
80 pid_z = tl.program_id(1)
81 grid_m = tl.cdiv(M, BLOCK_M)
82 grid_n = tl.cdiv(N, BLOCK_N)
83 # re-order program ID for better L2 performance
84 width = GROUP_M * grid_n
85 group_id = pid // width
86 group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
87 pid_m = group_id * GROUP_M + (pid % group_size)
88 pid_n = (pid % width) // (group_size)
89 # do matrix multiplication
90 if UPGRADE_A_OFFS:
91 rm = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)).to(tl.int64)
92 ram = (tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)).to(tl.int64)
93 else:
94 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
95 ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
96 if UPGRADE_B_OFFS:
97 rn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)).to(tl.int64)
98 rbn = (tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)).to(tl.int64)
99 else:
100 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
101 rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
103 rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
104 # pointers
105 A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
106 B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
107 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype)
108 for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
109 if EVEN_K:
110 a = tl.load(A)
111 b = tl.load(B)
112 else:
113 k_remaining = K - k * (BLOCK_K * SPLIT_K)
114 _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty)
115 a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0)
116 b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0)
117 if a.dtype != b.dtype:
118 a = a.to(C.dtype.element_ty)
119 b = b.to(C.dtype.element_ty)
120 acc += tl.dot(a, b, out_dtype=dot_out_dtype, allow_tf32=False)
121 A += BLOCK_K * SPLIT_K * stride_ak
122 B += BLOCK_K * SPLIT_K * stride_bk
123 acc = acc.to(C.dtype.element_ty)
124 # rematerialize rm and rn to save registers
125 if UPGRADE_C_OFFS:
126 rm = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)).to(tl.int64)
127 rn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)).to(tl.int64)
128 C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn).to(tl.int64)
129 else:
130 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
131 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
132 C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
133 mask = (rm < M)[:, None] & (rn < N)[None, :]
134 # handles write-back with reduction-splitting
135 if SPLIT_K == 1:
136 tl.store(C, acc, mask=mask)
137 else:
138 tl.atomic_add(C, acc, mask=mask)
141_ordered_datatypes = [torch.float16, torch.bfloat16, torch.float32]
144def get_higher_dtype(a, b):
145 if a is b:
146 return a
148 assert a in _ordered_datatypes
149 assert b in _ordered_datatypes
151 for d in _ordered_datatypes:
152 if a is d:
153 return b
154 if b is d:
155 return a
158def mm(a, b):
159 logger.debug("METAX GEMS MM")
160 device = a.device
161 # handle non-contiguous inputs if necessary
162 if a.stride(0) > 1 and a.stride(1) > 1:
163 a = a.contiguous()
164 if b.stride(0) > 1 and b.stride(1) > 1:
165 b = b.contiguous()
166 # checks constraints
167 assert a.shape[1] == b.shape[0], "incompatible dimensions"
168 M, K = a.shape
169 _, N = b.shape
170 # allocates output
171 c_dtype = get_higher_dtype(a.dtype, b.dtype)
172 c = torch.empty((M, N), device=device, dtype=c_dtype)
173 dot_out_dtype = tl.float32
174 logger.debug(
175 "METAX GEMS MM, [mm scenario]: general, [shape info]: [-, %s, %s, %s](batch, M, N, K), "
176 "[A column-major]: %s, [B column-major]: %s",
177 M,
178 N,
179 K,
180 a.stride(0) == 1,
181 b.stride(0) == 1,
182 )
183 # launch kernel
184 grid = lambda META: (
185 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
186 META["SPLIT_K"],
187 )
188 with torch_device_fn.device(a.device):
189 mm_kernel[grid](
190 a,
191 b,
192 c,
193 M,
194 N,
195 K,
196 a.stride(0),
197 a.stride(1),
198 b.stride(0),
199 b.stride(1),
200 c.stride(0),
201 c.stride(1),
202 dot_out_dtype=dot_out_dtype,
203 GROUP_M=8,
204 )
205 return c
208def mm_out(a, b, *, out):
209 logger.debug("METAX GEMS MM_OUT")
210 # handle non-contiguous inputs if necessary
211 if a.stride(0) > 1 and a.stride(1) > 1:
212 a = a.contiguous()
213 if b.stride(0) > 1 and b.stride(1) > 1:
214 b = b.contiguous()
215 # checks constraints
216 assert a.shape[1] == b.shape[0], "incompatible dimensions"
217 M, K = a.shape
218 _, N = b.shape
219 # allocates output
220 c = out
221 dot_out_dtype = tl.float32
222 logger.debug(
223 "METAX GEMS MM, [mm scenario]: general, [shape info]: [-, %s, %s, %s](batch, M, N, K), "
224 "[A column-major]: %s, [B column-major]: %s",
225 M,
226 N,
227 K,
228 a.stride(0) == 1,
229 b.stride(0) == 1,
230 )
231 # launch kernel
232 grid = lambda META: (
233 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
234 META["SPLIT_K"],
235 )
236 with torch_device_fn.device(a.device):
237 mm_kernel[grid](
238 a,
239 b,
240 c,
241 M,
242 N,
243 K,
244 a.stride(0),
245 a.stride(1),
246 b.stride(0),
247 b.stride(1),
248 c.stride(0),
249 c.stride(1),
250 dot_out_dtype=dot_out_dtype,
251 GROUP_M=8,
252 )
253 return c