Coverage for src/flag_gems/fused/cutlass_scaled_mm.py: 18%
191 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-19 02:32 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-19 02:32 +0800
1from typing import Callable, Optional
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils.device_info import get_device_capability
9SCALE_BLOCK_K, SCALE_BLOCK_N = 128, 128
12def get_sm_version_num():
13 major, minor = get_device_capability()
14 return major * 10 + minor
17SM_VERSION_NUM = get_sm_version_num()
20def get_block_wise_smm_configs():
21 tile_configs = [
22 # (TILE_M, TILE_N, num_stages, num_warps)
23 (32, 64, 5, 2),
24 (64, 32, 5, 2),
25 (64, 128, 4, 4),
26 (64, 256, 4, 4),
27 (128, 32, 4, 4),
28 (128, 64, 4, 4),
29 (128, 128, 4, 4),
30 (128, 256, 3, 8),
31 (256, 64, 4, 4),
32 (256, 128, 3, 8),
33 ]
35 return [
36 triton.Config(
37 {
38 "TILE_M": TILE_M,
39 "TILE_N": TILE_N,
40 "TILE_K": SCALE_BLOCK_K,
41 "SWIZZLE_GROUP_M": 8,
42 },
43 num_stages=stages,
44 num_warps=warps,
45 )
46 for TILE_M, TILE_N, stages, warps in tile_configs
47 ]
50@triton.jit
51def grouped_launch(
52 pid, M, N, TILE_M: tl.constexpr, TILE_N: tl.constexpr, SWIZZLE_GROUP_M: tl.constexpr
53):
54 grid_m = tl.cdiv(M, TILE_M)
55 grid_n = tl.cdiv(N, TILE_N)
57 width = SWIZZLE_GROUP_M * grid_n
58 group_id = pid // width
59 group_size = tl.minimum(grid_m - group_id * SWIZZLE_GROUP_M, SWIZZLE_GROUP_M)
61 pid_m = group_id * SWIZZLE_GROUP_M + (pid % group_size)
62 pid_n = (pid % width) // group_size
64 return pid_m, pid_n
67# block-wise dequantization kernel implemention
68# this kernel supports many `SCALE_BLOCK_K, SCALE_BLOCK_N` cases
69# as long as `TILE_K == SCALE_BLOCK_K` and `TILE_N % SCALE_BLOCK_N == 0`
70@triton.autotune(
71 configs=get_block_wise_smm_configs(),
72 key=["_M_NPO2", "N", "K"],
73)
74@triton.jit
75def _block_wise_smm_kernel(
76 a_ptr,
77 b_ptr,
78 c_ptr,
79 a_scale_ptr,
80 b_scale_ptr,
81 M,
82 N,
83 K,
84 _M_NPO2: tl.constexpr,
85 SCALE_BLOCK_N,
86 SCALE_BLOCK_K,
87 stride_am,
88 stride_ak,
89 stride_bk,
90 stride_bn,
91 stride_cm,
92 stride_cn,
93 stride_Ascale_m,
94 stride_Ascale_k,
95 stride_Bscale_k,
96 stride_Bscale_n,
97 TILE_M: tl.constexpr,
98 TILE_N: tl.constexpr,
99 TILE_K: tl.constexpr,
100 SWIZZLE_GROUP_M: tl.constexpr,
101):
102 pid = tl.program_id(0)
103 pid_m, pid_n = grouped_launch(pid, M, N, TILE_M, TILE_N, SWIZZLE_GROUP_M)
105 offs_am = (pid_m * TILE_M + tl.arange(0, TILE_M)) % M
106 offs_bn = (pid_n * TILE_N + tl.arange(0, TILE_N)) % N
107 offs_k = tl.arange(0, TILE_K)
108 a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
109 b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
111 a_scale_ptrs = a_scale_ptr + offs_am * stride_Ascale_m
112 offs_bsn = offs_bn // SCALE_BLOCK_N
113 b_scale_ptrs = b_scale_ptr + offs_bsn * stride_Bscale_n
115 acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
116 for k in range(0, tl.cdiv(K, TILE_K)):
117 k_remaining = K - k * TILE_K
118 a = tl.load(a_ptrs, mask=offs_k[None, :] < k_remaining, other=0.0)
119 b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0)
120 offs_ks = k * TILE_K // SCALE_BLOCK_K
121 a_scale = tl.load(a_scale_ptrs + offs_ks * stride_Ascale_k)
122 b_scale = tl.load(b_scale_ptrs + offs_ks * stride_Bscale_k)
123 acc += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
124 a_ptrs += TILE_K * stride_ak
125 b_ptrs += TILE_K * stride_bk
127 acc = acc.to(c_ptr.dtype.element_ty)
129 offs_cm = pid_m * TILE_M + tl.arange(0, TILE_M)
130 offs_cn = pid_n * TILE_N + tl.arange(0, TILE_N)
131 c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
132 mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
133 tl.store(c_ptrs, acc, mask=mask)
136def _block_wise_128_smm_launcher(
137 c: torch.Tensor,
138 a: torch.Tensor,
139 b: torch.Tensor,
140 a_scale: torch.Tensor,
141 b_scale: torch.Tensor,
142) -> torch.Tensor:
143 global SCALE_BLOCK_K, SCALE_BLOCK_N
144 SCALE_BLOCK_K, SCALE_BLOCK_N = 128, 128
145 M, K = a.shape
146 _, N = b.shape
147 _M_NPO2 = triton.next_power_of_2(M)
149 grid = lambda META: (
150 triton.cdiv(M, META["TILE_M"]) * triton.cdiv(N, META["TILE_N"]),
151 )
153 _block_wise_smm_kernel[grid](
154 a,
155 b,
156 c,
157 a_scale,
158 b_scale,
159 M,
160 N,
161 K,
162 _M_NPO2,
163 SCALE_BLOCK_N,
164 SCALE_BLOCK_K,
165 a.stride(0),
166 a.stride(1),
167 b.stride(0),
168 b.stride(1),
169 c.stride(0),
170 c.stride(1),
171 a_scale.stride(0),
172 a_scale.stride(1),
173 b_scale.stride(0),
174 b_scale.stride(1),
175 )
177 return c
180# per-tensor and per-token dequantization kernel implemention
181@triton.autotune(
182 configs=[
183 triton.Config({"TILE_M": 64, "TILE_N": 64, "TILE_K": 256}),
184 triton.Config({"TILE_M": 64, "TILE_N": 128, "TILE_K": 128}),
185 triton.Config({"TILE_M": 128, "TILE_N": 128, "TILE_K": 128}),
186 ],
187 key=["_M_NPO2", "N", "K"],
188)
189@triton.jit
190def _pertensor_or_pertoken_smm_kernel(
191 c_ptr,
192 a_ptr,
193 b_ptr,
194 a_scale_ptr,
195 b_scale_ptr,
196 bias_ptr,
197 M,
198 N,
199 K,
200 _M_NPO2,
201 stride_am,
202 stride_ak,
203 stride_bk,
204 stride_bn,
205 stride_cm,
206 stride_cn,
207 ACC_DTYPE: tl.constexpr,
208 TILE_M: tl.constexpr,
209 TILE_N: tl.constexpr,
210 TILE_K: tl.constexpr,
211 IS_PER_TOKEN_A: tl.constexpr,
212 IS_PER_TOKEN_B: tl.constexpr,
213):
214 if IS_PER_TOKEN_A:
215 TILE_SIZE_SCALE_A: tl.constexpr = TILE_M
216 else:
217 TILE_SIZE_SCALE_A: tl.constexpr = 1
219 if IS_PER_TOKEN_B:
220 TILE_SIZE_SCALE_B: tl.constexpr = TILE_N
221 else:
222 TILE_SIZE_SCALE_B: tl.constexpr = 1
224 pid = tl.program_id(axis=0)
225 num_pid_n = tl.cdiv(N, TILE_N)
226 pid_m = pid // num_pid_n
227 pid_n = pid % num_pid_n
229 acc = tl.zeros((TILE_M, TILE_N), dtype=ACC_DTYPE)
231 offsets_am = pid_m * TILE_M + tl.arange(0, TILE_M).to(tl.int64)
232 masks_am = offsets_am < M
234 offsets_bn = pid_n * TILE_N + tl.arange(0, TILE_N).to(tl.int64)
235 masks_bn = offsets_bn < N
237 offsets_k = tl.arange(0, TILE_K).to(tl.int64)
238 offsets_a = stride_am * offsets_am[:, None] + stride_ak * offsets_k[None, :]
239 offsets_b = stride_bk * offsets_k[:, None] + stride_bn * offsets_bn[None, :]
241 offsets_scale_am = (
242 tl.arange(0, TILE_SIZE_SCALE_A) + (TILE_SIZE_SCALE_A > 1) * pid_m * TILE_M
243 )
244 masks_scale_am = offsets_scale_am < M
246 offsets_scale_bn = (
247 tl.arange(0, TILE_SIZE_SCALE_B) + (TILE_SIZE_SCALE_B > 1) * pid_n * TILE_N
248 )
249 masks_scale_bn = offsets_scale_bn < N
251 a_ptrs = a_ptr + offsets_a
252 b_ptrs = b_ptr + offsets_b
254 scale_a_ptrs = a_scale_ptr + offsets_scale_am
255 scale_b_ptrs = b_scale_ptr + offsets_scale_bn
257 for k in range(0, tl.cdiv(K, TILE_K)):
258 masks_k = offsets_k < K
259 masks_a = masks_am[:, None] & masks_k[None, :]
260 a = tl.load(a_ptrs, mask=masks_a)
262 masks_b = masks_k[:, None] & masks_bn[None, :]
263 b = tl.load(b_ptrs, mask=masks_b)
265 acc = tl.dot(a, b, acc, out_dtype=ACC_DTYPE)
267 offsets_k += TILE_K
268 a_ptrs += TILE_K * stride_ak
269 b_ptrs += TILE_K * stride_bk
271 masks_scale_a = masks_scale_am[:, None] & (tl.arange(0, 1) < 1)[:, None]
272 a_scale = tl.load(scale_a_ptrs[:, None], masks_scale_a)
273 a_scale = a_scale.broadcast_to((TILE_M, 1))
274 acc = a_scale * acc.to(tl.float32)
276 masks_scale_b = masks_scale_bn[:, None] & (tl.arange(0, 1) < 1)[None, :]
277 b_scale = tl.load(scale_b_ptrs[:, None], masks_scale_b)
278 b_scale = b_scale.broadcast_to((TILE_N, 1))
279 acc = b_scale.T * acc.to(tl.float32)
281 c = acc.to(c_ptr.type.element_ty)
283 if bias_ptr:
284 offsets_bias = offsets_bn
285 bias_ptrs = bias_ptr + offsets_bias
286 bias_mask = offsets_bias < N
287 bias = tl.load(bias_ptrs, bias_mask)
288 c += bias
290 offs_cm = pid_m * TILE_M + tl.arange(0, TILE_M).to(tl.int64)
291 offs_cn = pid_n * TILE_N + tl.arange(0, TILE_N).to(tl.int64)
292 offs_cm = offs_cm.to(tl.int64)
293 offs_cn = offs_cn.to(tl.int64)
294 c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
295 c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
297 tl.store(c_ptrs, c, mask=c_mask)
300def _pertensor_or_pertoken_smm_launcher(
301 c: torch.Tensor,
302 a: torch.Tensor,
303 b: torch.Tensor,
304 a_scale: torch.Tensor,
305 b_scale: torch.Tensor,
306 bias: torch.Tensor | None = None,
307) -> torch.Tensor:
308 M, K = a.shape
309 _, N = b.shape
311 grid = lambda META: (
312 triton.cdiv(M, META["TILE_M"]) * triton.cdiv(N, META["TILE_N"]),
313 )
315 ACC_DTYPE = tl.float32 if a.is_floating_point() else tl.int32
317 _M_NPO2 = triton.next_power_of_2(M)
319 IS_PER_TOKEN_A = a_scale.numel() == M
320 IS_PER_TOKEN_B = b_scale.numel() == N
322 _pertensor_or_pertoken_smm_kernel[grid](
323 c,
324 a,
325 b,
326 a_scale,
327 b_scale,
328 bias,
329 M,
330 N,
331 K,
332 _M_NPO2,
333 a.stride(0),
334 a.stride(1),
335 b.stride(0),
336 b.stride(1),
337 c.stride(0),
338 c.stride(1),
339 ACC_DTYPE=ACC_DTYPE,
340 IS_PER_TOKEN_A=IS_PER_TOKEN_A,
341 IS_PER_TOKEN_B=IS_PER_TOKEN_B,
342 )
344 return c
347cutlass_scaled_mm_sm90_fp8 = _pertensor_or_pertoken_smm_launcher
349cutlass_scaled_mm_sm90_int8 = _pertensor_or_pertoken_smm_launcher
351cutlass_scaled_mm_blockwise_sm90_fp8 = _block_wise_128_smm_launcher
354def dispatch_scaled_mm(
355 c: torch.Tensor,
356 a: torch.Tensor,
357 b: torch.Tensor,
358 a_scale: torch.Tensor,
359 b_scale: torch.Tensor,
360 bias: Optional[torch.Tensor],
361 fp8_func: Callable,
362 int8_func: Optional[Callable],
363 blockwise_func: Callable,
364) -> None:
365 assert a_scale.dtype == torch.float32, "a_scale must be float32"
366 assert b_scale.dtype == torch.float32, "b_scale must be float32"
368 if (a_scale.numel() == 1 or a_scale.numel() == a.size(0)) and (
369 b_scale.numel() == 1 or b_scale.numel() == b.size(1)
370 ):
371 assert a_scale.is_contiguous(), "a_scale must be contiguous"
372 assert b_scale.is_contiguous(), "b_scale must be contiguous"
374 if a.dtype == torch.float8_e4m3fn:
375 fp8_func(c, a, b, a_scale, b_scale, bias)
376 else:
377 assert a.dtype == torch.int8, f"Unsupported dtype: {a.dtype}"
379 if int8_func is not None:
380 int8_func(c, a, b, a_scale, b_scale, bias)
381 else:
382 raise RuntimeError(
383 f"Int8 not supported on SM{SM_VERSION_NUM}. "
384 f"Use FP8 quantization instead, or run on older arch (SM < 100)."
385 )
386 else:
387 assert a_scale.dim() == 2, "a_scale must be 2D tensor for blockwise scaling"
388 assert b_scale.dim() == 2, "b_scale must be 2D tensor for blockwise scaling"
390 if SM_VERSION_NUM >= 90:
391 assert a.size(0) == a_scale.size(0), (
392 f"a_scale must have same first dimension as a: "
393 f"a.shape[0]={a.size(0)}, a_scale.shape[0]={a_scale.size(0)}"
394 )
395 assert triton.cdiv(a.size(1), 128) == a_scale.size(1), (
396 f"a_scale second dimension mismatch: "
397 f"triton.cdiv({a.size(1)}, 128)={triton.cdiv(a.size(1), 128)} != "
398 f"a_scale.shape[1]={a_scale.size(1)}"
399 )
401 assert triton.cdiv(b.size(0), 128) == b_scale.size(0), (
402 f"b_scale first dimension mismatch: "
403 f"triton.cdiv({b.size(0)}, 128)={triton.cdiv(b.size(0), 128)} != "
404 f"b_scale.shape[0]={b_scale.size(0)}"
405 )
406 assert triton.cdiv(b.size(1), 128) == b_scale.size(1), (
407 f"b_scale second dimension mismatch: "
408 f"triton.cdiv({b.size(1)}, 128)={triton.cdiv(b.size(1), 128)} != "
409 f"b_scale.shape[1]={b_scale.size(1)}"
410 )
412 assert bias is None, "Bias not yet supported for blockwise scaled_mm"
414 blockwise_func(c, a, b, a_scale, b_scale)
417def cutlass_scaled_mm_sm90(
418 c: torch.Tensor,
419 a: torch.Tensor,
420 b: torch.Tensor,
421 a_scale: torch.Tensor,
422 b_scale: torch.Tensor,
423 bias: Optional[torch.Tensor] = None,
424) -> None:
425 dispatch_scaled_mm(
426 c=c,
427 a=a,
428 b=b,
429 a_scale=a_scale,
430 b_scale=b_scale,
431 bias=bias,
432 fp8_func=cutlass_scaled_mm_sm90_fp8,
433 int8_func=cutlass_scaled_mm_sm90_int8,
434 blockwise_func=cutlass_scaled_mm_blockwise_sm90_fp8,
435 )
438def cutlass_scaled_mm_sm120(*args, **kwargs):
439 raise NotImplementedError("cutlass_scaled_mm_sm120 is not yet implemented. ")
442def cutlass_scaled_mm_sm100(*args, **kwargs):
443 raise NotImplementedError("cutlass_scaled_mm_sm100 is not yet implemented. ")
446def cutlass_scaled_mm_sm89(*args, **kwargs):
447 raise NotImplementedError("cutlass_scaled_mm_sm89 is not yet implemented. ")
450def cutlass_scaled_mm_sm80(*args, **kwargs):
451 raise NotImplementedError("cutlass_scaled_mm_sm80 is not yet implemented. ")
454def cutlass_scaled_mm_sm75(*args, **kwargs):
455 raise NotImplementedError("cutlass_scaled_mm_sm75 is not yet implemented. ")
458def cutlass_scaled_mm(
459 c: torch.Tensor,
460 a: torch.Tensor,
461 b: torch.Tensor,
462 a_scale: torch.Tensor,
463 b_scale: torch.Tensor,
464 bias: Optional[torch.Tensor] = None,
465) -> torch.Tensor:
466 assert (
467 a.dim() == 2 and b.dim() == 2 and c.dim() == 2
468 ), "All inputs must be 2D tensors"
470 assert c.size(0) == a.size(0), "Number of rows in c must equal number of rows in a"
471 assert a.size(1) == b.size(
472 0
473 ), "Number of columns in a must equal number of rows in b"
474 assert b.size(1) == c.size(
475 1
476 ), "Number of columns in b must equal number of columns in c"
478 assert a.stride(1) == 1 and c.stride(1) == 1, "a and c must be row-major"
480 assert b.stride(0) == 1, "b must be column-major"
482 assert c.stride(0) % 16 == 0, "Row stride of c must be 16-byte aligned"
483 assert b.stride(1) % 16 == 0, "Column stride of b must be 16-byte aligned"
485 if bias is not None:
486 assert bias.numel() == b.size(
487 1
488 ), f"Bias size {bias.numel()} must equal number of columns in b {b.size(1)}"
489 assert bias.is_contiguous(), "Bias must be contiguous"
490 assert bias.dim() == 1, "Bias must be a 1D tensor"
492 if SM_VERSION_NUM >= 120:
493 cutlass_scaled_mm_sm120(c, a, b, a_scale, b_scale, bias)
495 elif SM_VERSION_NUM >= 100:
496 cutlass_scaled_mm_sm100(c, a, b, a_scale, b_scale, bias)
498 elif SM_VERSION_NUM >= 90:
499 # Hopper
500 cutlass_scaled_mm_sm90(c, a, b, a_scale, b_scale, bias)
502 elif SM_VERSION_NUM >= 80:
503 # Ampere
504 cutlass_scaled_mm_sm80(c, a, b, a_scale, b_scale, bias)
506 elif SM_VERSION_NUM >= 75:
507 # Turing
508 cutlass_scaled_mm_sm75(c, a, b, a_scale, b_scale, bias)