Coverage for src/flag_gems/runtime/backend/_ascend/ops/sort.py: 0%

50 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-18 02:36 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.ops.topk import _get_finfo_val, _get_iinfo_val, argsort 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry 

11 

12logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}') 

13 

14 

15@libentry() 

16@triton.jit() 

17def sort_kernel( 

18 in_ptr, 

19 out_ptr, 

20 out_index_ptr, 

21 N: tl.constexpr, 

22 BLOCK_SIZE: tl.constexpr, 

23 DESCENDING: tl.constexpr, 

24 IS_FLOAT: tl.constexpr, 

25): 

26 cols = tl.arange(0, BLOCK_SIZE) 

27 mask = cols < N 

28 offset = tl.program_id(0) * N + cols 

29 in_ptr += offset 

30 out_ptr += offset 

31 out_index_ptr += offset 

32 

33 if IS_FLOAT: 

34 mask_val = _get_finfo_val(in_ptr.dtype.element_ty, return_max=not DESCENDING) 

35 in_val = tl.load(in_ptr, mask=mask, other=mask_val) 

36 in_val = tl.where(in_val.dtype.is_fp64(), in_val, in_val.to(tl.float32)) 

37 else: 

38 mask_val = _get_iinfo_val(in_ptr.dtype.element_ty, return_max=not DESCENDING) 

39 in_val = tl.load(in_ptr, mask=mask, other=mask_val).to(tl.int32) 

40 index_val = tl.arange(0, BLOCK_SIZE) 

41 

42 sorted_in_val, sorted_index_val = argsort( 

43 in_val, index_val, 0, descending=DESCENDING 

44 ) 

45 tl.store(out_ptr, sorted_in_val, mask=mask) 

46 tl.store(out_index_ptr, sorted_index_val, mask=mask) 

47 

48 

49def sort(inp, dim=-1, descending=False): 

50 logger.debug("GEMS_ASCEND SORT") 

51 sort_elem_cnt = inp.shape[dim] 

52 if sort_elem_cnt == 1: 

53 return inp, torch.zeros_like(inp, dtype=torch.int64) 

54 elif sort_elem_cnt > 128: # TODO: Optimize implementation for large cases. 

55 return torch.sort(inp, stable=False, dim=dim, descending=descending) 

56 block_size = triton.next_power_of_2(sort_elem_cnt) 

57 

58 if dim < 0: 

59 dim = dim + inp.ndim 

60 if dim != inp.ndim - 1: 

61 inp = torch.movedim(inp, dim, -1).contiguous() 

62 else: 

63 inp = inp.contiguous() 

64 batch_size = math.prod(inp.shape) // sort_elem_cnt 

65 

66 out = torch.empty_like(inp) 

67 out_index = torch.empty_like(inp, dtype=torch.int64) 

68 

69 with torch_device_fn.device(inp.device): 

70 sort_kernel[batch_size,]( 

71 inp, 

72 out, 

73 out_index, 

74 N=sort_elem_cnt, 

75 BLOCK_SIZE=block_size, 

76 DESCENDING=descending, 

77 IS_FLOAT=inp.is_floating_point(), 

78 ) 

79 

80 if dim != inp.ndim - 1: 

81 out = torch.movedim(out, -1, dim) 

82 out_index = torch.movedim(out_index, -1, dim) 

83 return out, out_index