Coverage for src/flag_gems/runtime/backend/_cambricon/ops/sort.py: 0%
207 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-07 22:33 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-07 22:33 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
9logger = logging.getLogger(__name__)
12def unwrap_if_constexpr(o):
13 return o.value if isinstance(o, tl.constexpr) else o
16@tl.constexpr
17def get_int_t(num_bits: tl.constexpr, signed: tl.constexpr) -> tl.dtype:
18 num_bits = unwrap_if_constexpr(num_bits)
19 signed = unwrap_if_constexpr(signed)
20 return tl.core.get_int_dtype(num_bits, signed)
23@tl.constexpr
24def one_zeros(num_bits: tl.constexpr) -> int:
25 num_bits = unwrap_if_constexpr(num_bits)
26 return 1 << (num_bits - 1)
29@tl.constexpr
30def zero_ones(num_bits: tl.constexpr) -> int:
31 num_bits = unwrap_if_constexpr(num_bits)
32 return (1 << (num_bits - 1)) - 1
35@triton.jit
36def uint_to_uint(x, descending: tl.constexpr = False):
37 out = ~x if descending else x
38 return out
41@triton.jit
42def int_to_uint(x, descending: tl.constexpr = False):
43 num_bits: tl.constexpr = x.dtype.primitive_bitwidth
44 udtype = get_int_t(num_bits, False)
45 ux = tl.cast(x, udtype, bitcast=True)
46 if descending:
47 # 0111111....1
48 bit_mask: tl.constexpr = zero_ones(num_bits)
49 bit_mask_tensor = tl.full((), value=bit_mask, dtype=udtype)
50 out = ux ^ bit_mask_tensor
51 else:
52 # 1000000...0
53 sign_bit_mask: tl.constexpr = one_zeros(num_bits)
54 sign_bit_mask_tensor = tl.full((), value=sign_bit_mask, dtype=udtype)
55 out = ux ^ sign_bit_mask_tensor
56 return out
59@triton.jit
60def floating_to_uint(x, descending: tl.constexpr = False):
61 num_bits: tl.constexpr = x.dtype.primitive_bitwidth
62 sdtype = get_int_t(num_bits, True)
63 udtype = get_int_t(num_bits, False)
64 sx = x.to(sdtype, bitcast=True)
65 ux = x.to(udtype, bitcast=True)
67 sign_bit_mask_v: tl.constexpr = one_zeros(num_bits)
68 sign_bit_mask = tl.full((), value=sign_bit_mask_v, dtype=udtype)
69 # mind the dtype, right_shift for signed is arithmetic right shift
70 # Fix for triton 3.1 or else `sx >> rshift_bits` is promoted to int32
71 rshift_bits = tl.full((), value=num_bits - 1, dtype=sdtype)
72 mask = sign_bit_mask | (sx >> rshift_bits).to(udtype, bitcast=True)
73 tl.static_assert(mask.dtype == udtype, "type mismatch")
74 # 1000000000...0 for positive
75 # 1111111111...1 for negative
76 if descending:
77 out = ux ^ (~mask)
78 else:
79 out = ux ^ mask
80 return out.to(udtype, bitcast=True)
83@triton.jit
84def convert_to_uint_preverse_order(x: tl.tensor, descending: tl.constexpr = False):
85 if x.dtype.is_floating():
86 out = floating_to_uint(x, descending)
87 elif x.dtype.is_int_signed():
88 out = int_to_uint(x, descending)
89 elif x.dtype.is_int_unsigned():
90 out = uint_to_uint(x, descending)
91 return out
94@triton.jit
95def compute_global_hist_kernel(
96 arr_ptr,
97 out_ptr,
98 num_passes,
99 m,
100 n,
101 tiles_n_per_cta,
102 TILE_N: tl.constexpr,
103 TILE_R: tl.constexpr,
104 num_bits_per_pass: tl.constexpr,
105 descending: tl.constexpr,
106 M_PER_SPLIT: tl.constexpr,
107):
108 # grid layout:
109 # program_id(0) -> split id s
110 # program_id(1) -> pid_n
111 # program_id(2) -> pid_m_idx (index inside split)
112 s = tl.program_id(0)
113 pid_n = tl.program_id(1)
114 pid_m_idx = tl.program_id(2)
115 pid_m = s * M_PER_SPLIT + pid_m_idx
116 if pid_m >= m:
117 return
119 # arr_ptr: (m, n)
120 # out_ptr: (m, n_passes, r), where r = 2 ** k_bits is the number of bins
121 r: tl.constexpr = 2**num_bits_per_pass
122 bfe_mask: tl.constexpr = (1 << num_bits_per_pass) - 1
123 CTA_TILE_N: tl.constexpr = TILE_N * tiles_n_per_cta
124 cta_n_start = CTA_TILE_N * pid_n
125 cta_n_end = tl.minimum(cta_n_start + CTA_TILE_N, n)
127 for p in range(0, num_passes):
128 bit_offset = p * num_bits_per_pass
129 for r_start in range(0, r, TILE_R):
130 bin_indices = r_start + tl.arange(0, TILE_R)
131 acc = tl.zeros((TILE_R, TILE_N), dtype=tl.int64)
132 for n_start in range(cta_n_start, cta_n_end, TILE_N):
133 n_offsets = n_start + tl.arange(0, TILE_N)
134 mask = n_offsets < cta_n_end
135 arr = tl.load(arr_ptr + pid_m * n + n_offsets, mask=mask)
136 arr = convert_to_uint_preverse_order(arr, descending)
137 key = (arr >> bit_offset) & bfe_mask
138 matches = tl.where(mask, (bin_indices[:, None] == key), False)
139 acc += matches
140 local_sum = tl.sum(acc, axis=1)
141 tl.atomic_add(
142 out_ptr + pid_m * num_passes * r + p * r + bin_indices,
143 local_sum,
144 sem="relaxed",
145 )
148@triton.jit
149def sweep(
150 arr_ptr,
151 associate_arr_ptr, # inputs: (key & value)
152 out_ptr,
153 associate_out_ptr, # outputs: (key & value)
154 excumsum_bins_ptr,
155 status_ptr, # aux input and status
156 n_passes,
157 pass_id,
158 bit_offset,
159 m,
160 N,
161 OUT_N,
162 TILE_N: tl.constexpr,
163 TILE_R: tl.constexpr,
164 k_bits: tl.constexpr,
165 descending: tl.constexpr,
166 M_PER_SPLIT: tl.constexpr,
167):
168 # r: num_bins = 2 ** k_bits
169 # OUT_N: grid_n = cdiv(N, )
171 # arr_ptr: (m, N)
172 # out_ptr: (m, N)
173 # excumsum_bins_ptr: (m, n_passes, r)
174 # flag_ptr: (m, r, OUT_N)
176 # grid: (S, grid_n, grid_r)
177 # program_id(0) -> split id (s)
178 # program_id(1) -> pid_n
179 # program_id(2) -> pid_r
181 s = tl.program_id(0)
182 pid_n = tl.program_id(1)
183 pid_r = tl.program_id(2)
185 # bit masks
186 aggregate_mask: tl.constexpr = 1 << 30
187 inclusive_prefix_mask: tl.constexpr = 1 << 31
188 v_mask: tl.constexpr = (1 << 30) - 1
189 bfe_mask: tl.constexpr = (1 << k_bits) - 1 # a.k.a. 2 ** k_bits - 1
191 # initialize flag to zero-local sum is not ready
192 r: tl.constexpr = 2**k_bits
193 cta_r_start = pid_r * TILE_R
194 cta_r_end = tl.minimum(cta_r_start + TILE_R, r)
196 for local_pid_m_idx in range(0, M_PER_SPLIT):
197 pid_m = s * M_PER_SPLIT + local_pid_m_idx
198 if pid_m < m:
199 n_offsets = pid_n * TILE_N + tl.arange(0, TILE_N) # (TILE_N, )
200 mask = n_offsets < N
201 arr = tl.load(arr_ptr + pid_m * N + n_offsets, mask=mask)
202 arr_u = convert_to_uint_preverse_order(arr, descending)
203 key = (arr_u >> bit_offset) & bfe_mask # (TILE_N, )
205 # since triton can only use scalar as condition, loop by bin_index
206 # status must be pre zero-initialized, or else we have to initialize it
207 for bin_index in range(cta_r_start, cta_r_end):
208 matches = tl.where(mask, key == bin_index, False) # (TILE_N, ) bool
209 # cta level cumsum per bin
210 # CAUTION: tl.sum in triton 3.2 does not promote type
211 local_sum = tl.sum(matches.to(tl.uint32), axis=0)
212 pack0 = aggregate_mask | local_sum
213 status_offset = pid_m * (r * OUT_N) + bin_index * OUT_N + pid_n
214 tl.store(status_ptr + status_offset, pack0, cache_modifier=".cg")
216 # decoupled lookback
217 exclusive_prefix = tl.zeros((), dtype=tl.uint32)
218 i_lookback = pid_n - 1
219 while i_lookback >= 0:
220 flag_offset_i = pid_m * (r * OUT_N) + bin_index * OUT_N + i_lookback
221 pack1 = tl.load(status_ptr + flag_offset_i, volatile=True) # uin32
222 while pack1 == 0:
223 pack1 = tl.load(status_ptr + flag_offset_i, volatile=True)
224 exclusive_prefix += pack1 & v_mask
225 if (pack1 & aggregate_mask) == aggregate_mask:
226 i_lookback -= 1
227 else:
228 i_lookback = -1
229 pack2 = inclusive_prefix_mask | (exclusive_prefix + local_sum)
230 tl.store(status_ptr + status_offset, pack2, cache_modifier=".cg")
232 local_ex_cumsum = (
233 tl.cumsum(matches.to(tl.uint32), axis=0) - matches
234 ) # (TILE_N, )
235 ex_cumsum_in_bin = (
236 exclusive_prefix + local_ex_cumsum
237 ) # global ex_cumsum_in_bin (TILE_N, )
239 # ex_cumsum_bins (m, n_passes, r)
240 ex_cumsum_bins = tl.load(
241 excumsum_bins_ptr + pid_m * (n_passes * r) + pass_id * r + bin_index
242 ) # scalar
243 pos = ex_cumsum_bins + ex_cumsum_in_bin # (TILE_N, )
245 # scatter
246 tl.store(out_ptr + pid_m * N + pos, arr, mask=matches)
247 if associate_arr_ptr is not None:
248 associate_arr = tl.load(
249 associate_arr_ptr + pid_m * N + n_offsets, mask=mask
250 )
251 tl.store(
252 associate_out_ptr + pid_m * N + pos, associate_arr, mask=matches
253 )
256def radix_sort(arr, k_bits=8, descending=False):
257 n = arr.shape[-1]
258 m = arr.numel() // n
259 assert n < (1 << 30), "we have not implemented 2**30 per launch"
260 dtype = arr.dtype
261 num_bits = 1 if dtype == torch.bool else (arr.itemsize * 8)
263 if arr.dtype == torch.int64:
264 TILE_N = 512
265 else:
266 TILE_N = 1024
267 tiles_n_per_cta = 8
268 CTA_TILE_N = tiles_n_per_cta * TILE_N
270 num_bins = 2**k_bits
271 n_passes = triton.cdiv(num_bits, k_bits)
272 TILE_R = 16
274 grid_n = triton.cdiv(n, CTA_TILE_N)
276 MAX_GRID = 65535
277 S = (m + MAX_GRID - 1) // MAX_GRID
278 M_PER_SPLIT = triton.cdiv(m, S)
279 # grid_for_global_hist: 3D grid (S, grid_n, M_PER_SPLIT)
280 grid_for_global_hist = (S, grid_n, M_PER_SPLIT)
282 with torch_device_fn.device(arr.device):
283 global_hist = torch.zeros(
284 (m, n_passes, num_bins), device=arr.device, dtype=torch.int32
285 )
286 # launch compute_global_hist_kernel with M_PER_SPLIT passed
287 compute_global_hist_kernel[grid_for_global_hist](
288 arr,
289 global_hist,
290 n_passes,
291 m,
292 n,
293 tiles_n_per_cta,
294 TILE_N,
295 TILE_R,
296 k_bits,
297 descending,
298 M_PER_SPLIT,
299 )
301 # ex_cumsum_bins shape: (m, n_passes, num_bins)
302 ex_cumsum_bins = torch.empty_like(global_hist, dtype=torch.uint32)
303 # For each split, compute cumsum on the slice [s_start : s_end]
304 for s in range(S):
305 s_start = s * M_PER_SPLIT
306 s_end = min(m, s_start + M_PER_SPLIT)
307 if s_start >= s_end:
308 continue
309 # slice: shape (m_chunk, n_passes, num_bins)
310 slice_hist = global_hist[s_start:s_end] # this is a view
311 # compute cumsum over last dim for this slice only (smaller kernel)
312 slice_ex_cumsum = torch.cumsum(slice_hist, dim=-1) - slice_hist
313 # write back to ex_cumsum_bins (and cast to uint32)
314 ex_cumsum_bins[s_start:s_end] = slice_ex_cumsum.to(torch.uint32)
316 # sort
317 arr_in = torch.clone(arr)
318 indices_in = (
319 torch.arange(0, n, dtype=torch.int64, device=arr_in.device)
320 .broadcast_to(arr.shape)
321 .contiguous()
322 )
323 arr_out = torch.empty_like(arr)
324 indices_out = torch.empty_like(indices_in)
326 TILE_R = 8
327 grid_r = triton.cdiv(num_bins, TILE_R)
328 TILE_N = 3072
329 grid_n = triton.cdiv(n, TILE_N)
331 # grid_for_sweep using same S (splits)
332 grid_for_sweep = (S, grid_n, grid_r)
334 status = torch.empty(
335 (m, num_bins, grid_n), device=arr.device, dtype=torch.uint32
336 )
338 for i in range(0, n_passes):
339 bit_offset = i * k_bits
340 status.zero_()
341 sweep[grid_for_sweep](
342 arr_in,
343 indices_in,
344 arr_out,
345 indices_out,
346 ex_cumsum_bins,
347 status,
348 n_passes,
349 i,
350 bit_offset,
351 m,
352 n,
353 grid_n,
354 TILE_N,
355 TILE_R,
356 k_bits,
357 descending,
358 M_PER_SPLIT,
359 )
360 arr_in, arr_out = arr_out, arr_in
361 indices_in, indices_out = indices_out, indices_in
363 return arr_in, indices_in
366def sort(inp, dim=-1, descending=False):
367 # We only implement stable radix sort here
368 logger.debug("GEMS_CAMBRICON SORT")
369 return sort_stable(inp, stable=False, dim=dim, descending=descending)
372def sort_stable(inp, *, stable, dim=-1, descending=False):
373 logger.debug("GEMS_CAMBRICON SORT.STABLE")
374 # We only implement stable radix sort here
375 _ = stable
376 sort_elem_cnt = inp.shape[dim]
377 if sort_elem_cnt == 1:
378 return inp, torch.zeros_like(inp, dtype=torch.int64)
380 if dim < 0:
381 dim = dim + inp.ndim
382 if dim != inp.ndim - 1:
383 inp = torch.movedim(inp, dim, -1).contiguous()
384 else:
385 inp = inp.contiguous()
387 dtype = inp.dtype
388 num_bits_per_pass = 1 if dtype == torch.bool else 4
389 out, out_index = radix_sort(inp, num_bits_per_pass, descending)
391 if dim != inp.ndim - 1:
392 out = torch.movedim(out, -1, dim)
393 out_index = torch.movedim(out_index, -1, dim)
394 return out, out_index