Coverage for src/flag_gems/runtime/backend/_nvidia/hopper/ops/mm.py: 0%
255 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
1import logging
2import os
3from typing import Optional
5import torch
6import triton
7import triton.language as tl
9from flag_gems import runtime
10from flag_gems.ops.mm_streamk import streamk_mm
11from flag_gems.runtime import torch_device_fn
12from flag_gems.utils import libentry, libtuner
13from flag_gems.utils import triton_lang_extension as tle
14from flag_gems.utils.device_info import get_device_capability, get_sm_count
16logger = logging.getLogger("flag_gems.runtime.backend._nvidia.hopper.ops.mm")
17CACHE_USAGE_THRESHOLD = 0.8
18EXPAND_CONFIG_FILENAME = os.path.normpath(
19 os.path.join(os.path.dirname(__file__), "..", "mm_hopper_expand.yaml")
20)
21_SHARED_MEM_SAFETY_MARGIN_BYTES = 1024
24def _get_shared_memory_limit_bytes():
25 """Return per-block opt-in shared-memory limit for current CUDA device."""
26 try:
27 if not torch.cuda.is_available():
28 return None
29 return torch.cuda.get_device_properties(
30 torch.cuda.current_device()
31 ).shared_memory_per_block_optin
32 except Exception:
33 return None
36def _estimate_tma_shared_memory_bytes(block_m, block_n, block_k, num_stages):
37 bytes_per_element = 4
38 tile_bytes = (block_m * block_k + block_k * block_n) * bytes_per_element
39 return tile_bytes * num_stages + _SHARED_MEM_SAFETY_MARGIN_BYTES
42def is_tma_compatible(a, b, N, K):
43 """
44 Check if tensors are compatible with TMA (Tensor Memory Accelerator).
46 TMA requires 128-bit (16-byte) alignment for memory access:
47 - For FP16/BF16 (2 bytes/element): N and K must be multiples of 8
48 (8 elements × 2 bytes = 16 bytes)
49 - For FP32 (4 bytes/element): N and K must be multiples of 4
50 (4 elements × 4 bytes = 16 bytes)
52 Args:
53 a, b: Input tensors
54 N, K: Matrix dimensions
56 Returns:
57 bool: True if compatible with TMA's alignment requirements
58 """
59 return (
60 a.dtype in (torch.float16, torch.bfloat16)
61 and b.dtype in (torch.float16, torch.bfloat16)
62 and N % 8 == 0
63 and K % 8 == 0
64 ) or (
65 a.dtype in (torch.float32,)
66 and b.dtype in (torch.float32,)
67 and N % 4 == 0
68 and K % 4 == 0
69 )
72@triton.jit
73def prev_multiple_of(a, b):
74 # the largest x<a that x%b ==0
75 return tl.cdiv(a, b) * b - b
78def matmul_tma_set_block_size_hook(nargs):
79 BLOCK_M = nargs["BLOCK_M"]
80 BLOCK_N = nargs["BLOCK_N"]
81 BLOCK_K = nargs["BLOCK_K"]
82 if nargs["A_ROW_MAJOR"]:
83 nargs["a_desc"].block_shape = [BLOCK_M, BLOCK_K]
84 else:
85 nargs["a_desc"].block_shape = [BLOCK_K, BLOCK_M]
87 if nargs["B_ROW_MAJOR"]:
88 nargs["b_desc"].block_shape = [BLOCK_K, BLOCK_N]
89 else:
90 nargs["b_desc"].block_shape = [BLOCK_N, BLOCK_K]
92 nargs["c_desc"].block_shape = [BLOCK_M, BLOCK_N]
95@libentry()
96@libtuner(
97 configs=runtime.get_tuned_config("mm"),
98 # Add 'stride_am' and 'stride_bk' to trigger autotune for tensors with the same shape but different strides.
99 key=["M", "N", "K", "stride_am", "stride_bk"],
100 strategy=["default", "default", "default", "default", "default"],
101 warmup=5,
102 rep=10,
103)
104@triton.jit
105def mm_kernel_general(
106 A,
107 B,
108 C,
109 M,
110 N,
111 K,
112 stride_am,
113 stride_ak,
114 stride_bk,
115 stride_bn,
116 stride_cm,
117 stride_cn,
118 BLOCK_M: tl.constexpr,
119 BLOCK_N: tl.constexpr,
120 BLOCK_K: tl.constexpr,
121 GROUP_M: tl.constexpr,
122 IS_FP64: tl.constexpr = False,
123):
124 # matrix multiplication
125 pid = tle.program_id(0)
126 grid_m = tl.cdiv(M, BLOCK_M)
127 grid_n = tl.cdiv(N, BLOCK_N)
128 # re-order program ID for better L2 performance
129 width = GROUP_M * grid_n
130 group_id = pid // width
131 group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
132 pid_m = group_id * GROUP_M + (pid % group_size)
133 pid_n = (pid % width) // (group_size)
135 if M % BLOCK_M == 0 and N % BLOCK_N == 0 and K % BLOCK_K == 0:
136 # offset
137 offset_am = pid_m * BLOCK_M
138 offset_bn = pid_n * BLOCK_N
139 offset_k = 0
141 a_desc = tl.make_tensor_descriptor(
142 base=A,
143 shape=[M, K],
144 strides=[K, 1],
145 block_shape=[BLOCK_M, BLOCK_K],
146 )
148 # row-major
149 b_desc = tl.make_tensor_descriptor(
150 base=B,
151 shape=[K, N],
152 strides=[N, 1],
153 block_shape=[BLOCK_K, BLOCK_N],
154 )
156 # column-major
157 # b_desc = tl.make_tensor_descriptor(
158 # B,
159 # shape = [N, K],
160 # strides = [K, 1],
161 # block_shape = [BLOCK_N, BLOCK_K],
162 # )
164 c_desc = tl.make_tensor_descriptor(
165 base=C,
166 shape=[M, N],
167 strides=[N, 1],
168 block_shape=[BLOCK_M, BLOCK_N],
169 )
171 if IS_FP64:
172 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float64)
173 else:
174 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
175 for k in range(0, tl.cdiv(K, BLOCK_K)):
176 a = a_desc.load([offset_am.to(tl.int32), offset_k.to(tl.int32)])
177 b = b_desc.load([offset_k.to(tl.int32), offset_bn.to(tl.int32)])
178 if IS_FP64:
179 acc += tl.dot(a, b, allow_tf32=False)
180 else:
181 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)
182 offset_k += BLOCK_K
184 acc = acc.to(a_desc.dtype)
185 c_desc.store([offset_am.to(tl.int32), offset_bn.to(tl.int32)], acc)
187 else:
188 # do matrix multiplication
189 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
190 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
191 ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M).to(tl.int64)
192 rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N).to(tl.int64)
193 rm = rm.to(tl.int64)
194 rn = rn.to(tl.int64)
195 prev_multiple = prev_multiple_of(K, BLOCK_K)
197 if IS_FP64:
198 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float64)
199 else:
200 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
201 for start_k in range(0, prev_multiple, BLOCK_K):
202 rk = (start_k + tl.arange(0, BLOCK_K)).to(tl.int64)
203 a = tl.load(A + (ram[:, None] * stride_am + rk[None, :] * stride_ak))
204 b = tl.load(B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn))
205 if a.dtype != b.dtype:
206 a = a.to(C.dtype.element_ty)
207 b = b.to(C.dtype.element_ty)
208 if IS_FP64:
209 acc += tl.dot(a, b, allow_tf32=False)
210 else:
211 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)
213 # loop peeling
214 rk = (prev_multiple + tl.arange(0, BLOCK_K)).to(tl.int64)
215 mask_k = rk < K
216 a = tl.load(
217 A + (ram[:, None] * stride_am + rk[None, :] * stride_ak),
218 mask=mask_k[None, :],
219 other=0.0,
220 )
221 b = tl.load(
222 B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn),
223 mask=mask_k[:, None],
224 other=0.0,
225 )
226 if a.dtype != b.dtype:
227 a = a.to(C.dtype.element_ty)
228 b = b.to(C.dtype.element_ty)
229 if IS_FP64:
230 acc += tl.dot(a, b, allow_tf32=False)
231 else:
232 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)
234 acc = acc.to(C.dtype.element_ty)
235 # rematerialize rm and rn to save registers
236 rm = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)).to(tl.int64)
237 rn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)).to(tl.int64)
238 offsets = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
239 mask = (rm < M)[:, None] & (rn < N)[None, :]
240 # handles write-back with reduction-splitting
241 tl.store(offsets, acc, mask=mask)
244def matmul_get_configs(pre_hook=matmul_tma_set_block_size_hook):
245 configs = [
246 triton.Config(
247 {"BLOCK_M": BM, "BLOCK_N": BN, "BLOCK_K": BK},
248 num_stages=s,
249 num_warps=w,
250 pre_hook=pre_hook,
251 )
252 for BM in [32, 64, 128, 256]
253 for BN in [32, 64, 128]
254 for BK in [32, 64, 128]
255 for s in [2, 3, 4]
256 for w in [4, 8]
257 ]
258 shared_mem_limit = _get_shared_memory_limit_bytes()
259 if shared_mem_limit is None:
260 return configs
262 filtered_configs = [
263 cfg
264 for cfg in configs
265 if _estimate_tma_shared_memory_bytes(
266 cfg.kwargs["BLOCK_M"],
267 cfg.kwargs["BLOCK_N"],
268 cfg.kwargs["BLOCK_K"],
269 cfg.num_stages,
270 )
271 <= shared_mem_limit
272 ]
273 if not filtered_configs:
274 logger.warning(
275 "No mm_general_tma config fits shared memory limit (%s bytes); falling back to unfiltered configs.",
276 shared_mem_limit,
277 )
278 return configs
279 return filtered_configs
282@libentry()
283@libtuner(
284 configs=runtime.ops_get_configs(
285 "mm_general_tma",
286 pre_hook=matmul_tma_set_block_size_hook,
287 yaml_path=EXPAND_CONFIG_FILENAME,
288 )
289 if os.environ.get("USE_FLAGTUNE") == "1"
290 else matmul_get_configs(),
291 key=["M", "N", "K", "stride_am", "stride_bk", "dtype"],
292 strategy=runtime.get_expand_config(
293 "mm_general_tma", yaml_path=EXPAND_CONFIG_FILENAME
294 )["strategy"]
295 if os.environ.get("USE_FLAGTUNE") == "1"
296 else ["align32", "align32", "align32", "align32", "align32", "default"],
297 warmup=5,
298 rep=5,
299)
300@triton.jit
301def mm_kernel_general_host_tma(
302 a_desc,
303 b_desc,
304 c_desc,
305 M,
306 N,
307 K,
308 stride_am,
309 stride_ak,
310 stride_bk,
311 stride_bn,
312 stride_cm,
313 stride_cn,
314 BLOCK_M: tl.constexpr,
315 BLOCK_N: tl.constexpr,
316 BLOCK_K: tl.constexpr,
317 GROUP_M: tl.constexpr,
318 A_ROW_MAJOR: tl.constexpr,
319 B_ROW_MAJOR: tl.constexpr,
320 dtype: tl.constexpr,
321 enable_warp_specialization=True,
322):
323 pid = tl.program_id(0)
324 grid_m = tl.cdiv(M, BLOCK_M)
325 grid_n = tl.cdiv(N, BLOCK_N)
327 width = GROUP_M * grid_n
328 group_id = pid // width
329 group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
330 pid_m = group_id * GROUP_M + (pid % group_size)
331 pid_n = (pid % width) // (group_size)
333 accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
334 offset_am = (pid_m * BLOCK_M).to(tl.int32)
335 offset_bn = (pid_n * BLOCK_N).to(tl.int32)
336 iters = tl.cdiv(K, BLOCK_K)
337 for k in range(iters):
338 offset_ak = (k * BLOCK_K).to(tl.int32)
340 if A_ROW_MAJOR:
341 a = a_desc.load([offset_am, offset_ak])
342 else:
343 a_t = a_desc.load([offset_ak, offset_am])
344 a = tl.trans(a_t)
346 if B_ROW_MAJOR:
347 b = b_desc.load([offset_ak, offset_bn])
348 else:
349 b_t = b_desc.load([offset_bn, offset_ak])
350 b = tl.trans(b_t)
352 if a_desc.dtype == tl.float16 or a_desc.dtype == tl.bfloat16:
353 accumulator = tl.dot(a, b, acc=accumulator, allow_tf32=False)
354 else:
355 accumulator = tl.dot(a, b, acc=accumulator, input_precision="tf32x3")
357 c = accumulator.to(c_desc.dtype)
358 c_desc.store([offset_am, offset_bn], c)
361def get_higher_dtype(a, b):
362 _ordered_datatypes = [torch.float16, torch.bfloat16, torch.float32, torch.float64]
364 if a is b:
365 return a
367 assert a in _ordered_datatypes
368 assert b in _ordered_datatypes
370 for d in _ordered_datatypes:
371 if a is d:
372 return b
373 if b is d:
374 return a
377def general_mm(a, b, c, M, N, K):
378 # TODO: Remove this debug message
379 logger.debug(
380 "GEMS MM-hopper, [mm scenario]: general, [shape info]: [-, %s, %s, %s](batch, M, N, K), "
381 "[A column-major]: %s, [B column-major]: %s",
382 M,
383 N,
384 K,
385 a.stride(0) == 1,
386 b.stride(0) == 1,
387 )
388 # Broadcast tensors from expand() have stride=0, incompatible with TMA
389 if 0 in a.stride():
390 a = a.contiguous()
391 if 0 in b.stride():
392 b = b.contiguous()
393 grid = lambda META: (
394 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
395 )
396 if hasattr(
397 triton.tools.tensor_descriptor, "TensorDescriptor"
398 ) and is_tma_compatible(a, b, N, K):
399 a_row_major = a.stride(1) == 1
400 b_row_major = b.stride(1) == 1
401 dummy_block = [1, 1]
402 # triton 3.5.0
403 from triton.tools.tensor_descriptor import TensorDescriptor
405 if a_row_major:
406 a_desc = TensorDescriptor(a, a.shape, a.stride(), dummy_block)
407 else:
408 a_desc = TensorDescriptor(a, a.T.shape, a.T.stride(), dummy_block)
409 if b_row_major:
410 b_desc = TensorDescriptor(b, b.shape, b.stride(), dummy_block)
411 else:
412 b_desc = TensorDescriptor(b, b.T.shape, b.T.stride(), dummy_block)
413 c_desc = TensorDescriptor(c, c.shape, c.stride(), dummy_block)
415 input_dtype = a.dtype
416 dtype_str = str(input_dtype).split(".")[-1]
418 with torch_device_fn.device(a.device):
419 mm_kernel_general_host_tma[grid](
420 a_desc,
421 b_desc,
422 c_desc,
423 M,
424 N,
425 K,
426 a.stride(0),
427 a.stride(1),
428 b.stride(0),
429 b.stride(1),
430 c.stride(0),
431 c.stride(1),
432 GROUP_M=8,
433 A_ROW_MAJOR=a_row_major,
434 B_ROW_MAJOR=b_row_major,
435 dtype=dtype_str,
436 )
437 else:
439 def alloc_fn(size: int, align: int, stream: Optional[int]):
440 return torch.empty(size, dtype=torch.int8, device=a.device)
442 triton.set_allocator(alloc_fn)
444 with torch_device_fn.device(a.device):
445 mm_kernel_general[grid](
446 a,
447 b,
448 c,
449 M,
450 N,
451 K,
452 a.stride(0),
453 a.stride(1),
454 b.stride(0),
455 b.stride(1),
456 c.stride(0),
457 c.stride(1),
458 GROUP_M=8,
459 IS_FP64=a.dtype == torch.float64,
460 )
461 return c
464@libentry()
465@libtuner(
466 configs=runtime.ops_get_configs(
467 "gemv", pre_hook=None, yaml_path=EXPAND_CONFIG_FILENAME
468 )
469 if os.environ.get("USE_FLAGTUNE") == "1"
470 else [
471 triton.Config(
472 {"BLOCK_M": 32, "BLOCK_K": 256},
473 )
474 ],
475 key=["M", "K", "stride_am", "stride_bk"],
476 strategy=runtime.get_expand_config("gemv", yaml_path=EXPAND_CONFIG_FILENAME)[
477 "strategy"
478 ]
479 if os.environ.get("USE_FLAGTUNE") == "1"
480 else ["align32", "align32", "align32", "default"],
481 warmup=5,
482 rep=10,
483)
484@triton.jit
485def gemv_kernel(
486 A,
487 B,
488 C,
489 M,
490 K,
491 stride_am,
492 stride_ak,
493 stride_bk,
494 BLOCK_M: tl.constexpr,
495 BLOCK_K: tl.constexpr,
496 IS_FP64: tl.constexpr = False,
497):
498 """Optimized kernel for matrix-vector multiplication (N=1 case)"""
499 pid = tl.program_id(0)
501 # Each program handles BLOCK_M rows
502 row_start = pid * BLOCK_M
503 row_offset = row_start + tl.arange(0, BLOCK_M)
504 row_mask = row_offset < M
506 # Accumulator for this block of rows
507 if IS_FP64:
508 acc = tl.zeros((BLOCK_M,), dtype=tl.float64)
509 else:
510 acc = tl.zeros((BLOCK_M,), dtype=tl.float32)
512 # Iterate over K dimension
513 for k_start in range(0, K, BLOCK_K):
514 k_offset = k_start + tl.arange(0, BLOCK_K)
515 k_mask = k_offset < K
517 # Load block from matrix A: [BLOCK_M, BLOCK_K]
518 a_ptrs = A + row_offset[:, None] * stride_am + k_offset[None, :] * stride_ak
519 a = tl.load(a_ptrs, mask=row_mask[:, None] & k_mask[None, :], other=0.0)
521 # Load block from vector B: [BLOCK_K]
522 b_ptrs = B + k_offset * stride_bk
523 b = tl.load(b_ptrs, mask=k_mask, other=0.0)
525 # Accumulate: sum over K dimension
526 if IS_FP64:
527 acc += tl.sum(a * b[None, :], axis=1)
528 else:
529 acc += tl.sum(a.to(tl.float32) * b.to(tl.float32)[None, :], axis=1)
531 # Store result
532 c_ptrs = C + row_offset
533 acc = acc.to(C.dtype.element_ty)
534 tl.store(c_ptrs, acc, mask=row_mask)
537def gemv_mm(a, b, c, M, K):
538 """Optimized matrix-vector multiplication for N=1 case"""
539 logger.debug(
540 "GEMS MM-hopper, [mm scenario]: gemv (N=1), [shape info]: [%s, %s, 1](M, K, N)",
541 M,
542 K,
543 )
545 grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]),)
547 with torch_device_fn.device(a.device):
548 gemv_kernel[grid](
549 a,
550 b,
551 c,
552 M,
553 K,
554 a.stride(0),
555 a.stride(1),
556 b.stride(0),
557 IS_FP64=a.dtype == torch.float64,
558 )
559 return c
562def streamk_scenario(a, b, M, N, K):
563 # TODO: this my change sometime according to the realbenchmark result
564 # Currently, the best configuration for streamk has only been tested on A100(capability[0] == 8).
565 # The optimal settings for other devices need to be determined through real testing.
566 capability = get_device_capability()
567 return (
568 capability[0] == 8
569 and a.dtype in [torch.float16, torch.bfloat16]
570 and b.dtype in [torch.float16, torch.bfloat16]
571 and a.is_contiguous()
572 and b.is_contiguous()
573 and K > M * 5
574 and K > N * 5
575 )
578def mm(a, b):
579 device = a.device
580 # handle non-contiguous inputs if necessary
581 if a.stride(0) > 1 and a.stride(1) > 1:
582 a = a.contiguous()
583 if b.stride(0) > 1 and b.stride(1) > 1:
584 b = b.contiguous()
585 # checks constraints
586 assert a.shape[1] == b.shape[0], "incompatible dimensions"
587 M, K = a.shape
588 _, N = b.shape
589 # allocates output
590 c_dtype = get_higher_dtype(a.dtype, b.dtype)
591 c = torch.empty((M, N), device=device, dtype=c_dtype)
593 # Optimize for N=1 case (matrix-vector multiplication)
594 if N == 1:
595 return gemv_mm(a, b, c, M, K)
596 # l2_cache_size = get_l2_cache_size()
597 sm_count = get_sm_count()
598 if streamk_scenario(a, b, M, N, K):
599 return streamk_mm(a, b, c, M, N, K, sm_count=sm_count)
600 else:
601 return general_mm(a, b, c, M, N, K)
604def mm_out(a, b, *, out):
605 # handle non-contiguous inputs if necessary
606 if a.stride(0) > 1 and a.stride(1) > 1:
607 a = a.contiguous()
608 if b.stride(0) > 1 and b.stride(1) > 1:
609 b = b.contiguous()
610 # checks constraints
611 assert a.shape[1] == b.shape[0], "incompatible dimensions"
612 M, K = a.shape
613 _, N = b.shape
615 # Optimize for N=1 case (matrix-vector multiplication)
616 if N == 1:
617 return gemv_mm(a, b, out, M, K)
618 # l2_cache_size = get_l2_cache_size()
619 sm_count = get_sm_count()
620 if streamk_scenario(a, b, M, N, K):
621 return streamk_mm(a, b, out, M, N, K, sm_count=sm_count)
622 else:
623 return general_mm(a, b, out, M, N, K)