Coverage for src/flag_gems/runtime/backend/_aipu/ops/cumsum.py: 0%

258 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-27 02:51 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.runtime import device, torch_device_fn 

9from flag_gems.utils import get_device_properties, libentry 

10from flag_gems.utils import triton_lang_extension as tle 

11 

12device = device.name 

13logger = logging.getLogger(__name__) 

14 

15 

16@libentry() 

17@triton.jit(do_not_specialize=["n_elements", "part_num"]) 

18def scan_part_sum_kernel( 

19 inp, 

20 out, 

21 partial_sum, 

22 n_elements, 

23 part_num, 

24 BLOCK_SIZE: tl.constexpr, 

25): 

26 pid = tle.program_id(0) 

27 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

28 mask = offset < n_elements 

29 

30 inp_ptrs = inp + offset 

31 inp_vals = tl.load(inp_ptrs, mask=mask) 

32 if ( 

33 tl.constexpr(inp_vals.dtype.is_int64()) 

34 or tl.constexpr(inp_vals.dtype.is_uint64()) 

35 ) or tl.constexpr(inp_vals.dtype.is_fp64()): 

36 inp_vals = inp_vals 

37 elif tl.constexpr(inp_vals.dtype.is_int()): 

38 inp_vals = inp_vals.to(tl.int32) 

39 else: 

40 inp_vals = inp_vals.to(tl.float32) 

41 result = tl.cumsum(inp_vals, axis=0) 

42 

43 part_sum_via_sum = tl.sum(inp_vals) 

44 

45 out_ptrs = out + offset 

46 tl.store(out_ptrs, result, mask=mask) 

47 

48 partial_sum_ptrs = partial_sum + pid 

49 tl.store(partial_sum_ptrs, part_sum_via_sum) 

50 

51 

52@libentry() 

53@triton.jit(do_not_specialize=["n_elements", "part_num"]) 

54def add_base_sum_kernel( 

55 out, 

56 partial_sum, 

57 n_elements, 

58 part_num, 

59 BLOCK_SIZE: tl.constexpr, 

60): 

61 pid = tle.program_id(0) 

62 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

63 mask = offset < n_elements 

64 

65 out_ptrs = out + offset 

66 out_vals = tl.load(out_ptrs, mask=mask) 

67 

68 if pid > 0: 

69 partial_sum_ptrs = partial_sum + pid - 1 

70 last_part_sum_via_sum = tl.load(partial_sum_ptrs) 

71 

72 final_vals = out_vals + last_part_sum_via_sum 

73 tl.store(out_ptrs, final_vals.to(out_vals.dtype), mask=mask) 

74 

75 

76@libentry() 

77@triton.jit(do_not_specialize=["part_num"]) 

78def scan_part_sum_abc_kernel( 

79 inp, 

80 out, 

81 partial_sum, 

82 B, 

83 C, 

84 part_num, 

85 BLOCK_SIZE: tl.constexpr, 

86): 

87 pid_a = tle.program_id(0) 

88 pid_b = tle.program_id(1) 

89 pid_c = tle.program_id(2) 

90 

91 a_idx = pid_a 

92 b_idx = pid_b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

93 c_idx = pid_c 

94 

95 offset = a_idx * B * C + b_idx * C + c_idx 

96 base_part_offset = a_idx * part_num * C + c_idx 

97 part_offset = base_part_offset + pid_b * C 

98 

99 mask = b_idx < B 

100 inp_ptrs = inp + offset 

101 inp_vals = tl.load(inp_ptrs, mask=mask) 

102 if ( 

103 tl.constexpr(inp_vals.dtype.is_int64()) 

104 or tl.constexpr(inp_vals.dtype.is_uint64()) 

105 ) or tl.constexpr(inp_vals.dtype.is_fp64()): 

106 inp_vals = inp_vals 

107 elif tl.constexpr(inp_vals.dtype.is_int()): 

108 inp_vals = inp_vals.to(tl.int32) 

109 else: 

110 inp_vals = inp_vals.to(tl.float32) 

111 result = tl.cumsum(inp_vals, axis=0) 

112 

113 part_sum_via_sum = tl.sum(inp_vals) 

114 

115 out_ptrs = out + offset 

116 tl.store(out_ptrs, result, mask=mask) 

117 

118 partial_sum_ptrs = partial_sum + part_offset 

119 tl.store(partial_sum_ptrs, part_sum_via_sum) 

120 

121 

122@libentry() 

123@triton.jit(do_not_specialize=["part_num"]) 

124def add_base_sum_abc_kernel( 

125 out, 

126 partial_sum, 

127 B, 

128 C, 

129 part_num, 

130 BLOCK_SIZE: tl.constexpr, 

131): 

132 pid_a = tle.program_id(0) 

133 pid_b = tle.program_id(1) 

134 pid_c = tle.program_id(2) 

135 

136 a_idx = pid_a 

137 b_idx = pid_b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

138 c_idx = pid_c 

139 

140 base_offset = a_idx * B * C + c_idx 

141 offset = base_offset + b_idx * C 

142 base_part_offset = a_idx * part_num * C + c_idx 

143 last_part_offset = base_part_offset + (pid_b - 1) * C 

144 

145 mask = b_idx < B 

146 out_ptrs = out + offset 

147 out_vals = tl.load(out_ptrs, mask=mask) 

148 

149 if pid_b > 0: 

150 partial_sum_ptrs = partial_sum + last_part_offset 

151 last_part_sum_via_sum = tl.load(partial_sum_ptrs) 

152 

153 final_vals = out_vals + last_part_sum_via_sum 

154 tl.store(out_ptrs, final_vals.to(out_vals.dtype), mask=mask) 

155 

156 

157def scan_then_fan_col(inp, out, n_ele, dtype): 

158 # TODO(all): tune on target board 

159 BLOCK_SIZE = 1024 

160 if n_ele <= 1024 * 4: 

161 BLOCK_SIZE = triton.next_power_of_2(n_ele) 

162 part_num = math.ceil(n_ele / BLOCK_SIZE) 

163 partial_sum = torch.empty(part_num, dtype=dtype, device=inp.device) 

164 

165 grid = (part_num,) 

166 with torch_device_fn.device(inp.device): 

167 scan_part_sum_kernel[grid](inp, out, partial_sum, n_ele, part_num, BLOCK_SIZE) 

168 

169 if part_num >= 2: 

170 scan_then_fan_col(partial_sum, partial_sum, part_num, dtype) 

171 with torch_device_fn.device(inp.device): 

172 add_base_sum_kernel[grid](out, partial_sum, n_ele, part_num, BLOCK_SIZE) 

173 

174 

175def scan_then_fan(inp, out, A, B, C, dtype): 

176 # TODO(all): tune on target board 

177 BLOCK_SIZE = 1024 

178 if B <= 1024 * 4: 

179 BLOCK_SIZE = triton.next_power_of_2(B) 

180 part_num = math.ceil(B / BLOCK_SIZE) 

181 partial_sum = torch.empty(A, part_num, C, dtype=dtype, device=inp.device) 

182 

183 grid = (A, part_num, C) 

184 with torch_device_fn.device(inp.device): 

185 scan_part_sum_abc_kernel[grid]( 

186 inp, out, partial_sum, B, C, part_num, BLOCK_SIZE 

187 ) 

188 

189 if part_num >= 2: 

190 scan_then_fan(partial_sum, partial_sum, A, part_num, C, dtype) 

191 with torch_device_fn.device(inp.device): 

192 add_base_sum_abc_kernel[grid](out, partial_sum, B, C, part_num, BLOCK_SIZE) 

193 

194 

195def cumsum_wrapper(inp, dim=1, dtype=None, out=None): 

196 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

197 shape = inp.shape 

198 dim = dim % inp.ndim 

199 M = 1 

200 N = shape[dim] 

201 for i in range(dim): 

202 M *= shape[i] 

203 inp = inp.contiguous() 

204 K = inp.numel() // M // N 

205 

206 if dtype is None: 

207 dtype = inp.dtype 

208 if dtype is torch.bool: 

209 dtype = torch.int64 

210 if out is None: 

211 out = torch.empty_like(inp, dtype=dtype) 

212 

213 compute_dtype = out.dtype 

214 if inp.dtype == torch.float16 or inp.dtype == torch.bfloat16: 

215 compute_dtype = torch.float32 

216 

217 if M == 1 and K == 1: 

218 scan_then_fan_col(inp, out, N, compute_dtype) 

219 else: 

220 scan_then_fan(inp, out, M, N, K, compute_dtype) 

221 return out 

222 

223 

224def cumsum(inp, dim=1, *, dtype=None): 

225 logger.debug("GEMS CUMSUM") 

226 return cumsum_wrapper(inp, dim, dtype) 

227 

228 

229def cumsum_out(inp, dim=1, *, dtype=None, out): 

230 logger.debug("GEMS CUMSUM_OUT") 

231 return cumsum_wrapper(inp, dim, dtype, out) 

232 

233 

234@libentry() 

235@triton.jit(do_not_specialize=["K"]) 

236def normed_cumsum_kernel(inp, out, K, BLOCK: tl.constexpr): 

237 row_start = tle.program_id(0) * K 

238 row_off = tl.arange(0, BLOCK) 

239 x = tl.load(inp + row_start + row_off, mask=row_off < K, other=0) 

240 if x.dtype.is_fp16(): 

241 x = x.to(tl.float32) 

242 y_sum = tl.sum(x, 0) 

243 y = tl.cumsum(x, 0) 

244 y = y / y_sum 

245 tl.store(out + row_start + row_off, y, mask=row_off < K) 

246 

247 

248@libentry() 

249@triton.jit( 

250 do_not_specialize=[ 

251 "r", 

252 "t", 

253 "R", 

254 "K", 

255 "r_stride", 

256 "out_r_stride", 

257 ] 

258) 

259def block_cumsum_kernel( 

260 inp, 

261 out, 

262 sums, 

263 r, 

264 t, 

265 R, 

266 K, 

267 r_stride, 

268 k_stride, 

269 out_r_stride, 

270 out_k_stride, 

271 OUTPUT_SUMS: tl.constexpr, 

272 NORMALIZE: tl.constexpr, 

273 HAS_OUT_LAYOUT: tl.constexpr, 

274 TILE: tl.constexpr, 

275): 

276 # One CTA processes a (r, t*tile) chunk 

277 # rows = [ grid.y, grid.y + r ) 

278 # cols = [ grid.x * t * tile, (grid.x + 1) * t * tile ) 

279 gridx = tle.program_id(0).to(tl.int64) 

280 gridy = tle.program_id(1).to(tl.int64) 

281 n_chunks = tle.num_programs(0) 

282 

283 for row in range(gridy * r, min((gridy + 1) * r, R)): 

284 curr_cumsum = tl.zeros((1,), tl.float32) 

285 row_offset = row * r_stride 

286 cols = gridx * t * TILE + tl.arange(0, TILE) 

287 for ti in range(0, t): 

288 cols_offset = cols * k_stride 

289 x = tl.load(inp + row_offset + cols_offset, mask=cols < K, other=0) 

290 if x.dtype.is_fp16() | x.dtype.is_bf16(): 

291 x = x.to(tl.float32) 

292 tile_sum = tl.sum(x, 0)[None] 

293 tile_cumsum = tl.cumsum(x, 0) + curr_cumsum 

294 curr_cumsum += tile_sum 

295 if HAS_OUT_LAYOUT: 

296 cols_offset = cols * out_k_stride 

297 row_offset = row * out_r_stride 

298 tl.store(out + row_offset + cols_offset, tile_cumsum, mask=cols < K) 

299 if OUTPUT_SUMS: 

300 tl.store(sums + row * n_chunks + gridx[None], curr_cumsum) 

301 cols += TILE 

302 if NORMALIZE: 

303 cols = gridx * t * TILE + tl.arange(0, TILE) 

304 for _ in range(0, t): 

305 cols_offset = cols * k_stride 

306 if HAS_OUT_LAYOUT: 

307 cols_offset = cols * out_k_stride 

308 row_offset = row * out_r_stride 

309 x = tl.load(out + row_offset + cols_offset, mask=cols < K, other=0) 

310 if x.dtype.is_fp16() | x.dtype.is_bf16(): 

311 x = x.to(tl.float32) 

312 x = x / curr_cumsum 

313 tl.store(out + row_offset + cols_offset, x, mask=cols < K) 

314 cols += TILE 

315 

316 

317@libentry() 

318@triton.jit( 

319 do_not_specialize=[ 

320 "r", 

321 "t", 

322 "R", 

323 "K", 

324 "r_stride", 

325 "out_r_stride", 

326 ] 

327) 

328def block_update_kernel( 

329 inp, 

330 base, 

331 rscale_ptr, 

332 out, 

333 r, 

334 t, 

335 R, 

336 K, 

337 r_stride, 

338 k_stride, 

339 out_r_stride, 

340 out_k_stride, 

341 rscale_stride, 

342 HAS_OUT_LAYOUT: tl.constexpr, 

343 TILE: tl.constexpr, 

344): 

345 # One CTA processes a (r, t*tile) chunk 

346 # rows = [ grid.y, grid.y + r ) 

347 # cols = [ grid.x * t * tile, (grid.x + 1) * t * tile ) 

348 gridx = tle.program_id(0).to(tl.int64) 

349 gridy = tle.program_id(1).to(tl.int64) 

350 n_gridx = tle.num_programs(1) 

351 

352 base += gridy * n_gridx + gridx 

353 rscale_ptr += gridy * rscale_stride 

354 

355 for row in range(gridy, min(gridy + r, R)): 

356 d = tl.load(base) 

357 rscale = tl.load(rscale_ptr) 

358 base += gridx 

359 rscale_ptr += rscale_stride 

360 row_offset = row * r_stride 

361 cols = gridx * t * TILE + tl.arange(0, TILE) 

362 for _ in range(0, t): 

363 cols_offset = cols * k_stride 

364 x = tl.load(inp + row_offset + cols_offset, mask=cols < K, other=0) 

365 x += d 

366 x /= rscale 

367 if HAS_OUT_LAYOUT: 

368 cols_offset = cols * out_k_stride 

369 row_offset = row * out_r_stride 

370 tl.store(out + row_offset + cols_offset, x, mask=cols < K) 

371 cols += TILE 

372 

373 

374GRID_Y_LIMIT = 65535 

375 

376 

377def normed_cumsum(inp, dim=-1): 

378 logger.debug("GEMS NORMED_CUMSUM") 

379 assert inp.dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64) 

380 dim = dim % inp.ndim 

381 N = inp.numel() 

382 K = inp.size(dim) 

383 # inp = inp.contiguous() 

384 # First and last dims are easier to handle, but transpose the middle dim to the last 

385 ranked_dims = sorted(range(inp.ndim), key=lambda i: inp.stride(i), reverse=True) 

386 is_mid_dim = dim not in (ranked_dims[0], ranked_dims[-1]) 

387 if is_mid_dim: 

388 inp = inp.transpose(dim, -1).contiguous() 

389 dim = -1 

390 out = torch.empty_like(inp) 

391 with torch_device_fn.device(inp.device.index): 

392 # Pass one, scan a (batch, n_tiles * TILE) sized block within each cta 

393 num_sms = get_device_properties(device).multi_processor_count 

394 TILE = ( 

395 2048 if K >= 2048 else triton.next_power_of_2(K) 

396 ) # TODO: _aipu changed TILE from 2048 to this 

397 # Each row is split into n_chunks of chunks where each chunk is compised of 

398 # n_tiles of tiles. Different chunks are assigned to different ctas. 

399 n_rows = N // K 

400 n_chunks = min(triton.cdiv(num_sms, n_rows), triton.cdiv(K, TILE)) 

401 n_tiles = triton.cdiv(triton.cdiv(K, TILE), n_chunks) 

402 k_stride = inp.stride(dim) 

403 r_stride = inp.size(dim) if k_stride == 1 else 1 

404 if n_rows > GRID_Y_LIMIT: 

405 batch = triton.cdiv(n_rows, GRID_Y_LIMIT) 

406 n_batch = triton.cdiv(n_rows, batch) 

407 else: 

408 batch = 1 

409 n_batch = n_rows 

410 

411 grid = (n_chunks, n_batch) 

412 if n_chunks == 1: 

413 block_cumsum_kernel[grid]( 

414 inp, 

415 out, 

416 0, 

417 batch, 

418 n_tiles, 

419 n_rows, 

420 K, 

421 r_stride, 

422 k_stride, 

423 r_stride, 

424 k_stride, 

425 OUTPUT_SUMS=False, 

426 NORMALIZE=True, 

427 HAS_OUT_LAYOUT=False, 

428 TILE=TILE, 

429 ) 

430 return out 

431 

432 if inp.dtype != torch.float64: 

433 acc_dtype = torch.float32 

434 sums = torch.empty((n_rows, n_chunks), dtype=acc_dtype, device=device.name) 

435 cumsums = torch.empty_like(sums) 

436 block_cumsum_kernel[grid]( 

437 inp, 

438 out, 

439 sums, 

440 batch, 

441 n_tiles, 

442 n_rows, 

443 K, 

444 r_stride, 

445 k_stride, 

446 r_stride, 

447 k_stride, 

448 OUTPUT_SUMS=True, 

449 NORMALIZE=False, 

450 HAS_OUT_LAYOUT=False, 

451 TILE=TILE, 

452 ) 

453 # Pass two, scan partial cumsums 

454 block_cumsum_kernel[(1, n_batch)]( 

455 sums, 

456 cumsums, 

457 0, 

458 batch, 

459 1, 

460 n_rows, 

461 n_chunks, 

462 n_chunks, 

463 1, 

464 n_chunks, 

465 1, 

466 OUTPUT_SUMS=False, 

467 NORMALIZE=False, 

468 HAS_OUT_LAYOUT=True, 

469 TILE=TILE, 

470 ) 

471 # print(sums) 

472 rscale = cumsums[..., -1] 

473 block_update_kernel[grid]( 

474 out, 

475 cumsums - sums, 

476 rscale, 

477 out, 

478 batch, 

479 n_tiles, 

480 n_rows, 

481 K, 

482 r_stride, 

483 k_stride, 

484 r_stride, 

485 k_stride, 

486 n_chunks, 

487 HAS_OUT_LAYOUT=False, 

488 TILE=TILE, 

489 ) 

490 return out