Coverage for src/flag_gems/ops/slice_backward.py: 68%
31 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def slice_backward_kernel(
8 grad_output_ptr,
9 grad_input_ptr,
10 numel,
11 inner,
12 slice_len,
13 dim_size,
14 start,
15 step,
16 BLOCK_SIZE: tl.constexpr,
17):
18 pid = tl.program_id(0)
20 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
22 mask = offsets < numel
24 grad = tl.load(grad_output_ptr + offsets, mask=mask)
26 outer_idx = offsets // (slice_len * inner)
28 slice_idx = (offsets // inner) % slice_len
30 inner_idx = offsets % inner
32 dim_index = start + slice_idx * step
34 input_offset = outer_idx * dim_size * inner + dim_index * inner + inner_idx
36 tl.store(grad_input_ptr + input_offset, grad, mask=mask)
39def slice_backward(
40 grad_output,
41 input_sizes,
42 dim,
43 start,
44 end,
45 step,
46):
47 grad_input = torch.zeros(
48 input_sizes,
49 device=grad_output.device,
50 dtype=grad_output.dtype,
51 )
53 shape = list(input_sizes)
55 slice_len = (end - start + step - 1) // step
57 outer = 1
58 for i in range(dim):
59 outer *= shape[i]
61 inner = 1
62 for i in range(dim + 1, len(shape)):
63 inner *= shape[i]
65 dim_size = shape[dim]
67 numel = grad_output.numel()
69 BLOCK = 1024
71 grid = lambda meta: (triton.cdiv(numel, meta["BLOCK_SIZE"]),)
73 slice_backward_kernel[grid](
74 grad_output,
75 grad_input,
76 numel,
77 inner,
78 slice_len,
79 dim_size,
80 start,
81 step,
82 BLOCK_SIZE=BLOCK,
83 )
85 return grad_input