Coverage for src/flag_gems/ops/count_nonzero.py: 62%
95 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.utils import dim_compress, libentry
9from flag_gems.utils import triton_lang_extension as tle
11logger = logging.getLogger(__name__)
14@libentry()
15@triton.jit
16def count_nonzero_kernel_1(x_ptr, out_ptr, numel, BLOCK_SIZE: tl.constexpr):
17 pid = tle.program_id(0)
18 block_start = pid * BLOCK_SIZE
19 offsets = block_start + tl.arange(0, BLOCK_SIZE)
20 mask = offsets < numel
21 x = tl.load(x_ptr + offsets, mask=mask, other=0)
22 is_nonzero = (x != 0).to(tl.int64)
23 nonzero_count = tl.sum(is_nonzero, axis=0)
24 tl.atomic_add(out_ptr, nonzero_count)
27@libentry()
28@triton.autotune(configs=runtime.get_tuned_config("count_nonzero"), key=["numel"])
29@triton.jit
30def count_nonzero_kernel(x_ptr, out_ptr, N, numel, BLOCK_SIZE: tl.constexpr):
31 pid_x = tle.program_id(0)
33 nonzero_count = tl.full((), value=0, dtype=out_ptr.dtype.element_ty)
34 for start_n in range(0, N, BLOCK_SIZE):
35 cols_offsets = start_n + tl.arange(0, BLOCK_SIZE)
36 offset = pid_x * N + cols_offsets
37 mask = offset < numel and cols_offsets < N
38 x = tl.load(x_ptr + offset, mask=mask, other=0)
39 is_nonzero = (x != 0).to(tl.int64)
40 nonzero_count += tl.sum(is_nonzero)
42 tl.store(out_ptr + pid_x, nonzero_count)
45@libentry()
46@triton.autotune(configs=runtime.get_tuned_config("count_nonzero"), key=["numel"])
47@triton.jit
48def count_nonzero_combin_kernel_1(x_ptr, out_ptr, N, numel, BLOCK_SIZE: tl.constexpr):
49 pid_x = tle.program_id(0)
50 nonzero_count = tl.full((), value=0, dtype=out_ptr.dtype.element_ty)
51 for start_n in range(0, N, BLOCK_SIZE):
52 cols_offsets = start_n + tl.arange(0, BLOCK_SIZE)
53 offset = pid_x * N + cols_offsets
54 mask = offset < numel and cols_offsets < N
55 x = tl.load(x_ptr + offset, mask=mask, other=0)
56 nonzero_count += tl.sum(x)
57 tl.store(out_ptr + pid_x, nonzero_count)
60@libentry()
61@triton.jit
62def count_nonzero_combin_kernel(
63 x_ptr, combin_ptr, N, combin_N, numel, BLOCK_SIZE: tl.constexpr
64):
65 pid_x = tle.program_id(0)
66 pid_y = tle.program_id(1)
67 cols_offsets = pid_y * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
68 offset = pid_x * N + cols_offsets
69 mask = offset < numel and cols_offsets < N
70 x = tl.load(x_ptr + offset, mask=mask, other=0)
71 is_nonzero = (x != 0).to(tl.int64)
72 nonzero_count = tl.sum(is_nonzero)
73 tl.store(combin_ptr + pid_x * combin_N + pid_y, nonzero_count)
76def count_nonzero(x, dim=None):
77 logger.debug("GEMS COUNT NONZERO")
78 if dim is not None:
79 assert dim >= -x.ndim and dim < x.ndim, "Invalid dim"
80 shape = x.shape
81 BLOCK_SIZE = 2048
82 numel = x.numel()
83 x = dim_compress(x, dim)
84 x = x.contiguous().flatten()
85 combin_shape = list(shape)
86 combin_shape[dim] = triton.cdiv(combin_shape[dim], BLOCK_SIZE)
87 if combin_shape[dim] != 1:
88 combin = torch.zeros(combin_shape, dtype=torch.int64, device=x.device)
89 grid = (triton.cdiv(numel, shape[dim]), combin_shape[dim], 1)
90 count_nonzero_combin_kernel[grid](
91 x, combin, shape[dim], combin_shape[dim], numel, BLOCK_SIZE
92 )
93 x = combin
94 shape = x.shape
95 numel = x.numel()
96 out_shape = list(shape)
97 del out_shape[dim]
98 out = torch.zeros(out_shape, dtype=torch.int64, device=x.device)
99 grid = lambda meta: (triton.cdiv(numel, shape[dim]),)
100 count_nonzero_combin_kernel_1[grid](x, out, shape[dim], numel)
101 return out
102 out_shape = list(shape)
103 del out_shape[dim]
104 out = torch.zeros(out_shape, dtype=torch.int64, device=x.device)
105 grid = lambda meta: (triton.cdiv(numel, shape[dim]),)
106 count_nonzero_kernel[grid](x, out, shape[dim], numel)
107 return out
108 else:
109 x = x.contiguous().flatten()
110 numel = x.numel()
112 out = torch.zeros(1, dtype=torch.int64, device=x.device)
114 BLOCK_SIZE = 1024
115 grid = lambda meta: (triton.cdiv(numel, meta["BLOCK_SIZE"]),)
117 count_nonzero_kernel_1[grid](x, out, numel, BLOCK_SIZE=BLOCK_SIZE)
119 return out[0]