Coverage for src/flag_gems/runtime/backend/_nvidia/hopper/ops/mm.py: 0%

216 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-07 22:33 +0800

1import logging 

2from typing import Optional 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.ops.mm_streamk import streamk_mm 

10from flag_gems.runtime import torch_device_fn 

11from flag_gems.utils import libentry, libtuner 

12from flag_gems.utils import triton_lang_extension as tle 

13from flag_gems.utils.device_info import get_device_capability, get_sm_count 

14 

15logger = logging.getLogger(__name__) 

16CACHE_USAGE_THRESHOLD = 0.8 

17 

18 

19def is_tma_compatible(a, b, N, K): 

20 """ 

21 Check if tensors are compatible with TMA (Tensor Memory Accelerator). 

22 

23 TMA requires 128-bit (16-byte) alignment for memory access: 

24 - For FP16/BF16 (2 bytes/element): N and K must be multiples of 8 

25 (8 elements × 2 bytes = 16 bytes) 

26 - For FP32 (4 bytes/element): N and K must be multiples of 4 

27 (4 elements × 4 bytes = 16 bytes) 

28 

29 Args: 

30 a, b: Input tensors 

31 N, K: Matrix dimensions 

32 

33 Returns: 

34 bool: True if compatible with TMA's 128-bit alignment requirement 

35 """ 

36 return ( 

37 a.dtype in (torch.float16, torch.bfloat16) 

38 and b.dtype in (torch.float16, torch.bfloat16) 

39 and N % 8 == 0 

40 and K % 8 == 0 

41 ) or ( 

42 a.dtype in (torch.float32,) 

43 and b.dtype in (torch.float32,) 

44 and N % 4 == 0 

45 and K % 4 == 0 

46 ) 

47 

48 

49@triton.jit 

50def prev_multiple_of(a, b): 

51 # the largest x<a that x%b ==0 

52 return tl.cdiv(a, b) * b - b 

53 

54 

55@libentry() 

56@libtuner( 

57 configs=runtime.get_tuned_config("mm"), 

58 # Add 'stride_am' and 'stride_bk' to trigger autotune for tensors with the same shape but different strides. 

59 key=["M", "N", "K", "stride_am", "stride_bk"], 

60 strategy=["default", "default", "default", "default", "default"], 

61 warmup=5, 

62 rep=10, 

63) 

64@triton.jit 

65def mm_kernel_general( 

66 A, 

67 B, 

68 C, 

69 M, 

70 N, 

71 K, 

72 stride_am, 

73 stride_ak, 

74 stride_bk, 

75 stride_bn, 

76 stride_cm, 

77 stride_cn, 

78 BLOCK_M: tl.constexpr, 

79 BLOCK_N: tl.constexpr, 

80 BLOCK_K: tl.constexpr, 

81 GROUP_M: tl.constexpr, 

82): 

83 # matrix multiplication 

84 pid = tle.program_id(0) 

85 grid_m = tl.cdiv(M, BLOCK_M) 

86 grid_n = tl.cdiv(N, BLOCK_N) 

87 # re-order program ID for better L2 performance 

88 width = GROUP_M * grid_n 

89 group_id = pid // width 

90 group_size = min(grid_m - group_id * GROUP_M, GROUP_M) 

91 pid_m = group_id * GROUP_M + (pid % group_size) 

92 pid_n = (pid % width) // (group_size) 

93 

94 if M % BLOCK_M == 0 and N % BLOCK_N == 0 and K % BLOCK_K == 0: 

95 # offset 

96 offset_am = pid_m * BLOCK_M 

97 offset_bn = pid_n * BLOCK_N 

98 offset_k = 0 

99 

100 a_desc = tl.make_tensor_descriptor( 

101 base=A, 

102 shape=[M, K], 

103 strides=[K, 1], 

104 block_shape=[BLOCK_M, BLOCK_K], 

105 ) 

106 

107 # row-major 

108 b_desc = tl.make_tensor_descriptor( 

109 base=B, 

110 shape=[K, N], 

111 strides=[N, 1], 

112 block_shape=[BLOCK_K, BLOCK_N], 

113 ) 

114 

115 # column-major 

116 # b_desc = tl.make_tensor_descriptor( 

117 # B, 

118 # shape = [N, K], 

119 # strides = [K, 1], 

120 # block_shape = [BLOCK_N, BLOCK_K], 

121 # ) 

122 

123 c_desc = tl.make_tensor_descriptor( 

124 base=C, 

125 shape=[M, N], 

126 strides=[N, 1], 

127 block_shape=[BLOCK_M, BLOCK_N], 

128 ) 

129 

130 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

131 for k in range(0, tl.cdiv(K, BLOCK_K)): 

132 a = a_desc.load([offset_am.to(tl.int32), offset_k.to(tl.int32)]) 

133 b = b_desc.load([offset_k.to(tl.int32), offset_bn.to(tl.int32)]) 

134 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False) 

135 offset_k += BLOCK_K 

136 

137 acc = acc.to(a_desc.dtype) 

138 c_desc.store([offset_am.to(tl.int32), offset_bn.to(tl.int32)], acc) 

139 

140 else: 

141 # do matrix multiplication 

142 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

143 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

144 ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M).to(tl.int64) 

145 rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N).to(tl.int64) 

146 rm = rm.to(tl.int64) 

147 rn = rn.to(tl.int64) 

148 prev_multiple = prev_multiple_of(K, BLOCK_K) 

149 

150 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

151 for start_k in range(0, prev_multiple, BLOCK_K): 

152 rk = (start_k + tl.arange(0, BLOCK_K)).to(tl.int64) 

153 a = tl.load(A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)) 

154 b = tl.load(B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)) 

155 if a.dtype != b.dtype: 

156 a = a.to(C.dtype.element_ty) 

157 b = b.to(C.dtype.element_ty) 

158 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False) 

159 

160 # loop peeling 

161 rk = (prev_multiple + tl.arange(0, BLOCK_K)).to(tl.int64) 

162 mask_k = rk < K 

163 a = tl.load( 

164 A + (ram[:, None] * stride_am + rk[None, :] * stride_ak), 

165 mask=mask_k[None, :], 

166 other=0.0, 

167 ) 

168 b = tl.load( 

169 B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn), 

170 mask=mask_k[:, None], 

171 other=0.0, 

172 ) 

173 if a.dtype != b.dtype: 

174 a = a.to(C.dtype.element_ty) 

175 b = b.to(C.dtype.element_ty) 

176 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False) 

177 

178 acc = acc.to(C.dtype.element_ty) 

179 # rematerialize rm and rn to save registers 

180 rm = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)).to(tl.int64) 

181 rn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)).to(tl.int64) 

182 offsets = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) 

183 mask = (rm < M)[:, None] & (rn < N)[None, :] 

184 # handles write-back with reduction-splitting 

185 tl.store(offsets, acc, mask=mask) 

186 

187 

188def matmul_tma_set_block_size_hook(nargs): 

189 BLOCK_M = nargs["BLOCK_M"] 

190 BLOCK_N = nargs["BLOCK_N"] 

191 BLOCK_K = nargs["BLOCK_K"] 

192 if nargs["A_ROW_MAJOR"]: 

193 nargs["a_desc"].block_shape = [BLOCK_M, BLOCK_K] 

194 else: 

195 nargs["a_desc"].block_shape = [BLOCK_K, BLOCK_M] 

196 

197 if nargs["B_ROW_MAJOR"]: 

198 nargs["b_desc"].block_shape = [BLOCK_K, BLOCK_N] 

199 else: 

200 nargs["b_desc"].block_shape = [BLOCK_N, BLOCK_K] 

201 

202 nargs["c_desc"].block_shape = [BLOCK_M, BLOCK_N] 

203 

204 

205def matmul_get_configs(pre_hook=matmul_tma_set_block_size_hook): 

206 return [ 

207 triton.Config( 

208 {"BLOCK_M": BM, "BLOCK_N": BN, "BLOCK_K": BK}, 

209 num_stages=s, 

210 num_warps=w, 

211 pre_hook=pre_hook, 

212 ) 

213 for BM in [32, 64, 128, 256] 

214 for BN in [32, 64, 128] 

215 for BK in [32, 64, 128] 

216 for s in [2, 3, 4] 

217 for w in [4, 8] 

218 ] 

219 

220 

221@libentry() 

222@libtuner( 

223 configs=matmul_get_configs(), 

224 key=["M", "N", "K", "stride_am", "stride_bk", "dtype"], 

225 strategy=["align32", "align32", "align32", "align32", "align32", "default"], 

226 warmup=5, 

227 rep=5, 

228) 

229@triton.jit 

230def mm_kernel_general_host_tma( 

231 a_desc, 

232 b_desc, 

233 c_desc, 

234 M, 

235 N, 

236 K, 

237 stride_am, 

238 stride_ak, 

239 stride_bk, 

240 stride_bn, 

241 stride_cm, 

242 stride_cn, 

243 BLOCK_M: tl.constexpr, 

244 BLOCK_N: tl.constexpr, 

245 BLOCK_K: tl.constexpr, 

246 GROUP_M: tl.constexpr, 

247 A_ROW_MAJOR: tl.constexpr, 

248 B_ROW_MAJOR: tl.constexpr, 

249 dtype: tl.constexpr, 

250 enable_warp_specialization=True, 

251): 

252 pid = tl.program_id(0) 

253 grid_m = tl.cdiv(M, BLOCK_M) 

254 grid_n = tl.cdiv(N, BLOCK_N) 

255 

256 width = GROUP_M * grid_n 

257 group_id = pid // width 

258 group_size = min(grid_m - group_id * GROUP_M, GROUP_M) 

259 pid_m = group_id * GROUP_M + (pid % group_size) 

260 pid_n = (pid % width) // (group_size) 

261 

262 accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

263 offset_am = (pid_m * BLOCK_M).to(tl.int32) 

264 offset_bn = (pid_n * BLOCK_N).to(tl.int32) 

265 iters = tl.cdiv(K, BLOCK_K) 

266 for k in range(iters): 

267 offset_ak = (k * BLOCK_K).to(tl.int32) 

268 

269 if A_ROW_MAJOR: 

270 a = a_desc.load([offset_am, offset_ak]) 

271 else: 

272 a_t = a_desc.load([offset_ak, offset_am]) 

273 a = tl.trans(a_t) 

274 

275 if B_ROW_MAJOR: 

276 b = b_desc.load([offset_ak, offset_bn]) 

277 else: 

278 b_t = b_desc.load([offset_bn, offset_ak]) 

279 b = tl.trans(b_t) 

280 

281 if a_desc.dtype == tl.float16 or a_desc.dtype == tl.bfloat16: 

282 accumulator = tl.dot(a, b, acc=accumulator, allow_tf32=False) 

283 else: 

284 accumulator = tl.dot(a, b, acc=accumulator, input_precision="tf32x3") 

285 

286 c = accumulator.to(c_desc.dtype) 

287 c_desc.store([offset_am, offset_bn], c) 

288 

289 

290def get_higher_dtype(a, b): 

291 _ordered_datatypes = [torch.float16, torch.bfloat16, torch.float32] 

292 

293 if a is b: 

294 return a 

295 

296 assert a in _ordered_datatypes 

297 assert b in _ordered_datatypes 

298 

299 for d in _ordered_datatypes: 

300 if a is d: 

301 return b 

302 if b is d: 

303 return a 

304 

305 

306def general_mm(a, b, c, M, N, K): 

307 # TODO: Remove this debug message 

308 logger.debug( 

309 "GEMS MM-hopper, [mm scenario]: general, [shape info]: [-, %s, %s, %s](batch, M, N, K), " 

310 "[A column-major]: %s, [B column-major]: %s", 

311 M, 

312 N, 

313 K, 

314 a.stride(0) == 1, 

315 b.stride(0) == 1, 

316 ) 

317 grid = lambda META: ( 

318 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), 

319 ) 

320 if hasattr( 

321 triton.tools.tensor_descriptor, "TensorDescriptor" 

322 ) and is_tma_compatible(a, b, N, K): 

323 a_row_major = a.stride(1) == 1 

324 b_row_major = b.stride(1) == 1 

325 dummy_block = [1, 1] 

326 # triton 3.5.0 

327 from triton.tools.tensor_descriptor import TensorDescriptor 

328 

329 if a_row_major: 

330 a_desc = TensorDescriptor(a, a.shape, a.stride(), dummy_block) 

331 else: 

332 a_desc = TensorDescriptor(a, a.T.shape, a.T.stride(), dummy_block) 

333 if b_row_major: 

334 b_desc = TensorDescriptor(b, b.shape, b.stride(), dummy_block) 

335 else: 

336 b_desc = TensorDescriptor(b, b.T.shape, b.T.stride(), dummy_block) 

337 c_desc = TensorDescriptor(c, c.shape, c.stride(), dummy_block) 

338 

339 input_dtype = a.dtype 

340 dtype_str = str(input_dtype).split(".")[-1] 

341 

342 with torch_device_fn.device(a.device): 

343 mm_kernel_general_host_tma[grid]( 

344 a_desc, 

345 b_desc, 

346 c_desc, 

347 M, 

348 N, 

349 K, 

350 a.stride(0), 

351 a.stride(1), 

352 b.stride(0), 

353 b.stride(1), 

354 c.stride(0), 

355 c.stride(1), 

356 GROUP_M=8, 

357 A_ROW_MAJOR=a_row_major, 

358 B_ROW_MAJOR=b_row_major, 

359 dtype=dtype_str, 

360 ) 

361 else: 

362 

363 def alloc_fn(size: int, align: int, stream: Optional[int]): 

364 return torch.empty(size, dtype=torch.int8, device=a.device) 

365 

366 triton.set_allocator(alloc_fn) 

367 

368 with torch_device_fn.device(a.device): 

369 mm_kernel_general[grid]( 

370 a, 

371 b, 

372 c, 

373 M, 

374 N, 

375 K, 

376 a.stride(0), 

377 a.stride(1), 

378 b.stride(0), 

379 b.stride(1), 

380 c.stride(0), 

381 c.stride(1), 

382 GROUP_M=8, 

383 ) 

384 return c 

385 

386 

387@libentry() 

388@triton.jit 

389def gemv_kernel( 

390 A, 

391 B, 

392 C, 

393 M, 

394 K, 

395 stride_am, 

396 stride_ak, 

397 stride_bk, 

398 BLOCK_M: tl.constexpr, 

399 BLOCK_K: tl.constexpr, 

400): 

401 """Optimized kernel for matrix-vector multiplication (N=1 case)""" 

402 pid = tl.program_id(0) 

403 

404 # Each program handles BLOCK_M rows 

405 row_start = pid * BLOCK_M 

406 row_offset = row_start + tl.arange(0, BLOCK_M) 

407 row_mask = row_offset < M 

408 

409 # Accumulator for this block of rows 

410 acc = tl.zeros((BLOCK_M,), dtype=tl.float32) 

411 

412 # Iterate over K dimension 

413 for k_start in range(0, K, BLOCK_K): 

414 k_offset = k_start + tl.arange(0, BLOCK_K) 

415 k_mask = k_offset < K 

416 

417 # Load block from matrix A: [BLOCK_M, BLOCK_K] 

418 a_ptrs = A + row_offset[:, None] * stride_am + k_offset[None, :] * stride_ak 

419 a = tl.load(a_ptrs, mask=row_mask[:, None] & k_mask[None, :], other=0.0) 

420 

421 # Load block from vector B: [BLOCK_K] 

422 b_ptrs = B + k_offset * stride_bk 

423 b = tl.load(b_ptrs, mask=k_mask, other=0.0) 

424 

425 # Accumulate: sum over K dimension 

426 acc += tl.sum(a * b[None, :], axis=1) 

427 

428 # Store result 

429 c_ptrs = C + row_offset 

430 acc = acc.to(C.dtype.element_ty) 

431 tl.store(c_ptrs, acc, mask=row_mask) 

432 

433 

434def gemv_mm(a, b, c, M, K): 

435 """Optimized matrix-vector multiplication for N=1 case""" 

436 logger.debug( 

437 "GEMS MM-hopper, [mm scenario]: gemv (N=1), [shape info]: [%s, %s, 1](M, K, N)", 

438 M, 

439 K, 

440 ) 

441 

442 BLOCK_M = 32 

443 BLOCK_K = 256 

444 grid = lambda META: (triton.cdiv(M, BLOCK_M),) 

445 

446 with torch_device_fn.device(a.device): 

447 gemv_kernel[grid]( 

448 a, 

449 b, 

450 c, 

451 M, 

452 K, 

453 a.stride(0), 

454 a.stride(1), 

455 b.stride(0), 

456 BLOCK_M=BLOCK_M, 

457 BLOCK_K=BLOCK_K, 

458 ) 

459 return c 

460 

461 

462def streamk_scenario(a, b, M, N, K): 

463 # TODO: this my change sometime according to the realbenchmark result 

464 # Currently, the best configuration for streamk has only been tested on A100(capability[0] == 8). 

465 # The optimal settings for other devices need to be determined through real testing. 

466 capability = get_device_capability() 

467 return ( 

468 capability[0] == 8 

469 and a.dtype in [torch.float16, torch.bfloat16] 

470 and b.dtype in [torch.float16, torch.bfloat16] 

471 and a.is_contiguous() 

472 and b.is_contiguous() 

473 and K > M * 5 

474 and K > N * 5 

475 ) 

476 

477 

478def mm(a, b): 

479 device = a.device 

480 # handle non-contiguous inputs if necessary 

481 if a.stride(0) > 1 and a.stride(1) > 1: 

482 a = a.contiguous() 

483 if b.stride(0) > 1 and b.stride(1) > 1: 

484 b = b.contiguous() 

485 # checks constraints 

486 assert a.shape[1] == b.shape[0], "incompatible dimensions" 

487 M, K = a.shape 

488 _, N = b.shape 

489 # allocates output 

490 c_dtype = get_higher_dtype(a.dtype, b.dtype) 

491 c = torch.empty((M, N), device=device, dtype=c_dtype) 

492 

493 # Optimize for N=1 case (matrix-vector multiplication) 

494 if N == 1: 

495 return gemv_mm(a, b, c, M, K) 

496 # l2_cache_size = get_l2_cache_size() 

497 sm_count = get_sm_count() 

498 if streamk_scenario(a, b, M, N, K): 

499 return streamk_mm(a, b, c, M, N, K, sm_count=sm_count) 

500 else: 

501 return general_mm(a, b, c, M, N, K) 

502 

503 

504def mm_out(a, b, *, out): 

505 # handle non-contiguous inputs if necessary 

506 if a.stride(0) > 1 and a.stride(1) > 1: 

507 a = a.contiguous() 

508 if b.stride(0) > 1 and b.stride(1) > 1: 

509 b = b.contiguous() 

510 # checks constraints 

511 assert a.shape[1] == b.shape[0], "incompatible dimensions" 

512 M, K = a.shape 

513 _, N = b.shape 

514 

515 # Optimize for N=1 case (matrix-vector multiplication) 

516 if N == 1: 

517 return gemv_mm(a, b, out, M, K) 

518 # l2_cache_size = get_l2_cache_size() 

519 sm_count = get_sm_count() 

520 if streamk_scenario(a, b, M, N, K): 

521 return streamk_mm(a, b, out, M, N, K, sm_count=sm_count) 

522 else: 

523 return general_mm(a, b, out, M, N, K)