Coverage for src/flag_gems/runtime/backend/_nvidia/hopper/ops/mm.py: 0%
265 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-26 15:32 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-26 15:32 +0800
1import logging
2import os
3from typing import Optional
5import torch
6import triton
7import triton.language as tl
8import yaml
10from flag_gems import runtime
11from flag_gems.ops.mm_streamk import streamk_mm
12from flag_gems.runtime import torch_device_fn
13from flag_gems.utils import libentry, libtuner
14from flag_gems.utils import triton_lang_extension as tle
15from flag_gems.utils.device_info import get_device_capability, get_sm_count
17logger = logging.getLogger(__name__)
18CACHE_USAGE_THRESHOLD = 0.8
21def is_tma_compatible(a, b, N, K):
22 """
23 Check if tensors are compatible with TMA (Tensor Memory Accelerator).
25 TMA requires 128-bit (16-byte) alignment for memory access:
26 - For FP16/BF16 (2 bytes/element): N and K must be multiples of 8
27 (8 elements × 2 bytes = 16 bytes)
28 - For FP32 (4 bytes/element): N and K must be multiples of 4
29 (4 elements × 4 bytes = 16 bytes)
31 Args:
32 a, b: Input tensors
33 N, K: Matrix dimensions
35 Returns:
36 bool: True if compatible with TMA's 128-bit alignment requirement
37 """
38 return (
39 a.dtype in (torch.float16, torch.bfloat16)
40 and b.dtype in (torch.float16, torch.bfloat16)
41 and N % 8 == 0
42 and K % 8 == 0
43 ) or (
44 a.dtype in (torch.float32,)
45 and b.dtype in (torch.float32,)
46 and N % 4 == 0
47 and K % 4 == 0
48 )
51@triton.jit
52def prev_multiple_of(a, b):
53 # the largest x<a that x%b ==0
54 return tl.cdiv(a, b) * b - b
57@libentry()
58@libtuner(
59 configs=runtime.get_tuned_config("mm"),
60 # Add 'stride_am' and 'stride_bk' to trigger autotune for tensors with the same shape but different strides.
61 key=["M", "N", "K", "stride_am", "stride_bk"],
62 strategy=["default", "default", "default", "default", "default"],
63 warmup=5,
64 rep=10,
65)
66@triton.jit
67def mm_kernel_general(
68 A,
69 B,
70 C,
71 M,
72 N,
73 K,
74 stride_am,
75 stride_ak,
76 stride_bk,
77 stride_bn,
78 stride_cm,
79 stride_cn,
80 BLOCK_M: tl.constexpr,
81 BLOCK_N: tl.constexpr,
82 BLOCK_K: tl.constexpr,
83 GROUP_M: tl.constexpr,
84):
85 # matrix multiplication
86 pid = tle.program_id(0)
87 grid_m = tl.cdiv(M, BLOCK_M)
88 grid_n = tl.cdiv(N, BLOCK_N)
89 # re-order program ID for better L2 performance
90 width = GROUP_M * grid_n
91 group_id = pid // width
92 group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
93 pid_m = group_id * GROUP_M + (pid % group_size)
94 pid_n = (pid % width) // (group_size)
96 if M % BLOCK_M == 0 and N % BLOCK_N == 0 and K % BLOCK_K == 0:
97 # offset
98 offset_am = pid_m * BLOCK_M
99 offset_bn = pid_n * BLOCK_N
100 offset_k = 0
102 a_desc = tl.make_tensor_descriptor(
103 base=A,
104 shape=[M, K],
105 strides=[K, 1],
106 block_shape=[BLOCK_M, BLOCK_K],
107 )
109 # row-major
110 b_desc = tl.make_tensor_descriptor(
111 base=B,
112 shape=[K, N],
113 strides=[N, 1],
114 block_shape=[BLOCK_K, BLOCK_N],
115 )
117 # column-major
118 # b_desc = tl.make_tensor_descriptor(
119 # B,
120 # shape = [N, K],
121 # strides = [K, 1],
122 # block_shape = [BLOCK_N, BLOCK_K],
123 # )
125 c_desc = tl.make_tensor_descriptor(
126 base=C,
127 shape=[M, N],
128 strides=[N, 1],
129 block_shape=[BLOCK_M, BLOCK_N],
130 )
132 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
133 for k in range(0, tl.cdiv(K, BLOCK_K)):
134 a = a_desc.load([offset_am.to(tl.int32), offset_k.to(tl.int32)])
135 b = b_desc.load([offset_k.to(tl.int32), offset_bn.to(tl.int32)])
136 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)
137 offset_k += BLOCK_K
139 acc = acc.to(a_desc.dtype)
140 c_desc.store([offset_am.to(tl.int32), offset_bn.to(tl.int32)], acc)
142 else:
143 # do matrix multiplication
144 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
145 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
146 ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M).to(tl.int64)
147 rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N).to(tl.int64)
148 rm = rm.to(tl.int64)
149 rn = rn.to(tl.int64)
150 prev_multiple = prev_multiple_of(K, BLOCK_K)
152 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
153 for start_k in range(0, prev_multiple, BLOCK_K):
154 rk = (start_k + tl.arange(0, BLOCK_K)).to(tl.int64)
155 a = tl.load(A + (ram[:, None] * stride_am + rk[None, :] * stride_ak))
156 b = tl.load(B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn))
157 if a.dtype != b.dtype:
158 a = a.to(C.dtype.element_ty)
159 b = b.to(C.dtype.element_ty)
160 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)
162 # loop peeling
163 rk = (prev_multiple + tl.arange(0, BLOCK_K)).to(tl.int64)
164 mask_k = rk < K
165 a = tl.load(
166 A + (ram[:, None] * stride_am + rk[None, :] * stride_ak),
167 mask=mask_k[None, :],
168 other=0.0,
169 )
170 b = tl.load(
171 B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn),
172 mask=mask_k[:, None],
173 other=0.0,
174 )
175 if a.dtype != b.dtype:
176 a = a.to(C.dtype.element_ty)
177 b = b.to(C.dtype.element_ty)
178 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)
180 acc = acc.to(C.dtype.element_ty)
181 # rematerialize rm and rn to save registers
182 rm = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)).to(tl.int64)
183 rn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)).to(tl.int64)
184 offsets = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
185 mask = (rm < M)[:, None] & (rn < N)[None, :]
186 # handles write-back with reduction-splitting
187 tl.store(offsets, acc, mask=mask)
190def matmul_tma_set_block_size_hook(nargs):
191 BLOCK_M = nargs["BLOCK_M"]
192 BLOCK_N = nargs["BLOCK_N"]
193 BLOCK_K = nargs["BLOCK_K"]
194 if nargs["A_ROW_MAJOR"]:
195 nargs["a_desc"].block_shape = [BLOCK_M, BLOCK_K]
196 else:
197 nargs["a_desc"].block_shape = [BLOCK_K, BLOCK_M]
199 if nargs["B_ROW_MAJOR"]:
200 nargs["b_desc"].block_shape = [BLOCK_K, BLOCK_N]
201 else:
202 nargs["b_desc"].block_shape = [BLOCK_N, BLOCK_K]
204 nargs["c_desc"].block_shape = [BLOCK_M, BLOCK_N]
207def get_expand_config(op):
208 default_strategies = {
209 "matmul": ["align32", "align32", "align32", "align32", "align32", "default"],
210 "gemv": ["align32", "align32", "align32", "default"],
211 }
212 op_key_orders = {
213 "matmul": ["M", "N", "K", "stride_am", "stride_bk", "dtype"],
214 "gemv": ["M", "K", "stride_am", "stride_bk"],
215 }
216 op_meta_map = {
217 "matmul": {
218 "BM": "BLOCK_M",
219 "BN": "BLOCK_N",
220 "BK": "BLOCK_K",
221 },
222 "gemv": {
223 "BM": "BLOCK_M",
224 "BK": "BLOCK_K",
225 },
226 }
228 if op not in default_strategies:
229 return -1
231 default_strategy = default_strategies[op]
232 config_path = os.path.join(
233 os.path.dirname(__file__), "..", "mm_hopper_tma_expand.yaml"
234 )
235 if not os.path.exists(config_path):
236 return -1
238 try:
239 with open(config_path, "r") as file:
240 config = yaml.safe_load(file) or {}
242 expand_configs = config.get(op)
244 gen_config = None
245 strategy_config = None
246 for single_config in expand_configs:
247 if isinstance(single_config, dict) and "param_map" in single_config:
248 gen_config = single_config
249 if isinstance(single_config, dict) and "strategy" in single_config:
250 strategy_config = single_config.get("strategy")
252 param_map = gen_config["param_map"]
253 meta_map = param_map["META"]
255 strategy = default_strategy
256 if isinstance(strategy_config, dict):
257 strategy = [
258 strategy_config.get(k, default_strategy[idx])
259 for idx, k in enumerate(op_key_orders[op])
260 ]
262 ranges = {}
263 for range_key, meta_key in op_meta_map[op].items():
264 ranges[range_key] = gen_config[meta_map[meta_key]]
265 ranges["s"] = gen_config[param_map["num_stages"]]
266 ranges["w"] = gen_config[param_map["num_warps"]]
268 return {
269 "ranges": ranges,
270 "strategy": strategy,
271 }
272 except Exception:
273 return -1
276def matmul_get_configs(pre_hook=matmul_tma_set_block_size_hook):
277 if os.environ.get("USE_FLAGTUNE") == "1":
278 expand_config = get_expand_config("matmul")
279 if expand_config != -1:
280 logger.debug(
281 "Using expand configurations from mm_hopper_tma_expand.yaml for matmul kernel autotuning"
282 )
283 ranges = expand_config["ranges"]
284 return [
285 triton.Config(
286 {"BLOCK_M": BM, "BLOCK_N": BN, "BLOCK_K": BK},
287 num_stages=s,
288 num_warps=w,
289 pre_hook=pre_hook,
290 )
291 for BM in ranges["BM"]
292 for BN in ranges["BN"]
293 for BK in ranges["BK"]
294 for s in ranges["s"]
295 for w in ranges["w"]
296 ]
297 return [
298 triton.Config(
299 {"BLOCK_M": BM, "BLOCK_N": BN, "BLOCK_K": BK},
300 num_stages=s,
301 num_warps=w,
302 pre_hook=pre_hook,
303 )
304 for BM in [32, 64, 128, 256]
305 for BN in [32, 64, 128]
306 for BK in [32, 64, 128]
307 for s in [2, 3, 4]
308 for w in [4, 8]
309 ]
312@libentry()
313@libtuner(
314 configs=matmul_get_configs(),
315 key=["M", "N", "K", "stride_am", "stride_bk", "dtype"],
316 strategy=get_expand_config("matmul")["strategy"]
317 if os.environ.get("USE_FLAGTUNE") == "1" and get_expand_config("matmul") != -1
318 else ["align32", "align32", "align32", "align32", "align32", "default"],
319 warmup=5,
320 rep=5,
321)
322@triton.jit
323def mm_kernel_general_host_tma(
324 a_desc,
325 b_desc,
326 c_desc,
327 M,
328 N,
329 K,
330 stride_am,
331 stride_ak,
332 stride_bk,
333 stride_bn,
334 stride_cm,
335 stride_cn,
336 BLOCK_M: tl.constexpr,
337 BLOCK_N: tl.constexpr,
338 BLOCK_K: tl.constexpr,
339 GROUP_M: tl.constexpr,
340 A_ROW_MAJOR: tl.constexpr,
341 B_ROW_MAJOR: tl.constexpr,
342 dtype: tl.constexpr,
343 enable_warp_specialization=True,
344):
345 pid = tl.program_id(0)
346 grid_m = tl.cdiv(M, BLOCK_M)
347 grid_n = tl.cdiv(N, BLOCK_N)
349 width = GROUP_M * grid_n
350 group_id = pid // width
351 group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
352 pid_m = group_id * GROUP_M + (pid % group_size)
353 pid_n = (pid % width) // (group_size)
355 accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
356 offset_am = (pid_m * BLOCK_M).to(tl.int32)
357 offset_bn = (pid_n * BLOCK_N).to(tl.int32)
358 iters = tl.cdiv(K, BLOCK_K)
359 for k in range(iters):
360 offset_ak = (k * BLOCK_K).to(tl.int32)
362 if A_ROW_MAJOR:
363 a = a_desc.load([offset_am, offset_ak])
364 else:
365 a_t = a_desc.load([offset_ak, offset_am])
366 a = tl.trans(a_t)
368 if B_ROW_MAJOR:
369 b = b_desc.load([offset_ak, offset_bn])
370 else:
371 b_t = b_desc.load([offset_bn, offset_ak])
372 b = tl.trans(b_t)
374 if a_desc.dtype == tl.float16 or a_desc.dtype == tl.bfloat16:
375 accumulator = tl.dot(a, b, acc=accumulator, allow_tf32=False)
376 else:
377 accumulator = tl.dot(a, b, acc=accumulator, input_precision="tf32x3")
379 c = accumulator.to(c_desc.dtype)
380 c_desc.store([offset_am, offset_bn], c)
383def get_higher_dtype(a, b):
384 _ordered_datatypes = [torch.float16, torch.bfloat16, torch.float32]
386 if a is b:
387 return a
389 assert a in _ordered_datatypes
390 assert b in _ordered_datatypes
392 for d in _ordered_datatypes:
393 if a is d:
394 return b
395 if b is d:
396 return a
399def general_mm(a, b, c, M, N, K):
400 # TODO: Remove this debug message
401 logger.debug(
402 "GEMS MM-hopper, [mm scenario]: general, [shape info]: [-, %s, %s, %s](batch, M, N, K), "
403 "[A column-major]: %s, [B column-major]: %s",
404 M,
405 N,
406 K,
407 a.stride(0) == 1,
408 b.stride(0) == 1,
409 )
410 grid = lambda META: (
411 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
412 )
413 if hasattr(
414 triton.tools.tensor_descriptor, "TensorDescriptor"
415 ) and is_tma_compatible(a, b, N, K):
416 a_row_major = a.stride(1) == 1
417 b_row_major = b.stride(1) == 1
418 dummy_block = [1, 1]
419 # triton 3.5.0
420 from triton.tools.tensor_descriptor import TensorDescriptor
422 if a_row_major:
423 a_desc = TensorDescriptor(a, a.shape, a.stride(), dummy_block)
424 else:
425 a_desc = TensorDescriptor(a, a.T.shape, a.T.stride(), dummy_block)
426 if b_row_major:
427 b_desc = TensorDescriptor(b, b.shape, b.stride(), dummy_block)
428 else:
429 b_desc = TensorDescriptor(b, b.T.shape, b.T.stride(), dummy_block)
430 c_desc = TensorDescriptor(c, c.shape, c.stride(), dummy_block)
432 input_dtype = a.dtype
433 dtype_str = str(input_dtype).split(".")[-1]
435 with torch_device_fn.device(a.device):
436 mm_kernel_general_host_tma[grid](
437 a_desc,
438 b_desc,
439 c_desc,
440 M,
441 N,
442 K,
443 a.stride(0),
444 a.stride(1),
445 b.stride(0),
446 b.stride(1),
447 c.stride(0),
448 c.stride(1),
449 GROUP_M=8,
450 A_ROW_MAJOR=a_row_major,
451 B_ROW_MAJOR=b_row_major,
452 dtype=dtype_str,
453 )
454 else:
456 def alloc_fn(size: int, align: int, stream: Optional[int]):
457 return torch.empty(size, dtype=torch.int8, device=a.device)
459 triton.set_allocator(alloc_fn)
461 with torch_device_fn.device(a.device):
462 mm_kernel_general[grid](
463 a,
464 b,
465 c,
466 M,
467 N,
468 K,
469 a.stride(0),
470 a.stride(1),
471 b.stride(0),
472 b.stride(1),
473 c.stride(0),
474 c.stride(1),
475 GROUP_M=8,
476 )
477 return c
480def gemv_get_configs():
481 if os.environ.get("USE_FLAGTUNE") == "1":
482 expand_config = get_expand_config("gemv")
483 if expand_config != -1:
484 logger.debug(
485 "Using expand configurations from mm_hopper_tma_expand.yaml for gemv kernel autotuning"
486 )
487 ranges = expand_config["ranges"]
488 return [
489 triton.Config(
490 {"BLOCK_M": BM, "BLOCK_K": BK},
491 num_stages=s,
492 num_warps=w,
493 )
494 for BM in ranges["BM"]
495 for BK in ranges["BK"]
496 for s in ranges["s"]
497 for w in ranges["w"]
498 ]
499 return [
500 triton.Config(
501 {"BLOCK_M": 32, "BLOCK_K": 256},
502 )
503 ]
506@libentry()
507@libtuner(
508 configs=gemv_get_configs(),
509 key=["M", "K", "stride_am", "stride_bk"],
510 strategy=get_expand_config("gemv")["strategy"]
511 if os.environ.get("USE_FLAGTUNE") == "1" and get_expand_config("gemv") != -1
512 else ["align32", "align32", "align32", "default"],
513 warmup=5,
514 rep=10,
515)
516@triton.jit
517def gemv_kernel(
518 A,
519 B,
520 C,
521 M,
522 K,
523 stride_am,
524 stride_ak,
525 stride_bk,
526 BLOCK_M: tl.constexpr,
527 BLOCK_K: tl.constexpr,
528):
529 """Optimized kernel for matrix-vector multiplication (N=1 case)"""
530 pid = tl.program_id(0)
532 # Each program handles BLOCK_M rows
533 row_start = pid * BLOCK_M
534 row_offset = row_start + tl.arange(0, BLOCK_M)
535 row_mask = row_offset < M
537 # Accumulator for this block of rows
538 acc = tl.zeros((BLOCK_M,), dtype=tl.float32)
540 # Iterate over K dimension
541 for k_start in range(0, K, BLOCK_K):
542 k_offset = k_start + tl.arange(0, BLOCK_K)
543 k_mask = k_offset < K
545 # Load block from matrix A: [BLOCK_M, BLOCK_K]
546 a_ptrs = A + row_offset[:, None] * stride_am + k_offset[None, :] * stride_ak
547 a = tl.load(a_ptrs, mask=row_mask[:, None] & k_mask[None, :], other=0.0)
549 # Load block from vector B: [BLOCK_K]
550 b_ptrs = B + k_offset * stride_bk
551 b = tl.load(b_ptrs, mask=k_mask, other=0.0)
553 # Accumulate: sum over K dimension
554 acc += tl.sum(a * b[None, :], axis=1)
556 # Store result
557 c_ptrs = C + row_offset
558 acc = acc.to(C.dtype.element_ty)
559 tl.store(c_ptrs, acc, mask=row_mask)
562def gemv_mm(a, b, c, M, K):
563 """Optimized matrix-vector multiplication for N=1 case"""
564 logger.debug(
565 "GEMS MM-hopper, [mm scenario]: gemv (N=1), [shape info]: [%s, %s, 1](M, K, N)",
566 M,
567 K,
568 )
570 grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]),)
572 with torch_device_fn.device(a.device):
573 gemv_kernel[grid](
574 a,
575 b,
576 c,
577 M,
578 K,
579 a.stride(0),
580 a.stride(1),
581 b.stride(0),
582 )
583 return c
586def streamk_scenario(a, b, M, N, K):
587 # TODO: this my change sometime according to the realbenchmark result
588 # Currently, the best configuration for streamk has only been tested on A100(capability[0] == 8).
589 # The optimal settings for other devices need to be determined through real testing.
590 capability = get_device_capability()
591 return (
592 capability[0] == 8
593 and a.dtype in [torch.float16, torch.bfloat16]
594 and b.dtype in [torch.float16, torch.bfloat16]
595 and a.is_contiguous()
596 and b.is_contiguous()
597 and K > M * 5
598 and K > N * 5
599 )
602def mm(a, b):
603 device = a.device
604 # handle non-contiguous inputs if necessary
605 if a.stride(0) > 1 and a.stride(1) > 1:
606 a = a.contiguous()
607 if b.stride(0) > 1 and b.stride(1) > 1:
608 b = b.contiguous()
609 # checks constraints
610 assert a.shape[1] == b.shape[0], "incompatible dimensions"
611 M, K = a.shape
612 _, N = b.shape
613 # allocates output
614 c_dtype = get_higher_dtype(a.dtype, b.dtype)
615 c = torch.empty((M, N), device=device, dtype=c_dtype)
617 # Optimize for N=1 case (matrix-vector multiplication)
618 if N == 1:
619 return gemv_mm(a, b, c, M, K)
620 # l2_cache_size = get_l2_cache_size()
621 sm_count = get_sm_count()
622 if streamk_scenario(a, b, M, N, K):
623 return streamk_mm(a, b, c, M, N, K, sm_count=sm_count)
624 else:
625 return general_mm(a, b, c, M, N, K)
628def mm_out(a, b, *, out):
629 # handle non-contiguous inputs if necessary
630 if a.stride(0) > 1 and a.stride(1) > 1:
631 a = a.contiguous()
632 if b.stride(0) > 1 and b.stride(1) > 1:
633 b = b.contiguous()
634 # checks constraints
635 assert a.shape[1] == b.shape[0], "incompatible dimensions"
636 M, K = a.shape
637 _, N = b.shape
639 # Optimize for N=1 case (matrix-vector multiplication)
640 if N == 1:
641 return gemv_mm(a, b, out, M, K)
642 # l2_cache_size = get_l2_cache_size()
643 sm_count = get_sm_count()
644 if streamk_scenario(a, b, M, N, K):
645 return streamk_mm(a, b, out, M, N, K, sm_count=sm_count)
646 else:
647 return general_mm(a, b, out, M, N, K)