Coverage for src/flag_gems/ops/sort.py: 41%
210 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.ops.topk import _get_finfo_val, _get_iinfo_val, argsort
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry
11logger = logging.getLogger(__name__)
14def unwrap_if_constexpr(o):
15 return o.value if isinstance(o, tl.constexpr) else o
18@tl.constexpr
19def get_int_t(num_bits: tl.constexpr, signed: tl.constexpr) -> tl.dtype:
20 num_bits = unwrap_if_constexpr(num_bits)
21 signed = unwrap_if_constexpr(signed)
22 return tl.core.get_int_dtype(num_bits, signed)
25@tl.constexpr
26def one_zeros(num_bits: tl.constexpr) -> int:
27 num_bits = unwrap_if_constexpr(num_bits)
28 return 1 << (num_bits - 1)
31@tl.constexpr
32def zero_ones(num_bits: tl.constexpr) -> int:
33 num_bits = unwrap_if_constexpr(num_bits)
34 return (1 << (num_bits - 1)) - 1
37@triton.jit
38def uint_to_uint(x, descending: tl.constexpr = False):
39 out = ~x if descending else x
40 return out
43@triton.jit
44def int_to_uint(x, descending: tl.constexpr = False):
45 num_bits: tl.constexpr = x.dtype.primitive_bitwidth
46 udtype = get_int_t(num_bits, False)
47 ux = tl.cast(x, udtype, bitcast=True)
48 if descending:
49 # 0111111....1
50 bit_mask: tl.constexpr = zero_ones(num_bits)
51 bit_mask_tensor = tl.full((), value=bit_mask, dtype=udtype)
52 out = ux ^ bit_mask_tensor
53 else:
54 # 1000000...0
55 sign_bit_mask: tl.constexpr = one_zeros(num_bits)
56 sign_bit_mask_tensor = tl.full((), value=sign_bit_mask, dtype=udtype)
57 out = ux ^ sign_bit_mask_tensor
58 return out
61@triton.jit
62def floating_to_uint(x, descending: tl.constexpr = False):
63 num_bits: tl.constexpr = x.dtype.primitive_bitwidth
64 sdtype = get_int_t(num_bits, True)
65 udtype = get_int_t(num_bits, False)
66 sx = x.to(sdtype, bitcast=True)
67 ux = x.to(udtype, bitcast=True)
69 sign_bit_mask_v: tl.constexpr = one_zeros(num_bits)
70 sign_bit_mask = tl.full((), value=sign_bit_mask_v, dtype=udtype)
71 # mind the dtype, right_shift for signed is arithmetic right shift
72 # Fix for triton 3.1 or else `sx >> rshift_bits` is promoted to int32
73 rshift_bits = tl.full((), value=num_bits - 1, dtype=sdtype)
74 mask = sign_bit_mask | (sx >> rshift_bits).to(udtype, bitcast=True)
75 tl.static_assert(mask.dtype == udtype, "type mismatch")
76 # 1000000000...0 for positive
77 # 1111111111...1 for negative
78 if descending:
79 out = ux ^ (~mask)
80 else:
81 out = ux ^ mask
82 return out.to(udtype, bitcast=True)
85@triton.jit
86def convert_to_uint_preverse_order(x: tl.tensor, descending: tl.constexpr = False):
87 if x.dtype.is_floating():
88 out = floating_to_uint(x, descending)
89 elif x.dtype.is_int_signed():
90 out = int_to_uint(x, descending)
91 elif x.dtype.is_int_unsigned():
92 out = uint_to_uint(x, descending)
93 return out
96@triton.jit
97def compute_global_hist_kernel(
98 arr_ptr,
99 out_ptr,
100 num_passes,
101 m,
102 n,
103 tiles_n_per_cta,
104 TILE_N: tl.constexpr,
105 TILE_R: tl.constexpr,
106 num_bits_per_pass: tl.constexpr,
107 descending: tl.constexpr,
108):
109 # arr_ptr: (m, n)
110 # out_ptr: (m, n_passes, r), where r = 2 ** k_bits is the number of bins
111 pid = tl.program_id(0)
112 pid_n = pid // m
113 pid_m = pid % m
115 r: tl.constexpr = 2**num_bits_per_pass
116 bfe_mask: tl.constexpr = (1 << num_bits_per_pass) - 1 # a.k.a. 2 ** k_bits - 1
117 CTA_TILE_N: tl.constexpr = TILE_N * tiles_n_per_cta
118 cta_n_start = CTA_TILE_N * pid_n
119 cta_n_end = tl.minimum(cta_n_start + CTA_TILE_N, n)
121 for p in range(0, num_passes): # parallel
122 bit_offset = p * num_bits_per_pass
123 for r_start in range(0, r, TILE_R): # parallel
124 bin_indices = r_start + tl.arange(0, TILE_R)
125 acc = tl.zeros((TILE_R, TILE_N), dtype=tl.int64)
126 for n_start in range(cta_n_start, cta_n_end, TILE_N): # sequantial
127 n_offsets = n_start + tl.arange(0, TILE_N) # (TILE_N, )
128 mask = n_offsets < cta_n_end
129 arr = tl.load(arr_ptr + pid_m * n + n_offsets, mask=mask)
130 arr = convert_to_uint_preverse_order(arr, descending)
131 key = (arr >> bit_offset) & bfe_mask # (TILE_N, )
132 matches = tl.where(
133 mask, (bin_indices[:, None] == key), False
134 ) # (TILE_R, TILE_N)
135 acc += matches
136 local_sum = tl.sum(acc, axis=1)
137 tl.atomic_add(
138 out_ptr + pid_m * num_passes * r + p * r + bin_indices,
139 local_sum,
140 sem="relaxed",
141 )
144@triton.jit
145def sweep(
146 arr_ptr,
147 associate_arr_ptr, # inputs: (key & value)
148 out_ptr,
149 associate_out_ptr, # outputs: (key & value)
150 excumsum_bins_ptr,
151 status_ptr, # aux input and status
152 n_passes,
153 pass_id,
154 bit_offset,
155 m,
156 N,
157 OUT_N,
158 TILE_N: tl.constexpr,
159 TILE_R: tl.constexpr,
160 k_bits: tl.constexpr,
161 descending: tl.constexpr,
162):
163 # r: num_bins = 2 ** k_bits
164 # OUT_N: grid_n = cdiv(N, )
166 # arr_ptr: (m, N)
167 # out_ptr: (m, N)
168 # excumsum_bins_ptr: (m, n_passes, r)
169 # flag_ptr: (m, r, OUT_N)
171 # grid: (m, grid_r, grid_n)
173 # load data
174 pid = tl.program_id(0)
175 pid_m = pid % m
176 pid_n = pid // m
177 pid_r = tl.program_id(1)
179 # bit masks
180 aggregate_mask: tl.constexpr = 1 << 30
181 inclusive_prefix_mask: tl.constexpr = 1 << 31
182 v_mask: tl.constexpr = (1 << 30) - 1
183 bfe_mask: tl.constexpr = (1 << k_bits) - 1 # a.k.a. 2 ** k_bits - 1
185 # initialize flag to zero-local sum is not ready
186 r: tl.constexpr = 2**k_bits
187 cta_r_start = pid_r * TILE_R
188 cta_r_end = tl.minimum(cta_r_start + TILE_R, r)
190 # cumsum for a bin_index
191 n_offsets = pid_n * TILE_N + tl.arange(0, TILE_N) # (TILE_N, )
192 mask = n_offsets < N
193 arr = tl.load(arr_ptr + pid_m * N + n_offsets, mask=mask)
194 arr_u = convert_to_uint_preverse_order(arr, descending)
195 key = (arr_u >> bit_offset) & bfe_mask # (TILE_N, )
197 # since triton can only use scalar as condition, loop by bin_index
198 # status must be pre zero-initialized, or else we have to initialize it
199 for bin_index in range(cta_r_start, cta_r_end):
200 matches = tl.where(mask, key == bin_index, False) # (TILE_N, ) bool
201 # cta level cumsum per bin
202 # CAUTION: tl.sum in triton 3.2 does not promote type
203 local_sum = tl.sum(matches.to(tl.uint32), axis=0)
204 pack0 = aggregate_mask | local_sum
205 status_offset = pid_m * (r * OUT_N) + bin_index * OUT_N + pid_n
206 tl.store(status_ptr + status_offset, pack0, cache_modifier=".cg")
208 # decoupled lookback
209 exclusive_prefix = tl.zeros((), dtype=tl.uint32)
210 i_lookback = pid_n - 1
211 while i_lookback >= 0:
212 flag_offset_i = pid_m * (r * OUT_N) + bin_index * OUT_N + i_lookback
213 pack1 = tl.load(status_ptr + flag_offset_i, volatile=True) # uin32
214 while pack1 == 0:
215 pack1 = tl.load(status_ptr + flag_offset_i, volatile=True)
216 exclusive_prefix += pack1 & v_mask
217 if (pack1 & aggregate_mask) == aggregate_mask:
218 i_lookback -= 1
219 else:
220 i_lookback = -1
221 pack2 = inclusive_prefix_mask | (exclusive_prefix + local_sum)
222 tl.store(status_ptr + status_offset, pack2, cache_modifier=".cg")
224 local_ex_cumsum = (
225 tl.cumsum(matches.to(tl.uint32), axis=0) - matches
226 ) # (TILE_N, )
227 ex_cumsum_in_bin = (
228 exclusive_prefix + local_ex_cumsum
229 ) # global ex_cumsum_in_bin (TILE_N, )
231 # ex_cumsum_bins (m, n_passes, r)
232 ex_cumsum_bins = tl.load(
233 excumsum_bins_ptr + pid_m * (n_passes * r) + pass_id * r + bin_index
234 ) # scalar
235 pos = ex_cumsum_bins + ex_cumsum_in_bin # (TILE_N, )
237 # scatter
238 tl.store(out_ptr + pid_m * N + pos, arr, mask=matches)
239 if associate_arr_ptr is not None:
240 associate_arr = tl.load(
241 associate_arr_ptr + pid_m * N + n_offsets, mask=mask
242 )
243 tl.store(associate_out_ptr + pid_m * N + pos, associate_arr, mask=matches)
246def radix_sort(arr, k_bits=8, descending=False):
247 n = arr.shape[-1]
248 m = arr.numel() // n
249 assert n < (1 << 30), "we have not implemented 2**30 per launch"
250 dtype = arr.dtype
251 num_bits = 1 if dtype == torch.bool else (arr.itemsize * 8)
253 TILE_N = 1024
254 tiles_n_per_cta = 8
255 CTA_TILE_N = tiles_n_per_cta * TILE_N
257 num_bins = 2**k_bits
258 n_passes = triton.cdiv(num_bits, k_bits)
259 TILE_R = 16
261 grid_n = triton.cdiv(n, CTA_TILE_N)
262 grid_for_global_hist = (m * grid_n, 1, 1)
264 with torch_device_fn.device(arr.device):
265 global_hist = torch.zeros(
266 (m, n_passes, num_bins), device=arr.device, dtype=torch.int32
267 )
268 compute_global_hist_kernel[grid_for_global_hist](
269 arr,
270 global_hist,
271 n_passes,
272 m,
273 n,
274 tiles_n_per_cta,
275 TILE_N,
276 TILE_R,
277 k_bits,
278 descending,
279 )
280 ex_cumsum_bins = torch.cumsum(global_hist, -1) - global_hist
281 ex_cumsum_bins = ex_cumsum_bins.to(torch.uint32)
283 # sort
284 arr_in = torch.clone(arr)
285 indices_in = (
286 torch.arange(0, n, dtype=torch.int64, device=arr_in.device)
287 .broadcast_to(arr.shape)
288 .contiguous()
289 )
290 arr_out = torch.empty_like(arr)
291 indices_out = torch.empty_like(indices_in)
293 TILE_R = 8
294 grid_r = triton.cdiv(num_bins, TILE_R)
295 TILE_N = 2048
296 grid_n = triton.cdiv(n, TILE_N)
297 grid_for_sweep = (m * grid_n, grid_r)
299 status = torch.empty(
300 (m, num_bins, grid_n), device=arr.device, dtype=torch.uint32
301 )
303 for i in range(0, n_passes):
304 bit_offset = i * k_bits
305 status.zero_()
306 sweep[grid_for_sweep](
307 arr_in,
308 indices_in,
309 arr_out,
310 indices_out,
311 ex_cumsum_bins,
312 status,
313 n_passes,
314 i,
315 bit_offset,
316 m,
317 n,
318 grid_n,
319 TILE_N,
320 TILE_R,
321 k_bits,
322 descending,
323 )
324 # print(f"< sorted last {bit_offset + k_bits:>2d} bits: {arr_out}")
325 arr_in, arr_out = arr_out, arr_in
326 indices_in, indices_out = indices_out, indices_in
328 return arr_in, indices_in
331@libentry()
332@triton.jit()
333def sort_kernel(
334 in_ptr,
335 out_ptr,
336 out_index_ptr,
337 N: tl.constexpr,
338 BLOCK_SIZE: tl.constexpr,
339 DESCENDING: tl.constexpr,
340 IS_FLOAT: tl.constexpr,
341):
342 cols = tl.arange(0, BLOCK_SIZE)
343 mask = cols < N
344 offset = tl.program_id(0) * N + cols
345 in_ptr += offset
346 out_ptr += offset
347 out_index_ptr += offset
349 if IS_FLOAT:
350 mask_val = _get_finfo_val(in_ptr.dtype.element_ty, return_max=not DESCENDING)
351 in_val = tl.load(in_ptr, mask=mask, other=mask_val)
352 else:
353 mask_val = _get_iinfo_val(in_ptr.dtype.element_ty, return_max=not DESCENDING)
354 in_val = tl.load(in_ptr, mask=mask, other=mask_val)
356 index_val = tl.arange(0, BLOCK_SIZE)
358 sorted_in_val, sorted_index_val = argsort(
359 in_val, index_val, 0, descending=DESCENDING
360 )
361 tl.store(out_ptr, sorted_in_val, mask=mask)
362 tl.store(out_index_ptr, sorted_index_val, mask=mask)
365def sort(inp, dim=-1, descending=False):
366 # We only implement stable radix sort here
367 logger.debug("GEMS SORT")
368 return sort_stable(inp, stable=False, dim=dim, descending=descending)
371def sort_stable(inp, *, stable, dim=-1, descending=False):
372 logger.debug("GEMS SORT.STABLE")
373 # We only implement stable radix sort here
374 _ = stable
375 sort_elem_cnt = inp.shape[dim]
376 if sort_elem_cnt == 1:
377 return inp, torch.zeros_like(inp, dtype=torch.int64)
379 if dim < 0:
380 dim = dim + inp.ndim
381 if dim != inp.ndim - 1:
382 inp = torch.movedim(inp, dim, -1).contiguous()
383 else:
384 inp = inp.contiguous()
386 dtype = inp.dtype
387 num_bits_per_pass = 1 if dtype == torch.bool else 4
388 out, out_index = radix_sort(inp, num_bits_per_pass, descending)
390 if dim != inp.ndim - 1:
391 out = torch.movedim(out, -1, dim)
392 out_index = torch.movedim(out_index, -1, dim)
393 return out, out_index