Coverage for src/flag_gems/runtime/backend/_tsingmicro/ops/count_nonzero.py: 0%
113 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import dim_compress, libentry
10from flag_gems.utils import triton_lang_extension as tle
12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
15@libentry()
16@triton.jit
17def count_nonzero_kernel_1(x_ptr, out_ptr, numel, BLOCK_SIZE: tl.constexpr):
18 pid = tle.program_id(0)
19 block_start = pid * BLOCK_SIZE
20 offsets = block_start + tl.arange(0, BLOCK_SIZE)
21 mask = offsets < numel
22 x = tl.load(x_ptr + offsets, mask=mask, other=0)
23 is_nonzero = (x != 0).to(tl.int32)
24 nonzero_count = tl.sum(is_nonzero, axis=0)
25 tl.atomic_add(out_ptr, nonzero_count)
28@libentry()
29@triton.autotune(configs=runtime.get_tuned_config("count_nonzero"), key=["numel"])
30@triton.jit
31def count_nonzero_kernel(x_ptr, out_ptr, N, numel, BLOCK_SIZE: tl.constexpr):
32 pid_0 = tle.program_id(0)
33 num_p = tle.num_programs(0)
34 rows = (numel + N - 1) // N
35 rows_per_p = rows // num_p
37 for pid_n in range(0, rows_per_p):
38 pid_x = pid_0 * rows_per_p + pid_n
40 nonzero_count = tl.full((), value=0, dtype=out_ptr.dtype.element_ty)
41 for start_n in range(0, N, BLOCK_SIZE):
42 cols_offsets = start_n + tl.arange(0, BLOCK_SIZE)
43 offset = pid_x * N + cols_offsets
44 mask = offset < numel and cols_offsets < N
45 x = tl.load(x_ptr + offset, mask=mask, other=0)
46 is_nonzero = (x != 0).to(tl.int64)
47 nonzero_count += tl.sum(is_nonzero)
49 tl.store(out_ptr + pid_x, nonzero_count)
51 remain = rows % num_p
52 if pid_0 < remain:
53 pid_x = rows // num_p * num_p + pid_0
54 nonzero_count = tl.full((), value=0, dtype=out_ptr.dtype.element_ty)
55 for start_n in range(0, N, BLOCK_SIZE):
56 cols_offsets = start_n + tl.arange(0, BLOCK_SIZE)
57 offset = pid_x * N + cols_offsets
58 mask = offset < numel and cols_offsets < N
59 x = tl.load(x_ptr + offset, mask=mask, other=0)
60 is_nonzero = (x != 0).to(tl.int64)
61 nonzero_count += tl.sum(is_nonzero)
63 tl.store(out_ptr + pid_x, nonzero_count)
66@libentry()
67@triton.autotune(configs=runtime.get_tuned_config("count_nonzero"), key=["numel"])
68@triton.jit
69def count_nonzero_combin_kernel_1(x_ptr, out_ptr, N, numel, BLOCK_SIZE: tl.constexpr):
70 pid_x = tle.program_id(0)
71 nonzero_count = tl.full((), value=0, dtype=out_ptr.dtype.element_ty)
72 for start_n in range(0, N, BLOCK_SIZE):
73 cols_offsets = start_n + tl.arange(0, BLOCK_SIZE)
74 offset = pid_x * N + cols_offsets
75 mask = offset < numel and cols_offsets < N
76 x = tl.load(x_ptr + offset, mask=mask, other=0)
77 nonzero_count += tl.sum(x)
78 tl.store(out_ptr + pid_x, nonzero_count)
81@libentry()
82@triton.jit
83def count_nonzero_combin_kernel(
84 x_ptr, combin_ptr, N, combin_N, numel, BLOCK_SIZE: tl.constexpr
85):
86 pid_x = tle.program_id(0)
87 pid_y = tle.program_id(1)
88 cols_offsets = pid_y * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
89 offset = pid_x * N + cols_offsets
90 mask = offset < numel and cols_offsets < N
91 x = tl.load(x_ptr + offset, mask=mask, other=0)
92 is_nonzero = (x != 0).to(tl.int64)
93 nonzero_count = tl.sum(is_nonzero)
94 tl.store(combin_ptr + pid_x * combin_N + pid_y, nonzero_count)
97def count_nonzero(x, dim=None):
98 logger.debug("GEMS_TSINGMICRO COUNT NONZERO")
99 if dim is not None:
100 assert dim >= -x.ndim and dim < x.ndim, "Invalid dim"
101 shape = x.shape
102 BLOCK_SIZE = 2048
103 numel = x.numel()
104 x = dim_compress(x, dim)
105 x = x.contiguous().flatten()
106 combin_shape = list(shape)
107 combin_shape[dim] = triton.cdiv(combin_shape[dim], BLOCK_SIZE)
108 if combin_shape[dim] != 1:
109 combin = torch.zeros(combin_shape, dtype=torch.int64, device=x.device)
110 grid = (triton.cdiv(numel, shape[dim]), combin_shape[dim], 1)
111 count_nonzero_combin_kernel[grid](
112 x, combin, shape[dim], combin_shape[dim], numel, BLOCK_SIZE
113 )
114 x = combin
115 shape = x.shape
116 numel = x.numel()
117 out_shape = list(shape)
118 del out_shape[dim]
119 out = torch.zeros(out_shape, dtype=torch.int64, device=x.device)
120 grid = lambda meta: (triton.cdiv(numel, shape[dim]),)
121 count_nonzero_combin_kernel_1[grid](x, out, shape[dim], numel)
122 return out
123 out_shape = list(shape)
124 del out_shape[dim]
125 out = torch.zeros(out_shape, dtype=torch.int64, device=x.device)
126 grid = lambda meta: (
127 min(
128 torch_device_fn.get_device_properties().multi_processor_count,
129 triton.cdiv(numel, shape[dim]),
130 ),
131 )
132 count_nonzero_kernel[grid](x, out, shape[dim], numel)
133 return out
134 else:
135 x = x.contiguous().flatten()
136 numel = x.numel()
138 out = torch.zeros(1, dtype=torch.int32, device=x.device)
140 BLOCK_SIZE = 1024 * 8
141 grid = lambda meta: (triton.cdiv(numel, meta["BLOCK_SIZE"]),)
143 count_nonzero_kernel_1[grid](x, out, numel, BLOCK_SIZE=BLOCK_SIZE)
145 return out[0].to(torch.int64)