Coverage for src/flag_gems/runtime/backend/_mthreads/ops/index_select.py: 0%

101 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-24 15:40 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.ops.index_select import index_select as default_index_select 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry 

10from flag_gems.utils import triton_lang_extension as tle 

11 

12logger = logging.getLogger( 

13 f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}' 

14) 

15 

16 

17@libentry() 

18@triton.jit 

19def index_select_dim0_1d_kernel( 

20 inp_ptr, 

21 out_ptr, 

22 index_ptr, 

23 inp_row_stride, 

24 out_row_stride, 

25 row_size, 

26 num_indices, 

27 BLOCK_SIZE: tl.constexpr, 

28): 

29 """Kernel for dim=0 index_select - each program handles one row.""" 

30 pid = tle.program_id(axis=0) 

31 

32 # Load the index for this row 

33 row_index = tl.load(index_ptr + pid) 

34 

35 # Calculate input and output row offsets 

36 inp_row_offset = row_index * inp_row_stride 

37 out_row_offset = pid * out_row_stride 

38 

39 # Process row in chunks 

40 for offset in range(0, row_size, BLOCK_SIZE): 

41 cols = offset + tl.arange(0, BLOCK_SIZE) 

42 mask = cols < row_size 

43 

44 # Load from input and store to output 

45 data = tl.load(inp_ptr + inp_row_offset + cols, mask=mask, other=0.0) 

46 tl.store(out_ptr + out_row_offset + cols, data, mask=mask) 

47 

48 

49@libentry() 

50@triton.jit 

51def index_select_dim0_split_kernel( 

52 inp_ptr, 

53 out_ptr, 

54 index_ptr, 

55 inp_row_stride, 

56 out_row_stride, 

57 row_size, 

58 num_indices, 

59 BLOCK_SIZE: tl.constexpr, 

60): 

61 """Kernel for dim=0 index_select - 2D grid for large row_size. 

62 First dimension: indices, Second dimension: column chunks. 

63 """ 

64 pid_idx = tle.program_id(axis=0) 

65 pid_col = tle.program_id(axis=1) 

66 

67 # Load the index for this row 

68 row_index = tl.load(index_ptr + pid_idx) 

69 

70 # Calculate input and output row offsets 

71 inp_row_offset = row_index * inp_row_stride 

72 out_row_offset = pid_idx * out_row_stride 

73 

74 # Calculate column offset for this program 

75 col_offset = pid_col * BLOCK_SIZE 

76 cols = col_offset + tl.arange(0, BLOCK_SIZE) 

77 mask = cols < row_size 

78 

79 # Load from input and store to output 

80 data = tl.load(inp_ptr + inp_row_offset + cols, mask=mask, other=0.0) 

81 tl.store(out_ptr + out_row_offset + cols, data, mask=mask) 

82 

83 

84@libentry() 

85@triton.jit 

86def index_select_dim1_kernel( 

87 inp_ptr, 

88 out_ptr, 

89 index_ptr, 

90 num_rows, 

91 inp_row_stride, 

92 out_row_stride, 

93 num_indices, 

94 BLOCK_M: tl.constexpr, 

95 BLOCK_N: tl.constexpr, 

96): 

97 """Optimized kernel for dim=1 index_select on 2D tensors. 

98 Each program handles a tile of rows x indices. 

99 """ 

100 pid_m = tle.program_id(axis=0) 

101 pid_n = tle.program_id(axis=1) 

102 

103 row_start = pid_m * BLOCK_M 

104 idx_start = pid_n * BLOCK_N 

105 

106 rows = row_start + tl.arange(0, BLOCK_M)[:, None] 

107 idx_offsets = idx_start + tl.arange(0, BLOCK_N)[None, :] 

108 

109 rows_mask = rows < num_rows 

110 idx_mask = idx_offsets < num_indices 

111 mask = rows_mask & idx_mask 

112 

113 # Load indices 

114 indices = tl.load(index_ptr + idx_offsets, mask=idx_mask, other=0) 

115 

116 # Calculate offsets 

117 inp_offsets = rows * inp_row_stride + indices 

118 out_offsets = rows * out_row_stride + idx_offsets 

119 

120 # Load and store 

121 data = tl.load(inp_ptr + inp_offsets, mask=mask, other=0.0) 

122 tl.store(out_ptr + out_offsets, data, mask=mask) 

123 

124 

125def _get_num_warps(total_elements): 

126 """Get optimal num_warps based on workload size.""" 

127 if total_elements < 1024: 

128 return 2 

129 elif total_elements < 4096: 

130 return 4 

131 elif total_elements < 16384: 

132 return 8 

133 else: 

134 return 16 

135 

136 

137def index_select(inp, dim, index): 

138 logger.debug("GEMS_MTHREADS INDEX SELECT") 

139 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

140 assert index.ndim <= 1, "Index should have dimension 1 or 0" 

141 

142 if index.ndim == 0: 

143 index = index.unsqueeze(0) 

144 

145 dim = dim % inp.ndim 

146 index_len = index.numel() 

147 

148 # Create output shape 

149 out_shape = list(inp.shape) 

150 out_shape[dim] = index_len 

151 out = torch.empty(out_shape, dtype=inp.dtype, device=inp.device) 

152 

153 if inp.numel() == 0 or index_len == 0: 

154 return out 

155 

156 # Optimized path for 2D tensors with dim=0 

157 if inp.ndim == 2 and dim == 0 and inp.is_contiguous(): 

158 num_rows, row_size = inp.shape 

159 inp_row_stride = inp.stride(0) 

160 out_row_stride = out.stride(0) 

161 

162 # For large row_size, use 2D grid (indices x column_chunks) for more parallelism 

163 if row_size >= 16384: 

164 BLOCK_SIZE = 1024 

165 num_col_chunks = triton.cdiv(row_size, BLOCK_SIZE) 

166 grid = (index_len, num_col_chunks) 

167 num_warps = _get_num_warps(BLOCK_SIZE) 

168 

169 with torch_device_fn.device(inp.device): 

170 index_select_dim0_split_kernel[grid]( 

171 inp, 

172 out, 

173 index, 

174 inp_row_stride, 

175 out_row_stride, 

176 row_size, 

177 index_len, 

178 BLOCK_SIZE=BLOCK_SIZE, 

179 num_warps=num_warps, 

180 ) 

181 return out 

182 else: 

183 # Use 1D kernel - each program handles one complete row 

184 BLOCK_SIZE = min(triton.next_power_of_2(row_size), 2048) 

185 num_warps = _get_num_warps(BLOCK_SIZE) 

186 

187 with torch_device_fn.device(inp.device): 

188 index_select_dim0_1d_kernel[(index_len,)]( 

189 inp, 

190 out, 

191 index, 

192 inp_row_stride, 

193 out_row_stride, 

194 row_size, 

195 index_len, 

196 BLOCK_SIZE=BLOCK_SIZE, 

197 num_warps=num_warps, 

198 ) 

199 return out 

200 

201 # Optimized path for 2D tensors with dim=1 

202 if inp.ndim == 2 and dim == 1 and inp.is_contiguous(): 

203 num_rows, num_cols = inp.shape 

204 inp_row_stride = inp.stride(0) 

205 out_row_stride = out.stride(0) 

206 

207 BLOCK_M = min(triton.next_power_of_2(num_rows), 64) 

208 BLOCK_N = min(triton.next_power_of_2(index_len), 128) 

209 

210 grid = (triton.cdiv(num_rows, BLOCK_M), triton.cdiv(index_len, BLOCK_N)) 

211 num_warps = _get_num_warps(BLOCK_M * BLOCK_N) 

212 

213 with torch_device_fn.device(inp.device): 

214 index_select_dim1_kernel[grid]( 

215 inp, 

216 out, 

217 index, 

218 num_rows, 

219 inp_row_stride, 

220 out_row_stride, 

221 index_len, 

222 BLOCK_M=BLOCK_M, 

223 BLOCK_N=BLOCK_N, 

224 num_warps=num_warps, 

225 ) 

226 return out 

227 

228 # Fall back to default implementation for other cases 

229 return default_index_select(inp, dim, index)