Coverage for src/flag_gems/runtime/backend/_metax/ops/index_select.py: 0%

72 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-13 10:08 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7import flag_gems.runtime as runtime 

8from flag_gems.utils import dim_compress, libentry 

9from flag_gems.utils import triton_lang_extension as tle 

10 

11logger = logging.getLogger("flag_gems." + __name__) 

12 

13 

14@libentry() 

15@triton.heuristics(runtime.get_heuristic_config("index_select")) 

16@triton.jit 

17def index_select_kernel( 

18 inp, out, M, N, index, index_len, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr 

19): 

20 pid_x = tle.program_id(axis=0) 

21 pid_y = tle.program_id(axis=1) 

22 rows_offsets = pid_x * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

23 rows_mask = rows_offsets < M 

24 cols_offsets = pid_y * BLOCK_N + tl.arange(0, BLOCK_N) 

25 

26 out_mask = rows_mask and (cols_offsets < index_len) 

27 

28 indices = tl.load(index + cols_offsets, mask=(cols_offsets < index_len), other=0) 

29 inp_off = rows_offsets * N + indices[None, :] 

30 out_off = rows_offsets * index_len + cols_offsets[None, :] 

31 

32 selected = tl.load(inp + inp_off, mask=rows_mask, other=0.0) 

33 tl.store(out + out_off, selected, mask=out_mask) 

34 

35 

36@libentry() 

37@triton.jit 

38def index_select_2d_opt_kernel(inp, out, M, N, index, BLOCK_SIZE: tl.constexpr): 

39 pid = tle.program_id(axis=0) 

40 

41 row_index = tl.load(index + pid) 

42 row_offset = row_index * M 

43 rows_mask = row_index < N 

44 

45 for m in range(0, tl.cdiv(M, BLOCK_SIZE)): 

46 cols_offsets = m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

47 cols_mask = cols_offsets < M 

48 block_mask = rows_mask and cols_mask 

49 cur_inp = tl.load(inp + row_offset + cols_offsets, mask=block_mask, other=0.0) 

50 out_offset = pid * M + cols_offsets 

51 tl.store(out + out_offset, cur_inp, mask=block_mask) 

52 

53 

54# Swap two dimensions of a 2D tensor to for better memory access pattern 

55def dim_transpose(inp): 

56 return torch.transpose(inp, 0, 1).contiguous() 

57 

58 

59def index_select(inp, dim, index): 

60 logger.debug("METAX GEMS INDEX SELECT") 

61 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

62 assert index.ndim <= 1, "Index should have dimension 1 or 0" 

63 assert ((i >= 0 and i < inp.size(dim)) for i in index), "Index out of range" 

64 

65 if index.ndim == 0: 

66 index = index.unsqueeze(0) 

67 dim = dim % inp.ndim 

68 index_len = index.numel() 

69 inp_shape = list(inp.shape) 

70 N = inp_shape[dim] 

71 M = inp.numel() // N 

72 out_shape = list(inp.shape) 

73 out_shape[dim] = index_len 

74 out = torch.empty(out_shape, dtype=inp.dtype, device=inp.device) 

75 

76 if inp.ndim == 2 and dim == 0: 

77 BLOCK_SIZE = min(triton.next_power_of_2(M), 4096) 

78 if dim == 0: 

79 index_select_2d_opt_kernel[(index_len,)]( 

80 inp, out, M, N, index, BLOCK_SIZE=BLOCK_SIZE 

81 ) 

82 

83 return out 

84 else: 

85 # with dim_compress 

86 inp = dim_compress(inp, dim) 

87 N = inp_shape[dim] 

88 M = inp.numel() // N 

89 out_shape = list(inp.shape) 

90 out_shape[inp.ndim - 1] = index_len 

91 out = torch.empty(out_shape, dtype=inp.dtype, device=inp.device) 

92 

93 grid = lambda meta: ( 

94 triton.cdiv(M, meta["BLOCK_M"]), 

95 triton.cdiv(index_len, meta["BLOCK_N"]), 

96 ) 

97 index_select_kernel[grid](inp, out, M, N, index, index_len) 

98 if dim != out.ndim - 1: 

99 order = [i for i in range(out.ndim - 1)] 

100 order.insert(dim, out.ndim - 1) 

101 return out.permute(order).contiguous() 

102 else: 

103 return out