Coverage for src/flag_gems/runtime/backend/_cambricon/ops/masked_select.py: 0%
73 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-23 02:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-23 02:03 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import broadcastable
11from ..utils import TOTAL_CORE_NUM
13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
16@triton.autotune(configs=runtime.get_tuned_config("masked_select"), key=["n_elements"])
17@triton.jit
18def masked_select_kernel(
19 inp_ptr,
20 select_mask_ptr,
21 select_val_ptr,
22 select_num_ptr,
23 n_elements,
24 BLOCK_SIZE: tl.constexpr,
25):
26 pid = tl.program_id(axis=0)
27 num_p = tl.num_programs(axis=0)
28 split_n = (n_elements + BLOCK_SIZE - 1) // BLOCK_SIZE
29 step = BLOCK_SIZE * num_p
30 offset_start = pid * BLOCK_SIZE
31 loop = 0
32 for offset in tl.range(offset_start, n_elements, step):
33 offsets = offset + tl.arange(0, BLOCK_SIZE)
34 mask = offsets < n_elements
35 inp = tl.load(inp_ptr + offsets, mask=mask, other=0.0)
36 select_mask = tl.load(select_mask_ptr + offsets, mask=mask, other=0.0).to(
37 tl.int1
38 )
39 select_val, select_num = tl.masked_select(inp, select_mask)
40 tl.store(select_val_ptr + offsets, select_val, mask=mask)
41 num_select_offset = loop * num_p + pid + tl.arange(0, 1)
42 loop += 1
43 num_select_mask = num_select_offset < split_n
44 tl.store(select_num_ptr + num_select_offset, select_num, mask=num_select_mask)
47@triton.jit
48def get_out_kernel(
49 select_val_ptr,
50 select_num_ptr,
51 output_ptr,
52 n_elements: tl.constexpr,
53 BLOCK_SIZE: tl.constexpr,
54):
55 pid = tl.program_id(axis=0)
56 num_p = tl.num_programs(axis=0)
57 step = BLOCK_SIZE * num_p
58 split_n: tl.constexpr = (n_elements + BLOCK_SIZE - 1) // BLOCK_SIZE
60 all_select_num_offset = tl.arange(0, split_n)
61 all_select_num_mask = all_select_num_offset < split_n
62 all_select_num = tl.load(
63 select_num_ptr + all_select_num_offset, mask=all_select_num_mask, other=0.0
64 )
65 prefix_select_num = tl.cumsum(all_select_num, 0)
67 offset_start = pid * BLOCK_SIZE
68 loop = 0
69 for offset in tl.range(offset_start, n_elements, step):
70 offsets = offset + tl.arange(0, BLOCK_SIZE)
71 mask = offsets < n_elements
72 select_val = tl.load(select_val_ptr + offsets, mask=mask, other=0.0)
73 select_num_offset = loop * num_p + pid + tl.arange(0, 1)
74 select_num_mask = select_num_offset < split_n
75 select_num = tl.load(
76 select_num_ptr + select_num_offset, mask=select_num_mask, other=0.0
77 )
78 if loop == 0 and pid == 0:
79 output_offset = tl.arange(0, BLOCK_SIZE)
80 else:
81 output_offset = prefix_select_num[loop * num_p + pid - 1] + tl.arange(
82 0, BLOCK_SIZE
83 )
84 loop += 1
85 output_mask = tl.arange(0, BLOCK_SIZE) < select_num
86 tl.store(output_ptr + output_offset, select_val, mask=output_mask)
89def masked_select(inp, mask):
90 logger.debug("GEMS_CAMBRICON MASKED SELECT")
92 inp_shape = tuple(inp.shape)
93 mask_shape = tuple(mask.shape)
95 assert broadcastable(
96 inp_shape, mask_shape
97 ), "The shapes of the `mask` and the `input` tensor must be broadcastable"
98 inp, mask = torch.broadcast_tensors(inp, mask)
100 inp = inp.contiguous()
101 mask = mask.contiguous()
103 n_elements = inp.numel()
104 grid = lambda meta: (
105 min(triton.cdiv(n_elements, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),
106 )
107 with torch_device_fn.device(inp.device):
108 select_val = torch.empty(n_elements, dtype=inp.dtype, device=inp.device)
109 select_num = torch.empty(n_elements, dtype=torch.int32, device=inp.device)
110 masked_select_kernel[grid](inp, mask, select_val, select_num, n_elements)
112 cur_block_size = masked_select_kernel.best_config.kwargs["BLOCK_SIZE"]
113 num_select = mask.sum().item()
114 output = torch.empty(num_select, dtype=inp.dtype, device=inp.device)
115 get_out_kernel[grid](select_val, select_num, output, n_elements, cur_block_size)
117 return output