Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/cumsum.py: 0%

269 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-18 02:36 +0800

1import logging 

2import math 

3import os 

4 

5import torch 

6import triton 

7import triton.language as tl 

8from torch._prims_common import is_boolean_dtype, is_integer_dtype 

9 

10from flag_gems.runtime import device, torch_device_fn 

11from flag_gems.utils import libentry 

12from flag_gems.utils import triton_lang_extension as tle 

13 

14logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

15device = device.name 

16 

17 

18@libentry() 

19@triton.jit(do_not_specialize=["n_elements", "part_num"]) 

20def scan_part_sum_kernel( 

21 inp, 

22 out, 

23 partial_sum, 

24 n_elements, 

25 part_num, 

26 BLOCK_SIZE: tl.constexpr, 

27): 

28 pid = tle.program_id(0) 

29 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

30 mask = offset < n_elements 

31 

32 inp_ptrs = inp + offset 

33 inp_vals = tl.load(inp_ptrs, mask=mask) 

34 if ( 

35 tl.constexpr(inp_vals.dtype.is_int64()) 

36 or tl.constexpr(inp_vals.dtype.is_uint64()) 

37 ) or tl.constexpr(inp_vals.dtype.is_fp64()): 

38 inp_vals = inp_vals 

39 elif tl.constexpr(inp_vals.dtype.is_int()): 

40 inp_vals = inp_vals.to(tl.int32) 

41 else: 

42 inp_vals = inp_vals.to(tl.float32) 

43 result = tl.cumsum(inp_vals, axis=0) 

44 

45 part_sum_via_sum = tl.sum(inp_vals) 

46 

47 out_ptrs = out + offset 

48 tl.store(out_ptrs, result, mask=mask) 

49 

50 partial_sum_ptrs = partial_sum + pid 

51 tl.store(partial_sum_ptrs, part_sum_via_sum) 

52 

53 

54@libentry() 

55@triton.jit(do_not_specialize=["n_elements", "part_num"]) 

56def add_base_sum_kernel( 

57 out, 

58 partial_sum, 

59 n_elements, 

60 part_num, 

61 BLOCK_SIZE: tl.constexpr, 

62): 

63 pid = tle.program_id(0) 

64 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

65 mask = offset < n_elements 

66 

67 out_ptrs = out + offset 

68 out_vals = tl.load(out_ptrs, mask=mask) 

69 

70 if pid > 0: 

71 partial_sum_ptrs = partial_sum + pid - 1 

72 last_part_sum_via_sum = tl.load(partial_sum_ptrs) 

73 

74 final_vals = out_vals + last_part_sum_via_sum 

75 tl.store(out_ptrs, final_vals.to(out_vals.dtype), mask=mask) 

76 

77 

78@libentry() 

79@triton.jit(do_not_specialize=["part_num"]) 

80def scan_part_sum_abc_kernel( 

81 inp, 

82 out, 

83 partial_sum, 

84 B, 

85 C, 

86 part_num, 

87 BLOCK_SIZE: tl.constexpr, 

88): 

89 pid_a = tle.program_id(0) 

90 pid_b = tle.program_id(1) 

91 pid_c = tle.program_id(2) 

92 

93 a_idx = pid_a 

94 b_idx = pid_b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

95 c_idx = pid_c 

96 

97 offset = a_idx * B * C + b_idx * C + c_idx 

98 base_part_offset = a_idx * part_num * C + c_idx 

99 part_offset = base_part_offset + pid_b * C 

100 

101 mask = b_idx < B 

102 inp_ptrs = inp + offset 

103 inp_vals = tl.load(inp_ptrs, mask=mask) 

104 if ( 

105 tl.constexpr(inp_vals.dtype.is_int64()) 

106 or tl.constexpr(inp_vals.dtype.is_uint64()) 

107 ) or tl.constexpr(inp_vals.dtype.is_fp64()): 

108 inp_vals = inp_vals 

109 elif tl.constexpr(inp_vals.dtype.is_int()): 

110 inp_vals = inp_vals.to(tl.int32) 

111 else: 

112 inp_vals = inp_vals.to(tl.float32) 

113 result = tl.cumsum(inp_vals, axis=0) 

114 

115 part_sum_via_sum = tl.sum(inp_vals) 

116 

117 offset = tl.where(mask, offset, -1) 

118 out_ptrs = out + offset 

119 tl.store(out_ptrs, result, mask=mask) 

120 

121 partial_sum_ptrs = partial_sum + part_offset 

122 tl.store(partial_sum_ptrs, part_sum_via_sum) 

123 

124 

125@libentry() 

126@triton.jit(do_not_specialize=["part_num"]) 

127def add_base_sum_abc_kernel( 

128 out, 

129 partial_sum, 

130 B, 

131 C, 

132 part_num, 

133 BLOCK_SIZE: tl.constexpr, 

134): 

135 pid_a = tle.program_id(0) 

136 pid_b = tle.program_id(1) 

137 pid_c = tle.program_id(2) 

138 

139 a_idx = pid_a 

140 b_idx = pid_b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

141 c_idx = pid_c 

142 

143 base_offset = a_idx * B * C + c_idx 

144 offset = base_offset + b_idx * C 

145 base_part_offset = a_idx * part_num * C + c_idx 

146 last_part_offset = base_part_offset + (pid_b - 1) * C 

147 

148 mask = b_idx < B 

149 out_ptrs = out + offset 

150 out_vals = tl.load(out_ptrs, mask=mask) 

151 

152 if pid_b > 0: 

153 partial_sum_ptrs = partial_sum + last_part_offset 

154 last_part_sum_via_sum = tl.load(partial_sum_ptrs) 

155 

156 final_vals = out_vals + last_part_sum_via_sum 

157 tl.store(out_ptrs, final_vals.to(out_vals.dtype), mask=mask) 

158 

159 

160def scan_then_fan_col(inp, out, n_ele, dtype): 

161 # TODO(all): tune on target board 

162 BLOCK_SIZE = 1024 

163 if n_ele <= 1024 * 4: 

164 BLOCK_SIZE = triton.next_power_of_2(n_ele) 

165 part_num = math.ceil(n_ele / BLOCK_SIZE) 

166 partial_sum = torch.empty(part_num, dtype=dtype, device=inp.device) 

167 

168 grid = (part_num,) 

169 with torch_device_fn.device(inp.device): 

170 scan_part_sum_kernel[grid](inp, out, partial_sum, n_ele, part_num, BLOCK_SIZE) 

171 

172 if part_num >= 2: 

173 scan_then_fan_col(partial_sum, partial_sum, part_num, dtype) 

174 with torch_device_fn.device(inp.device): 

175 add_base_sum_kernel[grid](out, partial_sum, n_ele, part_num, BLOCK_SIZE) 

176 

177 

178def scan_then_fan(inp, out, A, B, C, dtype): 

179 # TODO(all): tune on target board 

180 BLOCK_SIZE = 1024 

181 if B <= 1024 * 4: 

182 BLOCK_SIZE = triton.next_power_of_2(B) 

183 part_num = math.ceil(B / BLOCK_SIZE) 

184 partial_sum = torch.empty(A, part_num, C, dtype=dtype, device=inp.device) 

185 

186 grid = (A, part_num, C) 

187 

188 if inp.shape[1] > 8192: 

189 os.environ["TRITONXPU_OTHER_SIM"] = "1" 

190 os.environ["TRITONXPU_STORE_MASK_SIM"] = "1" 

191 scan_part_sum_abc_kernel[grid]( 

192 inp, out, partial_sum, B, C, part_num, BLOCK_SIZE 

193 ) 

194 

195 if "TRITONXPU_OTHER_SIM" in os.environ: 

196 del os.environ["TRITONXPU_OTHER_SIM"] 

197 if "TRITONXPU_STORE_MASK_SIM" in os.environ: 

198 del os.environ["TRITONXPU_STORE_MASK_SIM"] 

199 

200 else: 

201 with torch_device_fn.device(inp.device): 

202 scan_part_sum_abc_kernel[grid]( 

203 inp, out, partial_sum, B, C, part_num, BLOCK_SIZE 

204 ) 

205 

206 if part_num >= 2: 

207 scan_then_fan(partial_sum, partial_sum, A, part_num, C, dtype) 

208 with torch_device_fn.device(inp.device): 

209 add_base_sum_abc_kernel[grid](out, partial_sum, B, C, part_num, BLOCK_SIZE) 

210 

211 

212def cumsum_wrapper(inp, dim=1, dtype=None, out=None): 

213 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

214 shape = inp.shape 

215 dim = dim % inp.ndim 

216 M = 1 

217 N = shape[dim] 

218 for i in range(dim): 

219 M *= shape[i] 

220 inp = inp.contiguous() 

221 K = inp.numel() // M // N 

222 

223 if dtype is None: 

224 dtype = inp.dtype 

225 if is_integer_dtype(dtype) or is_boolean_dtype(dtype): 

226 dtype = torch.int64 

227 if out is None: 

228 out = torch.empty_like(inp, dtype=dtype) 

229 

230 compute_dtype = out.dtype 

231 if inp.dtype == torch.float16 or inp.dtype == torch.bfloat16: 

232 compute_dtype = torch.float32 

233 

234 if M == 1 and K == 1: 

235 scan_then_fan_col(inp, out, N, compute_dtype) 

236 else: 

237 scan_then_fan(inp, out, M, N, K, compute_dtype) 

238 return out 

239 

240 

241def cumsum(inp, dim=1, *, dtype=None): 

242 logger.debug("GEMS CUMSUM") 

243 return cumsum_wrapper(inp, dim, dtype) 

244 

245 

246def cumsum_out(inp, dim=1, *, dtype=None, out): 

247 logger.debug("GEMS CUMSUM_OUT") 

248 return cumsum_wrapper(inp, dim, dtype, out) 

249 

250 

251@libentry() 

252@triton.jit(do_not_specialize=["K"]) 

253def normed_cumsum_kernel(inp, out, K, BLOCK: tl.constexpr): 

254 row_start = tle.program_id(0) * K 

255 row_off = tl.arange(0, BLOCK) 

256 x = tl.load(inp + row_start + row_off, mask=row_off < K, other=0) 

257 if x.dtype.is_fp16(): 

258 x = x.to(tl.float32) 

259 y_sum = tl.sum(x, 0) 

260 y = tl.cumsum(x, 0) 

261 y = y / y_sum 

262 tl.store(out + row_start + row_off, y, mask=row_off < K) 

263 

264 

265@libentry() 

266@triton.jit( 

267 do_not_specialize=[ 

268 "r", 

269 "t", 

270 "R", 

271 "K", 

272 "r_stride", 

273 "out_r_stride", 

274 ] 

275) 

276def block_cumsum_kernel( 

277 inp, 

278 out, 

279 sums, 

280 r: tl.constexpr, 

281 t: tl.constexpr, 

282 R: tl.constexpr, 

283 K: tl.constexpr, 

284 r_stride: tl.constexpr, 

285 k_stride: tl.constexpr, 

286 out_r_stride: tl.constexpr, 

287 out_k_stride: tl.constexpr, 

288 OUTPUT_SUMS: tl.constexpr, 

289 NORMALIZE: tl.constexpr, 

290 HAS_OUT_LAYOUT: tl.constexpr, 

291 TILE: tl.constexpr, 

292): 

293 # One CTA processes a (r, t*tile) chunk 

294 # rows = [ grid.y, grid.y + r ) 

295 # cols = [ grid.x * t * tile, (grid.x + 1) * t * tile ) 

296 gridx = tle.program_id(0).to(tl.int64) 

297 gridy = tle.program_id(1).to(tl.int64) 

298 n_chunks = tle.num_programs(0) 

299 

300 for row in range(gridy * r, min((gridy + 1) * r, R)): 

301 curr_cumsum = tl.zeros((1,), tl.float32) 

302 row_offset = row * r_stride 

303 cols = gridx * t * TILE + tl.arange(0, TILE) 

304 for ti in range(0, t): 

305 cols_offset = cols * k_stride 

306 x = tl.load(inp + row_offset + cols_offset, mask=cols < K, other=0) 

307 if x.dtype.is_fp16() | x.dtype.is_bf16(): 

308 x = x.to(tl.float32) 

309 tile_sum = tl.sum(x, 0)[None] 

310 tile_cumsum = tl.cumsum(x, 0) + curr_cumsum 

311 curr_cumsum += tile_sum 

312 if HAS_OUT_LAYOUT: 

313 cols_offset = cols * out_k_stride 

314 row_offset = row * out_r_stride 

315 tl.store(out + row_offset + cols_offset, tile_cumsum, mask=cols < K) 

316 if OUTPUT_SUMS: 

317 tl.store(sums + row * n_chunks + gridx[None], curr_cumsum) 

318 cols += TILE 

319 if NORMALIZE: 

320 cols = gridx * t * TILE + tl.arange(0, TILE) 

321 for _ in range(0, t): 

322 cols_offset = cols * k_stride 

323 if HAS_OUT_LAYOUT: 

324 cols_offset = cols * out_k_stride 

325 row_offset = row * out_r_stride 

326 x = tl.load(out + row_offset + cols_offset, mask=cols < K, other=0) 

327 if x.dtype.is_fp16() | x.dtype.is_bf16(): 

328 x = x.to(tl.float32) 

329 x = x / curr_cumsum 

330 tl.store(out + row_offset + cols_offset, x, mask=cols < K) 

331 cols += TILE 

332 

333 

334@libentry() 

335@triton.jit( 

336 do_not_specialize=[ 

337 "r", 

338 "t", 

339 "R", 

340 "K", 

341 "r_stride", 

342 "out_r_stride", 

343 ] 

344) 

345def block_update_kernel( 

346 inp, 

347 base, 

348 rscale_ptr, 

349 out, 

350 r, 

351 t, 

352 R, 

353 K, 

354 r_stride, 

355 k_stride, 

356 out_r_stride, 

357 out_k_stride, 

358 rscale_stride, 

359 HAS_OUT_LAYOUT: tl.constexpr, 

360 TILE: tl.constexpr, 

361): 

362 # One CTA processes a (r, t*tile) chunk 

363 # rows = [ grid.y, grid.y + r ) 

364 # cols = [ grid.x * t * tile, (grid.x + 1) * t * tile ) 

365 gridx = tle.program_id(0).to(tl.int64) 

366 gridy = tle.program_id(1).to(tl.int64) 

367 n_gridx = tle.num_programs(1) 

368 

369 base += gridy * n_gridx + gridx 

370 rscale_ptr += gridy * rscale_stride 

371 

372 for row in range(gridy, min(gridy + r, R)): 

373 d = tl.load(base) 

374 rscale = tl.load(rscale_ptr) 

375 base += gridx 

376 rscale_ptr += rscale_stride 

377 row_offset = row * r_stride 

378 cols = gridx * t * TILE + tl.arange(0, TILE) 

379 for _ in range(0, t): 

380 cols_offset = cols * k_stride 

381 x = tl.load(inp + row_offset + cols_offset, mask=cols < K, other=0) 

382 x += d 

383 x /= rscale 

384 if HAS_OUT_LAYOUT: 

385 cols_offset = cols * out_k_stride 

386 row_offset = row * out_r_stride 

387 tl.store(out + row_offset + cols_offset, x, mask=cols < K) 

388 cols += TILE 

389 

390 

391GRID_Y_LIMIT = 65535 

392 

393 

394def normed_cumsum(inp, dim=-1): 

395 logger.debug("GEMS NORMED_CUMSUM") 

396 assert inp.dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64) 

397 dim = dim % inp.ndim 

398 N = inp.numel() 

399 K = inp.size(dim) 

400 # inp = inp.contiguous() 

401 # First and last dims are easier to handle, but transpose the middle dim to the last 

402 ranked_dims = sorted(range(inp.ndim), key=lambda i: inp.stride(i), reverse=True) 

403 is_mid_dim = dim not in (ranked_dims[0], ranked_dims[-1]) 

404 if is_mid_dim: 

405 inp = inp.transpose(dim, -1).contiguous() 

406 dim = -1 

407 out = torch.empty_like(inp) 

408 with torch_device_fn.device(inp.device.index): 

409 # Pass one, scan a (batch, n_tiles * TILE) sized block within each cta 

410 num_sms = torch_device_fn.get_device_properties(device).multi_processor_count 

411 TILE = 8192 

412 # Each row is split into n_chunks of chunks where each chunk is compised of 

413 # n_tiles of tiles. Different chunks are assigned to different ctas. 

414 n_rows = N // K 

415 n_chunks = min(triton.cdiv(num_sms, n_rows), triton.cdiv(K, TILE)) 

416 n_tiles = triton.cdiv(triton.cdiv(K, TILE), n_chunks) 

417 k_stride = inp.stride(dim) 

418 r_stride = inp.size(dim) if k_stride == 1 else 1 

419 if n_rows > GRID_Y_LIMIT: 

420 batch = triton.cdiv(n_rows, GRID_Y_LIMIT) 

421 n_batch = triton.cdiv(n_rows, batch) 

422 else: 

423 batch = 1 

424 n_batch = n_rows 

425 

426 grid = (n_chunks, n_batch) 

427 if n_chunks == 1: 

428 block_cumsum_kernel[grid]( 

429 inp, 

430 out, 

431 0, 

432 batch, 

433 n_tiles, 

434 n_rows, 

435 K, 

436 r_stride, 

437 k_stride, 

438 r_stride, 

439 k_stride, 

440 OUTPUT_SUMS=False, 

441 NORMALIZE=True, 

442 HAS_OUT_LAYOUT=False, 

443 TILE=TILE, 

444 isCloseUnrollControl=True, 

445 ) 

446 return out 

447 

448 if inp.dtype != torch.float64: 

449 acc_dtype = torch.float32 

450 sums = torch.empty((n_rows, n_chunks), dtype=acc_dtype, device=device.name) 

451 cumsums = torch.empty_like(sums) 

452 block_cumsum_kernel[grid]( 

453 inp, 

454 out, 

455 sums, 

456 batch, 

457 n_tiles, 

458 n_rows, 

459 K, 

460 r_stride, 

461 k_stride, 

462 r_stride, 

463 k_stride, 

464 OUTPUT_SUMS=True, 

465 NORMALIZE=False, 

466 HAS_OUT_LAYOUT=False, 

467 TILE=TILE, 

468 isCloseUnrollControl=True, 

469 ) 

470 # Pass two, scan partial cumsums 

471 block_cumsum_kernel[(1, n_batch)]( 

472 sums, 

473 cumsums, 

474 0, 

475 batch, 

476 1, 

477 n_rows, 

478 n_chunks, 

479 n_chunks, 

480 1, 

481 n_chunks, 

482 1, 

483 OUTPUT_SUMS=False, 

484 NORMALIZE=False, 

485 HAS_OUT_LAYOUT=True, 

486 TILE=TILE, 

487 isCloseUnrollControl=True, 

488 ) 

489 # print(sums) 

490 rscale = cumsums[..., -1] 

491 block_update_kernel[grid]( 

492 out, 

493 cumsums - sums, 

494 rscale, 

495 out, 

496 batch, 

497 n_tiles, 

498 n_rows, 

499 K, 

500 r_stride, 

501 k_stride, 

502 r_stride, 

503 k_stride, 

504 n_chunks, 

505 HAS_OUT_LAYOUT=False, 

506 TILE=TILE, 

507 ) 

508 return out