Coverage for src/flag_gems/ops/mm_streamk.py: 22%

234 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-23 02:03 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils import libentry 

8from flag_gems.utils import triton_lang_extension as tle 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13@triton.jit 

14def prev_multiple_of(a, b): 

15 # the largest x<a that x%b ==0 

16 return tl.cdiv(a, b) * b - b 

17 

18 

19@triton.jit() 

20def swizzle_tile( 

21 tile_id, 

22 M, 

23 N, 

24 BLOCK_M: tl.constexpr, 

25 BLOCK_N: tl.constexpr, 

26 GROUP_M: tl.constexpr, 

27): 

28 grid_m = tl.cdiv(M, BLOCK_M) 

29 grid_n = tl.cdiv(N, BLOCK_N) 

30 # re-order program ID for better L2 performance 

31 width = GROUP_M * grid_n 

32 group_id = tile_id // width 

33 group_size = tl.minimum(grid_m - group_id * GROUP_M, GROUP_M) 

34 pid_m = group_id * GROUP_M + (tile_id % group_size) 

35 pid_n = (tile_id % width) // group_size 

36 return pid_m, pid_n 

37 

38 

39@triton.jit() 

40def linear_tile( 

41 tile_id, 

42 M, 

43 N, 

44 BLOCK_M: tl.constexpr, 

45 BLOCK_N: tl.constexpr, 

46 GROUP_M: tl.constexpr, 

47): 

48 grid_n = tl.cdiv(N, BLOCK_N) 

49 

50 # column first 

51 pid_m = tile_id // grid_n 

52 pid_n = tile_id % grid_n 

53 

54 return pid_m, pid_n 

55 

56 

57@triton.jit( 

58 do_not_specialize=[ 

59 "iters_per_pid", 

60 "iters_remaining", 

61 "iters_per_tile", 

62 "start_iter", 

63 "end_iter", 

64 ] 

65) 

66def mac_loop( 

67 A, 

68 B, 

69 C, 

70 P, 

71 M, 

72 N, 

73 K, 

74 locks, 

75 stride_am, 

76 stride_ak, 

77 stride_bk, 

78 stride_bn, 

79 stride_cm, 

80 stride_cn, 

81 iters_per_pid, 

82 iters_remaining, 

83 iters_per_tile, 

84 start_iter, 

85 end_iter, 

86 BLOCK_M: tl.constexpr, 

87 BLOCK_N: tl.constexpr, 

88 BLOCK_K: tl.constexpr, 

89 GROUP_M: tl.constexpr, 

90): 

91 # where are we in the grid 

92 pid = tle.program_id(0) 

93 tile_id = start_iter // iters_per_tile 

94 

95 pid_m, pid_n = swizzle_tile(tile_id, M, N, BLOCK_M, BLOCK_N, GROUP_M) 

96 

97 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

98 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

99 rk = tl.arange(0, BLOCK_K) 

100 

101 if stride_am == 1: 

102 ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) 

103 else: 

104 ram = rm % M 

105 if stride_bk == 1: 

106 rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) 

107 else: 

108 rbn = rn % N 

109 

110 # pointers 

111 A_base = A + ram[:, None] * stride_am 

112 B_base = B + rbn[None, :] * stride_bn 

113 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

114 

115 if end_iter % iters_per_tile != 0: 

116 for current_iter in range(start_iter, end_iter): 

117 k_offset_in_tile = (current_iter % iters_per_tile) * BLOCK_K 

118 a = tl.load(A_base + (k_offset_in_tile + rk[None, :]) * stride_ak) 

119 b = tl.load(B_base + (k_offset_in_tile + rk[:, None]) * stride_bk) 

120 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False) 

121 else: 

122 prev_multiple = prev_multiple_of(K, BLOCK_K) 

123 for current_iter in range(start_iter, end_iter - 1): 

124 k_offset_in_tile = (current_iter % iters_per_tile) * BLOCK_K 

125 a = tl.load(A_base + (k_offset_in_tile + rk[None, :]) * stride_ak) 

126 b = tl.load(B_base + (k_offset_in_tile + rk[:, None]) * stride_bk) 

127 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False) 

128 

129 # handle the last iter 

130 rk = prev_multiple + tl.arange(0, BLOCK_K) 

131 mask_k = rk < K 

132 a = tl.load(A_base + rk[None, :] * stride_ak, mask=mask_k[None, :]) 

133 b = tl.load(B_base + rk[:, None] * stride_bk, mask=mask_k[:, None]) 

134 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False) 

135 

136 rm1 = tl.arange(0, BLOCK_M) 

137 rn1 = tl.arange(0, BLOCK_N) 

138 

139 # the first situation: not the starting parts. only need to store the data on P 

140 if start_iter % iters_per_tile != 0: 

141 P_ptr = P + pid * BLOCK_M * BLOCK_N + (rm1[:, None] * BLOCK_N + rn1[None, :]) 

142 tl.store(P_ptr, acc, cache_modifier=".cg") 

143 # tl.debug_barrier() 

144 tl.atomic_xchg(locks + pid, 1) 

145 else: # the first part of certain grids. shoud read datas and merge datas 

146 next_pid = pid + 1 

147 stop_loading_iter = start_iter + iters_per_tile 

148 end = end_iter 

149 while end < stop_loading_iter: 

150 while tl.atomic_cas(locks + next_pid, 1, 1) != 1: 

151 pass 

152 P_ptr = ( 

153 P 

154 + next_pid * BLOCK_M * BLOCK_N 

155 + (rm1[:, None] * BLOCK_N + rn1[None, :]) 

156 ) 

157 acc += tl.load(P_ptr, cache_modifier=".cg") 

158 end += iters_per_pid + (next_pid < iters_remaining) 

159 next_pid += 1 

160 

161 # acc = acc.to(C.dtype.element_ty) # 

162 C_ = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) 

163 mask = (rm < M)[:, None] & (rn < N)[None, :] 

164 tl.store(C_, acc, mask=mask) 

165 

166 

167@libentry() 

168@triton.jit( 

169 do_not_specialize=[ 

170 "iters_per_pid", 

171 "iters_remaining", 

172 "iters_per_tile", 

173 ], 

174) 

175def first_wave( 

176 A, 

177 B, 

178 C, 

179 M, 

180 N, 

181 K, 

182 locks, 

183 stride_am, 

184 stride_ak, 

185 stride_bk, 

186 stride_bn, 

187 stride_cm, 

188 stride_cn, 

189 iters_per_pid, 

190 iters_remaining, 

191 iters_per_tile, 

192 BLOCK_M: tl.constexpr, 

193 BLOCK_N: tl.constexpr, 

194 BLOCK_K: tl.constexpr, 

195 GROUP_M: tl.constexpr, 

196 EVEN_K: tl.constexpr, 

197): 

198 pid = tle.program_id(0) # pid range from 0 to sm_count 

199 start_iter = pid * iters_per_pid + tl.minimum(pid, iters_remaining) 

200 last_iter = (pid + 1) * iters_per_pid + tl.minimum(pid + 1, iters_remaining) 

201 while start_iter < last_iter: 

202 iter_offset_in_tile = start_iter % iters_per_tile 

203 # Iterate over the K axis. Recalculate end_iter as M/N may change during the iteration. 

204 end_iter = tl.minimum( 

205 start_iter + (iters_per_tile - iter_offset_in_tile), last_iter 

206 ) 

207 

208 tile_id = start_iter // iters_per_tile 

209 

210 pid_m, pid_n = swizzle_tile(tile_id, M, N, BLOCK_M, BLOCK_N, GROUP_M) 

211 

212 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

213 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

214 rk = tl.arange(0, BLOCK_K) 

215 

216 ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) 

217 rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) 

218 

219 A_base = ( 

220 A 

221 + ram[:, None] * stride_am 

222 + rk[None, :] * stride_ak 

223 + BLOCK_K * stride_ak * iter_offset_in_tile 

224 ) 

225 B_base = ( 

226 B 

227 + rk[:, None] * stride_bk 

228 + rbn[None, :] * stride_bn 

229 + BLOCK_K * stride_bk * iter_offset_in_tile 

230 ) 

231 

232 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

233 

234 for current_iter in range(start_iter, end_iter): 

235 if EVEN_K: 

236 a = tl.load(A_base) 

237 b = tl.load(B_base) 

238 else: 

239 k_offset_in_tile = (current_iter % iters_per_tile) * BLOCK_K 

240 k_mask = (k_offset_in_tile + rk) < K 

241 a = tl.load(A_base, mask=k_mask[None, :], other=0.0) 

242 b = tl.load(B_base, mask=k_mask[:, None], other=0.0) 

243 

244 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False) 

245 A_base += BLOCK_K * stride_ak 

246 B_base += BLOCK_K * stride_bk 

247 

248 # last iteration of the tile always happens before its start on another SM 

249 if end_iter % iters_per_tile == 0: 

250 C_ptr = C + ( 

251 rm[:, None] * stride_cm + rn[None, :] * stride_cn 

252 ) # compute inside the if/else to avoid spilling! 

253 mask = (rm < M)[:, None] & (rn < N)[None, :] 

254 tl.store(C_ptr, acc, mask=mask) 

255 if iter_offset_in_tile != 0: # only if tile has been partially processed 

256 tl.atomic_xchg(locks + tile_id, 1) 

257 else: 

258 while tl.atomic_cas(locks + tile_id, 1, 1) != 1: 

259 pass 

260 C_ptr = C + ( 

261 rm[:, None] * stride_cm + rn[None, :] * stride_cn 

262 ) # compute inside the if/else to avoid spilling! 

263 mask = (rm < M)[:, None] & (rn < N)[None, :] 

264 tl.atomic_add(C_ptr, acc, mask=mask, sem="relaxed") 

265 # next round 

266 start_iter = end_iter 

267 

268 

269@libentry() 

270@triton.jit( 

271 do_not_specialize=[ 

272 "iters_per_pid", 

273 "iters_remaining", 

274 "iters_per_tile", 

275 ], 

276) 

277def first_wave_for_bf16( 

278 A, 

279 B, 

280 C, 

281 P, 

282 M, 

283 N, 

284 K, 

285 locks, 

286 stride_am, 

287 stride_ak, 

288 stride_bk, 

289 stride_bn, 

290 stride_cm, 

291 stride_cn, 

292 iters_per_pid, 

293 iters_remaining, 

294 iters_per_tile, 

295 BLOCK_M: tl.constexpr, 

296 BLOCK_N: tl.constexpr, 

297 BLOCK_K: tl.constexpr, 

298 GROUP_M: tl.constexpr, 

299 EVEN_K: tl.constexpr, 

300): 

301 pid = tle.program_id(0) # pid range from 0 to sm_count 

302 start_iter = pid * iters_per_pid + tl.minimum(pid, iters_remaining) 

303 last_iter = (pid + 1) * iters_per_pid + tl.minimum(pid + 1, iters_remaining) 

304 while start_iter < last_iter: 

305 iter_offset_in_tile = start_iter % iters_per_tile 

306 # Iterate over the K axis. Recalculate end_iter as M/N may change during the iteration. 

307 end_iter = tl.minimum( 

308 start_iter + (iters_per_tile - iter_offset_in_tile), last_iter 

309 ) 

310 

311 tile_id = start_iter // iters_per_tile 

312 

313 pid_m, pid_n = swizzle_tile(tile_id, M, N, BLOCK_M, BLOCK_N, GROUP_M) 

314 

315 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

316 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

317 rk = tl.arange(0, BLOCK_K) 

318 

319 ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) 

320 rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) 

321 

322 A_base = ( 

323 A 

324 + ram[:, None] * stride_am 

325 + rk[None, :] * stride_ak 

326 + BLOCK_K * stride_ak * iter_offset_in_tile 

327 ) 

328 B_base = ( 

329 B 

330 + rk[:, None] * stride_bk 

331 + rbn[None, :] * stride_bn 

332 + BLOCK_K * stride_bk * iter_offset_in_tile 

333 ) 

334 

335 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

336 

337 for current_iter in range(start_iter, end_iter): 

338 if EVEN_K: 

339 a = tl.load(A_base) 

340 b = tl.load(B_base) 

341 else: 

342 k_offset_in_tile = (current_iter % iters_per_tile) * BLOCK_K 

343 k_mask = (k_offset_in_tile + rk) < K 

344 a = tl.load(A_base, mask=k_mask[None, :], other=0.0) 

345 b = tl.load(B_base, mask=k_mask[:, None], other=0.0) 

346 

347 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False) 

348 A_base += BLOCK_K * stride_ak 

349 B_base += BLOCK_K * stride_bk 

350 

351 rm1 = tl.arange(0, BLOCK_M) 

352 rn1 = tl.arange(0, BLOCK_N) 

353 

354 # the first situation: not the starting parts. only need to store the data on P 

355 if start_iter % iters_per_tile != 0: 

356 P_ptr = ( 

357 P + pid * BLOCK_M * BLOCK_N + (rm1[:, None] * BLOCK_N + rn1[None, :]) 

358 ) 

359 tl.store(P_ptr, acc, cache_modifier=".cg") 

360 # tl.debug_barrier() 

361 tl.atomic_xchg(locks + pid, 1) 

362 else: # the first part of certain grids. shoud read datas and merge datas 

363 next_pid = pid + 1 

364 stop_loading_iter = start_iter + iters_per_tile 

365 end = end_iter 

366 while end < stop_loading_iter: 

367 while tl.atomic_cas(locks + next_pid, 1, 1) != 1: 

368 pass 

369 P_ptr = ( 

370 P 

371 + next_pid * BLOCK_M * BLOCK_N 

372 + (rm1[:, None] * BLOCK_N + rn1[None, :]) 

373 ) 

374 acc += tl.load(P_ptr, cache_modifier=".cg") 

375 end += iters_per_pid + (next_pid < iters_remaining) 

376 next_pid += 1 

377 

378 # acc = acc.to(C.dtype.element_ty) # 

379 C_ = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) 

380 mask = (rm < M)[:, None] & (rn < N)[None, :] 

381 tl.store(C_, acc, mask=mask) 

382 start_iter = end_iter 

383 

384 

385@libentry() 

386@triton.jit 

387def classic_mm( 

388 A, 

389 B, 

390 C, 

391 M, 

392 N, 

393 K, 

394 stride_am, 

395 stride_ak, 

396 stride_bk, 

397 stride_bn, 

398 stride_cm, 

399 stride_cn, 

400 total_tiles_streamk, 

401 BLOCK_M: tl.constexpr, 

402 BLOCK_N: tl.constexpr, 

403 BLOCK_K: tl.constexpr, 

404 GROUP_M: tl.constexpr, 

405): 

406 # first wave has done more tiles than there are SMs, we adjust pid 

407 tile_id = tle.program_id(0) + total_tiles_streamk 

408 pid_m, pid_n = swizzle_tile(tile_id, M, N, BLOCK_M, BLOCK_N, GROUP_M) 

409 

410 # do matrix multiplication 

411 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

412 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

413 # pointers 

414 ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) 

415 rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) 

416 prev_multiple = prev_multiple_of(K, BLOCK_K) 

417 

418 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

419 for start_k in range(0, prev_multiple, BLOCK_K): 

420 rk = start_k + tl.arange(0, BLOCK_K) 

421 a = tl.load(A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)) 

422 b = tl.load(B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)) 

423 if a.dtype != b.dtype: 

424 a = a.to(C.dtype.element_ty) 

425 b = b.to(C.dtype.element_ty) 

426 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False) 

427 

428 # loop peeling 

429 rk = prev_multiple + tl.arange(0, BLOCK_K) 

430 mask_k = rk < K 

431 a = tl.load( 

432 A + (ram[:, None] * stride_am + rk[None, :] * stride_ak), mask=mask_k[None, :] 

433 ) 

434 b = tl.load( 

435 B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn), mask=mask_k[:, None] 

436 ) 

437 if a.dtype != b.dtype: 

438 a = a.to(C.dtype.element_ty) 

439 b = b.to(C.dtype.element_ty) 

440 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False) 

441 

442 acc = acc.to(C.dtype.element_ty) 

443 # rematerialize rm and rn to save registers 

444 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

445 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

446 C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) 

447 mask = (rm < M)[:, None] & (rn < N)[None, :] 

448 # handles write-back with reduction-splitting 

449 tl.store(C, acc, mask=mask) 

450 

451 

452def streamk_mm(a, b, c, M, N, K, sm_count=108): 

453 logger.debug( 

454 "GEMS MM, [mm scenario]: streamk, [shape info]: [-, %s, %s, %s](batch, M, N, K), " 

455 "[A column-major]: %s, [B column-major]: %s", 

456 M, 

457 N, 

458 K, 

459 a.stride(0) == 1, 

460 b.stride(0) == 1, 

461 ) 

462 # TODO: change the hard code to tuning config 

463 BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 128 

464 num_stages = 3 

465 num_warps = 8 

466 GROUP_M = 8 

467 number_blocks_m = triton.cdiv(M, BLOCK_M) 

468 number_blocks_n = triton.cdiv(N, BLOCK_N) 

469 

470 total_tiles = number_blocks_m * number_blocks_n 

471 iters_per_tile = triton.cdiv(K, BLOCK_K) 

472 tiles_per_wave = sm_count 

473 

474 number_cooperative_tiles = total_tiles % tiles_per_wave 

475 number_other_tiles = total_tiles - number_cooperative_tiles 

476 if number_other_tiles > 0 and number_cooperative_tiles < sm_count * 0.5: 

477 number_cooperative_tiles = number_cooperative_tiles + tiles_per_wave 

478 elif number_other_tiles > 0 and number_cooperative_tiles > sm_count * 0.8: 

479 number_cooperative_tiles = 0 

480 

481 if number_cooperative_tiles > 0: 

482 # mini wave 

483 total_iters_streamk = number_cooperative_tiles * iters_per_tile 

484 iters_per_pid = total_iters_streamk // tiles_per_wave 

485 iters_remaining = total_iters_streamk % tiles_per_wave 

486 even_k = K % BLOCK_K == 0 

487 

488 if a.dtype == torch.bfloat16: 

489 locks = torch.zeros((tiles_per_wave,), device=a.device, dtype=torch.int32) 

490 P = torch.empty( 

491 (tiles_per_wave, BLOCK_M, BLOCK_N), device=a.device, dtype=torch.float32 

492 ) 

493 # with torch_device_fn.device(a.device): 

494 first_wave_for_bf16[(tiles_per_wave,)]( 

495 a, 

496 b, 

497 c, 

498 P, 

499 M, 

500 N, 

501 K, 

502 locks, 

503 a.stride(0), 

504 a.stride(1), 

505 b.stride(0), 

506 b.stride(1), 

507 c.stride(0), 

508 c.stride(1), 

509 iters_per_pid=iters_per_pid, 

510 iters_remaining=iters_remaining, 

511 iters_per_tile=iters_per_tile, 

512 BLOCK_M=BLOCK_M, 

513 BLOCK_N=BLOCK_N, 

514 BLOCK_K=BLOCK_K, 

515 GROUP_M=GROUP_M, 

516 EVEN_K=even_k, 

517 num_stages=num_stages, 

518 num_warps=num_warps, 

519 ) 

520 # logger.debug(f"{k1.n_regs} registers used, {k1.n_spills} spills") 

521 # logger.debug(f"shared memory: {k1.metadata.shared} bytes") 

522 else: 

523 locks = torch.zeros( 

524 (number_cooperative_tiles,), device=a.device, dtype=torch.int32 

525 ) 

526 first_wave[(tiles_per_wave,)]( 

527 a, 

528 b, 

529 c, 

530 M, 

531 N, 

532 K, 

533 locks, 

534 a.stride(0), 

535 a.stride(1), 

536 b.stride(0), 

537 b.stride(1), 

538 c.stride(0), 

539 c.stride(1), 

540 iters_per_pid=iters_per_pid, 

541 iters_remaining=iters_remaining, 

542 iters_per_tile=iters_per_tile, 

543 BLOCK_M=BLOCK_M, 

544 BLOCK_N=BLOCK_N, 

545 BLOCK_K=BLOCK_K, 

546 GROUP_M=GROUP_M, 

547 EVEN_K=even_k, 

548 num_stages=num_stages, 

549 num_warps=num_warps, 

550 ) 

551 # logger.debug(f"{k1.n_regs} registers used, {k1.n_spills} spills") 

552 # logger.debug(f"shared memory: {k1.metadata.shared} bytes") 

553 

554 classic_mm[(total_tiles - number_cooperative_tiles,)]( 

555 a, 

556 b, 

557 c, 

558 M, 

559 N, 

560 K, 

561 a.stride(0), 

562 a.stride(1), 

563 b.stride(0), 

564 b.stride(1), 

565 c.stride(0), 

566 c.stride(1), 

567 total_tiles_streamk=number_cooperative_tiles, 

568 BLOCK_M=BLOCK_M, 

569 BLOCK_N=BLOCK_N, 

570 BLOCK_K=BLOCK_K, 

571 GROUP_M=GROUP_M, 

572 num_stages=num_stages, 

573 num_warps=num_warps, 

574 ) 

575 # logger.debug(f"{k2.n_regs} registers used, {k2.n_spills} spills") 

576 # logger.debug(f"shared memory: {k2.metadata.shared} bytes") 

577 return c