Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/nonzero.py: 0%
47 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7# from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry
10from flag_gems.utils import triton_lang_extension as tle
12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
15def nonzero_kernel_heur_block_size(args):
16 return triton.next_power_of_2(triton.cdiv(args["n_elements"], 12)) # cluster_num
19@libentry()
20# @triton.autotune(
21# configs=runtime.get_tuned_config("nonzero"),
22# key=[
23# "n_elements",
24# ],
25# )
26@triton.heuristics(
27 values={
28 "BLOCK_SIZE": nonzero_kernel_heur_block_size,
29 },
30)
31@triton.jit
32def nonzero_kernel(
33 inp,
34 prefix_sum,
35 out,
36 n_elements: tl.constexpr,
37 shape,
38 ndim: tl.constexpr,
39 BLOCK_SIZE: tl.constexpr,
40):
41 pid = tle.program_id(0)
43 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
44 mask = offset < n_elements
46 inp_vals = tl.load(inp + offset, mask=mask).to(tl.int1)
47 out_offset = tl.load(prefix_sum + offset, mask=mask) - 1
49 nonzero_mask = mask and inp_vals # noqa
51 idx_flat = offset
52 for dim in range(ndim - 1, -1, -1):
53 dim_size = tl.load(shape + dim)
54 remainder = idx_flat % dim_size
55 idx_flat //= dim_size
56 tl.store(out + out_offset * ndim + dim, remainder, mask=nonzero_mask)
59def nonzero(inp, *, as_tuple=False):
60 logger.debug("GEMS NONZERO")
62 inp_ndim = inp.ndim
64 inp = inp.contiguous()
65 n_elements = inp.numel()
66 inp_view = inp.view(n_elements)
68 shape = torch.tensor(inp.shape, dtype=torch.int32, device=inp.device)
70 inp_bool = inp_view
71 if inp_view.dtype != torch.bool:
72 inp_bool = inp_view != 0
74 prefix_sum = inp_bool.cumsum(axis=0)
76 num_nonzeros = n_elements
77 out = torch.empty(num_nonzeros, inp_ndim, dtype=torch.int64, device=inp.device)
79 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
80 with torch_device_fn.device(inp.device):
81 nonzero_kernel[grid](
82 inp_bool,
83 prefix_sum,
84 out,
85 n_elements,
86 shape,
87 inp_ndim,
88 isCloseUnrollControl=True,
89 is_use_mask_zero=True,
90 )
92 num_nonzeros = prefix_sum[n_elements - 1].item()
93 out = out[0:num_nonzeros]
95 if as_tuple:
96 return torch.unbind(out, dim=0)
97 else:
98 return out