Coverage for src/flag_gems/runtime/backend/_ascend/ops/index.py: 0%

59 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-17 02:35 +0800

1import logging 

2from typing import List 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}') 

9 

10 

11@triton.jit 

12def index_kernel_func( 

13 input_ptr, 

14 stride: tl.constexpr, 

15 index_len, 

16 index_ptr, 

17 out_ptr, 

18 BLOCK_SIZE: tl.constexpr, 

19 MAX_DATA_SIZE: tl.constexpr, 

20): 

21 pid0 = tl.program_id(axis=0) 

22 

23 for i in range(0, BLOCK_SIZE): 

24 offset = pid0 * BLOCK_SIZE + i 

25 

26 if offset < index_len: 

27 in_start_index = tl.load(index_ptr + offset) * stride 

28 out_start_offset = offset * stride 

29 loop_num = (stride - 1) // MAX_DATA_SIZE + 1 

30 

31 for loop_idx in range(0, loop_num): 

32 inner_offset = loop_idx * MAX_DATA_SIZE + tl.arange(0, MAX_DATA_SIZE) 

33 mask = inner_offset < stride 

34 cur_value = tl.load( 

35 input_ptr + in_start_index + inner_offset, mask=mask 

36 ) 

37 tl.store( 

38 out_ptr + out_start_offset + inner_offset, cur_value, mask=mask 

39 ) 

40 

41 

42def index_wrapper(input, indices, out): 

43 input_shape = input.shape 

44 input_dim = len(input_shape) 

45 indices_dim = len(indices) 

46 

47 stride = 1 

48 for i in range(0, input_dim - indices_dim): 

49 stride *= input_shape[input_dim - i - 1] 

50 

51 index_len = indices[0].numel() 

52 

53 actual_index = indices[0] 

54 for idx in range(0, indices_dim - 1): 

55 actual_index = actual_index * input_shape[idx + 1] + indices[idx + 1] 

56 

57 BLOCK_SIZE = 32 

58 MAX_DATA_SIZE = 16 * 1024 

59 

60 grid = lambda meta: (triton.cdiv(index_len, meta["BLOCK_SIZE"]),) 

61 

62 index_kernel_func[grid]( 

63 input, 

64 stride, 

65 index_len, 

66 actual_index, 

67 out, 

68 BLOCK_SIZE=BLOCK_SIZE, 

69 MAX_DATA_SIZE=MAX_DATA_SIZE, 

70 ) 

71 

72 

73def get_max_rank_shape(indices: List[torch.Tensor]) -> List[int]: 

74 max_rank = max([len(index.shape) for index in indices]) 

75 shape = [0 for _ in range(max_rank)] 

76 for i in range(max_rank): 

77 max_num = 0 

78 for index in indices: 

79 axis = len(index.shape) - 1 - i 

80 if axis >= 0: 

81 max_num = max(max_num, index.shape[axis]) # 

82 shape[max_rank - 1 - i] = max_num 

83 return shape 

84 

85 

86def broadcast_indices(indices, target_shape): 

87 for i, index in enumerate(indices): 

88 if tuple(index.shape) != tuple(target_shape): 

89 indices[i] = torch.broadcast_to(index, target_shape) 

90 

91 

92def index(inp, indices): 

93 logger.debug("GEMS_ASCEND INDEX") 

94 indices = list(indices) 

95 

96 target_shape = get_max_rank_shape(indices) 

97 broadcast_indices(indices, target_shape) 

98 target_shape += inp.shape[len(indices) :] 

99 out = torch.empty(target_shape, dtype=inp.dtype, device=inp.device) 

100 

101 index_wrapper(inp, indices, out) 

102 return out