Coverage for src/flag_gems/ops/mm.py: 43%
104 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.ops.mm_streamk import streamk_mm
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry, libtuner
11from flag_gems.utils import triton_lang_extension as tle
12from flag_gems.utils.device_info import get_device_capability, get_sm_count
14CACHE_USAGE_THRESHOLD = 0.8
16logger = logging.getLogger(__name__)
19@triton.jit
20def prev_multiple_of(a, b):
21 # the largest x<a that x%b ==0
22 return tl.cdiv(a, b) * b - b
25@libentry()
26@libtuner(
27 configs=runtime.get_tuned_config("mm"),
28 # Add 'stride_am' and 'stride_bk' to trigger autotune for tensors with the same shape but different strides.
29 key=["M", "N", "K", "stride_am", "stride_bk"],
30 strategy=["align32", "align32", "align32", "align32", "align32"],
31 warmup=5,
32 rep=10,
33)
34@triton.jit
35def mm_kernel_general(
36 A,
37 B,
38 C,
39 M,
40 N,
41 K,
42 stride_am,
43 stride_ak,
44 stride_bk,
45 stride_bn,
46 stride_cm,
47 stride_cn,
48 BLOCK_M: tl.constexpr,
49 BLOCK_N: tl.constexpr,
50 BLOCK_K: tl.constexpr,
51 GROUP_M: tl.constexpr,
52):
53 # matrix multiplication
54 pid = tle.program_id(0)
55 grid_m = tl.cdiv(M, BLOCK_M)
56 grid_n = tl.cdiv(N, BLOCK_N)
57 # re-order program ID for better L2 performance
58 width = GROUP_M * grid_n
59 group_id = pid // width
60 group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
61 pid_m = group_id * GROUP_M + (pid % group_size)
62 pid_n = (pid % width) // (group_size)
63 # do matrix multiplication
64 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
65 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
66 ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M).to(tl.int64)
67 rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N).to(tl.int64)
68 rm = rm.to(tl.int64)
69 rn = rn.to(tl.int64)
70 prev_multiple = prev_multiple_of(K, BLOCK_K)
72 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
73 for start_k in range(0, prev_multiple, BLOCK_K):
74 rk = (start_k + tl.arange(0, BLOCK_K)).to(tl.int64)
75 a = tl.load(A + (ram[:, None] * stride_am + rk[None, :] * stride_ak))
76 b = tl.load(B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn))
77 if a.dtype != b.dtype:
78 a = a.to(C.dtype.element_ty)
79 b = b.to(C.dtype.element_ty)
80 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)
82 # loop peeling
83 rk = (prev_multiple + tl.arange(0, BLOCK_K)).to(tl.int64)
84 mask_k = rk < K
85 a = tl.load(
86 A + (ram[:, None] * stride_am + rk[None, :] * stride_ak),
87 mask=mask_k[None, :],
88 other=0.0,
89 )
90 b = tl.load(
91 B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn),
92 mask=mask_k[:, None],
93 other=0.0,
94 )
95 if a.dtype != b.dtype:
96 a = a.to(C.dtype.element_ty)
97 b = b.to(C.dtype.element_ty)
98 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)
100 acc = acc.to(C.dtype.element_ty)
101 # rematerialize rm and rn to save registers
102 rm = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)).to(tl.int64)
103 rn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)).to(tl.int64)
104 C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
105 mask = (rm < M)[:, None] & (rn < N)[None, :]
106 # handles write-back with reduction-splitting
107 tl.store(C, acc, mask=mask)
110_ordered_datatypes = [torch.float16, torch.bfloat16, torch.float32]
113def get_higher_dtype(a, b):
114 if a is b:
115 return a
117 assert a in _ordered_datatypes
118 assert b in _ordered_datatypes
120 for d in _ordered_datatypes:
121 if a is d:
122 return b
123 if b is d:
124 return a
127def general_mm(a, b, c, M, N, K):
128 logger.debug(
129 "GEMS MM, [mm scenario]: general, [shape info]: [-, %s, %s, %s](batch, M, N, K), "
130 "[A column-major]: %s, [B column-major]: %s",
131 M,
132 N,
133 K,
134 a.stride(0) == 1,
135 b.stride(0) == 1,
136 )
137 grid = lambda META: (
138 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
139 )
140 with torch_device_fn.device(a.device):
141 mm_kernel_general[grid](
142 a,
143 b,
144 c,
145 M,
146 N,
147 K,
148 a.stride(0),
149 a.stride(1),
150 b.stride(0),
151 b.stride(1),
152 c.stride(0),
153 c.stride(1),
154 GROUP_M=8,
155 )
156 return c
159def streamk_scenario(a, b, M, N, K):
160 # TODO: this my change sometime according to the realbenchmark result
161 # Currently, the best configuration for streamk has only been tested on A100(capability[0] == 8).
162 # The optimal settings for other devices need to be determined through real testing.
163 capability = get_device_capability()
164 return (
165 capability[0] == 8
166 and a.dtype in [torch.float16, torch.bfloat16]
167 and b.dtype in [torch.float16, torch.bfloat16]
168 and a.is_contiguous()
169 and b.is_contiguous()
170 and K > M * 5
171 and K > N * 5
172 )
175def mm(a, b):
176 device = a.device
177 # handle non-contiguous inputs if necessary
178 if a.stride(0) > 1 and a.stride(1) > 1:
179 a = a.contiguous()
180 if b.stride(0) > 1 and b.stride(1) > 1:
181 b = b.contiguous()
182 # checks constraints
183 assert a.shape[1] == b.shape[0], "incompatible dimensions"
184 M, K = a.shape
185 _, N = b.shape
186 # allocates output
187 c_dtype = get_higher_dtype(a.dtype, b.dtype)
188 c = torch.empty((M, N), device=device, dtype=c_dtype)
189 # l2_cache_size = get_l2_cache_size()
190 sm_count = get_sm_count()
191 if streamk_scenario(a, b, M, N, K):
192 return streamk_mm(a, b, c, M, N, K, sm_count=sm_count)
193 else:
194 return general_mm(a, b, c, M, N, K)
197def mm_out(a, b, *, out):
198 # handle non-contiguous inputs if necessary
199 if a.stride(0) > 1 and a.stride(1) > 1:
200 a = a.contiguous()
201 if b.stride(0) > 1 and b.stride(1) > 1:
202 b = b.contiguous()
203 # checks constraints
204 assert a.shape[1] == b.shape[0], "incompatible dimensions"
205 M, K = a.shape
206 _, N = b.shape
207 # l2_cache_size = get_l2_cache_size()
208 sm_count = get_sm_count()
209 if streamk_scenario(a, b, M, N, K):
210 return streamk_mm(a, b, out, M, N, K, sm_count=sm_count)
211 else:
212 return general_mm(a, b, out, M, N, K)