Coverage for src/flag_gems/fused/cutlass_scaled_mm.py: 19%

194 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-26 15:32 +0800

1import logging 

2from typing import Callable, Optional 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.utils.device_info import get_device_capability 

9 

10logger = logging.getLogger(__name__) 

11 

12SCALE_BLOCK_K, SCALE_BLOCK_N = 128, 128 

13 

14 

15def get_sm_version_num(): 

16 major, minor = get_device_capability() 

17 return major * 10 + minor 

18 

19 

20SM_VERSION_NUM = get_sm_version_num() 

21 

22 

23def get_block_wise_smm_configs(): 

24 tile_configs = [ 

25 # (TILE_M, TILE_N, num_stages, num_warps) 

26 (32, 64, 5, 2), 

27 (64, 32, 5, 2), 

28 (64, 128, 4, 4), 

29 (64, 256, 4, 4), 

30 (128, 32, 4, 4), 

31 (128, 64, 4, 4), 

32 (128, 128, 4, 4), 

33 (128, 256, 3, 8), 

34 (256, 64, 4, 4), 

35 (256, 128, 3, 8), 

36 ] 

37 

38 return [ 

39 triton.Config( 

40 { 

41 "TILE_M": TILE_M, 

42 "TILE_N": TILE_N, 

43 "TILE_K": SCALE_BLOCK_K, 

44 "SWIZZLE_GROUP_M": 8, 

45 }, 

46 num_stages=stages, 

47 num_warps=warps, 

48 ) 

49 for TILE_M, TILE_N, stages, warps in tile_configs 

50 ] 

51 

52 

53@triton.jit 

54def grouped_launch( 

55 pid, M, N, TILE_M: tl.constexpr, TILE_N: tl.constexpr, SWIZZLE_GROUP_M: tl.constexpr 

56): 

57 grid_m = tl.cdiv(M, TILE_M) 

58 grid_n = tl.cdiv(N, TILE_N) 

59 

60 width = SWIZZLE_GROUP_M * grid_n 

61 group_id = pid // width 

62 group_size = tl.minimum(grid_m - group_id * SWIZZLE_GROUP_M, SWIZZLE_GROUP_M) 

63 

64 pid_m = group_id * SWIZZLE_GROUP_M + (pid % group_size) 

65 pid_n = (pid % width) // group_size 

66 

67 return pid_m, pid_n 

68 

69 

70# block-wise dequantization kernel implemention 

71# this kernel supports many `SCALE_BLOCK_K, SCALE_BLOCK_N` cases 

72# as long as `TILE_K == SCALE_BLOCK_K` and `TILE_N % SCALE_BLOCK_N == 0` 

73@triton.autotune( 

74 configs=get_block_wise_smm_configs(), 

75 key=["_M_NPO2", "N", "K"], 

76) 

77@triton.jit 

78def _block_wise_smm_kernel( 

79 a_ptr, 

80 b_ptr, 

81 c_ptr, 

82 a_scale_ptr, 

83 b_scale_ptr, 

84 M, 

85 N, 

86 K, 

87 _M_NPO2: tl.constexpr, 

88 SCALE_BLOCK_N, 

89 SCALE_BLOCK_K, 

90 stride_am, 

91 stride_ak, 

92 stride_bk, 

93 stride_bn, 

94 stride_cm, 

95 stride_cn, 

96 stride_Ascale_m, 

97 stride_Ascale_k, 

98 stride_Bscale_k, 

99 stride_Bscale_n, 

100 TILE_M: tl.constexpr, 

101 TILE_N: tl.constexpr, 

102 TILE_K: tl.constexpr, 

103 SWIZZLE_GROUP_M: tl.constexpr, 

104): 

105 pid = tl.program_id(0) 

106 pid_m, pid_n = grouped_launch(pid, M, N, TILE_M, TILE_N, SWIZZLE_GROUP_M) 

107 

108 offs_am = (pid_m * TILE_M + tl.arange(0, TILE_M)) % M 

109 offs_bn = (pid_n * TILE_N + tl.arange(0, TILE_N)) % N 

110 offs_k = tl.arange(0, TILE_K) 

111 a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) 

112 b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) 

113 

114 a_scale_ptrs = a_scale_ptr + offs_am * stride_Ascale_m 

115 offs_bsn = offs_bn // SCALE_BLOCK_N 

116 b_scale_ptrs = b_scale_ptr + offs_bsn * stride_Bscale_n 

117 

118 acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32) 

119 for k in range(0, tl.cdiv(K, TILE_K)): 

120 k_remaining = K - k * TILE_K 

121 a = tl.load(a_ptrs, mask=offs_k[None, :] < k_remaining, other=0.0) 

122 b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0) 

123 offs_ks = k * TILE_K // SCALE_BLOCK_K 

124 a_scale = tl.load(a_scale_ptrs + offs_ks * stride_Ascale_k) 

125 b_scale = tl.load(b_scale_ptrs + offs_ks * stride_Bscale_k) 

126 acc += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :] 

127 a_ptrs += TILE_K * stride_ak 

128 b_ptrs += TILE_K * stride_bk 

129 

130 acc = acc.to(c_ptr.dtype.element_ty) 

131 

132 offs_cm = pid_m * TILE_M + tl.arange(0, TILE_M) 

133 offs_cn = pid_n * TILE_N + tl.arange(0, TILE_N) 

134 c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] 

135 mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) 

136 tl.store(c_ptrs, acc, mask=mask) 

137 

138 

139def _block_wise_128_smm_launcher( 

140 c: torch.Tensor, 

141 a: torch.Tensor, 

142 b: torch.Tensor, 

143 a_scale: torch.Tensor, 

144 b_scale: torch.Tensor, 

145) -> torch.Tensor: 

146 global SCALE_BLOCK_K, SCALE_BLOCK_N 

147 SCALE_BLOCK_K, SCALE_BLOCK_N = 128, 128 

148 M, K = a.shape 

149 _, N = b.shape 

150 _M_NPO2 = triton.next_power_of_2(M) 

151 

152 grid = lambda META: ( 

153 triton.cdiv(M, META["TILE_M"]) * triton.cdiv(N, META["TILE_N"]), 

154 ) 

155 

156 _block_wise_smm_kernel[grid]( 

157 a, 

158 b, 

159 c, 

160 a_scale, 

161 b_scale, 

162 M, 

163 N, 

164 K, 

165 _M_NPO2, 

166 SCALE_BLOCK_N, 

167 SCALE_BLOCK_K, 

168 a.stride(0), 

169 a.stride(1), 

170 b.stride(0), 

171 b.stride(1), 

172 c.stride(0), 

173 c.stride(1), 

174 a_scale.stride(0), 

175 a_scale.stride(1), 

176 b_scale.stride(0), 

177 b_scale.stride(1), 

178 ) 

179 

180 return c 

181 

182 

183# per-tensor and per-token dequantization kernel implemention 

184@triton.autotune( 

185 configs=[ 

186 triton.Config({"TILE_M": 64, "TILE_N": 64, "TILE_K": 256}), 

187 triton.Config({"TILE_M": 64, "TILE_N": 128, "TILE_K": 128}), 

188 triton.Config({"TILE_M": 128, "TILE_N": 128, "TILE_K": 128}), 

189 ], 

190 key=["_M_NPO2", "N", "K"], 

191) 

192@triton.jit 

193def _pertensor_or_pertoken_smm_kernel( 

194 c_ptr, 

195 a_ptr, 

196 b_ptr, 

197 a_scale_ptr, 

198 b_scale_ptr, 

199 bias_ptr, 

200 M, 

201 N, 

202 K, 

203 _M_NPO2, 

204 stride_am, 

205 stride_ak, 

206 stride_bk, 

207 stride_bn, 

208 stride_cm, 

209 stride_cn, 

210 ACC_DTYPE: tl.constexpr, 

211 TILE_M: tl.constexpr, 

212 TILE_N: tl.constexpr, 

213 TILE_K: tl.constexpr, 

214 IS_PER_TOKEN_A: tl.constexpr, 

215 IS_PER_TOKEN_B: tl.constexpr, 

216): 

217 if IS_PER_TOKEN_A: 

218 TILE_SIZE_SCALE_A: tl.constexpr = TILE_M 

219 else: 

220 TILE_SIZE_SCALE_A: tl.constexpr = 1 

221 

222 if IS_PER_TOKEN_B: 

223 TILE_SIZE_SCALE_B: tl.constexpr = TILE_N 

224 else: 

225 TILE_SIZE_SCALE_B: tl.constexpr = 1 

226 

227 pid = tl.program_id(axis=0) 

228 num_pid_n = tl.cdiv(N, TILE_N) 

229 pid_m = pid // num_pid_n 

230 pid_n = pid % num_pid_n 

231 

232 acc = tl.zeros((TILE_M, TILE_N), dtype=ACC_DTYPE) 

233 

234 offsets_am = pid_m * TILE_M + tl.arange(0, TILE_M).to(tl.int64) 

235 masks_am = offsets_am < M 

236 

237 offsets_bn = pid_n * TILE_N + tl.arange(0, TILE_N).to(tl.int64) 

238 masks_bn = offsets_bn < N 

239 

240 offsets_k = tl.arange(0, TILE_K).to(tl.int64) 

241 offsets_a = stride_am * offsets_am[:, None] + stride_ak * offsets_k[None, :] 

242 offsets_b = stride_bk * offsets_k[:, None] + stride_bn * offsets_bn[None, :] 

243 

244 offsets_scale_am = ( 

245 tl.arange(0, TILE_SIZE_SCALE_A) + (TILE_SIZE_SCALE_A > 1) * pid_m * TILE_M 

246 ) 

247 masks_scale_am = offsets_scale_am < M 

248 

249 offsets_scale_bn = ( 

250 tl.arange(0, TILE_SIZE_SCALE_B) + (TILE_SIZE_SCALE_B > 1) * pid_n * TILE_N 

251 ) 

252 masks_scale_bn = offsets_scale_bn < N 

253 

254 a_ptrs = a_ptr + offsets_a 

255 b_ptrs = b_ptr + offsets_b 

256 

257 scale_a_ptrs = a_scale_ptr + offsets_scale_am 

258 scale_b_ptrs = b_scale_ptr + offsets_scale_bn 

259 

260 for k in range(0, tl.cdiv(K, TILE_K)): 

261 masks_k = offsets_k < K 

262 masks_a = masks_am[:, None] & masks_k[None, :] 

263 a = tl.load(a_ptrs, mask=masks_a) 

264 

265 masks_b = masks_k[:, None] & masks_bn[None, :] 

266 b = tl.load(b_ptrs, mask=masks_b) 

267 

268 acc = tl.dot(a, b, acc, out_dtype=ACC_DTYPE) 

269 

270 offsets_k += TILE_K 

271 a_ptrs += TILE_K * stride_ak 

272 b_ptrs += TILE_K * stride_bk 

273 

274 masks_scale_a = masks_scale_am[:, None] & (tl.arange(0, 1) < 1)[:, None] 

275 a_scale = tl.load(scale_a_ptrs[:, None], masks_scale_a) 

276 a_scale = a_scale.broadcast_to((TILE_M, 1)) 

277 acc = a_scale * acc.to(tl.float32) 

278 

279 masks_scale_b = masks_scale_bn[:, None] & (tl.arange(0, 1) < 1)[None, :] 

280 b_scale = tl.load(scale_b_ptrs[:, None], masks_scale_b) 

281 b_scale = b_scale.broadcast_to((TILE_N, 1)) 

282 acc = b_scale.T * acc.to(tl.float32) 

283 

284 c = acc.to(c_ptr.type.element_ty) 

285 

286 if bias_ptr: 

287 offsets_bias = offsets_bn 

288 bias_ptrs = bias_ptr + offsets_bias 

289 bias_mask = offsets_bias < N 

290 bias = tl.load(bias_ptrs, bias_mask) 

291 c += bias 

292 

293 offs_cm = pid_m * TILE_M + tl.arange(0, TILE_M).to(tl.int64) 

294 offs_cn = pid_n * TILE_N + tl.arange(0, TILE_N).to(tl.int64) 

295 offs_cm = offs_cm.to(tl.int64) 

296 offs_cn = offs_cn.to(tl.int64) 

297 c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] 

298 c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) 

299 

300 tl.store(c_ptrs, c, mask=c_mask) 

301 

302 

303def _pertensor_or_pertoken_smm_launcher( 

304 c: torch.Tensor, 

305 a: torch.Tensor, 

306 b: torch.Tensor, 

307 a_scale: torch.Tensor, 

308 b_scale: torch.Tensor, 

309 bias: torch.Tensor | None = None, 

310) -> torch.Tensor: 

311 M, K = a.shape 

312 _, N = b.shape 

313 

314 grid = lambda META: ( 

315 triton.cdiv(M, META["TILE_M"]) * triton.cdiv(N, META["TILE_N"]), 

316 ) 

317 

318 ACC_DTYPE = tl.float32 if a.is_floating_point() else tl.int32 

319 

320 _M_NPO2 = triton.next_power_of_2(M) 

321 

322 IS_PER_TOKEN_A = a_scale.numel() == M 

323 IS_PER_TOKEN_B = b_scale.numel() == N 

324 

325 _pertensor_or_pertoken_smm_kernel[grid]( 

326 c, 

327 a, 

328 b, 

329 a_scale, 

330 b_scale, 

331 bias, 

332 M, 

333 N, 

334 K, 

335 _M_NPO2, 

336 a.stride(0), 

337 a.stride(1), 

338 b.stride(0), 

339 b.stride(1), 

340 c.stride(0), 

341 c.stride(1), 

342 ACC_DTYPE=ACC_DTYPE, 

343 IS_PER_TOKEN_A=IS_PER_TOKEN_A, 

344 IS_PER_TOKEN_B=IS_PER_TOKEN_B, 

345 ) 

346 

347 return c 

348 

349 

350cutlass_scaled_mm_sm90_fp8 = _pertensor_or_pertoken_smm_launcher 

351 

352cutlass_scaled_mm_sm90_int8 = _pertensor_or_pertoken_smm_launcher 

353 

354cutlass_scaled_mm_blockwise_sm90_fp8 = _block_wise_128_smm_launcher 

355 

356 

357def dispatch_scaled_mm( 

358 c: torch.Tensor, 

359 a: torch.Tensor, 

360 b: torch.Tensor, 

361 a_scale: torch.Tensor, 

362 b_scale: torch.Tensor, 

363 bias: Optional[torch.Tensor], 

364 fp8_func: Callable, 

365 int8_func: Optional[Callable], 

366 blockwise_func: Callable, 

367) -> None: 

368 assert a_scale.dtype == torch.float32, "a_scale must be float32" 

369 assert b_scale.dtype == torch.float32, "b_scale must be float32" 

370 

371 if (a_scale.numel() == 1 or a_scale.numel() == a.size(0)) and ( 

372 b_scale.numel() == 1 or b_scale.numel() == b.size(1) 

373 ): 

374 assert a_scale.is_contiguous(), "a_scale must be contiguous" 

375 assert b_scale.is_contiguous(), "b_scale must be contiguous" 

376 

377 if a.dtype == torch.float8_e4m3fn: 

378 fp8_func(c, a, b, a_scale, b_scale, bias) 

379 else: 

380 assert a.dtype == torch.int8, f"Unsupported dtype: {a.dtype}" 

381 

382 if int8_func is not None: 

383 int8_func(c, a, b, a_scale, b_scale, bias) 

384 else: 

385 raise RuntimeError( 

386 f"Int8 not supported on SM{SM_VERSION_NUM}. " 

387 f"Use FP8 quantization instead, or run on older arch (SM < 100)." 

388 ) 

389 else: 

390 assert a_scale.dim() == 2, "a_scale must be 2D tensor for blockwise scaling" 

391 assert b_scale.dim() == 2, "b_scale must be 2D tensor for blockwise scaling" 

392 

393 if SM_VERSION_NUM >= 90: 

394 assert a.size(0) == a_scale.size(0), ( 

395 f"a_scale must have same first dimension as a: " 

396 f"a.shape[0]={a.size(0)}, a_scale.shape[0]={a_scale.size(0)}" 

397 ) 

398 assert triton.cdiv(a.size(1), 128) == a_scale.size(1), ( 

399 f"a_scale second dimension mismatch: " 

400 f"triton.cdiv({a.size(1)}, 128)={triton.cdiv(a.size(1), 128)} != " 

401 f"a_scale.shape[1]={a_scale.size(1)}" 

402 ) 

403 

404 assert triton.cdiv(b.size(0), 128) == b_scale.size(0), ( 

405 f"b_scale first dimension mismatch: " 

406 f"triton.cdiv({b.size(0)}, 128)={triton.cdiv(b.size(0), 128)} != " 

407 f"b_scale.shape[0]={b_scale.size(0)}" 

408 ) 

409 assert triton.cdiv(b.size(1), 128) == b_scale.size(1), ( 

410 f"b_scale second dimension mismatch: " 

411 f"triton.cdiv({b.size(1)}, 128)={triton.cdiv(b.size(1), 128)} != " 

412 f"b_scale.shape[1]={b_scale.size(1)}" 

413 ) 

414 

415 assert bias is None, "Bias not yet supported for blockwise scaled_mm" 

416 

417 blockwise_func(c, a, b, a_scale, b_scale) 

418 

419 

420def cutlass_scaled_mm_sm90( 

421 c: torch.Tensor, 

422 a: torch.Tensor, 

423 b: torch.Tensor, 

424 a_scale: torch.Tensor, 

425 b_scale: torch.Tensor, 

426 bias: Optional[torch.Tensor] = None, 

427) -> None: 

428 dispatch_scaled_mm( 

429 c=c, 

430 a=a, 

431 b=b, 

432 a_scale=a_scale, 

433 b_scale=b_scale, 

434 bias=bias, 

435 fp8_func=cutlass_scaled_mm_sm90_fp8, 

436 int8_func=cutlass_scaled_mm_sm90_int8, 

437 blockwise_func=cutlass_scaled_mm_blockwise_sm90_fp8, 

438 ) 

439 

440 

441def cutlass_scaled_mm_sm120(*args, **kwargs): 

442 raise NotImplementedError("cutlass_scaled_mm_sm120 is not yet implemented. ") 

443 

444 

445def cutlass_scaled_mm_sm100(*args, **kwargs): 

446 raise NotImplementedError("cutlass_scaled_mm_sm100 is not yet implemented. ") 

447 

448 

449def cutlass_scaled_mm_sm89(*args, **kwargs): 

450 raise NotImplementedError("cutlass_scaled_mm_sm89 is not yet implemented. ") 

451 

452 

453def cutlass_scaled_mm_sm80(*args, **kwargs): 

454 raise NotImplementedError("cutlass_scaled_mm_sm80 is not yet implemented. ") 

455 

456 

457def cutlass_scaled_mm_sm75(*args, **kwargs): 

458 raise NotImplementedError("cutlass_scaled_mm_sm75 is not yet implemented. ") 

459 

460 

461def cutlass_scaled_mm( 

462 c: torch.Tensor, 

463 a: torch.Tensor, 

464 b: torch.Tensor, 

465 a_scale: torch.Tensor, 

466 b_scale: torch.Tensor, 

467 bias: Optional[torch.Tensor] = None, 

468) -> torch.Tensor: 

469 logger.debug("GEMS CUTLASS SCALED MM") 

470 assert ( 

471 a.dim() == 2 and b.dim() == 2 and c.dim() == 2 

472 ), "All inputs must be 2D tensors" 

473 

474 assert c.size(0) == a.size(0), "Number of rows in c must equal number of rows in a" 

475 assert a.size(1) == b.size( 

476 0 

477 ), "Number of columns in a must equal number of rows in b" 

478 assert b.size(1) == c.size( 

479 1 

480 ), "Number of columns in b must equal number of columns in c" 

481 

482 assert a.stride(1) == 1 and c.stride(1) == 1, "a and c must be row-major" 

483 

484 assert b.stride(0) == 1, "b must be column-major" 

485 

486 assert c.stride(0) % 16 == 0, "Row stride of c must be 16-byte aligned" 

487 assert b.stride(1) % 16 == 0, "Column stride of b must be 16-byte aligned" 

488 

489 if bias is not None: 

490 assert bias.numel() == b.size( 

491 1 

492 ), f"Bias size {bias.numel()} must equal number of columns in b {b.size(1)}" 

493 assert bias.is_contiguous(), "Bias must be contiguous" 

494 assert bias.dim() == 1, "Bias must be a 1D tensor" 

495 

496 if SM_VERSION_NUM >= 120: 

497 cutlass_scaled_mm_sm120(c, a, b, a_scale, b_scale, bias) 

498 

499 elif SM_VERSION_NUM >= 100: 

500 cutlass_scaled_mm_sm100(c, a, b, a_scale, b_scale, bias) 

501 

502 elif SM_VERSION_NUM >= 90: 

503 # Hopper 

504 cutlass_scaled_mm_sm90(c, a, b, a_scale, b_scale, bias) 

505 

506 elif SM_VERSION_NUM >= 80: 

507 # Ampere 

508 cutlass_scaled_mm_sm80(c, a, b, a_scale, b_scale, bias) 

509 

510 elif SM_VERSION_NUM >= 75: 

511 # Turing 

512 cutlass_scaled_mm_sm75(c, a, b, a_scale, b_scale, bias)