Coverage for src/flag_gems/runtime/backend/_mthreads/ops/bmm.py: 0%
137 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-19 02:32 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-19 02:32 +0800
1import logging
2import os
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry, libtuner
11from flag_gems.utils import triton_lang_extension as tle
13from .utils import create_tma_device_descriptor, should_enable_sqmma
15logger = logging.getLogger(
16 f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}'
17)
20@libentry()
21@libtuner(
22 configs=runtime.get_tuned_config("bmm"),
23 key=["M", "N", "K"],
24 strategy=["log", "log", "log"],
25)
26@triton.heuristics(runtime.get_heuristic_config("bmm"))
27@triton.jit
28def bmm_kernel(
29 A,
30 B,
31 O,
32 M,
33 N,
34 K,
35 TILE_M: tl.constexpr,
36 TILE_N: tl.constexpr,
37 TILE_K: tl.constexpr,
38 GROUP_M: tl.constexpr,
39 DIVISIBLE_M: tl.constexpr,
40 DIVISIBLE_N: tl.constexpr,
41 DIVISIBLE_K: tl.constexpr,
42):
43 # batch offsets
44 pid_b = tle.program_id(2)
45 A += pid_b * M * K
46 B += pid_b * K * N
47 O += pid_b * M * N
49 pidx = tle.program_id(0)
50 pidy = tle.program_id(1)
52 if GROUP_M == 1:
53 pid_m, pid_n = pidx, pidy
54 else:
55 # reorder CTAs
56 gridx = tle.num_programs(0)
57 gridy = tle.num_programs(1)
58 pid = pidx + pidy * gridx
60 num_CTA_per_group = gridy * GROUP_M
62 group_id = pid // num_CTA_per_group
63 inner_group_id = pid % num_CTA_per_group
64 GROUP_SIZE = tl.where(
65 (group_id * GROUP_M + GROUP_M) > gridx, gridx % GROUP_M, GROUP_M
66 )
67 pid_m = group_id * GROUP_M + inner_group_id % GROUP_SIZE
68 pid_n = inner_group_id // GROUP_SIZE
70 offs_m = pid_m * TILE_M + tl.arange(0, TILE_M)
71 offs_n = pid_n * TILE_N + tl.arange(0, TILE_N)
72 offs_k = tl.arange(0, TILE_K)
74 if not DIVISIBLE_M:
75 mask_m = offs_m < M
76 if not DIVISIBLE_N:
77 mask_n = offs_n < N
79 a_ptrs = A + offs_m[:, None] * K + offs_k[None, :]
80 b_ptrs = B + offs_k[:, None] * N + offs_n[None, :]
81 o_ptrs = O + offs_m[:, None] * N + offs_n[None, :]
83 num_iters = tl.cdiv(K, TILE_K)
84 o = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
85 for _ in range(num_iters):
86 if DIVISIBLE_K:
87 if DIVISIBLE_M:
88 mask_a = None
89 else:
90 mask_a = mask_m[:, None]
91 if DIVISIBLE_N:
92 mask_b = None
93 else:
94 mask_b = mask_n[None, :]
95 else:
96 mask_k = offs_k < K
97 if DIVISIBLE_M:
98 mask_a = mask_k[None, :]
99 else:
100 mask_a = mask_m[:, None] & mask_k[None, :]
101 if DIVISIBLE_N:
102 mask_b = mask_k[:, None]
103 else:
104 mask_b = mask_k[:, None] & mask_n[None, :]
106 a = tl.load(a_ptrs, mask_a)
107 b = tl.load(b_ptrs, mask_b)
109 offs_k += TILE_K
110 a_ptrs += TILE_K
111 b_ptrs += TILE_K * N
113 o += tl.dot(a, b, allow_tf32=False)
115 if DIVISIBLE_M and DIVISIBLE_N:
116 mask_c = None
117 elif DIVISIBLE_M and not DIVISIBLE_N:
118 mask_c = mask_n[None, :]
119 elif not DIVISIBLE_M and DIVISIBLE_N:
120 mask_c = mask_m[:, None]
121 else:
122 mask_c = mask_m[:, None] & mask_n[None, :]
123 tl.store(o_ptrs, o, mask_c)
126def bmm_fma(A, B):
127 logger.debug("GEMS_MTHREADS BMM(FMA)")
128 batch, M, K = A.shape
129 _, _, N = B.shape
130 A = A.contiguous()
131 B = B.contiguous()
132 out = torch.empty((batch, M, N), dtype=A.dtype, device=A.device)
134 grid_fn = lambda meta: (
135 triton.cdiv(meta["M"], meta["TILE_M"]),
136 triton.cdiv(meta["N"], meta["TILE_N"]),
137 batch,
138 )
139 with torch_device_fn.device(A.device):
140 bmm_kernel[grid_fn](A, B, out, M, N, K)
141 return out
144@triton.jit
145def bmm_sqmma_kernel(
146 a_desc_ptr,
147 b_desc_ptr,
148 c_desc_ptr,
149 M,
150 N,
151 K,
152 BLOCK_SIZE_M: tl.constexpr,
153 BLOCK_SIZE_N: tl.constexpr,
154 BLOCK_SIZE_K: tl.constexpr,
155 ab_type: tl.constexpr,
156 d_type: tl.constexpr,
157):
158 pid = tl.program_id(axis=0)
159 batch_index = tl.program_id(axis=1)
160 num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
161 pid_m = pid % num_pid_m
162 pid_n = pid // num_pid_m
163 offs_am = pid_m * BLOCK_SIZE_M + batch_index * M
164 offs_bn = pid_n * BLOCK_SIZE_N
165 offs_ak = 0
166 offs_bk = batch_index * K
167 tme_load_type = ab_type
168 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
169 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
170 a = tl._experimental_descriptor_load(
171 a_desc_ptr, [offs_am, offs_ak], [BLOCK_SIZE_M, BLOCK_SIZE_K], tme_load_type
172 )
173 b = tl._experimental_descriptor_load(
174 b_desc_ptr, [offs_bk, offs_bn], [BLOCK_SIZE_K, BLOCK_SIZE_N], tme_load_type
175 )
176 accumulator = tl.dot(a, b, acc=accumulator)
177 offs_ak += BLOCK_SIZE_K
178 offs_bk += BLOCK_SIZE_K
179 accumulator = accumulator.to(d_type)
180 tl._experimental_descriptor_store(c_desc_ptr, accumulator, [offs_am, offs_bn])
183def get_triton_type(elem_type):
184 type_map = {
185 torch.float16: tl.float16,
186 torch.bfloat16: tl.bfloat16,
187 torch.float8_e4m3fn: tl.float8e4nv,
188 }
189 return type_map.get(elem_type, None)
192def bmm_sqmma(
193 A, B, elem_type, batch, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_stages
194):
195 device = "musa"
196 ab_type = elem_type
197 c_type = elem_type if (elem_type != torch.bfloat16) else torch.float16
198 C = torch.empty((batch, M, N), dtype=torch.float16, device=device).to(c_type)
199 desc_a = create_tma_device_descriptor(
200 A.reshape(batch * M, K), BLOCK_M, BLOCK_K, device
201 )
202 desc_b = create_tma_device_descriptor(
203 B.reshape(batch * K, N), BLOCK_K, BLOCK_N, device
204 )
205 desc_c = create_tma_device_descriptor(
206 C.reshape(batch * M, N), BLOCK_M, BLOCK_N, device
207 )
208 bmm_sqmma_kernel[(triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), batch, 1)](
209 desc_a,
210 desc_b,
211 desc_c,
212 M,
213 N,
214 K,
215 BLOCK_M,
216 BLOCK_N,
217 BLOCK_K,
218 get_triton_type(ab_type),
219 get_triton_type(c_type),
220 num_warps=num_warps,
221 num_stages=num_stages,
222 )
223 return C
226def bmm(a, b):
227 a_dtype = a.dtype
228 b_dtype = b.dtype
229 batch, M, K = a.shape
230 _, _, N = b.shape
231 use_sqmma = should_enable_sqmma(a_dtype, b_dtype, M, N, K)
232 if use_sqmma:
233 BLOCK_M = 128
234 BLOCK_N = BLOCK_M
235 BLOCK_K = 64
236 num_warps = 16 if BLOCK_M == 256 else 4
237 num_stages = 1
238 return bmm_sqmma(
239 a,
240 b,
241 a_dtype,
242 batch,
243 M,
244 N,
245 K,
246 BLOCK_M,
247 BLOCK_N,
248 BLOCK_K,
249 num_warps,
250 num_stages,
251 )
252 else:
253 enable_sqmma = os.environ.pop("MUSA_ENABLE_SQMMA", None)
254 result = bmm_fma(a, b)
255 if enable_sqmma:
256 os.environ["MUSA_ENABLE_SQMMA"] = enable_sqmma
257 return result