Coverage for src/flag_gems/runtime/backend/_cambricon/ops/cumsum.py: 0%

325 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-21 14:31 +0800

1import copy 

2import logging 

3import math 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9from flag_gems.runtime import device, torch_device_fn 

10from flag_gems.utils import libentry, libtuner 

11 

12from ..utils import MAX_GRID_SIZE_Y, TOTAL_CORE_NUM 

13 

14logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

15device = device.name 

16 

17# FIXME(cambricon): double 8192 when JIRA:1488 is fixed 

18MAX_C_MLU_CUMSUM = 8192 

19MAX_C_MLU_SPILT_CUMSUM = 32768 

20MAX_TILE_N = 256 

21 

22 

23@triton.jit 

24def cumsum_blelloch_impl( 

25 in_block, 

26 DTYPE: tl.constexpr, 

27 BLOCK_M: tl.constexpr, 

28 BLOCK_N: tl.constexpr, 

29 BLOCK_K: tl.constexpr, 

30 TILE_N: tl.constexpr, 

31 TILE_NUM: tl.constexpr, 

32): 

33 x_block = tl.reshape(in_block, (BLOCK_M, TILE_NUM, TILE_N, BLOCK_K)) 

34 # Trans TILE_N and apply blelloch in TILE_N dim 

35 x_block = tl.trans(x_block, 0, 2, 1, 3) 

36 # Apply blelloch algo 

37 # Up-Sweep Phase 

38 step = 1 

39 while step < TILE_N: 

40 idx_a = step - 1 

41 idx_b = idx_a + step 

42 while idx_b < TILE_N: 

43 x_block[:, idx_b, :, :] = x_block[:, idx_a, :, :] + x_block[:, idx_b, :, :] 

44 idx_a += 2 * step 

45 idx_b += 2 * step 

46 step *= 2 

47 # Down-Sweep Phase 

48 step //= 2 

49 while step > 0: 

50 idx_b = TILE_N - 1 - step 

51 idx_a = idx_b - step 

52 while idx_a > 0: 

53 x_block[:, idx_b, :, :] = x_block[:, idx_a, :, :] + x_block[:, idx_b, :, :] 

54 idx_b -= 2 * step 

55 idx_a -= 2 * step 

56 step //= 2 

57 # Deal the last tile row exclusive sum(Composed by right shift and tl.cumsum) 

58 # Right shift 1 position for the last tile row 

59 partial_sum = tl.zeros((BLOCK_M, TILE_NUM, BLOCK_K), dtype=tl.dtype(DTYPE)) 

60 if TILE_NUM > 1: 

61 partial_sum[:, 1:, :] = x_block[:, TILE_N - 1, 0 : (TILE_NUM - 1), :] 

62 partial_sum = tl.cumsum(partial_sum, axis=1) 

63 # Apply cycle add for all tile data 

64 x_block += partial_sum[:, None, :, :] 

65 # Trans TILE_N dim to original pos 

66 x_block = tl.trans(x_block, 0, 2, 1, 3) 

67 x_block = tl.reshape(x_block, (BLOCK_M, BLOCK_N, BLOCK_K)) 

68 return x_block 

69 

70 

71def config_prune(configs, named_args, **kwargs): 

72 M = named_args["M"] 

73 N = named_args["N"] 

74 configs_map = {} 

75 for config in configs: 

76 kw = config.kwargs 

77 BLOCK_M, BLOCK_N, TILE_N, num_warps, num_stages = ( 

78 kw["BLOCK_M"], 

79 kw["BLOCK_N"], 

80 kw["TILE_N"], 

81 config.num_warps, 

82 config.num_stages, 

83 ) 

84 new_config = config 

85 # When N is less than MAX_C_MLU_CUMSUM, no reduction loops. Unify different BLOCK_N configs. 

86 if N <= MAX_C_MLU_CUMSUM: 

87 # change config 

88 new_config = copy.deepcopy(config) 

89 BLOCK_N = new_config.kwargs["BLOCK_N"] = triton.next_power_of_2(N) 

90 num_stages = new_config.num_stages = 1 

91 else: 

92 # When N is greater than MAX_C_MLU_CUMSUM, the pruning condition was obtained through experimentation. 

93 # It may result in not finding the optimal solution. 

94 if BLOCK_N < 2048: 

95 continue 

96 if BLOCK_N >= 2048 and TILE_N < 8: 

97 continue 

98 if ( 

99 BLOCK_N < MAX_C_MLU_CUMSUM 

100 and BLOCK_M < M 

101 and BLOCK_M <= (MAX_C_MLU_CUMSUM // BLOCK_N * 2) 

102 ): 

103 continue 

104 # BLOCK_M can only be 1 when BLOCK_N is at its maximum 

105 if BLOCK_N == MAX_C_MLU_CUMSUM and BLOCK_M > 1: 

106 continue 

107 # Prune invalid BLOCK_M 

108 if BLOCK_M > M: 

109 continue 

110 # Prune invalid TILE_N 

111 if TILE_N > BLOCK_N: 

112 continue 

113 # The pruning condition was obtained through experimentation. It may result in not finding the optimal solution. 

114 if BLOCK_N > 128 and TILE_N < 8: 

115 continue 

116 key = (BLOCK_M, BLOCK_N, TILE_N, num_warps, num_stages) 

117 # Only keep one config for the same key 

118 configs_map.setdefault(key, new_config) 

119 pruned_configs = [] 

120 for k, v in configs_map.items(): 

121 pruned_configs.append(v) 

122 return pruned_configs 

123 

124 

125@libentry() 

126@libtuner( 

127 configs=[ 

128 triton.Config( 

129 { 

130 "BLOCK_M": m, 

131 "BLOCK_N": 2**n, 

132 "TILE_N": 2**t, 

133 }, 

134 num_stages=s, 

135 num_warps=1, 

136 ) 

137 for m in range(1, 20, 3) 

138 for n in range(7, 13, 1) 

139 for t in range(0, 7, 1) 

140 for s in [1, 3] 

141 ], 

142 key=[ 

143 "M", 

144 "N", 

145 "K", 

146 ], 

147 strategy=["log", "log", "log"], 

148 prune_configs_by={"early_config_prune": config_prune}, 

149) 

150@triton.heuristics( 

151 values={ 

152 "TILE_NUM": lambda args: args["BLOCK_N"] // args["TILE_N"] 

153 if args["BLOCK_N"] % args["TILE_N"] == 0 

154 and args["BLOCK_N"] // args["TILE_N"] >= 1 

155 else 1, 

156 "TILE_N": lambda args: args["BLOCK_N"] 

157 if args["TILE_NUM"] == 1 

158 else args["TILE_N"], 

159 }, 

160) 

161@triton.jit 

162def cumsum_blelloch( 

163 inp, 

164 out, 

165 M, 

166 N, 

167 K, 

168 DTYPE: tl.constexpr, 

169 BLOCK_M: tl.constexpr, 

170 BLOCK_N: tl.constexpr, 

171 TILE_N: tl.constexpr, 

172 TILE_NUM: tl.constexpr, 

173): 

174 pid_m = tl.program_id(0) 

175 pid_k = tl.program_id(1) 

176 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

177 kep = tl.full([BLOCK_M, BLOCK_N, 1], float(0), tl.dtype(DTYPE)) 

178 for col_offset in range(0, N, BLOCK_N): 

179 n_offset = col_offset + tl.arange(0, BLOCK_N) 

180 # Pointers to the start of the row 

181 offsets = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k 

182 mask = m_offset[:, None] < M and n_offset[None, :] < N 

183 x_ptrs = inp + offsets 

184 y_ptrs = out + offsets 

185 

186 # Load data into NRAM 

187 in_block = tl.load(x_ptrs, mask=mask, other=0.0).to(tl.dtype(DTYPE)) 

188 

189 x_block = cumsum_blelloch_impl( 

190 in_block, DTYPE, BLOCK_M, BLOCK_N, 1, TILE_N, TILE_NUM 

191 ) 

192 # Add last block partial sum to current block 

193 x_block = tl.reshape(x_block, (BLOCK_M, BLOCK_N)) 

194 kep_tmp = kep[:, BLOCK_N - 1, :] 

195 x_block += kep_tmp 

196 kep = x_block[:, :, None] 

197 # Store result back to global memory 

198 tl.store(y_ptrs, x_block, mask=mask) 

199 

200 

201def get_reduction_dim_block_size(N): 

202 block_size = N // TOTAL_CORE_NUM + ((N % TOTAL_CORE_NUM) != 0) 

203 if block_size > MAX_C_MLU_SPILT_CUMSUM: 

204 block_size = MAX_C_MLU_SPILT_CUMSUM 

205 # In blelloch, block_size = TILE_N * TILE_NUM 

206 # TILE_N and TILE_NUM should be power of 2, So is it 

207 return triton.next_power_of_2(block_size) 

208 

209 

210def config_prune_mid(configs, named_args, **kwargs): 

211 M = named_args["M"] 

212 K = named_args["K"] 

213 BLOCK_N = named_args["BLOCK_N"] 

214 configs_map = {} 

215 for config in configs: 

216 kw = config.kwargs 

217 BLOCK_M, BLOCK_K, TILE_N, num_warps, num_stages = ( 

218 kw["BLOCK_M"], 

219 kw["BLOCK_K"], 

220 kw["TILE_N"], 

221 config.num_warps, 

222 config.num_stages, 

223 ) 

224 new_config = config 

225 # Prune invalid BLOCK_M 

226 if BLOCK_M > M: 

227 continue 

228 # Prune invalid BLOCK_K 

229 if BLOCK_K > K: 

230 continue 

231 if BLOCK_N * BLOCK_K * BLOCK_M > MAX_C_MLU_SPILT_CUMSUM: 

232 continue 

233 # Prune invalid TILE_N 

234 if TILE_N > BLOCK_N: 

235 continue 

236 # The pruning condition was obtained through experimentation. It may result in not finding the optimal solution. 

237 if BLOCK_N > 128 and TILE_N < 8: 

238 continue 

239 key = (BLOCK_M, BLOCK_N, BLOCK_K, TILE_N, num_warps, num_stages) 

240 # Only keep one config for the same key 

241 configs_map.setdefault(key, new_config) 

242 pruned_configs = [] 

243 for k, v in configs_map.items(): 

244 pruned_configs.append(v) 

245 return pruned_configs 

246 

247 

248@libentry() 

249@libtuner( 

250 configs=[ 

251 triton.Config( 

252 { 

253 "BLOCK_M": m, 

254 "BLOCK_K": 2**k, 

255 "TILE_N": 2**t, 

256 }, 

257 num_stages=s, 

258 num_warps=1, 

259 ) 

260 for m in range(1, 10, 3) 

261 for k in range(0, 3, 1) 

262 for t in range(5, int(math.log(MAX_TILE_N, 2) + 1), 1) 

263 for s in [1, 3] 

264 ], 

265 key=[ 

266 "M", 

267 "N", 

268 "K", 

269 "BLOCK_N", 

270 ], 

271 strategy=["log", "log", "log", "log"], 

272 prune_configs_by={"early_config_prune": config_prune_mid}, 

273) 

274@triton.heuristics( 

275 values={ 

276 "TILE_NUM": lambda args: args["BLOCK_N"] // args["TILE_N"] 

277 if args["BLOCK_N"] % args["TILE_N"] == 0 

278 and args["BLOCK_N"] // args["TILE_N"] >= 1 

279 else 1, 

280 "TILE_N": lambda args: args["BLOCK_N"] 

281 if args["TILE_NUM"] == 1 

282 else args["TILE_N"], 

283 }, 

284) 

285@triton.jit 

286def cumsum_kernel_mid( 

287 inp, 

288 out, 

289 prefix_sum, 

290 M, 

291 N, 

292 K, 

293 BLOCK_N: tl.constexpr, 

294 DTYPE: tl.constexpr, 

295 BLOCK_M: tl.constexpr, 

296 BLOCK_K: tl.constexpr, 

297 TILE_N: tl.constexpr, 

298 TILE_NUM: tl.constexpr, 

299): 

300 pid_m = tl.program_id(0) 

301 pid_n = tl.program_id(1) 

302 num_jobs_n = tl.num_programs(1) 

303 pid_k = tl.program_id(2) 

304 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

305 k_offset = pid_k * BLOCK_K + tl.arange(0, BLOCK_K) 

306 n_offset = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

307 offsets = ( 

308 m_offset[:, None, None] * N * K 

309 + n_offset[ 

310 None, 

311 :, 

312 None, 

313 ] 

314 * K 

315 + k_offset[None, None, :] 

316 ) 

317 mask = (m_offset[:, None, None] < M and n_offset[None, :, None] < N) and k_offset[ 

318 None, None, : 

319 ] < K 

320 x_ptrs = inp + offsets 

321 y_ptrs = out + offsets 

322 

323 # Load data into NRAM 

324 in_block = tl.load(x_ptrs, mask=mask, other=0.0).to(tl.dtype(DTYPE)) 

325 

326 x_block = cumsum_blelloch_impl( 

327 in_block, DTYPE, BLOCK_M, BLOCK_N, BLOCK_K, TILE_N, TILE_NUM 

328 ) 

329 tl.store(y_ptrs, x_block, mask=mask) 

330 prefix_sum_offsets = ( 

331 m_offset[:, None] * num_jobs_n * K + pid_n * K + k_offset[None, :] 

332 ) 

333 prefix_sum_mask = m_offset[:, None] < M and k_offset[None, :] < K 

334 prefix_sum_ptrs = prefix_sum + prefix_sum_offsets 

335 tl.store(prefix_sum_ptrs, x_block[:, BLOCK_N - 1, :], prefix_sum_mask) 

336 

337 

338@libentry() 

339@libtuner( 

340 configs=[ 

341 triton.Config( 

342 { 

343 "BLOCK_M": m, 

344 "BLOCK_K": 2**k, 

345 }, 

346 num_stages=s, 

347 num_warps=1, 

348 ) 

349 for m in [1, 3, 6] 

350 for k in range(0, 3, 1) 

351 for s in [1, 3] 

352 ], 

353 key=[ 

354 "M", 

355 "N", 

356 "K", 

357 "BLOCK_N", 

358 ], 

359 strategy=["log", "log", "log", "log"], 

360) 

361@triton.jit 

362def cumsum_kernel_result( 

363 inp, 

364 prefix_sum, 

365 out, 

366 M, 

367 N, 

368 K, 

369 BLOCK_N: tl.constexpr, 

370 DTYPE: tl.constexpr, 

371 BLOCK_M: tl.constexpr, 

372 BLOCK_K: tl.constexpr, 

373): 

374 pid_m = tl.program_id(0) 

375 pid_n = tl.program_id(1) 

376 

377 num_jobs_n = tl.num_programs(1) 

378 pid_k = tl.program_id(2) 

379 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

380 n_offset = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

381 k_offset = pid_k * BLOCK_K + tl.arange(0, BLOCK_K) 

382 offsets = ( 

383 m_offset[:, None, None] * N * K 

384 + n_offset[ 

385 None, 

386 :, 

387 None, 

388 ] 

389 * K 

390 + k_offset[None, None, :] 

391 ) 

392 mask = (m_offset[:, None, None] < M and n_offset[None, :, None] < N) and k_offset[ 

393 None, None, : 

394 ] < K 

395 x_ptrs = inp + offsets 

396 y_ptrs = out + offsets 

397 

398 # Load data into NRAM 

399 x_block = tl.load(x_ptrs, mask=mask, other=0.0).to(tl.dtype(DTYPE)) 

400 

401 if pid_n > 0: 

402 sum_offsets = ( 

403 m_offset[:, None] * num_jobs_n * K + (pid_n - 1) * K + k_offset[None, :] 

404 ) 

405 sum_mask = m_offset[:, None] < M and k_offset[None, :] < K 

406 sum_ptrs = prefix_sum + sum_offsets 

407 sum_block = tl.load(sum_ptrs, mask=sum_mask, other=0.0).to(tl.dtype(DTYPE)) 

408 x_block += sum_block[:, None, :] 

409 

410 # Store result back to global memory 

411 tl.store(y_ptrs, x_block, mask=mask) 

412 

413 

414def cumsum_wrapper(inp, dim=1, dtype=None, out=None): 

415 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

416 shape = inp.shape 

417 dim = dim % inp.ndim 

418 M = 1 

419 N = shape[dim] 

420 for i in range(dim): 

421 M *= shape[i] 

422 inp = inp.contiguous() 

423 K = inp.numel() // M // N 

424 

425 if dtype is None: 

426 dtype = inp.dtype 

427 if dtype is torch.bool: 

428 dtype = torch.int32 

429 if out is None: 

430 out = torch.empty_like(inp, dtype=dtype) 

431 

432 blelloch_grid = lambda meta: ( 

433 triton.cdiv(M, meta["BLOCK_M"]), 

434 K, 

435 ) 

436 

437 dtypestr = "fp32" if torch.is_floating_point(out) else "int32" 

438 if (M * K < TOTAL_CORE_NUM / 2) and (N > MAX_C_MLU_CUMSUM): 

439 # result BLOCK_N must be same as mid BLOCK_N 

440 mid_out = torch.empty_like(inp, dtype=dtype) 

441 BLOCK_N = get_reduction_dim_block_size(N) 

442 prefix_sum_inp = torch.empty( 

443 M, triton.cdiv(N, BLOCK_N), K, dtype=dtype, device=inp.device 

444 ) 

445 prefix_sum = torch.empty( 

446 M, triton.cdiv(N, BLOCK_N), K, dtype=dtype, device=inp.device 

447 ) 

448 grid = lambda meta: ( 

449 triton.cdiv(M, meta["BLOCK_M"]), 

450 triton.cdiv(N, BLOCK_N), 

451 triton.cdiv(K, meta["BLOCK_K"]), 

452 ) 

453 with torch_device_fn.device(inp.device): 

454 cumsum_kernel_mid[grid]( 

455 inp, mid_out, prefix_sum_inp, M, N, K, BLOCK_N, dtypestr 

456 ) 

457 cumsum_blelloch[blelloch_grid]( 

458 prefix_sum_inp, prefix_sum, M, triton.cdiv(N, BLOCK_N), K, dtypestr 

459 ) 

460 cumsum_kernel_result[grid]( 

461 mid_out, prefix_sum, out, M, N, K, BLOCK_N, dtypestr 

462 ) 

463 else: 

464 with torch_device_fn.device(inp.device): 

465 cumsum_blelloch[blelloch_grid](inp, out, M, N, K, dtypestr) 

466 return out 

467 

468 

469def cumsum(inp, dim=1, *, dtype=None): 

470 logger.debug("GEMS_CAMBRICON CUMSUM") 

471 return cumsum_wrapper(inp, dim, dtype) 

472 

473 

474def cumsum_out(inp, dim=1, *, dtype=None, out): 

475 logger.debug("GEMS_CAMBRICON CUMSUM_OUT") 

476 return cumsum_wrapper(inp, dim, dtype, out) 

477 

478 

479@libentry() 

480@triton.jit(do_not_specialize=["K"]) 

481def normed_cumsum_kernel(inp, out, K, BLOCK: tl.constexpr): 

482 row_start = tl.program_id(0) * K 

483 row_off = tl.arange(0, BLOCK) 

484 x = tl.load(inp + row_start + row_off, mask=row_off < K, other=0) 

485 if x.dtype.is_fp16(): 

486 x = x.to(tl.float32) 

487 y_sum = tl.sum(x, 0) 

488 y = tl.cumsum(x, 0) 

489 y = y / y_sum 

490 tl.store(out + row_start + row_off, y, mask=row_off < K) 

491 

492 

493@libentry() 

494@triton.jit( 

495 do_not_specialize=[ 

496 "r", 

497 "t", 

498 "R", 

499 "K", 

500 "r_stride", 

501 "out_r_stride", 

502 ] 

503) 

504def block_cumsum_kernel( 

505 inp, 

506 out, 

507 sums, 

508 r, 

509 t, 

510 R, 

511 K, 

512 r_stride, 

513 k_stride, 

514 out_r_stride, 

515 out_k_stride, 

516 OUTPUT_SUMS: tl.constexpr, 

517 NORMALIZE: tl.constexpr, 

518 HAS_OUT_LAYOUT: tl.constexpr, 

519 TILE: tl.constexpr, 

520): 

521 # One CTA processes a (r, t*tile) chunk 

522 # rows = [ grid.y, grid.y + r ) 

523 # cols = [ grid.x * t * tile, (grid.x + 1) * t * tile ) 

524 gridx = tl.program_id(0).to(tl.int64) 

525 gridy = tl.program_id(1).to(tl.int64) 

526 n_chunks = tl.num_programs(0) 

527 

528 for row in range(gridy * r, min((gridy + 1) * r, R)): 

529 curr_cumsum = tl.zeros((1,), tl.float32) 

530 row_offset = row * r_stride 

531 cols = gridx * t * TILE + tl.arange(0, TILE) 

532 for ti in range(0, t): 

533 cols_offset = cols * k_stride 

534 x = tl.load(inp + row_offset + cols_offset, mask=cols < K, other=0) 

535 if x.dtype.is_fp16() | x.dtype.is_bf16(): 

536 x = x.to(tl.float32) 

537 tile_sum = tl.sum(x, 0)[None] 

538 tile_cumsum = tl.cumsum(x, 0) + curr_cumsum 

539 curr_cumsum += tile_sum 

540 if HAS_OUT_LAYOUT: 

541 cols_offset = cols * out_k_stride 

542 row_offset = row * out_r_stride 

543 tl.store(out + row_offset + cols_offset, tile_cumsum, mask=cols < K) 

544 if OUTPUT_SUMS: 

545 tl.store(sums + row * n_chunks + gridx[None], curr_cumsum) 

546 cols += TILE 

547 if NORMALIZE: 

548 cols = gridx * t * TILE + tl.arange(0, TILE) 

549 for _ in range(0, t): 

550 cols_offset = cols * k_stride 

551 if HAS_OUT_LAYOUT: 

552 cols_offset = cols * out_k_stride 

553 row_offset = row * out_r_stride 

554 x = tl.load(out + row_offset + cols_offset, mask=cols < K, other=0) 

555 if x.dtype.is_fp16() | x.dtype.is_bf16(): 

556 x = x.to(tl.float32) 

557 x = x / curr_cumsum 

558 tl.store(out + row_offset + cols_offset, x, mask=cols < K) 

559 cols += TILE 

560 

561 

562@libentry() 

563@triton.jit( 

564 do_not_specialize=[ 

565 "r", 

566 "t", 

567 "R", 

568 "K", 

569 "r_stride", 

570 "out_r_stride", 

571 ] 

572) 

573def block_update_kernel( 

574 inp, 

575 base, 

576 rscale_ptr, 

577 out, 

578 r, 

579 t, 

580 R, 

581 K, 

582 r_stride, 

583 k_stride, 

584 out_r_stride, 

585 out_k_stride, 

586 rscale_stride, 

587 HAS_OUT_LAYOUT: tl.constexpr, 

588 TILE: tl.constexpr, 

589): 

590 # One CTA processes a (r, t*tile) chunk 

591 # rows = [ grid.y, grid.y + r ) 

592 # cols = [ grid.x * t * tile, (grid.x + 1) * t * tile ) 

593 gridx = tl.program_id(0).to(tl.int64) 

594 gridy = tl.program_id(1).to(tl.int64) 

595 n_gridx = tl.num_programs(1) 

596 

597 base += gridy * n_gridx + gridx 

598 rscale_ptr += gridy * rscale_stride 

599 

600 for row in range(gridy, min(gridy + r, R)): 

601 d = tl.load(base) 

602 rscale = tl.load(rscale_ptr) 

603 base += gridx 

604 rscale_ptr += rscale_stride 

605 row_offset = row * r_stride 

606 cols = gridx * t * TILE + tl.arange(0, TILE) 

607 for _ in range(0, t): 

608 cols_offset = cols * k_stride 

609 x = tl.load(inp + row_offset + cols_offset, mask=cols < K, other=0) 

610 x += d 

611 x /= rscale 

612 if HAS_OUT_LAYOUT: 

613 cols_offset = cols * out_k_stride 

614 row_offset = row * out_r_stride 

615 tl.store(out + row_offset + cols_offset, x, mask=cols < K) 

616 cols += TILE 

617 

618 

619GRID_Y_LIMIT = MAX_GRID_SIZE_Y 

620 

621 

622def normed_cumsum(inp, dim=-1): 

623 logger.debug("GEMS_CAMBRICON NORMED_CUMSUM") 

624 assert inp.dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64) 

625 dim = dim % inp.ndim 

626 N = inp.numel() 

627 K = inp.size(dim) 

628 # inp = inp.contiguous() 

629 # First and last dims are easier to handle, but transpose the middle dim to the last 

630 ranked_dims = sorted(range(inp.ndim), key=lambda i: inp.stride(i), reverse=True) 

631 is_mid_dim = dim not in (ranked_dims[0], ranked_dims[-1]) 

632 if is_mid_dim: 

633 inp = inp.transpose(dim, -1).contiguous() 

634 dim = -1 

635 out = torch.empty_like(inp) 

636 with torch_device_fn.device(inp.device.index): 

637 # Pass one, scan a (batch, n_tiles * TILE) sized block within each cta 

638 num_sms = TOTAL_CORE_NUM # torch.cuda.get_device_properties("cuda").multi_processor_count 

639 TILE = 2048 

640 # Each row is split into n_chunks of chunks where each chunk is compised of 

641 # n_tiles of tiles. Different chunks are assigned to different ctas. 

642 n_rows = N // K 

643 n_chunks = min(triton.cdiv(num_sms, n_rows), triton.cdiv(K, TILE)) 

644 n_tiles = triton.cdiv(triton.cdiv(K, TILE), n_chunks) 

645 k_stride = inp.stride(dim) 

646 r_stride = inp.size(dim) if k_stride == 1 else 1 

647 if n_rows > GRID_Y_LIMIT: 

648 batch = triton.cdiv(n_rows, GRID_Y_LIMIT) 

649 n_batch = triton.cdiv(n_rows, batch) 

650 else: 

651 batch = 1 

652 n_batch = n_rows 

653 

654 grid = (n_chunks, n_batch) 

655 if n_chunks == 1: 

656 block_cumsum_kernel[grid]( 

657 inp, 

658 out, 

659 0, 

660 batch, 

661 n_tiles, 

662 n_rows, 

663 K, 

664 r_stride, 

665 k_stride, 

666 r_stride, 

667 k_stride, 

668 OUTPUT_SUMS=False, 

669 NORMALIZE=True, 

670 HAS_OUT_LAYOUT=False, 

671 TILE=TILE, 

672 ) 

673 return out 

674 

675 if inp.dtype != torch.float64: 

676 acc_dtype = torch.float32 

677 sums = torch.empty((n_rows, n_chunks), dtype=acc_dtype, device=device.name) 

678 cumsums = torch.empty_like(sums) 

679 block_cumsum_kernel[grid]( 

680 inp, 

681 out, 

682 sums, 

683 batch, 

684 n_tiles, 

685 n_rows, 

686 K, 

687 r_stride, 

688 k_stride, 

689 r_stride, 

690 k_stride, 

691 OUTPUT_SUMS=True, 

692 NORMALIZE=False, 

693 HAS_OUT_LAYOUT=False, 

694 TILE=TILE, 

695 ) 

696 # Pass two, scan partial cumsums 

697 block_cumsum_kernel[(1, n_batch)]( 

698 sums, 

699 cumsums, 

700 0, 

701 batch, 

702 1, 

703 n_rows, 

704 n_chunks, 

705 n_chunks, 

706 1, 

707 n_chunks, 

708 1, 

709 OUTPUT_SUMS=False, 

710 NORMALIZE=False, 

711 HAS_OUT_LAYOUT=True, 

712 TILE=TILE, 

713 ) 

714 # print(sums) 

715 rscale = cumsums[..., -1] 

716 block_update_kernel[grid]( 

717 out, 

718 cumsums - sums, 

719 rscale, 

720 out, 

721 batch, 

722 n_tiles, 

723 n_rows, 

724 K, 

725 r_stride, 

726 k_stride, 

727 r_stride, 

728 k_stride, 

729 n_chunks, 

730 HAS_OUT_LAYOUT=False, 

731 TILE=TILE, 

732 ) 

733 return out