Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/slice_scatter.py: 0%

29 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-11 02:28 +0800

1import logging 

2 

3import torch 

4import triton 

5from _kunlunxin.ops.copy import copy_slice 

6 

7from flag_gems.utils.shape_utils import MemOverlap, has_internal_overlapping 

8 

9logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

10 

11 

12def slice_scatter(inp, src, dim=0, start=None, end=None, step=1): 

13 logger.debug("GEMS SLICE_SCATTER") 

14 assert src.device == inp.device, "inp and src reside on different devices." 

15 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

16 assert step > 0, "slice step must be positive" 

17 dim = dim % inp.ndim 

18 

19 start = start or 0 

20 end = end or inp.size(dim) 

21 if end < 0: 

22 end = end % inp.size(dim) 

23 

24 valid_shape = list(inp.shape) 

25 valid_shape[dim] = triton.cdiv(end - start, step) 

26 assert ( 

27 list(src.shape) == valid_shape 

28 ), "Expected src to have a size equal to the slice of self" 

29 

30 if has_internal_overlapping(inp) == MemOverlap.Yes: 

31 out = torch.empty(inp.size(), dtype=inp.dtype, device=inp.device) 

32 else: 

33 out = torch.empty_strided( 

34 inp.size(), inp.stride(), dtype=inp.dtype, device=inp.device 

35 ) 

36 

37 ndim = inp.ndim 

38 copy_slice(inp, out0=out) 

39 

40 indices = [slice(None)] * ndim 

41 indices[dim] = slice(start, end, step) 

42 out_ = out[indices] 

43 copy_slice(src, out0=out_) 

44 

45 return out