Coverage for src/flag_gems/runtime/backend/_cambricon/ops/slice_scatter.py: 0%
51 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-19 02:32 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-19 02:32 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils.shape_utils import MemOverlap, has_internal_overlapping
9logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
12@triton.jit
13def slice_scatter_kernel(
14 out_ptr,
15 inp_ptr,
16 src_ptr,
17 total_elements,
18 dim_size,
19 dim_prod_post,
20 start,
21 step,
22 src_dim_size,
23 BLOCK_SIZE: tl.constexpr,
24):
25 pid = tl.program_id(0)
26 block_start = pid * BLOCK_SIZE
27 offsets = tl.arange(0, BLOCK_SIZE)
28 mask = block_start + offsets < total_elements
30 idx = block_start + offsets
31 pre_idx = idx // (dim_size * dim_prod_post)
32 dim_idx = (idx // dim_prod_post) % dim_size
33 post_idx = idx % dim_prod_post
35 slice_mask = (
36 (dim_idx >= start)
37 & (dim_idx < start + src_dim_size * step)
38 & ((dim_idx - start) % step == 0)
39 )
41 inp_data = tl.load(inp_ptr + idx, mask=mask)
43 src_dim_idx = (dim_idx - start) // step
44 src_idx = (
45 pre_idx * src_dim_size * dim_prod_post + src_dim_idx * dim_prod_post + post_idx
46 )
47 src_data = tl.load(src_ptr + src_idx, mask=mask & slice_mask)
48 result = tl.where(slice_mask, src_data, inp_data)
49 tl.store(out_ptr + idx, result, mask=mask)
52def slice_scatter(inp, src, dim=0, start=None, end=None, step=1):
53 logger.debug("GEMS_CAMBRICON SLICE_SCATTER")
54 assert src.device == inp.device, "inp and src reside on different devices."
55 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
56 assert step > 0, "slice step must be positive"
57 dim = dim % inp.ndim
59 start = start or 0
60 end = end or inp.size(dim)
61 if end < 0:
62 end = end % inp.size(dim)
64 valid_shape = list(inp.shape)
65 valid_shape[dim] = triton.cdiv(end - start, step)
66 assert (
67 list(src.shape) == valid_shape
68 ), "Expected src to have a size equal to the slice of self"
70 if has_internal_overlapping(inp) == MemOverlap.Yes:
71 out = torch.empty(inp.size(), dtype=inp.dtype, device=inp.device)
72 else:
73 out = torch.empty_strided(
74 inp.size(), inp.stride(), dtype=inp.dtype, device=inp.device
75 )
77 inp = inp.contiguous()
78 src = src.contiguous()
80 total_elements = inp.numel()
81 dim_size = inp.size(dim)
82 src_dim_size = src.size(dim)
84 dim_prod_post = 1
85 for d in range(dim + 1, inp.ndim):
86 dim_prod_post *= inp.size(d)
88 BLOCK_SIZE = 2048
89 grid = (triton.cdiv(total_elements, BLOCK_SIZE),)
91 slice_scatter_kernel[grid](
92 out,
93 inp,
94 src,
95 total_elements,
96 dim_size,
97 dim_prod_post,
98 start,
99 step,
100 src_dim_size,
101 BLOCK_SIZE=BLOCK_SIZE,
102 )
104 return out