Coverage for src/flag_gems/runtime/backend/_ascend/ops/randperm.py: 0%

273 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-20 02:31 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.ops.topk import argsort 

9from flag_gems.runtime import device, torch_device_fn 

10from flag_gems.utils import libentry 

11from flag_gems.utils.random_utils import philox_backend_seed_offset 

12 

13logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}') 

14 

15device_ = device 

16 

17_MIN_INT8_VAL: tl.constexpr = torch.iinfo(torch.int8).min 

18_MAX_INT8_VAL: tl.constexpr = torch.iinfo(torch.int8).max 

19_MIN_INT16_VAL: tl.constexpr = torch.iinfo(torch.int16).min 

20_MAX_INT16_VAL: tl.constexpr = torch.iinfo(torch.int16).max 

21_MIN_INT32_VAL: tl.constexpr = torch.iinfo(torch.int32).min 

22_MAX_INT32_VAL: tl.constexpr = torch.iinfo(torch.int32).max 

23_MIN_INT64_VAL: tl.constexpr = torch.iinfo(torch.int64).min 

24_MAX_INT64_VAL: tl.constexpr = torch.iinfo(torch.int64).max 

25_MAX_UINT32_VAL: tl.constexpr = (1 << 32) - 1 

26_MIN_UINT32_VAL: tl.constexpr = 0 

27_MIN_INT24_VAL: tl.constexpr = -(2**23) 

28_MAX_INT24_VAL: tl.constexpr = 2**23 - 1 

29 

30 

31@triton.jit 

32def _get_iinfo_val( 

33 dtype, 

34 return_max, 

35): 

36 if dtype is tl.int64: 

37 if return_max: 

38 return _MAX_INT64_VAL 

39 else: 

40 return _MIN_INT64_VAL 

41 elif dtype is tl.int32: 

42 if return_max: 

43 return _MAX_INT32_VAL 

44 else: 

45 return _MIN_INT32_VAL 

46 elif dtype is tl.int16: 

47 if return_max: 

48 return _MAX_INT16_VAL 

49 else: 

50 return _MIN_INT16_VAL 

51 elif dtype is tl.int8: 

52 if return_max: 

53 return _MAX_INT8_VAL 

54 else: 

55 return _MIN_INT8_VAL 

56 elif dtype is tl.uint32: 

57 if return_max: 

58 return _MAX_UINT32_VAL 

59 else: 

60 return _MIN_UINT32_VAL 

61 else: 

62 raise ValueError("Unknown dtype") 

63 

64 

65@libentry() 

66@triton.jit 

67def bitonic_sortbykey_kernel( 

68 y_ptr, 

69 index_ptr, 

70 chunk_x, 

71 chunk_index, 

72 N: tl.constexpr, 

73 BLOCK_SIZE: tl.constexpr, 

74 DESCENDING: tl.constexpr, 

75): 

76 cur_batch = tl.program_id(0) 

77 chunk_x += cur_batch * N 

78 chunk_index += cur_batch * N 

79 index_ptr += cur_batch * N 

80 y_ptr += cur_batch * N 

81 

82 cols = tl.arange(0, BLOCK_SIZE) 

83 mask = cols < N 

84 

85 mask_val = _get_iinfo_val(chunk_x.dtype.element_ty, return_max=not DESCENDING) 

86 

87 chunk_x_val = tl.load(chunk_x + cols, mask=mask, other=mask_val) 

88 chunk_index_val = tl.load(chunk_index + cols, mask=mask) 

89 

90 sorted_chunk_x, sorted_chunk_index = argsort( 

91 chunk_x_val, chunk_index_val, 0, descending=DESCENDING 

92 ) 

93 tl.store(y_ptr + cols, sorted_chunk_x, mask=cols < N) 

94 tl.store(index_ptr + cols, sorted_chunk_index, mask=cols < N) 

95 

96 

97@triton.jit 

98def radix_type_convert(k): 

99 if tl.constexpr(k.dtype == tl.int8): 

100 ik = k.to(tl.int8, bitcast=True) 

101 mask = (ik >> 7) & 0x1 

102 o = tl.where(mask, ik & 0x7F, ik | 0x80) 

103 elif tl.constexpr(k.dtype == tl.int16): 

104 ik = k.to(tl.int16, bitcast=True) 

105 mask = (ik >> 15) & 0x1 

106 o = tl.where(mask, ik & 0x7FFF, ik | 0x8000) 

107 elif tl.constexpr(k.dtype == tl.int32): 

108 ik = k.to(tl.int32, bitcast=True) 

109 mask = (ik >> 31) & 0x1 

110 o = tl.where(mask, ik & 0x7FFFFFFF, ik | 0x80000000) 

111 elif tl.constexpr(k.dtype == tl.int64): 

112 ik = k.to(tl.int64, bitcast=True) 

113 mask = (ik >> 63) & 0x1 

114 o = tl.where(mask, ik & 0x7FFFFFFFFFFFFFFF, ik | 0x8000000000000000) 

115 else: 

116 o = k 

117 return o 

118 

119 

120@libentry() 

121@triton.jit 

122def digit_hist_kernel( 

123 digit_hist, 

124 key, 

125 n_elements, 

126 bits_per_pass, 

127 bins, 

128 passes, 

129 bit_mask, 

130 bins_segment, 

131 BLOCK_SIZE: tl.constexpr, 

132): 

133 bin_segid = tl.program_id(1) 

134 pid0 = tl.program_id(0) 

135 grid0 = tl.num_programs(0) 

136 

137 key_offset = pid0.to(tl.int64) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

138 key_mask = key_offset < n_elements 

139 key_data = tl.load(key + key_offset, mask=key_mask) 

140 ikey_data = radix_type_convert(key_data) 

141 bit_offset = 0 

142 for p in range(passes): 

143 key_digit = (ikey_data >> bit_offset) & bit_mask 

144 blk_bin_start = bin_segid * bins_segment 

145 for s in range(bins_segment): 

146 bin_id = s + blk_bin_start 

147 digit_mask = tl.where(key_digit == bin_id and key_mask, 1, 0) 

148 digit_sum = tl.sum(digit_mask) 

149 # +1 for exclusive 

150 bin_offset = p * (bins + 1) * grid0 + (bin_id + 1) * grid0 + pid0 

151 # reduce rather than global atomic for perf issue 

152 tl.store(digit_hist + bin_offset, digit_sum) 

153 tl.store(digit_hist + p * (bins + 1) * grid0 + pid0, 0, mask=bin_segid == 0) 

154 bit_offset += bits_per_pass 

155 

156 

157@libentry() 

158@triton.autotune( 

159 configs=runtime.get_tuned_config("randperm"), 

160 key=["n_elements"], 

161) 

162@triton.jit 

163def radix_sortbykey_scatter_kernel( 

164 key_out, 

165 value_out, 

166 key_in, 

167 value_in, 

168 digit_hist, 

169 d_lookback, 

170 n_elements, 

171 bit_offset, 

172 passes, 

173 p, 

174 num_portions, 

175 portion_size, 

176 portion_id, 

177 bit_mask, 

178 bins_segment, 

179 max_tiles_per_portion, 

180 bins: tl.constexpr, 

181 BLOCK_SIZE: tl.constexpr, 

182): 

183 LOOKBACK_PARTIAL_MASK = 1 << 30 

184 LOOKBACK_GLOBAL_MASK = 1 << 31 

185 LOOKBACK_KIND_MASK = LOOKBACK_PARTIAL_MASK | LOOKBACK_GLOBAL_MASK 

186 LOOKBACK_VALUE_MASK = ~LOOKBACK_KIND_MASK 

187 

188 pid0 = tl.program_id(0) 

189 portion_id_i64 = portion_id 

190 portion_id_i64 = portion_id_i64.to(tl.int64) 

191 key_offset = ( 

192 portion_id_i64 * portion_size 

193 + pid0.to(tl.int64) * BLOCK_SIZE 

194 + tl.arange(0, BLOCK_SIZE) 

195 ) 

196 

197 key_mask = key_offset < n_elements 

198 value_data = tl.load(value_in + key_offset, mask=key_mask) 

199 key_data = tl.load(key_in + key_offset, mask=key_mask) 

200 

201 ikey_data = radix_type_convert(key_data) 

202 key_digit = (ikey_data >> bit_offset) & bit_mask 

203 

204 blk_bin_start = tl.program_id(1) * bins_segment 

205 last_block = tl.program_id(0) == tl.num_programs(0) - 1 

206 for s in range(bins_segment): 

207 bin_id = s + blk_bin_start 

208 key_digit_mask = (key_digit == bin_id) & key_mask 

209 key_elem_mask = tl.where(key_digit_mask, 1, 0) 

210 key_block_rank = tl.cumsum(key_elem_mask) 

211 key_block_rank = tl.where(key_digit_mask, key_block_rank - 1, 0) 

212 bin_of_bucket = tl.sum(key_elem_mask) 

213 partial_counter = bin_of_bucket | LOOKBACK_PARTIAL_MASK 

214 tl.store( 

215 d_lookback 

216 + ((portion_id * passes + p) * max_tiles_per_portion + pid0) * bins 

217 + bin_id, 

218 partial_counter, 

219 cache_modifier=".cg", 

220 ) 

221 bin_offset = p * (bins + 1) + bin_id 

222 prefix_offsets = tl.load( 

223 digit_hist + bin_offset + portion_id * passes * (bins + 1) 

224 ) 

225 bk = pid0 - 1 

226 

227 inc_sum = bin_of_bucket 

228 

229 while bk >= 0: 

230 rd_lbk_offset = ( 

231 (portion_id * passes + p) * max_tiles_per_portion + bk 

232 ) * bins + bin_id 

233 partial_prefix = tl.load(d_lookback + rd_lbk_offset, volatile=True) 

234 max_wait = 1000 

235 wait_count = 0 

236 while partial_prefix == 0 and wait_count < max_wait: 

237 partial_prefix = tl.load(d_lookback + rd_lbk_offset, volatile=True) 

238 wait_count += 1 

239 inc_sum += (partial_prefix & LOOKBACK_VALUE_MASK).to(tl.int32) 

240 if partial_prefix & LOOKBACK_GLOBAL_MASK: 

241 # break 

242 bk = -1 

243 else: 

244 bk -= 1 

245 

246 global_counter = inc_sum | LOOKBACK_GLOBAL_MASK 

247 tl.store( 

248 d_lookback 

249 + ((portion_id * passes + p) * max_tiles_per_portion + pid0) * bins 

250 + bin_id, 

251 global_counter, 

252 cache_modifier=".cg", 

253 ) 

254 inc_bucket_offset = prefix_offsets.to(tl.int64) + inc_sum.to(tl.int64) 

255 if last_block and portion_id < num_portions - 1: 

256 tl.store( 

257 digit_hist + bin_offset + (portion_id + 1) * passes * (bins + 1), 

258 inc_bucket_offset, 

259 ) 

260 global_offsets = ( 

261 inc_bucket_offset - bin_of_bucket.to(tl.int64) + key_block_rank.to(tl.int64) 

262 ) 

263 tl.store(key_out + global_offsets, key_data, mask=key_digit_mask) 

264 tl.store(value_out + global_offsets, value_data, mask=key_digit_mask) 

265 

266 

267@libentry() 

268@triton.jit(do_not_specialize=["philox_seed", "philox_offset"]) 

269def duplicate_keys_shuffle_kernel( 

270 value_in, n_elements, philox_seed, philox_offset, BLOCK_SIZE: tl.constexpr 

271): 

272 pid0 = tl.program_id(0) 

273 offset_range = tl.arange(0, BLOCK_SIZE) 

274 value_offset = pid0.to(tl.int64) * BLOCK_SIZE + offset_range 

275 value_mask = value_offset < n_elements 

276 value_data = tl.load(value_in + value_offset, mask=value_mask) 

277 

278 philox_seed = philox_seed.to(tl.int64) 

279 philox_offset = philox_offset.to(tl.int64) 

280 

281 c0 = (philox_offset & 0xFFFFFFFF).to(tl.int32) 

282 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.int32) 

283 

284 i4 = pid0 * BLOCK_SIZE + offset_range 

285 c0 = c0 + i4 

286 _O = c0 * 0 

287 

288 r0, _, _, _ = tl.philox(philox_seed, c0, c1, _O, _O) 

289 

290 r0_int = r0.to(tl.int32) 

291 r1 = r0_int - (r0_int // BLOCK_SIZE) * BLOCK_SIZE 

292 

293 mask_val = tl.full([BLOCK_SIZE], BLOCK_SIZE, dtype=tl.int32) 

294 r1 = tl.where(value_mask, r1, mask_val) 

295 

296 _, sorted_chunk_index = argsort(r1, offset_range, 0, descending=False) 

297 

298 store_offset = pid0.to(tl.int64) * BLOCK_SIZE + sorted_chunk_index.to(tl.int64) 

299 tl.store(value_in + store_offset, value_data, mask=store_offset < n_elements) 

300 

301 

302def sort_by_key(key, value, valid_bits): 

303 n_elements = key.numel() 

304 if n_elements > 4: 

305 # radix method 

306 BLOCK_SIZE = 1024 

307 bits_per_pass = 4 

308 bits_per_segment = 3 

309 passes = triton.cdiv(valid_bits, bits_per_pass) 

310 bins = 2**bits_per_pass 

311 bins_per_sgement = 2**bits_per_segment 

312 bit_mask = bins - 1 

313 

314 portion_size = 2**30 # 2 bits reserved for mask 

315 num_portions = triton.cdiv(n_elements, portion_size) 

316 max_portion_items = portion_size if num_portions > 1 else n_elements 

317 max_tiles_per_portion = triton.cdiv(max_portion_items, BLOCK_SIZE) 

318 

319 hist_dtype = torch.int64 if num_portions > 1 else torch.int32 

320 grid_hist = (triton.cdiv(n_elements, BLOCK_SIZE), bins // bins_per_sgement) 

321 

322 digit_hist_slice = torch.empty( 

323 (passes, bins + 1, grid_hist[0]), dtype=hist_dtype, device=key.device 

324 ) 

325 

326 digit_hist = torch.empty( 

327 (num_portions, passes, bins + 1), dtype=hist_dtype, device=key.device 

328 ) 

329 d_lookback = torch.empty( 

330 num_portions * passes * bins * max_tiles_per_portion, 

331 dtype=torch.int32, 

332 device=key.device, 

333 ) 

334 

335 key_out_p = torch.empty_like(key) 

336 key_out_q = torch.empty_like(key) 

337 value_out_p = torch.empty_like(value) 

338 value_out_q = torch.empty_like(value) 

339 # step1 

340 d_lookback.zero_() 

341 with torch_device_fn.device(key.device): 

342 digit_hist_kernel[grid_hist]( 

343 digit_hist_slice, 

344 key, 

345 n_elements, 

346 bits_per_pass, 

347 bins, 

348 passes, 

349 bit_mask, 

350 bins_per_sgement, 

351 BLOCK_SIZE, 

352 ) 

353 

354 # step2 

355 digit_hist_slice = torch.sum(digit_hist_slice, dim=2, keepdim=False) 

356 digit_hist_slice = digit_hist_slice.cumsum(dim=1) # shape of [passes, bins + 1] 

357 digit_hist.copy_(digit_hist_slice) 

358 

359 bit_offset = 0 

360 

361 for p in range(passes): 

362 k_in = (key if p == 0 else key_out_p) if p % 2 == 0 else key_out_q 

363 v_in = (value if p == 0 else value_out_p) if p % 2 == 0 else value_out_q 

364 k_out = key_out_q if p % 2 == 0 else key_out_p 

365 v_out = value_out_q if p % 2 == 0 else value_out_p 

366 # step3 

367 for portion_id in range(num_portions): 

368 portion_items = min( 

369 n_elements - portion_id * portion_size, portion_size 

370 ) 

371 tiles_per_portion = triton.cdiv(portion_items, BLOCK_SIZE) 

372 grid_scatter = (tiles_per_portion, grid_hist[1]) 

373 with torch_device_fn.device(key.device): 

374 radix_sortbykey_scatter_kernel[grid_scatter]( 

375 k_out, 

376 v_out, 

377 k_in, 

378 v_in, 

379 digit_hist, 

380 d_lookback, 

381 n_elements, 

382 bit_offset, 

383 passes, 

384 p, 

385 num_portions, 

386 portion_size, 

387 portion_id, 

388 bit_mask, 

389 bins_per_sgement, 

390 max_tiles_per_portion, 

391 bins, 

392 BLOCK_SIZE, 

393 ) 

394 

395 bit_offset += bits_per_pass 

396 

397 # last step, shuffle inner-block data 

398 BLOCK_SIZE_SHUFFLE = 128 

399 grid_shuffle = (triton.cdiv(n_elements, BLOCK_SIZE_SHUFFLE),) 

400 philox_seed, philox_offset = philox_backend_seed_offset(n_elements) 

401 with torch_device_fn.device(key.device): 

402 duplicate_keys_shuffle_kernel[grid_shuffle]( 

403 v_out, 

404 n_elements, 

405 philox_seed, 

406 philox_offset, 

407 BLOCK_SIZE_SHUFFLE, 

408 num_warps=4, 

409 ) 

410 return v_out 

411 else: 

412 # bitonic method 

413 BLOCK_SIZE = triton.next_power_of_2(n_elements) 

414 logger.debug(n_elements) 

415 

416 grid = (1,) 

417 k_out = torch.empty_like(key) 

418 v_out = torch.empty_like(value) 

419 with torch_device_fn.device(key.device): 

420 bitonic_sortbykey_kernel[grid]( 

421 k_out, v_out, key, value, n_elements, BLOCK_SIZE, False 

422 ) 

423 return v_out 

424 

425 

426def randperm( 

427 n, 

428 *, 

429 generator=None, 

430 out=None, 

431 dtype=torch.int64, 

432 layout=torch.strided, 

433 device=None, 

434 requires_grad=False, 

435 pin_memory=False, 

436): 

437 logger.debug("GEMS_ASCEND RANDPERM") 

438 assert dtype == torch.int16 or dtype == torch.int32 or dtype == torch.int64 

439 assert n <= _MAX_INT64_VAL, "n exceeds maximum int64" 

440 

441 if device is None: 

442 device = torch.device(device_.name) 

443 in_range = torch.arange(n, dtype=dtype, device=device) 

444 

445 u8max = 2**8 

446 u16max = 2**16 

447 u24max = 2**24 

448 u32max = 2**32 

449 

450 if n <= u8max: 

451 valid_bits = 8 

452 key_dtype = torch.int8 

453 keymin = _MIN_INT8_VAL 

454 keymax = _MAX_INT8_VAL 

455 elif n <= u16max: 

456 valid_bits = 16 

457 key_dtype = torch.int16 

458 keymin = _MIN_INT16_VAL 

459 keymax = _MAX_INT16_VAL 

460 elif n <= u24max: 

461 valid_bits = 24 

462 key_dtype = torch.int32 

463 keymin = _MIN_INT24_VAL 

464 keymax = _MAX_INT24_VAL 

465 elif n <= u32max: 

466 valid_bits = 32 

467 key_dtype = torch.int32 

468 keymin = _MIN_INT32_VAL 

469 keymax = _MAX_INT32_VAL 

470 else: 

471 valid_bits = 64 

472 key_dtype = torch.int64 

473 keymin = _MIN_INT64_VAL 

474 keymax = _MAX_INT64_VAL 

475 

476 rand_key = torch.randint( 

477 low=keymin, high=keymax, size=[n], dtype=key_dtype, device=device 

478 ) 

479 # breakpoint() 

480 perm_range = sort_by_key(rand_key, in_range, valid_bits) 

481 return perm_range