Coverage for src/flag_gems/runtime/backend/_ascend/ops/index_select.py: 0%
54 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-26 15:32 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-26 15:32 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.utils import dim_compress, libentry
9from flag_gems.utils import triton_lang_extension as tle
11logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}')
14@libentry()
15@triton.heuristics(runtime.get_heuristic_config("index_select"))
16@triton.jit
17def index_select_kernel(
18 inp, out, M, N, index, index_len, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr
19):
20 pid_x = tle.program_id(axis=0)
21 pid_y = tle.program_id(axis=1)
22 num_pid_x = tle.num_programs(axis=0)
23 loop_count = tl.cdiv(M, num_pid_x)
24 for loop in range(0, loop_count):
25 rows_offsets = (pid_x * loop_count + loop) * BLOCK_M + tl.arange(0, BLOCK_M)[
26 :, None
27 ]
28 rows_mask = rows_offsets < M
29 cols_offsets = pid_y * BLOCK_N + tl.arange(0, BLOCK_N)
31 out_mask = rows_mask and (cols_offsets < index_len)
33 indices = tl.load(
34 index + cols_offsets, mask=(cols_offsets < index_len), other=0
35 )
36 inp_off = rows_offsets * N + indices[None, :]
37 out_off = rows_offsets * index_len + cols_offsets[None, :]
39 selected = tl.load(inp + inp_off, mask=rows_mask, other=0.0)
40 tl.store(out + out_off, selected, mask=out_mask)
43def index_select(inp, dim, index):
44 logger.debug("GEMS_ASCEND INDEX SELECT")
45 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
46 assert index.ndim <= 1, "Index should have dimension 1 or 0"
47 assert ((i >= 0 and i < inp.size(dim)) for i in index), "Index out of range"
49 if index.ndim == 0:
50 index = index.unsqueeze(0)
51 dim = dim % inp.ndim
52 inp_shape = list(inp.shape)
53 index_len = index.numel()
55 # with dim_compress
56 inp = dim_compress(inp, dim)
57 N = inp_shape[dim]
58 M = inp.numel() // N
59 out_shape = list(inp.shape)
60 out_shape[inp.ndim - 1] = index_len
61 out = torch.empty(out_shape, dtype=inp.dtype, device=inp.device)
63 def grid(meta):
64 dim0 = triton.cdiv(M, meta["BLOCK_M"])
65 dim1 = triton.cdiv(index_len, meta["BLOCK_N"]) if index_len > 0 else 1
66 while dim0 * dim1 >= 65536:
67 dim0 = triton.cdiv(dim0, 2)
68 return (
69 dim0,
70 dim1,
71 )
73 index_select_kernel[grid](inp, out, M, N, index, index_len)
74 if dim != out.ndim - 1:
75 order = [i for i in range(out.ndim - 1)]
76 order.insert(dim, out.ndim - 1)
77 return out.permute(order).contiguous()
78 else:
79 return out