Coverage for src/flag_gems/runtime/backend/_ascend/ops/gather.py: 0%

81 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-10 02:30 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.ops.scatter import scatter_ 

8from flag_gems.utils import libentry 

9from flag_gems.utils.shape_utils import restride_dim 

10 

11logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}') 

12# Hardware specification: Atlas 800T/I A2 product's on-chip memory capacity is 192KB 

13UB_SIZE_BYTES = 192 * 1024 

14 

15 

16def compute_base_offset(shape, strides, dim): 

17 # Given shape/strides and a dimension, output a tensor with the size of 'shape', 

18 # where each position is the offset of the input (excluding the 'dim' dimension) 

19 idx = torch.arange(int(torch.prod(torch.tensor(shape))), device="cpu") 

20 coord = torch.empty((len(shape), idx.numel()), dtype=torch.long, device="cpu") 

21 for i in reversed(range(len(shape))): 

22 coord[i] = idx % shape[i] 

23 idx = idx // shape[i] 

24 

25 offset = torch.zeros_like(coord[0]) 

26 for i in range(len(shape)): 

27 if i != dim: 

28 offset += coord[i] * strides[i] 

29 return offset 

30 

31 

32@libentry() 

33@triton.heuristics({"BLOCK_SIZE": lambda args: 4096}) 

34@triton.jit 

35def _gather_flat_kernel_fixed( 

36 inp, 

37 index, 

38 out, 

39 base_offset, 

40 inp_dim_stride, 

41 N, 

42 BLOCK_SIZE: tl.constexpr, 

43): 

44 pid = tl.program_id(0) 

45 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

46 mask = offset < N 

47 

48 cur_index = tl.load(index + offset, mask=mask, other=0) 

49 base = tl.load(base_offset + offset, mask=mask, other=0) 

50 

51 inp_offset = base + cur_index * inp_dim_stride 

52 

53 val = tl.load(inp + inp_offset, mask=mask, other=0) 

54 tl.store(out + offset, val, mask=mask) 

55 

56 

57def gather_flat_fixed(inp: torch.Tensor, dim: int, index: torch.Tensor, out=None): 

58 logger.debug("GEMS_ASCEND GATHER (fixed version)") 

59 

60 if out is None: 

61 out = torch.empty_like(index, dtype=inp.dtype, device=inp.device) 

62 

63 N = index.numel() 

64 dim_stride = inp.stride(dim) 

65 inp_strided = restride_dim(inp, dim, index.shape) 

66 if dim == -1: 

67 dim = inp_strided.dim() - 1 

68 base_offset = compute_base_offset(index.shape, inp_strided.stride(), dim).to( 

69 torch.int64 

70 ) 

71 base_offset = base_offset.npu() 

72 grid = lambda META: (triton.cdiv(N, META["BLOCK_SIZE"]),) 

73 _gather_flat_kernel_fixed[grid]( 

74 inp_strided, 

75 index, 

76 out, 

77 base_offset, 

78 dim_stride, 

79 N, 

80 ) 

81 return out 

82 

83 

84@triton.jit 

85def _gather_high_perf_kernel( 

86 # Pointers 

87 x_ptr, 

88 idx_ptr, 

89 out_ptr, 

90 stride_x_rows, 

91 stride_x_feats, 

92 stride_idx_rows, 

93 stride_idx_cols, 

94 stride_out_rows, 

95 stride_out_cols, 

96 num_indices: tl.constexpr, 

97 x_size: tl.constexpr, 

98): 

99 row_id = tl.program_id(0) 

100 

101 offs_idx = tl.arange(0, num_indices) 

102 offs_x = tl.arange(0, x_size) 

103 

104 # Load indices for this row 

105 idx_ptrs = idx_ptr + row_id * stride_idx_rows + offs_idx * stride_idx_cols 

106 indices = tl.load(idx_ptrs) 

107 

108 # Load feature vector 

109 x_ptrs = x_ptr + row_id * stride_x_rows + offs_x * stride_x_feats 

110 x_vals = tl.load(x_ptrs) 

111 

112 # Perform gather 

113 out_vals = tl.gather(x_vals, indices, 0) 

114 

115 # Store result 

116 out_ptrs = out_ptr + row_id * stride_out_rows + offs_idx * stride_out_cols 

117 tl.store(out_ptrs, out_vals) 

118 

119 

120def gather_high_perf(inp: torch.Tensor, index: torch.Tensor, out=None): 

121 if out is None: 

122 out = torch.empty_like(index, dtype=inp.dtype, device=inp.device) 

123 

124 x_size = inp.shape[-1] 

125 num_indices = index.shape[-1] 

126 

127 num_rows = index.shape[0] 

128 

129 grid = (num_rows,) 

130 _gather_high_perf_kernel[grid]( 

131 inp, 

132 index, 

133 out, 

134 inp.stride(0), 

135 inp.stride(1), 

136 index.stride(0), 

137 index.stride(1), 

138 out.stride(0), 

139 out.stride(1), 

140 num_indices=num_indices, 

141 x_size=x_size, 

142 ) 

143 return out 

144 

145 

146def gather(inp, dim, index, out=None, sparse_grad=False): 

147 logger.debug("GEMS_ASCEND GATHER") 

148 if out is None: 

149 out = torch.empty_like(index, dtype=inp.dtype, device=inp.device) 

150 

151 dim = dim % inp.dim() 

152 is_last_dim = dim == inp.dim() - 1 

153 

154 total_bytes = ( 

155 inp.size(-1) * inp.element_size() 

156 + index.size(-1) * index.element_size() 

157 + index.size(-1) * inp.element_size() 

158 ) 

159 

160 if is_last_dim and inp.dim() == 2 and total_bytes < UB_SIZE_BYTES: 

161 return gather_high_perf(inp, index, out) 

162 

163 return gather_flat_fixed(inp, dim, index, out) 

164 

165 

166def gather_backward(grad, self, dim, index, sparse_grad): 

167 logger.debug("GEMS_ASCEND GATHER BACKWARD") 

168 result = grad.new_zeros(self.shape) 

169 return scatter_(result, dim, index, grad, reduce="add")