Coverage for src/flag_gems/runtime/backend/_cambricon/ops/randperm.py: 0%
329 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
6import triton.language.core as core
7from triton.language.standard import _log2, zeros_like
9from flag_gems import runtime
10from flag_gems.runtime import device, torch_device_fn
11from flag_gems.utils import libentry
12from flag_gems.utils.random_utils import philox_backend_seed_offset
14logger = logging.getLogger(__name__)
15device_ = device
17_MIN_INT8_VAL = tl.constexpr(torch.iinfo(torch.int8).min)
18_MAX_INT8_VAL = tl.constexpr(torch.iinfo(torch.int8).max)
19_MIN_INT16_VAL = tl.constexpr(torch.iinfo(torch.int16).min)
20_MAX_INT16_VAL = tl.constexpr(torch.iinfo(torch.int16).max)
21_MIN_INT32_VAL = tl.constexpr(torch.iinfo(torch.int32).min)
22_MAX_INT32_VAL = tl.constexpr(torch.iinfo(torch.int32).max)
23_MIN_INT64_VAL = tl.constexpr(torch.iinfo(torch.int64).min)
24_MAX_INT64_VAL = tl.constexpr(torch.iinfo(torch.int64).max)
25_MAX_UINT32_VAL = tl.constexpr((1 << 32) - 1)
26_MIN_UINT32_VAL = tl.constexpr(0)
27_MIN_INT24_VAL = tl.constexpr(-(2**23))
28_MAX_INT24_VAL = tl.constexpr(2**23 - 1)
31"""
32Note(Zhengzekang):
33Refer from triton2.2 official `sort` implementation:
34https://github.com/triton-lang/triton/blob/release/2.2.x/python/triton/language/standard.py#L392-L404
35Just add indices to sort with values.
36"""
39@triton.jit
40def _compare_and_swap(x, ids, flip, i: core.constexpr, n_dims: core.constexpr):
41 n_outer: core.constexpr = x.numel >> n_dims
42 shape: core.constexpr = [n_outer * 2**i, 2, 2 ** (n_dims - i - 1)]
44 # tl.device_print("shape is: ", shape)
45 y = core.reshape(x, shape)
46 y_idx = core.reshape(ids, shape)
48 # slice left/right with 'stride' 2**(n_dims - i - 1)
49 mask = core.arange(0, 2)[None, :, None]
50 left = core.broadcast_to(tl.sum(y * (1 - mask), 1)[:, None, :], shape).to(x.dtype)
51 right = core.broadcast_to(tl.sum(y * mask, 1)[:, None, :], shape).to(x.dtype)
52 left = core.reshape(left, x.shape)
53 right = core.reshape(right, x.shape)
55 left_idx = core.broadcast_to(tl.sum(y_idx * (1 - mask), 1)[:, None, :], shape).to(
56 ids.dtype
57 )
58 right_idx = core.broadcast_to(tl.sum(y_idx * mask, 1)[:, None, :], shape).to(
59 ids.dtype
60 )
61 left_idx = core.reshape(left_idx, ids.shape)
62 right_idx = core.reshape(right_idx, ids.shape)
64 # actual compare-and-swap
65 if core.constexpr(x.dtype.primitive_bitwidth) == 8:
66 idtype = core.int8
67 elif core.constexpr(x.dtype.primitive_bitwidth) == 16:
68 idtype = core.int16
69 elif core.constexpr(x.dtype.primitive_bitwidth) == 32:
70 idtype = core.int32
71 elif core.constexpr(x.dtype.primitive_bitwidth) == 64:
72 idtype = core.int64
73 else:
74 raise ValueError("Unsupported dtype")
76 ileft = left.to(idtype, bitcast=True)
77 iright = right.to(idtype, bitcast=True)
78 ix = x.to(idtype, bitcast=True)
80 cond = (left > right) ^ flip
81 ret = ix ^ core.where(cond, ileft ^ iright, zeros_like(ix))
83 if core.constexpr(ids.dtype.primitive_bitwidth) == 8:
84 idx_dtype = core.int8
85 elif core.constexpr(ids.dtype.primitive_bitwidth) == 16:
86 idx_dtype = core.int16
87 elif core.constexpr(ids.dtype.primitive_bitwidth) == 32:
88 idx_dtype = core.int32
89 elif core.constexpr(ids.dtype.primitive_bitwidth) == 64:
90 idx_dtype = core.int64
91 else:
92 raise ValueError("Unsupported dtype")
94 ileft_idx = left_idx.to(idx_dtype, bitcast=True)
95 iright_idx = right_idx.to(idx_dtype, bitcast=True)
96 ix_idx = ids.to(idx_dtype, bitcast=True)
97 ret_idx = ix_idx ^ core.where(cond, ileft_idx ^ iright_idx, zeros_like(ix_idx))
99 return ret.to(x.dtype, bitcast=True), ret_idx.to(ids.dtype, bitcast=True)
102@triton.jit
103def _bitonic_merge(
104 x, ids, stage: core.constexpr, order: core.constexpr, n_dims: core.constexpr
105):
106 """
107 order_type 0 == ascending
108 order_type 1 == descending
109 order_type 2 == alternating
110 """
111 n_outer: core.constexpr = x.numel >> n_dims
112 core.static_assert(stage <= n_dims)
113 # flip denotes whether to re-arrange sub-sequences of elements in ascending or
114 # descending order.
115 # if flip = 00000000... then all elements will be re-arranged ascendingly at this stage
116 # if flip = 00110011... then all the elements will be re-arranged alternatingly (with
117 # a stride of 2) at this stage
118 if order == 2:
119 shape: core.constexpr = [n_outer * 2 ** (n_dims - 1 - stage), 2, 2**stage]
120 flip = core.reshape(
121 core.broadcast_to(core.arange(0, 2)[None, :, None], shape), x.shape
122 )
123 else:
124 flip = order
125 # perform `stage` rounds of `compare-and-swap`
126 for i in core.static_range(stage):
127 x, ids = _compare_and_swap(x, ids, flip, i + (n_dims - stage), n_dims)
128 return x, ids
131@triton.jit
132def argsort(x, ids, dim: tl.constexpr, descending: core.constexpr):
133 # handle default dimension or check that it is the most minor dim
134 _dim: core.constexpr = dim
135 n_dims: core.constexpr = _log2(x.shape[_dim])
136 for i in core.static_range(1, n_dims + 1):
137 x, ids = _bitonic_merge(x, ids, i, 2 if i < n_dims else descending, n_dims)
138 return x, ids
141@triton.jit
142def _get_iinfo_val(
143 dtype,
144 return_max,
145):
146 if dtype is tl.int64:
147 if return_max:
148 return _MAX_INT64_VAL
149 else:
150 return _MIN_INT64_VAL
151 elif dtype is tl.int32:
152 if return_max:
153 return _MAX_INT32_VAL
154 else:
155 return _MIN_INT32_VAL
156 elif dtype is tl.int16:
157 if return_max:
158 return _MAX_INT16_VAL
159 else:
160 return _MIN_INT16_VAL
161 elif dtype is tl.int8:
162 if return_max:
163 return _MAX_INT8_VAL
164 else:
165 return _MIN_INT8_VAL
166 elif dtype is tl.uint32:
167 if return_max:
168 return _MAX_UINT32_VAL
169 else:
170 return _MIN_UINT32_VAL
171 else:
172 raise ValueError("Unknown dtype")
175@libentry()
176@triton.jit
177def bitonic_sortbykey_kernel(
178 y_ptr,
179 index_ptr,
180 chunk_x,
181 chunk_index,
182 N: tl.constexpr,
183 BLOCK_SIZE: tl.constexpr,
184 DESCENDING: tl.constexpr,
185):
186 cur_batch = tl.program_id(0)
187 chunk_x += cur_batch * N
188 chunk_index += cur_batch * N
189 index_ptr += cur_batch * N
190 y_ptr += cur_batch * N
192 cols = tl.arange(0, BLOCK_SIZE)
193 mask = cols < N
195 mask_val = _get_iinfo_val(chunk_x.dtype.element_ty, return_max=not DESCENDING)
197 chunk_x_val = tl.load(chunk_x + cols, mask=mask, other=mask_val)
198 chunk_index_val = tl.load(chunk_index + cols, mask=mask)
200 sorted_chunk_x, sorted_chunk_index = argsort(
201 chunk_x_val, chunk_index_val, 0, descending=DESCENDING
202 )
203 tl.store(y_ptr + cols, sorted_chunk_x, mask=cols < N)
204 tl.store(index_ptr + cols, sorted_chunk_index, mask=cols < N)
207@triton.jit
208def radix_type_convert(k):
209 ik = k.to(tl.int64)
210 if tl.constexpr(k.dtype == tl.int8):
211 mask = (ik >> 7) & 0x1
212 o = tl.where(mask, ik & 0x7F, ik | 0x80)
213 elif tl.constexpr(k.dtype == tl.int16):
214 mask = (ik >> 15) & 0x1
215 o = tl.where(mask, ik & 0x7FFF, ik | 0x8000)
216 elif tl.constexpr(k.dtype == tl.int32):
217 mask = (ik >> 31) & 0x1
218 o = tl.where(mask, ik & 0x7FFFFFFF, ik | 0x80000000)
219 elif tl.constexpr(k.dtype == tl.int64):
220 mask = (ik >> 63) & 0x1
221 o = tl.where(mask, ik & 0x7FFFFFFFFFFFFFFF, ik | 0x8000000000000000)
222 else:
223 o = k
224 return o
227@libentry()
228@triton.jit
229def digit_hist_kernel(
230 digit_hist,
231 key,
232 n_elements,
233 bits_per_pass,
234 bins,
235 passes,
236 bit_mask,
237 bins_segment,
238 BLOCK_SIZE: tl.constexpr,
239):
240 bin_segid = tl.program_id(1)
241 pid0 = tl.program_id(0)
242 grid0 = tl.num_programs(0)
244 key_offset = pid0.to(tl.int64) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
245 key_mask = key_offset < n_elements
246 key_data = tl.load(key + key_offset, mask=key_mask)
247 ikey_data = radix_type_convert(key_data)
248 bit_offset = 0
249 for p in range(passes):
250 key_digit = (ikey_data >> bit_offset) & bit_mask
251 blk_bin_start = bin_segid * bins_segment
252 for s in range(bins_segment):
253 bin_id = s + blk_bin_start
254 digit_mask = tl.where(key_digit == bin_id and key_mask, 1, 0)
255 digit_sum = tl.sum(digit_mask)
256 # +1 for exclusive
257 bin_offset = p * (bins + 1) * grid0 + (bin_id + 1) * grid0 + pid0
258 # reduce rather than global atomic for perf issue
259 tl.store(digit_hist + bin_offset, digit_sum)
260 tl.store(digit_hist + p * (bins + 1) * grid0 + pid0, 0, mask=bin_segid == 0)
261 bit_offset += bits_per_pass
264@libentry()
265@triton.autotune(
266 configs=runtime.get_tuned_config("randperm"),
267 key=["n_elements"],
268)
269@triton.jit
270def radix_sortbykey_scatter_kernel(
271 key_out,
272 value_out,
273 key_in,
274 value_in,
275 digit_hist,
276 d_lookback,
277 n_elements,
278 bit_offset,
279 passes,
280 p,
281 num_portions,
282 portion_size,
283 portion_id,
284 bit_mask,
285 bins_segment,
286 max_tiles_per_portion,
287 bins: tl.constexpr,
288 BLOCK_SIZE: tl.constexpr,
289):
290 LOOKBACK_PARTIAL_MASK = 1 << 30
291 LOOKBACK_GLOBAL_MASK = 1 << 31
292 LOOKBACK_KIND_MASK = LOOKBACK_PARTIAL_MASK | LOOKBACK_GLOBAL_MASK
293 LOOKBACK_VALUE_MASK = ~LOOKBACK_KIND_MASK
295 pid0 = tl.program_id(0)
296 portion_id_i64 = portion_id
297 portion_id_i64 = portion_id_i64.to(tl.int64)
298 key_offset = (
299 portion_id_i64 * portion_size
300 + pid0.to(tl.int64) * BLOCK_SIZE
301 + tl.arange(0, BLOCK_SIZE)
302 )
304 key_mask = key_offset < n_elements
305 value_data = tl.load(value_in + key_offset, mask=key_mask)
306 key_data = tl.load(key_in + key_offset, mask=key_mask)
308 ikey_data = radix_type_convert(key_data)
309 key_digit = (ikey_data >> bit_offset) & bit_mask
311 blk_bin_start = tl.program_id(1) * bins_segment
312 last_block = tl.program_id(0) == tl.num_programs(0) - 1
313 for s in range(bins_segment):
314 bin_id = s + blk_bin_start
315 key_digit_mask = (key_digit == bin_id) & key_mask
316 key_elem_mask = tl.where(key_digit_mask, 1, 0)
317 key_block_rank = tl.cumsum(key_elem_mask)
318 key_block_rank = tl.where(key_digit_mask, key_block_rank - 1, 0)
319 bin_of_bucket = tl.sum(key_elem_mask)
320 partial_counter = bin_of_bucket | LOOKBACK_PARTIAL_MASK
321 tl.store(
322 d_lookback
323 + ((portion_id * passes + p) * max_tiles_per_portion + pid0) * bins
324 + bin_id,
325 partial_counter,
326 cache_modifier=".cg",
327 )
328 bin_offset = p * (bins + 1) + bin_id
329 prefix_offsets = tl.load(
330 digit_hist + bin_offset + portion_id * passes * (bins + 1)
331 )
332 bk = pid0 - 1
333 inc_sum = bin_of_bucket
334 while bk >= 0:
335 rd_lbk_offset = (
336 (portion_id * passes + p) * max_tiles_per_portion + bk
337 ) * bins + bin_id
338 partial_prefix = tl.load(d_lookback + rd_lbk_offset, volatile=True)
339 while partial_prefix == 0:
340 partial_prefix = tl.load(d_lookback + rd_lbk_offset, volatile=True)
341 inc_sum += (partial_prefix & LOOKBACK_VALUE_MASK).to(tl.int32)
342 if partial_prefix & LOOKBACK_GLOBAL_MASK:
343 # break
344 bk = -1
345 else:
346 bk -= 1
347 global_counter = inc_sum | LOOKBACK_GLOBAL_MASK
348 tl.store(
349 d_lookback
350 + ((portion_id * passes + p) * max_tiles_per_portion + pid0) * bins
351 + bin_id,
352 global_counter,
353 cache_modifier=".cg",
354 )
355 inc_bucket_offset = prefix_offsets.to(tl.int64) + inc_sum.to(tl.int64)
356 if last_block and portion_id < num_portions - 1:
357 tl.store(
358 digit_hist + bin_offset + (portion_id + 1) * passes * (bins + 1),
359 inc_bucket_offset,
360 )
361 global_offsets = (
362 inc_bucket_offset - bin_of_bucket.to(tl.int64) + key_block_rank.to(tl.int64)
363 )
364 tl.store(key_out + global_offsets, key_data, mask=key_digit_mask)
365 tl.store(value_out + global_offsets, value_data, mask=key_digit_mask)
368# for parallelization, randomly shuffle the entire block rather than adjacent equal elements as pytorch GPU backend
369@libentry()
370@triton.jit(do_not_specialize=["philox_seed", "philox_offset"])
371def duplicate_keys_shuffle_kernel(
372 value_in, n_elements, philox_seed, philox_offset, BLOCK_SIZE: tl.constexpr
373):
374 pid0 = tl.program_id(0)
375 offset_range = tl.arange(0, BLOCK_SIZE)
376 value_offset = pid0.to(tl.int64) * BLOCK_SIZE + offset_range
377 value_mask = value_offset < n_elements
378 value_data = tl.load(value_in + value_offset, mask=value_mask)
380 philox_seed = philox_seed.to(tl.int64)
381 philox_offset = philox_offset.to(tl.int64)
382 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32)
383 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32)
384 i4 = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
385 c0 += i4
386 _O = c0 * 0
387 r0, _, _, _ = tl.philox(philox_seed, c0, c1, _O, _O)
389 _block_size = BLOCK_SIZE
390 r1 = r0 % _block_size.to(tl.uint32)
391 mask_val = _get_iinfo_val(tl.uint32, True)
392 r1 = tl.where(value_offset < n_elements, r1, mask_val)
393 _, sorted_chunk_index = argsort(r1, offset_range, 0, descending=False)
394 store_offset = pid0.to(tl.int64) * BLOCK_SIZE + sorted_chunk_index.to(tl.int64)
395 tl.store(value_in + store_offset, value_data, mask=store_offset < n_elements)
398def sort_by_key(key, value, valid_bits, generator=None):
399 n_elements = key.numel()
400 if n_elements > 2 * 1024:
401 # radix method
402 BLOCK_SIZE = 1024
403 bits_per_pass = 4
404 bits_per_segment = 3
405 passes = triton.cdiv(valid_bits, bits_per_pass)
406 bins = 2**bits_per_pass
407 bins_per_sgement = 2**bits_per_segment
408 bit_mask = bins - 1
410 portion_size = 2**30 # 2 bits reserved for mask
411 num_portions = triton.cdiv(n_elements, portion_size)
412 max_portion_items = portion_size if num_portions > 1 else n_elements
413 max_tiles_per_portion = triton.cdiv(max_portion_items, BLOCK_SIZE)
415 hist_dtype = torch.int64 if num_portions > 1 else torch.int32
416 grid_hist = (triton.cdiv(n_elements, BLOCK_SIZE), bins // bins_per_sgement)
418 digit_hist_slice = torch.empty(
419 (passes, bins + 1, grid_hist[0]), dtype=hist_dtype, device=key.device
420 )
422 digit_hist = torch.empty(
423 (num_portions, passes, bins + 1), dtype=hist_dtype, device=key.device
424 )
425 d_lookback = torch.empty(
426 num_portions * passes * bins * max_tiles_per_portion,
427 dtype=torch.int32,
428 device=key.device,
429 )
431 key_out_p = torch.empty_like(key)
432 key_out_q = torch.empty_like(key)
433 value_out_p = torch.empty_like(value)
434 value_out_q = torch.empty_like(value)
436 # step1
437 d_lookback.zero_()
438 with torch_device_fn.device(key.device):
439 digit_hist_kernel[grid_hist](
440 digit_hist_slice,
441 key,
442 n_elements,
443 bits_per_pass,
444 bins,
445 passes,
446 bit_mask,
447 bins_per_sgement,
448 BLOCK_SIZE,
449 )
451 # step2
452 digit_hist_slice = torch.sum(digit_hist_slice, dim=2, keepdim=False)
453 digit_hist_slice = digit_hist_slice.cumsum(dim=1) # shape of [passes, bins + 1]
454 digit_hist.copy_(digit_hist_slice)
456 bit_offset = 0
457 for p in range(passes):
458 k_in = (key if p == 0 else key_out_p) if p % 2 == 0 else key_out_q
459 v_in = (value if p == 0 else value_out_p) if p % 2 == 0 else value_out_q
460 k_out = key_out_q if p % 2 == 0 else key_out_p
461 v_out = value_out_q if p % 2 == 0 else value_out_p
462 # step3
463 for portion_id in range(num_portions):
464 portion_items = min(
465 n_elements - portion_id * portion_size, portion_size
466 )
467 tiles_per_portion = triton.cdiv(portion_items, BLOCK_SIZE)
468 grid_scatter = (tiles_per_portion, grid_hist[1])
469 with torch_device_fn.device(key.device):
470 radix_sortbykey_scatter_kernel[grid_scatter](
471 k_out,
472 v_out,
473 k_in,
474 v_in,
475 digit_hist,
476 d_lookback,
477 n_elements,
478 bit_offset,
479 passes,
480 p,
481 num_portions,
482 portion_size,
483 portion_id,
484 bit_mask,
485 bins_per_sgement,
486 max_tiles_per_portion,
487 bins,
488 BLOCK_SIZE,
489 )
490 bit_offset += bits_per_pass
492 # last step, shuffle inner-block data
493 BLOCK_SIZE_SHUFFLE = 512
494 grid_shuffle = (triton.cdiv(n_elements, BLOCK_SIZE_SHUFFLE),)
495 philox_seed, philox_offset = philox_backend_seed_offset(
496 n_elements, generator=generator
497 )
498 with torch_device_fn.device(key.device):
499 duplicate_keys_shuffle_kernel[grid_shuffle](
500 v_out,
501 n_elements,
502 philox_seed,
503 philox_offset,
504 BLOCK_SIZE_SHUFFLE,
505 num_warps=4,
506 )
507 return v_out
508 else:
509 # bitonic method
510 BLOCK_SIZE = triton.next_power_of_2(n_elements)
511 grid = (1,)
512 k_out = torch.empty_like(key)
513 v_out = torch.empty_like(value)
514 with torch_device_fn.device(key.device):
515 bitonic_sortbykey_kernel[grid](
516 k_out, v_out, key, value, n_elements, BLOCK_SIZE, False
517 )
518 return v_out
521def randperm(
522 n,
523 *,
524 generator=None,
525 out=None,
526 dtype=torch.int64,
527 layout=torch.strided,
528 device=None,
529 requires_grad=False,
530 pin_memory=False,
531):
532 logger.debug("GEMS_CAMBRICON RANDPERM")
533 assert dtype == torch.int16 or dtype == torch.int32 or dtype == torch.int64
534 assert n <= _MAX_INT64_VAL, "n exceeds maximum int64"
536 if device is None:
537 device = torch.device(device_.name)
538 in_range = torch.arange(n, dtype=dtype, device=device)
540 u8max = 2**8
541 u16max = 2**16
542 u24max = 2**24
543 u32max = 2**32
545 if n <= u8max:
546 valid_bits = 8
547 key_dtype = torch.int8
548 keymin = _MIN_INT8_VAL
549 keymax = _MAX_INT8_VAL
550 elif n <= u16max:
551 valid_bits = 16
552 key_dtype = torch.int16
553 keymin = _MIN_INT16_VAL
554 keymax = _MAX_INT16_VAL
555 elif n <= u24max:
556 valid_bits = 24
557 key_dtype = torch.int32
558 keymin = _MIN_INT24_VAL
559 keymax = _MAX_INT24_VAL
560 elif n <= u32max:
561 valid_bits = 32
562 key_dtype = torch.int32
563 keymin = _MIN_INT32_VAL
564 keymax = _MAX_INT32_VAL
565 else:
566 valid_bits = 64
567 key_dtype = torch.int64
568 keymin = _MIN_INT64_VAL
569 keymax = _MAX_INT64_VAL
571 rand_key = torch.randint(
572 low=keymin, high=keymax, size=[n], dtype=key_dtype, device="cpu"
573 ).to(device)
574 perm_range = sort_by_key(rand_key, in_range, valid_bits, generator=generator)
575 return perm_range