Coverage for src/flag_gems/runtime/backend/_nvidia/hopper/ops/mm.py: 0%

265 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-30 03:43 +0800

1import logging 

2import os 

3from typing import Optional 

4 

5import torch 

6import triton 

7import triton.language as tl 

8import yaml 

9 

10from flag_gems import runtime 

11from flag_gems.ops.mm_streamk import streamk_mm 

12from flag_gems.runtime import torch_device_fn 

13from flag_gems.utils import libentry, libtuner 

14from flag_gems.utils import triton_lang_extension as tle 

15from flag_gems.utils.device_info import get_device_capability, get_sm_count 

16 

17logger = logging.getLogger("flag_gems.runtime.backend._nvidia.hopper.ops.mm") 

18CACHE_USAGE_THRESHOLD = 0.8 

19 

20 

21def is_tma_compatible(a, b, N, K): 

22 """ 

23 Check if tensors are compatible with TMA (Tensor Memory Accelerator). 

24 

25 TMA requires 128-bit (16-byte) alignment for memory access: 

26 - For FP16/BF16 (2 bytes/element): N and K must be multiples of 8 

27 (8 elements × 2 bytes = 16 bytes) 

28 - For FP32 (4 bytes/element): N and K must be multiples of 4 

29 (4 elements × 4 bytes = 16 bytes) 

30 

31 Args: 

32 a, b: Input tensors 

33 N, K: Matrix dimensions 

34 

35 Returns: 

36 bool: True if compatible with TMA's 128-bit alignment requirement 

37 """ 

38 return ( 

39 a.dtype in (torch.float16, torch.bfloat16) 

40 and b.dtype in (torch.float16, torch.bfloat16) 

41 and N % 8 == 0 

42 and K % 8 == 0 

43 ) or ( 

44 a.dtype in (torch.float32,) 

45 and b.dtype in (torch.float32,) 

46 and N % 4 == 0 

47 and K % 4 == 0 

48 ) 

49 

50 

51@triton.jit 

52def prev_multiple_of(a, b): 

53 # the largest x<a that x%b ==0 

54 return tl.cdiv(a, b) * b - b 

55 

56 

57def matmul_tma_set_block_size_hook(nargs): 

58 BLOCK_M = nargs["BLOCK_M"] 

59 BLOCK_N = nargs["BLOCK_N"] 

60 BLOCK_K = nargs["BLOCK_K"] 

61 if nargs["A_ROW_MAJOR"]: 

62 nargs["a_desc"].block_shape = [BLOCK_M, BLOCK_K] 

63 else: 

64 nargs["a_desc"].block_shape = [BLOCK_K, BLOCK_M] 

65 

66 if nargs["B_ROW_MAJOR"]: 

67 nargs["b_desc"].block_shape = [BLOCK_K, BLOCK_N] 

68 else: 

69 nargs["b_desc"].block_shape = [BLOCK_N, BLOCK_K] 

70 

71 nargs["c_desc"].block_shape = [BLOCK_M, BLOCK_N] 

72 

73 

74def get_expand_config(op): 

75 default_strategies = { 

76 "matmul": ["align32", "align32", "align32", "align32", "align32", "default"], 

77 "gemv": ["align32", "align32", "align32", "default"], 

78 } 

79 op_key_orders = { 

80 "matmul": ["M", "N", "K", "stride_am", "stride_bk", "dtype"], 

81 "gemv": ["M", "K", "stride_am", "stride_bk"], 

82 } 

83 op_meta_map = { 

84 "matmul": { 

85 "BM": "BLOCK_M", 

86 "BN": "BLOCK_N", 

87 "BK": "BLOCK_K", 

88 }, 

89 "gemv": { 

90 "BM": "BLOCK_M", 

91 "BK": "BLOCK_K", 

92 }, 

93 } 

94 

95 if op not in default_strategies: 

96 return -1 

97 

98 default_strategy = default_strategies[op] 

99 config_path = os.path.join(os.path.dirname(__file__), "..", "mm_hopper_expand.yaml") 

100 if not os.path.exists(config_path): 

101 return -1 

102 

103 try: 

104 with open(config_path, "r") as file: 

105 config = yaml.safe_load(file) or {} 

106 

107 expand_configs = config.get(op) 

108 

109 gen_config = None 

110 strategy_config = None 

111 for single_config in expand_configs: 

112 if isinstance(single_config, dict) and "param_map" in single_config: 

113 gen_config = single_config 

114 if isinstance(single_config, dict) and "strategy" in single_config: 

115 strategy_config = single_config.get("strategy") 

116 

117 param_map = gen_config["param_map"] 

118 meta_map = param_map["META"] 

119 

120 strategy = default_strategy 

121 if isinstance(strategy_config, dict): 

122 strategy = [ 

123 strategy_config.get(k, default_strategy[idx]) 

124 for idx, k in enumerate(op_key_orders[op]) 

125 ] 

126 

127 ranges = {} 

128 for range_key, meta_key in op_meta_map[op].items(): 

129 ranges[range_key] = gen_config[meta_map[meta_key]] 

130 ranges["s"] = gen_config[param_map["num_stages"]] 

131 ranges["w"] = gen_config[param_map["num_warps"]] 

132 

133 return { 

134 "ranges": ranges, 

135 "strategy": strategy, 

136 } 

137 except Exception: 

138 return -1 

139 

140 

141def matmul_get_configs(pre_hook=matmul_tma_set_block_size_hook): 

142 if os.environ.get("USE_FLAGTUNE") == "1": 

143 expand_config = get_expand_config("matmul") 

144 if expand_config != -1: 

145 logger.debug( 

146 "Using expand configurations from mm_hopper_expand.yaml for matmul kernel autotuning" 

147 ) 

148 ranges = expand_config["ranges"] 

149 return [ 

150 triton.Config( 

151 {"BLOCK_M": BM, "BLOCK_N": BN, "BLOCK_K": BK}, 

152 num_stages=s, 

153 num_warps=w, 

154 pre_hook=pre_hook, 

155 ) 

156 for BM in ranges["BM"] 

157 for BN in ranges["BN"] 

158 for BK in ranges["BK"] 

159 for s in ranges["s"] 

160 for w in ranges["w"] 

161 ] 

162 return [ 

163 triton.Config( 

164 {"BLOCK_M": BM, "BLOCK_N": BN, "BLOCK_K": BK}, 

165 num_stages=s, 

166 num_warps=w, 

167 pre_hook=pre_hook, 

168 ) 

169 for BM in [32, 64, 128, 256] 

170 for BN in [32, 64, 128] 

171 for BK in [32, 64, 128] 

172 for s in [2, 3, 4] 

173 for w in [4, 8] 

174 ] 

175 

176 

177@libentry() 

178@libtuner( 

179 configs=matmul_get_configs(pre_hook=None) 

180 if os.environ.get("USE_FLAGTUNE") == "1" and get_expand_config("matmul") != -1 

181 else runtime.get_tuned_config("mm"), 

182 key=["M", "N", "K", "stride_am", "stride_bk", "dtype"], 

183 strategy=get_expand_config("matmul")["strategy"] 

184 if os.environ.get("USE_FLAGTUNE") == "1" and get_expand_config("matmul") != -1 

185 else ["default", "default", "default", "default", "default", "default"], 

186 warmup=5, 

187 rep=10, 

188) 

189@triton.jit 

190def mm_kernel_general( 

191 A, 

192 B, 

193 C, 

194 M, 

195 N, 

196 K, 

197 stride_am, 

198 stride_ak, 

199 stride_bk, 

200 stride_bn, 

201 stride_cm, 

202 stride_cn, 

203 BLOCK_M: tl.constexpr, 

204 BLOCK_N: tl.constexpr, 

205 BLOCK_K: tl.constexpr, 

206 GROUP_M: tl.constexpr, 

207): 

208 # matrix multiplication 

209 pid = tle.program_id(0) 

210 grid_m = tl.cdiv(M, BLOCK_M) 

211 grid_n = tl.cdiv(N, BLOCK_N) 

212 # re-order program ID for better L2 performance 

213 width = GROUP_M * grid_n 

214 group_id = pid // width 

215 group_size = min(grid_m - group_id * GROUP_M, GROUP_M) 

216 pid_m = group_id * GROUP_M + (pid % group_size) 

217 pid_n = (pid % width) // (group_size) 

218 

219 if M % BLOCK_M == 0 and N % BLOCK_N == 0 and K % BLOCK_K == 0: 

220 # offset 

221 offset_am = pid_m * BLOCK_M 

222 offset_bn = pid_n * BLOCK_N 

223 offset_k = 0 

224 

225 a_desc = tl.make_tensor_descriptor( 

226 base=A, 

227 shape=[M, K], 

228 strides=[K, 1], 

229 block_shape=[BLOCK_M, BLOCK_K], 

230 ) 

231 

232 # row-major 

233 b_desc = tl.make_tensor_descriptor( 

234 base=B, 

235 shape=[K, N], 

236 strides=[N, 1], 

237 block_shape=[BLOCK_K, BLOCK_N], 

238 ) 

239 

240 # column-major 

241 # b_desc = tl.make_tensor_descriptor( 

242 # B, 

243 # shape = [N, K], 

244 # strides = [K, 1], 

245 # block_shape = [BLOCK_N, BLOCK_K], 

246 # ) 

247 

248 c_desc = tl.make_tensor_descriptor( 

249 base=C, 

250 shape=[M, N], 

251 strides=[N, 1], 

252 block_shape=[BLOCK_M, BLOCK_N], 

253 ) 

254 

255 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

256 for k in range(0, tl.cdiv(K, BLOCK_K)): 

257 a = a_desc.load([offset_am.to(tl.int32), offset_k.to(tl.int32)]) 

258 b = b_desc.load([offset_k.to(tl.int32), offset_bn.to(tl.int32)]) 

259 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False) 

260 offset_k += BLOCK_K 

261 

262 acc = acc.to(a_desc.dtype) 

263 c_desc.store([offset_am.to(tl.int32), offset_bn.to(tl.int32)], acc) 

264 

265 else: 

266 # do matrix multiplication 

267 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

268 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

269 ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M).to(tl.int64) 

270 rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N).to(tl.int64) 

271 rm = rm.to(tl.int64) 

272 rn = rn.to(tl.int64) 

273 prev_multiple = prev_multiple_of(K, BLOCK_K) 

274 

275 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

276 for start_k in range(0, prev_multiple, BLOCK_K): 

277 rk = (start_k + tl.arange(0, BLOCK_K)).to(tl.int64) 

278 a = tl.load(A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)) 

279 b = tl.load(B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)) 

280 if a.dtype != b.dtype: 

281 a = a.to(C.dtype.element_ty) 

282 b = b.to(C.dtype.element_ty) 

283 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False) 

284 

285 # loop peeling 

286 rk = (prev_multiple + tl.arange(0, BLOCK_K)).to(tl.int64) 

287 mask_k = rk < K 

288 a = tl.load( 

289 A + (ram[:, None] * stride_am + rk[None, :] * stride_ak), 

290 mask=mask_k[None, :], 

291 other=0.0, 

292 ) 

293 b = tl.load( 

294 B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn), 

295 mask=mask_k[:, None], 

296 other=0.0, 

297 ) 

298 if a.dtype != b.dtype: 

299 a = a.to(C.dtype.element_ty) 

300 b = b.to(C.dtype.element_ty) 

301 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False) 

302 

303 acc = acc.to(C.dtype.element_ty) 

304 # rematerialize rm and rn to save registers 

305 rm = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)).to(tl.int64) 

306 rn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)).to(tl.int64) 

307 offsets = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) 

308 mask = (rm < M)[:, None] & (rn < N)[None, :] 

309 # handles write-back with reduction-splitting 

310 tl.store(offsets, acc, mask=mask) 

311 

312 

313@libentry() 

314@libtuner( 

315 configs=matmul_get_configs(), 

316 key=["M", "N", "K", "stride_am", "stride_bk", "dtype"], 

317 strategy=get_expand_config("matmul")["strategy"] 

318 if os.environ.get("USE_FLAGTUNE") == "1" and get_expand_config("matmul") != -1 

319 else ["align32", "align32", "align32", "align32", "align32", "default"], 

320 warmup=5, 

321 rep=5, 

322) 

323@triton.jit 

324def mm_kernel_general_host_tma( 

325 a_desc, 

326 b_desc, 

327 c_desc, 

328 M, 

329 N, 

330 K, 

331 stride_am, 

332 stride_ak, 

333 stride_bk, 

334 stride_bn, 

335 stride_cm, 

336 stride_cn, 

337 BLOCK_M: tl.constexpr, 

338 BLOCK_N: tl.constexpr, 

339 BLOCK_K: tl.constexpr, 

340 GROUP_M: tl.constexpr, 

341 A_ROW_MAJOR: tl.constexpr, 

342 B_ROW_MAJOR: tl.constexpr, 

343 dtype: tl.constexpr, 

344 enable_warp_specialization=True, 

345): 

346 pid = tl.program_id(0) 

347 grid_m = tl.cdiv(M, BLOCK_M) 

348 grid_n = tl.cdiv(N, BLOCK_N) 

349 

350 width = GROUP_M * grid_n 

351 group_id = pid // width 

352 group_size = min(grid_m - group_id * GROUP_M, GROUP_M) 

353 pid_m = group_id * GROUP_M + (pid % group_size) 

354 pid_n = (pid % width) // (group_size) 

355 

356 accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

357 offset_am = (pid_m * BLOCK_M).to(tl.int32) 

358 offset_bn = (pid_n * BLOCK_N).to(tl.int32) 

359 iters = tl.cdiv(K, BLOCK_K) 

360 for k in range(iters): 

361 offset_ak = (k * BLOCK_K).to(tl.int32) 

362 

363 if A_ROW_MAJOR: 

364 a = a_desc.load([offset_am, offset_ak]) 

365 else: 

366 a_t = a_desc.load([offset_ak, offset_am]) 

367 a = tl.trans(a_t) 

368 

369 if B_ROW_MAJOR: 

370 b = b_desc.load([offset_ak, offset_bn]) 

371 else: 

372 b_t = b_desc.load([offset_bn, offset_ak]) 

373 b = tl.trans(b_t) 

374 

375 if a_desc.dtype == tl.float16 or a_desc.dtype == tl.bfloat16: 

376 accumulator = tl.dot(a, b, acc=accumulator, allow_tf32=False) 

377 else: 

378 accumulator = tl.dot(a, b, acc=accumulator, input_precision="tf32x3") 

379 

380 c = accumulator.to(c_desc.dtype) 

381 c_desc.store([offset_am, offset_bn], c) 

382 

383 

384def get_higher_dtype(a, b): 

385 _ordered_datatypes = [torch.float16, torch.bfloat16, torch.float32] 

386 

387 if a is b: 

388 return a 

389 

390 assert a in _ordered_datatypes 

391 assert b in _ordered_datatypes 

392 

393 for d in _ordered_datatypes: 

394 if a is d: 

395 return b 

396 if b is d: 

397 return a 

398 

399 

400def general_mm(a, b, c, M, N, K): 

401 # TODO: Remove this debug message 

402 logger.debug( 

403 "GEMS MM-hopper, [mm scenario]: general, [shape info]: [-, %s, %s, %s](batch, M, N, K), " 

404 "[A column-major]: %s, [B column-major]: %s", 

405 M, 

406 N, 

407 K, 

408 a.stride(0) == 1, 

409 b.stride(0) == 1, 

410 ) 

411 grid = lambda META: ( 

412 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), 

413 ) 

414 if hasattr( 

415 triton.tools.tensor_descriptor, "TensorDescriptor" 

416 ) and is_tma_compatible(a, b, N, K): 

417 a_row_major = a.stride(1) == 1 

418 b_row_major = b.stride(1) == 1 

419 dummy_block = [1, 1] 

420 # triton 3.5.0 

421 from triton.tools.tensor_descriptor import TensorDescriptor 

422 

423 if a_row_major: 

424 a_desc = TensorDescriptor(a, a.shape, a.stride(), dummy_block) 

425 else: 

426 a_desc = TensorDescriptor(a, a.T.shape, a.T.stride(), dummy_block) 

427 if b_row_major: 

428 b_desc = TensorDescriptor(b, b.shape, b.stride(), dummy_block) 

429 else: 

430 b_desc = TensorDescriptor(b, b.T.shape, b.T.stride(), dummy_block) 

431 c_desc = TensorDescriptor(c, c.shape, c.stride(), dummy_block) 

432 

433 input_dtype = a.dtype 

434 dtype_str = str(input_dtype).split(".")[-1] 

435 

436 with torch_device_fn.device(a.device): 

437 mm_kernel_general_host_tma[grid]( 

438 a_desc, 

439 b_desc, 

440 c_desc, 

441 M, 

442 N, 

443 K, 

444 a.stride(0), 

445 a.stride(1), 

446 b.stride(0), 

447 b.stride(1), 

448 c.stride(0), 

449 c.stride(1), 

450 GROUP_M=8, 

451 A_ROW_MAJOR=a_row_major, 

452 B_ROW_MAJOR=b_row_major, 

453 dtype=dtype_str, 

454 ) 

455 else: 

456 

457 def alloc_fn(size: int, align: int, stream: Optional[int]): 

458 return torch.empty(size, dtype=torch.int8, device=a.device) 

459 

460 triton.set_allocator(alloc_fn) 

461 

462 with torch_device_fn.device(a.device): 

463 mm_kernel_general[grid]( 

464 a, 

465 b, 

466 c, 

467 M, 

468 N, 

469 K, 

470 a.stride(0), 

471 a.stride(1), 

472 b.stride(0), 

473 b.stride(1), 

474 c.stride(0), 

475 c.stride(1), 

476 GROUP_M=8, 

477 ) 

478 return c 

479 

480 

481def gemv_get_configs(): 

482 if os.environ.get("USE_FLAGTUNE") == "1": 

483 expand_config = get_expand_config("gemv") 

484 if expand_config != -1: 

485 logger.debug( 

486 "Using expand configurations from mm_hopper_expand.yaml for gemv kernel autotuning" 

487 ) 

488 ranges = expand_config["ranges"] 

489 return [ 

490 triton.Config( 

491 {"BLOCK_M": BM, "BLOCK_K": BK}, 

492 num_stages=s, 

493 num_warps=w, 

494 ) 

495 for BM in ranges["BM"] 

496 for BK in ranges["BK"] 

497 for s in ranges["s"] 

498 for w in ranges["w"] 

499 ] 

500 return [ 

501 triton.Config( 

502 {"BLOCK_M": 32, "BLOCK_K": 256}, 

503 ) 

504 ] 

505 

506 

507@libentry() 

508@libtuner( 

509 configs=gemv_get_configs(), 

510 key=["M", "K", "stride_am", "stride_bk"], 

511 strategy=get_expand_config("gemv")["strategy"] 

512 if os.environ.get("USE_FLAGTUNE") == "1" and get_expand_config("gemv") != -1 

513 else ["align32", "align32", "align32", "default"], 

514 warmup=5, 

515 rep=10, 

516) 

517@triton.jit 

518def gemv_kernel( 

519 A, 

520 B, 

521 C, 

522 M, 

523 K, 

524 stride_am, 

525 stride_ak, 

526 stride_bk, 

527 BLOCK_M: tl.constexpr, 

528 BLOCK_K: tl.constexpr, 

529): 

530 """Optimized kernel for matrix-vector multiplication (N=1 case)""" 

531 pid = tl.program_id(0) 

532 

533 # Each program handles BLOCK_M rows 

534 row_start = pid * BLOCK_M 

535 row_offset = row_start + tl.arange(0, BLOCK_M) 

536 row_mask = row_offset < M 

537 

538 # Accumulator for this block of rows 

539 acc = tl.zeros((BLOCK_M,), dtype=tl.float32) 

540 

541 # Iterate over K dimension 

542 for k_start in range(0, K, BLOCK_K): 

543 k_offset = k_start + tl.arange(0, BLOCK_K) 

544 k_mask = k_offset < K 

545 

546 # Load block from matrix A: [BLOCK_M, BLOCK_K] 

547 a_ptrs = A + row_offset[:, None] * stride_am + k_offset[None, :] * stride_ak 

548 a = tl.load(a_ptrs, mask=row_mask[:, None] & k_mask[None, :], other=0.0) 

549 

550 # Load block from vector B: [BLOCK_K] 

551 b_ptrs = B + k_offset * stride_bk 

552 b = tl.load(b_ptrs, mask=k_mask, other=0.0) 

553 

554 # Accumulate: sum over K dimension 

555 acc += tl.sum(a * b[None, :], axis=1) 

556 

557 # Store result 

558 c_ptrs = C + row_offset 

559 acc = acc.to(C.dtype.element_ty) 

560 tl.store(c_ptrs, acc, mask=row_mask) 

561 

562 

563def gemv_mm(a, b, c, M, K): 

564 """Optimized matrix-vector multiplication for N=1 case""" 

565 logger.debug( 

566 "GEMS MM-hopper, [mm scenario]: gemv (N=1), [shape info]: [%s, %s, 1](M, K, N)", 

567 M, 

568 K, 

569 ) 

570 

571 grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]),) 

572 

573 with torch_device_fn.device(a.device): 

574 gemv_kernel[grid]( 

575 a, 

576 b, 

577 c, 

578 M, 

579 K, 

580 a.stride(0), 

581 a.stride(1), 

582 b.stride(0), 

583 ) 

584 return c 

585 

586 

587def streamk_scenario(a, b, M, N, K): 

588 # TODO: this my change sometime according to the realbenchmark result 

589 # Currently, the best configuration for streamk has only been tested on A100(capability[0] == 8). 

590 # The optimal settings for other devices need to be determined through real testing. 

591 capability = get_device_capability() 

592 return ( 

593 capability[0] == 8 

594 and a.dtype in [torch.float16, torch.bfloat16] 

595 and b.dtype in [torch.float16, torch.bfloat16] 

596 and a.is_contiguous() 

597 and b.is_contiguous() 

598 and K > M * 5 

599 and K > N * 5 

600 ) 

601 

602 

603def mm(a, b): 

604 device = a.device 

605 # handle non-contiguous inputs if necessary 

606 if a.stride(0) > 1 and a.stride(1) > 1: 

607 a = a.contiguous() 

608 if b.stride(0) > 1 and b.stride(1) > 1: 

609 b = b.contiguous() 

610 # checks constraints 

611 assert a.shape[1] == b.shape[0], "incompatible dimensions" 

612 M, K = a.shape 

613 _, N = b.shape 

614 # allocates output 

615 c_dtype = get_higher_dtype(a.dtype, b.dtype) 

616 c = torch.empty((M, N), device=device, dtype=c_dtype) 

617 

618 # Optimize for N=1 case (matrix-vector multiplication) 

619 if N == 1: 

620 return gemv_mm(a, b, c, M, K) 

621 # l2_cache_size = get_l2_cache_size() 

622 sm_count = get_sm_count() 

623 if streamk_scenario(a, b, M, N, K): 

624 return streamk_mm(a, b, c, M, N, K, sm_count=sm_count) 

625 else: 

626 return general_mm(a, b, c, M, N, K) 

627 

628 

629def mm_out(a, b, *, out): 

630 # handle non-contiguous inputs if necessary 

631 if a.stride(0) > 1 and a.stride(1) > 1: 

632 a = a.contiguous() 

633 if b.stride(0) > 1 and b.stride(1) > 1: 

634 b = b.contiguous() 

635 # checks constraints 

636 assert a.shape[1] == b.shape[0], "incompatible dimensions" 

637 M, K = a.shape 

638 _, N = b.shape 

639 

640 # Optimize for N=1 case (matrix-vector multiplication) 

641 if N == 1: 

642 return gemv_mm(a, b, out, M, K) 

643 # l2_cache_size = get_l2_cache_size() 

644 sm_count = get_sm_count() 

645 if streamk_scenario(a, b, M, N, K): 

646 return streamk_mm(a, b, out, M, N, K, sm_count=sm_count) 

647 else: 

648 return general_mm(a, b, out, M, N, K)