Coverage for src/flag_gems/runtime/backend/_ascend/ops/cumsum.py: 0%

321 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-29 04:01 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7from torch._prims_common import is_boolean_dtype, is_integer_dtype 

8 

9from flag_gems.runtime import device, torch_device_fn 

10from flag_gems.utils import libentry 

11from flag_gems.utils import triton_lang_extension as tle 

12 

13from ..utils import CORE_NUM 

14 

15device = device.name 

16logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}') 

17 

18 

19@tl.constexpr 

20def get_scan_accum_type(inp_dtype: tl.dtype) -> tl.dtype: 

21 if inp_dtype.is_bf16() or inp_dtype.is_fp16(): 

22 return tl.float32 

23 if inp_dtype.is_int(): # signed or not(including bool) 

24 return tl.int64 

25 else: 

26 return inp_dtype 

27 

28 

29@libentry() 

30@triton.jit(do_not_specialize=["n_elements", "part_num"]) 

31def scan_part_sum_kernel( 

32 inp, 

33 out, 

34 partial_sum, 

35 n_elements, 

36 part_num, 

37 BLOCK_SIZE: tl.constexpr, 

38): 

39 pid = tle.program_id(0) 

40 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

41 mask = offset < n_elements 

42 

43 inp_ptrs = inp + offset 

44 inp_vals = tl.load(inp_ptrs, mask=mask) 

45 if ( 

46 tl.constexpr(inp_vals.dtype.is_int64()) 

47 or tl.constexpr(inp_vals.dtype.is_uint64()) 

48 ) or tl.constexpr(inp_vals.dtype.is_fp64()): 

49 inp_vals = inp_vals 

50 elif tl.constexpr(inp_vals.dtype.is_int()): 

51 inp_vals = inp_vals.to(tl.int32) 

52 else: 

53 inp_vals = inp_vals.to(tl.float32) 

54 result = tl.cumsum(inp_vals, axis=0) 

55 

56 part_sum_via_sum = tl.sum(inp_vals) 

57 

58 out_ptrs = out + offset 

59 tl.store(out_ptrs, result, mask=mask) 

60 

61 partial_sum_ptrs = partial_sum + pid 

62 tl.store(partial_sum_ptrs, part_sum_via_sum) 

63 

64 

65@libentry() 

66@triton.jit(do_not_specialize=["n_elements", "part_num"]) 

67def add_base_sum_kernel( 

68 out, 

69 partial_sum, 

70 n_elements, 

71 part_num, 

72 BLOCK_SIZE: tl.constexpr, 

73): 

74 pid = tle.program_id(0) 

75 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

76 mask = offset < n_elements 

77 

78 out_ptrs = out + offset 

79 out_vals = tl.load(out_ptrs, mask=mask) 

80 

81 if pid > 0: 

82 partial_sum_ptrs = partial_sum + pid - 1 

83 last_part_sum_via_sum = tl.load(partial_sum_ptrs) 

84 

85 final_vals = out_vals + last_part_sum_via_sum 

86 tl.store(out_ptrs, final_vals.to(out_vals.dtype), mask=mask) 

87 

88 

89@libentry() 

90@triton.jit(do_not_specialize=["part_num"]) 

91def scan_part_sum_abc_kernel( 

92 inp, 

93 out, 

94 partial_sum, 

95 B, 

96 C, 

97 part_num, 

98 BLOCK_SIZE: tl.constexpr, 

99): 

100 pid_a = tle.program_id(0) 

101 pid_b = tle.program_id(1) 

102 pid_c = tle.program_id(2) 

103 

104 a_idx = pid_a 

105 b_idx = pid_b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

106 c_idx = pid_c 

107 

108 offset = a_idx * B * C + b_idx * C + c_idx 

109 base_part_offset = a_idx * part_num * C + c_idx 

110 part_offset = base_part_offset + pid_b * C 

111 

112 mask = b_idx < B 

113 inp_ptrs = inp + offset 

114 inp_vals = tl.load(inp_ptrs, mask=mask) 

115 if ( 

116 tl.constexpr(inp_vals.dtype.is_int64()) 

117 or tl.constexpr(inp_vals.dtype.is_uint64()) 

118 ) or tl.constexpr(inp_vals.dtype.is_fp64()): 

119 inp_vals = inp_vals 

120 elif tl.constexpr(inp_vals.dtype.is_int()): 

121 inp_vals = inp_vals.to(tl.int32) 

122 else: 

123 inp_vals = inp_vals.to(tl.float32) 

124 result = tl.cumsum(inp_vals, axis=0) 

125 

126 part_sum_via_sum = tl.sum(inp_vals) 

127 

128 out_ptrs = out + offset 

129 tl.store(out_ptrs, result, mask=mask) 

130 

131 partial_sum_ptrs = partial_sum + part_offset 

132 tl.store(partial_sum_ptrs, part_sum_via_sum) 

133 

134 

135@libentry() 

136@triton.jit(do_not_specialize=["part_num"]) 

137def add_base_sum_abc_kernel( 

138 out, 

139 partial_sum, 

140 B, 

141 C, 

142 part_num, 

143 BLOCK_SIZE: tl.constexpr, 

144): 

145 pid_a = tle.program_id(0) 

146 pid_b = tle.program_id(1) 

147 pid_c = tle.program_id(2) 

148 

149 a_idx = pid_a 

150 b_idx = pid_b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

151 c_idx = pid_c 

152 

153 base_offset = a_idx * B * C + c_idx 

154 offset = base_offset + b_idx * C 

155 base_part_offset = a_idx * part_num * C + c_idx 

156 last_part_offset = base_part_offset + (pid_b - 1) * C 

157 

158 mask = b_idx < B 

159 out_ptrs = out + offset 

160 out_vals = tl.load(out_ptrs, mask=mask) 

161 

162 if pid_b > 0: 

163 partial_sum_ptrs = partial_sum + last_part_offset 

164 last_part_sum_via_sum = tl.load(partial_sum_ptrs) 

165 

166 final_vals = out_vals + last_part_sum_via_sum 

167 tl.store(out_ptrs, final_vals.to(out_vals.dtype), mask=mask) 

168 

169 

170def scan_then_fan_col(inp, out, n_ele, dtype): 

171 # TODO(all): tune on target board 

172 BLOCK_SIZE = 1024 

173 if n_ele <= 1024 * 4: 

174 BLOCK_SIZE = triton.next_power_of_2(n_ele) 

175 part_num = math.ceil(n_ele / BLOCK_SIZE) 

176 partial_sum = torch.empty(part_num, dtype=dtype, device=inp.device) 

177 

178 grid = (part_num,) 

179 with torch_device_fn.device(inp.device): 

180 scan_part_sum_kernel[grid](inp, out, partial_sum, n_ele, part_num, BLOCK_SIZE) 

181 

182 if part_num >= 2: 

183 scan_then_fan_col(partial_sum, partial_sum, part_num, dtype) 

184 with torch_device_fn.device(inp.device): 

185 add_base_sum_kernel[grid](out, partial_sum, n_ele, part_num, BLOCK_SIZE) 

186 

187 

188def scan_then_fan(inp, out, A, B, C, dtype): 

189 # TODO(all): tune on target board 

190 BLOCK_SIZE = 1024 

191 if B <= 1024 * 4: 

192 BLOCK_SIZE = triton.next_power_of_2(B) 

193 part_num = math.ceil(B / BLOCK_SIZE) 

194 partial_sum = torch.empty(A, part_num, C, dtype=dtype, device=inp.device) 

195 

196 grid = (A, part_num, C) 

197 with torch_device_fn.device(inp.device): 

198 scan_part_sum_abc_kernel[grid]( 

199 inp, out, partial_sum, B, C, part_num, BLOCK_SIZE 

200 ) 

201 

202 if part_num >= 2: 

203 scan_then_fan(partial_sum, partial_sum, A, part_num, C, dtype) 

204 with torch_device_fn.device(inp.device): 

205 add_base_sum_abc_kernel[grid](out, partial_sum, B, C, part_num, BLOCK_SIZE) 

206 

207 

208def cumsum_wrapper(inp, dim=1, dtype=None, out=None): 

209 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

210 shape = inp.shape 

211 dim = dim % inp.ndim 

212 M = 1 

213 N = shape[dim] 

214 for i in range(dim): 

215 M *= shape[i] 

216 inp = inp.contiguous() 

217 K = inp.numel() // M // N 

218 

219 if dtype is None: 

220 dtype = inp.dtype 

221 if is_integer_dtype(dtype) or is_boolean_dtype(dtype): 

222 dtype = torch.int64 

223 if out is None: 

224 out = torch.empty_like(inp, dtype=dtype) 

225 

226 compute_dtype = out.dtype 

227 if inp.dtype == torch.float16 or inp.dtype == torch.bfloat16: 

228 compute_dtype = torch.float32 

229 

230 if K == 1: # row scan 

231 reduce_then_scan_row(inp, out, M, N, compute_dtype) 

232 else: # col scan 

233 scan_then_fan(inp, out, M, N, K, compute_dtype) 

234 

235 return out 

236 

237 

238def reduce_then_scan_row(x, out, M, N, compute_dtype): 

239 if N <= 16384: # persistent 

240 TILE_SIZE = triton.next_power_of_2(N) 

241 reduce_then_scan_root_scan_kernel_row[(M, 1, 1)]( 

242 x, 

243 out, 

244 N, 

245 TILE_SIZE, 

246 ) 

247 return out 

248 

249 TILE_SIZE = min(4096, triton.next_power_of_2(N)) 

250 num_tiles = triton.cdiv(N, TILE_SIZE) 

251 num_ctas = num_tiles 

252 ROOT_SCAN_TILE_SIZE = triton.next_power_of_2(num_ctas) 

253 tiles_per_cta = triton.cdiv(num_tiles, num_ctas) 

254 block_sums = torch.empty( 

255 ( 

256 M, 

257 num_ctas, 

258 ), 

259 dtype=compute_dtype, 

260 device=x.device, 

261 ) 

262 block_inclusive_prefix = torch.empty( 

263 ( 

264 M, 

265 num_ctas, 

266 ), 

267 dtype=compute_dtype, 

268 device=x.device, 

269 ) 

270 

271 # 3-kernel implementation 

272 reduce_then_scan_block_sum_kernel_row[(M, num_ctas, 1, 1)]( 

273 x, 

274 block_sums, 

275 N, 

276 tiles_per_cta, 

277 TILE_SIZE, 

278 ) 

279 reduce_then_scan_root_scan_kernel_row[(M, 1, 1)]( 

280 block_sums, 

281 block_inclusive_prefix, 

282 num_ctas, 

283 ROOT_SCAN_TILE_SIZE, 

284 ) 

285 reduce_then_scan_block_scan_kernel_row[(M, num_ctas, 1)]( 

286 x, 

287 block_inclusive_prefix, 

288 out, 

289 N, 

290 num_ctas, 

291 tiles_per_cta, 

292 TILE_SIZE, 

293 ) 

294 return out 

295 

296 

297@triton.jit 

298def reduce_then_scan_block_sum_kernel_row( 

299 in_ptr, 

300 block_sum_ptr, 

301 N, 

302 tiles_per_cta, 

303 TILE_SIZE: tl.constexpr, 

304): 

305 """The same kernel as the block sum in parallel reduce""" 

306 pid_n = tl.program_id(1).to(tl.int64) 

307 pid_m = tl.program_id(0).to(tl.int64) 

308 num_programs_n = tl.num_programs(1) 

309 block_offset = pid_n * (tiles_per_cta * TILE_SIZE) 

310 block_end = min(block_offset + tiles_per_cta * TILE_SIZE, N) 

311 

312 acc_dtype: tl.constexpr = get_scan_accum_type(in_ptr.type.element_ty) 

313 acc = tl.zeros((TILE_SIZE,), dtype=acc_dtype) 

314 for start in range(block_offset, block_end, TILE_SIZE): 

315 offsets = start + tl.arange(0, TILE_SIZE) 

316 x = tl.load(in_ptr + pid_m * N + offsets, mask=offsets < N).to(acc_dtype) 

317 acc += x 

318 block_sum = tl.sum(acc, 0) 

319 tl.store( 

320 block_sum_ptr + pid_m * num_programs_n + pid_n, block_sum, cache_modifier=".cg" 

321 ) 

322 

323 

324@triton.jit 

325def reduce_then_scan_root_scan_kernel_row(in_ptr, out_ptr, N, TILE_SIZE: tl.constexpr): 

326 """Almost The same kernel as the persistent scan kernel""" 

327 pid = tl.program_id(0).to(tl.int64) 

328 offsets = tl.arange(0, TILE_SIZE) 

329 mask = offsets < N 

330 acc_dtype: tl.constexpr = get_scan_accum_type(in_ptr.type.element_ty) 

331 x = tl.load(in_ptr + pid * N + offsets, mask=mask, other=0).to(acc_dtype) 

332 out = tl.cumsum(x, 0) 

333 tl.store(out_ptr + pid * N + offsets, out, mask=mask) 

334 

335 

336@triton.jit 

337def reduce_then_scan_block_scan_kernel_row( 

338 in_ptr, 

339 previous_sum_ptr, 

340 out_ptr, 

341 N, 

342 num_tiles_n, 

343 tiles_per_cta, 

344 TILE_SIZE: tl.constexpr, 

345): 

346 pid_m = tl.program_id(0).to(tl.int64) 

347 pid_n = tl.program_id(1).to(tl.int64) 

348 block_offset = pid_n * (tiles_per_cta * TILE_SIZE) 

349 block_end = min(block_offset + tiles_per_cta * TILE_SIZE, N) 

350 acc_dtype: tl.constexpr = get_scan_accum_type(in_ptr.type.element_ty) 

351 

352 prefix = tl.load( 

353 previous_sum_ptr + pid_m * num_tiles_n + pid_n - 1, mask=pid_n > 0, other=0 

354 ).to(acc_dtype) 

355 for start in range(block_offset, block_end, TILE_SIZE): 

356 offsets = start + tl.arange(0, TILE_SIZE) 

357 mask = offsets < N 

358 x = tl.load(in_ptr + pid_m * N + offsets, mask=mask).to(acc_dtype) 

359 tile_scan = prefix + tl.cumsum(x, 0) 

360 prefix += tl.sum(x, 0) 

361 tl.store( 

362 out_ptr + pid_m * N + offsets, tile_scan, mask=mask, cache_modifier=".cg" 

363 ) 

364 

365 

366def cumsum(inp, dim=1, *, dtype=None): 

367 logger.debug("GEMS_ASCEND CUMSUM") 

368 return cumsum_wrapper(inp, dim, dtype) 

369 

370 

371def cumsum_out(inp, dim=1, *, dtype=None, out): 

372 logger.debug("GEMS_ASCEND CUMSUM_OUT") 

373 return cumsum_wrapper(inp, dim, dtype, out) 

374 

375 

376@libentry() 

377@triton.jit(do_not_specialize=["K"]) 

378def normed_cumsum_kernel(inp, out, K, BLOCK: tl.constexpr): 

379 row_start = tle.program_id(0) * K 

380 row_off = tl.arange(0, BLOCK) 

381 x = tl.load(inp + row_start + row_off, mask=row_off < K, other=0) 

382 if x.dtype.is_fp16(): 

383 x = x.to(tl.float32) 

384 y_sum = tl.sum(x, 0) 

385 y = tl.cumsum(x, 0) 

386 y = y / y_sum 

387 tl.store(out + row_start + row_off, y, mask=row_off < K) 

388 

389 

390@libentry() 

391@triton.jit( 

392 do_not_specialize=[ 

393 "r", 

394 "t", 

395 "R", 

396 "K", 

397 "r_stride", 

398 "out_r_stride", 

399 ] 

400) 

401def block_cumsum_kernel( 

402 inp, 

403 out, 

404 sums, 

405 r, 

406 t, 

407 R, 

408 K, 

409 r_stride, 

410 k_stride, 

411 out_r_stride, 

412 out_k_stride, 

413 OUTPUT_SUMS: tl.constexpr, 

414 NORMALIZE: tl.constexpr, 

415 HAS_OUT_LAYOUT: tl.constexpr, 

416 TILE: tl.constexpr, 

417): 

418 # One CTA processes a (r, t*tile) chunk 

419 # rows = [ grid.y, grid.y + r ) 

420 # cols = [ grid.x * t * tile, (grid.x + 1) * t * tile ) 

421 gridx = tle.program_id(0).to(tl.int64) 

422 gridy = tle.program_id(1).to(tl.int64) 

423 n_chunks = tle.num_programs(0) 

424 

425 for row in range(gridy * r, min((gridy + 1) * r, R)): 

426 curr_cumsum = tl.zeros((1,), tl.float32) 

427 row_offset = row * r_stride 

428 cols = gridx * t * TILE + tl.arange(0, TILE) 

429 for ti in range(0, t): 

430 cols_offset = cols * k_stride 

431 x = tl.load(inp + row_offset + cols_offset, mask=cols < K, other=0) 

432 if x.dtype.is_fp16() | x.dtype.is_bf16(): 

433 x = x.to(tl.float32) 

434 tile_sum = tl.sum(x, 0)[None] 

435 tile_cumsum = tl.cumsum(x, 0) + curr_cumsum 

436 curr_cumsum += tile_sum 

437 if HAS_OUT_LAYOUT: 

438 cols_offset = cols * out_k_stride 

439 row_offset = row * out_r_stride 

440 tl.store(out + row_offset + cols_offset, tile_cumsum, mask=cols < K) 

441 if OUTPUT_SUMS: 

442 tl.store(sums + row * n_chunks + gridx[None], curr_cumsum) 

443 cols += TILE 

444 if NORMALIZE: 

445 cols = gridx * t * TILE + tl.arange(0, TILE) 

446 for _ in range(0, t): 

447 cols_offset = cols * k_stride 

448 if HAS_OUT_LAYOUT: 

449 cols_offset = cols * out_k_stride 

450 row_offset = row * out_r_stride 

451 x = tl.load(out + row_offset + cols_offset, mask=cols < K, other=0) 

452 if x.dtype.is_fp16() | x.dtype.is_bf16(): 

453 x = x.to(tl.float32) 

454 x = x / curr_cumsum 

455 tl.store(out + row_offset + cols_offset, x, mask=cols < K) 

456 cols += TILE 

457 

458 

459@libentry() 

460@triton.jit( 

461 do_not_specialize=[ 

462 "r", 

463 "t", 

464 "R", 

465 "K", 

466 "r_stride", 

467 "out_r_stride", 

468 ] 

469) 

470def block_update_kernel( 

471 inp, 

472 base, 

473 rscale_ptr, 

474 out, 

475 r, 

476 t, 

477 R, 

478 K, 

479 r_stride, 

480 k_stride, 

481 out_r_stride, 

482 out_k_stride, 

483 rscale_stride, 

484 HAS_OUT_LAYOUT: tl.constexpr, 

485 TILE: tl.constexpr, 

486): 

487 # One CTA processes a (r, t*tile) chunk 

488 # rows = [ grid.y, grid.y + r ) 

489 # cols = [ grid.x * t * tile, (grid.x + 1) * t * tile ) 

490 gridx = tle.program_id(0).to(tl.int64) 

491 gridy = tle.program_id(1).to(tl.int64) 

492 n_gridx = tle.num_programs(1) 

493 

494 base += gridy * n_gridx + gridx 

495 rscale_ptr += gridy * rscale_stride 

496 

497 for row in range(gridy, min(gridy + r, R)): 

498 d = tl.load(base) 

499 rscale = tl.load(rscale_ptr) 

500 base += gridx 

501 rscale_ptr += rscale_stride 

502 row_offset = row * r_stride 

503 cols = gridx * t * TILE + tl.arange(0, TILE) 

504 for _ in range(0, t): 

505 cols_offset = cols * k_stride 

506 x = tl.load(inp + row_offset + cols_offset, mask=cols < K, other=0) 

507 x += d 

508 x /= rscale 

509 if HAS_OUT_LAYOUT: 

510 cols_offset = cols * out_k_stride 

511 row_offset = row * out_r_stride 

512 tl.store(out + row_offset + cols_offset, x, mask=cols < K) 

513 cols += TILE 

514 

515 

516GRID_Y_LIMIT = 65535 

517 

518 

519def normed_cumsum(inp, dim=-1): 

520 logger.debug("GEMS_ASCEND NORMED_CUMSUM") 

521 assert inp.dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64) 

522 dim = dim % inp.ndim 

523 N = inp.numel() 

524 K = inp.size(dim) 

525 # inp = inp.contiguous() 

526 # First and last dims are easier to handle, but transpose the middle dim to the last 

527 ranked_dims = sorted(range(inp.ndim), key=lambda i: inp.stride(i), reverse=True) 

528 is_mid_dim = dim not in (ranked_dims[0], ranked_dims[-1]) 

529 if is_mid_dim: 

530 inp = inp.transpose(dim, -1).contiguous() 

531 dim = -1 

532 out = torch.empty_like(inp) 

533 with torch_device_fn.device(inp.device.index): 

534 # Pass one, scan a (batch, n_tiles * TILE) sized block within each cta 

535 TILE = 2048 

536 # Each row is split into n_chunks of chunks where each chunk is compised of 

537 # n_tiles of tiles. Different chunks are assigned to different ctas. 

538 n_rows = N // K 

539 n_chunks = min(triton.cdiv(CORE_NUM, n_rows), triton.cdiv(K, TILE)) 

540 n_tiles = triton.cdiv(triton.cdiv(K, TILE), n_chunks) 

541 k_stride = inp.stride(dim) 

542 r_stride = inp.size(dim) if k_stride == 1 else 1 

543 if n_rows > GRID_Y_LIMIT: 

544 batch = triton.cdiv(n_rows, GRID_Y_LIMIT) 

545 n_batch = triton.cdiv(n_rows, batch) 

546 else: 

547 batch = 1 

548 n_batch = n_rows 

549 

550 grid = (n_chunks, n_batch) 

551 if n_chunks == 1: 

552 block_cumsum_kernel[grid]( 

553 inp, 

554 out, 

555 0, 

556 batch, 

557 n_tiles, 

558 n_rows, 

559 K, 

560 r_stride, 

561 k_stride, 

562 r_stride, 

563 k_stride, 

564 OUTPUT_SUMS=False, 

565 NORMALIZE=True, 

566 HAS_OUT_LAYOUT=False, 

567 TILE=TILE, 

568 ) 

569 return out 

570 

571 if inp.dtype != torch.float64: 

572 acc_dtype = torch.float32 

573 sums = torch.empty((n_rows, n_chunks), dtype=acc_dtype, device=device.name) 

574 cumsums = torch.empty_like(sums) 

575 block_cumsum_kernel[grid]( 

576 inp, 

577 out, 

578 sums, 

579 batch, 

580 n_tiles, 

581 n_rows, 

582 K, 

583 r_stride, 

584 k_stride, 

585 r_stride, 

586 k_stride, 

587 OUTPUT_SUMS=True, 

588 NORMALIZE=False, 

589 HAS_OUT_LAYOUT=False, 

590 TILE=TILE, 

591 ) 

592 # Pass two, scan partial cumsums 

593 block_cumsum_kernel[(1, n_batch)]( 

594 sums, 

595 cumsums, 

596 0, 

597 batch, 

598 1, 

599 n_rows, 

600 n_chunks, 

601 n_chunks, 

602 1, 

603 n_chunks, 

604 1, 

605 OUTPUT_SUMS=False, 

606 NORMALIZE=False, 

607 HAS_OUT_LAYOUT=True, 

608 TILE=TILE, 

609 ) 

610 rscale = cumsums[..., -1] 

611 block_update_kernel[grid]( 

612 out, 

613 cumsums - sums, 

614 rscale, 

615 out, 

616 batch, 

617 n_tiles, 

618 n_rows, 

619 K, 

620 r_stride, 

621 k_stride, 

622 r_stride, 

623 k_stride, 

624 n_chunks, 

625 HAS_OUT_LAYOUT=False, 

626 TILE=TILE, 

627 ) 

628 return out