Coverage for src/flag_gems/runtime/backend/_hygon/ops/sort.py: 0%

182 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-13 10:08 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.ops.topk import _get_finfo_val, _get_iinfo_val, argsort 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry 

10 

11logger = logging.getLogger(__name__) 

12 

13 

14def unwrap_if_constexpr(o): 

15 return o.value if isinstance(o, tl.constexpr) else o 

16 

17 

18@tl.constexpr 

19def get_int_t(num_bits: tl.constexpr, signed: tl.constexpr) -> tl.dtype: 

20 num_bits = unwrap_if_constexpr(num_bits) 

21 signed = unwrap_if_constexpr(signed) 

22 return tl.core.get_int_dtype(num_bits, signed) 

23 

24 

25@tl.constexpr 

26def one_zeros(num_bits: tl.constexpr) -> int: 

27 num_bits = unwrap_if_constexpr(num_bits) 

28 return 1 << (num_bits - 1) 

29 

30 

31@tl.constexpr 

32def zero_ones(num_bits: tl.constexpr) -> int: 

33 num_bits = unwrap_if_constexpr(num_bits) 

34 return (1 << (num_bits - 1)) - 1 

35 

36 

37@triton.jit 

38def uint_to_uint(x, descending: tl.constexpr = False): 

39 out = ~x if descending else x 

40 return out 

41 

42 

43@triton.jit 

44def int_to_uint(x, descending: tl.constexpr = False): 

45 num_bits: tl.constexpr = x.dtype.primitive_bitwidth 

46 udtype = get_int_t(num_bits, False) 

47 ux = tl.cast(x, udtype, bitcast=True) 

48 if descending: 

49 # 0111111....1 

50 bit_mask: tl.constexpr = zero_ones(num_bits) 

51 bit_mask_tensor = tl.full((), value=bit_mask, dtype=udtype) 

52 out = ux ^ bit_mask_tensor 

53 else: 

54 # 1000000...0 

55 sign_bit_mask: tl.constexpr = one_zeros(num_bits) 

56 sign_bit_mask_tensor = tl.full((), value=sign_bit_mask, dtype=udtype) 

57 out = ux ^ sign_bit_mask_tensor 

58 return out 

59 

60 

61@triton.jit 

62def floating_to_uint(x, descending: tl.constexpr = False): 

63 num_bits: tl.constexpr = x.dtype.primitive_bitwidth 

64 sdtype = get_int_t(num_bits, True) 

65 udtype = get_int_t(num_bits, False) 

66 sx = x.to(sdtype, bitcast=True) 

67 ux = x.to(udtype, bitcast=True) 

68 

69 sign_bit_mask_v: tl.constexpr = one_zeros(num_bits) 

70 sign_bit_mask = tl.full((), value=sign_bit_mask_v, dtype=udtype) 

71 # mind the dtype, right_shift for signed is arithmetic right shift 

72 # Fix for triton 3.1 or else `sx >> rshift_bits` is promoted to int32 

73 rshift_bits = tl.full((), value=num_bits - 1, dtype=sdtype) 

74 mask = sign_bit_mask | (sx >> rshift_bits).to(udtype, bitcast=True) 

75 tl.static_assert(mask.dtype == udtype, "type mismatch") 

76 # 1000000000...0 for positive 

77 # 1111111111...1 for negative 

78 if descending: 

79 out = ux ^ (~mask) 

80 else: 

81 out = ux ^ mask 

82 return out.to(udtype, bitcast=True) 

83 

84 

85@triton.jit 

86def convert_to_uint_preverse_order(x: tl.tensor, descending: tl.constexpr = False): 

87 # Explicitly handle bool to avoid ambiguity 

88 if x.dtype == tl.int1: 

89 out = uint_to_uint(x, descending) 

90 elif x.dtype.is_floating(): 

91 out = floating_to_uint(x, descending) 

92 elif x.dtype.is_int_signed(): 

93 out = int_to_uint(x, descending) 

94 elif x.dtype.is_int_unsigned(): 

95 out = uint_to_uint(x, descending) 

96 else: 

97 out = uint_to_uint(x, descending) 

98 return out 

99 

100 

101@triton.jit 

102def count_kernel( 

103 arr_ptr, 

104 count_ptr, # Output: (Grid, 2**k_bits) 

105 m, 

106 N, 

107 grid_n, # [FIX] Explicitly pass grid_n 

108 k_bits: tl.constexpr, 

109 bit_offset: tl.constexpr, 

110 BLOCK_N: tl.constexpr, 

111 descending: tl.constexpr, 

112): 

113 pid = tl.program_id(0) 

114 # Use explicitly passed grid_n to avoid inconsistency 

115 pid_m = pid // grid_n 

116 pid_n = pid % grid_n 

117 

118 n_offset = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

119 mask = n_offset < N 

120 

121 # [FIX] Use int64 for pointer arithmetic to be safe with large m 

122 val = tl.load(arr_ptr + pid_m.to(tl.int64) * N + n_offset, mask=mask, other=0) 

123 val_u = convert_to_uint_preverse_order(val, descending) 

124 

125 bfe_mask: tl.constexpr = (1 << k_bits) - 1 

126 key = (val_u >> bit_offset) & bfe_mask 

127 

128 # Cast key to int32 to match atomic_add pointer arithmetic requirements 

129 key = key.to(tl.int32) 

130 

131 NUM_BINS: tl.constexpr = 1 << k_bits 

132 off_base = pid * NUM_BINS 

133 tl.atomic_add(count_ptr + off_base + key, 1, mask=mask) 

134 

135 

136@triton.jit 

137def scatter_kernel( 

138 arr_ptr, 

139 arr_out_ptr, 

140 idx_ptr, # Optional: input indices 

141 idx_out_ptr, # Optional: output indices 

142 global_offsets_ptr, # Input: (Grid, 2**k_bits) - Precomputed prefix sum 

143 m, 

144 N, 

145 grid_n, # [FIX] Explicitly pass grid_n 

146 k_bits: tl.constexpr, 

147 bit_offset: tl.constexpr, 

148 BLOCK_N: tl.constexpr, 

149 descending: tl.constexpr, 

150): 

151 pid = tl.program_id(0) 

152 # Use explicitly passed grid_n 

153 pid_m = pid // grid_n 

154 pid_n = pid % grid_n 

155 

156 NUM_BINS: tl.constexpr = 1 << k_bits 

157 bfe_mask: tl.constexpr = NUM_BINS - 1 

158 

159 # Base destination index for this block (ptr to the start of bins for this block) 

160 off_base_ptr = global_offsets_ptr + pid * NUM_BINS 

161 

162 n_offset = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

163 mask = n_offset < N 

164 

165 # 1. Load Data 

166 # [FIX] Use int64 for pointer arithmetic 

167 val = tl.load(arr_ptr + pid_m.to(tl.int64) * N + n_offset, mask=mask, other=0) 

168 val_u = convert_to_uint_preverse_order(val, descending) 

169 key = (val_u >> bit_offset) & bfe_mask 

170 key = key.to(tl.int32) 

171 

172 # 2. Load Index (Pre-load OUTSIDE the loop) 

173 # The index belongs to the thread's element, it is invariant of the bin loop. 

174 # Loading it once here ensures stability and correctness. 

175 src_idx = tl.zeros((BLOCK_N,), dtype=tl.int64) 

176 if idx_ptr is not None: 

177 src_idx = tl.load( 

178 idx_ptr + pid_m.to(tl.int64) * N + n_offset, mask=mask, other=0 

179 ) 

180 

181 # 3. Calculate Local Rank and Scatter 

182 for b in range(0, NUM_BINS): 

183 # Load the scalar offset for the specific bin 

184 base_offset = tl.load(off_base_ptr + b) 

185 

186 is_bin = key == b 

187 

188 # Compute local prefix sum for stability 

189 local_cumsum = tl.cumsum(is_bin.to(tl.int32), axis=0) 

190 local_rank = local_cumsum - 1 

191 

192 dest_idx = base_offset + local_rank 

193 write_mask = mask & is_bin 

194 

195 # Store Data 

196 tl.store(arr_out_ptr + pid_m.to(tl.int64) * N + dest_idx, val, mask=write_mask) 

197 

198 # Store Index (using the pre-loaded value) 

199 if idx_ptr is not None: 

200 tl.store( 

201 idx_out_ptr + pid_m.to(tl.int64) * N + dest_idx, 

202 src_idx, 

203 mask=write_mask, 

204 ) 

205 

206 

207def radix_sort(arr, k_bits=4, descending=False): 

208 # Determine dimensions 

209 n = arr.shape[-1] 

210 m = arr.numel() // n 

211 dtype = arr.dtype 

212 num_bits = 1 if dtype == torch.bool else (arr.itemsize * 8) 

213 

214 # Tuning parameters 

215 # Increase k_bits to 8 for speed if compilation allows. 

216 # BLOCK_N needs to balance register usage. 

217 BLOCK_N = 512 if k_bits >= 8 else 1024 

218 

219 grid_n = triton.cdiv(n, BLOCK_N) 

220 num_bins = 1 << k_bits 

221 n_passes = triton.cdiv(num_bits, k_bits) 

222 

223 # Double buffering 

224 # TODO: If we can modify inplace, we can arr_in = arr 

225 arr_in = arr.clone() 

226 arr_out = torch.empty_like(arr) 

227 

228 # Indices double buffering 

229 indices_in = ( 

230 torch.arange(0, n, dtype=torch.int64, device=arr.device) 

231 .broadcast_to(arr.shape) 

232 .contiguous() 

233 ) 

234 indices_out = torch.empty_like(indices_in) 

235 

236 # Count Buffer: (Total_Blocks, num_bins) 

237 counts = torch.zeros((m * grid_n, num_bins), dtype=torch.int32, device=arr.device) 

238 

239 with torch_device_fn.device(arr.device): 

240 for i in range(n_passes): 

241 bit_offset = i * k_bits 

242 

243 # Step 1: Count 

244 counts.zero_() 

245 grid_total = m * grid_n 

246 

247 count_kernel[(grid_total,)]( 

248 arr_in, 

249 counts, 

250 m, 

251 n, 

252 grid_n, # Pass grid_n explicitly 

253 k_bits, 

254 bit_offset, 

255 BLOCK_N, 

256 descending, 

257 ) 

258 

259 # Step 2: Scan (Host Side with PyTorch) 

260 # Calculate global offsets for Scatter 

261 

262 # View counts as (m, grid_n, bins) 

263 cnt_view = counts.view(m, grid_n, num_bins) 

264 

265 # Total count per bin for each row m 

266 # .sum() on int32 produces int64 in PyTorch 

267 total_per_bin = cnt_view.sum(dim=1) # (m, bins) 

268 

269 # Global start position of each bin (Exclusive Scan over bins) 

270 start_per_bin = torch.cumsum(total_per_bin, dim=1) - total_per_bin 

271 

272 # Offset of each block within its bin (Exclusive Scan over grid) 

273 offset_in_bin = torch.cumsum(cnt_view, dim=1) - cnt_view 

274 

275 # Final Offsets = Bin_Start + Block_Offset_In_Bin 

276 final_offsets = start_per_bin.unsqueeze(1) + offset_in_bin 

277 final_offsets = final_offsets.view(m * grid_n, num_bins).contiguous() 

278 

279 # Force offsets to int32 to match kernel pointer expectations 

280 final_offsets = final_offsets.to(torch.int32) 

281 

282 # Step 3: Scatter 

283 scatter_kernel[(grid_total,)]( 

284 arr_in, 

285 arr_out, 

286 indices_in, 

287 indices_out, 

288 final_offsets, 

289 m, 

290 n, 

291 grid_n, # Pass grid_n explicitly 

292 k_bits, 

293 bit_offset, 

294 BLOCK_N, 

295 descending, 

296 ) 

297 

298 # Swap buffers for next pass 

299 arr_in, arr_out = arr_out, arr_in 

300 indices_in, indices_out = indices_out, indices_in 

301 

302 return arr_in, indices_in 

303 

304 

305@libentry() 

306@triton.jit() 

307def sort_kernel( 

308 in_ptr, 

309 out_ptr, 

310 out_index_ptr, 

311 N: tl.constexpr, 

312 BLOCK_SIZE: tl.constexpr, 

313 DESCENDING: tl.constexpr, 

314 IS_FLOAT: tl.constexpr, 

315): 

316 cols = tl.arange(0, BLOCK_SIZE) 

317 mask = cols < N 

318 offset = tl.program_id(0) * N + cols 

319 in_ptr += offset 

320 out_ptr += offset 

321 out_index_ptr += offset 

322 

323 if IS_FLOAT: 

324 mask_val = _get_finfo_val(in_ptr.dtype.element_ty, return_max=not DESCENDING) 

325 in_val = tl.load(in_ptr, mask=mask, other=mask_val) 

326 else: 

327 mask_val = _get_iinfo_val(in_ptr.dtype.element_ty, return_max=not DESCENDING) 

328 in_val = tl.load(in_ptr, mask=mask, other=mask_val) 

329 

330 index_val = tl.arange(0, BLOCK_SIZE) 

331 

332 sorted_in_val, sorted_index_val = argsort( 

333 in_val, index_val, 0, descending=DESCENDING 

334 ) 

335 tl.store(out_ptr, sorted_in_val, mask=mask) 

336 tl.store(out_index_ptr, sorted_index_val, mask=mask) 

337 

338 

339def sort(inp, dim=-1, descending=False): 

340 # We only implement stable radix sort here 

341 logger.debug("GEMS SORT") 

342 return sort_stable(inp, stable=False, dim=dim, descending=descending) 

343 

344 

345def sort_stable(inp, *, stable, dim=-1, descending=False): 

346 logger.debug("GEMS SORT.STABLE") 

347 # We only implement stable radix sort here 

348 _ = stable 

349 sort_elem_cnt = inp.shape[dim] 

350 if sort_elem_cnt == 1: 

351 return inp, torch.zeros_like(inp, dtype=torch.int64) 

352 

353 if dim < 0: 

354 dim = dim + inp.ndim 

355 if dim != inp.ndim - 1: 

356 inp = torch.movedim(inp, dim, -1).contiguous() 

357 else: 

358 inp = inp.contiguous() 

359 

360 # Ensure memory is contiguous even if dim was already last 

361 # This fixes issues with non-contiguous inputs like slices or transposed tensors 

362 if not inp.is_contiguous(): 

363 inp = inp.contiguous() 

364 

365 dtype = inp.dtype 

366 # NOTE: You can increase this to 8 for higher performance on large arrays, 

367 # but 4 is safer for compilation/resource limits. 

368 num_bits_per_pass = 1 if dtype == torch.bool else 4 

369 out, out_index = radix_sort(inp, num_bits_per_pass, descending) 

370 

371 if dim != inp.ndim - 1: 

372 out = torch.movedim(out, -1, dim) 

373 out_index = torch.movedim(out_index, -1, dim) 

374 return out, out_index