Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/mm.py: 0%
108 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
1import logging
2import os
4import torch
5import triton
6import triton.language as tl
8# from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry
11from flag_gems.utils import triton_lang_extension as tle
13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
16def heur_split_k(args):
17 return 1
20def heur_even_k(args):
21 return args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0
24def heur_group_m(args):
25 if args["BLOCK_M"] > args["BLOCK_N"]:
26 return 1
27 else:
28 return (args["M"] + args["BLOCK_M"] - 1) // args["BLOCK_M"]
31autotune_decorator = triton.autotune(
32 configs=[],
33 generate_configs="mm",
34 key=["M", "N", "K"],
35)
38KLX_USE_AUTOTUNE = os.environ.get("KLX_USE_AUTOTUNE", "1") == "1"
40if not KLX_USE_AUTOTUNE:
41 autotune_decorator = triton.autotune(
42 configs=[
43 triton.Config({"BLOCK_M": 256, "BLOCK_N": 256, "BLOCK_K": 256}),
44 ],
45 key=["M", "N", "K"],
46 )
49@libentry()
50@autotune_decorator
51@triton.heuristics(
52 {
53 "SPLIT_K": heur_split_k,
54 "EVEN_K": heur_even_k,
55 "GROUP_M": heur_group_m,
56 }
57)
58@triton.jit
59def mm_kernel(
60 A,
61 B,
62 C,
63 M,
64 N,
65 K,
66 stride_am,
67 stride_ak,
68 stride_bk,
69 stride_bn,
70 stride_cm,
71 stride_cn,
72 dot_out_dtype: tl.constexpr,
73 BLOCK_M: tl.constexpr,
74 BLOCK_N: tl.constexpr,
75 BLOCK_K: tl.constexpr,
76 GROUP_M: tl.constexpr,
77 SPLIT_K: tl.constexpr,
78 EVEN_K: tl.constexpr,
79):
80 # matrix multiplication
81 pid = tle.program_id(0)
82 pid_z = tle.program_id(1)
83 grid_m = tl.cdiv(M, BLOCK_M)
84 grid_n = tl.cdiv(N, BLOCK_N)
85 # re-order program ID for better L2 performance
86 width = GROUP_M * grid_n
87 group_id = pid // width
88 group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
89 pid_m = group_id * GROUP_M + (pid % group_size)
90 pid_n = (pid % width) // (group_size)
91 # do matrix multiplication
92 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
93 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
94 ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
95 rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
96 rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
97 # pointers
98 A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
99 B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
100 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype)
101 for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
102 if EVEN_K:
103 a = tl.load(A)
104 b = tl.load(B)
105 else:
106 k_remaining = K - k * (BLOCK_K * SPLIT_K)
107 _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty)
108 a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0)
109 b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0)
110 if a.dtype != b.dtype:
111 a = a.to(C.dtype.element_ty)
112 b = b.to(C.dtype.element_ty)
113 acc += tl.dot(a, b, out_dtype=dot_out_dtype, allow_tf32=False)
114 A += BLOCK_K * SPLIT_K * stride_ak
115 B += BLOCK_K * SPLIT_K * stride_bk
116 acc = acc.to(C.dtype.element_ty)
117 # rematerialize rm and rn to save registers
118 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
119 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
120 C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
121 mask = (rm < M)[:, None] & (rn < N)[None, :]
122 # handles write-back with reduction-splitting
123 if SPLIT_K == 1:
124 tl.store(C, acc, mask=mask)
125 else:
126 tl.atomic_add(C, acc, mask=mask)
129_ordered_datatypes = [torch.float16, torch.bfloat16, torch.float32]
132def get_higher_dtype(a, b):
133 if a is b:
134 return a
136 assert a in _ordered_datatypes
137 assert b in _ordered_datatypes
139 for d in _ordered_datatypes:
140 if a is d:
141 return b
142 if b is d:
143 return a
146def mm(a, b):
147 logger.debug("GEMS MM")
148 device = a.device
149 # handle non-contiguous inputs if necessary
150 if a.stride(0) > 1 and a.stride(1) > 1:
151 a = a.contiguous()
152 if b.stride(0) > 1 and b.stride(1) > 1:
153 b = b.contiguous()
154 # checks constraints
155 assert a.shape[1] == b.shape[0], "incompatible dimensions"
156 M, K = a.shape
157 _, N = b.shape
158 # allocates output
159 c_dtype = get_higher_dtype(a.dtype, b.dtype)
160 c = torch.empty((M, N), device=device, dtype=c_dtype)
161 dot_out_dtype = tl.float32
162 # launch kernel
163 grid = lambda META: (
164 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
165 META["SPLIT_K"],
166 )
167 with torch_device_fn.device(a.device):
168 mm_kernel[grid](
169 a,
170 b,
171 c,
172 M,
173 N,
174 K,
175 a.stride(0),
176 a.stride(1),
177 b.stride(0),
178 b.stride(1),
179 c.stride(0),
180 c.stride(1),
181 dot_out_dtype=dot_out_dtype,
182 )
183 return c
186def mm_out(a, b, *, out):
187 logger.debug("GEMS MM_OUT")
188 # handle non-contiguous inputs if necessary
189 if a.stride(0) > 1 and a.stride(1) > 1:
190 a = a.contiguous()
191 if b.stride(0) > 1 and b.stride(1) > 1:
192 b = b.contiguous()
193 # checks constraints
194 assert a.shape[1] == b.shape[0], "incompatible dimensions"
195 M, K = a.shape
196 _, N = b.shape
197 # allocates output
198 c = out
199 dot_out_dtype = tl.float32
200 # launch kernel
201 grid = lambda META: (
202 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
203 META["SPLIT_K"],
204 )
205 with torch_device_fn.device(a.device):
206 mm_kernel[grid](
207 a,
208 b,
209 c,
210 M,
211 N,
212 K,
213 a.stride(0),
214 a.stride(1),
215 b.stride(0),
216 b.stride(1),
217 c.stride(0),
218 c.stride(1),
219 dot_out_dtype=dot_out_dtype,
220 )
221 return c