Coverage for src/flag_gems/runtime/backend/_ascend/ops/count_nonzero.py: 0%
103 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.utils import dim_compress, libentry
9from flag_gems.utils import triton_lang_extension as tle
11logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}')
14@libentry()
15@triton.jit
16def count_nonzero_kernel_1(x_ptr, out_ptr, numel, BLOCK_SIZE: tl.constexpr):
17 pid = tle.program_id(0)
18 block_start = pid * BLOCK_SIZE
19 offsets = block_start + tl.arange(0, BLOCK_SIZE)
20 mask = offsets < numel
21 x = tl.load(x_ptr + offsets, mask=mask, other=0)
22 is_nonzero = (x != 0).to(tl.int32)
23 nonzero_count = tl.sum(is_nonzero, axis=0)
24 tl.atomic_add(out_ptr, nonzero_count)
27@libentry()
28@triton.autotune(configs=runtime.get_tuned_config("count_nonzero"), key=["numel"])
29@triton.jit
30def count_nonzero_kernel(x_ptr, out_ptr, N, numel, BLOCK_SIZE: tl.constexpr):
31 n_workers = tle.num_programs(0)
32 pid = tle.program_id(0)
34 n_tasks = tl.cdiv(numel, N)
35 tasks_per_worker = tl.cdiv(n_tasks, n_workers)
37 for task_index in range(tasks_per_worker):
38 task_id = pid + task_index * n_workers
39 nonzero_count = tl.full((), value=0, dtype=out_ptr.dtype.element_ty)
40 for start_n in range(0, N, BLOCK_SIZE):
41 cols_offsets = start_n + tl.arange(0, BLOCK_SIZE)
42 offset = task_id * N + cols_offsets
43 mask = offset < numel and cols_offsets < N
44 x = tl.load(x_ptr + offset, mask=mask, other=0)
45 is_nonzero = (x != 0).to(tl.int64)
46 nonzero_count += tl.sum(is_nonzero)
48 tl.store(out_ptr + task_id, nonzero_count)
51@libentry()
52@triton.autotune(configs=runtime.get_tuned_config("count_nonzero"), key=["numel"])
53@triton.jit
54def count_nonzero_combin_kernel_1(x_ptr, out_ptr, N, numel, BLOCK_SIZE: tl.constexpr):
55 pid_x = tle.program_id(0)
56 nonzero_count = tl.full((), value=0, dtype=out_ptr.dtype.element_ty)
57 for start_n in range(0, N, BLOCK_SIZE):
58 cols_offsets = start_n + tl.arange(0, BLOCK_SIZE)
59 offset = pid_x * N + cols_offsets
60 mask = offset < numel and cols_offsets < N
61 x = tl.load(x_ptr + offset, mask=mask, other=0)
62 nonzero_count += tl.sum(x)
63 tl.store(out_ptr + pid_x, nonzero_count)
66@libentry()
67@triton.jit
68def count_nonzero_combin_kernel(
69 x_ptr, combin_ptr, N, combin_N, numel, BLOCK_SIZE: tl.constexpr
70):
71 pid_x = tle.program_id(0)
72 pid_y = tle.program_id(1)
73 cols_offsets = pid_y * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
74 offset = pid_x * N + cols_offsets
75 mask = offset < numel and cols_offsets < N
76 x = tl.load(x_ptr + offset, mask=mask, other=0)
77 is_nonzero = (x != 0).to(tl.int64)
78 nonzero_count = tl.sum(is_nonzero)
79 tl.store(combin_ptr + pid_x * combin_N + pid_y, nonzero_count)
82def count_nonzero(x, dim=None):
83 logger.debug("GEMS_ASCEND COUNT NONZERO")
84 if dim is not None:
85 assert dim >= -x.ndim and dim < x.ndim, "Invalid dim"
86 shape = x.shape
87 BLOCK_SIZE = 8192
88 numel = x.numel()
89 x = dim_compress(x, dim)
90 x = x.contiguous().flatten()
91 combin_shape = list(shape)
92 combin_shape[dim] = triton.cdiv(combin_shape[dim], BLOCK_SIZE)
93 if combin_shape[dim] != 1:
94 combin = torch.zeros(combin_shape, dtype=torch.int64, device=x.device)
95 grid = (triton.cdiv(numel, shape[dim]), combin_shape[dim], 1)
96 count_nonzero_combin_kernel[grid](
97 x, combin, shape[dim], combin_shape[dim], numel, BLOCK_SIZE
98 )
99 x = combin
100 shape = x.shape
101 numel = x.numel()
102 out_shape = list(shape)
103 del out_shape[dim]
104 out = torch.zeros(out_shape, dtype=torch.int64, device=x.device)
105 grid = lambda meta: (triton.cdiv(numel, shape[dim]),)
106 count_nonzero_combin_kernel_1[grid](x, out, shape[dim], numel)
107 return out
108 out_shape = list(shape)
109 del out_shape[dim]
110 out = torch.zeros(out_shape, dtype=torch.int64, device=x.device)
112 def grid(meta):
113 axis0 = triton.cdiv(numel, shape[dim])
114 axis0 = axis0 if axis0 < 240 else 240
115 return (axis0,)
117 count_nonzero_kernel[grid](x, out, shape[dim], numel)
118 return out
119 else:
120 x = x.contiguous().flatten()
121 numel = x.numel()
123 out = torch.zeros(1, dtype=torch.int32, device=x.device)
125 BLOCK_SIZE = 8192
126 grid = lambda meta: (triton.cdiv(numel, meta["BLOCK_SIZE"]),)
128 count_nonzero_kernel_1[grid](x, out, numel, BLOCK_SIZE=BLOCK_SIZE)
130 return out[0].to(torch.int64)