Coverage for src/flag_gems/fused/cutlass_scaled_mm.py: 18%

191 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-20 02:31 +0800

1from typing import Callable, Optional 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils.device_info import get_device_capability 

8 

9SCALE_BLOCK_K, SCALE_BLOCK_N = 128, 128 

10 

11 

12def get_sm_version_num(): 

13 major, minor = get_device_capability() 

14 return major * 10 + minor 

15 

16 

17SM_VERSION_NUM = get_sm_version_num() 

18 

19 

20def get_block_wise_smm_configs(): 

21 tile_configs = [ 

22 # (TILE_M, TILE_N, num_stages, num_warps) 

23 (32, 64, 5, 2), 

24 (64, 32, 5, 2), 

25 (64, 128, 4, 4), 

26 (64, 256, 4, 4), 

27 (128, 32, 4, 4), 

28 (128, 64, 4, 4), 

29 (128, 128, 4, 4), 

30 (128, 256, 3, 8), 

31 (256, 64, 4, 4), 

32 (256, 128, 3, 8), 

33 ] 

34 

35 return [ 

36 triton.Config( 

37 { 

38 "TILE_M": TILE_M, 

39 "TILE_N": TILE_N, 

40 "TILE_K": SCALE_BLOCK_K, 

41 "SWIZZLE_GROUP_M": 8, 

42 }, 

43 num_stages=stages, 

44 num_warps=warps, 

45 ) 

46 for TILE_M, TILE_N, stages, warps in tile_configs 

47 ] 

48 

49 

50@triton.jit 

51def grouped_launch( 

52 pid, M, N, TILE_M: tl.constexpr, TILE_N: tl.constexpr, SWIZZLE_GROUP_M: tl.constexpr 

53): 

54 grid_m = tl.cdiv(M, TILE_M) 

55 grid_n = tl.cdiv(N, TILE_N) 

56 

57 width = SWIZZLE_GROUP_M * grid_n 

58 group_id = pid // width 

59 group_size = tl.minimum(grid_m - group_id * SWIZZLE_GROUP_M, SWIZZLE_GROUP_M) 

60 

61 pid_m = group_id * SWIZZLE_GROUP_M + (pid % group_size) 

62 pid_n = (pid % width) // group_size 

63 

64 return pid_m, pid_n 

65 

66 

67# block-wise dequantization kernel implemention 

68# this kernel supports many `SCALE_BLOCK_K, SCALE_BLOCK_N` cases 

69# as long as `TILE_K == SCALE_BLOCK_K` and `TILE_N % SCALE_BLOCK_N == 0` 

70@triton.autotune( 

71 configs=get_block_wise_smm_configs(), 

72 key=["_M_NPO2", "N", "K"], 

73) 

74@triton.jit 

75def _block_wise_smm_kernel( 

76 a_ptr, 

77 b_ptr, 

78 c_ptr, 

79 a_scale_ptr, 

80 b_scale_ptr, 

81 M, 

82 N, 

83 K, 

84 _M_NPO2: tl.constexpr, 

85 SCALE_BLOCK_N, 

86 SCALE_BLOCK_K, 

87 stride_am, 

88 stride_ak, 

89 stride_bk, 

90 stride_bn, 

91 stride_cm, 

92 stride_cn, 

93 stride_Ascale_m, 

94 stride_Ascale_k, 

95 stride_Bscale_k, 

96 stride_Bscale_n, 

97 TILE_M: tl.constexpr, 

98 TILE_N: tl.constexpr, 

99 TILE_K: tl.constexpr, 

100 SWIZZLE_GROUP_M: tl.constexpr, 

101): 

102 pid = tl.program_id(0) 

103 pid_m, pid_n = grouped_launch(pid, M, N, TILE_M, TILE_N, SWIZZLE_GROUP_M) 

104 

105 offs_am = (pid_m * TILE_M + tl.arange(0, TILE_M)) % M 

106 offs_bn = (pid_n * TILE_N + tl.arange(0, TILE_N)) % N 

107 offs_k = tl.arange(0, TILE_K) 

108 a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) 

109 b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) 

110 

111 a_scale_ptrs = a_scale_ptr + offs_am * stride_Ascale_m 

112 offs_bsn = offs_bn // SCALE_BLOCK_N 

113 b_scale_ptrs = b_scale_ptr + offs_bsn * stride_Bscale_n 

114 

115 acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32) 

116 for k in range(0, tl.cdiv(K, TILE_K)): 

117 k_remaining = K - k * TILE_K 

118 a = tl.load(a_ptrs, mask=offs_k[None, :] < k_remaining, other=0.0) 

119 b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0) 

120 offs_ks = k * TILE_K // SCALE_BLOCK_K 

121 a_scale = tl.load(a_scale_ptrs + offs_ks * stride_Ascale_k) 

122 b_scale = tl.load(b_scale_ptrs + offs_ks * stride_Bscale_k) 

123 acc += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :] 

124 a_ptrs += TILE_K * stride_ak 

125 b_ptrs += TILE_K * stride_bk 

126 

127 acc = acc.to(c_ptr.dtype.element_ty) 

128 

129 offs_cm = pid_m * TILE_M + tl.arange(0, TILE_M) 

130 offs_cn = pid_n * TILE_N + tl.arange(0, TILE_N) 

131 c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] 

132 mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) 

133 tl.store(c_ptrs, acc, mask=mask) 

134 

135 

136def _block_wise_128_smm_launcher( 

137 c: torch.Tensor, 

138 a: torch.Tensor, 

139 b: torch.Tensor, 

140 a_scale: torch.Tensor, 

141 b_scale: torch.Tensor, 

142) -> torch.Tensor: 

143 global SCALE_BLOCK_K, SCALE_BLOCK_N 

144 SCALE_BLOCK_K, SCALE_BLOCK_N = 128, 128 

145 M, K = a.shape 

146 _, N = b.shape 

147 _M_NPO2 = triton.next_power_of_2(M) 

148 

149 grid = lambda META: ( 

150 triton.cdiv(M, META["TILE_M"]) * triton.cdiv(N, META["TILE_N"]), 

151 ) 

152 

153 _block_wise_smm_kernel[grid]( 

154 a, 

155 b, 

156 c, 

157 a_scale, 

158 b_scale, 

159 M, 

160 N, 

161 K, 

162 _M_NPO2, 

163 SCALE_BLOCK_N, 

164 SCALE_BLOCK_K, 

165 a.stride(0), 

166 a.stride(1), 

167 b.stride(0), 

168 b.stride(1), 

169 c.stride(0), 

170 c.stride(1), 

171 a_scale.stride(0), 

172 a_scale.stride(1), 

173 b_scale.stride(0), 

174 b_scale.stride(1), 

175 ) 

176 

177 return c 

178 

179 

180# per-tensor and per-token dequantization kernel implemention 

181@triton.autotune( 

182 configs=[ 

183 triton.Config({"TILE_M": 64, "TILE_N": 64, "TILE_K": 256}), 

184 triton.Config({"TILE_M": 64, "TILE_N": 128, "TILE_K": 128}), 

185 triton.Config({"TILE_M": 128, "TILE_N": 128, "TILE_K": 128}), 

186 ], 

187 key=["_M_NPO2", "N", "K"], 

188) 

189@triton.jit 

190def _pertensor_or_pertoken_smm_kernel( 

191 c_ptr, 

192 a_ptr, 

193 b_ptr, 

194 a_scale_ptr, 

195 b_scale_ptr, 

196 bias_ptr, 

197 M, 

198 N, 

199 K, 

200 _M_NPO2, 

201 stride_am, 

202 stride_ak, 

203 stride_bk, 

204 stride_bn, 

205 stride_cm, 

206 stride_cn, 

207 ACC_DTYPE: tl.constexpr, 

208 TILE_M: tl.constexpr, 

209 TILE_N: tl.constexpr, 

210 TILE_K: tl.constexpr, 

211 IS_PER_TOKEN_A: tl.constexpr, 

212 IS_PER_TOKEN_B: tl.constexpr, 

213): 

214 if IS_PER_TOKEN_A: 

215 TILE_SIZE_SCALE_A: tl.constexpr = TILE_M 

216 else: 

217 TILE_SIZE_SCALE_A: tl.constexpr = 1 

218 

219 if IS_PER_TOKEN_B: 

220 TILE_SIZE_SCALE_B: tl.constexpr = TILE_N 

221 else: 

222 TILE_SIZE_SCALE_B: tl.constexpr = 1 

223 

224 pid = tl.program_id(axis=0) 

225 num_pid_n = tl.cdiv(N, TILE_N) 

226 pid_m = pid // num_pid_n 

227 pid_n = pid % num_pid_n 

228 

229 acc = tl.zeros((TILE_M, TILE_N), dtype=ACC_DTYPE) 

230 

231 offsets_am = pid_m * TILE_M + tl.arange(0, TILE_M).to(tl.int64) 

232 masks_am = offsets_am < M 

233 

234 offsets_bn = pid_n * TILE_N + tl.arange(0, TILE_N).to(tl.int64) 

235 masks_bn = offsets_bn < N 

236 

237 offsets_k = tl.arange(0, TILE_K).to(tl.int64) 

238 offsets_a = stride_am * offsets_am[:, None] + stride_ak * offsets_k[None, :] 

239 offsets_b = stride_bk * offsets_k[:, None] + stride_bn * offsets_bn[None, :] 

240 

241 offsets_scale_am = ( 

242 tl.arange(0, TILE_SIZE_SCALE_A) + (TILE_SIZE_SCALE_A > 1) * pid_m * TILE_M 

243 ) 

244 masks_scale_am = offsets_scale_am < M 

245 

246 offsets_scale_bn = ( 

247 tl.arange(0, TILE_SIZE_SCALE_B) + (TILE_SIZE_SCALE_B > 1) * pid_n * TILE_N 

248 ) 

249 masks_scale_bn = offsets_scale_bn < N 

250 

251 a_ptrs = a_ptr + offsets_a 

252 b_ptrs = b_ptr + offsets_b 

253 

254 scale_a_ptrs = a_scale_ptr + offsets_scale_am 

255 scale_b_ptrs = b_scale_ptr + offsets_scale_bn 

256 

257 for k in range(0, tl.cdiv(K, TILE_K)): 

258 masks_k = offsets_k < K 

259 masks_a = masks_am[:, None] & masks_k[None, :] 

260 a = tl.load(a_ptrs, mask=masks_a) 

261 

262 masks_b = masks_k[:, None] & masks_bn[None, :] 

263 b = tl.load(b_ptrs, mask=masks_b) 

264 

265 acc = tl.dot(a, b, acc, out_dtype=ACC_DTYPE) 

266 

267 offsets_k += TILE_K 

268 a_ptrs += TILE_K * stride_ak 

269 b_ptrs += TILE_K * stride_bk 

270 

271 masks_scale_a = masks_scale_am[:, None] & (tl.arange(0, 1) < 1)[:, None] 

272 a_scale = tl.load(scale_a_ptrs[:, None], masks_scale_a) 

273 a_scale = a_scale.broadcast_to((TILE_M, 1)) 

274 acc = a_scale * acc.to(tl.float32) 

275 

276 masks_scale_b = masks_scale_bn[:, None] & (tl.arange(0, 1) < 1)[None, :] 

277 b_scale = tl.load(scale_b_ptrs[:, None], masks_scale_b) 

278 b_scale = b_scale.broadcast_to((TILE_N, 1)) 

279 acc = b_scale.T * acc.to(tl.float32) 

280 

281 c = acc.to(c_ptr.type.element_ty) 

282 

283 if bias_ptr: 

284 offsets_bias = offsets_bn 

285 bias_ptrs = bias_ptr + offsets_bias 

286 bias_mask = offsets_bias < N 

287 bias = tl.load(bias_ptrs, bias_mask) 

288 c += bias 

289 

290 offs_cm = pid_m * TILE_M + tl.arange(0, TILE_M).to(tl.int64) 

291 offs_cn = pid_n * TILE_N + tl.arange(0, TILE_N).to(tl.int64) 

292 offs_cm = offs_cm.to(tl.int64) 

293 offs_cn = offs_cn.to(tl.int64) 

294 c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] 

295 c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) 

296 

297 tl.store(c_ptrs, c, mask=c_mask) 

298 

299 

300def _pertensor_or_pertoken_smm_launcher( 

301 c: torch.Tensor, 

302 a: torch.Tensor, 

303 b: torch.Tensor, 

304 a_scale: torch.Tensor, 

305 b_scale: torch.Tensor, 

306 bias: torch.Tensor | None = None, 

307) -> torch.Tensor: 

308 M, K = a.shape 

309 _, N = b.shape 

310 

311 grid = lambda META: ( 

312 triton.cdiv(M, META["TILE_M"]) * triton.cdiv(N, META["TILE_N"]), 

313 ) 

314 

315 ACC_DTYPE = tl.float32 if a.is_floating_point() else tl.int32 

316 

317 _M_NPO2 = triton.next_power_of_2(M) 

318 

319 IS_PER_TOKEN_A = a_scale.numel() == M 

320 IS_PER_TOKEN_B = b_scale.numel() == N 

321 

322 _pertensor_or_pertoken_smm_kernel[grid]( 

323 c, 

324 a, 

325 b, 

326 a_scale, 

327 b_scale, 

328 bias, 

329 M, 

330 N, 

331 K, 

332 _M_NPO2, 

333 a.stride(0), 

334 a.stride(1), 

335 b.stride(0), 

336 b.stride(1), 

337 c.stride(0), 

338 c.stride(1), 

339 ACC_DTYPE=ACC_DTYPE, 

340 IS_PER_TOKEN_A=IS_PER_TOKEN_A, 

341 IS_PER_TOKEN_B=IS_PER_TOKEN_B, 

342 ) 

343 

344 return c 

345 

346 

347cutlass_scaled_mm_sm90_fp8 = _pertensor_or_pertoken_smm_launcher 

348 

349cutlass_scaled_mm_sm90_int8 = _pertensor_or_pertoken_smm_launcher 

350 

351cutlass_scaled_mm_blockwise_sm90_fp8 = _block_wise_128_smm_launcher 

352 

353 

354def dispatch_scaled_mm( 

355 c: torch.Tensor, 

356 a: torch.Tensor, 

357 b: torch.Tensor, 

358 a_scale: torch.Tensor, 

359 b_scale: torch.Tensor, 

360 bias: Optional[torch.Tensor], 

361 fp8_func: Callable, 

362 int8_func: Optional[Callable], 

363 blockwise_func: Callable, 

364) -> None: 

365 assert a_scale.dtype == torch.float32, "a_scale must be float32" 

366 assert b_scale.dtype == torch.float32, "b_scale must be float32" 

367 

368 if (a_scale.numel() == 1 or a_scale.numel() == a.size(0)) and ( 

369 b_scale.numel() == 1 or b_scale.numel() == b.size(1) 

370 ): 

371 assert a_scale.is_contiguous(), "a_scale must be contiguous" 

372 assert b_scale.is_contiguous(), "b_scale must be contiguous" 

373 

374 if a.dtype == torch.float8_e4m3fn: 

375 fp8_func(c, a, b, a_scale, b_scale, bias) 

376 else: 

377 assert a.dtype == torch.int8, f"Unsupported dtype: {a.dtype}" 

378 

379 if int8_func is not None: 

380 int8_func(c, a, b, a_scale, b_scale, bias) 

381 else: 

382 raise RuntimeError( 

383 f"Int8 not supported on SM{SM_VERSION_NUM}. " 

384 f"Use FP8 quantization instead, or run on older arch (SM < 100)." 

385 ) 

386 else: 

387 assert a_scale.dim() == 2, "a_scale must be 2D tensor for blockwise scaling" 

388 assert b_scale.dim() == 2, "b_scale must be 2D tensor for blockwise scaling" 

389 

390 if SM_VERSION_NUM >= 90: 

391 assert a.size(0) == a_scale.size(0), ( 

392 f"a_scale must have same first dimension as a: " 

393 f"a.shape[0]={a.size(0)}, a_scale.shape[0]={a_scale.size(0)}" 

394 ) 

395 assert triton.cdiv(a.size(1), 128) == a_scale.size(1), ( 

396 f"a_scale second dimension mismatch: " 

397 f"triton.cdiv({a.size(1)}, 128)={triton.cdiv(a.size(1), 128)} != " 

398 f"a_scale.shape[1]={a_scale.size(1)}" 

399 ) 

400 

401 assert triton.cdiv(b.size(0), 128) == b_scale.size(0), ( 

402 f"b_scale first dimension mismatch: " 

403 f"triton.cdiv({b.size(0)}, 128)={triton.cdiv(b.size(0), 128)} != " 

404 f"b_scale.shape[0]={b_scale.size(0)}" 

405 ) 

406 assert triton.cdiv(b.size(1), 128) == b_scale.size(1), ( 

407 f"b_scale second dimension mismatch: " 

408 f"triton.cdiv({b.size(1)}, 128)={triton.cdiv(b.size(1), 128)} != " 

409 f"b_scale.shape[1]={b_scale.size(1)}" 

410 ) 

411 

412 assert bias is None, "Bias not yet supported for blockwise scaled_mm" 

413 

414 blockwise_func(c, a, b, a_scale, b_scale) 

415 

416 

417def cutlass_scaled_mm_sm90( 

418 c: torch.Tensor, 

419 a: torch.Tensor, 

420 b: torch.Tensor, 

421 a_scale: torch.Tensor, 

422 b_scale: torch.Tensor, 

423 bias: Optional[torch.Tensor] = None, 

424) -> None: 

425 dispatch_scaled_mm( 

426 c=c, 

427 a=a, 

428 b=b, 

429 a_scale=a_scale, 

430 b_scale=b_scale, 

431 bias=bias, 

432 fp8_func=cutlass_scaled_mm_sm90_fp8, 

433 int8_func=cutlass_scaled_mm_sm90_int8, 

434 blockwise_func=cutlass_scaled_mm_blockwise_sm90_fp8, 

435 ) 

436 

437 

438def cutlass_scaled_mm_sm120(*args, **kwargs): 

439 raise NotImplementedError("cutlass_scaled_mm_sm120 is not yet implemented. ") 

440 

441 

442def cutlass_scaled_mm_sm100(*args, **kwargs): 

443 raise NotImplementedError("cutlass_scaled_mm_sm100 is not yet implemented. ") 

444 

445 

446def cutlass_scaled_mm_sm89(*args, **kwargs): 

447 raise NotImplementedError("cutlass_scaled_mm_sm89 is not yet implemented. ") 

448 

449 

450def cutlass_scaled_mm_sm80(*args, **kwargs): 

451 raise NotImplementedError("cutlass_scaled_mm_sm80 is not yet implemented. ") 

452 

453 

454def cutlass_scaled_mm_sm75(*args, **kwargs): 

455 raise NotImplementedError("cutlass_scaled_mm_sm75 is not yet implemented. ") 

456 

457 

458def cutlass_scaled_mm( 

459 c: torch.Tensor, 

460 a: torch.Tensor, 

461 b: torch.Tensor, 

462 a_scale: torch.Tensor, 

463 b_scale: torch.Tensor, 

464 bias: Optional[torch.Tensor] = None, 

465) -> torch.Tensor: 

466 assert ( 

467 a.dim() == 2 and b.dim() == 2 and c.dim() == 2 

468 ), "All inputs must be 2D tensors" 

469 

470 assert c.size(0) == a.size(0), "Number of rows in c must equal number of rows in a" 

471 assert a.size(1) == b.size( 

472 0 

473 ), "Number of columns in a must equal number of rows in b" 

474 assert b.size(1) == c.size( 

475 1 

476 ), "Number of columns in b must equal number of columns in c" 

477 

478 assert a.stride(1) == 1 and c.stride(1) == 1, "a and c must be row-major" 

479 

480 assert b.stride(0) == 1, "b must be column-major" 

481 

482 assert c.stride(0) % 16 == 0, "Row stride of c must be 16-byte aligned" 

483 assert b.stride(1) % 16 == 0, "Column stride of b must be 16-byte aligned" 

484 

485 if bias is not None: 

486 assert bias.numel() == b.size( 

487 1 

488 ), f"Bias size {bias.numel()} must equal number of columns in b {b.size(1)}" 

489 assert bias.is_contiguous(), "Bias must be contiguous" 

490 assert bias.dim() == 1, "Bias must be a 1D tensor" 

491 

492 if SM_VERSION_NUM >= 120: 

493 cutlass_scaled_mm_sm120(c, a, b, a_scale, b_scale, bias) 

494 

495 elif SM_VERSION_NUM >= 100: 

496 cutlass_scaled_mm_sm100(c, a, b, a_scale, b_scale, bias) 

497 

498 elif SM_VERSION_NUM >= 90: 

499 # Hopper 

500 cutlass_scaled_mm_sm90(c, a, b, a_scale, b_scale, bias) 

501 

502 elif SM_VERSION_NUM >= 80: 

503 # Ampere 

504 cutlass_scaled_mm_sm80(c, a, b, a_scale, b_scale, bias) 

505 

506 elif SM_VERSION_NUM >= 75: 

507 # Turing 

508 cutlass_scaled_mm_sm75(c, a, b, a_scale, b_scale, bias)