Coverage for src/flag_gems/fused/cutlass_scaled_mm.py: 19%
194 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-24 15:40 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-24 15:40 +0800
1import logging
2from typing import Callable, Optional
4import torch
5import triton
6import triton.language as tl
8from flag_gems.utils.device_info import get_device_capability
10logger = logging.getLogger(__name__)
12SCALE_BLOCK_K, SCALE_BLOCK_N = 128, 128
15def get_sm_version_num():
16 major, minor = get_device_capability()
17 return major * 10 + minor
20SM_VERSION_NUM = get_sm_version_num()
23def get_block_wise_smm_configs():
24 tile_configs = [
25 # (TILE_M, TILE_N, num_stages, num_warps)
26 (32, 64, 5, 2),
27 (64, 32, 5, 2),
28 (64, 128, 4, 4),
29 (64, 256, 4, 4),
30 (128, 32, 4, 4),
31 (128, 64, 4, 4),
32 (128, 128, 4, 4),
33 (128, 256, 3, 8),
34 (256, 64, 4, 4),
35 (256, 128, 3, 8),
36 ]
38 return [
39 triton.Config(
40 {
41 "TILE_M": TILE_M,
42 "TILE_N": TILE_N,
43 "TILE_K": SCALE_BLOCK_K,
44 "SWIZZLE_GROUP_M": 8,
45 },
46 num_stages=stages,
47 num_warps=warps,
48 )
49 for TILE_M, TILE_N, stages, warps in tile_configs
50 ]
53@triton.jit
54def grouped_launch(
55 pid, M, N, TILE_M: tl.constexpr, TILE_N: tl.constexpr, SWIZZLE_GROUP_M: tl.constexpr
56):
57 grid_m = tl.cdiv(M, TILE_M)
58 grid_n = tl.cdiv(N, TILE_N)
60 width = SWIZZLE_GROUP_M * grid_n
61 group_id = pid // width
62 group_size = tl.minimum(grid_m - group_id * SWIZZLE_GROUP_M, SWIZZLE_GROUP_M)
64 pid_m = group_id * SWIZZLE_GROUP_M + (pid % group_size)
65 pid_n = (pid % width) // group_size
67 return pid_m, pid_n
70# block-wise dequantization kernel implemention
71# this kernel supports many `SCALE_BLOCK_K, SCALE_BLOCK_N` cases
72# as long as `TILE_K == SCALE_BLOCK_K` and `TILE_N % SCALE_BLOCK_N == 0`
73@triton.autotune(
74 configs=get_block_wise_smm_configs(),
75 key=["_M_NPO2", "N", "K"],
76)
77@triton.jit
78def _block_wise_smm_kernel(
79 a_ptr,
80 b_ptr,
81 c_ptr,
82 a_scale_ptr,
83 b_scale_ptr,
84 M,
85 N,
86 K,
87 _M_NPO2: tl.constexpr,
88 SCALE_BLOCK_N,
89 SCALE_BLOCK_K,
90 stride_am,
91 stride_ak,
92 stride_bk,
93 stride_bn,
94 stride_cm,
95 stride_cn,
96 stride_Ascale_m,
97 stride_Ascale_k,
98 stride_Bscale_k,
99 stride_Bscale_n,
100 TILE_M: tl.constexpr,
101 TILE_N: tl.constexpr,
102 TILE_K: tl.constexpr,
103 SWIZZLE_GROUP_M: tl.constexpr,
104):
105 pid = tl.program_id(0)
106 pid_m, pid_n = grouped_launch(pid, M, N, TILE_M, TILE_N, SWIZZLE_GROUP_M)
108 offs_am = (pid_m * TILE_M + tl.arange(0, TILE_M)) % M
109 offs_bn = (pid_n * TILE_N + tl.arange(0, TILE_N)) % N
110 offs_k = tl.arange(0, TILE_K)
111 a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
112 b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
114 a_scale_ptrs = a_scale_ptr + offs_am * stride_Ascale_m
115 offs_bsn = offs_bn // SCALE_BLOCK_N
116 b_scale_ptrs = b_scale_ptr + offs_bsn * stride_Bscale_n
118 acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
119 for k in range(0, tl.cdiv(K, TILE_K)):
120 k_remaining = K - k * TILE_K
121 a = tl.load(a_ptrs, mask=offs_k[None, :] < k_remaining, other=0.0)
122 b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0)
123 offs_ks = k * TILE_K // SCALE_BLOCK_K
124 a_scale = tl.load(a_scale_ptrs + offs_ks * stride_Ascale_k)
125 b_scale = tl.load(b_scale_ptrs + offs_ks * stride_Bscale_k)
126 acc += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
127 a_ptrs += TILE_K * stride_ak
128 b_ptrs += TILE_K * stride_bk
130 acc = acc.to(c_ptr.dtype.element_ty)
132 offs_cm = pid_m * TILE_M + tl.arange(0, TILE_M)
133 offs_cn = pid_n * TILE_N + tl.arange(0, TILE_N)
134 c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
135 mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
136 tl.store(c_ptrs, acc, mask=mask)
139def _block_wise_128_smm_launcher(
140 c: torch.Tensor,
141 a: torch.Tensor,
142 b: torch.Tensor,
143 a_scale: torch.Tensor,
144 b_scale: torch.Tensor,
145) -> torch.Tensor:
146 global SCALE_BLOCK_K, SCALE_BLOCK_N
147 SCALE_BLOCK_K, SCALE_BLOCK_N = 128, 128
148 M, K = a.shape
149 _, N = b.shape
150 _M_NPO2 = triton.next_power_of_2(M)
152 grid = lambda META: (
153 triton.cdiv(M, META["TILE_M"]) * triton.cdiv(N, META["TILE_N"]),
154 )
156 _block_wise_smm_kernel[grid](
157 a,
158 b,
159 c,
160 a_scale,
161 b_scale,
162 M,
163 N,
164 K,
165 _M_NPO2,
166 SCALE_BLOCK_N,
167 SCALE_BLOCK_K,
168 a.stride(0),
169 a.stride(1),
170 b.stride(0),
171 b.stride(1),
172 c.stride(0),
173 c.stride(1),
174 a_scale.stride(0),
175 a_scale.stride(1),
176 b_scale.stride(0),
177 b_scale.stride(1),
178 )
180 return c
183# per-tensor and per-token dequantization kernel implemention
184@triton.autotune(
185 configs=[
186 triton.Config({"TILE_M": 64, "TILE_N": 64, "TILE_K": 256}),
187 triton.Config({"TILE_M": 64, "TILE_N": 128, "TILE_K": 128}),
188 triton.Config({"TILE_M": 128, "TILE_N": 128, "TILE_K": 128}),
189 ],
190 key=["_M_NPO2", "N", "K"],
191)
192@triton.jit
193def _pertensor_or_pertoken_smm_kernel(
194 c_ptr,
195 a_ptr,
196 b_ptr,
197 a_scale_ptr,
198 b_scale_ptr,
199 bias_ptr,
200 M,
201 N,
202 K,
203 _M_NPO2,
204 stride_am,
205 stride_ak,
206 stride_bk,
207 stride_bn,
208 stride_cm,
209 stride_cn,
210 ACC_DTYPE: tl.constexpr,
211 TILE_M: tl.constexpr,
212 TILE_N: tl.constexpr,
213 TILE_K: tl.constexpr,
214 IS_PER_TOKEN_A: tl.constexpr,
215 IS_PER_TOKEN_B: tl.constexpr,
216):
217 if IS_PER_TOKEN_A:
218 TILE_SIZE_SCALE_A: tl.constexpr = TILE_M
219 else:
220 TILE_SIZE_SCALE_A: tl.constexpr = 1
222 if IS_PER_TOKEN_B:
223 TILE_SIZE_SCALE_B: tl.constexpr = TILE_N
224 else:
225 TILE_SIZE_SCALE_B: tl.constexpr = 1
227 pid = tl.program_id(axis=0)
228 num_pid_n = tl.cdiv(N, TILE_N)
229 pid_m = pid // num_pid_n
230 pid_n = pid % num_pid_n
232 acc = tl.zeros((TILE_M, TILE_N), dtype=ACC_DTYPE)
234 offsets_am = pid_m * TILE_M + tl.arange(0, TILE_M).to(tl.int64)
235 masks_am = offsets_am < M
237 offsets_bn = pid_n * TILE_N + tl.arange(0, TILE_N).to(tl.int64)
238 masks_bn = offsets_bn < N
240 offsets_k = tl.arange(0, TILE_K).to(tl.int64)
241 offsets_a = stride_am * offsets_am[:, None] + stride_ak * offsets_k[None, :]
242 offsets_b = stride_bk * offsets_k[:, None] + stride_bn * offsets_bn[None, :]
244 offsets_scale_am = (
245 tl.arange(0, TILE_SIZE_SCALE_A) + (TILE_SIZE_SCALE_A > 1) * pid_m * TILE_M
246 )
247 masks_scale_am = offsets_scale_am < M
249 offsets_scale_bn = (
250 tl.arange(0, TILE_SIZE_SCALE_B) + (TILE_SIZE_SCALE_B > 1) * pid_n * TILE_N
251 )
252 masks_scale_bn = offsets_scale_bn < N
254 a_ptrs = a_ptr + offsets_a
255 b_ptrs = b_ptr + offsets_b
257 scale_a_ptrs = a_scale_ptr + offsets_scale_am
258 scale_b_ptrs = b_scale_ptr + offsets_scale_bn
260 for k in range(0, tl.cdiv(K, TILE_K)):
261 masks_k = offsets_k < K
262 masks_a = masks_am[:, None] & masks_k[None, :]
263 a = tl.load(a_ptrs, mask=masks_a)
265 masks_b = masks_k[:, None] & masks_bn[None, :]
266 b = tl.load(b_ptrs, mask=masks_b)
268 acc = tl.dot(a, b, acc, out_dtype=ACC_DTYPE)
270 offsets_k += TILE_K
271 a_ptrs += TILE_K * stride_ak
272 b_ptrs += TILE_K * stride_bk
274 masks_scale_a = masks_scale_am[:, None] & (tl.arange(0, 1) < 1)[:, None]
275 a_scale = tl.load(scale_a_ptrs[:, None], masks_scale_a)
276 a_scale = a_scale.broadcast_to((TILE_M, 1))
277 acc = a_scale * acc.to(tl.float32)
279 masks_scale_b = masks_scale_bn[:, None] & (tl.arange(0, 1) < 1)[None, :]
280 b_scale = tl.load(scale_b_ptrs[:, None], masks_scale_b)
281 b_scale = b_scale.broadcast_to((TILE_N, 1))
282 acc = b_scale.T * acc.to(tl.float32)
284 c = acc.to(c_ptr.type.element_ty)
286 if bias_ptr:
287 offsets_bias = offsets_bn
288 bias_ptrs = bias_ptr + offsets_bias
289 bias_mask = offsets_bias < N
290 bias = tl.load(bias_ptrs, bias_mask)
291 c += bias
293 offs_cm = pid_m * TILE_M + tl.arange(0, TILE_M).to(tl.int64)
294 offs_cn = pid_n * TILE_N + tl.arange(0, TILE_N).to(tl.int64)
295 offs_cm = offs_cm.to(tl.int64)
296 offs_cn = offs_cn.to(tl.int64)
297 c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
298 c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
300 tl.store(c_ptrs, c, mask=c_mask)
303def _pertensor_or_pertoken_smm_launcher(
304 c: torch.Tensor,
305 a: torch.Tensor,
306 b: torch.Tensor,
307 a_scale: torch.Tensor,
308 b_scale: torch.Tensor,
309 bias: torch.Tensor | None = None,
310) -> torch.Tensor:
311 M, K = a.shape
312 _, N = b.shape
314 grid = lambda META: (
315 triton.cdiv(M, META["TILE_M"]) * triton.cdiv(N, META["TILE_N"]),
316 )
318 ACC_DTYPE = tl.float32 if a.is_floating_point() else tl.int32
320 _M_NPO2 = triton.next_power_of_2(M)
322 IS_PER_TOKEN_A = a_scale.numel() == M
323 IS_PER_TOKEN_B = b_scale.numel() == N
325 _pertensor_or_pertoken_smm_kernel[grid](
326 c,
327 a,
328 b,
329 a_scale,
330 b_scale,
331 bias,
332 M,
333 N,
334 K,
335 _M_NPO2,
336 a.stride(0),
337 a.stride(1),
338 b.stride(0),
339 b.stride(1),
340 c.stride(0),
341 c.stride(1),
342 ACC_DTYPE=ACC_DTYPE,
343 IS_PER_TOKEN_A=IS_PER_TOKEN_A,
344 IS_PER_TOKEN_B=IS_PER_TOKEN_B,
345 )
347 return c
350cutlass_scaled_mm_sm90_fp8 = _pertensor_or_pertoken_smm_launcher
352cutlass_scaled_mm_sm90_int8 = _pertensor_or_pertoken_smm_launcher
354cutlass_scaled_mm_blockwise_sm90_fp8 = _block_wise_128_smm_launcher
357def dispatch_scaled_mm(
358 c: torch.Tensor,
359 a: torch.Tensor,
360 b: torch.Tensor,
361 a_scale: torch.Tensor,
362 b_scale: torch.Tensor,
363 bias: Optional[torch.Tensor],
364 fp8_func: Callable,
365 int8_func: Optional[Callable],
366 blockwise_func: Callable,
367) -> None:
368 assert a_scale.dtype == torch.float32, "a_scale must be float32"
369 assert b_scale.dtype == torch.float32, "b_scale must be float32"
371 if (a_scale.numel() == 1 or a_scale.numel() == a.size(0)) and (
372 b_scale.numel() == 1 or b_scale.numel() == b.size(1)
373 ):
374 assert a_scale.is_contiguous(), "a_scale must be contiguous"
375 assert b_scale.is_contiguous(), "b_scale must be contiguous"
377 if a.dtype == torch.float8_e4m3fn:
378 fp8_func(c, a, b, a_scale, b_scale, bias)
379 else:
380 assert a.dtype == torch.int8, f"Unsupported dtype: {a.dtype}"
382 if int8_func is not None:
383 int8_func(c, a, b, a_scale, b_scale, bias)
384 else:
385 raise RuntimeError(
386 f"Int8 not supported on SM{SM_VERSION_NUM}. "
387 f"Use FP8 quantization instead, or run on older arch (SM < 100)."
388 )
389 else:
390 assert a_scale.dim() == 2, "a_scale must be 2D tensor for blockwise scaling"
391 assert b_scale.dim() == 2, "b_scale must be 2D tensor for blockwise scaling"
393 if SM_VERSION_NUM >= 90:
394 assert a.size(0) == a_scale.size(0), (
395 f"a_scale must have same first dimension as a: "
396 f"a.shape[0]={a.size(0)}, a_scale.shape[0]={a_scale.size(0)}"
397 )
398 assert triton.cdiv(a.size(1), 128) == a_scale.size(1), (
399 f"a_scale second dimension mismatch: "
400 f"triton.cdiv({a.size(1)}, 128)={triton.cdiv(a.size(1), 128)} != "
401 f"a_scale.shape[1]={a_scale.size(1)}"
402 )
404 assert triton.cdiv(b.size(0), 128) == b_scale.size(0), (
405 f"b_scale first dimension mismatch: "
406 f"triton.cdiv({b.size(0)}, 128)={triton.cdiv(b.size(0), 128)} != "
407 f"b_scale.shape[0]={b_scale.size(0)}"
408 )
409 assert triton.cdiv(b.size(1), 128) == b_scale.size(1), (
410 f"b_scale second dimension mismatch: "
411 f"triton.cdiv({b.size(1)}, 128)={triton.cdiv(b.size(1), 128)} != "
412 f"b_scale.shape[1]={b_scale.size(1)}"
413 )
415 assert bias is None, "Bias not yet supported for blockwise scaled_mm"
417 blockwise_func(c, a, b, a_scale, b_scale)
420def cutlass_scaled_mm_sm90(
421 c: torch.Tensor,
422 a: torch.Tensor,
423 b: torch.Tensor,
424 a_scale: torch.Tensor,
425 b_scale: torch.Tensor,
426 bias: Optional[torch.Tensor] = None,
427) -> None:
428 dispatch_scaled_mm(
429 c=c,
430 a=a,
431 b=b,
432 a_scale=a_scale,
433 b_scale=b_scale,
434 bias=bias,
435 fp8_func=cutlass_scaled_mm_sm90_fp8,
436 int8_func=cutlass_scaled_mm_sm90_int8,
437 blockwise_func=cutlass_scaled_mm_blockwise_sm90_fp8,
438 )
441def cutlass_scaled_mm_sm120(*args, **kwargs):
442 raise NotImplementedError("cutlass_scaled_mm_sm120 is not yet implemented. ")
445def cutlass_scaled_mm_sm100(*args, **kwargs):
446 raise NotImplementedError("cutlass_scaled_mm_sm100 is not yet implemented. ")
449def cutlass_scaled_mm_sm89(*args, **kwargs):
450 raise NotImplementedError("cutlass_scaled_mm_sm89 is not yet implemented. ")
453def cutlass_scaled_mm_sm80(*args, **kwargs):
454 raise NotImplementedError("cutlass_scaled_mm_sm80 is not yet implemented. ")
457def cutlass_scaled_mm_sm75(*args, **kwargs):
458 raise NotImplementedError("cutlass_scaled_mm_sm75 is not yet implemented. ")
461def cutlass_scaled_mm(
462 c: torch.Tensor,
463 a: torch.Tensor,
464 b: torch.Tensor,
465 a_scale: torch.Tensor,
466 b_scale: torch.Tensor,
467 bias: Optional[torch.Tensor] = None,
468) -> torch.Tensor:
469 logger.debug("GEMS CUTLASS SCALED MM")
470 assert (
471 a.dim() == 2 and b.dim() == 2 and c.dim() == 2
472 ), "All inputs must be 2D tensors"
474 assert c.size(0) == a.size(0), "Number of rows in c must equal number of rows in a"
475 assert a.size(1) == b.size(
476 0
477 ), "Number of columns in a must equal number of rows in b"
478 assert b.size(1) == c.size(
479 1
480 ), "Number of columns in b must equal number of columns in c"
482 assert a.stride(1) == 1 and c.stride(1) == 1, "a and c must be row-major"
484 assert b.stride(0) == 1, "b must be column-major"
486 assert c.stride(0) % 16 == 0, "Row stride of c must be 16-byte aligned"
487 assert b.stride(1) % 16 == 0, "Column stride of b must be 16-byte aligned"
489 if bias is not None:
490 assert bias.numel() == b.size(
491 1
492 ), f"Bias size {bias.numel()} must equal number of columns in b {b.size(1)}"
493 assert bias.is_contiguous(), "Bias must be contiguous"
494 assert bias.dim() == 1, "Bias must be a 1D tensor"
496 if SM_VERSION_NUM >= 120:
497 cutlass_scaled_mm_sm120(c, a, b, a_scale, b_scale, bias)
499 elif SM_VERSION_NUM >= 100:
500 cutlass_scaled_mm_sm100(c, a, b, a_scale, b_scale, bias)
502 elif SM_VERSION_NUM >= 90:
503 # Hopper
504 cutlass_scaled_mm_sm90(c, a, b, a_scale, b_scale, bias)
506 elif SM_VERSION_NUM >= 80:
507 # Ampere
508 cutlass_scaled_mm_sm80(c, a, b, a_scale, b_scale, bias)
510 elif SM_VERSION_NUM >= 75:
511 # Turing
512 cutlass_scaled_mm_sm75(c, a, b, a_scale, b_scale, bias)