Coverage for src/flag_gems/runtime/backend/_cambricon/ops/randperm.py: 0%

329 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-11 02:28 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6import triton.language.core as core 

7from triton.language.standard import _log2, zeros_like 

8 

9from flag_gems import runtime 

10from flag_gems.runtime import device, torch_device_fn 

11from flag_gems.utils import libentry 

12from flag_gems.utils.random_utils import philox_backend_seed_offset 

13 

14logger = logging.getLogger(__name__) 

15device_ = device 

16 

17_MIN_INT8_VAL = tl.constexpr(torch.iinfo(torch.int8).min) 

18_MAX_INT8_VAL = tl.constexpr(torch.iinfo(torch.int8).max) 

19_MIN_INT16_VAL = tl.constexpr(torch.iinfo(torch.int16).min) 

20_MAX_INT16_VAL = tl.constexpr(torch.iinfo(torch.int16).max) 

21_MIN_INT32_VAL = tl.constexpr(torch.iinfo(torch.int32).min) 

22_MAX_INT32_VAL = tl.constexpr(torch.iinfo(torch.int32).max) 

23_MIN_INT64_VAL = tl.constexpr(torch.iinfo(torch.int64).min) 

24_MAX_INT64_VAL = tl.constexpr(torch.iinfo(torch.int64).max) 

25_MAX_UINT32_VAL = tl.constexpr((1 << 32) - 1) 

26_MIN_UINT32_VAL = tl.constexpr(0) 

27_MIN_INT24_VAL = tl.constexpr(-(2**23)) 

28_MAX_INT24_VAL = tl.constexpr(2**23 - 1) 

29 

30 

31""" 

32Note(Zhengzekang): 

33Refer from triton2.2 official `sort` implementation: 

34https://github.com/triton-lang/triton/blob/release/2.2.x/python/triton/language/standard.py#L392-L404 

35Just add indices to sort with values. 

36""" 

37 

38 

39@triton.jit 

40def _compare_and_swap(x, ids, flip, i: core.constexpr, n_dims: core.constexpr): 

41 n_outer: core.constexpr = x.numel >> n_dims 

42 shape: core.constexpr = [n_outer * 2**i, 2, 2 ** (n_dims - i - 1)] 

43 

44 # tl.device_print("shape is: ", shape) 

45 y = core.reshape(x, shape) 

46 y_idx = core.reshape(ids, shape) 

47 

48 # slice left/right with 'stride' 2**(n_dims - i - 1) 

49 mask = core.arange(0, 2)[None, :, None] 

50 left = core.broadcast_to(tl.sum(y * (1 - mask), 1)[:, None, :], shape).to(x.dtype) 

51 right = core.broadcast_to(tl.sum(y * mask, 1)[:, None, :], shape).to(x.dtype) 

52 left = core.reshape(left, x.shape) 

53 right = core.reshape(right, x.shape) 

54 

55 left_idx = core.broadcast_to(tl.sum(y_idx * (1 - mask), 1)[:, None, :], shape).to( 

56 ids.dtype 

57 ) 

58 right_idx = core.broadcast_to(tl.sum(y_idx * mask, 1)[:, None, :], shape).to( 

59 ids.dtype 

60 ) 

61 left_idx = core.reshape(left_idx, ids.shape) 

62 right_idx = core.reshape(right_idx, ids.shape) 

63 

64 # actual compare-and-swap 

65 if core.constexpr(x.dtype.primitive_bitwidth) == 8: 

66 idtype = core.int8 

67 elif core.constexpr(x.dtype.primitive_bitwidth) == 16: 

68 idtype = core.int16 

69 elif core.constexpr(x.dtype.primitive_bitwidth) == 32: 

70 idtype = core.int32 

71 elif core.constexpr(x.dtype.primitive_bitwidth) == 64: 

72 idtype = core.int64 

73 else: 

74 raise ValueError("Unsupported dtype") 

75 

76 ileft = left.to(idtype, bitcast=True) 

77 iright = right.to(idtype, bitcast=True) 

78 ix = x.to(idtype, bitcast=True) 

79 

80 cond = (left > right) ^ flip 

81 ret = ix ^ core.where(cond, ileft ^ iright, zeros_like(ix)) 

82 

83 if core.constexpr(ids.dtype.primitive_bitwidth) == 8: 

84 idx_dtype = core.int8 

85 elif core.constexpr(ids.dtype.primitive_bitwidth) == 16: 

86 idx_dtype = core.int16 

87 elif core.constexpr(ids.dtype.primitive_bitwidth) == 32: 

88 idx_dtype = core.int32 

89 elif core.constexpr(ids.dtype.primitive_bitwidth) == 64: 

90 idx_dtype = core.int64 

91 else: 

92 raise ValueError("Unsupported dtype") 

93 

94 ileft_idx = left_idx.to(idx_dtype, bitcast=True) 

95 iright_idx = right_idx.to(idx_dtype, bitcast=True) 

96 ix_idx = ids.to(idx_dtype, bitcast=True) 

97 ret_idx = ix_idx ^ core.where(cond, ileft_idx ^ iright_idx, zeros_like(ix_idx)) 

98 

99 return ret.to(x.dtype, bitcast=True), ret_idx.to(ids.dtype, bitcast=True) 

100 

101 

102@triton.jit 

103def _bitonic_merge( 

104 x, ids, stage: core.constexpr, order: core.constexpr, n_dims: core.constexpr 

105): 

106 """ 

107 order_type 0 == ascending 

108 order_type 1 == descending 

109 order_type 2 == alternating 

110 """ 

111 n_outer: core.constexpr = x.numel >> n_dims 

112 core.static_assert(stage <= n_dims) 

113 # flip denotes whether to re-arrange sub-sequences of elements in ascending or 

114 # descending order. 

115 # if flip = 00000000... then all elements will be re-arranged ascendingly at this stage 

116 # if flip = 00110011... then all the elements will be re-arranged alternatingly (with 

117 # a stride of 2) at this stage 

118 if order == 2: 

119 shape: core.constexpr = [n_outer * 2 ** (n_dims - 1 - stage), 2, 2**stage] 

120 flip = core.reshape( 

121 core.broadcast_to(core.arange(0, 2)[None, :, None], shape), x.shape 

122 ) 

123 else: 

124 flip = order 

125 # perform `stage` rounds of `compare-and-swap` 

126 for i in core.static_range(stage): 

127 x, ids = _compare_and_swap(x, ids, flip, i + (n_dims - stage), n_dims) 

128 return x, ids 

129 

130 

131@triton.jit 

132def argsort(x, ids, dim: tl.constexpr, descending: core.constexpr): 

133 # handle default dimension or check that it is the most minor dim 

134 _dim: core.constexpr = dim 

135 n_dims: core.constexpr = _log2(x.shape[_dim]) 

136 for i in core.static_range(1, n_dims + 1): 

137 x, ids = _bitonic_merge(x, ids, i, 2 if i < n_dims else descending, n_dims) 

138 return x, ids 

139 

140 

141@triton.jit 

142def _get_iinfo_val( 

143 dtype, 

144 return_max, 

145): 

146 if dtype is tl.int64: 

147 if return_max: 

148 return _MAX_INT64_VAL 

149 else: 

150 return _MIN_INT64_VAL 

151 elif dtype is tl.int32: 

152 if return_max: 

153 return _MAX_INT32_VAL 

154 else: 

155 return _MIN_INT32_VAL 

156 elif dtype is tl.int16: 

157 if return_max: 

158 return _MAX_INT16_VAL 

159 else: 

160 return _MIN_INT16_VAL 

161 elif dtype is tl.int8: 

162 if return_max: 

163 return _MAX_INT8_VAL 

164 else: 

165 return _MIN_INT8_VAL 

166 elif dtype is tl.uint32: 

167 if return_max: 

168 return _MAX_UINT32_VAL 

169 else: 

170 return _MIN_UINT32_VAL 

171 else: 

172 raise ValueError("Unknown dtype") 

173 

174 

175@libentry() 

176@triton.jit 

177def bitonic_sortbykey_kernel( 

178 y_ptr, 

179 index_ptr, 

180 chunk_x, 

181 chunk_index, 

182 N: tl.constexpr, 

183 BLOCK_SIZE: tl.constexpr, 

184 DESCENDING: tl.constexpr, 

185): 

186 cur_batch = tl.program_id(0) 

187 chunk_x += cur_batch * N 

188 chunk_index += cur_batch * N 

189 index_ptr += cur_batch * N 

190 y_ptr += cur_batch * N 

191 

192 cols = tl.arange(0, BLOCK_SIZE) 

193 mask = cols < N 

194 

195 mask_val = _get_iinfo_val(chunk_x.dtype.element_ty, return_max=not DESCENDING) 

196 

197 chunk_x_val = tl.load(chunk_x + cols, mask=mask, other=mask_val) 

198 chunk_index_val = tl.load(chunk_index + cols, mask=mask) 

199 

200 sorted_chunk_x, sorted_chunk_index = argsort( 

201 chunk_x_val, chunk_index_val, 0, descending=DESCENDING 

202 ) 

203 tl.store(y_ptr + cols, sorted_chunk_x, mask=cols < N) 

204 tl.store(index_ptr + cols, sorted_chunk_index, mask=cols < N) 

205 

206 

207@triton.jit 

208def radix_type_convert(k): 

209 ik = k.to(tl.int64) 

210 if tl.constexpr(k.dtype == tl.int8): 

211 mask = (ik >> 7) & 0x1 

212 o = tl.where(mask, ik & 0x7F, ik | 0x80) 

213 elif tl.constexpr(k.dtype == tl.int16): 

214 mask = (ik >> 15) & 0x1 

215 o = tl.where(mask, ik & 0x7FFF, ik | 0x8000) 

216 elif tl.constexpr(k.dtype == tl.int32): 

217 mask = (ik >> 31) & 0x1 

218 o = tl.where(mask, ik & 0x7FFFFFFF, ik | 0x80000000) 

219 elif tl.constexpr(k.dtype == tl.int64): 

220 mask = (ik >> 63) & 0x1 

221 o = tl.where(mask, ik & 0x7FFFFFFFFFFFFFFF, ik | 0x8000000000000000) 

222 else: 

223 o = k 

224 return o 

225 

226 

227@libentry() 

228@triton.jit 

229def digit_hist_kernel( 

230 digit_hist, 

231 key, 

232 n_elements, 

233 bits_per_pass, 

234 bins, 

235 passes, 

236 bit_mask, 

237 bins_segment, 

238 BLOCK_SIZE: tl.constexpr, 

239): 

240 bin_segid = tl.program_id(1) 

241 pid0 = tl.program_id(0) 

242 grid0 = tl.num_programs(0) 

243 

244 key_offset = pid0.to(tl.int64) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

245 key_mask = key_offset < n_elements 

246 key_data = tl.load(key + key_offset, mask=key_mask) 

247 ikey_data = radix_type_convert(key_data) 

248 bit_offset = 0 

249 for p in range(passes): 

250 key_digit = (ikey_data >> bit_offset) & bit_mask 

251 blk_bin_start = bin_segid * bins_segment 

252 for s in range(bins_segment): 

253 bin_id = s + blk_bin_start 

254 digit_mask = tl.where(key_digit == bin_id and key_mask, 1, 0) 

255 digit_sum = tl.sum(digit_mask) 

256 # +1 for exclusive 

257 bin_offset = p * (bins + 1) * grid0 + (bin_id + 1) * grid0 + pid0 

258 # reduce rather than global atomic for perf issue 

259 tl.store(digit_hist + bin_offset, digit_sum) 

260 tl.store(digit_hist + p * (bins + 1) * grid0 + pid0, 0, mask=bin_segid == 0) 

261 bit_offset += bits_per_pass 

262 

263 

264@libentry() 

265@triton.autotune( 

266 configs=runtime.get_tuned_config("randperm"), 

267 key=["n_elements"], 

268) 

269@triton.jit 

270def radix_sortbykey_scatter_kernel( 

271 key_out, 

272 value_out, 

273 key_in, 

274 value_in, 

275 digit_hist, 

276 d_lookback, 

277 n_elements, 

278 bit_offset, 

279 passes, 

280 p, 

281 num_portions, 

282 portion_size, 

283 portion_id, 

284 bit_mask, 

285 bins_segment, 

286 max_tiles_per_portion, 

287 bins: tl.constexpr, 

288 BLOCK_SIZE: tl.constexpr, 

289): 

290 LOOKBACK_PARTIAL_MASK = 1 << 30 

291 LOOKBACK_GLOBAL_MASK = 1 << 31 

292 LOOKBACK_KIND_MASK = LOOKBACK_PARTIAL_MASK | LOOKBACK_GLOBAL_MASK 

293 LOOKBACK_VALUE_MASK = ~LOOKBACK_KIND_MASK 

294 

295 pid0 = tl.program_id(0) 

296 portion_id_i64 = portion_id 

297 portion_id_i64 = portion_id_i64.to(tl.int64) 

298 key_offset = ( 

299 portion_id_i64 * portion_size 

300 + pid0.to(tl.int64) * BLOCK_SIZE 

301 + tl.arange(0, BLOCK_SIZE) 

302 ) 

303 

304 key_mask = key_offset < n_elements 

305 value_data = tl.load(value_in + key_offset, mask=key_mask) 

306 key_data = tl.load(key_in + key_offset, mask=key_mask) 

307 

308 ikey_data = radix_type_convert(key_data) 

309 key_digit = (ikey_data >> bit_offset) & bit_mask 

310 

311 blk_bin_start = tl.program_id(1) * bins_segment 

312 last_block = tl.program_id(0) == tl.num_programs(0) - 1 

313 for s in range(bins_segment): 

314 bin_id = s + blk_bin_start 

315 key_digit_mask = (key_digit == bin_id) & key_mask 

316 key_elem_mask = tl.where(key_digit_mask, 1, 0) 

317 key_block_rank = tl.cumsum(key_elem_mask) 

318 key_block_rank = tl.where(key_digit_mask, key_block_rank - 1, 0) 

319 bin_of_bucket = tl.sum(key_elem_mask) 

320 partial_counter = bin_of_bucket | LOOKBACK_PARTIAL_MASK 

321 tl.store( 

322 d_lookback 

323 + ((portion_id * passes + p) * max_tiles_per_portion + pid0) * bins 

324 + bin_id, 

325 partial_counter, 

326 cache_modifier=".cg", 

327 ) 

328 bin_offset = p * (bins + 1) + bin_id 

329 prefix_offsets = tl.load( 

330 digit_hist + bin_offset + portion_id * passes * (bins + 1) 

331 ) 

332 bk = pid0 - 1 

333 inc_sum = bin_of_bucket 

334 while bk >= 0: 

335 rd_lbk_offset = ( 

336 (portion_id * passes + p) * max_tiles_per_portion + bk 

337 ) * bins + bin_id 

338 partial_prefix = tl.load(d_lookback + rd_lbk_offset, volatile=True) 

339 while partial_prefix == 0: 

340 partial_prefix = tl.load(d_lookback + rd_lbk_offset, volatile=True) 

341 inc_sum += (partial_prefix & LOOKBACK_VALUE_MASK).to(tl.int32) 

342 if partial_prefix & LOOKBACK_GLOBAL_MASK: 

343 # break 

344 bk = -1 

345 else: 

346 bk -= 1 

347 global_counter = inc_sum | LOOKBACK_GLOBAL_MASK 

348 tl.store( 

349 d_lookback 

350 + ((portion_id * passes + p) * max_tiles_per_portion + pid0) * bins 

351 + bin_id, 

352 global_counter, 

353 cache_modifier=".cg", 

354 ) 

355 inc_bucket_offset = prefix_offsets.to(tl.int64) + inc_sum.to(tl.int64) 

356 if last_block and portion_id < num_portions - 1: 

357 tl.store( 

358 digit_hist + bin_offset + (portion_id + 1) * passes * (bins + 1), 

359 inc_bucket_offset, 

360 ) 

361 global_offsets = ( 

362 inc_bucket_offset - bin_of_bucket.to(tl.int64) + key_block_rank.to(tl.int64) 

363 ) 

364 tl.store(key_out + global_offsets, key_data, mask=key_digit_mask) 

365 tl.store(value_out + global_offsets, value_data, mask=key_digit_mask) 

366 

367 

368# for parallelization, randomly shuffle the entire block rather than adjacent equal elements as pytorch GPU backend 

369@libentry() 

370@triton.jit(do_not_specialize=["philox_seed", "philox_offset"]) 

371def duplicate_keys_shuffle_kernel( 

372 value_in, n_elements, philox_seed, philox_offset, BLOCK_SIZE: tl.constexpr 

373): 

374 pid0 = tl.program_id(0) 

375 offset_range = tl.arange(0, BLOCK_SIZE) 

376 value_offset = pid0.to(tl.int64) * BLOCK_SIZE + offset_range 

377 value_mask = value_offset < n_elements 

378 value_data = tl.load(value_in + value_offset, mask=value_mask) 

379 

380 philox_seed = philox_seed.to(tl.int64) 

381 philox_offset = philox_offset.to(tl.int64) 

382 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32) 

383 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32) 

384 i4 = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

385 c0 += i4 

386 _O = c0 * 0 

387 r0, _, _, _ = tl.philox(philox_seed, c0, c1, _O, _O) 

388 

389 _block_size = BLOCK_SIZE 

390 r1 = r0 % _block_size.to(tl.uint32) 

391 mask_val = _get_iinfo_val(tl.uint32, True) 

392 r1 = tl.where(value_offset < n_elements, r1, mask_val) 

393 _, sorted_chunk_index = argsort(r1, offset_range, 0, descending=False) 

394 store_offset = pid0.to(tl.int64) * BLOCK_SIZE + sorted_chunk_index.to(tl.int64) 

395 tl.store(value_in + store_offset, value_data, mask=store_offset < n_elements) 

396 

397 

398def sort_by_key(key, value, valid_bits, generator=None): 

399 n_elements = key.numel() 

400 if n_elements > 2 * 1024: 

401 # radix method 

402 BLOCK_SIZE = 1024 

403 bits_per_pass = 4 

404 bits_per_segment = 3 

405 passes = triton.cdiv(valid_bits, bits_per_pass) 

406 bins = 2**bits_per_pass 

407 bins_per_sgement = 2**bits_per_segment 

408 bit_mask = bins - 1 

409 

410 portion_size = 2**30 # 2 bits reserved for mask 

411 num_portions = triton.cdiv(n_elements, portion_size) 

412 max_portion_items = portion_size if num_portions > 1 else n_elements 

413 max_tiles_per_portion = triton.cdiv(max_portion_items, BLOCK_SIZE) 

414 

415 hist_dtype = torch.int64 if num_portions > 1 else torch.int32 

416 grid_hist = (triton.cdiv(n_elements, BLOCK_SIZE), bins // bins_per_sgement) 

417 

418 digit_hist_slice = torch.empty( 

419 (passes, bins + 1, grid_hist[0]), dtype=hist_dtype, device=key.device 

420 ) 

421 

422 digit_hist = torch.empty( 

423 (num_portions, passes, bins + 1), dtype=hist_dtype, device=key.device 

424 ) 

425 d_lookback = torch.empty( 

426 num_portions * passes * bins * max_tiles_per_portion, 

427 dtype=torch.int32, 

428 device=key.device, 

429 ) 

430 

431 key_out_p = torch.empty_like(key) 

432 key_out_q = torch.empty_like(key) 

433 value_out_p = torch.empty_like(value) 

434 value_out_q = torch.empty_like(value) 

435 

436 # step1 

437 d_lookback.zero_() 

438 with torch_device_fn.device(key.device): 

439 digit_hist_kernel[grid_hist]( 

440 digit_hist_slice, 

441 key, 

442 n_elements, 

443 bits_per_pass, 

444 bins, 

445 passes, 

446 bit_mask, 

447 bins_per_sgement, 

448 BLOCK_SIZE, 

449 ) 

450 

451 # step2 

452 digit_hist_slice = torch.sum(digit_hist_slice, dim=2, keepdim=False) 

453 digit_hist_slice = digit_hist_slice.cumsum(dim=1) # shape of [passes, bins + 1] 

454 digit_hist.copy_(digit_hist_slice) 

455 

456 bit_offset = 0 

457 for p in range(passes): 

458 k_in = (key if p == 0 else key_out_p) if p % 2 == 0 else key_out_q 

459 v_in = (value if p == 0 else value_out_p) if p % 2 == 0 else value_out_q 

460 k_out = key_out_q if p % 2 == 0 else key_out_p 

461 v_out = value_out_q if p % 2 == 0 else value_out_p 

462 # step3 

463 for portion_id in range(num_portions): 

464 portion_items = min( 

465 n_elements - portion_id * portion_size, portion_size 

466 ) 

467 tiles_per_portion = triton.cdiv(portion_items, BLOCK_SIZE) 

468 grid_scatter = (tiles_per_portion, grid_hist[1]) 

469 with torch_device_fn.device(key.device): 

470 radix_sortbykey_scatter_kernel[grid_scatter]( 

471 k_out, 

472 v_out, 

473 k_in, 

474 v_in, 

475 digit_hist, 

476 d_lookback, 

477 n_elements, 

478 bit_offset, 

479 passes, 

480 p, 

481 num_portions, 

482 portion_size, 

483 portion_id, 

484 bit_mask, 

485 bins_per_sgement, 

486 max_tiles_per_portion, 

487 bins, 

488 BLOCK_SIZE, 

489 ) 

490 bit_offset += bits_per_pass 

491 

492 # last step, shuffle inner-block data 

493 BLOCK_SIZE_SHUFFLE = 512 

494 grid_shuffle = (triton.cdiv(n_elements, BLOCK_SIZE_SHUFFLE),) 

495 philox_seed, philox_offset = philox_backend_seed_offset( 

496 n_elements, generator=generator 

497 ) 

498 with torch_device_fn.device(key.device): 

499 duplicate_keys_shuffle_kernel[grid_shuffle]( 

500 v_out, 

501 n_elements, 

502 philox_seed, 

503 philox_offset, 

504 BLOCK_SIZE_SHUFFLE, 

505 num_warps=4, 

506 ) 

507 return v_out 

508 else: 

509 # bitonic method 

510 BLOCK_SIZE = triton.next_power_of_2(n_elements) 

511 grid = (1,) 

512 k_out = torch.empty_like(key) 

513 v_out = torch.empty_like(value) 

514 with torch_device_fn.device(key.device): 

515 bitonic_sortbykey_kernel[grid]( 

516 k_out, v_out, key, value, n_elements, BLOCK_SIZE, False 

517 ) 

518 return v_out 

519 

520 

521def randperm( 

522 n, 

523 *, 

524 generator=None, 

525 out=None, 

526 dtype=torch.int64, 

527 layout=torch.strided, 

528 device=None, 

529 requires_grad=False, 

530 pin_memory=False, 

531): 

532 logger.debug("GEMS_CAMBRICON RANDPERM") 

533 assert dtype == torch.int16 or dtype == torch.int32 or dtype == torch.int64 

534 assert n <= _MAX_INT64_VAL, "n exceeds maximum int64" 

535 

536 if device is None: 

537 device = torch.device(device_.name) 

538 in_range = torch.arange(n, dtype=dtype, device=device) 

539 

540 u8max = 2**8 

541 u16max = 2**16 

542 u24max = 2**24 

543 u32max = 2**32 

544 

545 if n <= u8max: 

546 valid_bits = 8 

547 key_dtype = torch.int8 

548 keymin = _MIN_INT8_VAL 

549 keymax = _MAX_INT8_VAL 

550 elif n <= u16max: 

551 valid_bits = 16 

552 key_dtype = torch.int16 

553 keymin = _MIN_INT16_VAL 

554 keymax = _MAX_INT16_VAL 

555 elif n <= u24max: 

556 valid_bits = 24 

557 key_dtype = torch.int32 

558 keymin = _MIN_INT24_VAL 

559 keymax = _MAX_INT24_VAL 

560 elif n <= u32max: 

561 valid_bits = 32 

562 key_dtype = torch.int32 

563 keymin = _MIN_INT32_VAL 

564 keymax = _MAX_INT32_VAL 

565 else: 

566 valid_bits = 64 

567 key_dtype = torch.int64 

568 keymin = _MIN_INT64_VAL 

569 keymax = _MAX_INT64_VAL 

570 

571 rand_key = torch.randint( 

572 low=keymin, high=keymax, size=[n], dtype=key_dtype, device="cpu" 

573 ).to(device) 

574 perm_range = sort_by_key(rand_key, in_range, valid_bits, generator=generator) 

575 return perm_range