Coverage for src/flag_gems/runtime/backend/_nvidia/hopper/ops/mm.py: 0%
216 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
1import logging
2from typing import Optional
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.ops.mm_streamk import streamk_mm
10from flag_gems.runtime import torch_device_fn
11from flag_gems.utils import libentry, libtuner
12from flag_gems.utils import triton_lang_extension as tle
13from flag_gems.utils.device_info import get_device_capability, get_sm_count
15logger = logging.getLogger(__name__)
16CACHE_USAGE_THRESHOLD = 0.8
19def is_tma_compatible(a, b, N, K):
20 """
21 Check if tensors are compatible with TMA (Tensor Memory Accelerator).
23 TMA requires 128-bit (16-byte) alignment for memory access:
24 - For FP16/BF16 (2 bytes/element): N and K must be multiples of 8
25 (8 elements × 2 bytes = 16 bytes)
26 - For FP32 (4 bytes/element): N and K must be multiples of 4
27 (4 elements × 4 bytes = 16 bytes)
29 Args:
30 a, b: Input tensors
31 N, K: Matrix dimensions
33 Returns:
34 bool: True if compatible with TMA's 128-bit alignment requirement
35 """
36 return (
37 a.dtype in (torch.float16, torch.bfloat16)
38 and b.dtype in (torch.float16, torch.bfloat16)
39 and N % 8 == 0
40 and K % 8 == 0
41 ) or (
42 a.dtype in (torch.float32,)
43 and b.dtype in (torch.float32,)
44 and N % 4 == 0
45 and K % 4 == 0
46 )
49@triton.jit
50def prev_multiple_of(a, b):
51 # the largest x<a that x%b ==0
52 return tl.cdiv(a, b) * b - b
55@libentry()
56@libtuner(
57 configs=runtime.get_tuned_config("mm"),
58 # Add 'stride_am' and 'stride_bk' to trigger autotune for tensors with the same shape but different strides.
59 key=["M", "N", "K", "stride_am", "stride_bk"],
60 strategy=["default", "default", "default", "default", "default"],
61 warmup=5,
62 rep=10,
63)
64@triton.jit
65def mm_kernel_general(
66 A,
67 B,
68 C,
69 M,
70 N,
71 K,
72 stride_am,
73 stride_ak,
74 stride_bk,
75 stride_bn,
76 stride_cm,
77 stride_cn,
78 BLOCK_M: tl.constexpr,
79 BLOCK_N: tl.constexpr,
80 BLOCK_K: tl.constexpr,
81 GROUP_M: tl.constexpr,
82):
83 # matrix multiplication
84 pid = tle.program_id(0)
85 grid_m = tl.cdiv(M, BLOCK_M)
86 grid_n = tl.cdiv(N, BLOCK_N)
87 # re-order program ID for better L2 performance
88 width = GROUP_M * grid_n
89 group_id = pid // width
90 group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
91 pid_m = group_id * GROUP_M + (pid % group_size)
92 pid_n = (pid % width) // (group_size)
94 if M % BLOCK_M == 0 and N % BLOCK_N == 0 and K % BLOCK_K == 0:
95 # offset
96 offset_am = pid_m * BLOCK_M
97 offset_bn = pid_n * BLOCK_N
98 offset_k = 0
100 a_desc = tl.make_tensor_descriptor(
101 base=A,
102 shape=[M, K],
103 strides=[K, 1],
104 block_shape=[BLOCK_M, BLOCK_K],
105 )
107 # row-major
108 b_desc = tl.make_tensor_descriptor(
109 base=B,
110 shape=[K, N],
111 strides=[N, 1],
112 block_shape=[BLOCK_K, BLOCK_N],
113 )
115 # column-major
116 # b_desc = tl.make_tensor_descriptor(
117 # B,
118 # shape = [N, K],
119 # strides = [K, 1],
120 # block_shape = [BLOCK_N, BLOCK_K],
121 # )
123 c_desc = tl.make_tensor_descriptor(
124 base=C,
125 shape=[M, N],
126 strides=[N, 1],
127 block_shape=[BLOCK_M, BLOCK_N],
128 )
130 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
131 for k in range(0, tl.cdiv(K, BLOCK_K)):
132 a = a_desc.load([offset_am.to(tl.int32), offset_k.to(tl.int32)])
133 b = b_desc.load([offset_k.to(tl.int32), offset_bn.to(tl.int32)])
134 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)
135 offset_k += BLOCK_K
137 acc = acc.to(a_desc.dtype)
138 c_desc.store([offset_am.to(tl.int32), offset_bn.to(tl.int32)], acc)
140 else:
141 # do matrix multiplication
142 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
143 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
144 ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M).to(tl.int64)
145 rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N).to(tl.int64)
146 rm = rm.to(tl.int64)
147 rn = rn.to(tl.int64)
148 prev_multiple = prev_multiple_of(K, BLOCK_K)
150 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
151 for start_k in range(0, prev_multiple, BLOCK_K):
152 rk = (start_k + tl.arange(0, BLOCK_K)).to(tl.int64)
153 a = tl.load(A + (ram[:, None] * stride_am + rk[None, :] * stride_ak))
154 b = tl.load(B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn))
155 if a.dtype != b.dtype:
156 a = a.to(C.dtype.element_ty)
157 b = b.to(C.dtype.element_ty)
158 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)
160 # loop peeling
161 rk = (prev_multiple + tl.arange(0, BLOCK_K)).to(tl.int64)
162 mask_k = rk < K
163 a = tl.load(
164 A + (ram[:, None] * stride_am + rk[None, :] * stride_ak),
165 mask=mask_k[None, :],
166 other=0.0,
167 )
168 b = tl.load(
169 B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn),
170 mask=mask_k[:, None],
171 other=0.0,
172 )
173 if a.dtype != b.dtype:
174 a = a.to(C.dtype.element_ty)
175 b = b.to(C.dtype.element_ty)
176 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)
178 acc = acc.to(C.dtype.element_ty)
179 # rematerialize rm and rn to save registers
180 rm = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)).to(tl.int64)
181 rn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)).to(tl.int64)
182 offsets = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
183 mask = (rm < M)[:, None] & (rn < N)[None, :]
184 # handles write-back with reduction-splitting
185 tl.store(offsets, acc, mask=mask)
188def matmul_tma_set_block_size_hook(nargs):
189 BLOCK_M = nargs["BLOCK_M"]
190 BLOCK_N = nargs["BLOCK_N"]
191 BLOCK_K = nargs["BLOCK_K"]
192 if nargs["A_ROW_MAJOR"]:
193 nargs["a_desc"].block_shape = [BLOCK_M, BLOCK_K]
194 else:
195 nargs["a_desc"].block_shape = [BLOCK_K, BLOCK_M]
197 if nargs["B_ROW_MAJOR"]:
198 nargs["b_desc"].block_shape = [BLOCK_K, BLOCK_N]
199 else:
200 nargs["b_desc"].block_shape = [BLOCK_N, BLOCK_K]
202 nargs["c_desc"].block_shape = [BLOCK_M, BLOCK_N]
205def matmul_get_configs(pre_hook=matmul_tma_set_block_size_hook):
206 return [
207 triton.Config(
208 {"BLOCK_M": BM, "BLOCK_N": BN, "BLOCK_K": BK},
209 num_stages=s,
210 num_warps=w,
211 pre_hook=pre_hook,
212 )
213 for BM in [32, 64, 128, 256]
214 for BN in [32, 64, 128]
215 for BK in [32, 64, 128]
216 for s in [2, 3, 4]
217 for w in [4, 8]
218 ]
221@libentry()
222@libtuner(
223 configs=matmul_get_configs(),
224 key=["M", "N", "K", "stride_am", "stride_bk", "dtype"],
225 strategy=["align32", "align32", "align32", "align32", "align32", "default"],
226 warmup=5,
227 rep=5,
228)
229@triton.jit
230def mm_kernel_general_host_tma(
231 a_desc,
232 b_desc,
233 c_desc,
234 M,
235 N,
236 K,
237 stride_am,
238 stride_ak,
239 stride_bk,
240 stride_bn,
241 stride_cm,
242 stride_cn,
243 BLOCK_M: tl.constexpr,
244 BLOCK_N: tl.constexpr,
245 BLOCK_K: tl.constexpr,
246 GROUP_M: tl.constexpr,
247 A_ROW_MAJOR: tl.constexpr,
248 B_ROW_MAJOR: tl.constexpr,
249 dtype: tl.constexpr,
250 enable_warp_specialization=True,
251):
252 pid = tl.program_id(0)
253 grid_m = tl.cdiv(M, BLOCK_M)
254 grid_n = tl.cdiv(N, BLOCK_N)
256 width = GROUP_M * grid_n
257 group_id = pid // width
258 group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
259 pid_m = group_id * GROUP_M + (pid % group_size)
260 pid_n = (pid % width) // (group_size)
262 accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
263 offset_am = (pid_m * BLOCK_M).to(tl.int32)
264 offset_bn = (pid_n * BLOCK_N).to(tl.int32)
265 iters = tl.cdiv(K, BLOCK_K)
266 for k in range(iters):
267 offset_ak = (k * BLOCK_K).to(tl.int32)
269 if A_ROW_MAJOR:
270 a = a_desc.load([offset_am, offset_ak])
271 else:
272 a_t = a_desc.load([offset_ak, offset_am])
273 a = tl.trans(a_t)
275 if B_ROW_MAJOR:
276 b = b_desc.load([offset_ak, offset_bn])
277 else:
278 b_t = b_desc.load([offset_bn, offset_ak])
279 b = tl.trans(b_t)
281 if a_desc.dtype == tl.float16 or a_desc.dtype == tl.bfloat16:
282 accumulator = tl.dot(a, b, acc=accumulator, allow_tf32=False)
283 else:
284 accumulator = tl.dot(a, b, acc=accumulator, input_precision="tf32x3")
286 c = accumulator.to(c_desc.dtype)
287 c_desc.store([offset_am, offset_bn], c)
290def get_higher_dtype(a, b):
291 _ordered_datatypes = [torch.float16, torch.bfloat16, torch.float32]
293 if a is b:
294 return a
296 assert a in _ordered_datatypes
297 assert b in _ordered_datatypes
299 for d in _ordered_datatypes:
300 if a is d:
301 return b
302 if b is d:
303 return a
306def general_mm(a, b, c, M, N, K):
307 # TODO: Remove this debug message
308 logger.debug(
309 "GEMS MM-hopper, [mm scenario]: general, [shape info]: [-, %s, %s, %s](batch, M, N, K), "
310 "[A column-major]: %s, [B column-major]: %s",
311 M,
312 N,
313 K,
314 a.stride(0) == 1,
315 b.stride(0) == 1,
316 )
317 grid = lambda META: (
318 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
319 )
320 if hasattr(
321 triton.tools.tensor_descriptor, "TensorDescriptor"
322 ) and is_tma_compatible(a, b, N, K):
323 a_row_major = a.stride(1) == 1
324 b_row_major = b.stride(1) == 1
325 dummy_block = [1, 1]
326 # triton 3.5.0
327 from triton.tools.tensor_descriptor import TensorDescriptor
329 if a_row_major:
330 a_desc = TensorDescriptor(a, a.shape, a.stride(), dummy_block)
331 else:
332 a_desc = TensorDescriptor(a, a.T.shape, a.T.stride(), dummy_block)
333 if b_row_major:
334 b_desc = TensorDescriptor(b, b.shape, b.stride(), dummy_block)
335 else:
336 b_desc = TensorDescriptor(b, b.T.shape, b.T.stride(), dummy_block)
337 c_desc = TensorDescriptor(c, c.shape, c.stride(), dummy_block)
339 input_dtype = a.dtype
340 dtype_str = str(input_dtype).split(".")[-1]
342 with torch_device_fn.device(a.device):
343 mm_kernel_general_host_tma[grid](
344 a_desc,
345 b_desc,
346 c_desc,
347 M,
348 N,
349 K,
350 a.stride(0),
351 a.stride(1),
352 b.stride(0),
353 b.stride(1),
354 c.stride(0),
355 c.stride(1),
356 GROUP_M=8,
357 A_ROW_MAJOR=a_row_major,
358 B_ROW_MAJOR=b_row_major,
359 dtype=dtype_str,
360 )
361 else:
363 def alloc_fn(size: int, align: int, stream: Optional[int]):
364 return torch.empty(size, dtype=torch.int8, device=a.device)
366 triton.set_allocator(alloc_fn)
368 with torch_device_fn.device(a.device):
369 mm_kernel_general[grid](
370 a,
371 b,
372 c,
373 M,
374 N,
375 K,
376 a.stride(0),
377 a.stride(1),
378 b.stride(0),
379 b.stride(1),
380 c.stride(0),
381 c.stride(1),
382 GROUP_M=8,
383 )
384 return c
387@libentry()
388@triton.jit
389def gemv_kernel(
390 A,
391 B,
392 C,
393 M,
394 K,
395 stride_am,
396 stride_ak,
397 stride_bk,
398 BLOCK_M: tl.constexpr,
399 BLOCK_K: tl.constexpr,
400):
401 """Optimized kernel for matrix-vector multiplication (N=1 case)"""
402 pid = tl.program_id(0)
404 # Each program handles BLOCK_M rows
405 row_start = pid * BLOCK_M
406 row_offset = row_start + tl.arange(0, BLOCK_M)
407 row_mask = row_offset < M
409 # Accumulator for this block of rows
410 acc = tl.zeros((BLOCK_M,), dtype=tl.float32)
412 # Iterate over K dimension
413 for k_start in range(0, K, BLOCK_K):
414 k_offset = k_start + tl.arange(0, BLOCK_K)
415 k_mask = k_offset < K
417 # Load block from matrix A: [BLOCK_M, BLOCK_K]
418 a_ptrs = A + row_offset[:, None] * stride_am + k_offset[None, :] * stride_ak
419 a = tl.load(a_ptrs, mask=row_mask[:, None] & k_mask[None, :], other=0.0)
421 # Load block from vector B: [BLOCK_K]
422 b_ptrs = B + k_offset * stride_bk
423 b = tl.load(b_ptrs, mask=k_mask, other=0.0)
425 # Accumulate: sum over K dimension
426 acc += tl.sum(a * b[None, :], axis=1)
428 # Store result
429 c_ptrs = C + row_offset
430 acc = acc.to(C.dtype.element_ty)
431 tl.store(c_ptrs, acc, mask=row_mask)
434def gemv_mm(a, b, c, M, K):
435 """Optimized matrix-vector multiplication for N=1 case"""
436 logger.debug(
437 "GEMS MM-hopper, [mm scenario]: gemv (N=1), [shape info]: [%s, %s, 1](M, K, N)",
438 M,
439 K,
440 )
442 BLOCK_M = 32
443 BLOCK_K = 256
444 grid = lambda META: (triton.cdiv(M, BLOCK_M),)
446 with torch_device_fn.device(a.device):
447 gemv_kernel[grid](
448 a,
449 b,
450 c,
451 M,
452 K,
453 a.stride(0),
454 a.stride(1),
455 b.stride(0),
456 BLOCK_M=BLOCK_M,
457 BLOCK_K=BLOCK_K,
458 )
459 return c
462def streamk_scenario(a, b, M, N, K):
463 # TODO: this my change sometime according to the realbenchmark result
464 # Currently, the best configuration for streamk has only been tested on A100(capability[0] == 8).
465 # The optimal settings for other devices need to be determined through real testing.
466 capability = get_device_capability()
467 return (
468 capability[0] == 8
469 and a.dtype in [torch.float16, torch.bfloat16]
470 and b.dtype in [torch.float16, torch.bfloat16]
471 and a.is_contiguous()
472 and b.is_contiguous()
473 and K > M * 5
474 and K > N * 5
475 )
478def mm(a, b):
479 device = a.device
480 # handle non-contiguous inputs if necessary
481 if a.stride(0) > 1 and a.stride(1) > 1:
482 a = a.contiguous()
483 if b.stride(0) > 1 and b.stride(1) > 1:
484 b = b.contiguous()
485 # checks constraints
486 assert a.shape[1] == b.shape[0], "incompatible dimensions"
487 M, K = a.shape
488 _, N = b.shape
489 # allocates output
490 c_dtype = get_higher_dtype(a.dtype, b.dtype)
491 c = torch.empty((M, N), device=device, dtype=c_dtype)
493 # Optimize for N=1 case (matrix-vector multiplication)
494 if N == 1:
495 return gemv_mm(a, b, c, M, K)
496 # l2_cache_size = get_l2_cache_size()
497 sm_count = get_sm_count()
498 if streamk_scenario(a, b, M, N, K):
499 return streamk_mm(a, b, c, M, N, K, sm_count=sm_count)
500 else:
501 return general_mm(a, b, c, M, N, K)
504def mm_out(a, b, *, out):
505 # handle non-contiguous inputs if necessary
506 if a.stride(0) > 1 and a.stride(1) > 1:
507 a = a.contiguous()
508 if b.stride(0) > 1 and b.stride(1) > 1:
509 b = b.contiguous()
510 # checks constraints
511 assert a.shape[1] == b.shape[0], "incompatible dimensions"
512 M, K = a.shape
513 _, N = b.shape
515 # Optimize for N=1 case (matrix-vector multiplication)
516 if N == 1:
517 return gemv_mm(a, b, out, M, K)
518 # l2_cache_size = get_l2_cache_size()
519 sm_count = get_sm_count()
520 if streamk_scenario(a, b, M, N, K):
521 return streamk_mm(a, b, out, M, N, K, sm_count=sm_count)
522 else:
523 return general_mm(a, b, out, M, N, K)