Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/select_scatter.py: 0%
45 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-13 10:08 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-13 10:08 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils.shape_utils import MemOverlap, has_internal_overlapping
9logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
12@triton.jit
13def select_scatter_kernel(
14 out_ptr,
15 inp_ptr,
16 src_ptr,
17 total_elements,
18 dim_size,
19 dim_prod_post,
20 index,
21 BLOCK_SIZE: tl.constexpr,
22):
23 pid = tl.program_id(0)
24 block_start = pid * BLOCK_SIZE
25 offsets = tl.arange(0, BLOCK_SIZE)
26 mask = block_start + offsets < total_elements
27 idx = block_start + offsets
29 pre_idx = idx // (dim_size * dim_prod_post)
30 dim_idx = (idx // dim_prod_post) % dim_size
31 post_idx = idx % dim_prod_post
33 select_mask = dim_idx == index
35 inp_data = tl.load(inp_ptr + idx, mask=mask)
37 src_idx = pre_idx * dim_prod_post + post_idx
38 src_data = tl.load(src_ptr + src_idx, mask=mask & select_mask)
39 result = tl.where(select_mask, src_data, inp_data)
40 tl.store(out_ptr + idx, result, mask=mask)
43def select_scatter(inp, src, dim, index):
44 logger.debug("GEMS SELECT_SCATTER")
45 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
46 assert index >= -inp.size(dim) and index < inp.size(dim), "Invalid index"
47 dim = dim % inp.ndim
48 index = index % inp.size(dim)
50 valid_shape = list(inp.shape)
51 del valid_shape[dim]
52 assert (
53 list(src.shape) == valid_shape
54 ), "Expected src to have a size equal to the slice of self"
56 if has_internal_overlapping(inp) == MemOverlap.Yes:
57 out = torch.empty(inp.size(), dtype=inp.dtype, device=inp.device)
58 else:
59 out = torch.empty_strided(
60 inp.size(), inp.stride(), dtype=inp.dtype, device=inp.device
61 )
63 inp = inp.contiguous()
64 src = src.contiguous()
66 total_elements = inp.numel()
67 dim_size = inp.size(dim)
69 dim_prod_post = 1
70 for d in range(dim + 1, inp.ndim):
71 dim_prod_post *= inp.size(d)
73 BLOCK_SIZE = 1024
74 grid = (triton.cdiv(total_elements, BLOCK_SIZE),)
76 select_scatter_kernel[grid](
77 out,
78 inp,
79 src,
80 total_elements,
81 dim_size,
82 dim_prod_post,
83 index,
84 BLOCK_SIZE=BLOCK_SIZE,
85 )
87 return out