Coverage for src/flag_gems/ops/masked_select.py: 49%
105 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-23 02:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-23 02:03 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils import broadcastable, libentry
9from flag_gems.utils.shape_utils import bracket_next_power_of_2
11logger = logging.getLogger(__name__)
14@libentry()
15@triton.jit
16def masked_select_single_pass_kernel(
17 inp_ptr, mask_ptr, out_ptr, N, BLOCK_SIZE: tl.constexpr
18):
19 pid = tl.program_id(0)
20 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
21 inp = tl.load(inp_ptr + offsets, mask=offsets < N)
22 mask = tl.load(mask_ptr + offsets, mask=offsets < N).to(tl.int1)
23 mask_ints = mask.to(tl.int32)
24 out_offsets = tl.cumsum(mask_ints, axis=0) - 1
26 tl.store(out_ptr + out_offsets, inp, mask=(offsets < N) & mask)
29def masked_select_single_pass(inp, mask, out, N):
30 BLOCK_SIZE = triton.next_power_of_2(N)
31 if BLOCK_SIZE <= 512:
32 num_warps = 4
33 elif BLOCK_SIZE <= 2048:
34 num_warps = 8
35 else:
36 num_warps = 16
37 masked_select_single_pass_kernel[(1,)](
38 inp, mask, out, N, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps
39 )
40 return out
43@libentry()
44@triton.jit(do_not_specialize=["N", "nr", "row_stride"])
45def mask_part_sum_kernel(
46 inp_ptr,
47 mask_ptr,
48 part_sums_ptr,
49 counter_ptr,
50 N,
51 num_blocks,
52 num_blocks_per_row,
53 NP_BLOCK: tl.constexpr,
54 BLOCK_SIZE: tl.constexpr,
55):
56 row_id = tl.program_id(0)
57 start_block = row_id * num_blocks_per_row
58 offset = start_block * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
59 acc = tl.zeros((BLOCK_SIZE,), dtype=part_sums_ptr.dtype.element_ty)
61 last_block_id = min(num_blocks - 1, start_block + num_blocks_per_row - 1)
63 for block_id in range(start_block, last_block_id):
64 select = tl.load(mask_ptr + offset)
65 select_ints = select.to(part_sums_ptr.dtype.element_ty)
66 acc += select_ints
67 offset += BLOCK_SIZE
68 # Peeled last block
69 select = tl.load(mask_ptr + offset, mask=offset < N, other=0)
70 select_ints = select.to(part_sums_ptr.dtype.element_ty)
71 acc += select_ints
73 part_sum = tl.sum(acc, axis=0)
74 tl.store(part_sums_ptr + row_id, part_sum)
75 # cumsum the part_sums
76 count = tl.atomic_add(counter_ptr, 1, sem="acq_rel")
77 np = tl.num_programs(0)
78 if count == np - 1:
79 mask = tl.arange(0, NP_BLOCK) < np
80 part_sums = tl.load(part_sums_ptr + tl.arange(0, NP_BLOCK), mask=mask)
81 final_sum = tl.sum(part_sums, axis=0)
82 pre_sums = tl.cumsum(part_sums, axis=0)
83 tl.store(
84 part_sums_ptr + tl.arange(0, NP_BLOCK), pre_sums - part_sums, mask=mask
85 )
86 tl.store(part_sums_ptr + np, final_sum)
89@libentry()
90@triton.jit(do_not_specialize=["N", "nr", "row_stride"])
91def write_back_kernel(
92 inp_ptr,
93 mask_ptr,
94 part_sums_ptr,
95 out_ptr,
96 N,
97 num_blocks,
98 num_blocks_per_row,
99 NP_BLOCK: tl.constexpr,
100 BLOCK_SIZE: tl.constexpr,
101):
102 row_id = tl.program_id(0)
104 start_block = row_id * num_blocks_per_row
105 offset = start_block * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
106 advance = tl.load(part_sums_ptr + row_id)
108 last_block_id = min(num_blocks - 1, start_block + num_blocks_per_row - 1)
110 for block_id in range(start_block, last_block_id):
111 inp = tl.load(inp_ptr + offset)
112 select_mask = tl.load(mask_ptr + offset).to(tl.int1)
113 select_ints = select_mask.to(tl.constexpr(part_sums_ptr.dtype.element_ty))
114 out_ptr += advance
115 advance = tl.sum(select_ints, axis=0)
116 pre_sums = tl.cumsum(select_ints, axis=0) - 1
117 tl.store(out_ptr + pre_sums, inp, mask=select_mask)
118 offset += BLOCK_SIZE
119 # Peeled last block
120 inp = tl.load(inp_ptr + offset, mask=offset < N)
121 select_mask = tl.load(mask_ptr + offset, mask=offset < N, other=0).to(tl.int1)
122 select_ints = select_mask.to(tl.constexpr(part_sums_ptr.dtype.element_ty))
123 out_ptr += advance
124 pre_sums = tl.cumsum(select_ints, axis=0) - 1
125 tl.store(out_ptr + pre_sums, inp, mask=(offset < N) & select_mask)
128def masked_select(inp, mask):
129 logger.debug("GEMS MASKED SELECT")
131 inp_shape = tuple(inp.shape)
132 mask_shape = tuple(mask.shape)
134 assert broadcastable(
135 inp_shape, mask_shape
136 ), "The shapes of the `mask` and the `input` tensor must be broadcastable"
137 inp, mask = torch.broadcast_tensors(inp, mask)
139 inp = inp.contiguous()
140 mask = mask.contiguous()
142 N = inp.numel()
143 if N <= 4096:
144 out = torch.empty(mask.sum(), dtype=inp.dtype, device=inp.device)
145 return masked_select_single_pass(inp, mask, out, N)
147 # return mask_select(inp, mask)
149 BLOCK_SIZE = bracket_next_power_of_2(N, 128, 4096)
150 num_warps = min(16, BLOCK_SIZE // 32)
152 # max degree of parallelism
153 np = torch_device_fn.get_device_properties(mask.device).multi_processor_count
155 # arranged as np rows of blocks
156 n_blocks = triton.cdiv(N, BLOCK_SIZE)
157 np = min(n_blocks, np)
158 n_blocks_per_row = triton.cdiv(n_blocks, np)
159 np = triton.cdiv(n_blocks, n_blocks_per_row)
160 NP_BLOCK = triton.next_power_of_2(np)
162 with torch_device_fn.device(inp.device):
163 # Compute per cta sums and cumulative sums across ctas
164 dtype = torch.int32 if N < 2**31 else torch.int64
165 part_sums = torch.empty(np + 1, dtype=dtype, device=mask.device)
166 barrier = torch.zeros([], dtype=torch.int, device=mask.device)
167 mask_part_sum_kernel[(np,)](
168 inp,
169 mask,
170 part_sums,
171 barrier,
172 N,
173 n_blocks,
174 n_blocks_per_row,
175 NP_BLOCK=NP_BLOCK,
176 BLOCK_SIZE=BLOCK_SIZE,
177 num_warps=num_warps,
178 )
180 # Write back selected data
181 out = torch.empty(part_sums[-1], dtype=inp.dtype, device=mask.device)
182 # write_offsets = pre_sums - part_sums
183 write_back_kernel[(np,)](
184 inp,
185 mask,
186 part_sums,
187 out,
188 N,
189 n_blocks,
190 n_blocks_per_row,
191 NP_BLOCK=triton.next_power_of_2(np),
192 BLOCK_SIZE=BLOCK_SIZE,
193 num_warps=num_warps,
194 )
196 return out