Coverage for src/flag_gems/runtime/backend/_tsingmicro/ops/masked_select.py: 0%
35 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import broadcastable, libentry
11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
14@libentry()
15@triton.autotune(
16 configs=runtime.get_tuned_config("masked_select"),
17 key=["n_elements"],
18 warmup=5,
19 rep=5,
20)
21@triton.jit
22def masked_select_kernel(
23 inp_ptr,
24 select_mask_ptr,
25 prefix_sum_ptr,
26 out_ptr,
27 n_elements,
28 BLOCK_SIZE: tl.constexpr,
29):
30 pid = tl.program_id(axis=0)
31 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
32 mask = offsets < n_elements
34 inp = tl.load(inp_ptr + offsets, mask=mask, other=0.0)
35 select_mask = tl.load(select_mask_ptr + offsets, mask=mask, other=0.0).to(tl.int1)
36 out_offset = tl.load(prefix_sum_ptr + offsets, mask=mask, other=0.0) - 1
38 tl.store(out_ptr + out_offset, inp, mask=(select_mask & mask))
41def masked_select(inp, mask):
42 logger.debug("GEMS_TSINGMICRO MASKED SELECT")
44 inp_shape = tuple(inp.shape)
45 mask_shape = tuple(mask.shape)
47 assert broadcastable(
48 inp_shape, mask_shape
49 ), "The shapes of the `mask` and the `input` tensor must be broadcastable"
50 inp, mask = torch.broadcast_tensors(inp, mask)
52 inp = inp.contiguous()
53 mask = mask.contiguous()
55 mask_flattened = mask.ravel()
57 prefix_sum = mask_flattened.cumsum(axis=0)
58 out = torch.empty(prefix_sum[-1].item(), dtype=inp.dtype, device=inp.device)
60 n_elements = inp.numel()
61 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
62 with torch_device_fn.device(inp.device):
63 masked_select_kernel[grid](inp, mask_flattened, prefix_sum, out, n_elements)
64 return out