Coverage for src/flag_gems/runtime/backend/_ascend/ops/index_select.py: 0%

54 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-27 02:51 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.utils import dim_compress, libentry 

9from flag_gems.utils import triton_lang_extension as tle 

10 

11logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}') 

12 

13 

14@libentry() 

15@triton.heuristics(runtime.get_heuristic_config("index_select")) 

16@triton.jit 

17def index_select_kernel( 

18 inp, out, M, N, index, index_len, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr 

19): 

20 pid_x = tle.program_id(axis=0) 

21 pid_y = tle.program_id(axis=1) 

22 num_pid_x = tle.num_programs(axis=0) 

23 loop_count = tl.cdiv(M, num_pid_x) 

24 for loop in range(0, loop_count): 

25 rows_offsets = (pid_x * loop_count + loop) * BLOCK_M + tl.arange(0, BLOCK_M)[ 

26 :, None 

27 ] 

28 rows_mask = rows_offsets < M 

29 cols_offsets = pid_y * BLOCK_N + tl.arange(0, BLOCK_N) 

30 

31 out_mask = rows_mask and (cols_offsets < index_len) 

32 

33 indices = tl.load( 

34 index + cols_offsets, mask=(cols_offsets < index_len), other=0 

35 ) 

36 inp_off = rows_offsets * N + indices[None, :] 

37 out_off = rows_offsets * index_len + cols_offsets[None, :] 

38 

39 selected = tl.load(inp + inp_off, mask=rows_mask, other=0.0) 

40 tl.store(out + out_off, selected, mask=out_mask) 

41 

42 

43def index_select(inp, dim, index): 

44 logger.debug("GEMS_ASCEND INDEX SELECT") 

45 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

46 assert index.ndim <= 1, "Index should have dimension 1 or 0" 

47 assert ((i >= 0 and i < inp.size(dim)) for i in index), "Index out of range" 

48 

49 if index.ndim == 0: 

50 index = index.unsqueeze(0) 

51 dim = dim % inp.ndim 

52 inp_shape = list(inp.shape) 

53 index_len = index.numel() 

54 

55 # with dim_compress 

56 inp = dim_compress(inp, dim) 

57 N = inp_shape[dim] 

58 M = inp.numel() // N 

59 out_shape = list(inp.shape) 

60 out_shape[inp.ndim - 1] = index_len 

61 out = torch.empty(out_shape, dtype=inp.dtype, device=inp.device) 

62 

63 def grid(meta): 

64 dim0 = triton.cdiv(M, meta["BLOCK_M"]) 

65 dim1 = triton.cdiv(index_len, meta["BLOCK_N"]) if index_len > 0 else 1 

66 while dim0 * dim1 >= 65536: 

67 dim0 = triton.cdiv(dim0, 2) 

68 return ( 

69 dim0, 

70 dim1, 

71 ) 

72 

73 index_select_kernel[grid](inp, out, M, N, index, index_len) 

74 if dim != out.ndim - 1: 

75 order = [i for i in range(out.ndim - 1)] 

76 order.insert(dim, out.ndim - 1) 

77 return out.permute(order).contiguous() 

78 else: 

79 return out