Coverage for src/flag_gems/runtime/backend/_ascend/ops/slice_scatter.py: 0%
35 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-07 22:33 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-07 22:33 +0800
1import logging
3import torch
4import triton
6from flag_gems.utils.codegen_config_utils import CodeGenConfig
7from flag_gems.utils.pointwise_dynamic import pointwise_dynamic
8from flag_gems.utils.shape_utils import has_internal_overlapping
10logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}')
13config_ = CodeGenConfig(
14 1536,
15 (40, 1, 1),
16 32,
17 False,
18 prefer_1d_tile=int(triton.__version__[0]) < 3,
19)
22@pointwise_dynamic(is_tensor=[True], promotion_methods=[(0, "DEFAULT")], config=config_)
23@triton.jit
24def copy(src):
25 return src
28def slice_scatter(inp, src, dim=0, start=None, end=None, step=1):
29 logger.debug("GEMS_ASCEND SLICE_SCATTER")
30 assert src.device == inp.device, "inp and src reside on different devices."
31 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
32 assert step > 0, "slice step must be positive"
33 dim = dim % inp.ndim
35 start = start or 0
36 end = end or inp.size(dim)
37 if end < 0:
38 end = end % inp.size(dim)
40 valid_shape = list(inp.shape)
41 valid_shape[dim] = triton.cdiv(end - start, step)
42 assert (
43 list(src.shape) == valid_shape
44 ), "Expected src to have a size equal to the slice of self"
46 if has_internal_overlapping(inp):
47 out = torch.empty(inp.size(), dtype=inp.dtype, device=inp.device)
48 else:
49 out = torch.empty_strided(
50 inp.size(), inp.stride(), dtype=inp.dtype, device=inp.device
51 )
53 ndim = inp.ndim
54 copy(inp, out0=out)
56 indices = [slice(None)] * ndim
57 indices[dim] = slice(start, end, step)
58 out_ = out[indices]
59 copy(src, out0=out_)
61 return out