Coverage for src/flag_gems/runtime/backend/_mthreads/ops/randperm.py: 0%
266 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-16 02:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-16 02:02 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.ops.topk import argsort
9from flag_gems.runtime import device, torch_device_fn
10from flag_gems.utils import libentry
11from flag_gems.utils.random_utils import philox_backend_seed_offset
13device_ = device
14logger = logging.getLogger(
15 f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}'
16)
18_MIN_INT8_VAL = tl.constexpr(torch.iinfo(torch.int8).min)
19_MAX_INT8_VAL = tl.constexpr(torch.iinfo(torch.int8).max)
20_MIN_INT16_VAL = tl.constexpr(torch.iinfo(torch.int16).min)
21_MAX_INT16_VAL = tl.constexpr(torch.iinfo(torch.int16).max)
22_MIN_INT32_VAL = tl.constexpr(torch.iinfo(torch.int32).min)
23_MAX_INT32_VAL = tl.constexpr(torch.iinfo(torch.int32).max)
24_MIN_INT64_VAL = tl.constexpr(torch.iinfo(torch.int64).min)
25_MAX_INT64_VAL = tl.constexpr(torch.iinfo(torch.int64).max)
26_MAX_UINT32_VAL = tl.constexpr((1 << 32) - 1)
27_MIN_UINT32_VAL = tl.constexpr(0)
28_MIN_INT24_VAL = tl.constexpr(-(2**23))
29_MAX_INT24_VAL = tl.constexpr(2**23 - 1)
32@triton.jit
33def _get_iinfo_val(
34 dtype,
35 return_max,
36):
37 if dtype is tl.int64:
38 if return_max:
39 return _MAX_INT64_VAL
40 else:
41 return _MIN_INT64_VAL
42 elif dtype is tl.int32:
43 if return_max:
44 return _MAX_INT32_VAL
45 else:
46 return _MIN_INT32_VAL
47 elif dtype is tl.int16:
48 if return_max:
49 return _MAX_INT16_VAL
50 else:
51 return _MIN_INT16_VAL
52 elif dtype is tl.int8:
53 if return_max:
54 return _MAX_INT8_VAL
55 else:
56 return _MIN_INT8_VAL
57 elif dtype is tl.uint32:
58 if return_max:
59 return _MAX_UINT32_VAL
60 else:
61 return _MIN_UINT32_VAL
62 else:
63 raise ValueError("Unknown dtype")
66@libentry()
67@triton.jit
68def bitonic_sortbykey_kernel(
69 y_ptr,
70 index_ptr,
71 chunk_x,
72 chunk_index,
73 N: tl.constexpr,
74 BLOCK_SIZE: tl.constexpr,
75 DESCENDING: tl.constexpr,
76):
77 cur_batch = tl.program_id(0)
78 chunk_x += cur_batch * N
79 chunk_index += cur_batch * N
80 index_ptr += cur_batch * N
81 y_ptr += cur_batch * N
83 cols = tl.arange(0, BLOCK_SIZE)
84 mask = cols < N
86 mask_val = _get_iinfo_val(chunk_x.dtype.element_ty, return_max=not DESCENDING)
88 chunk_x_val = tl.load(chunk_x + cols, mask=mask, other=mask_val)
89 chunk_index_val = tl.load(chunk_index + cols, mask=mask)
91 sorted_chunk_x, sorted_chunk_index = argsort(
92 chunk_x_val, chunk_index_val, 0, descending=DESCENDING
93 )
94 tl.store(y_ptr + cols, sorted_chunk_x, mask=cols < N)
95 tl.store(index_ptr + cols, sorted_chunk_index, mask=cols < N)
98@triton.jit
99def radix_type_convert(k):
100 ik = k.to(tl.int64)
101 if tl.constexpr(k.dtype == tl.int8):
102 mask = (ik >> 7) & 0x1
103 o = tl.where(mask, ik & 0x7F, ik | 0x80)
104 elif tl.constexpr(k.dtype == tl.int16):
105 mask = (ik >> 15) & 0x1
106 o = tl.where(mask, ik & 0x7FFF, ik | 0x8000)
107 elif tl.constexpr(k.dtype == tl.int32):
108 mask = (ik >> 31) & 0x1
109 o = tl.where(mask, ik & 0x7FFFFFFF, ik | 0x80000000)
110 elif tl.constexpr(k.dtype == tl.int64):
111 mask = (ik >> 63) & 0x1
112 o = tl.where(mask, ik & 0x7FFFFFFFFFFFFFFF, ik | 0x8000000000000000)
113 else:
114 o = k
115 return o
118@libentry()
119@triton.jit
120def digit_hist_kernel(
121 digit_hist,
122 key,
123 n_elements,
124 bits_per_pass,
125 bins,
126 passes,
127 bit_mask,
128 bins_segment,
129 BLOCK_SIZE: tl.constexpr,
130):
131 bin_segid = tl.program_id(1)
132 pid0 = tl.program_id(0)
133 grid0 = tl.num_programs(0)
135 key_offset = pid0.to(tl.int64) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
136 key_mask = key_offset < n_elements
137 key_data = tl.load(key + key_offset, mask=key_mask)
138 ikey_data = radix_type_convert(key_data)
139 bit_offset = 0
140 for p in range(passes):
141 key_digit = (ikey_data >> bit_offset) & bit_mask
142 blk_bin_start = bin_segid * bins_segment
143 for s in range(bins_segment):
144 bin_id = s + blk_bin_start
145 digit_mask = tl.where(key_digit == bin_id and key_mask, 1, 0)
146 digit_sum = tl.sum(digit_mask)
147 # +1 for exclusive
148 bin_offset = p * (bins + 1) * grid0 + (bin_id + 1) * grid0 + pid0
149 # reduce rather than global atomic for perf issue
150 tl.store(digit_hist + bin_offset, digit_sum)
151 tl.store(digit_hist + p * (bins + 1) * grid0 + pid0, 0, mask=bin_segid == 0)
152 bit_offset += bits_per_pass
155@libentry()
156@triton.autotune(
157 configs=runtime.get_tuned_config("randperm"),
158 key=["n_elements"],
159)
160@triton.jit
161def radix_sortbykey_scatter_kernel(
162 key_out,
163 value_out,
164 key_in,
165 value_in,
166 digit_hist,
167 d_lookback,
168 n_elements,
169 bit_offset,
170 passes,
171 p,
172 num_portions,
173 portion_size,
174 portion_id,
175 bit_mask,
176 bins_segment,
177 max_tiles_per_portion,
178 bins: tl.constexpr,
179 BLOCK_SIZE: tl.constexpr,
180):
181 LOOKBACK_PARTIAL_MASK = 1 << 30
182 LOOKBACK_GLOBAL_MASK = 1 << 31
183 LOOKBACK_KIND_MASK = LOOKBACK_PARTIAL_MASK | LOOKBACK_GLOBAL_MASK
184 LOOKBACK_VALUE_MASK = ~LOOKBACK_KIND_MASK
186 pid0 = tl.program_id(0)
187 portion_id_i64 = portion_id
188 portion_id_i64 = portion_id_i64.to(tl.int64)
189 key_offset = (
190 portion_id_i64 * portion_size
191 + pid0.to(tl.int64) * BLOCK_SIZE
192 + tl.arange(0, BLOCK_SIZE)
193 )
195 key_mask = key_offset < n_elements
196 value_data = tl.load(value_in + key_offset, mask=key_mask)
197 key_data = tl.load(key_in + key_offset, mask=key_mask)
199 ikey_data = radix_type_convert(key_data)
200 key_digit = (ikey_data >> bit_offset) & bit_mask
202 blk_bin_start = tl.program_id(1) * bins_segment
203 last_block = tl.program_id(0) == tl.num_programs(0) - 1
204 for s in range(bins_segment):
205 bin_id = s + blk_bin_start
206 key_digit_mask = (key_digit == bin_id) & key_mask
207 key_elem_mask = tl.where(key_digit_mask, 1, 0)
208 key_block_rank = tl.cumsum(key_elem_mask)
209 key_block_rank = tl.where(key_digit_mask, key_block_rank - 1, 0)
210 bin_of_bucket = tl.sum(key_elem_mask)
211 partial_counter = bin_of_bucket | LOOKBACK_PARTIAL_MASK
212 tl.store(
213 d_lookback
214 + ((portion_id * passes + p) * max_tiles_per_portion + pid0) * bins
215 + bin_id,
216 partial_counter,
217 cache_modifier=".cg",
218 )
219 bin_offset = p * (bins + 1) + bin_id
220 prefix_offsets = tl.load(
221 digit_hist + bin_offset + portion_id * passes * (bins + 1)
222 )
223 bk = pid0 - 1
224 inc_sum = bin_of_bucket
225 while bk >= 0:
226 rd_lbk_offset = (
227 (portion_id * passes + p) * max_tiles_per_portion + bk
228 ) * bins + bin_id
229 partial_prefix = 0
230 while partial_prefix == 0:
231 partial_prefix = tl.atomic_cas(
232 d_lookback + rd_lbk_offset, 0, 0, sem="acquire"
233 )
234 inc_sum += (partial_prefix & LOOKBACK_VALUE_MASK).to(tl.int32)
235 if partial_prefix & LOOKBACK_GLOBAL_MASK:
236 # break
237 bk = -1
238 else:
239 bk -= 1
240 global_counter = inc_sum | LOOKBACK_GLOBAL_MASK
241 tl.store(
242 d_lookback
243 + ((portion_id * passes + p) * max_tiles_per_portion + pid0) * bins
244 + bin_id,
245 global_counter,
246 cache_modifier=".cg",
247 )
248 inc_bucket_offset = prefix_offsets.to(tl.int64) + inc_sum.to(tl.int64)
249 if last_block and portion_id < num_portions - 1:
250 tl.store(
251 digit_hist + bin_offset + (portion_id + 1) * passes * (bins + 1),
252 inc_bucket_offset,
253 )
254 global_offsets = (
255 inc_bucket_offset - bin_of_bucket.to(tl.int64) + key_block_rank.to(tl.int64)
256 )
257 tl.store(key_out + global_offsets, key_data, mask=key_digit_mask)
258 tl.store(value_out + global_offsets, value_data, mask=key_digit_mask)
261# for parallelization, randomly shuffle the entire block rather than adjacent equal elements as pytorch GPU backend
262@libentry()
263@triton.jit(do_not_specialize=["philox_seed", "philox_offset"])
264def duplicate_keys_shuffle_kernel(
265 value_in, n_elements, philox_seed, philox_offset, BLOCK_SIZE: tl.constexpr
266):
267 pid0 = tl.program_id(0)
268 offset_range = tl.arange(0, BLOCK_SIZE)
269 value_offset = pid0.to(tl.int64) * BLOCK_SIZE + offset_range
270 value_mask = value_offset < n_elements
271 value_data = tl.load(value_in + value_offset, mask=value_mask)
273 philox_seed = philox_seed.to(tl.int64)
274 philox_offset = philox_offset.to(tl.int64)
275 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32)
276 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32)
277 i4 = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
278 c0 += i4
279 _O = c0 * 0
280 r0, _, _, _ = tl.philox(philox_seed, c0, c1, _O, _O)
282 _block_size = BLOCK_SIZE
283 r1 = r0 % _block_size.to(tl.uint32)
284 mask_val = _get_iinfo_val(tl.uint32, True)
285 r1 = tl.where(value_offset < n_elements, r1, mask_val)
286 _, sorted_chunk_index = argsort(r1, offset_range, 0, descending=False)
287 store_offset = pid0.to(tl.int64) * BLOCK_SIZE + sorted_chunk_index.to(tl.int64)
288 tl.store(value_in + store_offset, value_data, mask=store_offset < n_elements)
291def sort_by_key(key, value, valid_bits, generator=None):
292 n_elements = key.numel()
293 if n_elements > 2 * 1024:
294 # radix method
295 BLOCK_SIZE = 1024
296 bits_per_pass = 4
297 bits_per_segment = 3
298 passes = triton.cdiv(valid_bits, bits_per_pass)
299 bins = 2**bits_per_pass
300 bins_per_sgement = 2**bits_per_segment
301 bit_mask = bins - 1
303 portion_size = 2**30 # 2 bits reserved for mask
304 num_portions = triton.cdiv(n_elements, portion_size)
305 max_portion_items = portion_size if num_portions > 1 else n_elements
306 max_tiles_per_portion = triton.cdiv(max_portion_items, BLOCK_SIZE)
308 hist_dtype = torch.int64 if num_portions > 1 else torch.int32
309 grid_hist = (triton.cdiv(n_elements, BLOCK_SIZE), bins // bins_per_sgement)
311 digit_hist_slice = torch.empty(
312 (passes, bins + 1, grid_hist[0]), dtype=hist_dtype, device=key.device
313 )
315 digit_hist = torch.empty(
316 (num_portions, passes, bins + 1), dtype=hist_dtype, device=key.device
317 )
318 d_lookback = torch.empty(
319 num_portions * passes * bins * max_tiles_per_portion,
320 dtype=torch.int32,
321 device=key.device,
322 )
324 key_out_p = torch.empty_like(key)
325 key_out_q = torch.empty_like(key)
326 value_out_p = torch.empty_like(value)
327 value_out_q = torch.empty_like(value)
329 # step1
330 d_lookback.zero_()
331 with torch_device_fn.device(key.device):
332 digit_hist_kernel[grid_hist](
333 digit_hist_slice,
334 key,
335 n_elements,
336 bits_per_pass,
337 bins,
338 passes,
339 bit_mask,
340 bins_per_sgement,
341 BLOCK_SIZE,
342 )
344 # step2
345 digit_hist_slice = torch.sum(digit_hist_slice, dim=2, keepdim=False)
346 digit_hist_slice = digit_hist_slice.cumsum(dim=1) # shape of [passes, bins + 1]
347 digit_hist.copy_(digit_hist_slice)
349 bit_offset = 0
350 for p in range(passes):
351 k_in = (key if p == 0 else key_out_p) if p % 2 == 0 else key_out_q
352 v_in = (value if p == 0 else value_out_p) if p % 2 == 0 else value_out_q
353 k_out = key_out_q if p % 2 == 0 else key_out_p
354 v_out = value_out_q if p % 2 == 0 else value_out_p
355 # step3
356 for portion_id in range(num_portions):
357 portion_items = min(
358 n_elements - portion_id * portion_size, portion_size
359 )
360 tiles_per_portion = triton.cdiv(portion_items, BLOCK_SIZE)
361 grid_scatter = (tiles_per_portion, grid_hist[1])
362 with torch_device_fn.device(key.device):
363 radix_sortbykey_scatter_kernel[grid_scatter](
364 k_out,
365 v_out,
366 k_in,
367 v_in,
368 digit_hist,
369 d_lookback,
370 n_elements,
371 bit_offset,
372 passes,
373 p,
374 num_portions,
375 portion_size,
376 portion_id,
377 bit_mask,
378 bins_per_sgement,
379 max_tiles_per_portion,
380 bins,
381 BLOCK_SIZE,
382 )
383 bit_offset += bits_per_pass
385 # last step, shuffle inner-block data
386 BLOCK_SIZE_SHUFFLE = 512
387 grid_shuffle = (triton.cdiv(n_elements, BLOCK_SIZE_SHUFFLE),)
388 philox_seed, philox_offset = philox_backend_seed_offset(
389 n_elements, generator=generator
390 )
391 with torch_device_fn.device(key.device):
392 duplicate_keys_shuffle_kernel[grid_shuffle](
393 v_out,
394 n_elements,
395 philox_seed,
396 philox_offset,
397 BLOCK_SIZE_SHUFFLE,
398 num_warps=4,
399 )
400 return v_out
401 else:
402 # bitonic method
403 BLOCK_SIZE = triton.next_power_of_2(n_elements)
404 grid = (1,)
405 k_out = torch.empty_like(key)
406 v_out = torch.empty_like(value)
407 with torch_device_fn.device(key.device):
408 bitonic_sortbykey_kernel[grid](
409 k_out, v_out, key, value, n_elements, BLOCK_SIZE, False
410 )
411 return v_out
414def randperm(
415 n,
416 *,
417 generator=None,
418 out=None,
419 dtype=torch.int64,
420 layout=torch.strided,
421 device=None,
422 requires_grad=False,
423 pin_memory=False,
424):
425 logger.debug("GEMS_MTHREADS RANDPERM")
426 assert dtype == torch.int16 or dtype == torch.int32 or dtype == torch.int64
427 assert n <= _MAX_INT64_VAL, "n exceeds maximum int64"
429 if device is None:
430 device = torch.device(device_.name)
431 in_range = torch.arange(n, dtype=dtype, device=device)
433 u8max = 2**8
434 u16max = 2**16
435 u24max = 2**24
436 u32max = 2**32
438 if n <= u8max:
439 valid_bits = 8
440 key_dtype = torch.int8
441 keymin = _MIN_INT8_VAL
442 keymax = _MAX_INT8_VAL
443 elif n <= u16max:
444 valid_bits = 16
445 key_dtype = torch.int16
446 keymin = _MIN_INT16_VAL
447 keymax = _MAX_INT16_VAL
448 elif n <= u24max:
449 valid_bits = 24
450 key_dtype = torch.int32
451 keymin = _MIN_INT24_VAL
452 keymax = _MAX_INT24_VAL
453 elif n <= u32max:
454 valid_bits = 32
455 key_dtype = torch.int32
456 keymin = _MIN_INT32_VAL
457 keymax = _MAX_INT32_VAL
458 else:
459 valid_bits = 64
460 key_dtype = torch.int64
461 keymin = _MIN_INT64_VAL
462 keymax = _MAX_INT64_VAL
464 rand_key = torch.randint(
465 low=keymin, high=keymax, size=[n], dtype=key_dtype, device=device
466 )
467 perm_range = sort_by_key(rand_key, in_range, valid_bits, generator=generator)
468 return perm_range