Coverage for src/flag_gems/runtime/backend/_ascend/ops/randperm.py: 0%
273 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.ops.topk import argsort
9from flag_gems.runtime import device, torch_device_fn
10from flag_gems.utils import libentry
11from flag_gems.utils.random_utils import philox_backend_seed_offset
13logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}')
15device_ = device
17_MIN_INT8_VAL: tl.constexpr = torch.iinfo(torch.int8).min
18_MAX_INT8_VAL: tl.constexpr = torch.iinfo(torch.int8).max
19_MIN_INT16_VAL: tl.constexpr = torch.iinfo(torch.int16).min
20_MAX_INT16_VAL: tl.constexpr = torch.iinfo(torch.int16).max
21_MIN_INT32_VAL: tl.constexpr = torch.iinfo(torch.int32).min
22_MAX_INT32_VAL: tl.constexpr = torch.iinfo(torch.int32).max
23_MIN_INT64_VAL: tl.constexpr = torch.iinfo(torch.int64).min
24_MAX_INT64_VAL: tl.constexpr = torch.iinfo(torch.int64).max
25_MAX_UINT32_VAL: tl.constexpr = (1 << 32) - 1
26_MIN_UINT32_VAL: tl.constexpr = 0
27_MIN_INT24_VAL: tl.constexpr = -(2**23)
28_MAX_INT24_VAL: tl.constexpr = 2**23 - 1
31@triton.jit
32def _get_iinfo_val(
33 dtype,
34 return_max,
35):
36 if dtype is tl.int64:
37 if return_max:
38 return _MAX_INT64_VAL
39 else:
40 return _MIN_INT64_VAL
41 elif dtype is tl.int32:
42 if return_max:
43 return _MAX_INT32_VAL
44 else:
45 return _MIN_INT32_VAL
46 elif dtype is tl.int16:
47 if return_max:
48 return _MAX_INT16_VAL
49 else:
50 return _MIN_INT16_VAL
51 elif dtype is tl.int8:
52 if return_max:
53 return _MAX_INT8_VAL
54 else:
55 return _MIN_INT8_VAL
56 elif dtype is tl.uint32:
57 if return_max:
58 return _MAX_UINT32_VAL
59 else:
60 return _MIN_UINT32_VAL
61 else:
62 raise ValueError("Unknown dtype")
65@libentry()
66@triton.jit
67def bitonic_sortbykey_kernel(
68 y_ptr,
69 index_ptr,
70 chunk_x,
71 chunk_index,
72 N: tl.constexpr,
73 BLOCK_SIZE: tl.constexpr,
74 DESCENDING: tl.constexpr,
75):
76 cur_batch = tl.program_id(0)
77 chunk_x += cur_batch * N
78 chunk_index += cur_batch * N
79 index_ptr += cur_batch * N
80 y_ptr += cur_batch * N
82 cols = tl.arange(0, BLOCK_SIZE)
83 mask = cols < N
85 mask_val = _get_iinfo_val(chunk_x.dtype.element_ty, return_max=not DESCENDING)
87 chunk_x_val = tl.load(chunk_x + cols, mask=mask, other=mask_val)
88 chunk_index_val = tl.load(chunk_index + cols, mask=mask)
90 sorted_chunk_x, sorted_chunk_index = argsort(
91 chunk_x_val, chunk_index_val, 0, descending=DESCENDING
92 )
93 tl.store(y_ptr + cols, sorted_chunk_x, mask=cols < N)
94 tl.store(index_ptr + cols, sorted_chunk_index, mask=cols < N)
97@triton.jit
98def radix_type_convert(k):
99 if tl.constexpr(k.dtype == tl.int8):
100 ik = k.to(tl.int8, bitcast=True)
101 mask = (ik >> 7) & 0x1
102 o = tl.where(mask, ik & 0x7F, ik | 0x80)
103 elif tl.constexpr(k.dtype == tl.int16):
104 ik = k.to(tl.int16, bitcast=True)
105 mask = (ik >> 15) & 0x1
106 o = tl.where(mask, ik & 0x7FFF, ik | 0x8000)
107 elif tl.constexpr(k.dtype == tl.int32):
108 ik = k.to(tl.int32, bitcast=True)
109 mask = (ik >> 31) & 0x1
110 o = tl.where(mask, ik & 0x7FFFFFFF, ik | 0x80000000)
111 elif tl.constexpr(k.dtype == tl.int64):
112 ik = k.to(tl.int64, bitcast=True)
113 mask = (ik >> 63) & 0x1
114 o = tl.where(mask, ik & 0x7FFFFFFFFFFFFFFF, ik | 0x8000000000000000)
115 else:
116 o = k
117 return o
120@libentry()
121@triton.jit
122def digit_hist_kernel(
123 digit_hist,
124 key,
125 n_elements,
126 bits_per_pass,
127 bins,
128 passes,
129 bit_mask,
130 bins_segment,
131 BLOCK_SIZE: tl.constexpr,
132):
133 bin_segid = tl.program_id(1)
134 pid0 = tl.program_id(0)
135 grid0 = tl.num_programs(0)
137 key_offset = pid0.to(tl.int64) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
138 key_mask = key_offset < n_elements
139 key_data = tl.load(key + key_offset, mask=key_mask)
140 ikey_data = radix_type_convert(key_data)
141 bit_offset = 0
142 for p in range(passes):
143 key_digit = (ikey_data >> bit_offset) & bit_mask
144 blk_bin_start = bin_segid * bins_segment
145 for s in range(bins_segment):
146 bin_id = s + blk_bin_start
147 digit_mask = tl.where(key_digit == bin_id and key_mask, 1, 0)
148 digit_sum = tl.sum(digit_mask)
149 # +1 for exclusive
150 bin_offset = p * (bins + 1) * grid0 + (bin_id + 1) * grid0 + pid0
151 # reduce rather than global atomic for perf issue
152 tl.store(digit_hist + bin_offset, digit_sum)
153 tl.store(digit_hist + p * (bins + 1) * grid0 + pid0, 0, mask=bin_segid == 0)
154 bit_offset += bits_per_pass
157@libentry()
158@triton.autotune(
159 configs=runtime.get_tuned_config("randperm"),
160 key=["n_elements"],
161)
162@triton.jit
163def radix_sortbykey_scatter_kernel(
164 key_out,
165 value_out,
166 key_in,
167 value_in,
168 digit_hist,
169 d_lookback,
170 n_elements,
171 bit_offset,
172 passes,
173 p,
174 num_portions,
175 portion_size,
176 portion_id,
177 bit_mask,
178 bins_segment,
179 max_tiles_per_portion,
180 bins: tl.constexpr,
181 BLOCK_SIZE: tl.constexpr,
182):
183 LOOKBACK_PARTIAL_MASK = 1 << 30
184 LOOKBACK_GLOBAL_MASK = 1 << 31
185 LOOKBACK_KIND_MASK = LOOKBACK_PARTIAL_MASK | LOOKBACK_GLOBAL_MASK
186 LOOKBACK_VALUE_MASK = ~LOOKBACK_KIND_MASK
188 pid0 = tl.program_id(0)
189 portion_id_i64 = portion_id
190 portion_id_i64 = portion_id_i64.to(tl.int64)
191 key_offset = (
192 portion_id_i64 * portion_size
193 + pid0.to(tl.int64) * BLOCK_SIZE
194 + tl.arange(0, BLOCK_SIZE)
195 )
197 key_mask = key_offset < n_elements
198 value_data = tl.load(value_in + key_offset, mask=key_mask)
199 key_data = tl.load(key_in + key_offset, mask=key_mask)
201 ikey_data = radix_type_convert(key_data)
202 key_digit = (ikey_data >> bit_offset) & bit_mask
204 blk_bin_start = tl.program_id(1) * bins_segment
205 last_block = tl.program_id(0) == tl.num_programs(0) - 1
206 for s in range(bins_segment):
207 bin_id = s + blk_bin_start
208 key_digit_mask = (key_digit == bin_id) & key_mask
209 key_elem_mask = tl.where(key_digit_mask, 1, 0)
210 key_block_rank = tl.cumsum(key_elem_mask)
211 key_block_rank = tl.where(key_digit_mask, key_block_rank - 1, 0)
212 bin_of_bucket = tl.sum(key_elem_mask)
213 partial_counter = bin_of_bucket | LOOKBACK_PARTIAL_MASK
214 tl.store(
215 d_lookback
216 + ((portion_id * passes + p) * max_tiles_per_portion + pid0) * bins
217 + bin_id,
218 partial_counter,
219 cache_modifier=".cg",
220 )
221 bin_offset = p * (bins + 1) + bin_id
222 prefix_offsets = tl.load(
223 digit_hist + bin_offset + portion_id * passes * (bins + 1)
224 )
225 bk = pid0 - 1
227 inc_sum = bin_of_bucket
229 while bk >= 0:
230 rd_lbk_offset = (
231 (portion_id * passes + p) * max_tiles_per_portion + bk
232 ) * bins + bin_id
233 partial_prefix = tl.load(d_lookback + rd_lbk_offset, volatile=True)
234 max_wait = 1000
235 wait_count = 0
236 while partial_prefix == 0 and wait_count < max_wait:
237 partial_prefix = tl.load(d_lookback + rd_lbk_offset, volatile=True)
238 wait_count += 1
239 inc_sum += (partial_prefix & LOOKBACK_VALUE_MASK).to(tl.int32)
240 if partial_prefix & LOOKBACK_GLOBAL_MASK:
241 # break
242 bk = -1
243 else:
244 bk -= 1
246 global_counter = inc_sum | LOOKBACK_GLOBAL_MASK
247 tl.store(
248 d_lookback
249 + ((portion_id * passes + p) * max_tiles_per_portion + pid0) * bins
250 + bin_id,
251 global_counter,
252 cache_modifier=".cg",
253 )
254 inc_bucket_offset = prefix_offsets.to(tl.int64) + inc_sum.to(tl.int64)
255 if last_block and portion_id < num_portions - 1:
256 tl.store(
257 digit_hist + bin_offset + (portion_id + 1) * passes * (bins + 1),
258 inc_bucket_offset,
259 )
260 global_offsets = (
261 inc_bucket_offset - bin_of_bucket.to(tl.int64) + key_block_rank.to(tl.int64)
262 )
263 tl.store(key_out + global_offsets, key_data, mask=key_digit_mask)
264 tl.store(value_out + global_offsets, value_data, mask=key_digit_mask)
267@libentry()
268@triton.jit(do_not_specialize=["philox_seed", "philox_offset"])
269def duplicate_keys_shuffle_kernel(
270 value_in, n_elements, philox_seed, philox_offset, BLOCK_SIZE: tl.constexpr
271):
272 pid0 = tl.program_id(0)
273 offset_range = tl.arange(0, BLOCK_SIZE)
274 value_offset = pid0.to(tl.int64) * BLOCK_SIZE + offset_range
275 value_mask = value_offset < n_elements
276 value_data = tl.load(value_in + value_offset, mask=value_mask)
278 philox_seed = philox_seed.to(tl.int64)
279 philox_offset = philox_offset.to(tl.int64)
281 c0 = (philox_offset & 0xFFFFFFFF).to(tl.int32)
282 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.int32)
284 i4 = pid0 * BLOCK_SIZE + offset_range
285 c0 = c0 + i4
286 _O = c0 * 0
288 r0, _, _, _ = tl.philox(philox_seed, c0, c1, _O, _O)
290 r0_int = r0.to(tl.int32)
291 r1 = r0_int - (r0_int // BLOCK_SIZE) * BLOCK_SIZE
293 mask_val = tl.full([BLOCK_SIZE], BLOCK_SIZE, dtype=tl.int32)
294 r1 = tl.where(value_mask, r1, mask_val)
296 _, sorted_chunk_index = argsort(r1, offset_range, 0, descending=False)
298 store_offset = pid0.to(tl.int64) * BLOCK_SIZE + sorted_chunk_index.to(tl.int64)
299 tl.store(value_in + store_offset, value_data, mask=store_offset < n_elements)
302def sort_by_key(key, value, valid_bits):
303 n_elements = key.numel()
304 if n_elements > 4:
305 # radix method
306 BLOCK_SIZE = 1024
307 bits_per_pass = 4
308 bits_per_segment = 3
309 passes = triton.cdiv(valid_bits, bits_per_pass)
310 bins = 2**bits_per_pass
311 bins_per_sgement = 2**bits_per_segment
312 bit_mask = bins - 1
314 portion_size = 2**30 # 2 bits reserved for mask
315 num_portions = triton.cdiv(n_elements, portion_size)
316 max_portion_items = portion_size if num_portions > 1 else n_elements
317 max_tiles_per_portion = triton.cdiv(max_portion_items, BLOCK_SIZE)
319 hist_dtype = torch.int64 if num_portions > 1 else torch.int32
320 grid_hist = (triton.cdiv(n_elements, BLOCK_SIZE), bins // bins_per_sgement)
322 digit_hist_slice = torch.empty(
323 (passes, bins + 1, grid_hist[0]), dtype=hist_dtype, device=key.device
324 )
326 digit_hist = torch.empty(
327 (num_portions, passes, bins + 1), dtype=hist_dtype, device=key.device
328 )
329 d_lookback = torch.empty(
330 num_portions * passes * bins * max_tiles_per_portion,
331 dtype=torch.int32,
332 device=key.device,
333 )
335 key_out_p = torch.empty_like(key)
336 key_out_q = torch.empty_like(key)
337 value_out_p = torch.empty_like(value)
338 value_out_q = torch.empty_like(value)
339 # step1
340 d_lookback.zero_()
341 with torch_device_fn.device(key.device):
342 digit_hist_kernel[grid_hist](
343 digit_hist_slice,
344 key,
345 n_elements,
346 bits_per_pass,
347 bins,
348 passes,
349 bit_mask,
350 bins_per_sgement,
351 BLOCK_SIZE,
352 )
354 # step2
355 digit_hist_slice = torch.sum(digit_hist_slice, dim=2, keepdim=False)
356 digit_hist_slice = digit_hist_slice.cumsum(dim=1) # shape of [passes, bins + 1]
357 digit_hist.copy_(digit_hist_slice)
359 bit_offset = 0
361 for p in range(passes):
362 k_in = (key if p == 0 else key_out_p) if p % 2 == 0 else key_out_q
363 v_in = (value if p == 0 else value_out_p) if p % 2 == 0 else value_out_q
364 k_out = key_out_q if p % 2 == 0 else key_out_p
365 v_out = value_out_q if p % 2 == 0 else value_out_p
366 # step3
367 for portion_id in range(num_portions):
368 portion_items = min(
369 n_elements - portion_id * portion_size, portion_size
370 )
371 tiles_per_portion = triton.cdiv(portion_items, BLOCK_SIZE)
372 grid_scatter = (tiles_per_portion, grid_hist[1])
373 with torch_device_fn.device(key.device):
374 radix_sortbykey_scatter_kernel[grid_scatter](
375 k_out,
376 v_out,
377 k_in,
378 v_in,
379 digit_hist,
380 d_lookback,
381 n_elements,
382 bit_offset,
383 passes,
384 p,
385 num_portions,
386 portion_size,
387 portion_id,
388 bit_mask,
389 bins_per_sgement,
390 max_tiles_per_portion,
391 bins,
392 BLOCK_SIZE,
393 )
395 bit_offset += bits_per_pass
397 # last step, shuffle inner-block data
398 BLOCK_SIZE_SHUFFLE = 128
399 grid_shuffle = (triton.cdiv(n_elements, BLOCK_SIZE_SHUFFLE),)
400 philox_seed, philox_offset = philox_backend_seed_offset(n_elements)
401 with torch_device_fn.device(key.device):
402 duplicate_keys_shuffle_kernel[grid_shuffle](
403 v_out,
404 n_elements,
405 philox_seed,
406 philox_offset,
407 BLOCK_SIZE_SHUFFLE,
408 num_warps=4,
409 )
410 return v_out
411 else:
412 # bitonic method
413 BLOCK_SIZE = triton.next_power_of_2(n_elements)
414 logger.debug(n_elements)
416 grid = (1,)
417 k_out = torch.empty_like(key)
418 v_out = torch.empty_like(value)
419 with torch_device_fn.device(key.device):
420 bitonic_sortbykey_kernel[grid](
421 k_out, v_out, key, value, n_elements, BLOCK_SIZE, False
422 )
423 return v_out
426def randperm(
427 n,
428 *,
429 generator=None,
430 out=None,
431 dtype=torch.int64,
432 layout=torch.strided,
433 device=None,
434 requires_grad=False,
435 pin_memory=False,
436):
437 logger.debug("GEMS_ASCEND RANDPERM")
438 assert dtype == torch.int16 or dtype == torch.int32 or dtype == torch.int64
439 assert n <= _MAX_INT64_VAL, "n exceeds maximum int64"
441 if device is None:
442 device = torch.device(device_.name)
443 in_range = torch.arange(n, dtype=dtype, device=device)
445 u8max = 2**8
446 u16max = 2**16
447 u24max = 2**24
448 u32max = 2**32
450 if n <= u8max:
451 valid_bits = 8
452 key_dtype = torch.int8
453 keymin = _MIN_INT8_VAL
454 keymax = _MAX_INT8_VAL
455 elif n <= u16max:
456 valid_bits = 16
457 key_dtype = torch.int16
458 keymin = _MIN_INT16_VAL
459 keymax = _MAX_INT16_VAL
460 elif n <= u24max:
461 valid_bits = 24
462 key_dtype = torch.int32
463 keymin = _MIN_INT24_VAL
464 keymax = _MAX_INT24_VAL
465 elif n <= u32max:
466 valid_bits = 32
467 key_dtype = torch.int32
468 keymin = _MIN_INT32_VAL
469 keymax = _MAX_INT32_VAL
470 else:
471 valid_bits = 64
472 key_dtype = torch.int64
473 keymin = _MIN_INT64_VAL
474 keymax = _MAX_INT64_VAL
476 rand_key = torch.randint(
477 low=keymin, high=keymax, size=[n], dtype=key_dtype, device=device
478 )
479 # breakpoint()
480 perm_range = sort_by_key(rand_key, in_range, valid_bits)
481 return perm_range