Coverage for src/flag_gems/runtime/backend/_ascend/ops/gather.py: 0%
81 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-16 02:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-16 02:02 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.ops.scatter import scatter_
8from flag_gems.utils import libentry
9from flag_gems.utils.shape_utils import restride_dim
11logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}')
12# Hardware specification: Atlas 800T/I A2 product's on-chip memory capacity is 192KB
13UB_SIZE_BYTES = 192 * 1024
16def compute_base_offset(shape, strides, dim):
17 # Given shape/strides and a dimension, output a tensor with the size of 'shape',
18 # where each position is the offset of the input (excluding the 'dim' dimension)
19 idx = torch.arange(int(torch.prod(torch.tensor(shape))), device="cpu")
20 coord = torch.empty((len(shape), idx.numel()), dtype=torch.long, device="cpu")
21 for i in reversed(range(len(shape))):
22 coord[i] = idx % shape[i]
23 idx = idx // shape[i]
25 offset = torch.zeros_like(coord[0])
26 for i in range(len(shape)):
27 if i != dim:
28 offset += coord[i] * strides[i]
29 return offset
32@libentry()
33@triton.heuristics({"BLOCK_SIZE": lambda args: 4096})
34@triton.jit
35def _gather_flat_kernel_fixed(
36 inp,
37 index,
38 out,
39 base_offset,
40 inp_dim_stride,
41 N,
42 BLOCK_SIZE: tl.constexpr,
43):
44 pid = tl.program_id(0)
45 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
46 mask = offset < N
48 cur_index = tl.load(index + offset, mask=mask, other=0)
49 base = tl.load(base_offset + offset, mask=mask, other=0)
51 inp_offset = base + cur_index * inp_dim_stride
53 val = tl.load(inp + inp_offset, mask=mask, other=0)
54 tl.store(out + offset, val, mask=mask)
57def gather_flat_fixed(inp: torch.Tensor, dim: int, index: torch.Tensor, out=None):
58 logger.debug("GEMS_ASCEND GATHER (fixed version)")
60 if out is None:
61 out = torch.empty_like(index, dtype=inp.dtype, device=inp.device)
63 N = index.numel()
64 dim_stride = inp.stride(dim)
65 inp_strided = restride_dim(inp, dim, index.shape)
66 if dim == -1:
67 dim = inp_strided.dim() - 1
68 base_offset = compute_base_offset(index.shape, inp_strided.stride(), dim).to(
69 torch.int64
70 )
71 base_offset = base_offset.npu()
72 grid = lambda META: (triton.cdiv(N, META["BLOCK_SIZE"]),)
73 _gather_flat_kernel_fixed[grid](
74 inp_strided,
75 index,
76 out,
77 base_offset,
78 dim_stride,
79 N,
80 )
81 return out
84@triton.jit
85def _gather_high_perf_kernel(
86 # Pointers
87 x_ptr,
88 idx_ptr,
89 out_ptr,
90 stride_x_rows,
91 stride_x_feats,
92 stride_idx_rows,
93 stride_idx_cols,
94 stride_out_rows,
95 stride_out_cols,
96 num_indices: tl.constexpr,
97 x_size: tl.constexpr,
98):
99 row_id = tl.program_id(0)
101 offs_idx = tl.arange(0, num_indices)
102 offs_x = tl.arange(0, x_size)
104 # Load indices for this row
105 idx_ptrs = idx_ptr + row_id * stride_idx_rows + offs_idx * stride_idx_cols
106 indices = tl.load(idx_ptrs)
108 # Load feature vector
109 x_ptrs = x_ptr + row_id * stride_x_rows + offs_x * stride_x_feats
110 x_vals = tl.load(x_ptrs)
112 # Perform gather
113 out_vals = tl.gather(x_vals, indices, 0)
115 # Store result
116 out_ptrs = out_ptr + row_id * stride_out_rows + offs_idx * stride_out_cols
117 tl.store(out_ptrs, out_vals)
120def gather_high_perf(inp: torch.Tensor, index: torch.Tensor, out=None):
121 if out is None:
122 out = torch.empty_like(index, dtype=inp.dtype, device=inp.device)
124 x_size = inp.shape[-1]
125 num_indices = index.shape[-1]
127 num_rows = index.shape[0]
129 grid = (num_rows,)
130 _gather_high_perf_kernel[grid](
131 inp,
132 index,
133 out,
134 inp.stride(0),
135 inp.stride(1),
136 index.stride(0),
137 index.stride(1),
138 out.stride(0),
139 out.stride(1),
140 num_indices=num_indices,
141 x_size=x_size,
142 )
143 return out
146def gather(inp, dim, index, out=None, sparse_grad=False):
147 logger.debug("GEMS_ASCEND GATHER")
148 if out is None:
149 out = torch.empty_like(index, dtype=inp.dtype, device=inp.device)
151 dim = dim % inp.dim()
152 is_last_dim = dim == inp.dim() - 1
154 total_bytes = (
155 inp.size(-1) * inp.element_size()
156 + index.size(-1) * index.element_size()
157 + index.size(-1) * inp.element_size()
158 )
160 if is_last_dim and inp.dim() == 2 and total_bytes < UB_SIZE_BYTES:
161 return gather_high_perf(inp, index, out)
163 return gather_flat_fixed(inp, dim, index, out)
166def gather_backward(grad, self, dim, index, sparse_grad):
167 logger.debug("GEMS_ASCEND GATHER BACKWARD")
168 result = grad.new_zeros(self.shape)
169 return scatter_(result, dim, index, grad, reduce="add")