Coverage for src/flag_gems/runtime/backend/_cambricon/ops/bmm.py: 0%
92 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-18 02:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-18 02:36 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry, libtuner
11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
14@libentry()
15@libtuner(
16 configs=runtime.get_tuned_config("bmm"),
17 key=["M", "N", "K", "stride_am", "stride_bk"],
18 strategy=["log", "log", "log", "align32", "align32"],
19)
20@triton.heuristics(runtime.get_heuristic_config("bmm"))
21@triton.jit
22def bmm_kernel(
23 A,
24 B,
25 O,
26 M,
27 N,
28 K,
29 stride_ab,
30 stride_am,
31 stride_ak,
32 stride_bb,
33 stride_bk,
34 stride_bn,
35 stride_ob,
36 stride_om,
37 stride_on,
38 TILE_M: tl.constexpr,
39 TILE_N: tl.constexpr,
40 TILE_K: tl.constexpr,
41 GROUP_M: tl.constexpr,
42 DIVISIBLE_M: tl.constexpr,
43 DIVISIBLE_N: tl.constexpr,
44 DIVISIBLE_K: tl.constexpr,
45):
46 # batch offsets
47 pid_b = tl.program_id(2)
48 A += pid_b * stride_ab
49 B += pid_b * stride_bb
50 O += pid_b * stride_ob
52 pidx = tl.program_id(0)
53 pidy = tl.program_id(1)
55 if GROUP_M == 1:
56 pid_m, pid_n = pidx, pidy
57 else:
58 # reorder CTAs
59 gridx = tl.num_programs(0)
60 gridy = tl.num_programs(1)
61 pid = pidx + pidy * gridx
63 num_CTA_per_group = gridy * GROUP_M
65 group_id = pid // num_CTA_per_group
66 inner_group_id = pid % num_CTA_per_group
67 GROUP_SIZE = tl.where(
68 (group_id * GROUP_M + GROUP_M) > gridx, gridx % GROUP_M, GROUP_M
69 )
70 pid_m = group_id * GROUP_M + inner_group_id % GROUP_SIZE
71 pid_n = inner_group_id // GROUP_SIZE
73 offs_m = pid_m * TILE_M + tl.arange(0, TILE_M)
74 offs_n = pid_n * TILE_N + tl.arange(0, TILE_N)
75 offs_k = tl.arange(0, TILE_K)
77 if not DIVISIBLE_M:
78 mask_m = offs_m < M
79 if not DIVISIBLE_N:
80 mask_n = offs_n < N
82 a_ptrs = A + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
83 b_ptrs = B + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
84 o_ptrs = O + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
86 num_iters = tl.cdiv(K, TILE_K)
87 o = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
88 for _ in range(num_iters):
89 if DIVISIBLE_K:
90 if DIVISIBLE_M:
91 mask_a = None
92 else:
93 mask_a = mask_m[:, None]
94 if DIVISIBLE_N:
95 mask_b = None
96 else:
97 mask_b = mask_n[None, :]
98 else:
99 mask_k = offs_k < K
100 if DIVISIBLE_M:
101 mask_a = mask_k[None, :]
102 else:
103 mask_a = mask_m[:, None] & mask_k[None, :]
104 if DIVISIBLE_N:
105 mask_b = mask_k[:, None]
106 else:
107 mask_b = mask_k[:, None] & mask_n[None, :]
109 a = tl.load(a_ptrs, mask_a)
110 b = tl.load(b_ptrs, mask_b)
112 offs_k += TILE_K
113 a_ptrs += TILE_K * stride_ak
114 b_ptrs += TILE_K * stride_bk
116 o += tl.dot(a, b, allow_tf32=False)
118 if DIVISIBLE_M and DIVISIBLE_N:
119 mask_c = None
120 elif DIVISIBLE_M and not DIVISIBLE_N:
121 mask_c = mask_n[None, :]
122 elif not DIVISIBLE_M and DIVISIBLE_N:
123 mask_c = mask_m[:, None]
124 else:
125 mask_c = mask_m[:, None] & mask_n[None, :]
126 tl.store(o_ptrs, o, mask_c)
129def bmm(A, B):
130 logger.debug("GEMS_CAMBRICON BMM")
131 assert A.shape[0] == B.shape[0], "Batch dim mismatch"
132 assert A.shape[2] == B.shape[1], "K dim mismatch"
133 batch, M, K = A.shape
134 _, _, N = B.shape
135 out = torch.empty((batch, M, N), dtype=A.dtype, device=A.device)
137 grid_fn = lambda meta: (
138 triton.cdiv(meta["M"], meta["TILE_M"]),
139 triton.cdiv(meta["N"], meta["TILE_N"]),
140 batch,
141 )
142 with torch_device_fn.device(A.device):
143 bmm_kernel[grid_fn](
144 A,
145 B,
146 out,
147 M,
148 N,
149 K,
150 A.stride(0),
151 A.stride(1),
152 A.stride(2),
153 B.stride(0),
154 B.stride(1),
155 B.stride(2),
156 out.stride(0),
157 out.stride(1),
158 out.stride(2),
159 )
160 return out
163def bmm_out(A, B, out):
164 logger.debug("GEMS_CAMBRICON BMM_OUT")
165 assert A.shape[0] == B.shape[0] == out.shape[0], "Batch dim mismatch"
166 assert A.shape[2] == B.shape[1], "K dim mismatch"
167 batch, M, K = A.shape
168 _, _, N = B.shape
170 grid_fn = lambda meta: (
171 triton.cdiv(meta["M"], meta["TILE_M"]),
172 triton.cdiv(meta["N"], meta["TILE_N"]),
173 batch,
174 )
175 with torch_device_fn.device(A.device):
176 bmm_kernel[grid_fn](
177 A,
178 B,
179 out,
180 M,
181 N,
182 K,
183 A.stride(0),
184 A.stride(1),
185 A.stride(2),
186 B.stride(0),
187 B.stride(1),
188 B.stride(2),
189 out.stride(0),
190 out.stride(1),
191 out.stride(2),
192 )
193 return out