Coverage for src/flag_gems/runtime/backend/_ascend/ops/mm.py: 0%
76 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry, libtuner
11logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}')
14@libentry()
15@libtuner(
16 configs=runtime.get_tuned_config("mm"),
17 key=["M", "N", "K"],
18)
19@triton.heuristics(runtime.get_heuristic_config("mm"))
20@triton.jit
21def mm_kernel(
22 A,
23 B,
24 C,
25 M: tl.constexpr,
26 N: tl.constexpr,
27 K: tl.constexpr,
28 stride_am: tl.constexpr,
29 stride_ak: tl.constexpr,
30 stride_bk: tl.constexpr,
31 stride_bn: tl.constexpr,
32 stride_cm: tl.constexpr,
33 stride_cn: tl.constexpr,
34 dot_out_dtype: tl.constexpr,
35 BLOCK_M: tl.constexpr,
36 BLOCK_N: tl.constexpr,
37 BLOCK_K: tl.constexpr,
38 GROUP_M: tl.constexpr,
39 EVEN_K: tl.constexpr,
40):
41 # matrix multiplication
42 pid = tl.program_id(0)
43 pid_z = tl.program_id(1)
44 grid_m = tl.cdiv(M, BLOCK_M)
45 grid_n = tl.cdiv(N, BLOCK_N)
46 # re-order program ID for better L2 performance
47 width = GROUP_M * grid_n
48 group_id = pid // width
49 group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
50 pid_m = group_id * GROUP_M + (pid % group_size)
51 pid_n = (pid % width) // (group_size)
52 # do matrix multiplication
53 ram = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
54 rbn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
55 # ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
56 # rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
57 rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
58 # pointers
59 A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
60 B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
61 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype)
62 for k in range(0, tl.cdiv(K, BLOCK_K)):
63 if EVEN_K:
64 a = tl.load(A)
65 b = tl.load(B)
66 else:
67 k_remaining = K - k * (BLOCK_K)
68 _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty)
69 a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0)
70 b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0)
71 if a.dtype != b.dtype:
72 a = a.to(C.dtype.element_ty)
73 b = b.to(C.dtype.element_ty)
74 acc += tl.dot(a, b, out_dtype=dot_out_dtype, allow_tf32=False)
75 A += BLOCK_K * stride_ak
76 B += BLOCK_K * stride_bk
77 acc = acc.to(C.dtype.element_ty)
78 # rematerialize rm and rn to save registers
79 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
80 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
81 C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
82 mask = (rm < M)[:, None] & (rn < N)[None, :]
83 # handles write-back with reduction-splitting
84 tl.store(C, acc, mask=mask)
87_ordered_datatypes = [torch.float16, torch.bfloat16, torch.float32]
90def get_higher_dtype(a, b):
91 if a is b:
92 return a
94 assert a in _ordered_datatypes
95 assert b in _ordered_datatypes
97 for d in _ordered_datatypes:
98 if a is d:
99 return b
100 if b is d:
101 return a
104def mm(a, b):
105 logger.debug("GEMS_ASCEND MM")
106 device = a.device
107 # handle non-contiguous inputs if necessary
108 if a.stride(0) > 1 and a.stride(1) > 1:
109 a = a.contiguous()
110 if b.stride(0) > 1 and b.stride(1) > 1:
111 b = b.contiguous()
112 # checks constraints
113 assert a.shape[1] == b.shape[0], "incompatible dimensions"
114 M, K = a.shape
115 _, N = b.shape
116 # allocates output
117 c_dtype = get_higher_dtype(a.dtype, b.dtype)
118 c = torch.empty((M, N), device=device, dtype=c_dtype)
119 dot_out_dtype = tl.float32
120 # launch kernel
121 grid = lambda META: (
122 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
123 )
124 with torch_device_fn.device(a.device):
125 mm_kernel[grid](
126 a,
127 b,
128 c,
129 M,
130 N,
131 K,
132 a.stride(0),
133 a.stride(1),
134 b.stride(0),
135 b.stride(1),
136 c.stride(0),
137 c.stride(1),
138 dot_out_dtype=dot_out_dtype,
139 GROUP_M=8,
140 )
141 return c