Coverage for src/flag_gems/experimental_ops/masked_select.py: 0%
85 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-17 02:35 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-17 02:35 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def _masked_select_count_kernel(
8 mask_ptr, # int32* flattened mask (0/1)
9 n_elements, # int32 number of elements
10 counts_ptr, # int32* per-block counts
11 BLOCK_SIZE: tl.constexpr,
12):
13 pid = tl.program_id(axis=0)
14 block_start = pid * BLOCK_SIZE
15 offsets = block_start + tl.arange(0, BLOCK_SIZE)
16 in_bounds = offsets < n_elements
18 flags = tl.load(mask_ptr + offsets, mask=in_bounds, other=0) # int32 0/1
19 block_count = tl.sum(flags, axis=0)
20 tl.store(counts_ptr + pid, block_count)
23@triton.jit
24def _masked_select_scatter_kernel(
25 input_ptr, # * input data (flattened, contiguous)
26 mask_ptr, # int32* flattened mask (0/1)
27 block_offsets_ptr, # int32* per-block exclusive offsets
28 output_ptr, # * output data
29 n_elements, # int32 number of elements
30 BLOCK_SIZE: tl.constexpr,
31):
32 pid = tl.program_id(axis=0)
33 block_start = pid * BLOCK_SIZE
34 offsets = block_start + tl.arange(0, BLOCK_SIZE)
35 in_bounds = offsets < n_elements
37 flags = tl.load(mask_ptr + offsets, mask=in_bounds, other=0) # int32
38 # Compute local exclusive positions for true elements
39 inclusive = tl.cumsum(flags, axis=0)
40 local_pos = inclusive - 1 # valid only where flags == 1
42 base = tl.load(block_offsets_ptr + pid) # int32
43 write_idx = base + local_pos
45 mstore = in_bounds & (flags != 0)
46 vals = tl.load(input_ptr + offsets, mask=mstore, other=0)
47 tl.store(output_ptr + write_idx, vals, mask=mstore)
50def _prepare_broadcast_flatten(input: torch.Tensor, mask: torch.Tensor):
51 # Broadcast input and mask to a common shape
52 bshape = torch.broadcast_shapes(tuple(input.shape), tuple(mask.shape))
53 inp_b = input.expand(bshape)
54 msk_b = mask.to(torch.bool).expand(bshape)
56 # Make contiguous flattened views
57 inp_flat = inp_b.contiguous().view(-1)
58 msk_flat_bool = msk_b.contiguous().view(-1)
59 # Convert mask to int32 (0/1) for kernels
60 msk_flat_i32 = msk_flat_bool.to(torch.int32)
61 return inp_flat, msk_flat_i32
64def masked_select(input: torch.Tensor, mask: torch.Tensor):
65 inp_flat, msk_flat_i32 = _prepare_broadcast_flatten(input, mask)
66 device = inp_flat.device
67 assert msk_flat_i32.device == device, "input and mask must be on the same device"
69 n_elements = inp_flat.numel()
70 if n_elements == 0:
71 return torch.empty(0, dtype=input.dtype, device=device)
73 BLOCK_SIZE = 1024
74 num_blocks = triton.cdiv(n_elements, BLOCK_SIZE)
76 counts = torch.empty(num_blocks, dtype=torch.int32, device=device)
77 grid = (num_blocks,)
78 _masked_select_count_kernel[grid](
79 msk_flat_i32, n_elements, counts, BLOCK_SIZE=BLOCK_SIZE
80 )
82 # Compute per-block exclusive offsets and total number of selected elements
83 if num_blocks == 1:
84 block_offsets = torch.zeros(1, dtype=torch.int32, device=device)
85 total_selected = int(counts[0].item())
86 else:
87 prefix = torch.cumsum(counts, dim=0)
88 block_offsets = torch.empty_like(counts)
89 block_offsets[0] = 0
90 block_offsets[1:] = prefix[:-1]
91 total_selected = int(prefix[-1].item())
93 output = torch.empty(total_selected, dtype=input.dtype, device=device)
94 _masked_select_scatter_kernel[grid](
95 inp_flat, msk_flat_i32, block_offsets, output, n_elements, BLOCK_SIZE=BLOCK_SIZE
96 )
97 return output
100def masked_select_out(input: torch.Tensor, mask: torch.Tensor, out: torch.Tensor):
101 inp_flat, msk_flat_i32 = _prepare_broadcast_flatten(input, mask)
102 device = inp_flat.device
103 assert msk_flat_i32.device == device, "input and mask must be on the same device"
104 if out.device != device:
105 raise RuntimeError("out tensor must be on the same device as input")
107 n_elements = inp_flat.numel()
108 if n_elements == 0:
109 out.resize_(0)
110 return out
112 BLOCK_SIZE = 1024
113 num_blocks = triton.cdiv(n_elements, BLOCK_SIZE)
115 counts = torch.empty(num_blocks, dtype=torch.int32, device=device)
116 grid = (num_blocks,)
117 _masked_select_count_kernel[grid](
118 msk_flat_i32, n_elements, counts, BLOCK_SIZE=BLOCK_SIZE
119 )
121 # Compute per-block exclusive offsets and total number of selected elements
122 if num_blocks == 1:
123 block_offsets = torch.zeros(1, dtype=torch.int32, device=device)
124 total_selected = int(counts[0].item())
125 else:
126 prefix = torch.cumsum(counts, dim=0)
127 block_offsets = torch.empty_like(counts)
128 block_offsets[0] = 0
129 block_offsets[1:] = prefix[:-1]
130 total_selected = int(prefix[-1].item())
132 if out.dtype != input.dtype:
133 raise RuntimeError("out tensor dtype must match input dtype")
134 out.resize_(total_selected)
136 _masked_select_scatter_kernel[grid](
137 inp_flat, msk_flat_i32, block_offsets, out, n_elements, BLOCK_SIZE=BLOCK_SIZE
138 )
139 return out