Coverage for src/flag_gems/ops/bmm.py: 38%
93 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry, libtuner
10from flag_gems.utils import triton_lang_extension as tle
12logger = logging.getLogger(__name__)
15@libentry()
16@libtuner(
17 configs=runtime.get_tuned_config("bmm"),
18 key=["M", "N", "K", "stride_am", "stride_bk"],
19 strategy=["log", "log", "log", "align32", "align32"],
20)
21@triton.heuristics(runtime.get_heuristic_config("bmm"))
22@triton.jit
23def bmm_kernel(
24 A,
25 B,
26 O,
27 M,
28 N,
29 K,
30 stride_ab,
31 stride_am,
32 stride_ak,
33 stride_bb,
34 stride_bk,
35 stride_bn,
36 stride_ob,
37 stride_om,
38 stride_on,
39 TILE_M: tl.constexpr,
40 TILE_N: tl.constexpr,
41 TILE_K: tl.constexpr,
42 GROUP_M: tl.constexpr,
43 DIVISIBLE_M: tl.constexpr,
44 DIVISIBLE_N: tl.constexpr,
45 DIVISIBLE_K: tl.constexpr,
46):
47 # batch offsets
48 pid_b = tle.program_id(2)
49 A += pid_b * stride_ab
50 B += pid_b * stride_bb
51 O += pid_b * stride_ob
53 pidx = tle.program_id(0)
54 pidy = tle.program_id(1)
56 if GROUP_M == 1:
57 pid_m, pid_n = pidx, pidy
58 else:
59 # reorder CTAs
60 gridx = tle.num_programs(0)
61 gridy = tle.num_programs(1)
62 pid = pidx + pidy * gridx
64 num_CTA_per_group = gridy * GROUP_M
66 group_id = pid // num_CTA_per_group
67 inner_group_id = pid % num_CTA_per_group
68 GROUP_SIZE = tl.where(
69 (group_id * GROUP_M + GROUP_M) > gridx, gridx % GROUP_M, GROUP_M
70 )
71 pid_m = group_id * GROUP_M + inner_group_id % GROUP_SIZE
72 pid_n = inner_group_id // GROUP_SIZE
74 offs_m = pid_m * TILE_M + tl.arange(0, TILE_M)
75 offs_n = pid_n * TILE_N + tl.arange(0, TILE_N)
76 offs_k = tl.arange(0, TILE_K)
78 if not DIVISIBLE_M:
79 mask_m = offs_m < M
80 if not DIVISIBLE_N:
81 mask_n = offs_n < N
83 a_ptrs = A + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
84 b_ptrs = B + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
85 o_ptrs = O + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
87 num_iters = tl.cdiv(K, TILE_K)
88 o = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
89 for _ in range(num_iters):
90 if DIVISIBLE_K:
91 if DIVISIBLE_M:
92 mask_a = None
93 else:
94 mask_a = mask_m[:, None]
95 if DIVISIBLE_N:
96 mask_b = None
97 else:
98 mask_b = mask_n[None, :]
99 else:
100 mask_k = offs_k < K
101 if DIVISIBLE_M:
102 mask_a = mask_k[None, :]
103 else:
104 mask_a = mask_m[:, None] & mask_k[None, :]
105 if DIVISIBLE_N:
106 mask_b = mask_k[:, None]
107 else:
108 mask_b = mask_k[:, None] & mask_n[None, :]
110 a = tl.load(a_ptrs, mask_a)
111 b = tl.load(b_ptrs, mask_b)
113 offs_k += TILE_K
114 a_ptrs += TILE_K * stride_ak
115 b_ptrs += TILE_K * stride_bk
117 o += tl.dot(a, b, allow_tf32=False)
119 if DIVISIBLE_M and DIVISIBLE_N:
120 mask_c = None
121 elif DIVISIBLE_M and not DIVISIBLE_N:
122 mask_c = mask_n[None, :]
123 elif not DIVISIBLE_M and DIVISIBLE_N:
124 mask_c = mask_m[:, None]
125 else:
126 mask_c = mask_m[:, None] & mask_n[None, :]
127 tl.store(o_ptrs, o, mask_c)
130def bmm(A, B):
131 logger.debug("GEMS BMM")
132 assert A.shape[0] == B.shape[0], "Batch dim mismatch"
133 assert A.shape[2] == B.shape[1], "K dim mismatch"
134 batch, M, K = A.shape
135 _, _, N = B.shape
136 out = torch.empty((batch, M, N), dtype=A.dtype, device=A.device)
138 grid_fn = lambda meta: (
139 triton.cdiv(meta["M"], meta["TILE_M"]),
140 triton.cdiv(meta["N"], meta["TILE_N"]),
141 batch,
142 )
143 with torch_device_fn.device(A.device):
144 bmm_kernel[grid_fn](
145 A,
146 B,
147 out,
148 M,
149 N,
150 K,
151 A.stride(0),
152 A.stride(1),
153 A.stride(2),
154 B.stride(0),
155 B.stride(1),
156 B.stride(2),
157 out.stride(0),
158 out.stride(1),
159 out.stride(2),
160 )
161 return out
164def bmm_out(A, B, out):
165 logger.debug("GEMS BMM_OUT")
166 assert A.shape[0] == B.shape[0] == out.shape[0], "Batch dim mismatch"
167 assert A.shape[2] == B.shape[1], "K dim mismatch"
168 batch, M, K = A.shape
169 _, _, N = B.shape
171 grid_fn = lambda meta: (
172 triton.cdiv(meta["M"], meta["TILE_M"]),
173 triton.cdiv(meta["N"], meta["TILE_N"]),
174 batch,
175 )
176 with torch_device_fn.device(A.device):
177 bmm_kernel[grid_fn](
178 A,
179 B,
180 out,
181 M,
182 N,
183 K,
184 A.stride(0),
185 A.stride(1),
186 A.stride(2),
187 B.stride(0),
188 B.stride(1),
189 B.stride(2),
190 out.stride(0),
191 out.stride(1),
192 out.stride(2),
193 )
194 return out