Coverage for src/flag_gems/runtime/backend/_ascend/ops/sort.py: 0%
50 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-16 02:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-16 02:02 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems.ops.topk import _get_finfo_val, _get_iinfo_val, argsort
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry
12logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}')
15@libentry()
16@triton.jit()
17def sort_kernel(
18 in_ptr,
19 out_ptr,
20 out_index_ptr,
21 N: tl.constexpr,
22 BLOCK_SIZE: tl.constexpr,
23 DESCENDING: tl.constexpr,
24 IS_FLOAT: tl.constexpr,
25):
26 cols = tl.arange(0, BLOCK_SIZE)
27 mask = cols < N
28 offset = tl.program_id(0) * N + cols
29 in_ptr += offset
30 out_ptr += offset
31 out_index_ptr += offset
33 if IS_FLOAT:
34 mask_val = _get_finfo_val(in_ptr.dtype.element_ty, return_max=not DESCENDING)
35 in_val = tl.load(in_ptr, mask=mask, other=mask_val)
36 in_val = tl.where(in_val.dtype.is_fp64(), in_val, in_val.to(tl.float32))
37 else:
38 mask_val = _get_iinfo_val(in_ptr.dtype.element_ty, return_max=not DESCENDING)
39 in_val = tl.load(in_ptr, mask=mask, other=mask_val).to(tl.int32)
40 index_val = tl.arange(0, BLOCK_SIZE)
42 sorted_in_val, sorted_index_val = argsort(
43 in_val, index_val, 0, descending=DESCENDING
44 )
45 tl.store(out_ptr, sorted_in_val, mask=mask)
46 tl.store(out_index_ptr, sorted_index_val, mask=mask)
49def sort(inp, dim=-1, descending=False):
50 logger.debug("GEMS_ASCEND SORT")
51 sort_elem_cnt = inp.shape[dim]
52 if sort_elem_cnt == 1:
53 return inp, torch.zeros_like(inp, dtype=torch.int64)
54 elif sort_elem_cnt > 128: # TODO: Optimize implementation for large cases.
55 return torch.sort(inp, stable=False, dim=dim, descending=descending)
56 block_size = triton.next_power_of_2(sort_elem_cnt)
58 if dim < 0:
59 dim = dim + inp.ndim
60 if dim != inp.ndim - 1:
61 inp = torch.movedim(inp, dim, -1).contiguous()
62 else:
63 inp = inp.contiguous()
64 batch_size = math.prod(inp.shape) // sort_elem_cnt
66 out = torch.empty_like(inp)
67 out_index = torch.empty_like(inp, dtype=torch.int64)
69 with torch_device_fn.device(inp.device):
70 sort_kernel[batch_size,](
71 inp,
72 out,
73 out_index,
74 N=sort_elem_cnt,
75 BLOCK_SIZE=block_size,
76 DESCENDING=descending,
77 IS_FLOAT=inp.is_floating_point(),
78 )
80 if dim != inp.ndim - 1:
81 out = torch.movedim(out, -1, dim)
82 out_index = torch.movedim(out_index, -1, dim)
83 return out, out_index