Coverage for src/flag_gems/runtime/backend/_cambricon/ops/nonzero.py: 0%
54 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry
11from ..utils import TOTAL_CORE_NUM
13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
16@libentry()
17@triton.autotune(
18 configs=runtime.get_tuned_config("nonzero"),
19 key=[
20 "n_elements",
21 ],
22)
23@triton.jit
24def nonzero_kernel(
25 inp,
26 prefix_sum,
27 out,
28 n_elements,
29 shape,
30 ndim: tl.constexpr,
31 BLOCK_SIZE: tl.constexpr,
32):
33 pid = tl.program_id(0)
34 num_jobs = tl.num_programs(0)
35 block_start = pid * BLOCK_SIZE
36 step = num_jobs * BLOCK_SIZE
37 for block_start_offset in range(block_start, n_elements, step):
38 offset = block_start_offset + tl.arange(0, BLOCK_SIZE)
39 mask = offset < n_elements
41 inp_vals = tl.load(inp + offset, mask=mask, other=0.0).to(tl.int1)
42 nonzero_mask = mask and inp_vals
43 out_row_offset = tl.load(prefix_sum + offset, mask=nonzero_mask) - 1
44 out_col_offset = tl.arange(0, ndim)
45 out_offsets = out_row_offset[:, None] * ndim + out_col_offset[None, :]
46 out_vals = tl.zeros((BLOCK_SIZE, ndim), tl.int32)
47 idx_flat = offset
48 for dim in range(ndim - 1, -1, -1):
49 dim_size = tl.load(shape + dim)
50 remainder = idx_flat % dim_size
51 idx_flat //= dim_size
52 out_vals[:, dim] = remainder
53 tl.store(out + out_offsets.to(tl.int32), out_vals, mask=nonzero_mask[:, None])
56def nonzero(inp, *, as_tuple=False):
57 logger.debug("GEMS_CAMBRICON NONZERO")
59 inp_ndim = inp.ndim
61 inp = inp.contiguous()
62 n_elements = inp.numel()
63 inp_view = inp.view(n_elements)
65 shape = torch.tensor(inp.shape, dtype=torch.int32, device=inp.device)
67 inp_bool = inp_view
68 if inp_view.dtype != torch.bool:
69 inp_bool = inp_view != 0
71 prefix_sum = inp_bool.cumsum(axis=0)
72 num_nonzeros = n_elements
73 out = torch.empty(num_nonzeros, inp_ndim, dtype=torch.int64, device=inp.device)
74 grid = lambda meta: (
75 min(triton.cdiv(n_elements, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),
76 )
77 with torch_device_fn.device(inp.device):
78 nonzero_kernel[grid](inp_bool, prefix_sum, out, n_elements, shape, inp_ndim)
80 num_nonzeros = prefix_sum[n_elements - 1].item()
81 out = out[0:num_nonzeros]
83 if as_tuple:
84 return torch.unbind(out, dim=0)
85 else:
86 return out