Coverage for src/flag_gems/experimental_ops/take.py: 0%

39 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-28 12:23 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def take_kernel( 

8 in_ptr, # pointer to input flattened tensor 

9 idx_ptr, # pointer to flattened indices (int32) 

10 out_ptr, # pointer to flattened output tensor 

11 n_index, # number of indices (int32) 

12 in_numel, # number of elements in input (int32) 

13 BLOCK_SIZE: tl.constexpr, 

14): 

15 pid = tl.program_id(axis=0) 

16 block_start = pid * BLOCK_SIZE 

17 offs = block_start + tl.arange(0, BLOCK_SIZE) 

18 mask = offs < n_index 

19 

20 idx = tl.load(idx_ptr + offs, mask=mask, other=0).to(tl.int32) 

21 

22 # Bounds check to avoid OOB memory access; assumes valid indices in normal use. 

23 in_range = (idx >= 0) & (idx < in_numel) & mask 

24 idx_safe = tl.maximum(0, tl.minimum(idx, in_numel - 1)) 

25 

26 vals = tl.load(in_ptr + idx_safe, mask=mask, other=0) 

27 # Zero out values for invalid indices (shouldn't happen if inputs are valid) 

28 vals = tl.where(in_range, vals, 0) 

29 tl.store(out_ptr + offs, vals, mask=mask) 

30 

31 

32def _launch_take(input: torch.Tensor, index: torch.Tensor, out_flat: torch.Tensor): 

33 assert ( 

34 input.is_cuda and index.is_cuda and out_flat.is_cuda 

35 ), "All tensors must be CUDA tensors" 

36 # Flatten input as per torch.take semantics (use contiguous flattened memory) 

37 input_flat = input.contiguous().view(-1) 

38 # Indices flattened and converted to int32 for kernel 

39 index_flat = index.contiguous().view(-1).to(torch.int32) 

40 n_index = index_flat.numel() 

41 if n_index == 0: 

42 return 

43 grid = lambda meta: (triton.cdiv(n_index, meta["BLOCK_SIZE"]),) 

44 take_kernel[grid]( 

45 input_flat, 

46 index_flat, 

47 out_flat, 

48 n_index, 

49 input_flat.numel(), 

50 BLOCK_SIZE=1024, 

51 ) 

52 

53 

54def take(input: torch.Tensor, index: torch.Tensor): 

55 """ 

56 Wrapper for aten::take 

57 Returns a 1-D tensor with elements of input at the given flat indices in index. 

58 """ 

59 assert input.device == index.device, "input and index must be on the same device" 

60 out_flat = torch.empty(index.numel(), device=input.device, dtype=input.dtype) 

61 _launch_take(input, index, out_flat) 

62 return out_flat.view(index.shape) 

63 

64 

65def take_out(input: torch.Tensor, index: torch.Tensor, out: torch.Tensor): 

66 """ 

67 Wrapper for aten::take.out 

68 Writes result into 'out' and returns it. 

69 """ 

70 assert ( 

71 input.device == index.device == out.device 

72 ), "All tensors must be on the same device" 

73 # Ensure output has correct dtype and shape; resize if needed 

74 if out.dtype != input.dtype: 

75 raise TypeError( 

76 f"out dtype {out.dtype} does not match input dtype {input.dtype}" 

77 ) 

78 if out.numel() != index.numel() or tuple(out.shape) != tuple(index.shape): 

79 out.resize_(index.shape) 

80 

81 # Use a temporary contiguous flat buffer to ensure correctness even if 'out' is non-contiguous 

82 tmp_flat = torch.empty(index.numel(), device=input.device, dtype=input.dtype) 

83 _launch_take(input, index, tmp_flat) 

84 out.copy_(tmp_flat.view(index.shape)) 

85 return out