Coverage for src/flag_gems/runtime/backend/_cambricon/ops/mm.py: 0%
97 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry, libtuner
11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
14@libentry()
15@libtuner(
16 configs=runtime.get_tuned_config("mm"),
17 key=["M", "N", "K", "stride_am", "stride_bk", "stride_ak", "stride_bn"],
18 strategy=[
19 "align32",
20 "align32",
21 "align32",
22 "align32",
23 "align32",
24 "align32",
25 "align32",
26 ],
27)
28@triton.heuristics(runtime.get_heuristic_config("mm"))
29@triton.jit
30def mm_kernel(
31 A,
32 B,
33 C,
34 M,
35 N,
36 K,
37 stride_am,
38 stride_ak,
39 stride_bk,
40 stride_bn,
41 stride_cm,
42 stride_cn,
43 dot_out_dtype: tl.constexpr,
44 BLOCK_M: tl.constexpr,
45 BLOCK_N: tl.constexpr,
46 BLOCK_K: tl.constexpr,
47 GROUP_M: tl.constexpr,
48 SPLIT_K: tl.constexpr,
49 EVEN_K: tl.constexpr,
50 UPCAST: tl.constexpr,
51):
52 # matrix multiplication
53 if UPCAST:
54 pid = tl.program_id(0).to(tl.int64)
55 pid_z = tl.program_id(1).to(tl.int64)
56 else:
57 pid = tl.program_id(0)
58 pid_z = tl.program_id(1)
59 grid_m = tl.cdiv(M, BLOCK_M)
60 grid_n = tl.cdiv(N, BLOCK_N)
61 # re-order program ID for better L2 performance
62 width = GROUP_M * grid_n
63 group_id = pid // width
64 group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
65 pid_m = group_id * GROUP_M + (pid % group_size)
66 pid_n = (pid % width) // (group_size)
67 # do matrix multiplication
68 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
69 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
70 rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
71 # pointers
72 A = A + (rm[:, None] * stride_am + rk[None, :] * stride_ak)
73 B = B + (rk[:, None] * stride_bk + rn[None, :] * stride_bn)
74 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype)
75 for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
76 if EVEN_K:
77 a = tl.load(A, mask=(rm < M)[:, None], other=0.0)
78 b = tl.load(B, mask=(rn < N)[None, :], other=0.0)
79 else:
80 k_remaining = K - k * (BLOCK_K * SPLIT_K)
81 a = tl.load(
82 A, mask=(rk[None, :] < k_remaining) & (rm < M)[:, None], other=0.0
83 )
84 b = tl.load(
85 B, mask=(rk[:, None] < k_remaining) & (rn < N)[None, :], other=0.0
86 )
88 if a.dtype != b.dtype:
89 a = a.to(C.dtype.element_ty)
90 b = b.to(C.dtype.element_ty)
91 acc += tl.dot(a, b, out_dtype=dot_out_dtype, allow_tf32=False)
92 A += BLOCK_K * SPLIT_K * stride_ak
93 B += BLOCK_K * SPLIT_K * stride_bk
94 acc = acc.to(C.dtype.element_ty)
95 # rematerialize rm and rn to save registers
96 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
97 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
98 C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
99 mask = (rm < M)[:, None] & (rn < N)[None, :]
100 # handles write-back with reduction-splitting
101 if SPLIT_K == 1:
102 tl.store(C, acc, mask=mask)
103 else:
104 tl.atomic_add(C, acc, mask=mask)
107_ordered_datatypes = [torch.float16, torch.bfloat16, torch.float32]
110def get_higher_dtype(a, b):
111 if a is b:
112 return a
114 assert a in _ordered_datatypes
115 assert b in _ordered_datatypes
117 for d in _ordered_datatypes:
118 if a is d:
119 return b
120 if b is d:
121 return a
124def mm(a, b):
125 logger.debug("GEMS_CAMBRICON MM")
126 device = a.device
127 # handle non-contiguous inputs if necessary
128 if a.stride(0) > 1 and a.stride(1) > 1:
129 a = a.contiguous()
130 if b.stride(0) > 1 and b.stride(1) > 1:
131 b = b.contiguous()
132 # checks constraints
133 assert a.shape[1] == b.shape[0], "incompatible dimensions"
134 M, K = a.shape
135 _, N = b.shape
136 # allocates output
137 c_dtype = get_higher_dtype(a.dtype, b.dtype)
138 c = torch.empty((M, N), device=device, dtype=c_dtype)
139 dot_out_dtype = tl.float32
140 UPCAST = (
141 M * max(a.stride(0), c.stride(0)) >= 1 << 31
142 or N * max(b.stride(1), c.stride(1)) >= 1 << 31
143 or K * max(a.stride(1), b.stride(0)) >= 1 << 31
144 )
145 # launch kernel
146 grid = lambda META: (
147 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
148 META["SPLIT_K"],
149 )
150 with torch_device_fn.device(a.device):
151 mm_kernel[grid](
152 a,
153 b,
154 c,
155 M,
156 N,
157 K,
158 a.stride(0),
159 a.stride(1),
160 b.stride(0),
161 b.stride(1),
162 c.stride(0),
163 c.stride(1),
164 dot_out_dtype=dot_out_dtype,
165 GROUP_M=8,
166 UPCAST=UPCAST,
167 )
168 return c
171def mm_out(a, b, *, out):
172 logger.debug("GEMS_CAMBRICON MM_OUT")
173 # handle non-contiguous inputs if necessary
174 if a.stride(0) > 1 and a.stride(1) > 1:
175 a = a.contiguous()
176 if b.stride(0) > 1 and b.stride(1) > 1:
177 b = b.contiguous()
178 # checks constraints
179 assert a.shape[1] == b.shape[0], "incompatible dimensions"
180 M, K = a.shape
181 _, N = b.shape
182 # allocates output
183 c = out
184 dot_out_dtype = tl.float32
185 UPCAST = (
186 M * max(a.stride(0), c.stride(0)) >= 1 << 31
187 or N * max(b.stride(1), c.stride(1)) >= 1 << 31
188 or K * max(a.stride(1), b.stride(0)) >= 1 << 31
189 )
190 # launch kernel
191 grid = lambda META: (
192 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
193 META["SPLIT_K"],
194 )
195 with torch_device_fn.device(a.device):
196 mm_kernel[grid](
197 a,
198 b,
199 c,
200 M,
201 N,
202 K,
203 a.stride(0),
204 a.stride(1),
205 b.stride(0),
206 b.stride(1),
207 c.stride(0),
208 c.stride(1),
209 dot_out_dtype=dot_out_dtype,
210 GROUP_M=8,
211 UPCAST=UPCAST,
212 )
213 return c