Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/sort.py: 0%
307 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry
11from .topk import _get_finfo_val, _get_iinfo_val, argsort
13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
16def unwrap_if_constexpr(o):
17 return o.value if isinstance(o, tl.constexpr) else o
20@tl.constexpr
21def get_int_t(num_bits: tl.constexpr, signed: tl.constexpr) -> tl.dtype:
22 num_bits = unwrap_if_constexpr(num_bits)
23 signed = unwrap_if_constexpr(signed)
24 return tl.core.get_int_dtype(num_bits, signed)
27@tl.constexpr
28def one_zeros(num_bits: tl.constexpr) -> int:
29 num_bits = unwrap_if_constexpr(num_bits)
30 return 1 << (num_bits - 1)
33@tl.constexpr
34def zero_ones(num_bits: tl.constexpr) -> int:
35 num_bits = unwrap_if_constexpr(num_bits)
36 return (1 << (num_bits - 1)) - 1
39@triton.jit
40def uint_to_uint(x, descending: tl.constexpr = False):
41 out = ~x if descending else x
42 return out
45@triton.jit
46def int_to_uint(x, descending: tl.constexpr = False):
47 num_bits: tl.constexpr = x.dtype.primitive_bitwidth
48 udtype = get_int_t(num_bits, False)
49 ux = tl.cast(x, udtype, bitcast=True)
50 if descending:
51 # 0111111....1
52 bit_mask: tl.constexpr = zero_ones(num_bits)
53 bit_mask_tensor = tl.full((), value=bit_mask, dtype=udtype)
54 out = ux ^ bit_mask_tensor
55 else:
56 # 1000000...0
57 sign_bit_mask: tl.constexpr = one_zeros(num_bits)
58 sign_bit_mask_tensor = tl.full((), value=sign_bit_mask, dtype=udtype)
59 out = ux ^ sign_bit_mask_tensor
60 return out
63@triton.jit
64def floating_to_uint(x, descending: tl.constexpr = False):
65 num_bits: tl.constexpr = x.dtype.primitive_bitwidth
66 sdtype = get_int_t(num_bits, True)
67 udtype = get_int_t(num_bits, False)
68 sx = x.to(sdtype, bitcast=True)
69 ux = x.to(udtype, bitcast=True)
71 sign_bit_mask_v: tl.constexpr = one_zeros(num_bits)
72 sign_bit_mask = tl.full((), value=sign_bit_mask_v, dtype=udtype)
73 # mind the dtype, right_shift for signed is arithmetic right shift
74 # Fix for triton 3.1 or else `sx >> rshift_bits` is promoted to int32
75 rshift_bits = tl.full((), value=num_bits - 1, dtype=sdtype)
76 mask = sign_bit_mask | (sx >> rshift_bits).to(udtype, bitcast=True)
77 tl.static_assert(mask.dtype == udtype, "type mismatch")
78 # 1000000000...0 for positive
79 # 1111111111...1 for negative
80 if descending:
81 out = ux ^ (~mask)
82 else:
83 out = ux ^ mask
84 return out.to(udtype, bitcast=True)
87@triton.jit
88def convert_to_uint_preverse_order(x: tl.tensor, descending: tl.constexpr = False):
89 if x.dtype.is_floating():
90 if x.dtype == tl.bfloat16:
91 x = x.to(tl.float32)
92 out = floating_to_uint(x, descending)
93 elif x.dtype.is_int_signed():
94 out = int_to_uint(x, descending)
95 elif x.dtype.is_int_unsigned():
96 out = uint_to_uint(x, descending)
97 return out
100@triton.jit
101def compute_global_hist_kernel(
102 arr_ptr,
103 out_ptr,
104 num_passes,
105 m,
106 n,
107 tiles_n_per_cta,
108 TILE_N: tl.constexpr,
109 TILE_R: tl.constexpr,
110 num_bits_per_pass: tl.constexpr,
111 descending: tl.constexpr,
112):
113 # arr_ptr: (m, n)
114 # out_ptr: (m, n_passes, r), where r = 2 ** k_bits is the number of bins
115 pid = tl.program_id(0)
116 pid_n = pid // m
117 pid_m = pid % m
119 r: tl.constexpr = 2**num_bits_per_pass
120 bfe_mask: tl.constexpr = (1 << num_bits_per_pass) - 1 # a.k.a. 2 ** k_bits - 1
121 CTA_TILE_N: tl.constexpr = TILE_N * tiles_n_per_cta
122 cta_n_start = CTA_TILE_N * pid_n
123 cta_n_end = tl.minimum(cta_n_start + CTA_TILE_N, n)
125 for p in range(0, num_passes): # parallel
126 bit_offset = p * num_bits_per_pass
127 for r_start in range(0, r, TILE_R): # parallel
128 bin_indices = r_start + tl.arange(0, TILE_R)
129 acc = tl.zeros((TILE_R, TILE_N), dtype=tl.int64)
130 for n_start in range(cta_n_start, cta_n_end, TILE_N): # sequantial
131 n_offsets = n_start + tl.arange(0, TILE_N) # (TILE_N, )
132 mask = n_offsets < cta_n_end
133 arr = tl.load(arr_ptr + pid_m * n + n_offsets, mask=mask)
134 arr = convert_to_uint_preverse_order(arr, descending)
135 key = (arr >> bit_offset) & bfe_mask # (TILE_N, )
136 matches = tl.where(
137 mask, (bin_indices[:, None] == key), False
138 ) # (TILE_R, TILE_N)
139 acc += matches
140 local_sum = tl.sum(acc, axis=1)
141 tl.atomic_add(
142 out_ptr + pid_m * num_passes * r + p * r + bin_indices,
143 local_sum,
144 sem="relaxed",
145 )
148@triton.jit
149def sweep(
150 arr_ptr,
151 associate_arr_ptr, # inputs: (key & value)
152 out_ptr,
153 associate_out_ptr, # outputs: (key & value)
154 excumsum_bins_ptr,
155 status_ptr, # aux input and status
156 n_passes,
157 pass_id,
158 bit_offset,
159 m,
160 N,
161 OUT_N,
162 TILE_N: tl.constexpr,
163 TILE_R: tl.constexpr,
164 k_bits: tl.constexpr,
165 descending: tl.constexpr,
166):
167 # r: num_bins = 2 ** k_bits
168 # OUT_N: grid_n = cdiv(N, )
170 # arr_ptr: (m, N)
171 # out_ptr: (m, N)
172 # excumsum_bins_ptr: (m, n_passes, r)
173 # flag_ptr: (m, r, OUT_N)
175 # grid: (m, grid_r, grid_n)
177 # load data
178 pid = tl.program_id(0)
179 pid_m = pid % m
180 pid_n = pid // m
181 pid_r = tl.program_id(1)
183 # bit masks
184 aggregate_mask: tl.constexpr = 1 << 30
185 inclusive_prefix_mask: tl.constexpr = 1 << 31
186 v_mask: tl.constexpr = (1 << 30) - 1
187 bfe_mask: tl.constexpr = (1 << k_bits) - 1 # a.k.a. 2 ** k_bits - 1
189 # initialize flag to zero-local sum is not ready
190 r: tl.constexpr = 2**k_bits
191 cta_r_start = pid_r * TILE_R
192 cta_r_end = tl.minimum(cta_r_start + TILE_R, r)
194 # cumsum for a bin_index
195 n_offsets = pid_n * TILE_N + tl.arange(0, TILE_N) # (TILE_N, )
196 mask = n_offsets < N
197 arr = tl.load(arr_ptr + pid_m * N + n_offsets, mask=mask)
198 arr_u = convert_to_uint_preverse_order(arr, descending)
199 key = (arr_u >> bit_offset) & bfe_mask # (TILE_N, )
201 # since triton can only use scalar as condition, loop by bin_index
202 # status must be pre zero-initialized, or else we have to initialize it
203 for bin_index in range(cta_r_start, cta_r_end):
204 matches = tl.where(mask, key == bin_index, False) # (TILE_N, ) bool
205 # cta level cumsum per bin
206 # CAUTION: tl.sum in triton 3.2 does not promote type
207 local_sum = tl.sum(matches.to(tl.uint32), axis=0)
208 pack0 = aggregate_mask | local_sum
209 status_offset = pid_m * (r * OUT_N) + bin_index * OUT_N + pid_n
210 tl.store(status_ptr + status_offset, pack0, cache_modifier=".cg")
212 # decoupled lookback
213 exclusive_prefix = tl.zeros((), dtype=tl.uint32)
214 i_lookback = pid_n - 1
215 while i_lookback >= 0:
216 flag_offset_i = pid_m * (r * OUT_N) + bin_index * OUT_N + i_lookback
217 pack1 = tl.load(status_ptr + flag_offset_i, volatile=True) # uin32
218 while pack1 == 0:
219 pack1 = tl.load(status_ptr + flag_offset_i, volatile=True)
220 exclusive_prefix += pack1 & v_mask
221 if (pack1 & aggregate_mask) == aggregate_mask:
222 i_lookback -= 1
223 else:
224 i_lookback = -1
225 pack2 = inclusive_prefix_mask | (exclusive_prefix + local_sum)
226 tl.store(status_ptr + status_offset, pack2, cache_modifier=".cg")
228 local_ex_cumsum = (
229 tl.cumsum(matches.to(tl.uint32), axis=0) - matches
230 ) # (TILE_N, )
231 ex_cumsum_in_bin = (
232 exclusive_prefix + local_ex_cumsum
233 ) # global ex_cumsum_in_bin (TILE_N, )
235 # ex_cumsum_bins (m, n_passes, r)
236 ex_cumsum_bins = tl.load(
237 excumsum_bins_ptr + pid_m * (n_passes * r) + pass_id * r + bin_index
238 ) # scalar
239 pos = ex_cumsum_bins + ex_cumsum_in_bin # (TILE_N, )
241 # scatter
242 tl.store(out_ptr + pid_m * N + pos, arr, mask=matches)
243 if associate_arr_ptr is not None:
244 associate_arr = tl.load(
245 associate_arr_ptr + pid_m * N + n_offsets, mask=mask
246 )
247 tl.store(associate_out_ptr + pid_m * N + pos, associate_arr, mask=matches)
250@triton.jit
251def count_kernel(
252 x_ptr,
253 counts_ptr, # Output: [M, grid_n, num_bins]
254 M,
255 N,
256 bit_offset,
257 num_bins: tl.constexpr,
258 BLOCK_N: tl.constexpr,
259 descending: tl.constexpr,
260):
261 pid = tl.program_id(0)
263 num_blocks_per_row = tl.cdiv(N, BLOCK_N)
264 row_idx = pid // num_blocks_per_row
265 block_idx = pid % num_blocks_per_row
267 row_start = row_idx * N
268 n_offset = block_idx * BLOCK_N + tl.arange(0, BLOCK_N)
269 mask = n_offset < N
271 val = tl.load(x_ptr + row_start + n_offset, mask=mask, other=0)
272 val_u = convert_to_uint_preverse_order(val, descending)
274 bfe_mask = num_bins - 1
275 key = (val_u >> bit_offset) & bfe_mask
277 for i in range(num_bins):
278 bin_mask = (key == i) & mask
279 count = tl.sum(bin_mask.to(tl.int32))
280 out_offset = (
281 (row_idx * num_blocks_per_row * num_bins) + (block_idx * num_bins) + i
282 )
283 tl.store(counts_ptr + out_offset, count)
286@triton.jit
287def scatter_kernel(
288 x_ptr,
289 x_out_ptr,
290 idx_in_ptr,
291 idx_out_ptr,
292 global_offsets_ptr,
293 M,
294 N,
295 bit_offset,
296 num_bins: tl.constexpr,
297 BLOCK_N: tl.constexpr,
298 descending: tl.constexpr,
299):
300 pid = tl.program_id(0)
301 num_blocks_per_row = tl.cdiv(N, BLOCK_N)
302 row_idx = pid // num_blocks_per_row
303 block_idx = pid % num_blocks_per_row
305 row_start = row_idx * N
306 n_offset = block_idx * BLOCK_N + tl.arange(0, BLOCK_N)
307 mask = n_offset < N
309 val = tl.load(x_ptr + row_start + n_offset, mask=mask, other=0)
310 val_u = convert_to_uint_preverse_order(val, descending)
312 idx = tl.load(idx_in_ptr + row_start + n_offset, mask=mask, other=0)
314 bfe_mask = num_bins - 1
315 key = (val_u >> bit_offset) & bfe_mask
317 for i in range(num_bins):
318 bin_mask = (key == i) & mask
319 local_rank = tl.cumsum(bin_mask.to(tl.int32), axis=0) - 1
321 offset_idx = (
322 (row_idx * num_blocks_per_row * num_bins) + (block_idx * num_bins) + i
323 )
324 global_start = tl.load(global_offsets_ptr + offset_idx)
326 dest_idx = row_start + global_start + local_rank
328 tl.store(x_out_ptr + dest_idx, val, mask=bin_mask)
329 tl.store(idx_out_ptr + dest_idx, idx, mask=bin_mask)
332def radix_sort_low_mem(arr, k_bits=4, descending=False):
333 if arr.ndim == 1:
334 arr = arr.unsqueeze(0)
335 M, N = arr.shape
336 arr_in = arr
337 arr_out = torch.empty_like(arr_in)
339 indices = (
340 torch.arange(N, device=arr.device, dtype=torch.int64)
341 .broadcast_to(arr.shape)
342 .clone()
343 )
344 idx_in = indices
345 idx_out = torch.empty_like(idx_in)
347 dtype = arr.dtype
348 num_bits = 1
349 if dtype == torch.bool:
350 pass
351 elif dtype == torch.bfloat16:
352 num_bits = 4 * 8
353 else:
354 num_bits = arr.element_size() * 8
355 num_passes = (num_bits + k_bits - 1) // k_bits
356 num_bins = 2**k_bits
358 BLOCK_N = 512
359 grid_n = triton.cdiv(N, BLOCK_N)
360 grid = (M * grid_n,)
362 with torch_device_fn.device(arr.device):
363 counts = torch.empty(
364 (M, grid_n, num_bins), device=arr.device, dtype=torch.int32
365 )
367 for p in range(num_passes):
368 bit_offset = p * k_bits
369 count_kernel[grid](
370 arr_in,
371 counts,
372 M,
373 N,
374 bit_offset,
375 num_bins,
376 BLOCK_N,
377 descending,
378 is_use_mask_zero=True,
379 )
381 total_counts_per_bin = counts.sum(dim=1)
382 bin_global_starts = (
383 torch.cumsum(total_counts_per_bin, dim=1) - total_counts_per_bin
384 )
385 block_prefix_sum = torch.cumsum(counts, dim=1) - counts
386 global_offsets = (
387 bin_global_starts.unsqueeze(1)
388 .broadcast_to(block_prefix_sum.shape)
389 .clone()
390 + block_prefix_sum
391 )
393 scatter_kernel[grid](
394 arr_in,
395 arr_out,
396 idx_in,
397 idx_out,
398 global_offsets,
399 M,
400 N,
401 bit_offset,
402 num_bins,
403 BLOCK_N,
404 descending,
405 is_use_mask_zero=True,
406 )
408 arr_in, arr_out = arr_out, arr_in
409 idx_in, idx_out = idx_out, idx_in
411 return arr_in, idx_in
414def radix_sort(arr, k_bits=8, descending=False):
415 n = arr.shape[-1]
416 m = arr.numel() // n
417 assert n < (1 << 30), "we have not implemented 2**30 per launch"
418 dtype = arr.dtype
419 num_bits = 1 if dtype == torch.bool else (arr.element_size() * 8)
421 TILE_N = 1024
422 tiles_n_per_cta = 8
423 CTA_TILE_N = tiles_n_per_cta * TILE_N
425 num_bins = 2**k_bits
426 n_passes = triton.cdiv(num_bits, k_bits)
427 TILE_R = 16
429 grid_n = triton.cdiv(n, CTA_TILE_N)
430 grid_for_global_hist = (m * grid_n, 1, 1)
432 with torch_device_fn.device(arr.device):
433 global_hist = torch.zeros(
434 (m, n_passes, num_bins), device=arr.device, dtype=torch.int32
435 )
436 compute_global_hist_kernel[grid_for_global_hist](
437 arr,
438 global_hist,
439 n_passes,
440 m,
441 n,
442 tiles_n_per_cta,
443 TILE_N,
444 TILE_R,
445 k_bits,
446 descending,
447 )
448 ex_cumsum_bins = torch.cumsum(global_hist, -1) - global_hist
449 ex_cumsum_bins = ex_cumsum_bins.to(torch.uint32)
451 # sort
452 arr_in = torch.clone(arr)
453 indices_in = (
454 torch.arange(0, n, dtype=torch.int64, device=arr_in.device)
455 .broadcast_to(arr.shape)
456 .contiguous()
457 )
458 arr_out = torch.empty_like(arr)
459 indices_out = torch.empty_like(indices_in)
461 TILE_R = 8
462 grid_r = triton.cdiv(num_bins, TILE_R)
463 TILE_N = 2048
464 grid_n = triton.cdiv(n, TILE_N)
465 grid_for_sweep = (m * grid_n, grid_r)
467 status = torch.empty(
468 (m, num_bins, grid_n), device=arr.device, dtype=torch.uint32
469 )
471 for i in range(0, n_passes):
472 bit_offset = i * k_bits
473 status.zero_()
474 sweep[grid_for_sweep](
475 arr_in,
476 indices_in,
477 arr_out,
478 indices_out,
479 ex_cumsum_bins,
480 status,
481 n_passes,
482 i,
483 bit_offset,
484 m,
485 n,
486 grid_n,
487 TILE_N,
488 TILE_R,
489 k_bits,
490 descending,
491 )
492 # print(f"< sorted last {bit_offset + k_bits:>2d} bits: {arr_out}")
493 arr_in, arr_out = arr_out, arr_in
494 indices_in, indices_out = indices_out, indices_in
496 return arr_in, indices_in
499@libentry()
500@triton.jit()
501def sort_kernel(
502 in_ptr,
503 out_ptr,
504 out_index_ptr,
505 N: tl.constexpr,
506 BLOCK_SIZE: tl.constexpr,
507 DESCENDING: tl.constexpr,
508 IS_FLOAT: tl.constexpr,
509):
510 cols = tl.arange(0, BLOCK_SIZE)
511 mask = cols < N
512 offset = tl.program_id(0) * N + cols
513 in_ptr += offset
514 out_ptr += offset
515 out_index_ptr += offset
517 if IS_FLOAT:
518 mask_val = _get_finfo_val(in_ptr.dtype.element_ty, return_max=not DESCENDING)
519 in_val = tl.load(in_ptr, mask=mask, other=mask_val)
520 in_val = tl.where(in_val.dtype.is_fp64(), in_val, in_val.to(tl.float32))
521 else:
522 mask_val = _get_iinfo_val(in_ptr.dtype.element_ty, return_max=not DESCENDING)
523 in_val = tl.load(in_ptr, mask=mask, other=mask_val).to(tl.int32)
524 index_val = tl.arange(0, BLOCK_SIZE)
526 sorted_in_val, sorted_index_val = argsort(
527 in_val, index_val, 0, descending=DESCENDING
528 )
529 tl.store(out_ptr, sorted_in_val, mask=mask)
530 tl.store(out_index_ptr, sorted_index_val, mask=mask)
533def sort(inp, dim=-1, descending=False):
534 logger.debug("GEMS SORT")
535 sort_elem_cnt = inp.shape[dim]
536 if sort_elem_cnt == 1:
537 return inp, torch.zeros_like(inp, dtype=torch.int64)
538 elif sort_elem_cnt > 512: # TODO: Optimize implementation for large cases.
539 return torch.sort(inp, stable=False, dim=dim, descending=descending)
540 block_size = triton.next_power_of_2(sort_elem_cnt)
542 if dim < 0:
543 dim = dim + inp.ndim
544 if dim != inp.ndim - 1:
545 inp = torch.movedim(inp, dim, -1).contiguous()
546 else:
547 inp = inp.contiguous()
548 batch_size = math.prod(inp.shape) // sort_elem_cnt
550 out = torch.empty_like(inp)
551 out_index = torch.empty_like(inp, dtype=torch.int64)
553 with torch_device_fn.device(inp.device):
554 sort_kernel[batch_size,](
555 inp,
556 out,
557 out_index,
558 N=sort_elem_cnt,
559 BLOCK_SIZE=block_size,
560 DESCENDING=descending,
561 IS_FLOAT=inp.is_floating_point(),
562 num_warps=4,
563 )
565 if dim != inp.ndim - 1:
566 out = torch.movedim(out, -1, dim)
567 out_index = torch.movedim(out_index, -1, dim)
568 return out, out_index
571def sort_stable(inp, *, stable, dim=-1, descending=False):
572 logger.debug("GEMS SORT.STABLE")
573 # We only implement stable radix sort here
574 _ = stable
575 sort_elem_cnt = inp.shape[dim]
576 if sort_elem_cnt == 1:
577 return inp, torch.zeros_like(inp, dtype=torch.int64)
579 if dim < 0:
580 dim = dim + inp.ndim
581 if dim != inp.ndim - 1:
582 inp = torch.movedim(inp, dim, -1).contiguous()
583 else:
584 inp = inp.contiguous()
586 dtype = inp.dtype
587 num_bits_per_pass = 1 if dtype == torch.bool else 4
588 out, out_index = radix_sort_low_mem(inp, num_bits_per_pass, descending)
590 if dim != inp.ndim - 1:
591 out = torch.movedim(out, -1, dim)
592 out_index = torch.movedim(out_index, -1, dim)
593 return out, out_index