Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/count_nonzero.py: 0%

140 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-26 15:32 +0800

1import logging 

2import os 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8# from flag_gems import runtime 

9from flag_gems.utils import dim_compress, libentry 

10from flag_gems.utils import triton_lang_extension as tle 

11 

12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

13 

14 

15@libentry() 

16@triton.jit 

17def count_nonzero_kernel_1(x_ptr, out_ptr, numel, BLOCK_SIZE: tl.constexpr): 

18 pid = tle.program_id(0) 

19 block_start = pid * BLOCK_SIZE 

20 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

21 mask = offsets < numel 

22 x = tl.load(x_ptr + offsets, mask=mask, other=0) 

23 is_nonzero = (x != 0).to(tl.int64) 

24 nonzero_count = tl.sum(is_nonzero, axis=0) 

25 tl.atomic_add(out_ptr, nonzero_count) 

26 

27 

28"""***************************** TROTITON XPU KERNEL *****************************""" 

29 

30 

31@libentry() 

32@triton.jit 

33def count_nonzero_kernel_1_part0_xpu(x_ptr, out_ptr, numel, BLOCK_SIZE_0: tl.constexpr): 

34 pid = tle.program_id(0) 

35 block_start = pid * BLOCK_SIZE_0 

36 offsets = block_start + tl.arange(0, BLOCK_SIZE_0) 

37 mask = offsets < numel 

38 x = tl.load(x_ptr + offsets, mask=mask, other=0) 

39 is_nonzero = (x != 0).to(tl.int64) 

40 nonzero_count = tl.sum(is_nonzero, axis=0) 

41 tl.store(out_ptr + pid, nonzero_count) 

42 

43 

44@libentry() 

45@triton.jit 

46def count_nonzero_kernel_1_part1_xpu(x_ptr, out_ptr, numel, BLOCK_SIZE_1: tl.constexpr): 

47 offsets = tl.arange(0, BLOCK_SIZE_1) 

48 mask = offsets < numel 

49 x = tl.load(x_ptr + offsets, mask=mask, other=0) 

50 nonzero_count = tl.sum(x, axis=0) 

51 tl.store(out_ptr, nonzero_count) 

52 

53 

54"""***************************** TROTITON XPU KERNEL *****************************""" 

55 

56 

57def heur_block_size(args): 

58 return triton.next_power_of_2(triton.cdiv(args["numel"], 12)) 

59 

60 

61@libentry() 

62# @triton.autotune(configs=runtime.get_tuned_config("count_nonzero"), key=["numel"]) 

63@triton.heuristics( 

64 { 

65 "BLOCK_SIZE": heur_block_size, 

66 } 

67) 

68@triton.jit 

69def count_nonzero_kernel(x_ptr, out_ptr, N, numel, BLOCK_SIZE: tl.constexpr): 

70 pid_x = tle.program_id(0) 

71 

72 nonzero_count = tl.full((), value=0, dtype=out_ptr.dtype.element_ty) 

73 for start_n in range(0, N, BLOCK_SIZE): 

74 cols_offsets = start_n + tl.arange(0, BLOCK_SIZE) 

75 offset = pid_x * N + cols_offsets 

76 mask = offset < numel and cols_offsets < N 

77 x = tl.load(x_ptr + offset, mask=mask, other=0) 

78 is_nonzero = (x != 0).to(tl.int64) 

79 nonzero_count += tl.sum(is_nonzero) 

80 

81 tl.store(out_ptr + pid_x, nonzero_count) 

82 

83 

84"""***************************** TROTITON XPU KERNEL *****************************""" 

85 

86 

87@libentry() 

88@triton.jit 

89def count_nonzero_kernel_xpu( 

90 x_ptr, out_ptr, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr 

91): 

92 pid_x = tl.program_id(0) 

93 row = pid_x * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

94 row_mask = row < M 

95 

96 _nonzero_count = tl.zeros([BLOCK_M, BLOCK_N], dtype=out_ptr.dtype.element_ty) 

97 for start_n in range(0, N, BLOCK_N): 

98 cols = start_n + tl.arange(0, BLOCK_N)[None, :] 

99 col_mask = cols < N 

100 mask = row_mask and col_mask 

101 x = tl.load(x_ptr + row * N + cols, mask=mask, other=0) 

102 is_nonzero = (x != 0).to(tl.int64) 

103 _nonzero_count += is_nonzero 

104 

105 nonzero_count = tl.sum(_nonzero_count, axis=1)[:, None] 

106 tl.store(out_ptr + row, nonzero_count, row_mask) 

107 

108 

109"""***************************** TROTITON XPU KERNEL *****************************""" 

110 

111 

112@libentry() 

113# @triton.autotune(configs=runtime.get_tuned_config("count_nonzero"), key=["numel"]) 

114@triton.heuristics( 

115 { 

116 "BLOCK_SIZE": heur_block_size, 

117 } 

118) 

119@triton.jit 

120def count_nonzero_combin_kernel_1(x_ptr, out_ptr, N, numel, BLOCK_SIZE: tl.constexpr): 

121 pid_x = tle.program_id(0) 

122 nonzero_count = tl.full((), value=0, dtype=out_ptr.dtype.element_ty) 

123 for start_n in range(0, N, BLOCK_SIZE): 

124 cols_offsets = start_n + tl.arange(0, BLOCK_SIZE) 

125 offset = pid_x * N + cols_offsets 

126 mask = offset < numel and cols_offsets < N 

127 x = tl.load(x_ptr + offset, mask=mask, other=0) 

128 nonzero_count += tl.sum(x) 

129 tl.store(out_ptr + pid_x, nonzero_count) 

130 

131 

132@libentry() 

133@triton.jit 

134def count_nonzero_combin_kernel( 

135 x_ptr, 

136 combin_ptr, 

137 N: tl.constexpr, 

138 combin_N: tl.constexpr, 

139 numel: tl.constexpr, 

140 BLOCK_SIZE: tl.constexpr, 

141): 

142 pid_x = tle.program_id(0) 

143 pid_y = tle.program_id(1) 

144 cols_offsets = pid_y * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

145 offset = pid_x * N + cols_offsets 

146 mask = offset < numel and cols_offsets < N 

147 x = tl.load(x_ptr + offset, mask=mask, other=0) 

148 is_nonzero = (x != 0).to(tl.int64) 

149 nonzero_count = tl.sum(is_nonzero) 

150 tl.store(combin_ptr + pid_x * combin_N + pid_y, nonzero_count) 

151 

152 

153def count_nonzero(x, dim=None): 

154 logger.debug("GEMS COUNT NONZERO") 

155 

156 CORE_NUM = 64 

157 SIZE_PER_CORE = 512 

158 SIZE_PER_CLUSTER = CORE_NUM * SIZE_PER_CORE 

159 

160 elem_bytes = x.element_size() 

161 if dim is not None: 

162 assert dim >= -x.ndim and dim < x.ndim, "Invalid dim" 

163 shape = x.shape 

164 numel = x.numel() 

165 # premute 

166 os.environ["TRITONXPU_IS_SCATTER_SLICE"] = "1" 

167 x = dim_compress(x, dim) 

168 x = x.contiguous().flatten() 

169 del os.environ["TRITONXPU_IS_SCATTER_SLICE"] 

170 # 2D count_nonzero 

171 out_shape = list(shape) 

172 del out_shape[dim] 

173 os.environ["TRITONXPU_ELEMBYTES"] = "8" 

174 out = torch.zeros(out_shape, dtype=torch.int64, device=x.device) 

175 del os.environ["TRITONXPU_ELEMBYTES"] 

176 N = shape[dim] 

177 M = triton.cdiv(numel, shape[dim]) 

178 BLOCK_M = CORE_NUM 

179 BLOCK_N = SIZE_PER_CORE 

180 grid = lambda meta: (triton.cdiv(M, BLOCK_M),) 

181 os.environ["TRITONXPU_ELEMBYTES"] = "8" 

182 count_nonzero_kernel_xpu[grid]( 

183 x, 

184 out, 

185 M, 

186 N, 

187 BLOCK_M=BLOCK_M, 

188 BLOCK_N=BLOCK_N, 

189 groups_per_cluster=CORE_NUM, 

190 buffer_size_limit=SIZE_PER_CORE * 8, 

191 is_use_mask_zero=True, 

192 ) 

193 del os.environ["TRITONXPU_ELEMBYTES"] 

194 return out 

195 else: 

196 # 1D count_nonzero 

197 x = x.contiguous().flatten() 

198 numel = x.numel() 

199 gridX = triton.cdiv(numel, SIZE_PER_CLUSTER) 

200 os.environ["TRITONXPU_ELEMBYTES"] = "8" 

201 out_mid = torch.zeros(gridX, dtype=torch.int64, device=x.device) 

202 del os.environ["TRITONXPU_ELEMBYTES"] 

203 count_nonzero_kernel_1_part0_xpu[(gridX,)]( 

204 x, 

205 out_mid, 

206 numel, 

207 BLOCK_SIZE_0=SIZE_PER_CLUSTER, 

208 buffer_size_limit=SIZE_PER_CORE * elem_bytes, 

209 is_use_mask_zero=True, 

210 ) 

211 BLOCK_SIZE_1 = triton.next_power_of_2(gridX) 

212 os.environ["TRITONXPU_ELEMBYTES"] = "8" 

213 out = torch.zeros(1, dtype=torch.int64, device=x.device) 

214 count_nonzero_kernel_1_part1_xpu[(1,)]( 

215 out_mid, 

216 out, 

217 gridX, 

218 BLOCK_SIZE_1=BLOCK_SIZE_1, 

219 buffer_size_limit=SIZE_PER_CORE * 8, 

220 is_use_mask_zero=True, 

221 ) 

222 del os.environ["TRITONXPU_ELEMBYTES"] 

223 

224 return out[0]