Coverage for src/flag_gems/runtime/backend/_mthreads/ops/w8a8_block_fp8_matmul.py: 0%
168 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
1import logging
2import os
3from typing import List
5import torch
6import triton
7import triton.language as tl
9from flag_gems import runtime
10from flag_gems.runtime import torch_device_fn
11from flag_gems.utils import libentry, libtuner
12from flag_gems.utils import triton_lang_extension as tle
14from .utils import create_tma_device_descriptor, get_cached_tma_device_descriptor
16logger = logging.getLogger(
17 "flag_gems.runtime.backend._mthreads.ops.w8a8_block_fp8_matmul"
18)
19EXPAND_CONFIG_FILENAME = os.path.normpath(
20 os.path.join(
21 os.path.dirname(__file__),
22 "..",
23 "w8a8_block_fp8_matmul_mthreads_expand.yaml",
24 )
25)
27SQMMA_ON = False
30def is_supported_sqmma_layout(tensor):
31 return tensor.is_contiguous() or (
32 tensor.stride(0) == 1 and tensor.stride(1) == tensor.shape[0]
33 )
36def is_sqmma_compatible(a, b, output_dtype, n, k):
37 return (
38 a.dim() == 2
39 and SQMMA_ON
40 and b.dim() == 2
41 and a.dtype == b.dtype == torch.float8_e4m3fn
42 and output_dtype in (torch.float16, torch.bfloat16)
43 and is_supported_sqmma_layout(a)
44 and is_supported_sqmma_layout(b)
45 and n % 16 == 0
46 and k % 16 == 0
47 )
50def get_triton_type(elem_type):
51 type_map = {
52 torch.float16: tl.float16,
53 torch.bfloat16: tl.bfloat16,
54 torch.float32: tl.float32,
55 torch.float8_e4m3fn: tl.float8e4nv,
56 }
57 return type_map.get(elem_type, None)
60def matmul_get_configs():
61 return [
62 triton.Config(
63 {
64 "BLOCK_M": 64,
65 "BLOCK_N": 64,
66 "BLOCK_K": 128,
67 "GROUP_M": 8,
68 },
69 num_stages=3,
70 num_warps=4,
71 )
72 ]
75@libentry()
76@libtuner(
77 configs=runtime.ops_get_configs(
78 "w8a8_block_fp8_general", pre_hook=None, yaml_path=EXPAND_CONFIG_FILENAME
79 )
80 if os.environ.get("USE_FLAGTUNE") == "1"
81 else matmul_get_configs(),
82 key=["M", "N", "K", "stride_am", "stride_bk"],
83 strategy=runtime.get_expand_config(
84 "w8a8_block_fp8_general", yaml_path=EXPAND_CONFIG_FILENAME
85 )["strategy"]
86 if os.environ.get("USE_FLAGTUNE") == "1"
87 else ["align32", "align32", "align32", "align32", "align32"],
88 warmup=5,
89 rep=5,
90)
91@triton.jit
92def w8a8_block_fp8_matmul_kernel(
93 A,
94 B,
95 C,
96 As,
97 Bs,
98 M,
99 N,
100 K,
101 group_n,
102 group_k,
103 stride_am,
104 stride_ak,
105 stride_bk,
106 stride_bn,
107 stride_cm,
108 stride_cn,
109 stride_As_m,
110 stride_As_k,
111 stride_Bs_k,
112 stride_Bs_n,
113 BLOCK_M: tl.constexpr,
114 BLOCK_N: tl.constexpr,
115 BLOCK_K: tl.constexpr,
116 GROUP_M: tl.constexpr,
117):
118 pid = tl.program_id(axis=0)
119 num_pid_m = tl.cdiv(M, BLOCK_M)
120 num_pid_n = tl.cdiv(N, BLOCK_N)
121 num_pid_in_group = GROUP_M * num_pid_n
122 group_id = pid // num_pid_in_group
123 first_pid_m = group_id * GROUP_M
124 group_size_m = min(num_pid_m - first_pid_m, GROUP_M)
125 pid_m = first_pid_m + (pid % group_size_m)
126 pid_n = (pid % num_pid_in_group) // group_size_m
128 offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M
129 offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N
130 offs_k = tl.arange(0, BLOCK_K)
131 a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
132 b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
134 As_ptrs = As + offs_am * stride_As_m
135 offs_bsn = offs_bn // group_n
136 Bs_ptrs = Bs + offs_bsn * stride_Bs_n
138 accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
139 for k in range(0, tl.cdiv(K, BLOCK_K)):
140 a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0.0)
141 b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0)
143 k_start = k * BLOCK_K
144 offs_ks = k_start // group_k
145 a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
146 b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
147 accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
148 a_ptrs += BLOCK_K * stride_ak
149 b_ptrs += BLOCK_K * stride_bk
151 if C.dtype.element_ty == tl.bfloat16:
152 c = accumulator.to(tl.bfloat16)
153 elif C.dtype.element_ty == tl.float16:
154 c = accumulator.to(tl.float16)
155 else:
156 c = accumulator.to(tl.float32)
158 offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
159 offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
160 c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
161 c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
162 tl.store(c_ptrs, c, mask=c_mask)
165def sqmma_descriptor_pre_hook(nargs):
166 a = nargs["A"]
167 b = nargs["B"]
168 c = nargs["C"]
169 block_m = nargs["BLOCK_M"]
170 block_n = nargs["BLOCK_N"]
171 block_k = nargs["BLOCK_K"]
172 device = c.device
174 nargs["a_desc_ptr"].copy_(
175 get_cached_tma_device_descriptor(a, block_m, block_k, device)
176 )
177 nargs["b_desc_ptr"].copy_(
178 get_cached_tma_device_descriptor(b, block_k, block_n, device)
179 )
180 nargs["c_desc_ptr"].copy_(create_tma_device_descriptor(c, block_m, block_n, device))
183def sqmma_get_configs(pre_hook=sqmma_descriptor_pre_hook):
184 return [
185 triton.Config(
186 {
187 "BLOCK_M": 64,
188 "BLOCK_N": 64,
189 "BLOCK_K": 128,
190 "GROUP_M": 8,
191 },
192 num_stages=3,
193 num_warps=4,
194 pre_hook=pre_hook,
195 )
196 ]
199@libentry()
200@libtuner(
201 configs=runtime.ops_get_configs(
202 "w8a8_block_fp8_general_tma",
203 pre_hook=sqmma_descriptor_pre_hook,
204 yaml_path=EXPAND_CONFIG_FILENAME,
205 )
206 if os.environ.get("USE_FLAGTUNE") == "1"
207 else sqmma_get_configs(),
208 key=["M", "N", "K", "stride_am", "stride_bk", "dtype"],
209 strategy=runtime.get_expand_config(
210 "w8a8_block_fp8_general_tma", yaml_path=EXPAND_CONFIG_FILENAME
211 )["strategy"]
212 if os.environ.get("USE_FLAGTUNE") == "1"
213 else ["align32", "align32", "align32", "align32", "align32", "default"],
214 warmup=5,
215 rep=5,
216)
217@triton.jit
218def w8a8_block_fp8_matmul_sqmma_kernel(
219 A,
220 B,
221 C,
222 As,
223 Bs,
224 a_desc_ptr,
225 b_desc_ptr,
226 c_desc_ptr,
227 M,
228 N,
229 K,
230 group_n,
231 group_k,
232 stride_am,
233 stride_bk,
234 stride_As_m,
235 stride_As_k,
236 stride_Bs_n,
237 stride_Bs_k,
238 dtype: tl.constexpr,
239 input_dtype: tl.constexpr,
240 output_dtype: tl.constexpr,
241 GROUP_M: tl.constexpr,
242 BLOCK_M: tl.constexpr,
243 BLOCK_N: tl.constexpr,
244 BLOCK_K: tl.constexpr,
245 is_transpose_a: tl.constexpr = False,
246 is_transpose_b: tl.constexpr = True,
247):
248 pid = tle.program_id(0)
249 grid_m = tl.cdiv(M, BLOCK_M)
250 grid_n = tl.cdiv(N, BLOCK_N)
251 width = GROUP_M * grid_n
252 group_id = pid // width
253 group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
254 pid_m = group_id * GROUP_M + (pid % group_size)
255 pid_n = (pid % width) // group_size
257 offs_am = (pid_m * BLOCK_M).to(tl.int32)
258 offs_bn = (pid_n * BLOCK_N).to(tl.int32)
259 offs_k = tl.zeros((), dtype=tl.int32)
261 row_offset = offs_am + tl.arange(0, BLOCK_M)
262 col_offset = offs_bn + tl.arange(0, BLOCK_N)
263 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
264 tme_load_input_dtype = input_dtype
265 c_store_dtype = output_dtype
267 for _ in range(0, tl.cdiv(K, BLOCK_K)):
268 a = tl._experimental_descriptor_load(
269 a_desc_ptr,
270 [offs_am, offs_k],
271 [BLOCK_M, BLOCK_K],
272 tme_load_input_dtype,
273 is_transpose_a,
274 )
275 b = tl._experimental_descriptor_load(
276 b_desc_ptr,
277 [offs_k, offs_bn],
278 [BLOCK_K, BLOCK_N],
279 tme_load_input_dtype,
280 is_transpose_b,
281 )
283 scale_k = offs_k // group_k
284 a_s = tl.load(
285 As + row_offset * stride_As_m + scale_k * stride_As_k,
286 mask=row_offset < M,
287 other=0.0,
288 )
289 b_s = tl.load(
290 Bs + (col_offset // group_n) * stride_Bs_n + scale_k * stride_Bs_k,
291 mask=col_offset < N,
292 other=0.0,
293 )
294 acc += (
295 tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)
296 * a_s[:, None]
297 * b_s[None, :]
298 )
299 offs_k += BLOCK_K
301 tl._experimental_descriptor_store(
302 c_desc_ptr, acc.to(c_store_dtype), [offs_am, offs_bn]
303 )
306def general_w8a8_block_fp8_matmul(
307 a,
308 b,
309 c,
310 a_s,
311 b_s,
312 M,
313 N,
314 K,
315 group_n,
316 group_k,
317):
318 logger.debug(
319 "GEMS_MTHREADS W8A8_BLOCK_FP8_MATMUL(general), [shape info]: [-, %s, %s, %s](batch, M, N, K)",
320 M,
321 N,
322 K,
323 )
324 grid = lambda meta: (
325 triton.cdiv(M, meta["BLOCK_M"]) * triton.cdiv(N, meta["BLOCK_N"]),
326 )
328 with torch_device_fn.device(a.device):
329 w8a8_block_fp8_matmul_kernel[grid](
330 a,
331 b,
332 c,
333 a_s,
334 b_s,
335 M,
336 N,
337 K,
338 group_n,
339 group_k,
340 a.stride(0),
341 a.stride(1),
342 b.stride(1),
343 b.stride(0),
344 c.stride(0),
345 c.stride(1),
346 a_s.stride(0),
347 a_s.stride(1),
348 b_s.stride(1),
349 b_s.stride(0),
350 )
351 return c
354def sqmma_w8a8_block_fp8_matmul(
355 a,
356 b,
357 c,
358 a_s,
359 b_s,
360 M,
361 N,
362 K,
363 group_n,
364 group_k,
365):
366 logger.debug(
367 "GEMS_MTHREADS W8A8_BLOCK_FP8_MATMUL(sqmma), [shape info]: [-, %s, %s, %s](batch, M, N, K), "
368 "[A column-major]: %s, [B column-major]: %s",
369 M,
370 N,
371 K,
372 a.stride(0) == 1,
373 b.stride(0) == 1,
374 )
375 device = a.device
376 is_transpose_a = False
377 is_transpose_b = True
379 if not a.is_contiguous():
380 if a.stride(0) == 1 and a.stride(1) == a.shape[0]:
381 is_transpose_a = True
382 else:
383 a = a.contiguous()
384 if not b.is_contiguous():
385 if b.stride(0) == 1 and b.stride(1) == b.shape[0]:
386 is_transpose_b = False
387 else:
388 b = b.contiguous()
389 is_transpose_b = True
391 desc_a = torch.empty((64,), dtype=torch.int8, device=device)
392 desc_b = torch.empty((64,), dtype=torch.int8, device=device)
393 desc_c = torch.empty((64,), dtype=torch.int8, device=device)
395 grid = lambda meta: (
396 triton.cdiv(M, meta["BLOCK_M"]) * triton.cdiv(N, meta["BLOCK_N"]),
397 1,
398 1,
399 )
401 with torch_device_fn.device(device):
402 w8a8_block_fp8_matmul_sqmma_kernel[grid](
403 a,
404 b,
405 c,
406 a_s,
407 b_s,
408 desc_a,
409 desc_b,
410 desc_c,
411 M,
412 N,
413 K,
414 group_n,
415 group_k,
416 a.stride(0),
417 b.stride(1),
418 a_s.stride(0),
419 a_s.stride(1),
420 b_s.stride(0),
421 b_s.stride(1),
422 dtype=str(a.dtype).split(".")[-1],
423 input_dtype=get_triton_type(a.dtype),
424 output_dtype=get_triton_type(c.dtype),
425 is_transpose_a=is_transpose_a,
426 is_transpose_b=is_transpose_b,
427 )
428 return c
431def w8a8_block_fp8_matmul(
432 A: torch.Tensor,
433 B: torch.Tensor,
434 As: torch.Tensor,
435 Bs: torch.Tensor,
436 block_size: List[int],
437 output_dtype: torch.dtype = torch.bfloat16,
438) -> torch.Tensor:
439 device = A.device
440 assert len(block_size) == 2
441 block_n, block_k = block_size
443 if A.ndim >= 2 and A.stride(-2) > 1 and A.stride(-1) > 1:
444 A = A.contiguous()
445 if B.ndim == 2 and B.stride(0) > 1 and B.stride(1) > 1:
446 B = B.contiguous()
447 if As.ndim >= 2 and As.stride(-2) > 1 and As.stride(-1) > 1:
448 As = As.contiguous()
449 if Bs.ndim == 2 and Bs.stride(0) > 1 and Bs.stride(1) > 1:
450 Bs = Bs.contiguous()
452 assert A.shape[-1] == B.shape[-1], "incompatible dimensions"
453 assert A.shape[:-1] == As.shape[:-1], "A and As dimensions mismatch"
454 assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1], "invalid As shape"
455 assert B.ndim == 2 and Bs.ndim == 2, "B and Bs must be 2D"
457 M = A.numel() // A.shape[-1]
458 N, K = B.shape
459 assert triton.cdiv(N, block_n) == Bs.shape[0], "invalid Bs N dimension"
460 assert triton.cdiv(K, block_k) == Bs.shape[1], "invalid Bs K dimension"
462 output_shape = A.shape[:-1] + (N,)
463 c = torch.empty(output_shape, device=device, dtype=output_dtype)
465 a_2d = A.reshape(M, K)
466 as_2d = As.reshape(M, As.shape[-1])
467 c_2d = c.reshape(M, N)
468 prev_sqmma = os.environ.get("MUSA_ENABLE_SQMMA")
469 os.environ["MUSA_ENABLE_SQMMA"] = "1"
470 try:
471 if is_sqmma_compatible(a_2d, B, output_dtype, N, K):
472 return sqmma_w8a8_block_fp8_matmul(
473 a_2d,
474 B,
475 c_2d,
476 as_2d,
477 Bs,
478 M,
479 N,
480 K,
481 block_n,
482 block_k,
483 ).reshape(c.shape)
485 return general_w8a8_block_fp8_matmul(
486 a_2d,
487 B,
488 c_2d,
489 as_2d,
490 Bs,
491 M,
492 N,
493 K,
494 block_n,
495 block_k,
496 ).reshape(c.shape)
497 finally:
498 if prev_sqmma is None:
499 os.environ.pop("MUSA_ENABLE_SQMMA", None)
500 else:
501 os.environ["MUSA_ENABLE_SQMMA"] = prev_sqmma