Coverage for src/flag_gems/runtime/backend/_ascend/ops/index.py: 0%
59 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
1import logging
2from typing import List
4import torch
5import triton
6import triton.language as tl
8logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}')
11@triton.jit
12def index_kernel_func(
13 input_ptr,
14 stride: tl.constexpr,
15 index_len,
16 index_ptr,
17 out_ptr,
18 BLOCK_SIZE: tl.constexpr,
19 MAX_DATA_SIZE: tl.constexpr,
20):
21 pid0 = tl.program_id(axis=0)
23 for i in range(0, BLOCK_SIZE):
24 offset = pid0 * BLOCK_SIZE + i
26 if offset < index_len:
27 in_start_index = tl.load(index_ptr + offset) * stride
28 out_start_offset = offset * stride
29 loop_num = (stride - 1) // MAX_DATA_SIZE + 1
31 for loop_idx in range(0, loop_num):
32 inner_offset = loop_idx * MAX_DATA_SIZE + tl.arange(0, MAX_DATA_SIZE)
33 mask = inner_offset < stride
34 cur_value = tl.load(
35 input_ptr + in_start_index + inner_offset, mask=mask
36 )
37 tl.store(
38 out_ptr + out_start_offset + inner_offset, cur_value, mask=mask
39 )
42def index_wrapper(input, indices, out):
43 input_shape = input.shape
44 input_dim = len(input_shape)
45 indices_dim = len(indices)
47 stride = 1
48 for i in range(0, input_dim - indices_dim):
49 stride *= input_shape[input_dim - i - 1]
51 index_len = indices[0].numel()
53 actual_index = indices[0]
54 for idx in range(0, indices_dim - 1):
55 actual_index = actual_index * input_shape[idx + 1] + indices[idx + 1]
57 BLOCK_SIZE = 32
58 MAX_DATA_SIZE = 16 * 1024
60 grid = lambda meta: (triton.cdiv(index_len, meta["BLOCK_SIZE"]),)
62 index_kernel_func[grid](
63 input,
64 stride,
65 index_len,
66 actual_index,
67 out,
68 BLOCK_SIZE=BLOCK_SIZE,
69 MAX_DATA_SIZE=MAX_DATA_SIZE,
70 )
73def get_max_rank_shape(indices: List[torch.Tensor]) -> List[int]:
74 max_rank = max([len(index.shape) for index in indices])
75 shape = [0 for _ in range(max_rank)]
76 for i in range(max_rank):
77 max_num = 0
78 for index in indices:
79 axis = len(index.shape) - 1 - i
80 if axis >= 0:
81 max_num = max(max_num, index.shape[axis]) #
82 shape[max_rank - 1 - i] = max_num
83 return shape
86def broadcast_indices(indices, target_shape):
87 for i, index in enumerate(indices):
88 if tuple(index.shape) != tuple(target_shape):
89 indices[i] = torch.broadcast_to(index, target_shape)
92def index(inp, indices):
93 logger.debug("GEMS_ASCEND INDEX")
94 indices = list(indices)
96 target_shape = get_max_rank_shape(indices)
97 broadcast_indices(indices, target_shape)
98 target_shape += inp.shape[len(indices) :]
99 out = torch.empty(target_shape, dtype=inp.dtype, device=inp.device)
101 index_wrapper(inp, indices, out)
102 return out