Coverage for src/flag_gems/experimental_ops/slice_scatter.py: 0%
87 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-23 02:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-23 02:03 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def _copy_kernel(x_ptr, y_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
8 pid = tl.program_id(axis=0)
9 block_start = pid * BLOCK_SIZE
10 offs = block_start + tl.arange(0, BLOCK_SIZE)
11 mask = offs < n_elements
12 x = tl.load(x_ptr + offs, mask=mask)
13 tl.store(y_ptr + offs, x, mask=mask)
16@triton.jit
17def _slice_scatter_kernel(
18 src_ptr, # pointer to src tensor (flattened)
19 out_ptr, # pointer to output tensor (flattened)
20 outer, # number of chunks before the sliced dimension
21 dim_size, # size of the sliced dimension in the output
22 inner, # number of elements after the sliced dimension
23 start, # start index along the sliced dimension
24 step, # step along the sliced dimension
25 m_size, # number of indices along the sliced dimension to scatter (len of slice)
26 n_src_elements, # total number of elements in src (outer * m_size * inner)
27 BLOCK_SIZE: tl.constexpr,
28):
29 pid = tl.program_id(axis=0)
30 offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
31 mask = offs < n_src_elements
33 # Promote to int64 for intermediate index math
34 offs_i64 = offs.to(tl.int64)
36 inner64 = tl.full([BLOCK_SIZE], inner, tl.int64)
37 m_size64 = tl.full([BLOCK_SIZE], m_size, tl.int64)
38 dim_size64 = tl.full([BLOCK_SIZE], dim_size, tl.int64)
39 start64 = tl.full([BLOCK_SIZE], start, tl.int64)
40 step64 = tl.full([BLOCK_SIZE], step, tl.int64)
42 chunk64 = m_size64 * inner64
43 o = offs_i64 // chunk64
44 rem = offs_i64 - o * chunk64
45 m = rem // inner64
46 i = rem - m * inner64
48 dest_d = start64 + m * step64 # index along sliced dimension
49 dest_linear = o * (dim_size64 * inner64) + dest_d * inner64 + i
51 val = tl.load(src_ptr + offs, mask=mask)
52 tl.store(out_ptr + dest_linear.to(tl.int32), val, mask=mask)
55def _normalize_slice_params(size, start, end, step):
56 assert step is not None and step != 0, "step must be non-zero"
57 # This implementation supports only positive step for simplicity
58 assert step > 0, "Only positive step is supported in this Triton implementation"
59 if start is None:
60 start = 0
61 if end is None:
62 end = size
63 if start < 0:
64 start += size
65 if end < 0:
66 end += size
67 # Clamp to [0, size]
68 start = max(0, min(start, size))
69 end = max(0, min(end, size))
70 if end <= start:
71 m = 0
72 else:
73 m = (end - start + step - 1) // step
74 return start, end, step, m
77def _slice_scatter_impl(input, src, out, dim=0, start=None, end=None, step=1):
78 assert (
79 input.is_cuda and src.is_cuda and out.is_cuda
80 ), "All tensors must be CUDA tensors"
81 assert (
82 input.is_contiguous() and src.is_contiguous() and out.is_contiguous()
83 ), "Tensors must be contiguous"
84 assert input.dtype == src.dtype == out.dtype, "All tensors must have the same dtype"
85 assert input.shape == out.shape, "Output must have same shape as input"
87 ndim = input.dim()
88 if ndim == 0:
89 # Scalar case: slice along dim doesn't apply, just copy input to out (and no scatter)
90 if input.numel() > 0:
91 n = input.numel()
92 grid = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"]),)
93 _copy_kernel[grid](input, out, n, BLOCK_SIZE=1024)
94 return out
96 dim = dim if dim >= 0 else dim + ndim
97 assert 0 <= dim < ndim, "dim out of range"
99 size_d = input.size(dim)
100 s, e, st, m = _normalize_slice_params(size_d, start, end, step)
102 # Compute outer and inner sizes
103 outer = 1
104 for k in range(0, dim):
105 outer *= input.size(k)
106 inner = 1
107 for k in range(dim + 1, ndim):
108 inner *= input.size(k)
110 # Validate src shape/numel
111 expected_src_numel = outer * m * inner
112 assert src.numel() == expected_src_numel, (
113 f"src numel mismatch: got {src.numel()}, expected {expected_src_numel} "
114 f"(outer={outer}, m={m}, inner={inner})"
115 )
117 # 1) Copy input -> out
118 n = input.numel()
119 if n > 0:
120 grid_copy = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"]),)
121 _copy_kernel[grid_copy](input, out, n, BLOCK_SIZE=1024)
123 # 2) Scatter src into the sliced region
124 if expected_src_numel > 0:
125 grid_scatter = lambda meta: (
126 triton.cdiv(expected_src_numel, meta["BLOCK_SIZE"]),
127 )
128 _slice_scatter_kernel[grid_scatter](
129 src,
130 out,
131 outer,
132 size_d,
133 inner,
134 s,
135 st,
136 m,
137 expected_src_numel,
138 BLOCK_SIZE=1024,
139 )
141 return out
144def slice_scatter(input, src, dim=0, start=None, end=None, step=1):
145 out = torch.empty_like(input)
146 return _slice_scatter_impl(
147 input, src, out, dim=dim, start=start, end=end, step=step
148 )
151def slice_scatter_out(input, src, dim=0, start=None, end=None, step=1, out=None):
152 if out is None:
153 out = torch.empty_like(input)
154 return _slice_scatter_impl(
155 input, src, out, dim=dim, start=start, end=end, step=step
156 )