Coverage for src/flag_gems/runtime/backend/_hygon/ops/mm.py: 0%
94 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-24 15:40 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-24 15:40 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry, libtuner
10from flag_gems.utils import triton_lang_extension as tle
12logger = logging.getLogger(__name__)
15@triton.jit
16def prev_multiple_of(a, b):
17 # the largest x<a that x%b ==0
18 return tl.cdiv(a, b) * b - b
21@libentry()
22@libtuner(
23 configs=runtime.get_tuned_config("mm"),
24 key=["M", "N", "K"],
25 strategy=["log", "log", "log"],
26)
27@triton.jit
28def mm_kernel(
29 a_ptr,
30 b_ptr,
31 c_ptr,
32 M,
33 N,
34 K,
35 stride_am,
36 stride_ak,
37 stride_bk,
38 stride_bn,
39 stride_cm,
40 stride_cn,
41 BLOCK_M: tl.constexpr,
42 BLOCK_N: tl.constexpr,
43 BLOCK_K: tl.constexpr,
44 GROUP_M: tl.constexpr,
45):
46 pid = tle.program_id(0)
48 # --------------------------
49 # match naming: num_pid_m, num_pid_n
50 # --------------------------
51 num_pid_m = tl.cdiv(M, BLOCK_M)
52 num_pid_n = tl.cdiv(N, BLOCK_N)
54 # reorder for L2
55 num_pid_in_group = GROUP_M * num_pid_n
56 group_id = pid // num_pid_in_group
57 group_size_m = min(num_pid_m - group_id * GROUP_M, GROUP_M)
59 pid_m = group_id * GROUP_M + (pid % group_size_m)
60 pid_n = (pid % num_pid_in_group) // group_size_m
62 # --------------------------
63 # match naming: offs_am, offs_bn, offs_k
64 # --------------------------
65 offs_am = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
66 offs_bn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
67 offs_k = tl.arange(0, BLOCK_K)
69 # contiguous aligned offsets (ram/rbn → offs_am/offs_bn)
70 offs_am_cont = tl.max_contiguous(tl.multiple_of(offs_am % M, BLOCK_M), BLOCK_M)
71 offs_bn_cont = tl.max_contiguous(tl.multiple_of(offs_bn % N, BLOCK_N), BLOCK_N)
73 # previous K multiple
74 # prev_k_mult = prev_multiple_of(K, BLOCK_K)
75 prev_k_mult = tl.cdiv(K, BLOCK_K) * BLOCK_K - BLOCK_K
77 # accumulator
78 accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
80 # --------------------------
81 # main K loop
82 # --------------------------
83 for start_k in range(0, prev_k_mult, BLOCK_K):
84 rk = start_k + offs_k
86 a = tl.load(
87 a_ptr + (offs_am_cont[:, None] * stride_am + rk[None, :] * stride_ak)
88 )
89 b = tl.load(
90 b_ptr + (rk[:, None] * stride_bk + offs_bn_cont[None, :] * stride_bn)
91 )
93 if a.dtype != b.dtype:
94 a = a.to(c_ptr.dtype.element_ty)
95 b = b.to(c_ptr.dtype.element_ty)
97 accumulator += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)
99 # --------------------------
100 # loop peel
101 # --------------------------
102 rk = prev_k_mult + offs_k
103 mask_k = rk < K
105 a = tl.load(
106 a_ptr + (offs_am_cont[:, None] * stride_am + rk[None, :] * stride_ak),
107 mask=mask_k[None, :],
108 )
109 b = tl.load(
110 b_ptr + (rk[:, None] * stride_bk + offs_bn_cont[None, :] * stride_bn),
111 mask=mask_k[:, None],
112 )
114 if a.dtype != b.dtype:
115 a = a.to(c_ptr.dtype.element_ty)
116 b = b.to(c_ptr.dtype.element_ty)
118 accumulator += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)
120 # cast to output dtype
121 accumulator = accumulator.to(c_ptr.dtype.element_ty)
123 # --------------------------
124 # rematerialize offsets for store
125 # (match naming: offs_cm, offs_cn)
126 # --------------------------
127 offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
128 offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
130 c_ptr = c_ptr + (offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn)
131 mask_store = (offs_cm < M)[:, None] & (offs_cn < N)[None, :]
133 tl.store(c_ptr, accumulator, mask=mask_store)
136_ordered_datatypes = [torch.float16, torch.bfloat16, torch.float32]
139def get_higher_dtype(a, b):
140 if a is b:
141 return a
143 assert a in _ordered_datatypes
144 assert b in _ordered_datatypes
146 for d in _ordered_datatypes:
147 if a is d:
148 return b
149 if b is d:
150 return a
153def mm(a, b):
154 logger.debug("GEMS MM")
155 device = a.device
156 # handle non-contiguous inputs if necessary
157 if a.stride(0) > 1 and a.stride(1) > 1:
158 a = a.contiguous()
159 if b.stride(0) > 1 and b.stride(1) > 1:
160 b = b.contiguous()
161 # checks constraints
162 assert a.shape[1] == b.shape[0], "incompatible dimensions"
163 M, K = a.shape
164 _, N = b.shape
165 # allocates output
166 c_dtype = get_higher_dtype(a.dtype, b.dtype)
167 c = torch.empty((M, N), device=device, dtype=c_dtype)
168 # launch kernel
169 grid = lambda META: (
170 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
171 )
172 with torch_device_fn.device(a.device):
173 mm_kernel[grid](
174 a,
175 b,
176 c,
177 M,
178 N,
179 K,
180 a.stride(0),
181 a.stride(1),
182 b.stride(0),
183 b.stride(1),
184 c.stride(0),
185 c.stride(1),
186 GROUP_M=8,
187 )
188 return c
191def mm_out(a, b, *, out):
192 logger.debug("GEMS MM_OUT")
193 # handle non-contiguous inputs if necessary
194 if a.stride(0) > 1 and a.stride(1) > 1:
195 a = a.contiguous()
196 if b.stride(0) > 1 and b.stride(1) > 1:
197 b = b.contiguous()
198 # checks constraints
199 assert a.shape[1] == b.shape[0], "incompatible dimensions"
200 M, K = a.shape
201 _, N = b.shape
202 # allocates output
203 c = out
204 # launch kernel
205 grid = lambda META: (
206 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
207 )
208 with torch_device_fn.device(a.device):
209 mm_kernel[grid](
210 a,
211 b,
212 c,
213 M,
214 N,
215 K,
216 a.stride(0),
217 a.stride(1),
218 b.stride(0),
219 b.stride(1),
220 c.stride(0),
221 c.stride(1),
222 GROUP_M=8,
223 )
224 return c