Coverage for src/flag_gems/runtime/backend/_cambricon/ops/sort.py: 0%

207 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-27 02:51 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8 

9logger = logging.getLogger(__name__) 

10 

11 

12def unwrap_if_constexpr(o): 

13 return o.value if isinstance(o, tl.constexpr) else o 

14 

15 

16@tl.constexpr 

17def get_int_t(num_bits: tl.constexpr, signed: tl.constexpr) -> tl.dtype: 

18 num_bits = unwrap_if_constexpr(num_bits) 

19 signed = unwrap_if_constexpr(signed) 

20 return tl.core.get_int_dtype(num_bits, signed) 

21 

22 

23@tl.constexpr 

24def one_zeros(num_bits: tl.constexpr) -> int: 

25 num_bits = unwrap_if_constexpr(num_bits) 

26 return 1 << (num_bits - 1) 

27 

28 

29@tl.constexpr 

30def zero_ones(num_bits: tl.constexpr) -> int: 

31 num_bits = unwrap_if_constexpr(num_bits) 

32 return (1 << (num_bits - 1)) - 1 

33 

34 

35@triton.jit 

36def uint_to_uint(x, descending: tl.constexpr = False): 

37 out = ~x if descending else x 

38 return out 

39 

40 

41@triton.jit 

42def int_to_uint(x, descending: tl.constexpr = False): 

43 num_bits: tl.constexpr = x.dtype.primitive_bitwidth 

44 udtype = get_int_t(num_bits, False) 

45 ux = tl.cast(x, udtype, bitcast=True) 

46 if descending: 

47 # 0111111....1 

48 bit_mask: tl.constexpr = zero_ones(num_bits) 

49 bit_mask_tensor = tl.full((), value=bit_mask, dtype=udtype) 

50 out = ux ^ bit_mask_tensor 

51 else: 

52 # 1000000...0 

53 sign_bit_mask: tl.constexpr = one_zeros(num_bits) 

54 sign_bit_mask_tensor = tl.full((), value=sign_bit_mask, dtype=udtype) 

55 out = ux ^ sign_bit_mask_tensor 

56 return out 

57 

58 

59@triton.jit 

60def floating_to_uint(x, descending: tl.constexpr = False): 

61 num_bits: tl.constexpr = x.dtype.primitive_bitwidth 

62 sdtype = get_int_t(num_bits, True) 

63 udtype = get_int_t(num_bits, False) 

64 sx = x.to(sdtype, bitcast=True) 

65 ux = x.to(udtype, bitcast=True) 

66 

67 sign_bit_mask_v: tl.constexpr = one_zeros(num_bits) 

68 sign_bit_mask = tl.full((), value=sign_bit_mask_v, dtype=udtype) 

69 # mind the dtype, right_shift for signed is arithmetic right shift 

70 # Fix for triton 3.1 or else `sx >> rshift_bits` is promoted to int32 

71 rshift_bits = tl.full((), value=num_bits - 1, dtype=sdtype) 

72 mask = sign_bit_mask | (sx >> rshift_bits).to(udtype, bitcast=True) 

73 tl.static_assert(mask.dtype == udtype, "type mismatch") 

74 # 1000000000...0 for positive 

75 # 1111111111...1 for negative 

76 if descending: 

77 out = ux ^ (~mask) 

78 else: 

79 out = ux ^ mask 

80 return out.to(udtype, bitcast=True) 

81 

82 

83@triton.jit 

84def convert_to_uint_preverse_order(x: tl.tensor, descending: tl.constexpr = False): 

85 if x.dtype.is_floating(): 

86 out = floating_to_uint(x, descending) 

87 elif x.dtype.is_int_signed(): 

88 out = int_to_uint(x, descending) 

89 elif x.dtype.is_int_unsigned(): 

90 out = uint_to_uint(x, descending) 

91 return out 

92 

93 

94@triton.jit 

95def compute_global_hist_kernel( 

96 arr_ptr, 

97 out_ptr, 

98 num_passes, 

99 m, 

100 n, 

101 tiles_n_per_cta, 

102 TILE_N: tl.constexpr, 

103 TILE_R: tl.constexpr, 

104 num_bits_per_pass: tl.constexpr, 

105 descending: tl.constexpr, 

106 M_PER_SPLIT: tl.constexpr, 

107): 

108 # grid layout: 

109 # program_id(0) -> split id s 

110 # program_id(1) -> pid_n 

111 # program_id(2) -> pid_m_idx (index inside split) 

112 s = tl.program_id(0) 

113 pid_n = tl.program_id(1) 

114 pid_m_idx = tl.program_id(2) 

115 pid_m = s * M_PER_SPLIT + pid_m_idx 

116 if pid_m >= m: 

117 return 

118 

119 # arr_ptr: (m, n) 

120 # out_ptr: (m, n_passes, r), where r = 2 ** k_bits is the number of bins 

121 r: tl.constexpr = 2**num_bits_per_pass 

122 bfe_mask: tl.constexpr = (1 << num_bits_per_pass) - 1 

123 CTA_TILE_N: tl.constexpr = TILE_N * tiles_n_per_cta 

124 cta_n_start = CTA_TILE_N * pid_n 

125 cta_n_end = tl.minimum(cta_n_start + CTA_TILE_N, n) 

126 

127 for p in range(0, num_passes): 

128 bit_offset = p * num_bits_per_pass 

129 for r_start in range(0, r, TILE_R): 

130 bin_indices = r_start + tl.arange(0, TILE_R) 

131 acc = tl.zeros((TILE_R, TILE_N), dtype=tl.int64) 

132 for n_start in range(cta_n_start, cta_n_end, TILE_N): 

133 n_offsets = n_start + tl.arange(0, TILE_N) 

134 mask = n_offsets < cta_n_end 

135 arr = tl.load(arr_ptr + pid_m * n + n_offsets, mask=mask) 

136 arr = convert_to_uint_preverse_order(arr, descending) 

137 key = (arr >> bit_offset) & bfe_mask 

138 matches = tl.where(mask, (bin_indices[:, None] == key), False) 

139 acc += matches 

140 local_sum = tl.sum(acc, axis=1) 

141 tl.atomic_add( 

142 out_ptr + pid_m * num_passes * r + p * r + bin_indices, 

143 local_sum, 

144 sem="relaxed", 

145 ) 

146 

147 

148@triton.jit 

149def sweep( 

150 arr_ptr, 

151 associate_arr_ptr, # inputs: (key & value) 

152 out_ptr, 

153 associate_out_ptr, # outputs: (key & value) 

154 excumsum_bins_ptr, 

155 status_ptr, # aux input and status 

156 n_passes, 

157 pass_id, 

158 bit_offset, 

159 m, 

160 N, 

161 OUT_N, 

162 TILE_N: tl.constexpr, 

163 TILE_R: tl.constexpr, 

164 k_bits: tl.constexpr, 

165 descending: tl.constexpr, 

166 M_PER_SPLIT: tl.constexpr, 

167): 

168 # r: num_bins = 2 ** k_bits 

169 # OUT_N: grid_n = cdiv(N, ) 

170 

171 # arr_ptr: (m, N) 

172 # out_ptr: (m, N) 

173 # excumsum_bins_ptr: (m, n_passes, r) 

174 # flag_ptr: (m, r, OUT_N) 

175 

176 # grid: (S, grid_n, grid_r) 

177 # program_id(0) -> split id (s) 

178 # program_id(1) -> pid_n 

179 # program_id(2) -> pid_r 

180 

181 s = tl.program_id(0) 

182 pid_n = tl.program_id(1) 

183 pid_r = tl.program_id(2) 

184 

185 # bit masks 

186 aggregate_mask: tl.constexpr = 1 << 30 

187 inclusive_prefix_mask: tl.constexpr = 1 << 31 

188 v_mask: tl.constexpr = (1 << 30) - 1 

189 bfe_mask: tl.constexpr = (1 << k_bits) - 1 # a.k.a. 2 ** k_bits - 1 

190 

191 # initialize flag to zero-local sum is not ready 

192 r: tl.constexpr = 2**k_bits 

193 cta_r_start = pid_r * TILE_R 

194 cta_r_end = tl.minimum(cta_r_start + TILE_R, r) 

195 

196 for local_pid_m_idx in range(0, M_PER_SPLIT): 

197 pid_m = s * M_PER_SPLIT + local_pid_m_idx 

198 if pid_m < m: 

199 n_offsets = pid_n * TILE_N + tl.arange(0, TILE_N) # (TILE_N, ) 

200 mask = n_offsets < N 

201 arr = tl.load(arr_ptr + pid_m * N + n_offsets, mask=mask) 

202 arr_u = convert_to_uint_preverse_order(arr, descending) 

203 key = (arr_u >> bit_offset) & bfe_mask # (TILE_N, ) 

204 

205 # since triton can only use scalar as condition, loop by bin_index 

206 # status must be pre zero-initialized, or else we have to initialize it 

207 for bin_index in range(cta_r_start, cta_r_end): 

208 matches = tl.where(mask, key == bin_index, False) # (TILE_N, ) bool 

209 # cta level cumsum per bin 

210 # CAUTION: tl.sum in triton 3.2 does not promote type 

211 local_sum = tl.sum(matches.to(tl.uint32), axis=0) 

212 pack0 = aggregate_mask | local_sum 

213 status_offset = pid_m * (r * OUT_N) + bin_index * OUT_N + pid_n 

214 tl.store(status_ptr + status_offset, pack0, cache_modifier=".cg") 

215 

216 # decoupled lookback 

217 exclusive_prefix = tl.zeros((), dtype=tl.uint32) 

218 i_lookback = pid_n - 1 

219 while i_lookback >= 0: 

220 flag_offset_i = pid_m * (r * OUT_N) + bin_index * OUT_N + i_lookback 

221 pack1 = tl.load(status_ptr + flag_offset_i, volatile=True) # uin32 

222 while pack1 == 0: 

223 pack1 = tl.load(status_ptr + flag_offset_i, volatile=True) 

224 exclusive_prefix += pack1 & v_mask 

225 if (pack1 & aggregate_mask) == aggregate_mask: 

226 i_lookback -= 1 

227 else: 

228 i_lookback = -1 

229 pack2 = inclusive_prefix_mask | (exclusive_prefix + local_sum) 

230 tl.store(status_ptr + status_offset, pack2, cache_modifier=".cg") 

231 

232 local_ex_cumsum = ( 

233 tl.cumsum(matches.to(tl.uint32), axis=0) - matches 

234 ) # (TILE_N, ) 

235 ex_cumsum_in_bin = ( 

236 exclusive_prefix + local_ex_cumsum 

237 ) # global ex_cumsum_in_bin (TILE_N, ) 

238 

239 # ex_cumsum_bins (m, n_passes, r) 

240 ex_cumsum_bins = tl.load( 

241 excumsum_bins_ptr + pid_m * (n_passes * r) + pass_id * r + bin_index 

242 ) # scalar 

243 pos = ex_cumsum_bins + ex_cumsum_in_bin # (TILE_N, ) 

244 

245 # scatter 

246 tl.store(out_ptr + pid_m * N + pos, arr, mask=matches) 

247 if associate_arr_ptr is not None: 

248 associate_arr = tl.load( 

249 associate_arr_ptr + pid_m * N + n_offsets, mask=mask 

250 ) 

251 tl.store( 

252 associate_out_ptr + pid_m * N + pos, associate_arr, mask=matches 

253 ) 

254 

255 

256def radix_sort(arr, k_bits=8, descending=False): 

257 n = arr.shape[-1] 

258 m = arr.numel() // n 

259 assert n < (1 << 30), "we have not implemented 2**30 per launch" 

260 dtype = arr.dtype 

261 num_bits = 1 if dtype == torch.bool else (arr.itemsize * 8) 

262 

263 if arr.dtype == torch.int64: 

264 TILE_N = 512 

265 else: 

266 TILE_N = 1024 

267 tiles_n_per_cta = 8 

268 CTA_TILE_N = tiles_n_per_cta * TILE_N 

269 

270 num_bins = 2**k_bits 

271 n_passes = triton.cdiv(num_bits, k_bits) 

272 TILE_R = 16 

273 

274 grid_n = triton.cdiv(n, CTA_TILE_N) 

275 

276 MAX_GRID = 65535 

277 S = (m + MAX_GRID - 1) // MAX_GRID 

278 M_PER_SPLIT = triton.cdiv(m, S) 

279 # grid_for_global_hist: 3D grid (S, grid_n, M_PER_SPLIT) 

280 grid_for_global_hist = (S, grid_n, M_PER_SPLIT) 

281 

282 with torch_device_fn.device(arr.device): 

283 global_hist = torch.zeros( 

284 (m, n_passes, num_bins), device=arr.device, dtype=torch.int32 

285 ) 

286 # launch compute_global_hist_kernel with M_PER_SPLIT passed 

287 compute_global_hist_kernel[grid_for_global_hist]( 

288 arr, 

289 global_hist, 

290 n_passes, 

291 m, 

292 n, 

293 tiles_n_per_cta, 

294 TILE_N, 

295 TILE_R, 

296 k_bits, 

297 descending, 

298 M_PER_SPLIT, 

299 ) 

300 

301 # ex_cumsum_bins shape: (m, n_passes, num_bins) 

302 ex_cumsum_bins = torch.empty_like(global_hist, dtype=torch.uint32) 

303 # For each split, compute cumsum on the slice [s_start : s_end] 

304 for s in range(S): 

305 s_start = s * M_PER_SPLIT 

306 s_end = min(m, s_start + M_PER_SPLIT) 

307 if s_start >= s_end: 

308 continue 

309 # slice: shape (m_chunk, n_passes, num_bins) 

310 slice_hist = global_hist[s_start:s_end] # this is a view 

311 # compute cumsum over last dim for this slice only (smaller kernel) 

312 slice_ex_cumsum = torch.cumsum(slice_hist, dim=-1) - slice_hist 

313 # write back to ex_cumsum_bins (and cast to uint32) 

314 ex_cumsum_bins[s_start:s_end] = slice_ex_cumsum.to(torch.uint32) 

315 

316 # sort 

317 arr_in = torch.clone(arr) 

318 indices_in = ( 

319 torch.arange(0, n, dtype=torch.int64, device=arr_in.device) 

320 .broadcast_to(arr.shape) 

321 .contiguous() 

322 ) 

323 arr_out = torch.empty_like(arr) 

324 indices_out = torch.empty_like(indices_in) 

325 

326 TILE_R = 8 

327 grid_r = triton.cdiv(num_bins, TILE_R) 

328 TILE_N = 3072 

329 grid_n = triton.cdiv(n, TILE_N) 

330 

331 # grid_for_sweep using same S (splits) 

332 grid_for_sweep = (S, grid_n, grid_r) 

333 

334 status = torch.empty( 

335 (m, num_bins, grid_n), device=arr.device, dtype=torch.uint32 

336 ) 

337 

338 for i in range(0, n_passes): 

339 bit_offset = i * k_bits 

340 status.zero_() 

341 sweep[grid_for_sweep]( 

342 arr_in, 

343 indices_in, 

344 arr_out, 

345 indices_out, 

346 ex_cumsum_bins, 

347 status, 

348 n_passes, 

349 i, 

350 bit_offset, 

351 m, 

352 n, 

353 grid_n, 

354 TILE_N, 

355 TILE_R, 

356 k_bits, 

357 descending, 

358 M_PER_SPLIT, 

359 ) 

360 arr_in, arr_out = arr_out, arr_in 

361 indices_in, indices_out = indices_out, indices_in 

362 

363 return arr_in, indices_in 

364 

365 

366def sort(inp, dim=-1, descending=False): 

367 # We only implement stable radix sort here 

368 logger.debug("GEMS_CAMBRICON SORT") 

369 return sort_stable(inp, stable=False, dim=dim, descending=descending) 

370 

371 

372def sort_stable(inp, *, stable, dim=-1, descending=False): 

373 logger.debug("GEMS_CAMBRICON SORT.STABLE") 

374 # We only implement stable radix sort here 

375 _ = stable 

376 sort_elem_cnt = inp.shape[dim] 

377 if sort_elem_cnt == 1: 

378 return inp, torch.zeros_like(inp, dtype=torch.int64) 

379 

380 if dim < 0: 

381 dim = dim + inp.ndim 

382 if dim != inp.ndim - 1: 

383 inp = torch.movedim(inp, dim, -1).contiguous() 

384 else: 

385 inp = inp.contiguous() 

386 

387 dtype = inp.dtype 

388 num_bits_per_pass = 1 if dtype == torch.bool else 4 

389 out, out_index = radix_sort(inp, num_bits_per_pass, descending) 

390 

391 if dim != inp.ndim - 1: 

392 out = torch.movedim(out, -1, dim) 

393 out_index = torch.movedim(out_index, -1, dim) 

394 return out, out_index