Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/slice_scatter.py: 0%
29 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
1import logging
3import torch
4import triton
5from _kunlunxin.ops.copy import copy_slice
7from flag_gems.utils.shape_utils import MemOverlap, has_internal_overlapping
9logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
12def slice_scatter(inp, src, dim=0, start=None, end=None, step=1):
13 logger.debug("GEMS SLICE_SCATTER")
14 assert src.device == inp.device, "inp and src reside on different devices."
15 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
16 assert step > 0, "slice step must be positive"
17 dim = dim % inp.ndim
19 start = start or 0
20 end = end or inp.size(dim)
21 if end < 0:
22 end = end % inp.size(dim)
24 valid_shape = list(inp.shape)
25 valid_shape[dim] = triton.cdiv(end - start, step)
26 assert (
27 list(src.shape) == valid_shape
28 ), "Expected src to have a size equal to the slice of self"
30 if has_internal_overlapping(inp) == MemOverlap.Yes:
31 out = torch.empty(inp.size(), dtype=inp.dtype, device=inp.device)
32 else:
33 out = torch.empty_strided(
34 inp.size(), inp.stride(), dtype=inp.dtype, device=inp.device
35 )
37 ndim = inp.ndim
38 copy_slice(inp, out0=out)
40 indices = [slice(None)] * ndim
41 indices[dim] = slice(start, end, step)
42 out_ = out[indices]
43 copy_slice(src, out0=out_)
45 return out