Coverage for src/flag_gems/experimental_ops/take.py: 0%
39 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def take_kernel(
8 in_ptr, # pointer to input flattened tensor
9 idx_ptr, # pointer to flattened indices (int32)
10 out_ptr, # pointer to flattened output tensor
11 n_index, # number of indices (int32)
12 in_numel, # number of elements in input (int32)
13 BLOCK_SIZE: tl.constexpr,
14):
15 pid = tl.program_id(axis=0)
16 block_start = pid * BLOCK_SIZE
17 offs = block_start + tl.arange(0, BLOCK_SIZE)
18 mask = offs < n_index
20 idx = tl.load(idx_ptr + offs, mask=mask, other=0).to(tl.int32)
22 # Bounds check to avoid OOB memory access; assumes valid indices in normal use.
23 in_range = (idx >= 0) & (idx < in_numel) & mask
24 idx_safe = tl.maximum(0, tl.minimum(idx, in_numel - 1))
26 vals = tl.load(in_ptr + idx_safe, mask=mask, other=0)
27 # Zero out values for invalid indices (shouldn't happen if inputs are valid)
28 vals = tl.where(in_range, vals, 0)
29 tl.store(out_ptr + offs, vals, mask=mask)
32def _launch_take(input: torch.Tensor, index: torch.Tensor, out_flat: torch.Tensor):
33 assert (
34 input.is_cuda and index.is_cuda and out_flat.is_cuda
35 ), "All tensors must be CUDA tensors"
36 # Flatten input as per torch.take semantics (use contiguous flattened memory)
37 input_flat = input.contiguous().view(-1)
38 # Indices flattened and converted to int32 for kernel
39 index_flat = index.contiguous().view(-1).to(torch.int32)
40 n_index = index_flat.numel()
41 if n_index == 0:
42 return
43 grid = lambda meta: (triton.cdiv(n_index, meta["BLOCK_SIZE"]),)
44 take_kernel[grid](
45 input_flat,
46 index_flat,
47 out_flat,
48 n_index,
49 input_flat.numel(),
50 BLOCK_SIZE=1024,
51 )
54def take(input: torch.Tensor, index: torch.Tensor):
55 """
56 Wrapper for aten::take
57 Returns a 1-D tensor with elements of input at the given flat indices in index.
58 """
59 assert input.device == index.device, "input and index must be on the same device"
60 out_flat = torch.empty(index.numel(), device=input.device, dtype=input.dtype)
61 _launch_take(input, index, out_flat)
62 return out_flat.view(index.shape)
65def take_out(input: torch.Tensor, index: torch.Tensor, out: torch.Tensor):
66 """
67 Wrapper for aten::take.out
68 Writes result into 'out' and returns it.
69 """
70 assert (
71 input.device == index.device == out.device
72 ), "All tensors must be on the same device"
73 # Ensure output has correct dtype and shape; resize if needed
74 if out.dtype != input.dtype:
75 raise TypeError(
76 f"out dtype {out.dtype} does not match input dtype {input.dtype}"
77 )
78 if out.numel() != index.numel() or tuple(out.shape) != tuple(index.shape):
79 out.resize_(index.shape)
81 # Use a temporary contiguous flat buffer to ensure correctness even if 'out' is non-contiguous
82 tmp_flat = torch.empty(index.numel(), device=input.device, dtype=input.dtype)
83 _launch_take(input, index, tmp_flat)
84 out.copy_(tmp_flat.view(index.shape))
85 return out