Coverage for src/flag_gems/experimental_ops/slice_backward.py: 0%
76 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def _slice_backward_scatter_kernel(
8 grad_ptr, # *Pointer* to grad (input) vector
9 out_ptr, # *Pointer* to output (full grad) vector
10 n_elements, # numel of grad
11 inner, # product of sizes after 'dim'
12 gdim, # size of grad along 'dim'
13 odim, # size of output along 'dim'
14 start, # normalized start index along 'dim'
15 step, # step along 'dim'
16 BLOCK_SIZE: tl.constexpr,
17):
18 pid = tl.program_id(axis=0)
19 block_start = pid * BLOCK_SIZE
20 offs = block_start + tl.arange(0, BLOCK_SIZE)
21 offs = offs.to(tl.int64)
22 mask = offs < n_elements
24 inner_i64 = tl.full([1], inner, tl.int64)
25 gdim_i64 = tl.full([1], gdim, tl.int64)
26 odim_i64 = tl.full([1], odim, tl.int64)
27 start_i64 = tl.full([1], start, tl.int64)
28 step_i64 = tl.full([1], step, tl.int64)
30 # Decompose linear index into (outer, g_idx_dim, inner_idx)
31 outer = offs // (gdim_i64 * inner_i64)
32 inner_idx = offs % inner_i64
33 g_idx_dim = (offs // inner_i64) % gdim_i64
35 out_dim_index = start_i64 + g_idx_dim * step_i64
36 valid_o = (out_dim_index >= 0) & (out_dim_index < odim_i64)
37 o = outer * (odim_i64 * inner_i64) + out_dim_index * inner_i64 + inner_idx
39 m = mask & valid_o
40 val = tl.load(grad_ptr + offs, mask=m, other=0)
41 tl.store(out_ptr + o, val, mask=m)
44def _normalize_slice_params(input_sizes, dim, start, end, step):
45 D = len(input_sizes)
46 if dim < 0:
47 dim += D
48 size_dim = int(input_sizes[dim])
50 if step is None:
51 step = 1
52 if step == 0:
53 raise ValueError("slice step cannot be zero")
55 if start is None:
56 start = 0 if step > 0 else size_dim - 1
58 # Normalize start into valid index range
59 if start < 0:
60 start += size_dim
62 if step > 0:
63 # Clamp into [0, size_dim]
64 if start < 0:
65 start = 0
66 if start > size_dim:
67 start = size_dim
68 else:
69 # Clamp into [0, size_dim-1]
70 if start < 0:
71 start = 0
72 if start >= size_dim:
73 start = size_dim - 1
75 return dim, int(start), int(step)
78def _launch_slice_backward_kernel(grad, input_sizes, dim, start, end, step, out):
79 # Ensure contiguous
80 grad_c = grad.contiguous()
81 out_c = out.contiguous()
83 # Normalize parameters
84 dim, start_n, step_n = _normalize_slice_params(
85 list(input_sizes), int(dim), start, end, step
86 )
88 # Compute inner, gdim, odim
89 sizes = list(input_sizes)
90 odim = int(sizes[dim])
91 gdim = int(grad_c.shape[dim])
92 inner = 1
93 for s in sizes[dim + 1 :]:
94 inner *= int(s)
96 n_elements = grad_c.numel()
97 if n_elements == 0:
98 return out # nothing to do
100 # Zero the output tensor
101 out_c.zero_()
103 # Launch Triton kernel
104 BLOCK_SIZE = 1024
105 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
106 _slice_backward_scatter_kernel[grid](
107 grad_c,
108 out_c,
109 n_elements,
110 inner,
111 gdim,
112 odim,
113 start_n,
114 step_n,
115 BLOCK_SIZE=BLOCK_SIZE,
116 )
117 return out_c
120def slice_backward(grad, input_sizes, dim, start, end, step):
121 """
122 Python wrapper for aten::slice_backward
123 """
124 out = torch.empty(tuple(input_sizes), device=grad.device, dtype=grad.dtype)
125 out = _launch_slice_backward_kernel(grad, input_sizes, dim, start, end, step, out)
126 return out
129def slice_backward_out(grad, input_sizes, dim, start, end, step, out):
130 """
131 Python wrapper for aten::slice_backward.out
132 """
133 if tuple(out.shape) != tuple(input_sizes):
134 raise ValueError("Output tensor shape must match input_sizes")
135 if out.device != grad.device or out.dtype != grad.dtype:
136 raise ValueError("Output tensor must have same device and dtype as grad")
137 _launch_slice_backward_kernel(grad, input_sizes, dim, start, end, step, out)
138 return out