Coverage for src/flag_gems/runtime/backend/_nvidia/hopper/ops/mm.py: 0%
265 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-30 03:43 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-30 03:43 +0800
1import logging
2import os
3from typing import Optional
5import torch
6import triton
7import triton.language as tl
8import yaml
10from flag_gems import runtime
11from flag_gems.ops.mm_streamk import streamk_mm
12from flag_gems.runtime import torch_device_fn
13from flag_gems.utils import libentry, libtuner
14from flag_gems.utils import triton_lang_extension as tle
15from flag_gems.utils.device_info import get_device_capability, get_sm_count
17logger = logging.getLogger("flag_gems.runtime.backend._nvidia.hopper.ops.mm")
18CACHE_USAGE_THRESHOLD = 0.8
21def is_tma_compatible(a, b, N, K):
22 """
23 Check if tensors are compatible with TMA (Tensor Memory Accelerator).
25 TMA requires 128-bit (16-byte) alignment for memory access:
26 - For FP16/BF16 (2 bytes/element): N and K must be multiples of 8
27 (8 elements × 2 bytes = 16 bytes)
28 - For FP32 (4 bytes/element): N and K must be multiples of 4
29 (4 elements × 4 bytes = 16 bytes)
31 Args:
32 a, b: Input tensors
33 N, K: Matrix dimensions
35 Returns:
36 bool: True if compatible with TMA's 128-bit alignment requirement
37 """
38 return (
39 a.dtype in (torch.float16, torch.bfloat16)
40 and b.dtype in (torch.float16, torch.bfloat16)
41 and N % 8 == 0
42 and K % 8 == 0
43 ) or (
44 a.dtype in (torch.float32,)
45 and b.dtype in (torch.float32,)
46 and N % 4 == 0
47 and K % 4 == 0
48 )
51@triton.jit
52def prev_multiple_of(a, b):
53 # the largest x<a that x%b ==0
54 return tl.cdiv(a, b) * b - b
57def matmul_tma_set_block_size_hook(nargs):
58 BLOCK_M = nargs["BLOCK_M"]
59 BLOCK_N = nargs["BLOCK_N"]
60 BLOCK_K = nargs["BLOCK_K"]
61 if nargs["A_ROW_MAJOR"]:
62 nargs["a_desc"].block_shape = [BLOCK_M, BLOCK_K]
63 else:
64 nargs["a_desc"].block_shape = [BLOCK_K, BLOCK_M]
66 if nargs["B_ROW_MAJOR"]:
67 nargs["b_desc"].block_shape = [BLOCK_K, BLOCK_N]
68 else:
69 nargs["b_desc"].block_shape = [BLOCK_N, BLOCK_K]
71 nargs["c_desc"].block_shape = [BLOCK_M, BLOCK_N]
74def get_expand_config(op):
75 default_strategies = {
76 "matmul": ["align32", "align32", "align32", "align32", "align32", "default"],
77 "gemv": ["align32", "align32", "align32", "default"],
78 }
79 op_key_orders = {
80 "matmul": ["M", "N", "K", "stride_am", "stride_bk", "dtype"],
81 "gemv": ["M", "K", "stride_am", "stride_bk"],
82 }
83 op_meta_map = {
84 "matmul": {
85 "BM": "BLOCK_M",
86 "BN": "BLOCK_N",
87 "BK": "BLOCK_K",
88 },
89 "gemv": {
90 "BM": "BLOCK_M",
91 "BK": "BLOCK_K",
92 },
93 }
95 if op not in default_strategies:
96 return -1
98 default_strategy = default_strategies[op]
99 config_path = os.path.join(os.path.dirname(__file__), "..", "mm_hopper_expand.yaml")
100 if not os.path.exists(config_path):
101 return -1
103 try:
104 with open(config_path, "r") as file:
105 config = yaml.safe_load(file) or {}
107 expand_configs = config.get(op)
109 gen_config = None
110 strategy_config = None
111 for single_config in expand_configs:
112 if isinstance(single_config, dict) and "param_map" in single_config:
113 gen_config = single_config
114 if isinstance(single_config, dict) and "strategy" in single_config:
115 strategy_config = single_config.get("strategy")
117 param_map = gen_config["param_map"]
118 meta_map = param_map["META"]
120 strategy = default_strategy
121 if isinstance(strategy_config, dict):
122 strategy = [
123 strategy_config.get(k, default_strategy[idx])
124 for idx, k in enumerate(op_key_orders[op])
125 ]
127 ranges = {}
128 for range_key, meta_key in op_meta_map[op].items():
129 ranges[range_key] = gen_config[meta_map[meta_key]]
130 ranges["s"] = gen_config[param_map["num_stages"]]
131 ranges["w"] = gen_config[param_map["num_warps"]]
133 return {
134 "ranges": ranges,
135 "strategy": strategy,
136 }
137 except Exception:
138 return -1
141def matmul_get_configs(pre_hook=matmul_tma_set_block_size_hook):
142 if os.environ.get("USE_FLAGTUNE") == "1":
143 expand_config = get_expand_config("matmul")
144 if expand_config != -1:
145 logger.debug(
146 "Using expand configurations from mm_hopper_expand.yaml for matmul kernel autotuning"
147 )
148 ranges = expand_config["ranges"]
149 return [
150 triton.Config(
151 {"BLOCK_M": BM, "BLOCK_N": BN, "BLOCK_K": BK},
152 num_stages=s,
153 num_warps=w,
154 pre_hook=pre_hook,
155 )
156 for BM in ranges["BM"]
157 for BN in ranges["BN"]
158 for BK in ranges["BK"]
159 for s in ranges["s"]
160 for w in ranges["w"]
161 ]
162 return [
163 triton.Config(
164 {"BLOCK_M": BM, "BLOCK_N": BN, "BLOCK_K": BK},
165 num_stages=s,
166 num_warps=w,
167 pre_hook=pre_hook,
168 )
169 for BM in [32, 64, 128, 256]
170 for BN in [32, 64, 128]
171 for BK in [32, 64, 128]
172 for s in [2, 3, 4]
173 for w in [4, 8]
174 ]
177@libentry()
178@libtuner(
179 configs=matmul_get_configs(pre_hook=None)
180 if os.environ.get("USE_FLAGTUNE") == "1" and get_expand_config("matmul") != -1
181 else runtime.get_tuned_config("mm"),
182 key=["M", "N", "K", "stride_am", "stride_bk", "dtype"],
183 strategy=get_expand_config("matmul")["strategy"]
184 if os.environ.get("USE_FLAGTUNE") == "1" and get_expand_config("matmul") != -1
185 else ["default", "default", "default", "default", "default", "default"],
186 warmup=5,
187 rep=10,
188)
189@triton.jit
190def mm_kernel_general(
191 A,
192 B,
193 C,
194 M,
195 N,
196 K,
197 stride_am,
198 stride_ak,
199 stride_bk,
200 stride_bn,
201 stride_cm,
202 stride_cn,
203 BLOCK_M: tl.constexpr,
204 BLOCK_N: tl.constexpr,
205 BLOCK_K: tl.constexpr,
206 GROUP_M: tl.constexpr,
207):
208 # matrix multiplication
209 pid = tle.program_id(0)
210 grid_m = tl.cdiv(M, BLOCK_M)
211 grid_n = tl.cdiv(N, BLOCK_N)
212 # re-order program ID for better L2 performance
213 width = GROUP_M * grid_n
214 group_id = pid // width
215 group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
216 pid_m = group_id * GROUP_M + (pid % group_size)
217 pid_n = (pid % width) // (group_size)
219 if M % BLOCK_M == 0 and N % BLOCK_N == 0 and K % BLOCK_K == 0:
220 # offset
221 offset_am = pid_m * BLOCK_M
222 offset_bn = pid_n * BLOCK_N
223 offset_k = 0
225 a_desc = tl.make_tensor_descriptor(
226 base=A,
227 shape=[M, K],
228 strides=[K, 1],
229 block_shape=[BLOCK_M, BLOCK_K],
230 )
232 # row-major
233 b_desc = tl.make_tensor_descriptor(
234 base=B,
235 shape=[K, N],
236 strides=[N, 1],
237 block_shape=[BLOCK_K, BLOCK_N],
238 )
240 # column-major
241 # b_desc = tl.make_tensor_descriptor(
242 # B,
243 # shape = [N, K],
244 # strides = [K, 1],
245 # block_shape = [BLOCK_N, BLOCK_K],
246 # )
248 c_desc = tl.make_tensor_descriptor(
249 base=C,
250 shape=[M, N],
251 strides=[N, 1],
252 block_shape=[BLOCK_M, BLOCK_N],
253 )
255 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
256 for k in range(0, tl.cdiv(K, BLOCK_K)):
257 a = a_desc.load([offset_am.to(tl.int32), offset_k.to(tl.int32)])
258 b = b_desc.load([offset_k.to(tl.int32), offset_bn.to(tl.int32)])
259 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)
260 offset_k += BLOCK_K
262 acc = acc.to(a_desc.dtype)
263 c_desc.store([offset_am.to(tl.int32), offset_bn.to(tl.int32)], acc)
265 else:
266 # do matrix multiplication
267 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
268 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
269 ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M).to(tl.int64)
270 rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N).to(tl.int64)
271 rm = rm.to(tl.int64)
272 rn = rn.to(tl.int64)
273 prev_multiple = prev_multiple_of(K, BLOCK_K)
275 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
276 for start_k in range(0, prev_multiple, BLOCK_K):
277 rk = (start_k + tl.arange(0, BLOCK_K)).to(tl.int64)
278 a = tl.load(A + (ram[:, None] * stride_am + rk[None, :] * stride_ak))
279 b = tl.load(B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn))
280 if a.dtype != b.dtype:
281 a = a.to(C.dtype.element_ty)
282 b = b.to(C.dtype.element_ty)
283 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)
285 # loop peeling
286 rk = (prev_multiple + tl.arange(0, BLOCK_K)).to(tl.int64)
287 mask_k = rk < K
288 a = tl.load(
289 A + (ram[:, None] * stride_am + rk[None, :] * stride_ak),
290 mask=mask_k[None, :],
291 other=0.0,
292 )
293 b = tl.load(
294 B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn),
295 mask=mask_k[:, None],
296 other=0.0,
297 )
298 if a.dtype != b.dtype:
299 a = a.to(C.dtype.element_ty)
300 b = b.to(C.dtype.element_ty)
301 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)
303 acc = acc.to(C.dtype.element_ty)
304 # rematerialize rm and rn to save registers
305 rm = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)).to(tl.int64)
306 rn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)).to(tl.int64)
307 offsets = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
308 mask = (rm < M)[:, None] & (rn < N)[None, :]
309 # handles write-back with reduction-splitting
310 tl.store(offsets, acc, mask=mask)
313@libentry()
314@libtuner(
315 configs=matmul_get_configs(),
316 key=["M", "N", "K", "stride_am", "stride_bk", "dtype"],
317 strategy=get_expand_config("matmul")["strategy"]
318 if os.environ.get("USE_FLAGTUNE") == "1" and get_expand_config("matmul") != -1
319 else ["align32", "align32", "align32", "align32", "align32", "default"],
320 warmup=5,
321 rep=5,
322)
323@triton.jit
324def mm_kernel_general_host_tma(
325 a_desc,
326 b_desc,
327 c_desc,
328 M,
329 N,
330 K,
331 stride_am,
332 stride_ak,
333 stride_bk,
334 stride_bn,
335 stride_cm,
336 stride_cn,
337 BLOCK_M: tl.constexpr,
338 BLOCK_N: tl.constexpr,
339 BLOCK_K: tl.constexpr,
340 GROUP_M: tl.constexpr,
341 A_ROW_MAJOR: tl.constexpr,
342 B_ROW_MAJOR: tl.constexpr,
343 dtype: tl.constexpr,
344 enable_warp_specialization=True,
345):
346 pid = tl.program_id(0)
347 grid_m = tl.cdiv(M, BLOCK_M)
348 grid_n = tl.cdiv(N, BLOCK_N)
350 width = GROUP_M * grid_n
351 group_id = pid // width
352 group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
353 pid_m = group_id * GROUP_M + (pid % group_size)
354 pid_n = (pid % width) // (group_size)
356 accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
357 offset_am = (pid_m * BLOCK_M).to(tl.int32)
358 offset_bn = (pid_n * BLOCK_N).to(tl.int32)
359 iters = tl.cdiv(K, BLOCK_K)
360 for k in range(iters):
361 offset_ak = (k * BLOCK_K).to(tl.int32)
363 if A_ROW_MAJOR:
364 a = a_desc.load([offset_am, offset_ak])
365 else:
366 a_t = a_desc.load([offset_ak, offset_am])
367 a = tl.trans(a_t)
369 if B_ROW_MAJOR:
370 b = b_desc.load([offset_ak, offset_bn])
371 else:
372 b_t = b_desc.load([offset_bn, offset_ak])
373 b = tl.trans(b_t)
375 if a_desc.dtype == tl.float16 or a_desc.dtype == tl.bfloat16:
376 accumulator = tl.dot(a, b, acc=accumulator, allow_tf32=False)
377 else:
378 accumulator = tl.dot(a, b, acc=accumulator, input_precision="tf32x3")
380 c = accumulator.to(c_desc.dtype)
381 c_desc.store([offset_am, offset_bn], c)
384def get_higher_dtype(a, b):
385 _ordered_datatypes = [torch.float16, torch.bfloat16, torch.float32]
387 if a is b:
388 return a
390 assert a in _ordered_datatypes
391 assert b in _ordered_datatypes
393 for d in _ordered_datatypes:
394 if a is d:
395 return b
396 if b is d:
397 return a
400def general_mm(a, b, c, M, N, K):
401 # TODO: Remove this debug message
402 logger.debug(
403 "GEMS MM-hopper, [mm scenario]: general, [shape info]: [-, %s, %s, %s](batch, M, N, K), "
404 "[A column-major]: %s, [B column-major]: %s",
405 M,
406 N,
407 K,
408 a.stride(0) == 1,
409 b.stride(0) == 1,
410 )
411 grid = lambda META: (
412 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
413 )
414 if hasattr(
415 triton.tools.tensor_descriptor, "TensorDescriptor"
416 ) and is_tma_compatible(a, b, N, K):
417 a_row_major = a.stride(1) == 1
418 b_row_major = b.stride(1) == 1
419 dummy_block = [1, 1]
420 # triton 3.5.0
421 from triton.tools.tensor_descriptor import TensorDescriptor
423 if a_row_major:
424 a_desc = TensorDescriptor(a, a.shape, a.stride(), dummy_block)
425 else:
426 a_desc = TensorDescriptor(a, a.T.shape, a.T.stride(), dummy_block)
427 if b_row_major:
428 b_desc = TensorDescriptor(b, b.shape, b.stride(), dummy_block)
429 else:
430 b_desc = TensorDescriptor(b, b.T.shape, b.T.stride(), dummy_block)
431 c_desc = TensorDescriptor(c, c.shape, c.stride(), dummy_block)
433 input_dtype = a.dtype
434 dtype_str = str(input_dtype).split(".")[-1]
436 with torch_device_fn.device(a.device):
437 mm_kernel_general_host_tma[grid](
438 a_desc,
439 b_desc,
440 c_desc,
441 M,
442 N,
443 K,
444 a.stride(0),
445 a.stride(1),
446 b.stride(0),
447 b.stride(1),
448 c.stride(0),
449 c.stride(1),
450 GROUP_M=8,
451 A_ROW_MAJOR=a_row_major,
452 B_ROW_MAJOR=b_row_major,
453 dtype=dtype_str,
454 )
455 else:
457 def alloc_fn(size: int, align: int, stream: Optional[int]):
458 return torch.empty(size, dtype=torch.int8, device=a.device)
460 triton.set_allocator(alloc_fn)
462 with torch_device_fn.device(a.device):
463 mm_kernel_general[grid](
464 a,
465 b,
466 c,
467 M,
468 N,
469 K,
470 a.stride(0),
471 a.stride(1),
472 b.stride(0),
473 b.stride(1),
474 c.stride(0),
475 c.stride(1),
476 GROUP_M=8,
477 )
478 return c
481def gemv_get_configs():
482 if os.environ.get("USE_FLAGTUNE") == "1":
483 expand_config = get_expand_config("gemv")
484 if expand_config != -1:
485 logger.debug(
486 "Using expand configurations from mm_hopper_expand.yaml for gemv kernel autotuning"
487 )
488 ranges = expand_config["ranges"]
489 return [
490 triton.Config(
491 {"BLOCK_M": BM, "BLOCK_K": BK},
492 num_stages=s,
493 num_warps=w,
494 )
495 for BM in ranges["BM"]
496 for BK in ranges["BK"]
497 for s in ranges["s"]
498 for w in ranges["w"]
499 ]
500 return [
501 triton.Config(
502 {"BLOCK_M": 32, "BLOCK_K": 256},
503 )
504 ]
507@libentry()
508@libtuner(
509 configs=gemv_get_configs(),
510 key=["M", "K", "stride_am", "stride_bk"],
511 strategy=get_expand_config("gemv")["strategy"]
512 if os.environ.get("USE_FLAGTUNE") == "1" and get_expand_config("gemv") != -1
513 else ["align32", "align32", "align32", "default"],
514 warmup=5,
515 rep=10,
516)
517@triton.jit
518def gemv_kernel(
519 A,
520 B,
521 C,
522 M,
523 K,
524 stride_am,
525 stride_ak,
526 stride_bk,
527 BLOCK_M: tl.constexpr,
528 BLOCK_K: tl.constexpr,
529):
530 """Optimized kernel for matrix-vector multiplication (N=1 case)"""
531 pid = tl.program_id(0)
533 # Each program handles BLOCK_M rows
534 row_start = pid * BLOCK_M
535 row_offset = row_start + tl.arange(0, BLOCK_M)
536 row_mask = row_offset < M
538 # Accumulator for this block of rows
539 acc = tl.zeros((BLOCK_M,), dtype=tl.float32)
541 # Iterate over K dimension
542 for k_start in range(0, K, BLOCK_K):
543 k_offset = k_start + tl.arange(0, BLOCK_K)
544 k_mask = k_offset < K
546 # Load block from matrix A: [BLOCK_M, BLOCK_K]
547 a_ptrs = A + row_offset[:, None] * stride_am + k_offset[None, :] * stride_ak
548 a = tl.load(a_ptrs, mask=row_mask[:, None] & k_mask[None, :], other=0.0)
550 # Load block from vector B: [BLOCK_K]
551 b_ptrs = B + k_offset * stride_bk
552 b = tl.load(b_ptrs, mask=k_mask, other=0.0)
554 # Accumulate: sum over K dimension
555 acc += tl.sum(a * b[None, :], axis=1)
557 # Store result
558 c_ptrs = C + row_offset
559 acc = acc.to(C.dtype.element_ty)
560 tl.store(c_ptrs, acc, mask=row_mask)
563def gemv_mm(a, b, c, M, K):
564 """Optimized matrix-vector multiplication for N=1 case"""
565 logger.debug(
566 "GEMS MM-hopper, [mm scenario]: gemv (N=1), [shape info]: [%s, %s, 1](M, K, N)",
567 M,
568 K,
569 )
571 grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]),)
573 with torch_device_fn.device(a.device):
574 gemv_kernel[grid](
575 a,
576 b,
577 c,
578 M,
579 K,
580 a.stride(0),
581 a.stride(1),
582 b.stride(0),
583 )
584 return c
587def streamk_scenario(a, b, M, N, K):
588 # TODO: this my change sometime according to the realbenchmark result
589 # Currently, the best configuration for streamk has only been tested on A100(capability[0] == 8).
590 # The optimal settings for other devices need to be determined through real testing.
591 capability = get_device_capability()
592 return (
593 capability[0] == 8
594 and a.dtype in [torch.float16, torch.bfloat16]
595 and b.dtype in [torch.float16, torch.bfloat16]
596 and a.is_contiguous()
597 and b.is_contiguous()
598 and K > M * 5
599 and K > N * 5
600 )
603def mm(a, b):
604 device = a.device
605 # handle non-contiguous inputs if necessary
606 if a.stride(0) > 1 and a.stride(1) > 1:
607 a = a.contiguous()
608 if b.stride(0) > 1 and b.stride(1) > 1:
609 b = b.contiguous()
610 # checks constraints
611 assert a.shape[1] == b.shape[0], "incompatible dimensions"
612 M, K = a.shape
613 _, N = b.shape
614 # allocates output
615 c_dtype = get_higher_dtype(a.dtype, b.dtype)
616 c = torch.empty((M, N), device=device, dtype=c_dtype)
618 # Optimize for N=1 case (matrix-vector multiplication)
619 if N == 1:
620 return gemv_mm(a, b, c, M, K)
621 # l2_cache_size = get_l2_cache_size()
622 sm_count = get_sm_count()
623 if streamk_scenario(a, b, M, N, K):
624 return streamk_mm(a, b, c, M, N, K, sm_count=sm_count)
625 else:
626 return general_mm(a, b, c, M, N, K)
629def mm_out(a, b, *, out):
630 # handle non-contiguous inputs if necessary
631 if a.stride(0) > 1 and a.stride(1) > 1:
632 a = a.contiguous()
633 if b.stride(0) > 1 and b.stride(1) > 1:
634 b = b.contiguous()
635 # checks constraints
636 assert a.shape[1] == b.shape[0], "incompatible dimensions"
637 M, K = a.shape
638 _, N = b.shape
640 # Optimize for N=1 case (matrix-vector multiplication)
641 if N == 1:
642 return gemv_mm(a, b, out, M, K)
643 # l2_cache_size = get_l2_cache_size()
644 sm_count = get_sm_count()
645 if streamk_scenario(a, b, M, N, K):
646 return streamk_mm(a, b, out, M, N, K, sm_count=sm_count)
647 else:
648 return general_mm(a, b, out, M, N, K)