Coverage for src/flag_gems/ops/index_select.py: 68%

50 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-24 15:40 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.utils import dim_compress, libentry 

9from flag_gems.utils import triton_lang_extension as tle 

10 

11logger = logging.getLogger(__name__) 

12 

13 

14@libentry() 

15@triton.heuristics(runtime.get_heuristic_config("index_select")) 

16@triton.jit 

17def index_select_kernel( 

18 inp, out, M, N, index, index_len, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr 

19): 

20 pid_x = tle.program_id(axis=0) 

21 pid_y = tle.program_id(axis=1) 

22 rows_offsets = pid_x * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

23 rows_mask = rows_offsets < M 

24 cols_offsets = pid_y * BLOCK_N + tl.arange(0, BLOCK_N) 

25 

26 out_mask = rows_mask and (cols_offsets < index_len) 

27 

28 indices = tl.load(index + cols_offsets, mask=(cols_offsets < index_len), other=0) 

29 valid_lower_bound = indices >= 0 

30 valid_upper_bound = indices < N 

31 index_valid_mask = valid_lower_bound & valid_upper_bound 

32 

33 inp_off = rows_offsets * N + indices[None, :] 

34 out_off = rows_offsets * index_len + cols_offsets[None, :] 

35 

36 final_mask = out_mask & index_valid_mask 

37 selected = tl.load(inp + inp_off, mask=final_mask, other=0.0) 

38 tl.store(out + out_off, selected, mask=final_mask) 

39 

40 

41def index_select(inp, dim, index): 

42 logger.debug("GEMS INDEX SELECT") 

43 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

44 assert index.ndim <= 1, "Index should have dimension 1 or 0" 

45 

46 if index.ndim == 0: 

47 index = index.unsqueeze(0) 

48 dim = dim % inp.ndim 

49 inp_shape = list(inp.shape) 

50 index_len = index.numel() 

51 

52 # with dim_compress 

53 inp = dim_compress(inp, dim) 

54 N = inp_shape[dim] 

55 M = inp.numel() // N 

56 out_shape = list(inp.shape) 

57 out_shape[inp.ndim - 1] = index_len 

58 out = torch.empty(out_shape, dtype=inp.dtype, device=inp.device) 

59 

60 grid = lambda meta: ( 

61 triton.cdiv(M, meta["BLOCK_M"]), 

62 triton.cdiv(index_len, meta["BLOCK_N"]), 

63 ) 

64 index_select_kernel[grid](inp, out, M, N, index, index_len) 

65 if dim != out.ndim - 1: 

66 order = [i for i in range(out.ndim - 1)] 

67 order.insert(dim, out.ndim - 1) 

68 out = out.permute(order).contiguous() 

69 return out.reshape(out.shape) 

70 else: 

71 return out