Coverage for src/flag_gems/runtime/backend/_hygon/ops/sort.py: 0%
182 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.ops.topk import _get_finfo_val, _get_iinfo_val, argsort
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry
11logger = logging.getLogger(__name__)
14def unwrap_if_constexpr(o):
15 return o.value if isinstance(o, tl.constexpr) else o
18@tl.constexpr
19def get_int_t(num_bits: tl.constexpr, signed: tl.constexpr) -> tl.dtype:
20 num_bits = unwrap_if_constexpr(num_bits)
21 signed = unwrap_if_constexpr(signed)
22 return tl.core.get_int_dtype(num_bits, signed)
25@tl.constexpr
26def one_zeros(num_bits: tl.constexpr) -> int:
27 num_bits = unwrap_if_constexpr(num_bits)
28 return 1 << (num_bits - 1)
31@tl.constexpr
32def zero_ones(num_bits: tl.constexpr) -> int:
33 num_bits = unwrap_if_constexpr(num_bits)
34 return (1 << (num_bits - 1)) - 1
37@triton.jit
38def uint_to_uint(x, descending: tl.constexpr = False):
39 out = ~x if descending else x
40 return out
43@triton.jit
44def int_to_uint(x, descending: tl.constexpr = False):
45 num_bits: tl.constexpr = x.dtype.primitive_bitwidth
46 udtype = get_int_t(num_bits, False)
47 ux = tl.cast(x, udtype, bitcast=True)
48 if descending:
49 # 0111111....1
50 bit_mask: tl.constexpr = zero_ones(num_bits)
51 bit_mask_tensor = tl.full((), value=bit_mask, dtype=udtype)
52 out = ux ^ bit_mask_tensor
53 else:
54 # 1000000...0
55 sign_bit_mask: tl.constexpr = one_zeros(num_bits)
56 sign_bit_mask_tensor = tl.full((), value=sign_bit_mask, dtype=udtype)
57 out = ux ^ sign_bit_mask_tensor
58 return out
61@triton.jit
62def floating_to_uint(x, descending: tl.constexpr = False):
63 num_bits: tl.constexpr = x.dtype.primitive_bitwidth
64 sdtype = get_int_t(num_bits, True)
65 udtype = get_int_t(num_bits, False)
66 sx = x.to(sdtype, bitcast=True)
67 ux = x.to(udtype, bitcast=True)
69 sign_bit_mask_v: tl.constexpr = one_zeros(num_bits)
70 sign_bit_mask = tl.full((), value=sign_bit_mask_v, dtype=udtype)
71 # mind the dtype, right_shift for signed is arithmetic right shift
72 # Fix for triton 3.1 or else `sx >> rshift_bits` is promoted to int32
73 rshift_bits = tl.full((), value=num_bits - 1, dtype=sdtype)
74 mask = sign_bit_mask | (sx >> rshift_bits).to(udtype, bitcast=True)
75 tl.static_assert(mask.dtype == udtype, "type mismatch")
76 # 1000000000...0 for positive
77 # 1111111111...1 for negative
78 if descending:
79 out = ux ^ (~mask)
80 else:
81 out = ux ^ mask
82 return out.to(udtype, bitcast=True)
85@triton.jit
86def convert_to_uint_preverse_order(x: tl.tensor, descending: tl.constexpr = False):
87 # Explicitly handle bool to avoid ambiguity
88 if x.dtype == tl.int1:
89 out = uint_to_uint(x, descending)
90 elif x.dtype.is_floating():
91 out = floating_to_uint(x, descending)
92 elif x.dtype.is_int_signed():
93 out = int_to_uint(x, descending)
94 elif x.dtype.is_int_unsigned():
95 out = uint_to_uint(x, descending)
96 else:
97 out = uint_to_uint(x, descending)
98 return out
101@triton.jit
102def count_kernel(
103 arr_ptr,
104 count_ptr, # Output: (Grid, 2**k_bits)
105 m,
106 N,
107 grid_n, # [FIX] Explicitly pass grid_n
108 k_bits: tl.constexpr,
109 bit_offset: tl.constexpr,
110 BLOCK_N: tl.constexpr,
111 descending: tl.constexpr,
112):
113 pid = tl.program_id(0)
114 # Use explicitly passed grid_n to avoid inconsistency
115 pid_m = pid // grid_n
116 pid_n = pid % grid_n
118 n_offset = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
119 mask = n_offset < N
121 # [FIX] Use int64 for pointer arithmetic to be safe with large m
122 val = tl.load(arr_ptr + pid_m.to(tl.int64) * N + n_offset, mask=mask, other=0)
123 val_u = convert_to_uint_preverse_order(val, descending)
125 bfe_mask: tl.constexpr = (1 << k_bits) - 1
126 key = (val_u >> bit_offset) & bfe_mask
128 # Cast key to int32 to match atomic_add pointer arithmetic requirements
129 key = key.to(tl.int32)
131 NUM_BINS: tl.constexpr = 1 << k_bits
132 off_base = pid * NUM_BINS
133 tl.atomic_add(count_ptr + off_base + key, 1, mask=mask)
136@triton.jit
137def scatter_kernel(
138 arr_ptr,
139 arr_out_ptr,
140 idx_ptr, # Optional: input indices
141 idx_out_ptr, # Optional: output indices
142 global_offsets_ptr, # Input: (Grid, 2**k_bits) - Precomputed prefix sum
143 m,
144 N,
145 grid_n, # [FIX] Explicitly pass grid_n
146 k_bits: tl.constexpr,
147 bit_offset: tl.constexpr,
148 BLOCK_N: tl.constexpr,
149 descending: tl.constexpr,
150):
151 pid = tl.program_id(0)
152 # Use explicitly passed grid_n
153 pid_m = pid // grid_n
154 pid_n = pid % grid_n
156 NUM_BINS: tl.constexpr = 1 << k_bits
157 bfe_mask: tl.constexpr = NUM_BINS - 1
159 # Base destination index for this block (ptr to the start of bins for this block)
160 off_base_ptr = global_offsets_ptr + pid * NUM_BINS
162 n_offset = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
163 mask = n_offset < N
165 # 1. Load Data
166 # [FIX] Use int64 for pointer arithmetic
167 val = tl.load(arr_ptr + pid_m.to(tl.int64) * N + n_offset, mask=mask, other=0)
168 val_u = convert_to_uint_preverse_order(val, descending)
169 key = (val_u >> bit_offset) & bfe_mask
170 key = key.to(tl.int32)
172 # 2. Load Index (Pre-load OUTSIDE the loop)
173 # The index belongs to the thread's element, it is invariant of the bin loop.
174 # Loading it once here ensures stability and correctness.
175 src_idx = tl.zeros((BLOCK_N,), dtype=tl.int64)
176 if idx_ptr is not None:
177 src_idx = tl.load(
178 idx_ptr + pid_m.to(tl.int64) * N + n_offset, mask=mask, other=0
179 )
181 # 3. Calculate Local Rank and Scatter
182 for b in range(0, NUM_BINS):
183 # Load the scalar offset for the specific bin
184 base_offset = tl.load(off_base_ptr + b)
186 is_bin = key == b
188 # Compute local prefix sum for stability
189 local_cumsum = tl.cumsum(is_bin.to(tl.int32), axis=0)
190 local_rank = local_cumsum - 1
192 dest_idx = base_offset + local_rank
193 write_mask = mask & is_bin
195 # Store Data
196 tl.store(arr_out_ptr + pid_m.to(tl.int64) * N + dest_idx, val, mask=write_mask)
198 # Store Index (using the pre-loaded value)
199 if idx_ptr is not None:
200 tl.store(
201 idx_out_ptr + pid_m.to(tl.int64) * N + dest_idx,
202 src_idx,
203 mask=write_mask,
204 )
207def radix_sort(arr, k_bits=4, descending=False):
208 # Determine dimensions
209 n = arr.shape[-1]
210 m = arr.numel() // n
211 dtype = arr.dtype
212 num_bits = 1 if dtype == torch.bool else (arr.itemsize * 8)
214 # Tuning parameters
215 # Increase k_bits to 8 for speed if compilation allows.
216 # BLOCK_N needs to balance register usage.
217 BLOCK_N = 512 if k_bits >= 8 else 1024
219 grid_n = triton.cdiv(n, BLOCK_N)
220 num_bins = 1 << k_bits
221 n_passes = triton.cdiv(num_bits, k_bits)
223 # Double buffering
224 # TODO: If we can modify inplace, we can arr_in = arr
225 arr_in = arr.clone()
226 arr_out = torch.empty_like(arr)
228 # Indices double buffering
229 indices_in = (
230 torch.arange(0, n, dtype=torch.int64, device=arr.device)
231 .broadcast_to(arr.shape)
232 .contiguous()
233 )
234 indices_out = torch.empty_like(indices_in)
236 # Count Buffer: (Total_Blocks, num_bins)
237 counts = torch.zeros((m * grid_n, num_bins), dtype=torch.int32, device=arr.device)
239 with torch_device_fn.device(arr.device):
240 for i in range(n_passes):
241 bit_offset = i * k_bits
243 # Step 1: Count
244 counts.zero_()
245 grid_total = m * grid_n
247 count_kernel[(grid_total,)](
248 arr_in,
249 counts,
250 m,
251 n,
252 grid_n, # Pass grid_n explicitly
253 k_bits,
254 bit_offset,
255 BLOCK_N,
256 descending,
257 )
259 # Step 2: Scan (Host Side with PyTorch)
260 # Calculate global offsets for Scatter
262 # View counts as (m, grid_n, bins)
263 cnt_view = counts.view(m, grid_n, num_bins)
265 # Total count per bin for each row m
266 # .sum() on int32 produces int64 in PyTorch
267 total_per_bin = cnt_view.sum(dim=1) # (m, bins)
269 # Global start position of each bin (Exclusive Scan over bins)
270 start_per_bin = torch.cumsum(total_per_bin, dim=1) - total_per_bin
272 # Offset of each block within its bin (Exclusive Scan over grid)
273 offset_in_bin = torch.cumsum(cnt_view, dim=1) - cnt_view
275 # Final Offsets = Bin_Start + Block_Offset_In_Bin
276 final_offsets = start_per_bin.unsqueeze(1) + offset_in_bin
277 final_offsets = final_offsets.view(m * grid_n, num_bins).contiguous()
279 # Force offsets to int32 to match kernel pointer expectations
280 final_offsets = final_offsets.to(torch.int32)
282 # Step 3: Scatter
283 scatter_kernel[(grid_total,)](
284 arr_in,
285 arr_out,
286 indices_in,
287 indices_out,
288 final_offsets,
289 m,
290 n,
291 grid_n, # Pass grid_n explicitly
292 k_bits,
293 bit_offset,
294 BLOCK_N,
295 descending,
296 )
298 # Swap buffers for next pass
299 arr_in, arr_out = arr_out, arr_in
300 indices_in, indices_out = indices_out, indices_in
302 return arr_in, indices_in
305@libentry()
306@triton.jit()
307def sort_kernel(
308 in_ptr,
309 out_ptr,
310 out_index_ptr,
311 N: tl.constexpr,
312 BLOCK_SIZE: tl.constexpr,
313 DESCENDING: tl.constexpr,
314 IS_FLOAT: tl.constexpr,
315):
316 cols = tl.arange(0, BLOCK_SIZE)
317 mask = cols < N
318 offset = tl.program_id(0) * N + cols
319 in_ptr += offset
320 out_ptr += offset
321 out_index_ptr += offset
323 if IS_FLOAT:
324 mask_val = _get_finfo_val(in_ptr.dtype.element_ty, return_max=not DESCENDING)
325 in_val = tl.load(in_ptr, mask=mask, other=mask_val)
326 else:
327 mask_val = _get_iinfo_val(in_ptr.dtype.element_ty, return_max=not DESCENDING)
328 in_val = tl.load(in_ptr, mask=mask, other=mask_val)
330 index_val = tl.arange(0, BLOCK_SIZE)
332 sorted_in_val, sorted_index_val = argsort(
333 in_val, index_val, 0, descending=DESCENDING
334 )
335 tl.store(out_ptr, sorted_in_val, mask=mask)
336 tl.store(out_index_ptr, sorted_index_val, mask=mask)
339def sort(inp, dim=-1, descending=False):
340 # We only implement stable radix sort here
341 logger.debug("GEMS SORT")
342 return sort_stable(inp, stable=False, dim=dim, descending=descending)
345def sort_stable(inp, *, stable, dim=-1, descending=False):
346 logger.debug("GEMS SORT.STABLE")
347 # We only implement stable radix sort here
348 _ = stable
349 sort_elem_cnt = inp.shape[dim]
350 if sort_elem_cnt == 1:
351 return inp, torch.zeros_like(inp, dtype=torch.int64)
353 if dim < 0:
354 dim = dim + inp.ndim
355 if dim != inp.ndim - 1:
356 inp = torch.movedim(inp, dim, -1).contiguous()
357 else:
358 inp = inp.contiguous()
360 # Ensure memory is contiguous even if dim was already last
361 # This fixes issues with non-contiguous inputs like slices or transposed tensors
362 if not inp.is_contiguous():
363 inp = inp.contiguous()
365 dtype = inp.dtype
366 # NOTE: You can increase this to 8 for higher performance on large arrays,
367 # but 4 is safer for compilation/resource limits.
368 num_bits_per_pass = 1 if dtype == torch.bool else 4
369 out, out_index = radix_sort(inp, num_bits_per_pass, descending)
371 if dim != inp.ndim - 1:
372 out = torch.movedim(out, -1, dim)
373 out_index = torch.movedim(out_index, -1, dim)
374 return out, out_index