Coverage for src/flag_gems/runtime/backend/_cambricon/ops/count_nonzero.py: 0%
113 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.utils import dim_compress, libentry
9from flag_gems.utils import triton_lang_extension as tle
11from ..utils import TOTAL_CORE_NUM
13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
16@libentry()
17@triton.jit
18def count_nonzero_kernel_1(x_ptr, out_ptr, numel, BLOCK_SIZE: tl.constexpr):
19 pid = tle.program_id(0)
20 block_start = pid * BLOCK_SIZE
21 offsets = block_start + tl.arange(0, BLOCK_SIZE)
22 mask = offsets < numel
23 x = tl.load(x_ptr + offsets, mask=mask, other=0)
24 is_nonzero = (x != 0).to(tl.int32)
25 nonzero_count = tl.sum(is_nonzero, axis=0)
26 tl.atomic_add(out_ptr, nonzero_count)
29@libentry()
30@triton.autotune(configs=runtime.get_tuned_config("count_nonzero"), key=["numel"])
31@triton.jit
32def count_nonzero_kernel(x_ptr, out_ptr, N, numel, BLOCK_SIZE: tl.constexpr):
33 pid_0 = tle.program_id(0)
34 num_p = tle.num_programs(0)
35 rows = (numel + N - 1) // N
36 rows_per_p = rows // num_p
38 for pid_n in range(0, rows_per_p):
39 pid_x = pid_0 * rows_per_p + pid_n
41 nonzero_count = tl.full((), value=0, dtype=out_ptr.dtype.element_ty)
42 for start_n in range(0, N, BLOCK_SIZE):
43 cols_offsets = start_n + tl.arange(0, BLOCK_SIZE)
44 offset = pid_x * N + cols_offsets
45 mask = offset < numel and cols_offsets < N
46 x = tl.load(x_ptr + offset, mask=mask, other=0)
47 is_nonzero = (x != 0).to(tl.int64)
48 nonzero_count += tl.sum(is_nonzero)
50 tl.store(out_ptr + pid_x, nonzero_count)
52 remain = rows % num_p
53 if pid_0 < remain:
54 pid_x = rows // num_p * num_p + pid_0
55 nonzero_count = tl.full((), value=0, dtype=out_ptr.dtype.element_ty)
56 for start_n in range(0, N, BLOCK_SIZE):
57 cols_offsets = start_n + tl.arange(0, BLOCK_SIZE)
58 offset = pid_x * N + cols_offsets
59 mask = offset < numel and cols_offsets < N
60 x = tl.load(x_ptr + offset, mask=mask, other=0)
61 is_nonzero = (x != 0).to(tl.int64)
62 nonzero_count += tl.sum(is_nonzero)
64 tl.store(out_ptr + pid_x, nonzero_count)
67@libentry()
68@triton.autotune(configs=runtime.get_tuned_config("count_nonzero"), key=["numel"])
69@triton.jit
70def count_nonzero_combin_kernel_1(x_ptr, out_ptr, N, numel, BLOCK_SIZE: tl.constexpr):
71 pid_x = tle.program_id(0)
72 nonzero_count = tl.full((), value=0, dtype=out_ptr.dtype.element_ty)
73 for start_n in range(0, N, BLOCK_SIZE):
74 cols_offsets = start_n + tl.arange(0, BLOCK_SIZE)
75 offset = pid_x * N + cols_offsets
76 mask = offset < numel and cols_offsets < N
77 x = tl.load(x_ptr + offset, mask=mask, other=0)
78 nonzero_count += tl.sum(x)
79 tl.store(out_ptr + pid_x, nonzero_count)
82@libentry()
83@triton.jit
84def count_nonzero_combin_kernel(
85 x_ptr, combin_ptr, N, combin_N, numel, BLOCK_SIZE: tl.constexpr
86):
87 pid_x = tle.program_id(0)
88 pid_y = tle.program_id(1)
89 cols_offsets = pid_y * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
90 offset = pid_x * N + cols_offsets
91 mask = offset < numel and cols_offsets < N
92 x = tl.load(x_ptr + offset, mask=mask, other=0)
93 is_nonzero = (x != 0).to(tl.int64)
94 nonzero_count = tl.sum(is_nonzero)
95 tl.store(combin_ptr + pid_x * combin_N + pid_y, nonzero_count)
98def count_nonzero(x, dim=None):
99 logger.debug("GEMS_CAMBRICON COUNT NONZERO")
100 if dim is not None:
101 assert dim >= -x.ndim and dim < x.ndim, "Invalid dim"
102 shape = x.shape
103 BLOCK_SIZE = 2048
104 numel = x.numel()
105 x = dim_compress(x, dim)
106 x = x.contiguous().flatten()
107 combin_shape = list(shape)
108 combin_shape[dim] = triton.cdiv(combin_shape[dim], BLOCK_SIZE)
109 if combin_shape[dim] != 1:
110 combin = torch.zeros(combin_shape, dtype=torch.int64, device=x.device)
111 grid = (triton.cdiv(numel, shape[dim]), combin_shape[dim], 1)
112 count_nonzero_combin_kernel[grid](
113 x, combin, shape[dim], combin_shape[dim], numel, BLOCK_SIZE
114 )
115 x = combin
116 shape = x.shape
117 numel = x.numel()
118 out_shape = list(shape)
119 del out_shape[dim]
120 out = torch.zeros(out_shape, dtype=torch.int64, device=x.device)
121 grid = lambda meta: (triton.cdiv(numel, shape[dim]),)
122 count_nonzero_combin_kernel_1[grid](x, out, shape[dim], numel)
123 return out
124 out_shape = list(shape)
125 del out_shape[dim]
126 out = torch.zeros(out_shape, dtype=torch.int64, device=x.device)
127 grid = lambda meta: (min(TOTAL_CORE_NUM, triton.cdiv(numel, shape[dim])),)
128 count_nonzero_kernel[grid](x, out, shape[dim], numel)
129 return out
130 else:
131 x = x.contiguous().flatten()
132 numel = x.numel()
134 out = torch.zeros(1, dtype=torch.int32, device=x.device)
136 BLOCK_SIZE = 1024 * 8
137 grid = lambda meta: (triton.cdiv(numel, meta["BLOCK_SIZE"]),)
139 count_nonzero_kernel_1[grid](x, out, numel, BLOCK_SIZE=BLOCK_SIZE)
141 return out[0].to(torch.int64)