Coverage for src/flag_gems/ops/masked_scatter.py: 50%
122 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-23 02:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-23 02:03 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils import broadcastable, libentry
9from flag_gems.utils.shape_utils import bracket_next_power_of_2
11logger = logging.getLogger(__name__)
14@libentry()
15@triton.jit
16def masked_scatter_single_pass_kernel(
17 inp_ptr, mask_ptr, src_ptr, N, BLOCK_SIZE: tl.constexpr
18):
19 pid = tl.program_id(0)
20 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
22 block_mask = offsets < N
24 mask_val = tl.load(mask_ptr + offsets, mask=block_mask, other=0).to(tl.int1)
26 mask_ints = mask_val.to(tl.int32)
27 src_indices = tl.cumsum(mask_ints, axis=0) - 1
29 active = block_mask & mask_val
30 src_val = tl.load(src_ptr + src_indices, mask=active)
31 tl.store(inp_ptr + offsets, src_val, mask=active)
34@libentry()
35@triton.jit(do_not_specialize=["N", "num_blocks", "num_blocks_per_row"])
36def mask_part_sum_kernel(
37 mask_ptr,
38 part_sums_ptr,
39 counter_ptr,
40 N,
41 num_blocks,
42 num_blocks_per_row,
43 NP_BLOCK: tl.constexpr,
44 BLOCK_SIZE: tl.constexpr,
45):
46 row_id = tl.program_id(0)
47 start_block = row_id * num_blocks_per_row
48 offset = start_block * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
49 acc = tl.zeros((BLOCK_SIZE,), dtype=part_sums_ptr.dtype.element_ty)
51 last_block_id = min(num_blocks - 1, start_block + num_blocks_per_row - 1)
53 for block_id in range(start_block, last_block_id):
54 select = tl.load(mask_ptr + offset)
55 select_ints = select.to(part_sums_ptr.dtype.element_ty)
56 acc += select_ints
57 offset += BLOCK_SIZE
59 select = tl.load(mask_ptr + offset, mask=offset < N, other=0)
60 select_ints = select.to(part_sums_ptr.dtype.element_ty)
61 acc += select_ints
63 part_sum = tl.sum(acc, axis=0)
64 tl.store(part_sums_ptr + row_id, part_sum)
66 count = tl.atomic_add(counter_ptr, 1, sem="acq_rel")
67 np = tl.num_programs(0)
69 if count == np - 1:
70 mask = tl.arange(0, NP_BLOCK) < np
71 part_sums = tl.load(part_sums_ptr + tl.arange(0, NP_BLOCK), mask=mask)
72 final_sum = tl.sum(part_sums, axis=0)
73 pre_sums = tl.cumsum(part_sums, axis=0)
74 tl.store(
75 part_sums_ptr + tl.arange(0, NP_BLOCK), pre_sums - part_sums, mask=mask
76 )
77 tl.store(part_sums_ptr + np, final_sum)
80@libentry()
81@triton.jit(do_not_specialize=["N", "num_blocks", "num_blocks_per_row"])
82def masked_scatter_kernel(
83 inp_ptr,
84 mask_ptr,
85 src_ptr,
86 part_sums_ptr,
87 N,
88 num_blocks,
89 num_blocks_per_row,
90 BLOCK_SIZE: tl.constexpr,
91):
92 row_id = tl.program_id(0)
94 start_block = row_id * num_blocks_per_row
95 offset = start_block * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
97 advance = tl.load(part_sums_ptr + row_id)
99 last_block_id = min(num_blocks - 1, start_block + num_blocks_per_row - 1)
101 for block_id in range(start_block, last_block_id):
102 select_mask = tl.load(mask_ptr + offset).to(tl.int1)
103 select_ints = select_mask.to(tl.int32)
105 block_cumsum = tl.cumsum(select_ints, axis=0) - 1
106 global_src_idx = advance + block_cumsum
108 advance += tl.sum(select_ints, axis=0)
110 src_val = tl.load(src_ptr + global_src_idx, mask=select_mask)
111 tl.store(inp_ptr + offset, src_val, mask=select_mask)
113 offset += BLOCK_SIZE
115 block_mask = offset < N
116 select_mask = tl.load(mask_ptr + offset, mask=block_mask, other=0).to(tl.int1)
118 select_ints = select_mask.to(tl.int32)
119 block_cumsum = tl.cumsum(select_ints, axis=0) - 1
120 global_src_idx = advance + block_cumsum
122 active = block_mask & select_mask
123 src_val = tl.load(src_ptr + global_src_idx, mask=active)
124 tl.store(inp_ptr + offset, src_val, mask=active)
127def masked_scatter_impl(inp, mask, source, N):
128 if N <= 4096:
129 BLOCK_SIZE = triton.next_power_of_2(N)
130 num_warps = 4
131 if BLOCK_SIZE >= 2048:
132 num_warps = 8
133 if BLOCK_SIZE >= 4096:
134 num_warps = 16
136 masked_scatter_single_pass_kernel[(1,)](
137 inp, mask, source, N, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps
138 )
139 return inp
141 BLOCK_SIZE = bracket_next_power_of_2(N, 128, 4096)
142 num_warps = min(16, BLOCK_SIZE // 32)
144 np = torch_device_fn.get_device_properties(mask.device).multi_processor_count
145 n_blocks = triton.cdiv(N, BLOCK_SIZE)
146 np = min(n_blocks, np)
147 n_blocks_per_row = triton.cdiv(n_blocks, np)
148 np = triton.cdiv(n_blocks, n_blocks_per_row)
149 NP_BLOCK = triton.next_power_of_2(np)
151 with torch_device_fn.device(inp.device):
152 dtype = torch.int32 if N < 2**31 else torch.int64
153 part_sums = torch.empty(np + 1, dtype=dtype, device=mask.device)
154 barrier = torch.zeros([], dtype=torch.int, device=mask.device)
156 mask_part_sum_kernel[(np,)](
157 mask,
158 part_sums,
159 barrier,
160 N,
161 n_blocks,
162 n_blocks_per_row,
163 NP_BLOCK=NP_BLOCK,
164 BLOCK_SIZE=BLOCK_SIZE,
165 num_warps=num_warps,
166 )
168 masked_scatter_kernel[(np,)](
169 inp,
170 mask,
171 source,
172 part_sums,
173 N,
174 n_blocks,
175 n_blocks_per_row,
176 BLOCK_SIZE=BLOCK_SIZE,
177 num_warps=num_warps,
178 )
180 return inp
183def masked_scatter(inp, mask, source):
184 logger.debug("GEMS MASKED SCATTER")
186 assert broadcastable(
187 inp.shape, mask.shape
188 ), "The shapes of the `mask` and the `input` tensor must be broadcastable"
190 _, mask = torch.broadcast_tensors(inp, mask)
192 out = inp.clone()
193 if not out.is_contiguous():
194 out = out.contiguous()
195 if not mask.is_contiguous():
196 mask = mask.contiguous()
197 if not source.is_contiguous():
198 source = source.contiguous()
200 N = out.numel()
202 masked_scatter_impl(out, mask, source, N)
204 return out
207def masked_scatter_(inp, mask, source):
208 logger.debug("GEMS MASKED SCATTER_")
210 assert broadcastable(inp.shape, mask.shape)
211 _, mask = torch.broadcast_tensors(inp, mask)
213 if not inp.is_contiguous():
214 raise RuntimeError(
215 "in-place operation currently requires contiguous input tensor. "
216 )
218 mask = mask if mask.is_contiguous() else mask.contiguous()
219 source = source if source.is_contiguous() else source.contiguous()
221 N = inp.numel()
222 masked_scatter_impl(inp, mask, source, N)
224 return inp