Coverage for src/flag_gems/ops/nonzero.py: 72%
46 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry
10from flag_gems.utils import triton_lang_extension as tle
12logger = logging.getLogger(__name__)
15@libentry()
16@triton.heuristics(runtime.get_heuristic_config("elementwise_generic"))
17@triton.jit
18def nonzero_kernel(
19 inp,
20 prefix_sum,
21 out,
22 n_elements,
23 shape,
24 ndim: tl.constexpr,
25 BLOCK_SIZE: tl.constexpr,
26):
27 pid = tle.program_id(0)
29 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
30 mask = offset < n_elements
32 inp_vals = tl.load(inp + offset, mask=mask).to(tl.int1)
33 out_offset = tl.load(prefix_sum + offset, mask=mask) - 1
35 nonzero_mask = mask and inp_vals # noqa
37 idx_flat = offset
38 for dim in range(ndim - 1, -1, -1):
39 dim_size = tl.load(shape + dim)
40 remainder = idx_flat % dim_size
41 idx_flat //= dim_size
42 tl.store(out + out_offset * ndim + dim, remainder, mask=nonzero_mask)
45def nonzero(inp, *, as_tuple=False):
46 logger.debug("GEMS NONZERO")
48 inp_ndim = inp.ndim
50 inp = inp.contiguous()
51 n_elements = inp.numel()
52 inp_view = inp.view(n_elements)
54 shape = torch.tensor(inp.shape, dtype=torch.int32, device=inp.device)
56 inp_bool = inp_view
57 if inp_view.dtype != torch.bool:
58 inp_bool = inp_view != 0
60 prefix_sum = inp_bool.cumsum(axis=0)
62 num_nonzeros = n_elements
63 out = torch.empty(num_nonzeros, inp_ndim, dtype=torch.int64, device=inp.device)
65 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
66 with torch_device_fn.device(inp.device):
67 nonzero_kernel[grid](inp_bool, prefix_sum, out, n_elements, shape, inp_ndim)
69 num_nonzeros = prefix_sum[n_elements - 1].item()
70 out = out[0:num_nonzeros]
72 if as_tuple:
73 return torch.unbind(out, dim=0)
74 else:
75 return out