Coverage for src/flag_gems/runtime/backend/_cambricon/ops/index_select.py: 0%
167 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.utils import libentry, libtuner
11from ..utils import MAX_NRAM_SIZE, TOTAL_CORE_NUM
13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
16def get_max_block_size(dtype_size):
17 return MAX_NRAM_SIZE // 3 // dtype_size
20def config_prune(configs, named_args, **kwargs):
21 N = named_args["N"]
22 dtype_size = named_args["dtype_size"]
23 max_block_size = get_max_block_size(dtype_size)
25 pruned_configs = []
26 index_block_size = []
27 for config in configs:
28 bs = config.kwargs["BLOCK_SIZE"]
29 ibs = (bs + N - 1) // N
30 if ibs not in index_block_size and ibs * N <= max_block_size:
31 index_block_size.append(ibs)
32 pruned_configs.append(config)
34 in_n_elements = named_args["in_n_elements"]
36 # make sure at least one config is at the load-balance sweet point
37 if in_n_elements % TOTAL_CORE_NUM == 0:
38 bs = min(max(in_n_elements // TOTAL_CORE_NUM, 1) * N, max_block_size)
39 else:
40 bs = min(max(in_n_elements // TOTAL_CORE_NUM, 1) * N + 1, max_block_size)
41 if (bs + N - 1) // N not in index_block_size:
42 pruned_configs.append(
43 triton.Config(kwargs={"BLOCK_SIZE": bs}, num_stages=1, num_warps=1)
44 )
46 return pruned_configs
49@triton.jit
50def ld_st_1(indices, N: tl.constexpr, weight_ptr, in_mask, in_offsets, out_ptr):
51 weight_offsets = indices[:, None] * N + tl.arange(0, N)
52 embedding_weight = tl.load(weight_ptr + weight_offsets, in_mask[:, None])
53 out_offsets = in_offsets[:, None] * N + tl.arange(0, N)
54 tl.store(out_ptr + out_offsets, embedding_weight, in_mask[:, None])
57@libentry()
58@libtuner(
59 configs=[
60 # [512, 65536]
61 triton.Config(kwargs={"BLOCK_SIZE": 512 * 2**i}, num_stages=1, num_warps=1)
62 for i in range(0, 8, 2)
63 ],
64 key=["N"],
65 prune_configs_by={
66 "early_config_prune": config_prune,
67 },
68)
69@triton.jit
70def one_batch_index_select_kernel( # 2D
71 out_ptr,
72 in_ptr,
73 in_n_elements,
74 weight_ptr,
75 N: tl.constexpr,
76 dtype_size,
77 inp_numel,
78 BLOCK_SIZE: tl.constexpr,
79):
80 pid = tl.program_id(0)
81 num_jobs = tl.num_programs(axis=0)
83 INDEX_BLOCK_SIZE: tl.constexpr = (BLOCK_SIZE + N - 1) // N
85 step = num_jobs * INDEX_BLOCK_SIZE
86 iters = tl.cdiv(in_n_elements, step)
88 # TODO: remove dtype_size once contiguous DMA is ensured
89 small_out = inp_numel.to(tl.int64) * dtype_size <= 2**31
91 for i in tl.range(iters):
92 iter_start = i * step
93 iter_end = iter_start + step
95 if iter_end <= in_n_elements:
96 block_offset = iter_start + pid * INDEX_BLOCK_SIZE
97 block_len = INDEX_BLOCK_SIZE
98 else:
99 rem_n_elements = in_n_elements - iter_start
100 base_num = rem_n_elements // num_jobs
101 remn_num = rem_n_elements % num_jobs
102 extra_one = pid < remn_num
104 block_offset = iter_start + (
105 (base_num + 1) * pid if extra_one else (base_num * pid + remn_num)
106 )
107 block_len = base_num + extra_one
109 in_offsets = block_offset + tl.arange(0, INDEX_BLOCK_SIZE)
110 in_mask = in_offsets < (block_offset + block_len)
111 indices = tl.load(in_ptr + in_offsets, in_mask, other=0.0)
112 if indices.dtype != tl.int32 and small_out:
113 indices_int32 = indices.to(tl.int32)
114 ld_st_1(indices_int32, N, weight_ptr, in_mask, in_offsets, out_ptr)
115 else:
116 ld_st_1(indices, N, weight_ptr, in_mask, in_offsets, out_ptr)
119def config_prune(configs, named_args, **kwargs):
120 # TODO: bad perf when BLOCK_BATCH is 1
121 batch_dim = max(named_args["batch_dim"], 2)
122 index_dim = named_args["index_dim"]
123 c_dim = named_args["c_dim"]
124 dtype_size = named_args["dtype_size"]
126 # difficult to include these critical configs while keeping number of configs small
127 lb_block_batch_1 = triton.cdiv(batch_dim, TOTAL_CORE_NUM)
128 lb_block_batch_2 = max(batch_dim // TOTAL_CORE_NUM, 1)
129 lb_block_index_1 = triton.cdiv(index_dim, TOTAL_CORE_NUM)
130 lb_block_index_2 = max(index_dim // TOTAL_CORE_NUM, 1)
132 max_bs = get_max_block_size(dtype_size)
134 block_batches = set([lb_block_batch_1, lb_block_batch_2, batch_dim])
135 block_indices = set([lb_block_index_1, lb_block_index_2, index_dim])
136 block_cs = set([c_dim, min(max_bs, c_dim)])
138 new_configs = []
139 for config in configs:
140 block_batch = config.kwargs["BLOCK_BATCH"]
141 block_index = config.kwargs["BLOCK_INDEX"]
142 block_c = config.kwargs["BLOCK_C"]
144 # to keep the autotune space small: if c_dim is not very large, don't split c
145 block_c_max = 2048 * 5
146 block_c = c_dim if c_dim <= block_c_max else block_c
148 if block_batch <= batch_dim and block_index <= index_dim and block_c <= c_dim:
149 block_batches.add(block_batch)
150 block_indices.add(block_index)
151 block_cs.add(block_c)
153 for block_batch in block_batches:
154 for block_index in block_indices:
155 for block_c in block_cs:
156 if block_batch * block_index * block_c <= max_bs:
157 new_configs.append(
158 triton.Config(
159 {
160 "BLOCK_BATCH": block_batch,
161 "BLOCK_INDEX": block_index,
162 "BLOCK_C": block_c,
163 },
164 num_warps=1,
165 num_stages=1,
166 )
167 )
168 return new_configs
171@triton.jit
172def ld_st_2(
173 inp,
174 out,
175 batch_offsets,
176 index_offsets,
177 c_offsets,
178 inp_strides_0,
179 inp_strides_1,
180 out_strides_0,
181 out_strides_1,
182 index_cur,
183 input_output_mask,
184):
185 input_offsets = (batch_offsets * inp_strides_0)[:, None, None] + (
186 (index_cur * inp_strides_1)[:, None] + c_offsets[None, :]
187 )[None, :, :]
189 output_offsets = (batch_offsets * out_strides_0)[:, None, None] + (
190 (index_offsets * out_strides_1)[:, None] + c_offsets[None, :]
191 )[None, :, :]
193 selected = tl.load(inp + input_offsets, mask=input_output_mask, other=0.0)
194 tl.store(out + output_offsets, selected, mask=input_output_mask)
197@libentry()
198@libtuner(
199 configs=runtime.get_tuned_config("index_select"),
200 key=["batch_dim", "index_dim", "c_dim"],
201 prune_configs_by={"early_config_prune": config_prune},
202)
203@triton.jit
204def multi_batch_index_select_kernel(
205 inp,
206 index,
207 out,
208 batch_dim,
209 select_dim,
210 c_dim,
211 index_dim,
212 dtype_size,
213 inp_numel,
214 BLOCK_BATCH: tl.constexpr,
215 BLOCK_INDEX: tl.constexpr,
216 BLOCK_C: tl.constexpr,
217):
218 pid_x = tl.program_id(axis=0)
219 num_programs = tl.num_programs(axis=0)
221 block_id_start = pid_x
222 block_id_step = num_programs
224 block_batch: tl.constexpr = BLOCK_BATCH
225 block_index: tl.constexpr = BLOCK_INDEX
226 block_c: tl.constexpr = BLOCK_C
228 block_num_batch = tl.cdiv(batch_dim, block_batch)
229 block_num_index = tl.cdiv(index_dim, block_index)
230 block_num_c = tl.cdiv(c_dim, block_c)
232 block_num_total = block_num_batch * block_num_index * block_num_c
234 inp_strides_0, inp_strides_1 = [select_dim * c_dim, c_dim]
235 out_strides_0, out_strides_1 = [index_dim * c_dim, c_dim]
236 block_strides_0, block_strides_1 = [block_num_index * block_num_c, block_num_c]
238 # TODO: remove dtype_size once contiguous DMA is ensured
239 small_out = inp_numel.to(tl.int64) * dtype_size <= 2**31
241 for block_id in tl.range(block_id_start, block_num_total, block_id_step):
242 block_id_batch = block_id // block_strides_0
243 block_id_index = (block_id // block_strides_1) % block_num_index
244 block_id_c = block_id % block_num_c
246 # arange requires constexpr
247 batch_offsets = block_id_batch * block_batch + tl.arange(0, block_batch)
248 batch_mask = batch_offsets < batch_dim
250 index_offsets = block_id_index * block_index + tl.arange(0, block_index)
251 index_mask = index_offsets < index_dim
253 c_offsets = block_id_c * block_c + tl.arange(0, block_c)
254 c_mask = c_offsets < c_dim
256 input_output_mask = (
257 batch_mask[:, None, None]
258 and (index_mask[:, None] and c_mask[None, :])[None, :, :]
259 )
261 index_cur = tl.load(index + index_offsets, mask=index_mask, other=0)
262 # TODO: remove dtype_size once contiguous DMA is ensured
263 if index.dtype != tl.int32 and small_out:
264 index_cur_int32 = index_cur.to(tl.int32)
265 ld_st_2(
266 inp,
267 out,
268 batch_offsets,
269 index_offsets,
270 c_offsets,
271 inp_strides_0,
272 inp_strides_1,
273 out_strides_0,
274 out_strides_1,
275 index_cur_int32,
276 input_output_mask,
277 )
278 else:
279 ld_st_2(
280 inp,
281 out,
282 batch_offsets,
283 index_offsets,
284 c_offsets,
285 inp_strides_0,
286 inp_strides_1,
287 out_strides_0,
288 out_strides_1,
289 index_cur,
290 input_output_mask,
291 )
294def index_select(inp, dim, index):
295 logger.debug("GEMS_CAMBRICON INDEX SELECT")
296 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
297 assert index.ndim <= 1, "Index should have dimension 1 or 0"
298 # TODO: index is on device, should it be a kernel (like cnnl __assert_fail__) to check this?
299 assert ((i >= 0 and i < inp.size(dim)) for i in index), "Index out of range"
301 # TODO: make sure input is contiguous
303 if index.ndim == 0:
304 index = index.unsqueeze(0)
305 dim = dim % inp.ndim
306 inp_shape = list(inp.shape)
307 index_dim = index.numel()
309 # input [batch_dim, select_dim, c_dim]
310 # output [batch_dim, index_dim, c_dim]
311 inp = inp.contiguous()
312 index = index.contiguous()
313 inp_numel = inp.numel()
314 batch_dim = math.prod(inp_shape[:dim])
315 select_dim = inp_shape[dim]
316 c_dim = math.prod(inp_shape[(dim + 1) :])
318 out_shape = inp_shape
319 out_shape[dim] = index_dim
320 out = torch.empty(out_shape, dtype=inp.dtype, device=inp.device)
322 if torch.is_floating_point(inp):
323 dtype_size = torch.finfo(inp.dtype).bits // 8
324 else:
325 dtype_size = torch.iinfo(inp.dtype).bits // 8
327 if batch_dim == 1 and c_dim <= get_max_block_size(dtype_size):
328 # ram: (input, output), half, extra
329 # 2D, not split c_dim
330 def grid_fn(meta):
331 index_block_size_grid = max(meta["BLOCK_SIZE"] // c_dim, 1)
332 index_block_num = triton.cdiv(index_dim, index_block_size_grid)
333 return (min(index_block_num, TOTAL_CORE_NUM),)
335 one_batch_index_select_kernel[grid_fn](
336 out, index, index_dim, inp, c_dim, dtype_size, inp_numel
337 )
338 else:
339 grid = lambda meta: (
340 min(
341 triton.cdiv(batch_dim, meta["BLOCK_BATCH"])
342 * triton.cdiv(index_dim, meta["BLOCK_INDEX"])
343 * triton.cdiv(c_dim, meta["BLOCK_C"]),
344 TOTAL_CORE_NUM,
345 ),
346 )
347 multi_batch_index_select_kernel[grid](
348 inp,
349 index,
350 out,
351 batch_dim,
352 select_dim,
353 c_dim,
354 index_dim,
355 dtype_size,
356 inp_numel,
357 )
358 return out