Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/sort.py: 0%

307 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-22 16:54 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry 

10 

11from .topk import _get_finfo_val, _get_iinfo_val, argsort 

12 

13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

14 

15 

16def unwrap_if_constexpr(o): 

17 return o.value if isinstance(o, tl.constexpr) else o 

18 

19 

20@tl.constexpr 

21def get_int_t(num_bits: tl.constexpr, signed: tl.constexpr) -> tl.dtype: 

22 num_bits = unwrap_if_constexpr(num_bits) 

23 signed = unwrap_if_constexpr(signed) 

24 return tl.core.get_int_dtype(num_bits, signed) 

25 

26 

27@tl.constexpr 

28def one_zeros(num_bits: tl.constexpr) -> int: 

29 num_bits = unwrap_if_constexpr(num_bits) 

30 return 1 << (num_bits - 1) 

31 

32 

33@tl.constexpr 

34def zero_ones(num_bits: tl.constexpr) -> int: 

35 num_bits = unwrap_if_constexpr(num_bits) 

36 return (1 << (num_bits - 1)) - 1 

37 

38 

39@triton.jit 

40def uint_to_uint(x, descending: tl.constexpr = False): 

41 out = ~x if descending else x 

42 return out 

43 

44 

45@triton.jit 

46def int_to_uint(x, descending: tl.constexpr = False): 

47 num_bits: tl.constexpr = x.dtype.primitive_bitwidth 

48 udtype = get_int_t(num_bits, False) 

49 ux = tl.cast(x, udtype, bitcast=True) 

50 if descending: 

51 # 0111111....1 

52 bit_mask: tl.constexpr = zero_ones(num_bits) 

53 bit_mask_tensor = tl.full((), value=bit_mask, dtype=udtype) 

54 out = ux ^ bit_mask_tensor 

55 else: 

56 # 1000000...0 

57 sign_bit_mask: tl.constexpr = one_zeros(num_bits) 

58 sign_bit_mask_tensor = tl.full((), value=sign_bit_mask, dtype=udtype) 

59 out = ux ^ sign_bit_mask_tensor 

60 return out 

61 

62 

63@triton.jit 

64def floating_to_uint(x, descending: tl.constexpr = False): 

65 num_bits: tl.constexpr = x.dtype.primitive_bitwidth 

66 sdtype = get_int_t(num_bits, True) 

67 udtype = get_int_t(num_bits, False) 

68 sx = x.to(sdtype, bitcast=True) 

69 ux = x.to(udtype, bitcast=True) 

70 

71 sign_bit_mask_v: tl.constexpr = one_zeros(num_bits) 

72 sign_bit_mask = tl.full((), value=sign_bit_mask_v, dtype=udtype) 

73 # mind the dtype, right_shift for signed is arithmetic right shift 

74 # Fix for triton 3.1 or else `sx >> rshift_bits` is promoted to int32 

75 rshift_bits = tl.full((), value=num_bits - 1, dtype=sdtype) 

76 mask = sign_bit_mask | (sx >> rshift_bits).to(udtype, bitcast=True) 

77 tl.static_assert(mask.dtype == udtype, "type mismatch") 

78 # 1000000000...0 for positive 

79 # 1111111111...1 for negative 

80 if descending: 

81 out = ux ^ (~mask) 

82 else: 

83 out = ux ^ mask 

84 return out.to(udtype, bitcast=True) 

85 

86 

87@triton.jit 

88def convert_to_uint_preverse_order(x: tl.tensor, descending: tl.constexpr = False): 

89 if x.dtype.is_floating(): 

90 if x.dtype == tl.bfloat16: 

91 x = x.to(tl.float32) 

92 out = floating_to_uint(x, descending) 

93 elif x.dtype.is_int_signed(): 

94 out = int_to_uint(x, descending) 

95 elif x.dtype.is_int_unsigned(): 

96 out = uint_to_uint(x, descending) 

97 return out 

98 

99 

100@triton.jit 

101def compute_global_hist_kernel( 

102 arr_ptr, 

103 out_ptr, 

104 num_passes, 

105 m, 

106 n, 

107 tiles_n_per_cta, 

108 TILE_N: tl.constexpr, 

109 TILE_R: tl.constexpr, 

110 num_bits_per_pass: tl.constexpr, 

111 descending: tl.constexpr, 

112): 

113 # arr_ptr: (m, n) 

114 # out_ptr: (m, n_passes, r), where r = 2 ** k_bits is the number of bins 

115 pid = tl.program_id(0) 

116 pid_n = pid // m 

117 pid_m = pid % m 

118 

119 r: tl.constexpr = 2**num_bits_per_pass 

120 bfe_mask: tl.constexpr = (1 << num_bits_per_pass) - 1 # a.k.a. 2 ** k_bits - 1 

121 CTA_TILE_N: tl.constexpr = TILE_N * tiles_n_per_cta 

122 cta_n_start = CTA_TILE_N * pid_n 

123 cta_n_end = tl.minimum(cta_n_start + CTA_TILE_N, n) 

124 

125 for p in range(0, num_passes): # parallel 

126 bit_offset = p * num_bits_per_pass 

127 for r_start in range(0, r, TILE_R): # parallel 

128 bin_indices = r_start + tl.arange(0, TILE_R) 

129 acc = tl.zeros((TILE_R, TILE_N), dtype=tl.int64) 

130 for n_start in range(cta_n_start, cta_n_end, TILE_N): # sequantial 

131 n_offsets = n_start + tl.arange(0, TILE_N) # (TILE_N, ) 

132 mask = n_offsets < cta_n_end 

133 arr = tl.load(arr_ptr + pid_m * n + n_offsets, mask=mask) 

134 arr = convert_to_uint_preverse_order(arr, descending) 

135 key = (arr >> bit_offset) & bfe_mask # (TILE_N, ) 

136 matches = tl.where( 

137 mask, (bin_indices[:, None] == key), False 

138 ) # (TILE_R, TILE_N) 

139 acc += matches 

140 local_sum = tl.sum(acc, axis=1) 

141 tl.atomic_add( 

142 out_ptr + pid_m * num_passes * r + p * r + bin_indices, 

143 local_sum, 

144 sem="relaxed", 

145 ) 

146 

147 

148@triton.jit 

149def sweep( 

150 arr_ptr, 

151 associate_arr_ptr, # inputs: (key & value) 

152 out_ptr, 

153 associate_out_ptr, # outputs: (key & value) 

154 excumsum_bins_ptr, 

155 status_ptr, # aux input and status 

156 n_passes, 

157 pass_id, 

158 bit_offset, 

159 m, 

160 N, 

161 OUT_N, 

162 TILE_N: tl.constexpr, 

163 TILE_R: tl.constexpr, 

164 k_bits: tl.constexpr, 

165 descending: tl.constexpr, 

166): 

167 # r: num_bins = 2 ** k_bits 

168 # OUT_N: grid_n = cdiv(N, ) 

169 

170 # arr_ptr: (m, N) 

171 # out_ptr: (m, N) 

172 # excumsum_bins_ptr: (m, n_passes, r) 

173 # flag_ptr: (m, r, OUT_N) 

174 

175 # grid: (m, grid_r, grid_n) 

176 

177 # load data 

178 pid = tl.program_id(0) 

179 pid_m = pid % m 

180 pid_n = pid // m 

181 pid_r = tl.program_id(1) 

182 

183 # bit masks 

184 aggregate_mask: tl.constexpr = 1 << 30 

185 inclusive_prefix_mask: tl.constexpr = 1 << 31 

186 v_mask: tl.constexpr = (1 << 30) - 1 

187 bfe_mask: tl.constexpr = (1 << k_bits) - 1 # a.k.a. 2 ** k_bits - 1 

188 

189 # initialize flag to zero-local sum is not ready 

190 r: tl.constexpr = 2**k_bits 

191 cta_r_start = pid_r * TILE_R 

192 cta_r_end = tl.minimum(cta_r_start + TILE_R, r) 

193 

194 # cumsum for a bin_index 

195 n_offsets = pid_n * TILE_N + tl.arange(0, TILE_N) # (TILE_N, ) 

196 mask = n_offsets < N 

197 arr = tl.load(arr_ptr + pid_m * N + n_offsets, mask=mask) 

198 arr_u = convert_to_uint_preverse_order(arr, descending) 

199 key = (arr_u >> bit_offset) & bfe_mask # (TILE_N, ) 

200 

201 # since triton can only use scalar as condition, loop by bin_index 

202 # status must be pre zero-initialized, or else we have to initialize it 

203 for bin_index in range(cta_r_start, cta_r_end): 

204 matches = tl.where(mask, key == bin_index, False) # (TILE_N, ) bool 

205 # cta level cumsum per bin 

206 # CAUTION: tl.sum in triton 3.2 does not promote type 

207 local_sum = tl.sum(matches.to(tl.uint32), axis=0) 

208 pack0 = aggregate_mask | local_sum 

209 status_offset = pid_m * (r * OUT_N) + bin_index * OUT_N + pid_n 

210 tl.store(status_ptr + status_offset, pack0, cache_modifier=".cg") 

211 

212 # decoupled lookback 

213 exclusive_prefix = tl.zeros((), dtype=tl.uint32) 

214 i_lookback = pid_n - 1 

215 while i_lookback >= 0: 

216 flag_offset_i = pid_m * (r * OUT_N) + bin_index * OUT_N + i_lookback 

217 pack1 = tl.load(status_ptr + flag_offset_i, volatile=True) # uin32 

218 while pack1 == 0: 

219 pack1 = tl.load(status_ptr + flag_offset_i, volatile=True) 

220 exclusive_prefix += pack1 & v_mask 

221 if (pack1 & aggregate_mask) == aggregate_mask: 

222 i_lookback -= 1 

223 else: 

224 i_lookback = -1 

225 pack2 = inclusive_prefix_mask | (exclusive_prefix + local_sum) 

226 tl.store(status_ptr + status_offset, pack2, cache_modifier=".cg") 

227 

228 local_ex_cumsum = ( 

229 tl.cumsum(matches.to(tl.uint32), axis=0) - matches 

230 ) # (TILE_N, ) 

231 ex_cumsum_in_bin = ( 

232 exclusive_prefix + local_ex_cumsum 

233 ) # global ex_cumsum_in_bin (TILE_N, ) 

234 

235 # ex_cumsum_bins (m, n_passes, r) 

236 ex_cumsum_bins = tl.load( 

237 excumsum_bins_ptr + pid_m * (n_passes * r) + pass_id * r + bin_index 

238 ) # scalar 

239 pos = ex_cumsum_bins + ex_cumsum_in_bin # (TILE_N, ) 

240 

241 # scatter 

242 tl.store(out_ptr + pid_m * N + pos, arr, mask=matches) 

243 if associate_arr_ptr is not None: 

244 associate_arr = tl.load( 

245 associate_arr_ptr + pid_m * N + n_offsets, mask=mask 

246 ) 

247 tl.store(associate_out_ptr + pid_m * N + pos, associate_arr, mask=matches) 

248 

249 

250@triton.jit 

251def count_kernel( 

252 x_ptr, 

253 counts_ptr, # Output: [M, grid_n, num_bins] 

254 M, 

255 N, 

256 bit_offset, 

257 num_bins: tl.constexpr, 

258 BLOCK_N: tl.constexpr, 

259 descending: tl.constexpr, 

260): 

261 pid = tl.program_id(0) 

262 

263 num_blocks_per_row = tl.cdiv(N, BLOCK_N) 

264 row_idx = pid // num_blocks_per_row 

265 block_idx = pid % num_blocks_per_row 

266 

267 row_start = row_idx * N 

268 n_offset = block_idx * BLOCK_N + tl.arange(0, BLOCK_N) 

269 mask = n_offset < N 

270 

271 val = tl.load(x_ptr + row_start + n_offset, mask=mask, other=0) 

272 val_u = convert_to_uint_preverse_order(val, descending) 

273 

274 bfe_mask = num_bins - 1 

275 key = (val_u >> bit_offset) & bfe_mask 

276 

277 for i in range(num_bins): 

278 bin_mask = (key == i) & mask 

279 count = tl.sum(bin_mask.to(tl.int32)) 

280 out_offset = ( 

281 (row_idx * num_blocks_per_row * num_bins) + (block_idx * num_bins) + i 

282 ) 

283 tl.store(counts_ptr + out_offset, count) 

284 

285 

286@triton.jit 

287def scatter_kernel( 

288 x_ptr, 

289 x_out_ptr, 

290 idx_in_ptr, 

291 idx_out_ptr, 

292 global_offsets_ptr, 

293 M, 

294 N, 

295 bit_offset, 

296 num_bins: tl.constexpr, 

297 BLOCK_N: tl.constexpr, 

298 descending: tl.constexpr, 

299): 

300 pid = tl.program_id(0) 

301 num_blocks_per_row = tl.cdiv(N, BLOCK_N) 

302 row_idx = pid // num_blocks_per_row 

303 block_idx = pid % num_blocks_per_row 

304 

305 row_start = row_idx * N 

306 n_offset = block_idx * BLOCK_N + tl.arange(0, BLOCK_N) 

307 mask = n_offset < N 

308 

309 val = tl.load(x_ptr + row_start + n_offset, mask=mask, other=0) 

310 val_u = convert_to_uint_preverse_order(val, descending) 

311 

312 idx = tl.load(idx_in_ptr + row_start + n_offset, mask=mask, other=0) 

313 

314 bfe_mask = num_bins - 1 

315 key = (val_u >> bit_offset) & bfe_mask 

316 

317 for i in range(num_bins): 

318 bin_mask = (key == i) & mask 

319 local_rank = tl.cumsum(bin_mask.to(tl.int32), axis=0) - 1 

320 

321 offset_idx = ( 

322 (row_idx * num_blocks_per_row * num_bins) + (block_idx * num_bins) + i 

323 ) 

324 global_start = tl.load(global_offsets_ptr + offset_idx) 

325 

326 dest_idx = row_start + global_start + local_rank 

327 

328 tl.store(x_out_ptr + dest_idx, val, mask=bin_mask) 

329 tl.store(idx_out_ptr + dest_idx, idx, mask=bin_mask) 

330 

331 

332def radix_sort_low_mem(arr, k_bits=4, descending=False): 

333 if arr.ndim == 1: 

334 arr = arr.unsqueeze(0) 

335 M, N = arr.shape 

336 arr_in = arr 

337 arr_out = torch.empty_like(arr_in) 

338 

339 indices = ( 

340 torch.arange(N, device=arr.device, dtype=torch.int64) 

341 .broadcast_to(arr.shape) 

342 .clone() 

343 ) 

344 idx_in = indices 

345 idx_out = torch.empty_like(idx_in) 

346 

347 dtype = arr.dtype 

348 num_bits = 1 

349 if dtype == torch.bool: 

350 pass 

351 elif dtype == torch.bfloat16: 

352 num_bits = 4 * 8 

353 else: 

354 num_bits = arr.element_size() * 8 

355 num_passes = (num_bits + k_bits - 1) // k_bits 

356 num_bins = 2**k_bits 

357 

358 BLOCK_N = 512 

359 grid_n = triton.cdiv(N, BLOCK_N) 

360 grid = (M * grid_n,) 

361 

362 with torch_device_fn.device(arr.device): 

363 counts = torch.empty( 

364 (M, grid_n, num_bins), device=arr.device, dtype=torch.int32 

365 ) 

366 

367 for p in range(num_passes): 

368 bit_offset = p * k_bits 

369 count_kernel[grid]( 

370 arr_in, 

371 counts, 

372 M, 

373 N, 

374 bit_offset, 

375 num_bins, 

376 BLOCK_N, 

377 descending, 

378 is_use_mask_zero=True, 

379 ) 

380 

381 total_counts_per_bin = counts.sum(dim=1) 

382 bin_global_starts = ( 

383 torch.cumsum(total_counts_per_bin, dim=1) - total_counts_per_bin 

384 ) 

385 block_prefix_sum = torch.cumsum(counts, dim=1) - counts 

386 global_offsets = ( 

387 bin_global_starts.unsqueeze(1) 

388 .broadcast_to(block_prefix_sum.shape) 

389 .clone() 

390 + block_prefix_sum 

391 ) 

392 

393 scatter_kernel[grid]( 

394 arr_in, 

395 arr_out, 

396 idx_in, 

397 idx_out, 

398 global_offsets, 

399 M, 

400 N, 

401 bit_offset, 

402 num_bins, 

403 BLOCK_N, 

404 descending, 

405 is_use_mask_zero=True, 

406 ) 

407 

408 arr_in, arr_out = arr_out, arr_in 

409 idx_in, idx_out = idx_out, idx_in 

410 

411 return arr_in, idx_in 

412 

413 

414def radix_sort(arr, k_bits=8, descending=False): 

415 n = arr.shape[-1] 

416 m = arr.numel() // n 

417 assert n < (1 << 30), "we have not implemented 2**30 per launch" 

418 dtype = arr.dtype 

419 num_bits = 1 if dtype == torch.bool else (arr.element_size() * 8) 

420 

421 TILE_N = 1024 

422 tiles_n_per_cta = 8 

423 CTA_TILE_N = tiles_n_per_cta * TILE_N 

424 

425 num_bins = 2**k_bits 

426 n_passes = triton.cdiv(num_bits, k_bits) 

427 TILE_R = 16 

428 

429 grid_n = triton.cdiv(n, CTA_TILE_N) 

430 grid_for_global_hist = (m * grid_n, 1, 1) 

431 

432 with torch_device_fn.device(arr.device): 

433 global_hist = torch.zeros( 

434 (m, n_passes, num_bins), device=arr.device, dtype=torch.int32 

435 ) 

436 compute_global_hist_kernel[grid_for_global_hist]( 

437 arr, 

438 global_hist, 

439 n_passes, 

440 m, 

441 n, 

442 tiles_n_per_cta, 

443 TILE_N, 

444 TILE_R, 

445 k_bits, 

446 descending, 

447 ) 

448 ex_cumsum_bins = torch.cumsum(global_hist, -1) - global_hist 

449 ex_cumsum_bins = ex_cumsum_bins.to(torch.uint32) 

450 

451 # sort 

452 arr_in = torch.clone(arr) 

453 indices_in = ( 

454 torch.arange(0, n, dtype=torch.int64, device=arr_in.device) 

455 .broadcast_to(arr.shape) 

456 .contiguous() 

457 ) 

458 arr_out = torch.empty_like(arr) 

459 indices_out = torch.empty_like(indices_in) 

460 

461 TILE_R = 8 

462 grid_r = triton.cdiv(num_bins, TILE_R) 

463 TILE_N = 2048 

464 grid_n = triton.cdiv(n, TILE_N) 

465 grid_for_sweep = (m * grid_n, grid_r) 

466 

467 status = torch.empty( 

468 (m, num_bins, grid_n), device=arr.device, dtype=torch.uint32 

469 ) 

470 

471 for i in range(0, n_passes): 

472 bit_offset = i * k_bits 

473 status.zero_() 

474 sweep[grid_for_sweep]( 

475 arr_in, 

476 indices_in, 

477 arr_out, 

478 indices_out, 

479 ex_cumsum_bins, 

480 status, 

481 n_passes, 

482 i, 

483 bit_offset, 

484 m, 

485 n, 

486 grid_n, 

487 TILE_N, 

488 TILE_R, 

489 k_bits, 

490 descending, 

491 ) 

492 # print(f"< sorted last {bit_offset + k_bits:>2d} bits: {arr_out}") 

493 arr_in, arr_out = arr_out, arr_in 

494 indices_in, indices_out = indices_out, indices_in 

495 

496 return arr_in, indices_in 

497 

498 

499@libentry() 

500@triton.jit() 

501def sort_kernel( 

502 in_ptr, 

503 out_ptr, 

504 out_index_ptr, 

505 N: tl.constexpr, 

506 BLOCK_SIZE: tl.constexpr, 

507 DESCENDING: tl.constexpr, 

508 IS_FLOAT: tl.constexpr, 

509): 

510 cols = tl.arange(0, BLOCK_SIZE) 

511 mask = cols < N 

512 offset = tl.program_id(0) * N + cols 

513 in_ptr += offset 

514 out_ptr += offset 

515 out_index_ptr += offset 

516 

517 if IS_FLOAT: 

518 mask_val = _get_finfo_val(in_ptr.dtype.element_ty, return_max=not DESCENDING) 

519 in_val = tl.load(in_ptr, mask=mask, other=mask_val) 

520 in_val = tl.where(in_val.dtype.is_fp64(), in_val, in_val.to(tl.float32)) 

521 else: 

522 mask_val = _get_iinfo_val(in_ptr.dtype.element_ty, return_max=not DESCENDING) 

523 in_val = tl.load(in_ptr, mask=mask, other=mask_val).to(tl.int32) 

524 index_val = tl.arange(0, BLOCK_SIZE) 

525 

526 sorted_in_val, sorted_index_val = argsort( 

527 in_val, index_val, 0, descending=DESCENDING 

528 ) 

529 tl.store(out_ptr, sorted_in_val, mask=mask) 

530 tl.store(out_index_ptr, sorted_index_val, mask=mask) 

531 

532 

533def sort(inp, dim=-1, descending=False): 

534 logger.debug("GEMS SORT") 

535 sort_elem_cnt = inp.shape[dim] 

536 if sort_elem_cnt == 1: 

537 return inp, torch.zeros_like(inp, dtype=torch.int64) 

538 elif sort_elem_cnt > 512: # TODO: Optimize implementation for large cases. 

539 return torch.sort(inp, stable=False, dim=dim, descending=descending) 

540 block_size = triton.next_power_of_2(sort_elem_cnt) 

541 

542 if dim < 0: 

543 dim = dim + inp.ndim 

544 if dim != inp.ndim - 1: 

545 inp = torch.movedim(inp, dim, -1).contiguous() 

546 else: 

547 inp = inp.contiguous() 

548 batch_size = math.prod(inp.shape) // sort_elem_cnt 

549 

550 out = torch.empty_like(inp) 

551 out_index = torch.empty_like(inp, dtype=torch.int64) 

552 

553 with torch_device_fn.device(inp.device): 

554 sort_kernel[batch_size,]( 

555 inp, 

556 out, 

557 out_index, 

558 N=sort_elem_cnt, 

559 BLOCK_SIZE=block_size, 

560 DESCENDING=descending, 

561 IS_FLOAT=inp.is_floating_point(), 

562 num_warps=4, 

563 ) 

564 

565 if dim != inp.ndim - 1: 

566 out = torch.movedim(out, -1, dim) 

567 out_index = torch.movedim(out_index, -1, dim) 

568 return out, out_index 

569 

570 

571def sort_stable(inp, *, stable, dim=-1, descending=False): 

572 logger.debug("GEMS SORT.STABLE") 

573 # We only implement stable radix sort here 

574 _ = stable 

575 sort_elem_cnt = inp.shape[dim] 

576 if sort_elem_cnt == 1: 

577 return inp, torch.zeros_like(inp, dtype=torch.int64) 

578 

579 if dim < 0: 

580 dim = dim + inp.ndim 

581 if dim != inp.ndim - 1: 

582 inp = torch.movedim(inp, dim, -1).contiguous() 

583 else: 

584 inp = inp.contiguous() 

585 

586 dtype = inp.dtype 

587 num_bits_per_pass = 1 if dtype == torch.bool else 4 

588 out, out_index = radix_sort_low_mem(inp, num_bits_per_pass, descending) 

589 

590 if dim != inp.ndim - 1: 

591 out = torch.movedim(out, -1, dim) 

592 out_index = torch.movedim(out_index, -1, dim) 

593 return out, out_index