Coverage for src/flag_gems/runtime/backend/_cambricon/ops/count_nonzero.py: 0%

113 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-20 02:31 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.utils import dim_compress, libentry 

9from flag_gems.utils import triton_lang_extension as tle 

10 

11from ..utils import TOTAL_CORE_NUM 

12 

13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

14 

15 

16@libentry() 

17@triton.jit 

18def count_nonzero_kernel_1(x_ptr, out_ptr, numel, BLOCK_SIZE: tl.constexpr): 

19 pid = tle.program_id(0) 

20 block_start = pid * BLOCK_SIZE 

21 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

22 mask = offsets < numel 

23 x = tl.load(x_ptr + offsets, mask=mask, other=0) 

24 is_nonzero = (x != 0).to(tl.int32) 

25 nonzero_count = tl.sum(is_nonzero, axis=0) 

26 tl.atomic_add(out_ptr, nonzero_count) 

27 

28 

29@libentry() 

30@triton.autotune(configs=runtime.get_tuned_config("count_nonzero"), key=["numel"]) 

31@triton.jit 

32def count_nonzero_kernel(x_ptr, out_ptr, N, numel, BLOCK_SIZE: tl.constexpr): 

33 pid_0 = tle.program_id(0) 

34 num_p = tle.num_programs(0) 

35 rows = (numel + N - 1) // N 

36 rows_per_p = rows // num_p 

37 

38 for pid_n in range(0, rows_per_p): 

39 pid_x = pid_0 * rows_per_p + pid_n 

40 

41 nonzero_count = tl.full((), value=0, dtype=out_ptr.dtype.element_ty) 

42 for start_n in range(0, N, BLOCK_SIZE): 

43 cols_offsets = start_n + tl.arange(0, BLOCK_SIZE) 

44 offset = pid_x * N + cols_offsets 

45 mask = offset < numel and cols_offsets < N 

46 x = tl.load(x_ptr + offset, mask=mask, other=0) 

47 is_nonzero = (x != 0).to(tl.int64) 

48 nonzero_count += tl.sum(is_nonzero) 

49 

50 tl.store(out_ptr + pid_x, nonzero_count) 

51 

52 remain = rows % num_p 

53 if pid_0 < remain: 

54 pid_x = rows // num_p * num_p + pid_0 

55 nonzero_count = tl.full((), value=0, dtype=out_ptr.dtype.element_ty) 

56 for start_n in range(0, N, BLOCK_SIZE): 

57 cols_offsets = start_n + tl.arange(0, BLOCK_SIZE) 

58 offset = pid_x * N + cols_offsets 

59 mask = offset < numel and cols_offsets < N 

60 x = tl.load(x_ptr + offset, mask=mask, other=0) 

61 is_nonzero = (x != 0).to(tl.int64) 

62 nonzero_count += tl.sum(is_nonzero) 

63 

64 tl.store(out_ptr + pid_x, nonzero_count) 

65 

66 

67@libentry() 

68@triton.autotune(configs=runtime.get_tuned_config("count_nonzero"), key=["numel"]) 

69@triton.jit 

70def count_nonzero_combin_kernel_1(x_ptr, out_ptr, N, numel, BLOCK_SIZE: tl.constexpr): 

71 pid_x = tle.program_id(0) 

72 nonzero_count = tl.full((), value=0, dtype=out_ptr.dtype.element_ty) 

73 for start_n in range(0, N, BLOCK_SIZE): 

74 cols_offsets = start_n + tl.arange(0, BLOCK_SIZE) 

75 offset = pid_x * N + cols_offsets 

76 mask = offset < numel and cols_offsets < N 

77 x = tl.load(x_ptr + offset, mask=mask, other=0) 

78 nonzero_count += tl.sum(x) 

79 tl.store(out_ptr + pid_x, nonzero_count) 

80 

81 

82@libentry() 

83@triton.jit 

84def count_nonzero_combin_kernel( 

85 x_ptr, combin_ptr, N, combin_N, numel, BLOCK_SIZE: tl.constexpr 

86): 

87 pid_x = tle.program_id(0) 

88 pid_y = tle.program_id(1) 

89 cols_offsets = pid_y * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

90 offset = pid_x * N + cols_offsets 

91 mask = offset < numel and cols_offsets < N 

92 x = tl.load(x_ptr + offset, mask=mask, other=0) 

93 is_nonzero = (x != 0).to(tl.int64) 

94 nonzero_count = tl.sum(is_nonzero) 

95 tl.store(combin_ptr + pid_x * combin_N + pid_y, nonzero_count) 

96 

97 

98def count_nonzero(x, dim=None): 

99 logger.debug("GEMS_CAMBRICON COUNT NONZERO") 

100 if dim is not None: 

101 assert dim >= -x.ndim and dim < x.ndim, "Invalid dim" 

102 shape = x.shape 

103 BLOCK_SIZE = 2048 

104 numel = x.numel() 

105 x = dim_compress(x, dim) 

106 x = x.contiguous().flatten() 

107 combin_shape = list(shape) 

108 combin_shape[dim] = triton.cdiv(combin_shape[dim], BLOCK_SIZE) 

109 if combin_shape[dim] != 1: 

110 combin = torch.zeros(combin_shape, dtype=torch.int64, device=x.device) 

111 grid = (triton.cdiv(numel, shape[dim]), combin_shape[dim], 1) 

112 count_nonzero_combin_kernel[grid]( 

113 x, combin, shape[dim], combin_shape[dim], numel, BLOCK_SIZE 

114 ) 

115 x = combin 

116 shape = x.shape 

117 numel = x.numel() 

118 out_shape = list(shape) 

119 del out_shape[dim] 

120 out = torch.zeros(out_shape, dtype=torch.int64, device=x.device) 

121 grid = lambda meta: (triton.cdiv(numel, shape[dim]),) 

122 count_nonzero_combin_kernel_1[grid](x, out, shape[dim], numel) 

123 return out 

124 out_shape = list(shape) 

125 del out_shape[dim] 

126 out = torch.zeros(out_shape, dtype=torch.int64, device=x.device) 

127 grid = lambda meta: (min(TOTAL_CORE_NUM, triton.cdiv(numel, shape[dim])),) 

128 count_nonzero_kernel[grid](x, out, shape[dim], numel) 

129 return out 

130 else: 

131 x = x.contiguous().flatten() 

132 numel = x.numel() 

133 

134 out = torch.zeros(1, dtype=torch.int32, device=x.device) 

135 

136 BLOCK_SIZE = 1024 * 8 

137 grid = lambda meta: (triton.cdiv(numel, meta["BLOCK_SIZE"]),) 

138 

139 count_nonzero_kernel_1[grid](x, out, numel, BLOCK_SIZE=BLOCK_SIZE) 

140 

141 return out[0].to(torch.int64)