Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/count_nonzero.py: 0%
140 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
1import logging
2import os
4import torch
5import triton
6import triton.language as tl
8# from flag_gems import runtime
9from flag_gems.utils import dim_compress, libentry
10from flag_gems.utils import triton_lang_extension as tle
12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
15@libentry()
16@triton.jit
17def count_nonzero_kernel_1(x_ptr, out_ptr, numel, BLOCK_SIZE: tl.constexpr):
18 pid = tle.program_id(0)
19 block_start = pid * BLOCK_SIZE
20 offsets = block_start + tl.arange(0, BLOCK_SIZE)
21 mask = offsets < numel
22 x = tl.load(x_ptr + offsets, mask=mask, other=0)
23 is_nonzero = (x != 0).to(tl.int64)
24 nonzero_count = tl.sum(is_nonzero, axis=0)
25 tl.atomic_add(out_ptr, nonzero_count)
28"""***************************** TROTITON XPU KERNEL *****************************"""
31@libentry()
32@triton.jit
33def count_nonzero_kernel_1_part0_xpu(x_ptr, out_ptr, numel, BLOCK_SIZE_0: tl.constexpr):
34 pid = tle.program_id(0)
35 block_start = pid * BLOCK_SIZE_0
36 offsets = block_start + tl.arange(0, BLOCK_SIZE_0)
37 mask = offsets < numel
38 x = tl.load(x_ptr + offsets, mask=mask, other=0)
39 is_nonzero = (x != 0).to(tl.int64)
40 nonzero_count = tl.sum(is_nonzero, axis=0)
41 tl.store(out_ptr + pid, nonzero_count)
44@libentry()
45@triton.jit
46def count_nonzero_kernel_1_part1_xpu(x_ptr, out_ptr, numel, BLOCK_SIZE_1: tl.constexpr):
47 offsets = tl.arange(0, BLOCK_SIZE_1)
48 mask = offsets < numel
49 x = tl.load(x_ptr + offsets, mask=mask, other=0)
50 nonzero_count = tl.sum(x, axis=0)
51 tl.store(out_ptr, nonzero_count)
54"""***************************** TROTITON XPU KERNEL *****************************"""
57def heur_block_size(args):
58 return triton.next_power_of_2(triton.cdiv(args["numel"], 12))
61@libentry()
62# @triton.autotune(configs=runtime.get_tuned_config("count_nonzero"), key=["numel"])
63@triton.heuristics(
64 {
65 "BLOCK_SIZE": heur_block_size,
66 }
67)
68@triton.jit
69def count_nonzero_kernel(x_ptr, out_ptr, N, numel, BLOCK_SIZE: tl.constexpr):
70 pid_x = tle.program_id(0)
72 nonzero_count = tl.full((), value=0, dtype=out_ptr.dtype.element_ty)
73 for start_n in range(0, N, BLOCK_SIZE):
74 cols_offsets = start_n + tl.arange(0, BLOCK_SIZE)
75 offset = pid_x * N + cols_offsets
76 mask = offset < numel and cols_offsets < N
77 x = tl.load(x_ptr + offset, mask=mask, other=0)
78 is_nonzero = (x != 0).to(tl.int64)
79 nonzero_count += tl.sum(is_nonzero)
81 tl.store(out_ptr + pid_x, nonzero_count)
84"""***************************** TROTITON XPU KERNEL *****************************"""
87@libentry()
88@triton.jit
89def count_nonzero_kernel_xpu(
90 x_ptr, out_ptr, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr
91):
92 pid_x = tl.program_id(0)
93 row = pid_x * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
94 row_mask = row < M
96 _nonzero_count = tl.zeros([BLOCK_M, BLOCK_N], dtype=out_ptr.dtype.element_ty)
97 for start_n in range(0, N, BLOCK_N):
98 cols = start_n + tl.arange(0, BLOCK_N)[None, :]
99 col_mask = cols < N
100 mask = row_mask and col_mask
101 x = tl.load(x_ptr + row * N + cols, mask=mask, other=0)
102 is_nonzero = (x != 0).to(tl.int64)
103 _nonzero_count += is_nonzero
105 nonzero_count = tl.sum(_nonzero_count, axis=1)[:, None]
106 tl.store(out_ptr + row, nonzero_count, row_mask)
109"""***************************** TROTITON XPU KERNEL *****************************"""
112@libentry()
113# @triton.autotune(configs=runtime.get_tuned_config("count_nonzero"), key=["numel"])
114@triton.heuristics(
115 {
116 "BLOCK_SIZE": heur_block_size,
117 }
118)
119@triton.jit
120def count_nonzero_combin_kernel_1(x_ptr, out_ptr, N, numel, BLOCK_SIZE: tl.constexpr):
121 pid_x = tle.program_id(0)
122 nonzero_count = tl.full((), value=0, dtype=out_ptr.dtype.element_ty)
123 for start_n in range(0, N, BLOCK_SIZE):
124 cols_offsets = start_n + tl.arange(0, BLOCK_SIZE)
125 offset = pid_x * N + cols_offsets
126 mask = offset < numel and cols_offsets < N
127 x = tl.load(x_ptr + offset, mask=mask, other=0)
128 nonzero_count += tl.sum(x)
129 tl.store(out_ptr + pid_x, nonzero_count)
132@libentry()
133@triton.jit
134def count_nonzero_combin_kernel(
135 x_ptr,
136 combin_ptr,
137 N: tl.constexpr,
138 combin_N: tl.constexpr,
139 numel: tl.constexpr,
140 BLOCK_SIZE: tl.constexpr,
141):
142 pid_x = tle.program_id(0)
143 pid_y = tle.program_id(1)
144 cols_offsets = pid_y * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
145 offset = pid_x * N + cols_offsets
146 mask = offset < numel and cols_offsets < N
147 x = tl.load(x_ptr + offset, mask=mask, other=0)
148 is_nonzero = (x != 0).to(tl.int64)
149 nonzero_count = tl.sum(is_nonzero)
150 tl.store(combin_ptr + pid_x * combin_N + pid_y, nonzero_count)
153def count_nonzero(x, dim=None):
154 logger.debug("GEMS COUNT NONZERO")
156 CORE_NUM = 64
157 SIZE_PER_CORE = 512
158 SIZE_PER_CLUSTER = CORE_NUM * SIZE_PER_CORE
160 elem_bytes = x.element_size()
161 if dim is not None:
162 assert dim >= -x.ndim and dim < x.ndim, "Invalid dim"
163 shape = x.shape
164 numel = x.numel()
165 # premute
166 os.environ["TRITONXPU_IS_SCATTER_SLICE"] = "1"
167 x = dim_compress(x, dim)
168 x = x.contiguous().flatten()
169 del os.environ["TRITONXPU_IS_SCATTER_SLICE"]
170 # 2D count_nonzero
171 out_shape = list(shape)
172 del out_shape[dim]
173 os.environ["TRITONXPU_ELEMBYTES"] = "8"
174 out = torch.zeros(out_shape, dtype=torch.int64, device=x.device)
175 del os.environ["TRITONXPU_ELEMBYTES"]
176 N = shape[dim]
177 M = triton.cdiv(numel, shape[dim])
178 BLOCK_M = CORE_NUM
179 BLOCK_N = SIZE_PER_CORE
180 grid = lambda meta: (triton.cdiv(M, BLOCK_M),)
181 os.environ["TRITONXPU_ELEMBYTES"] = "8"
182 count_nonzero_kernel_xpu[grid](
183 x,
184 out,
185 M,
186 N,
187 BLOCK_M=BLOCK_M,
188 BLOCK_N=BLOCK_N,
189 groups_per_cluster=CORE_NUM,
190 buffer_size_limit=SIZE_PER_CORE * 8,
191 is_use_mask_zero=True,
192 )
193 del os.environ["TRITONXPU_ELEMBYTES"]
194 return out
195 else:
196 # 1D count_nonzero
197 x = x.contiguous().flatten()
198 numel = x.numel()
199 gridX = triton.cdiv(numel, SIZE_PER_CLUSTER)
200 os.environ["TRITONXPU_ELEMBYTES"] = "8"
201 out_mid = torch.zeros(gridX, dtype=torch.int64, device=x.device)
202 del os.environ["TRITONXPU_ELEMBYTES"]
203 count_nonzero_kernel_1_part0_xpu[(gridX,)](
204 x,
205 out_mid,
206 numel,
207 BLOCK_SIZE_0=SIZE_PER_CLUSTER,
208 buffer_size_limit=SIZE_PER_CORE * elem_bytes,
209 is_use_mask_zero=True,
210 )
211 BLOCK_SIZE_1 = triton.next_power_of_2(gridX)
212 os.environ["TRITONXPU_ELEMBYTES"] = "8"
213 out = torch.zeros(1, dtype=torch.int64, device=x.device)
214 count_nonzero_kernel_1_part1_xpu[(1,)](
215 out_mid,
216 out,
217 gridX,
218 BLOCK_SIZE_1=BLOCK_SIZE_1,
219 buffer_size_limit=SIZE_PER_CORE * 8,
220 is_use_mask_zero=True,
221 )
222 del os.environ["TRITONXPU_ELEMBYTES"]
224 return out[0]