Coverage for src/flag_gems/ops/randperm.py: 46%

266 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-16 02:02 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.ops.topk import argsort 

9from flag_gems.runtime import device, torch_device_fn 

10from flag_gems.utils import libentry 

11from flag_gems.utils.random_utils import philox_backend_seed_offset 

12 

13logger = logging.getLogger(__name__) 

14device_ = device 

15 

16_MIN_INT8_VAL = tl.constexpr(torch.iinfo(torch.int8).min) 

17_MAX_INT8_VAL = tl.constexpr(torch.iinfo(torch.int8).max) 

18_MIN_INT16_VAL = tl.constexpr(torch.iinfo(torch.int16).min) 

19_MAX_INT16_VAL = tl.constexpr(torch.iinfo(torch.int16).max) 

20_MIN_INT32_VAL = tl.constexpr(torch.iinfo(torch.int32).min) 

21_MAX_INT32_VAL = tl.constexpr(torch.iinfo(torch.int32).max) 

22_MIN_INT64_VAL = tl.constexpr(torch.iinfo(torch.int64).min) 

23_MAX_INT64_VAL = tl.constexpr(torch.iinfo(torch.int64).max) 

24_MAX_UINT32_VAL = tl.constexpr((1 << 32) - 1) 

25_MIN_UINT32_VAL = tl.constexpr(0) 

26_MIN_INT24_VAL = tl.constexpr(-(2**23)) 

27_MAX_INT24_VAL = tl.constexpr(2**23 - 1) 

28 

29 

30@triton.jit 

31def _get_iinfo_val( 

32 dtype, 

33 return_max, 

34): 

35 if dtype is tl.int64: 

36 if return_max: 

37 return _MAX_INT64_VAL 

38 else: 

39 return _MIN_INT64_VAL 

40 elif dtype is tl.int32: 

41 if return_max: 

42 return _MAX_INT32_VAL 

43 else: 

44 return _MIN_INT32_VAL 

45 elif dtype is tl.int16: 

46 if return_max: 

47 return _MAX_INT16_VAL 

48 else: 

49 return _MIN_INT16_VAL 

50 elif dtype is tl.int8: 

51 if return_max: 

52 return _MAX_INT8_VAL 

53 else: 

54 return _MIN_INT8_VAL 

55 elif dtype is tl.uint32: 

56 if return_max: 

57 return _MAX_UINT32_VAL 

58 else: 

59 return _MIN_UINT32_VAL 

60 else: 

61 raise ValueError("Unknown dtype") 

62 

63 

64@libentry() 

65@triton.jit 

66def bitonic_sortbykey_kernel( 

67 y_ptr, 

68 index_ptr, 

69 chunk_x, 

70 chunk_index, 

71 N: tl.constexpr, 

72 BLOCK_SIZE: tl.constexpr, 

73 DESCENDING: tl.constexpr, 

74): 

75 cur_batch = tl.program_id(0) 

76 chunk_x += cur_batch * N 

77 chunk_index += cur_batch * N 

78 index_ptr += cur_batch * N 

79 y_ptr += cur_batch * N 

80 

81 cols = tl.arange(0, BLOCK_SIZE) 

82 mask = cols < N 

83 

84 mask_val = _get_iinfo_val(chunk_x.dtype.element_ty, return_max=not DESCENDING) 

85 

86 chunk_x_val = tl.load(chunk_x + cols, mask=mask, other=mask_val) 

87 chunk_index_val = tl.load(chunk_index + cols, mask=mask) 

88 

89 sorted_chunk_x, sorted_chunk_index = argsort( 

90 chunk_x_val, chunk_index_val, 0, descending=DESCENDING 

91 ) 

92 tl.store(y_ptr + cols, sorted_chunk_x, mask=cols < N) 

93 tl.store(index_ptr + cols, sorted_chunk_index, mask=cols < N) 

94 

95 

96@triton.jit 

97def radix_type_convert(k): 

98 ik = k.to(tl.int64) 

99 if tl.constexpr(k.dtype == tl.int8): 

100 mask = (ik >> 7) & 0x1 

101 o = tl.where(mask, ik & 0x7F, ik | 0x80) 

102 elif tl.constexpr(k.dtype == tl.int16): 

103 mask = (ik >> 15) & 0x1 

104 o = tl.where(mask, ik & 0x7FFF, ik | 0x8000) 

105 elif tl.constexpr(k.dtype == tl.int32): 

106 mask = (ik >> 31) & 0x1 

107 o = tl.where(mask, ik & 0x7FFFFFFF, ik | 0x80000000) 

108 elif tl.constexpr(k.dtype == tl.int64): 

109 mask = (ik >> 63) & 0x1 

110 o = tl.where(mask, ik & 0x7FFFFFFFFFFFFFFF, ik | 0x8000000000000000) 

111 else: 

112 o = k 

113 return o 

114 

115 

116@libentry() 

117@triton.jit 

118def digit_hist_kernel( 

119 digit_hist, 

120 key, 

121 n_elements, 

122 bits_per_pass, 

123 bins, 

124 passes, 

125 bit_mask, 

126 bins_segment, 

127 BLOCK_SIZE: tl.constexpr, 

128): 

129 bin_segid = tl.program_id(1) 

130 pid0 = tl.program_id(0) 

131 grid0 = tl.num_programs(0) 

132 

133 key_offset = pid0.to(tl.int64) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

134 key_mask = key_offset < n_elements 

135 key_data = tl.load(key + key_offset, mask=key_mask) 

136 ikey_data = radix_type_convert(key_data) 

137 bit_offset = 0 

138 for p in range(passes): 

139 key_digit = (ikey_data >> bit_offset) & bit_mask 

140 blk_bin_start = bin_segid * bins_segment 

141 for s in range(bins_segment): 

142 bin_id = s + blk_bin_start 

143 digit_mask = tl.where(key_digit == bin_id and key_mask, 1, 0) 

144 digit_sum = tl.sum(digit_mask) 

145 # +1 for exclusive 

146 bin_offset = p * (bins + 1) * grid0 + (bin_id + 1) * grid0 + pid0 

147 # reduce rather than global atomic for perf issue 

148 tl.store(digit_hist + bin_offset, digit_sum) 

149 tl.store(digit_hist + p * (bins + 1) * grid0 + pid0, 0, mask=bin_segid == 0) 

150 bit_offset += bits_per_pass 

151 

152 

153@libentry() 

154@triton.autotune( 

155 configs=runtime.get_tuned_config("randperm"), 

156 key=["n_elements"], 

157) 

158@triton.jit 

159def radix_sortbykey_scatter_kernel( 

160 key_out, 

161 value_out, 

162 key_in, 

163 value_in, 

164 digit_hist, 

165 d_lookback, 

166 n_elements, 

167 bit_offset, 

168 passes, 

169 p, 

170 num_portions, 

171 portion_size, 

172 portion_id, 

173 bit_mask, 

174 bins_segment, 

175 max_tiles_per_portion, 

176 bins: tl.constexpr, 

177 BLOCK_SIZE: tl.constexpr, 

178): 

179 LOOKBACK_PARTIAL_MASK = 1 << 30 

180 LOOKBACK_GLOBAL_MASK = 1 << 31 

181 LOOKBACK_KIND_MASK = LOOKBACK_PARTIAL_MASK | LOOKBACK_GLOBAL_MASK 

182 LOOKBACK_VALUE_MASK = ~LOOKBACK_KIND_MASK 

183 

184 pid0 = tl.program_id(0) 

185 portion_id_i64 = portion_id 

186 portion_id_i64 = portion_id_i64.to(tl.int64) 

187 key_offset = ( 

188 portion_id_i64 * portion_size 

189 + pid0.to(tl.int64) * BLOCK_SIZE 

190 + tl.arange(0, BLOCK_SIZE) 

191 ) 

192 

193 key_mask = key_offset < n_elements 

194 value_data = tl.load(value_in + key_offset, mask=key_mask) 

195 key_data = tl.load(key_in + key_offset, mask=key_mask) 

196 

197 ikey_data = radix_type_convert(key_data) 

198 key_digit = (ikey_data >> bit_offset) & bit_mask 

199 

200 blk_bin_start = tl.program_id(1) * bins_segment 

201 last_block = tl.program_id(0) == tl.num_programs(0) - 1 

202 for s in range(bins_segment): 

203 bin_id = s + blk_bin_start 

204 key_digit_mask = (key_digit == bin_id) & key_mask 

205 key_elem_mask = tl.where(key_digit_mask, 1, 0) 

206 key_block_rank = tl.cumsum(key_elem_mask) 

207 key_block_rank = tl.where(key_digit_mask, key_block_rank - 1, 0) 

208 bin_of_bucket = tl.sum(key_elem_mask) 

209 partial_counter = bin_of_bucket | LOOKBACK_PARTIAL_MASK 

210 tl.store( 

211 d_lookback 

212 + ((portion_id * passes + p) * max_tiles_per_portion + pid0) * bins 

213 + bin_id, 

214 partial_counter, 

215 cache_modifier=".cg", 

216 ) 

217 bin_offset = p * (bins + 1) + bin_id 

218 prefix_offsets = tl.load( 

219 digit_hist + bin_offset + portion_id * passes * (bins + 1) 

220 ) 

221 bk = pid0 - 1 

222 inc_sum = bin_of_bucket 

223 while bk >= 0: 

224 rd_lbk_offset = ( 

225 (portion_id * passes + p) * max_tiles_per_portion + bk 

226 ) * bins + bin_id 

227 partial_prefix = tl.load(d_lookback + rd_lbk_offset, volatile=True) 

228 while partial_prefix == 0: 

229 partial_prefix = tl.load(d_lookback + rd_lbk_offset, volatile=True) 

230 inc_sum += (partial_prefix & LOOKBACK_VALUE_MASK).to(tl.int32) 

231 if partial_prefix & LOOKBACK_GLOBAL_MASK: 

232 # break 

233 bk = -1 

234 else: 

235 bk -= 1 

236 global_counter = inc_sum | LOOKBACK_GLOBAL_MASK 

237 tl.store( 

238 d_lookback 

239 + ((portion_id * passes + p) * max_tiles_per_portion + pid0) * bins 

240 + bin_id, 

241 global_counter, 

242 cache_modifier=".cg", 

243 ) 

244 inc_bucket_offset = prefix_offsets.to(tl.int64) + inc_sum.to(tl.int64) 

245 if last_block and portion_id < num_portions - 1: 

246 tl.store( 

247 digit_hist + bin_offset + (portion_id + 1) * passes * (bins + 1), 

248 inc_bucket_offset, 

249 ) 

250 global_offsets = ( 

251 inc_bucket_offset - bin_of_bucket.to(tl.int64) + key_block_rank.to(tl.int64) 

252 ) 

253 tl.store(key_out + global_offsets, key_data, mask=key_digit_mask) 

254 tl.store(value_out + global_offsets, value_data, mask=key_digit_mask) 

255 

256 

257# for parallelization, randomly shuffle the entire block rather than adjacent equal elements as pytorch GPU backend 

258@libentry() 

259@triton.jit(do_not_specialize=["philox_seed", "philox_offset"]) 

260def duplicate_keys_shuffle_kernel( 

261 value_in, n_elements, philox_seed, philox_offset, BLOCK_SIZE: tl.constexpr 

262): 

263 pid0 = tl.program_id(0) 

264 offset_range = tl.arange(0, BLOCK_SIZE) 

265 value_offset = pid0.to(tl.int64) * BLOCK_SIZE + offset_range 

266 value_mask = value_offset < n_elements 

267 value_data = tl.load(value_in + value_offset, mask=value_mask) 

268 

269 philox_seed = philox_seed.to(tl.int64) 

270 philox_offset = philox_offset.to(tl.int64) 

271 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32) 

272 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32) 

273 i4 = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

274 c0 += i4 

275 _O = c0 * 0 

276 r0, _, _, _ = tl.philox(philox_seed, c0, c1, _O, _O) 

277 

278 _block_size = BLOCK_SIZE 

279 r1 = r0 % _block_size.to(tl.uint32) 

280 mask_val = _get_iinfo_val(tl.uint32, True) 

281 r1 = tl.where(value_offset < n_elements, r1, mask_val) 

282 _, sorted_chunk_index = argsort(r1, offset_range, 0, descending=False) 

283 store_offset = pid0.to(tl.int64) * BLOCK_SIZE + sorted_chunk_index.to(tl.int64) 

284 tl.store(value_in + store_offset, value_data, mask=store_offset < n_elements) 

285 

286 

287def sort_by_key(key, value, valid_bits, generator=None): 

288 n_elements = key.numel() 

289 if n_elements > 2 * 1024: 

290 # radix method 

291 BLOCK_SIZE = 1024 

292 bits_per_pass = 4 

293 bits_per_segment = 3 

294 passes = triton.cdiv(valid_bits, bits_per_pass) 

295 bins = 2**bits_per_pass 

296 bins_per_sgement = 2**bits_per_segment 

297 bit_mask = bins - 1 

298 

299 portion_size = 2**30 # 2 bits reserved for mask 

300 num_portions = triton.cdiv(n_elements, portion_size) 

301 max_portion_items = portion_size if num_portions > 1 else n_elements 

302 max_tiles_per_portion = triton.cdiv(max_portion_items, BLOCK_SIZE) 

303 

304 hist_dtype = torch.int64 if num_portions > 1 else torch.int32 

305 grid_hist = (triton.cdiv(n_elements, BLOCK_SIZE), bins // bins_per_sgement) 

306 

307 digit_hist_slice = torch.empty( 

308 (passes, bins + 1, grid_hist[0]), dtype=hist_dtype, device=key.device 

309 ) 

310 

311 digit_hist = torch.empty( 

312 (num_portions, passes, bins + 1), dtype=hist_dtype, device=key.device 

313 ) 

314 d_lookback = torch.empty( 

315 num_portions * passes * bins * max_tiles_per_portion, 

316 dtype=torch.int32, 

317 device=key.device, 

318 ) 

319 

320 key_out_p = torch.empty_like(key) 

321 key_out_q = torch.empty_like(key) 

322 value_out_p = torch.empty_like(value) 

323 value_out_q = torch.empty_like(value) 

324 

325 # step1 

326 d_lookback.zero_() 

327 with torch_device_fn.device(key.device): 

328 digit_hist_kernel[grid_hist]( 

329 digit_hist_slice, 

330 key, 

331 n_elements, 

332 bits_per_pass, 

333 bins, 

334 passes, 

335 bit_mask, 

336 bins_per_sgement, 

337 BLOCK_SIZE, 

338 ) 

339 

340 # step2 

341 digit_hist_slice = torch.sum(digit_hist_slice, dim=2, keepdim=False) 

342 digit_hist_slice = digit_hist_slice.cumsum(dim=1) # shape of [passes, bins + 1] 

343 digit_hist.copy_(digit_hist_slice) 

344 

345 bit_offset = 0 

346 for p in range(passes): 

347 k_in = (key if p == 0 else key_out_p) if p % 2 == 0 else key_out_q 

348 v_in = (value if p == 0 else value_out_p) if p % 2 == 0 else value_out_q 

349 k_out = key_out_q if p % 2 == 0 else key_out_p 

350 v_out = value_out_q if p % 2 == 0 else value_out_p 

351 # step3 

352 for portion_id in range(num_portions): 

353 portion_items = min( 

354 n_elements - portion_id * portion_size, portion_size 

355 ) 

356 tiles_per_portion = triton.cdiv(portion_items, BLOCK_SIZE) 

357 grid_scatter = (tiles_per_portion, grid_hist[1]) 

358 with torch_device_fn.device(key.device): 

359 radix_sortbykey_scatter_kernel[grid_scatter]( 

360 k_out, 

361 v_out, 

362 k_in, 

363 v_in, 

364 digit_hist, 

365 d_lookback, 

366 n_elements, 

367 bit_offset, 

368 passes, 

369 p, 

370 num_portions, 

371 portion_size, 

372 portion_id, 

373 bit_mask, 

374 bins_per_sgement, 

375 max_tiles_per_portion, 

376 bins, 

377 BLOCK_SIZE, 

378 ) 

379 bit_offset += bits_per_pass 

380 

381 # last step, shuffle inner-block data 

382 BLOCK_SIZE_SHUFFLE = 512 

383 grid_shuffle = (triton.cdiv(n_elements, BLOCK_SIZE_SHUFFLE),) 

384 philox_seed, philox_offset = philox_backend_seed_offset( 

385 n_elements, generator=generator 

386 ) 

387 with torch_device_fn.device(key.device): 

388 duplicate_keys_shuffle_kernel[grid_shuffle]( 

389 v_out, 

390 n_elements, 

391 philox_seed, 

392 philox_offset, 

393 BLOCK_SIZE_SHUFFLE, 

394 num_warps=4, 

395 ) 

396 return v_out 

397 else: 

398 # bitonic method 

399 BLOCK_SIZE = triton.next_power_of_2(n_elements) 

400 grid = (1,) 

401 k_out = torch.empty_like(key) 

402 v_out = torch.empty_like(value) 

403 with torch_device_fn.device(key.device): 

404 bitonic_sortbykey_kernel[grid]( 

405 k_out, v_out, key, value, n_elements, BLOCK_SIZE, False 

406 ) 

407 return v_out 

408 

409 

410def randperm( 

411 n, 

412 *, 

413 generator=None, 

414 out=None, 

415 dtype=torch.int64, 

416 layout=torch.strided, 

417 device=None, 

418 requires_grad=False, 

419 pin_memory=False, 

420): 

421 logger.debug("GEMS RANDPERM") 

422 assert dtype == torch.int16 or dtype == torch.int32 or dtype == torch.int64 

423 assert n <= _MAX_INT64_VAL, "n exceeds maximum int64" 

424 

425 if device is None: 

426 device = torch.device(device_.name) 

427 in_range = torch.arange(n, dtype=dtype, device=device) 

428 

429 u8max = 2**8 

430 u16max = 2**16 

431 u24max = 2**24 

432 u32max = 2**32 

433 

434 if n <= u8max: 

435 valid_bits = 8 

436 key_dtype = torch.int8 

437 keymin = _MIN_INT8_VAL 

438 keymax = _MAX_INT8_VAL 

439 elif n <= u16max: 

440 valid_bits = 16 

441 key_dtype = torch.int16 

442 keymin = _MIN_INT16_VAL 

443 keymax = _MAX_INT16_VAL 

444 elif n <= u24max: 

445 valid_bits = 24 

446 key_dtype = torch.int32 

447 keymin = _MIN_INT24_VAL 

448 keymax = _MAX_INT24_VAL 

449 elif n <= u32max: 

450 valid_bits = 32 

451 key_dtype = torch.int32 

452 keymin = _MIN_INT32_VAL 

453 keymax = _MAX_INT32_VAL 

454 else: 

455 valid_bits = 64 

456 key_dtype = torch.int64 

457 keymin = _MIN_INT64_VAL 

458 keymax = _MAX_INT64_VAL 

459 

460 rand_key = torch.randint( 

461 low=keymin, high=keymax, size=[n], dtype=key_dtype, device=device 

462 ) 

463 perm_range = sort_by_key(rand_key, in_range, valid_bits, generator=generator) 

464 return perm_range