Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/masked_select.py: 0%
44 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-19 02:32 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-19 02:32 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7# from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import broadcastable, libentry
10from flag_gems.utils import triton_lang_extension as tle
12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
15def heur_block_size(args):
16 return triton.next_power_of_2(triton.cdiv(args["n_elements"], 12)) # cluster_num
19@libentry()
20@triton.heuristics(
21 values={
22 "BLOCK_SIZE": heur_block_size,
23 },
24)
25@triton.jit
26def masked_select_kernel(
27 inp_ptr,
28 select_mask_ptr,
29 prefix_sum_ptr,
30 out_ptr,
31 n_elements,
32 BLOCK_SIZE: tl.constexpr,
33):
34 pid = tle.program_id(axis=0)
35 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
36 mask = offsets < n_elements
38 inp = tl.load(inp_ptr + offsets, mask=mask, other=0.0)
39 select_mask = tl.load(select_mask_ptr + offsets, mask=mask, other=0.0).to(tl.int1)
40 out_offset = (
41 tl.load(prefix_sum_ptr + offsets, mask=(select_mask & mask), other=0.0) - 1
42 )
44 tl.store(out_ptr + out_offset, inp, mask=(select_mask & mask))
47def masked_select(inp, mask):
48 logger.debug("GEMS MASKED SELECT")
50 inp_shape = tuple(inp.shape)
51 mask_shape = tuple(mask.shape)
53 assert broadcastable(
54 inp_shape, mask_shape
55 ), "The shapes of the `mask` and the `input` tensor must be broadcastable"
56 inp, mask = torch.broadcast_tensors(inp, mask)
58 inp = inp.contiguous()
59 mask = mask.contiguous()
61 mask_flattened = mask.ravel()
63 prefix_sum = mask_flattened.cumsum(axis=0)
64 out = torch.empty(prefix_sum[-1].item(), dtype=inp.dtype, device=inp.device)
66 n_elements = inp.numel()
67 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
69 import os
71 os.environ["TRITONXPU_OTHER_SIM"] = "1"
72 os.environ["TRITONXPU_STORE_MASK_SIM"] = "1"
73 with torch_device_fn.device(inp.device):
74 masked_select_kernel[grid](inp, mask_flattened, prefix_sum, out, n_elements)
76 if "TRITONXPU_OTHER_SIM" in os.environ:
77 del os.environ["TRITONXPU_OTHER_SIM"]
78 if "TRITONXPU_STORE_MASK_SIM" in os.environ:
79 del os.environ["TRITONXPU_STORE_MASK_SIM"]
81 return out