Coverage for src/flag_gems/ops/select_backward.py: 59%
49 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-26 15:32 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-26 15:32 +0800
1import math
3import torch
4import triton
5import triton.language as tl
8@triton.jit
9def _select_backward_kernel(
10 grad_ptr,
11 out_ptr,
12 outer_size,
13 inner_size,
14 dim_stride,
15 index,
16 BLOCK: tl.constexpr,
17):
18 pid = tl.program_id(0)
20 offs = pid * BLOCK + tl.arange(0, BLOCK)
21 total = outer_size * inner_size
23 mask = offs < total
25 outer = offs // inner_size
26 inner = offs % inner_size
28 grad_vals = tl.load(grad_ptr + outer * inner_size + inner, mask=mask)
30 out_offset = outer * dim_stride + index * inner_size + inner
32 tl.store(out_ptr + out_offset, grad_vals, mask=mask)
35def _launch_select_backward(grad, input_sizes, dim, index, out=None):
36 if not grad.is_cuda:
37 raise ValueError("grad must be CUDA tensor")
39 dim = int(dim)
40 index = int(index)
42 sizes = list(input_sizes)
43 ndim = len(sizes)
45 if dim < 0:
46 dim += ndim
48 if dim < 0 or dim >= ndim:
49 raise ValueError("invalid dim")
51 dim_size = sizes[dim]
53 if index < 0 or index >= dim_size:
54 raise ValueError("index out of range")
56 outer_size = math.prod(sizes[:dim]) if dim > 0 else 1
57 inner_size = math.prod(sizes[dim + 1 :]) if dim < ndim - 1 else 1
59 grad_view = grad.contiguous().view(outer_size, inner_size)
61 if out is None:
62 out = torch.zeros(
63 sizes,
64 dtype=grad.dtype,
65 device=grad.device,
66 )
67 else:
68 if tuple(out.shape) != tuple(sizes):
69 raise ValueError("out shape mismatch")
70 if out.dtype != grad.dtype:
71 raise ValueError("dtype mismatch")
72 if out.device != grad.device:
73 raise ValueError("device mismatch")
75 out.zero_()
77 dim_stride = dim_size * inner_size
79 BLOCK = 1024
80 n_elements = outer_size * inner_size
81 grid = (triton.cdiv(n_elements, BLOCK),)
83 _select_backward_kernel[grid](
84 grad_view,
85 out,
86 outer_size,
87 inner_size,
88 dim_stride,
89 index,
90 BLOCK=BLOCK,
91 )
93 return out
96def select_backward(grad, input_sizes, dim, index, out=None):
97 return _launch_select_backward(grad, input_sizes, dim, index, out=out)