Coverage for src/flag_gems/experimental_ops/masked_scatter.py: 0%
75 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def _masked_scatter_count_kernel(
8 mask_ptr, # *Pointer* to mask tensor (bool)
9 counts_ptr, # *Pointer* to per-block counts (int32)
10 n_elements, # Number of elements in the flattened input
11 BLOCK_SIZE: tl.constexpr,
12):
13 pid = tl.program_id(axis=0)
14 block_start = pid * BLOCK_SIZE
15 offsets = block_start + tl.arange(0, BLOCK_SIZE)
16 in_bounds = offsets < n_elements
18 m = tl.load(mask_ptr + offsets, mask=in_bounds, other=0)
19 m_i32 = m.to(tl.int32)
20 local_count = tl.sum(m_i32, axis=0)
21 tl.store(counts_ptr + pid, local_count)
24@triton.jit
25def _masked_scatter_apply_kernel(
26 in_ptr, # *Pointer* to input tensor
27 mask_ptr, # *Pointer* to mask tensor (bool)
28 src_ptr, # *Pointer* to source tensor (1D)
29 out_ptr, # *Pointer* to output tensor
30 n_elements, # Number of elements in the flattened input
31 prefix_ptr, # *Pointer* to per-block exclusive prefix sums (int32)
32 BLOCK_SIZE: tl.constexpr,
33):
34 pid = tl.program_id(axis=0)
35 block_start = pid * BLOCK_SIZE
36 offsets = block_start + tl.arange(0, BLOCK_SIZE)
37 in_bounds = offsets < n_elements
39 x = tl.load(in_ptr + offsets, mask=in_bounds)
40 m = tl.load(mask_ptr + offsets, mask=in_bounds, other=0)
41 m_i32 = m.to(tl.int32)
43 # Compute per-block exclusive ranks for True mask elements
44 inclusive = tl.cumsum(m_i32, axis=0)
45 rank = inclusive - m_i32 # exclusive rank within the block
47 block_offset = tl.load(prefix_ptr + pid, mask=True, other=0).to(rank.dtype)
48 global_rank = block_offset + rank
50 take = m_i32 != 0
51 gathered = tl.load(src_ptr + global_rank, mask=(in_bounds & take), other=0)
53 out_vals = tl.where(take, gathered, x)
54 tl.store(out_ptr + offsets, out_vals, mask=in_bounds)
57def _launch_masked_scatter(
58 input_tensor: torch.Tensor,
59 mask: torch.Tensor,
60 source: torch.Tensor,
61 out_tensor: torch.Tensor = None,
62):
63 # Validate inputs
64 if input_tensor is None or mask is None or source is None:
65 raise ValueError("masked_scatter requires input, mask, and source tensors")
67 if mask.dtype != torch.bool:
68 mask = mask.to(torch.bool)
70 if input_tensor.numel() != mask.numel():
71 raise ValueError("input and mask must have the same number of elements")
73 if out_tensor is None:
74 out = torch.empty_like(input_tensor)
75 else:
76 out = out_tensor
77 if out.shape != input_tensor.shape:
78 raise ValueError("out tensor must have the same shape as input")
79 if out.dtype != input_tensor.dtype:
80 raise ValueError("out tensor must have the same dtype as input")
81 if out.device != input_tensor.device:
82 raise ValueError("out tensor must be on the same device as input")
84 device = input_tensor.device
85 if not device.type == "cuda":
86 raise ValueError("Triton kernels require CUDA tensors")
88 # Flatten to 1D contiguous views
89 x_flat = input_tensor.contiguous().view(-1)
90 m_flat = mask.contiguous().view(-1)
91 s_flat = source.contiguous().view(-1)
92 out_flat = out.contiguous().view(-1)
94 n_elements = x_flat.numel()
95 if n_elements == 0:
96 # Nothing to do
97 out.copy_(input_tensor)
98 return out
100 BLOCK_SIZE = 1024
101 n_blocks = triton.cdiv(n_elements, BLOCK_SIZE)
103 # 1) Count number of True mask elements per block
104 counts = torch.empty(n_blocks, dtype=torch.int32, device=device)
105 grid = (n_blocks,)
106 _masked_scatter_count_kernel[grid](
107 m_flat, counts, n_elements, BLOCK_SIZE=BLOCK_SIZE
108 )
110 # 2) Compute exclusive prefix sums of per-block counts
111 counts_prefix = torch.cumsum(counts, dim=0)
112 total_true = int(counts_prefix[-1].item()) if n_blocks > 0 else 0
113 if s_flat.numel() < total_true:
114 raise ValueError(
115 f"source has fewer elements ({s_flat.numel()}) than required by mask ({total_true})"
116 )
117 prefix_exclusive = counts_prefix - counts # int32, same device
119 # 3) Apply masked_scatter using per-block prefix offsets
120 _masked_scatter_apply_kernel[grid](
121 x_flat,
122 m_flat,
123 s_flat,
124 out_flat,
125 n_elements,
126 prefix_exclusive,
127 BLOCK_SIZE=BLOCK_SIZE,
128 )
130 # Reshape already matches; ensure out has the result
131 if out.data_ptr() != out_flat.data_ptr():
132 out.view(-1).copy_(out_flat)
133 return out
136def masked_scatter(input: torch.Tensor, mask: torch.Tensor, source: torch.Tensor):
137 return _launch_masked_scatter(input, mask, source, out_tensor=None)
140def masked_scatter_out(
141 input: torch.Tensor, mask: torch.Tensor, source: torch.Tensor, out: torch.Tensor
142):
143 return _launch_masked_scatter(input, mask, source, out_tensor=out)