Coverage for src/flag_gems/runtime/backend/_nvidia/hopper/ops/mm.py: 0%

265 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-22 16:54 +0800

1import logging 

2import os 

3from typing import Optional 

4 

5import torch 

6import triton 

7import triton.language as tl 

8import yaml 

9 

10from flag_gems import runtime 

11from flag_gems.ops.mm_streamk import streamk_mm 

12from flag_gems.runtime import torch_device_fn 

13from flag_gems.utils import libentry, libtuner 

14from flag_gems.utils import triton_lang_extension as tle 

15from flag_gems.utils.device_info import get_device_capability, get_sm_count 

16 

17logger = logging.getLogger(__name__) 

18CACHE_USAGE_THRESHOLD = 0.8 

19 

20 

21def is_tma_compatible(a, b, N, K): 

22 """ 

23 Check if tensors are compatible with TMA (Tensor Memory Accelerator). 

24 

25 TMA requires 128-bit (16-byte) alignment for memory access: 

26 - For FP16/BF16 (2 bytes/element): N and K must be multiples of 8 

27 (8 elements × 2 bytes = 16 bytes) 

28 - For FP32 (4 bytes/element): N and K must be multiples of 4 

29 (4 elements × 4 bytes = 16 bytes) 

30 

31 Args: 

32 a, b: Input tensors 

33 N, K: Matrix dimensions 

34 

35 Returns: 

36 bool: True if compatible with TMA's 128-bit alignment requirement 

37 """ 

38 return ( 

39 a.dtype in (torch.float16, torch.bfloat16) 

40 and b.dtype in (torch.float16, torch.bfloat16) 

41 and N % 8 == 0 

42 and K % 8 == 0 

43 ) or ( 

44 a.dtype in (torch.float32,) 

45 and b.dtype in (torch.float32,) 

46 and N % 4 == 0 

47 and K % 4 == 0 

48 ) 

49 

50 

51@triton.jit 

52def prev_multiple_of(a, b): 

53 # the largest x<a that x%b ==0 

54 return tl.cdiv(a, b) * b - b 

55 

56 

57@libentry() 

58@libtuner( 

59 configs=runtime.get_tuned_config("mm"), 

60 # Add 'stride_am' and 'stride_bk' to trigger autotune for tensors with the same shape but different strides. 

61 key=["M", "N", "K", "stride_am", "stride_bk"], 

62 strategy=["default", "default", "default", "default", "default"], 

63 warmup=5, 

64 rep=10, 

65) 

66@triton.jit 

67def mm_kernel_general( 

68 A, 

69 B, 

70 C, 

71 M, 

72 N, 

73 K, 

74 stride_am, 

75 stride_ak, 

76 stride_bk, 

77 stride_bn, 

78 stride_cm, 

79 stride_cn, 

80 BLOCK_M: tl.constexpr, 

81 BLOCK_N: tl.constexpr, 

82 BLOCK_K: tl.constexpr, 

83 GROUP_M: tl.constexpr, 

84): 

85 # matrix multiplication 

86 pid = tle.program_id(0) 

87 grid_m = tl.cdiv(M, BLOCK_M) 

88 grid_n = tl.cdiv(N, BLOCK_N) 

89 # re-order program ID for better L2 performance 

90 width = GROUP_M * grid_n 

91 group_id = pid // width 

92 group_size = min(grid_m - group_id * GROUP_M, GROUP_M) 

93 pid_m = group_id * GROUP_M + (pid % group_size) 

94 pid_n = (pid % width) // (group_size) 

95 

96 if M % BLOCK_M == 0 and N % BLOCK_N == 0 and K % BLOCK_K == 0: 

97 # offset 

98 offset_am = pid_m * BLOCK_M 

99 offset_bn = pid_n * BLOCK_N 

100 offset_k = 0 

101 

102 a_desc = tl.make_tensor_descriptor( 

103 base=A, 

104 shape=[M, K], 

105 strides=[K, 1], 

106 block_shape=[BLOCK_M, BLOCK_K], 

107 ) 

108 

109 # row-major 

110 b_desc = tl.make_tensor_descriptor( 

111 base=B, 

112 shape=[K, N], 

113 strides=[N, 1], 

114 block_shape=[BLOCK_K, BLOCK_N], 

115 ) 

116 

117 # column-major 

118 # b_desc = tl.make_tensor_descriptor( 

119 # B, 

120 # shape = [N, K], 

121 # strides = [K, 1], 

122 # block_shape = [BLOCK_N, BLOCK_K], 

123 # ) 

124 

125 c_desc = tl.make_tensor_descriptor( 

126 base=C, 

127 shape=[M, N], 

128 strides=[N, 1], 

129 block_shape=[BLOCK_M, BLOCK_N], 

130 ) 

131 

132 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

133 for k in range(0, tl.cdiv(K, BLOCK_K)): 

134 a = a_desc.load([offset_am.to(tl.int32), offset_k.to(tl.int32)]) 

135 b = b_desc.load([offset_k.to(tl.int32), offset_bn.to(tl.int32)]) 

136 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False) 

137 offset_k += BLOCK_K 

138 

139 acc = acc.to(a_desc.dtype) 

140 c_desc.store([offset_am.to(tl.int32), offset_bn.to(tl.int32)], acc) 

141 

142 else: 

143 # do matrix multiplication 

144 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

145 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

146 ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M).to(tl.int64) 

147 rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N).to(tl.int64) 

148 rm = rm.to(tl.int64) 

149 rn = rn.to(tl.int64) 

150 prev_multiple = prev_multiple_of(K, BLOCK_K) 

151 

152 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

153 for start_k in range(0, prev_multiple, BLOCK_K): 

154 rk = (start_k + tl.arange(0, BLOCK_K)).to(tl.int64) 

155 a = tl.load(A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)) 

156 b = tl.load(B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)) 

157 if a.dtype != b.dtype: 

158 a = a.to(C.dtype.element_ty) 

159 b = b.to(C.dtype.element_ty) 

160 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False) 

161 

162 # loop peeling 

163 rk = (prev_multiple + tl.arange(0, BLOCK_K)).to(tl.int64) 

164 mask_k = rk < K 

165 a = tl.load( 

166 A + (ram[:, None] * stride_am + rk[None, :] * stride_ak), 

167 mask=mask_k[None, :], 

168 other=0.0, 

169 ) 

170 b = tl.load( 

171 B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn), 

172 mask=mask_k[:, None], 

173 other=0.0, 

174 ) 

175 if a.dtype != b.dtype: 

176 a = a.to(C.dtype.element_ty) 

177 b = b.to(C.dtype.element_ty) 

178 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False) 

179 

180 acc = acc.to(C.dtype.element_ty) 

181 # rematerialize rm and rn to save registers 

182 rm = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)).to(tl.int64) 

183 rn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)).to(tl.int64) 

184 offsets = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) 

185 mask = (rm < M)[:, None] & (rn < N)[None, :] 

186 # handles write-back with reduction-splitting 

187 tl.store(offsets, acc, mask=mask) 

188 

189 

190def matmul_tma_set_block_size_hook(nargs): 

191 BLOCK_M = nargs["BLOCK_M"] 

192 BLOCK_N = nargs["BLOCK_N"] 

193 BLOCK_K = nargs["BLOCK_K"] 

194 if nargs["A_ROW_MAJOR"]: 

195 nargs["a_desc"].block_shape = [BLOCK_M, BLOCK_K] 

196 else: 

197 nargs["a_desc"].block_shape = [BLOCK_K, BLOCK_M] 

198 

199 if nargs["B_ROW_MAJOR"]: 

200 nargs["b_desc"].block_shape = [BLOCK_K, BLOCK_N] 

201 else: 

202 nargs["b_desc"].block_shape = [BLOCK_N, BLOCK_K] 

203 

204 nargs["c_desc"].block_shape = [BLOCK_M, BLOCK_N] 

205 

206 

207def get_expand_config(op): 

208 default_strategies = { 

209 "matmul": ["align32", "align32", "align32", "align32", "align32", "default"], 

210 "gemv": ["align32", "align32", "align32", "default"], 

211 } 

212 op_key_orders = { 

213 "matmul": ["M", "N", "K", "stride_am", "stride_bk", "dtype"], 

214 "gemv": ["M", "K", "stride_am", "stride_bk"], 

215 } 

216 op_meta_map = { 

217 "matmul": { 

218 "BM": "BLOCK_M", 

219 "BN": "BLOCK_N", 

220 "BK": "BLOCK_K", 

221 }, 

222 "gemv": { 

223 "BM": "BLOCK_M", 

224 "BK": "BLOCK_K", 

225 }, 

226 } 

227 

228 if op not in default_strategies: 

229 return -1 

230 

231 default_strategy = default_strategies[op] 

232 config_path = os.path.join( 

233 os.path.dirname(__file__), "..", "mm_hopper_tma_expand.yaml" 

234 ) 

235 if not os.path.exists(config_path): 

236 return -1 

237 

238 try: 

239 with open(config_path, "r") as file: 

240 config = yaml.safe_load(file) or {} 

241 

242 expand_configs = config.get(op) 

243 

244 gen_config = None 

245 strategy_config = None 

246 for single_config in expand_configs: 

247 if isinstance(single_config, dict) and "param_map" in single_config: 

248 gen_config = single_config 

249 if isinstance(single_config, dict) and "strategy" in single_config: 

250 strategy_config = single_config.get("strategy") 

251 

252 param_map = gen_config["param_map"] 

253 meta_map = param_map["META"] 

254 

255 strategy = default_strategy 

256 if isinstance(strategy_config, dict): 

257 strategy = [ 

258 strategy_config.get(k, default_strategy[idx]) 

259 for idx, k in enumerate(op_key_orders[op]) 

260 ] 

261 

262 ranges = {} 

263 for range_key, meta_key in op_meta_map[op].items(): 

264 ranges[range_key] = gen_config[meta_map[meta_key]] 

265 ranges["s"] = gen_config[param_map["num_stages"]] 

266 ranges["w"] = gen_config[param_map["num_warps"]] 

267 

268 return { 

269 "ranges": ranges, 

270 "strategy": strategy, 

271 } 

272 except Exception: 

273 return -1 

274 

275 

276def matmul_get_configs(pre_hook=matmul_tma_set_block_size_hook): 

277 if os.environ.get("USE_FLAGTUNE") == "1": 

278 expand_config = get_expand_config("matmul") 

279 if expand_config != -1: 

280 logger.debug( 

281 "Using expand configurations from mm_hopper_tma_expand.yaml for matmul kernel autotuning" 

282 ) 

283 ranges = expand_config["ranges"] 

284 return [ 

285 triton.Config( 

286 {"BLOCK_M": BM, "BLOCK_N": BN, "BLOCK_K": BK}, 

287 num_stages=s, 

288 num_warps=w, 

289 pre_hook=pre_hook, 

290 ) 

291 for BM in ranges["BM"] 

292 for BN in ranges["BN"] 

293 for BK in ranges["BK"] 

294 for s in ranges["s"] 

295 for w in ranges["w"] 

296 ] 

297 return [ 

298 triton.Config( 

299 {"BLOCK_M": BM, "BLOCK_N": BN, "BLOCK_K": BK}, 

300 num_stages=s, 

301 num_warps=w, 

302 pre_hook=pre_hook, 

303 ) 

304 for BM in [32, 64, 128, 256] 

305 for BN in [32, 64, 128] 

306 for BK in [32, 64, 128] 

307 for s in [2, 3, 4] 

308 for w in [4, 8] 

309 ] 

310 

311 

312@libentry() 

313@libtuner( 

314 configs=matmul_get_configs(), 

315 key=["M", "N", "K", "stride_am", "stride_bk", "dtype"], 

316 strategy=get_expand_config("matmul")["strategy"] 

317 if os.environ.get("USE_FLAGTUNE") == "1" and get_expand_config("matmul") != -1 

318 else ["align32", "align32", "align32", "align32", "align32", "default"], 

319 warmup=5, 

320 rep=5, 

321) 

322@triton.jit 

323def mm_kernel_general_host_tma( 

324 a_desc, 

325 b_desc, 

326 c_desc, 

327 M, 

328 N, 

329 K, 

330 stride_am, 

331 stride_ak, 

332 stride_bk, 

333 stride_bn, 

334 stride_cm, 

335 stride_cn, 

336 BLOCK_M: tl.constexpr, 

337 BLOCK_N: tl.constexpr, 

338 BLOCK_K: tl.constexpr, 

339 GROUP_M: tl.constexpr, 

340 A_ROW_MAJOR: tl.constexpr, 

341 B_ROW_MAJOR: tl.constexpr, 

342 dtype: tl.constexpr, 

343 enable_warp_specialization=True, 

344): 

345 pid = tl.program_id(0) 

346 grid_m = tl.cdiv(M, BLOCK_M) 

347 grid_n = tl.cdiv(N, BLOCK_N) 

348 

349 width = GROUP_M * grid_n 

350 group_id = pid // width 

351 group_size = min(grid_m - group_id * GROUP_M, GROUP_M) 

352 pid_m = group_id * GROUP_M + (pid % group_size) 

353 pid_n = (pid % width) // (group_size) 

354 

355 accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

356 offset_am = (pid_m * BLOCK_M).to(tl.int32) 

357 offset_bn = (pid_n * BLOCK_N).to(tl.int32) 

358 iters = tl.cdiv(K, BLOCK_K) 

359 for k in range(iters): 

360 offset_ak = (k * BLOCK_K).to(tl.int32) 

361 

362 if A_ROW_MAJOR: 

363 a = a_desc.load([offset_am, offset_ak]) 

364 else: 

365 a_t = a_desc.load([offset_ak, offset_am]) 

366 a = tl.trans(a_t) 

367 

368 if B_ROW_MAJOR: 

369 b = b_desc.load([offset_ak, offset_bn]) 

370 else: 

371 b_t = b_desc.load([offset_bn, offset_ak]) 

372 b = tl.trans(b_t) 

373 

374 if a_desc.dtype == tl.float16 or a_desc.dtype == tl.bfloat16: 

375 accumulator = tl.dot(a, b, acc=accumulator, allow_tf32=False) 

376 else: 

377 accumulator = tl.dot(a, b, acc=accumulator, input_precision="tf32x3") 

378 

379 c = accumulator.to(c_desc.dtype) 

380 c_desc.store([offset_am, offset_bn], c) 

381 

382 

383def get_higher_dtype(a, b): 

384 _ordered_datatypes = [torch.float16, torch.bfloat16, torch.float32] 

385 

386 if a is b: 

387 return a 

388 

389 assert a in _ordered_datatypes 

390 assert b in _ordered_datatypes 

391 

392 for d in _ordered_datatypes: 

393 if a is d: 

394 return b 

395 if b is d: 

396 return a 

397 

398 

399def general_mm(a, b, c, M, N, K): 

400 # TODO: Remove this debug message 

401 logger.debug( 

402 "GEMS MM-hopper, [mm scenario]: general, [shape info]: [-, %s, %s, %s](batch, M, N, K), " 

403 "[A column-major]: %s, [B column-major]: %s", 

404 M, 

405 N, 

406 K, 

407 a.stride(0) == 1, 

408 b.stride(0) == 1, 

409 ) 

410 grid = lambda META: ( 

411 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), 

412 ) 

413 if hasattr( 

414 triton.tools.tensor_descriptor, "TensorDescriptor" 

415 ) and is_tma_compatible(a, b, N, K): 

416 a_row_major = a.stride(1) == 1 

417 b_row_major = b.stride(1) == 1 

418 dummy_block = [1, 1] 

419 # triton 3.5.0 

420 from triton.tools.tensor_descriptor import TensorDescriptor 

421 

422 if a_row_major: 

423 a_desc = TensorDescriptor(a, a.shape, a.stride(), dummy_block) 

424 else: 

425 a_desc = TensorDescriptor(a, a.T.shape, a.T.stride(), dummy_block) 

426 if b_row_major: 

427 b_desc = TensorDescriptor(b, b.shape, b.stride(), dummy_block) 

428 else: 

429 b_desc = TensorDescriptor(b, b.T.shape, b.T.stride(), dummy_block) 

430 c_desc = TensorDescriptor(c, c.shape, c.stride(), dummy_block) 

431 

432 input_dtype = a.dtype 

433 dtype_str = str(input_dtype).split(".")[-1] 

434 

435 with torch_device_fn.device(a.device): 

436 mm_kernel_general_host_tma[grid]( 

437 a_desc, 

438 b_desc, 

439 c_desc, 

440 M, 

441 N, 

442 K, 

443 a.stride(0), 

444 a.stride(1), 

445 b.stride(0), 

446 b.stride(1), 

447 c.stride(0), 

448 c.stride(1), 

449 GROUP_M=8, 

450 A_ROW_MAJOR=a_row_major, 

451 B_ROW_MAJOR=b_row_major, 

452 dtype=dtype_str, 

453 ) 

454 else: 

455 

456 def alloc_fn(size: int, align: int, stream: Optional[int]): 

457 return torch.empty(size, dtype=torch.int8, device=a.device) 

458 

459 triton.set_allocator(alloc_fn) 

460 

461 with torch_device_fn.device(a.device): 

462 mm_kernel_general[grid]( 

463 a, 

464 b, 

465 c, 

466 M, 

467 N, 

468 K, 

469 a.stride(0), 

470 a.stride(1), 

471 b.stride(0), 

472 b.stride(1), 

473 c.stride(0), 

474 c.stride(1), 

475 GROUP_M=8, 

476 ) 

477 return c 

478 

479 

480def gemv_get_configs(): 

481 if os.environ.get("USE_FLAGTUNE") == "1": 

482 expand_config = get_expand_config("gemv") 

483 if expand_config != -1: 

484 logger.debug( 

485 "Using expand configurations from mm_hopper_tma_expand.yaml for gemv kernel autotuning" 

486 ) 

487 ranges = expand_config["ranges"] 

488 return [ 

489 triton.Config( 

490 {"BLOCK_M": BM, "BLOCK_K": BK}, 

491 num_stages=s, 

492 num_warps=w, 

493 ) 

494 for BM in ranges["BM"] 

495 for BK in ranges["BK"] 

496 for s in ranges["s"] 

497 for w in ranges["w"] 

498 ] 

499 return [ 

500 triton.Config( 

501 {"BLOCK_M": 32, "BLOCK_K": 256}, 

502 ) 

503 ] 

504 

505 

506@libentry() 

507@libtuner( 

508 configs=gemv_get_configs(), 

509 key=["M", "K", "stride_am", "stride_bk"], 

510 strategy=get_expand_config("gemv")["strategy"] 

511 if os.environ.get("USE_FLAGTUNE") == "1" and get_expand_config("gemv") != -1 

512 else ["align32", "align32", "align32", "default"], 

513 warmup=5, 

514 rep=10, 

515) 

516@triton.jit 

517def gemv_kernel( 

518 A, 

519 B, 

520 C, 

521 M, 

522 K, 

523 stride_am, 

524 stride_ak, 

525 stride_bk, 

526 BLOCK_M: tl.constexpr, 

527 BLOCK_K: tl.constexpr, 

528): 

529 """Optimized kernel for matrix-vector multiplication (N=1 case)""" 

530 pid = tl.program_id(0) 

531 

532 # Each program handles BLOCK_M rows 

533 row_start = pid * BLOCK_M 

534 row_offset = row_start + tl.arange(0, BLOCK_M) 

535 row_mask = row_offset < M 

536 

537 # Accumulator for this block of rows 

538 acc = tl.zeros((BLOCK_M,), dtype=tl.float32) 

539 

540 # Iterate over K dimension 

541 for k_start in range(0, K, BLOCK_K): 

542 k_offset = k_start + tl.arange(0, BLOCK_K) 

543 k_mask = k_offset < K 

544 

545 # Load block from matrix A: [BLOCK_M, BLOCK_K] 

546 a_ptrs = A + row_offset[:, None] * stride_am + k_offset[None, :] * stride_ak 

547 a = tl.load(a_ptrs, mask=row_mask[:, None] & k_mask[None, :], other=0.0) 

548 

549 # Load block from vector B: [BLOCK_K] 

550 b_ptrs = B + k_offset * stride_bk 

551 b = tl.load(b_ptrs, mask=k_mask, other=0.0) 

552 

553 # Accumulate: sum over K dimension 

554 acc += tl.sum(a * b[None, :], axis=1) 

555 

556 # Store result 

557 c_ptrs = C + row_offset 

558 acc = acc.to(C.dtype.element_ty) 

559 tl.store(c_ptrs, acc, mask=row_mask) 

560 

561 

562def gemv_mm(a, b, c, M, K): 

563 """Optimized matrix-vector multiplication for N=1 case""" 

564 logger.debug( 

565 "GEMS MM-hopper, [mm scenario]: gemv (N=1), [shape info]: [%s, %s, 1](M, K, N)", 

566 M, 

567 K, 

568 ) 

569 

570 grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]),) 

571 

572 with torch_device_fn.device(a.device): 

573 gemv_kernel[grid]( 

574 a, 

575 b, 

576 c, 

577 M, 

578 K, 

579 a.stride(0), 

580 a.stride(1), 

581 b.stride(0), 

582 ) 

583 return c 

584 

585 

586def streamk_scenario(a, b, M, N, K): 

587 # TODO: this my change sometime according to the realbenchmark result 

588 # Currently, the best configuration for streamk has only been tested on A100(capability[0] == 8). 

589 # The optimal settings for other devices need to be determined through real testing. 

590 capability = get_device_capability() 

591 return ( 

592 capability[0] == 8 

593 and a.dtype in [torch.float16, torch.bfloat16] 

594 and b.dtype in [torch.float16, torch.bfloat16] 

595 and a.is_contiguous() 

596 and b.is_contiguous() 

597 and K > M * 5 

598 and K > N * 5 

599 ) 

600 

601 

602def mm(a, b): 

603 device = a.device 

604 # handle non-contiguous inputs if necessary 

605 if a.stride(0) > 1 and a.stride(1) > 1: 

606 a = a.contiguous() 

607 if b.stride(0) > 1 and b.stride(1) > 1: 

608 b = b.contiguous() 

609 # checks constraints 

610 assert a.shape[1] == b.shape[0], "incompatible dimensions" 

611 M, K = a.shape 

612 _, N = b.shape 

613 # allocates output 

614 c_dtype = get_higher_dtype(a.dtype, b.dtype) 

615 c = torch.empty((M, N), device=device, dtype=c_dtype) 

616 

617 # Optimize for N=1 case (matrix-vector multiplication) 

618 if N == 1: 

619 return gemv_mm(a, b, c, M, K) 

620 # l2_cache_size = get_l2_cache_size() 

621 sm_count = get_sm_count() 

622 if streamk_scenario(a, b, M, N, K): 

623 return streamk_mm(a, b, c, M, N, K, sm_count=sm_count) 

624 else: 

625 return general_mm(a, b, c, M, N, K) 

626 

627 

628def mm_out(a, b, *, out): 

629 # handle non-contiguous inputs if necessary 

630 if a.stride(0) > 1 and a.stride(1) > 1: 

631 a = a.contiguous() 

632 if b.stride(0) > 1 and b.stride(1) > 1: 

633 b = b.contiguous() 

634 # checks constraints 

635 assert a.shape[1] == b.shape[0], "incompatible dimensions" 

636 M, K = a.shape 

637 _, N = b.shape 

638 

639 # Optimize for N=1 case (matrix-vector multiplication) 

640 if N == 1: 

641 return gemv_mm(a, b, out, M, K) 

642 # l2_cache_size = get_l2_cache_size() 

643 sm_count = get_sm_count() 

644 if streamk_scenario(a, b, M, N, K): 

645 return streamk_mm(a, b, out, M, N, K, sm_count=sm_count) 

646 else: 

647 return general_mm(a, b, out, M, N, K)