Coverage for src/flag_gems/runtime/backend/_nvidia/hopper/ops/mm.py: 0%

255 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-06 06:51 +0800

1import logging 

2import os 

3from typing import Optional 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9from flag_gems import runtime 

10from flag_gems.ops.mm_streamk import streamk_mm 

11from flag_gems.runtime import torch_device_fn 

12from flag_gems.utils import libentry, libtuner 

13from flag_gems.utils import triton_lang_extension as tle 

14from flag_gems.utils.device_info import get_device_capability, get_sm_count 

15 

16logger = logging.getLogger("flag_gems.runtime.backend._nvidia.hopper.ops.mm") 

17CACHE_USAGE_THRESHOLD = 0.8 

18EXPAND_CONFIG_FILENAME = os.path.normpath( 

19 os.path.join(os.path.dirname(__file__), "..", "mm_hopper_expand.yaml") 

20) 

21_SHARED_MEM_SAFETY_MARGIN_BYTES = 1024 

22 

23 

24def _get_shared_memory_limit_bytes(): 

25 """Return per-block opt-in shared-memory limit for current CUDA device.""" 

26 try: 

27 if not torch.cuda.is_available(): 

28 return None 

29 return torch.cuda.get_device_properties( 

30 torch.cuda.current_device() 

31 ).shared_memory_per_block_optin 

32 except Exception: 

33 return None 

34 

35 

36def _estimate_tma_shared_memory_bytes(block_m, block_n, block_k, num_stages): 

37 bytes_per_element = 4 

38 tile_bytes = (block_m * block_k + block_k * block_n) * bytes_per_element 

39 return tile_bytes * num_stages + _SHARED_MEM_SAFETY_MARGIN_BYTES 

40 

41 

42def is_tma_compatible(a, b, N, K): 

43 """ 

44 Check if tensors are compatible with TMA (Tensor Memory Accelerator). 

45 

46 TMA requires 128-bit (16-byte) alignment for memory access: 

47 - For FP16/BF16 (2 bytes/element): N and K must be multiples of 8 

48 (8 elements × 2 bytes = 16 bytes) 

49 - For FP32 (4 bytes/element): N and K must be multiples of 4 

50 (4 elements × 4 bytes = 16 bytes) 

51 

52 Args: 

53 a, b: Input tensors 

54 N, K: Matrix dimensions 

55 

56 Returns: 

57 bool: True if compatible with TMA's alignment requirements 

58 """ 

59 return ( 

60 a.dtype in (torch.float16, torch.bfloat16) 

61 and b.dtype in (torch.float16, torch.bfloat16) 

62 and N % 8 == 0 

63 and K % 8 == 0 

64 ) or ( 

65 a.dtype in (torch.float32,) 

66 and b.dtype in (torch.float32,) 

67 and N % 4 == 0 

68 and K % 4 == 0 

69 ) 

70 

71 

72@triton.jit 

73def prev_multiple_of(a, b): 

74 # the largest x<a that x%b ==0 

75 return tl.cdiv(a, b) * b - b 

76 

77 

78def matmul_tma_set_block_size_hook(nargs): 

79 BLOCK_M = nargs["BLOCK_M"] 

80 BLOCK_N = nargs["BLOCK_N"] 

81 BLOCK_K = nargs["BLOCK_K"] 

82 if nargs["A_ROW_MAJOR"]: 

83 nargs["a_desc"].block_shape = [BLOCK_M, BLOCK_K] 

84 else: 

85 nargs["a_desc"].block_shape = [BLOCK_K, BLOCK_M] 

86 

87 if nargs["B_ROW_MAJOR"]: 

88 nargs["b_desc"].block_shape = [BLOCK_K, BLOCK_N] 

89 else: 

90 nargs["b_desc"].block_shape = [BLOCK_N, BLOCK_K] 

91 

92 nargs["c_desc"].block_shape = [BLOCK_M, BLOCK_N] 

93 

94 

95@libentry() 

96@libtuner( 

97 configs=runtime.get_tuned_config("mm"), 

98 # Add 'stride_am' and 'stride_bk' to trigger autotune for tensors with the same shape but different strides. 

99 key=["M", "N", "K", "stride_am", "stride_bk"], 

100 strategy=["default", "default", "default", "default", "default"], 

101 warmup=5, 

102 rep=10, 

103) 

104@triton.jit 

105def mm_kernel_general( 

106 A, 

107 B, 

108 C, 

109 M, 

110 N, 

111 K, 

112 stride_am, 

113 stride_ak, 

114 stride_bk, 

115 stride_bn, 

116 stride_cm, 

117 stride_cn, 

118 BLOCK_M: tl.constexpr, 

119 BLOCK_N: tl.constexpr, 

120 BLOCK_K: tl.constexpr, 

121 GROUP_M: tl.constexpr, 

122 IS_FP64: tl.constexpr = False, 

123): 

124 # matrix multiplication 

125 pid = tle.program_id(0) 

126 grid_m = tl.cdiv(M, BLOCK_M) 

127 grid_n = tl.cdiv(N, BLOCK_N) 

128 # re-order program ID for better L2 performance 

129 width = GROUP_M * grid_n 

130 group_id = pid // width 

131 group_size = min(grid_m - group_id * GROUP_M, GROUP_M) 

132 pid_m = group_id * GROUP_M + (pid % group_size) 

133 pid_n = (pid % width) // (group_size) 

134 

135 if M % BLOCK_M == 0 and N % BLOCK_N == 0 and K % BLOCK_K == 0: 

136 # offset 

137 offset_am = pid_m * BLOCK_M 

138 offset_bn = pid_n * BLOCK_N 

139 offset_k = 0 

140 

141 a_desc = tl.make_tensor_descriptor( 

142 base=A, 

143 shape=[M, K], 

144 strides=[K, 1], 

145 block_shape=[BLOCK_M, BLOCK_K], 

146 ) 

147 

148 # row-major 

149 b_desc = tl.make_tensor_descriptor( 

150 base=B, 

151 shape=[K, N], 

152 strides=[N, 1], 

153 block_shape=[BLOCK_K, BLOCK_N], 

154 ) 

155 

156 # column-major 

157 # b_desc = tl.make_tensor_descriptor( 

158 # B, 

159 # shape = [N, K], 

160 # strides = [K, 1], 

161 # block_shape = [BLOCK_N, BLOCK_K], 

162 # ) 

163 

164 c_desc = tl.make_tensor_descriptor( 

165 base=C, 

166 shape=[M, N], 

167 strides=[N, 1], 

168 block_shape=[BLOCK_M, BLOCK_N], 

169 ) 

170 

171 if IS_FP64: 

172 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float64) 

173 else: 

174 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

175 for k in range(0, tl.cdiv(K, BLOCK_K)): 

176 a = a_desc.load([offset_am.to(tl.int32), offset_k.to(tl.int32)]) 

177 b = b_desc.load([offset_k.to(tl.int32), offset_bn.to(tl.int32)]) 

178 if IS_FP64: 

179 acc += tl.dot(a, b, allow_tf32=False) 

180 else: 

181 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False) 

182 offset_k += BLOCK_K 

183 

184 acc = acc.to(a_desc.dtype) 

185 c_desc.store([offset_am.to(tl.int32), offset_bn.to(tl.int32)], acc) 

186 

187 else: 

188 # do matrix multiplication 

189 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

190 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

191 ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M).to(tl.int64) 

192 rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N).to(tl.int64) 

193 rm = rm.to(tl.int64) 

194 rn = rn.to(tl.int64) 

195 prev_multiple = prev_multiple_of(K, BLOCK_K) 

196 

197 if IS_FP64: 

198 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float64) 

199 else: 

200 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

201 for start_k in range(0, prev_multiple, BLOCK_K): 

202 rk = (start_k + tl.arange(0, BLOCK_K)).to(tl.int64) 

203 a = tl.load(A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)) 

204 b = tl.load(B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)) 

205 if a.dtype != b.dtype: 

206 a = a.to(C.dtype.element_ty) 

207 b = b.to(C.dtype.element_ty) 

208 if IS_FP64: 

209 acc += tl.dot(a, b, allow_tf32=False) 

210 else: 

211 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False) 

212 

213 # loop peeling 

214 rk = (prev_multiple + tl.arange(0, BLOCK_K)).to(tl.int64) 

215 mask_k = rk < K 

216 a = tl.load( 

217 A + (ram[:, None] * stride_am + rk[None, :] * stride_ak), 

218 mask=mask_k[None, :], 

219 other=0.0, 

220 ) 

221 b = tl.load( 

222 B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn), 

223 mask=mask_k[:, None], 

224 other=0.0, 

225 ) 

226 if a.dtype != b.dtype: 

227 a = a.to(C.dtype.element_ty) 

228 b = b.to(C.dtype.element_ty) 

229 if IS_FP64: 

230 acc += tl.dot(a, b, allow_tf32=False) 

231 else: 

232 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False) 

233 

234 acc = acc.to(C.dtype.element_ty) 

235 # rematerialize rm and rn to save registers 

236 rm = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)).to(tl.int64) 

237 rn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)).to(tl.int64) 

238 offsets = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) 

239 mask = (rm < M)[:, None] & (rn < N)[None, :] 

240 # handles write-back with reduction-splitting 

241 tl.store(offsets, acc, mask=mask) 

242 

243 

244def matmul_get_configs(pre_hook=matmul_tma_set_block_size_hook): 

245 configs = [ 

246 triton.Config( 

247 {"BLOCK_M": BM, "BLOCK_N": BN, "BLOCK_K": BK}, 

248 num_stages=s, 

249 num_warps=w, 

250 pre_hook=pre_hook, 

251 ) 

252 for BM in [32, 64, 128, 256] 

253 for BN in [32, 64, 128] 

254 for BK in [32, 64, 128] 

255 for s in [2, 3, 4] 

256 for w in [4, 8] 

257 ] 

258 shared_mem_limit = _get_shared_memory_limit_bytes() 

259 if shared_mem_limit is None: 

260 return configs 

261 

262 filtered_configs = [ 

263 cfg 

264 for cfg in configs 

265 if _estimate_tma_shared_memory_bytes( 

266 cfg.kwargs["BLOCK_M"], 

267 cfg.kwargs["BLOCK_N"], 

268 cfg.kwargs["BLOCK_K"], 

269 cfg.num_stages, 

270 ) 

271 <= shared_mem_limit 

272 ] 

273 if not filtered_configs: 

274 logger.warning( 

275 "No mm_general_tma config fits shared memory limit (%s bytes); falling back to unfiltered configs.", 

276 shared_mem_limit, 

277 ) 

278 return configs 

279 return filtered_configs 

280 

281 

282@libentry() 

283@libtuner( 

284 configs=runtime.ops_get_configs( 

285 "mm_general_tma", 

286 pre_hook=matmul_tma_set_block_size_hook, 

287 yaml_path=EXPAND_CONFIG_FILENAME, 

288 ) 

289 if os.environ.get("USE_FLAGTUNE") == "1" 

290 else matmul_get_configs(), 

291 key=["M", "N", "K", "stride_am", "stride_bk", "dtype"], 

292 strategy=runtime.get_expand_config( 

293 "mm_general_tma", yaml_path=EXPAND_CONFIG_FILENAME 

294 )["strategy"] 

295 if os.environ.get("USE_FLAGTUNE") == "1" 

296 else ["align32", "align32", "align32", "align32", "align32", "default"], 

297 warmup=5, 

298 rep=5, 

299) 

300@triton.jit 

301def mm_kernel_general_host_tma( 

302 a_desc, 

303 b_desc, 

304 c_desc, 

305 M, 

306 N, 

307 K, 

308 stride_am, 

309 stride_ak, 

310 stride_bk, 

311 stride_bn, 

312 stride_cm, 

313 stride_cn, 

314 BLOCK_M: tl.constexpr, 

315 BLOCK_N: tl.constexpr, 

316 BLOCK_K: tl.constexpr, 

317 GROUP_M: tl.constexpr, 

318 A_ROW_MAJOR: tl.constexpr, 

319 B_ROW_MAJOR: tl.constexpr, 

320 dtype: tl.constexpr, 

321 enable_warp_specialization=True, 

322): 

323 pid = tl.program_id(0) 

324 grid_m = tl.cdiv(M, BLOCK_M) 

325 grid_n = tl.cdiv(N, BLOCK_N) 

326 

327 width = GROUP_M * grid_n 

328 group_id = pid // width 

329 group_size = min(grid_m - group_id * GROUP_M, GROUP_M) 

330 pid_m = group_id * GROUP_M + (pid % group_size) 

331 pid_n = (pid % width) // (group_size) 

332 

333 accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

334 offset_am = (pid_m * BLOCK_M).to(tl.int32) 

335 offset_bn = (pid_n * BLOCK_N).to(tl.int32) 

336 iters = tl.cdiv(K, BLOCK_K) 

337 for k in range(iters): 

338 offset_ak = (k * BLOCK_K).to(tl.int32) 

339 

340 if A_ROW_MAJOR: 

341 a = a_desc.load([offset_am, offset_ak]) 

342 else: 

343 a_t = a_desc.load([offset_ak, offset_am]) 

344 a = tl.trans(a_t) 

345 

346 if B_ROW_MAJOR: 

347 b = b_desc.load([offset_ak, offset_bn]) 

348 else: 

349 b_t = b_desc.load([offset_bn, offset_ak]) 

350 b = tl.trans(b_t) 

351 

352 if a_desc.dtype == tl.float16 or a_desc.dtype == tl.bfloat16: 

353 accumulator = tl.dot(a, b, acc=accumulator, allow_tf32=False) 

354 else: 

355 accumulator = tl.dot(a, b, acc=accumulator, input_precision="tf32x3") 

356 

357 c = accumulator.to(c_desc.dtype) 

358 c_desc.store([offset_am, offset_bn], c) 

359 

360 

361def get_higher_dtype(a, b): 

362 _ordered_datatypes = [torch.float16, torch.bfloat16, torch.float32, torch.float64] 

363 

364 if a is b: 

365 return a 

366 

367 assert a in _ordered_datatypes 

368 assert b in _ordered_datatypes 

369 

370 for d in _ordered_datatypes: 

371 if a is d: 

372 return b 

373 if b is d: 

374 return a 

375 

376 

377def general_mm(a, b, c, M, N, K): 

378 # TODO: Remove this debug message 

379 logger.debug( 

380 "GEMS MM-hopper, [mm scenario]: general, [shape info]: [-, %s, %s, %s](batch, M, N, K), " 

381 "[A column-major]: %s, [B column-major]: %s", 

382 M, 

383 N, 

384 K, 

385 a.stride(0) == 1, 

386 b.stride(0) == 1, 

387 ) 

388 # Broadcast tensors from expand() have stride=0, incompatible with TMA 

389 if 0 in a.stride(): 

390 a = a.contiguous() 

391 if 0 in b.stride(): 

392 b = b.contiguous() 

393 grid = lambda META: ( 

394 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), 

395 ) 

396 if hasattr( 

397 triton.tools.tensor_descriptor, "TensorDescriptor" 

398 ) and is_tma_compatible(a, b, N, K): 

399 a_row_major = a.stride(1) == 1 

400 b_row_major = b.stride(1) == 1 

401 dummy_block = [1, 1] 

402 # triton 3.5.0 

403 from triton.tools.tensor_descriptor import TensorDescriptor 

404 

405 if a_row_major: 

406 a_desc = TensorDescriptor(a, a.shape, a.stride(), dummy_block) 

407 else: 

408 a_desc = TensorDescriptor(a, a.T.shape, a.T.stride(), dummy_block) 

409 if b_row_major: 

410 b_desc = TensorDescriptor(b, b.shape, b.stride(), dummy_block) 

411 else: 

412 b_desc = TensorDescriptor(b, b.T.shape, b.T.stride(), dummy_block) 

413 c_desc = TensorDescriptor(c, c.shape, c.stride(), dummy_block) 

414 

415 input_dtype = a.dtype 

416 dtype_str = str(input_dtype).split(".")[-1] 

417 

418 with torch_device_fn.device(a.device): 

419 mm_kernel_general_host_tma[grid]( 

420 a_desc, 

421 b_desc, 

422 c_desc, 

423 M, 

424 N, 

425 K, 

426 a.stride(0), 

427 a.stride(1), 

428 b.stride(0), 

429 b.stride(1), 

430 c.stride(0), 

431 c.stride(1), 

432 GROUP_M=8, 

433 A_ROW_MAJOR=a_row_major, 

434 B_ROW_MAJOR=b_row_major, 

435 dtype=dtype_str, 

436 ) 

437 else: 

438 

439 def alloc_fn(size: int, align: int, stream: Optional[int]): 

440 return torch.empty(size, dtype=torch.int8, device=a.device) 

441 

442 triton.set_allocator(alloc_fn) 

443 

444 with torch_device_fn.device(a.device): 

445 mm_kernel_general[grid]( 

446 a, 

447 b, 

448 c, 

449 M, 

450 N, 

451 K, 

452 a.stride(0), 

453 a.stride(1), 

454 b.stride(0), 

455 b.stride(1), 

456 c.stride(0), 

457 c.stride(1), 

458 GROUP_M=8, 

459 IS_FP64=a.dtype == torch.float64, 

460 ) 

461 return c 

462 

463 

464@libentry() 

465@libtuner( 

466 configs=runtime.ops_get_configs( 

467 "gemv", pre_hook=None, yaml_path=EXPAND_CONFIG_FILENAME 

468 ) 

469 if os.environ.get("USE_FLAGTUNE") == "1" 

470 else [ 

471 triton.Config( 

472 {"BLOCK_M": 32, "BLOCK_K": 256}, 

473 ) 

474 ], 

475 key=["M", "K", "stride_am", "stride_bk"], 

476 strategy=runtime.get_expand_config("gemv", yaml_path=EXPAND_CONFIG_FILENAME)[ 

477 "strategy" 

478 ] 

479 if os.environ.get("USE_FLAGTUNE") == "1" 

480 else ["align32", "align32", "align32", "default"], 

481 warmup=5, 

482 rep=10, 

483) 

484@triton.jit 

485def gemv_kernel( 

486 A, 

487 B, 

488 C, 

489 M, 

490 K, 

491 stride_am, 

492 stride_ak, 

493 stride_bk, 

494 BLOCK_M: tl.constexpr, 

495 BLOCK_K: tl.constexpr, 

496 IS_FP64: tl.constexpr = False, 

497): 

498 """Optimized kernel for matrix-vector multiplication (N=1 case)""" 

499 pid = tl.program_id(0) 

500 

501 # Each program handles BLOCK_M rows 

502 row_start = pid * BLOCK_M 

503 row_offset = row_start + tl.arange(0, BLOCK_M) 

504 row_mask = row_offset < M 

505 

506 # Accumulator for this block of rows 

507 if IS_FP64: 

508 acc = tl.zeros((BLOCK_M,), dtype=tl.float64) 

509 else: 

510 acc = tl.zeros((BLOCK_M,), dtype=tl.float32) 

511 

512 # Iterate over K dimension 

513 for k_start in range(0, K, BLOCK_K): 

514 k_offset = k_start + tl.arange(0, BLOCK_K) 

515 k_mask = k_offset < K 

516 

517 # Load block from matrix A: [BLOCK_M, BLOCK_K] 

518 a_ptrs = A + row_offset[:, None] * stride_am + k_offset[None, :] * stride_ak 

519 a = tl.load(a_ptrs, mask=row_mask[:, None] & k_mask[None, :], other=0.0) 

520 

521 # Load block from vector B: [BLOCK_K] 

522 b_ptrs = B + k_offset * stride_bk 

523 b = tl.load(b_ptrs, mask=k_mask, other=0.0) 

524 

525 # Accumulate: sum over K dimension 

526 if IS_FP64: 

527 acc += tl.sum(a * b[None, :], axis=1) 

528 else: 

529 acc += tl.sum(a.to(tl.float32) * b.to(tl.float32)[None, :], axis=1) 

530 

531 # Store result 

532 c_ptrs = C + row_offset 

533 acc = acc.to(C.dtype.element_ty) 

534 tl.store(c_ptrs, acc, mask=row_mask) 

535 

536 

537def gemv_mm(a, b, c, M, K): 

538 """Optimized matrix-vector multiplication for N=1 case""" 

539 logger.debug( 

540 "GEMS MM-hopper, [mm scenario]: gemv (N=1), [shape info]: [%s, %s, 1](M, K, N)", 

541 M, 

542 K, 

543 ) 

544 

545 grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]),) 

546 

547 with torch_device_fn.device(a.device): 

548 gemv_kernel[grid]( 

549 a, 

550 b, 

551 c, 

552 M, 

553 K, 

554 a.stride(0), 

555 a.stride(1), 

556 b.stride(0), 

557 IS_FP64=a.dtype == torch.float64, 

558 ) 

559 return c 

560 

561 

562def streamk_scenario(a, b, M, N, K): 

563 # TODO: this my change sometime according to the realbenchmark result 

564 # Currently, the best configuration for streamk has only been tested on A100(capability[0] == 8). 

565 # The optimal settings for other devices need to be determined through real testing. 

566 capability = get_device_capability() 

567 return ( 

568 capability[0] == 8 

569 and a.dtype in [torch.float16, torch.bfloat16] 

570 and b.dtype in [torch.float16, torch.bfloat16] 

571 and a.is_contiguous() 

572 and b.is_contiguous() 

573 and K > M * 5 

574 and K > N * 5 

575 ) 

576 

577 

578def mm(a, b): 

579 device = a.device 

580 # handle non-contiguous inputs if necessary 

581 if a.stride(0) > 1 and a.stride(1) > 1: 

582 a = a.contiguous() 

583 if b.stride(0) > 1 and b.stride(1) > 1: 

584 b = b.contiguous() 

585 # checks constraints 

586 assert a.shape[1] == b.shape[0], "incompatible dimensions" 

587 M, K = a.shape 

588 _, N = b.shape 

589 # allocates output 

590 c_dtype = get_higher_dtype(a.dtype, b.dtype) 

591 c = torch.empty((M, N), device=device, dtype=c_dtype) 

592 

593 # Optimize for N=1 case (matrix-vector multiplication) 

594 if N == 1: 

595 return gemv_mm(a, b, c, M, K) 

596 # l2_cache_size = get_l2_cache_size() 

597 sm_count = get_sm_count() 

598 if streamk_scenario(a, b, M, N, K): 

599 return streamk_mm(a, b, c, M, N, K, sm_count=sm_count) 

600 else: 

601 return general_mm(a, b, c, M, N, K) 

602 

603 

604def mm_out(a, b, *, out): 

605 # handle non-contiguous inputs if necessary 

606 if a.stride(0) > 1 and a.stride(1) > 1: 

607 a = a.contiguous() 

608 if b.stride(0) > 1 and b.stride(1) > 1: 

609 b = b.contiguous() 

610 # checks constraints 

611 assert a.shape[1] == b.shape[0], "incompatible dimensions" 

612 M, K = a.shape 

613 _, N = b.shape 

614 

615 # Optimize for N=1 case (matrix-vector multiplication) 

616 if N == 1: 

617 return gemv_mm(a, b, out, M, K) 

618 # l2_cache_size = get_l2_cache_size() 

619 sm_count = get_sm_count() 

620 if streamk_scenario(a, b, M, N, K): 

621 return streamk_mm(a, b, out, M, N, K, sm_count=sm_count) 

622 else: 

623 return general_mm(a, b, out, M, N, K)