Coverage for src/flag_gems/runtime/backend/_tsingmicro/ops/count_nonzero.py: 0%

113 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-07 22:33 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import dim_compress, libentry 

10from flag_gems.utils import triton_lang_extension as tle 

11 

12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

13 

14 

15@libentry() 

16@triton.jit 

17def count_nonzero_kernel_1(x_ptr, out_ptr, numel, BLOCK_SIZE: tl.constexpr): 

18 pid = tle.program_id(0) 

19 block_start = pid * BLOCK_SIZE 

20 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

21 mask = offsets < numel 

22 x = tl.load(x_ptr + offsets, mask=mask, other=0) 

23 is_nonzero = (x != 0).to(tl.int32) 

24 nonzero_count = tl.sum(is_nonzero, axis=0) 

25 tl.atomic_add(out_ptr, nonzero_count) 

26 

27 

28@libentry() 

29@triton.autotune(configs=runtime.get_tuned_config("count_nonzero"), key=["numel"]) 

30@triton.jit 

31def count_nonzero_kernel(x_ptr, out_ptr, N, numel, BLOCK_SIZE: tl.constexpr): 

32 pid_0 = tle.program_id(0) 

33 num_p = tle.num_programs(0) 

34 rows = (numel + N - 1) // N 

35 rows_per_p = rows // num_p 

36 

37 for pid_n in range(0, rows_per_p): 

38 pid_x = pid_0 * rows_per_p + pid_n 

39 

40 nonzero_count = tl.full((), value=0, dtype=out_ptr.dtype.element_ty) 

41 for start_n in range(0, N, BLOCK_SIZE): 

42 cols_offsets = start_n + tl.arange(0, BLOCK_SIZE) 

43 offset = pid_x * N + cols_offsets 

44 mask = offset < numel and cols_offsets < N 

45 x = tl.load(x_ptr + offset, mask=mask, other=0) 

46 is_nonzero = (x != 0).to(tl.int64) 

47 nonzero_count += tl.sum(is_nonzero) 

48 

49 tl.store(out_ptr + pid_x, nonzero_count) 

50 

51 remain = rows % num_p 

52 if pid_0 < remain: 

53 pid_x = rows // num_p * num_p + pid_0 

54 nonzero_count = tl.full((), value=0, dtype=out_ptr.dtype.element_ty) 

55 for start_n in range(0, N, BLOCK_SIZE): 

56 cols_offsets = start_n + tl.arange(0, BLOCK_SIZE) 

57 offset = pid_x * N + cols_offsets 

58 mask = offset < numel and cols_offsets < N 

59 x = tl.load(x_ptr + offset, mask=mask, other=0) 

60 is_nonzero = (x != 0).to(tl.int64) 

61 nonzero_count += tl.sum(is_nonzero) 

62 

63 tl.store(out_ptr + pid_x, nonzero_count) 

64 

65 

66@libentry() 

67@triton.autotune(configs=runtime.get_tuned_config("count_nonzero"), key=["numel"]) 

68@triton.jit 

69def count_nonzero_combin_kernel_1(x_ptr, out_ptr, N, numel, BLOCK_SIZE: tl.constexpr): 

70 pid_x = tle.program_id(0) 

71 nonzero_count = tl.full((), value=0, dtype=out_ptr.dtype.element_ty) 

72 for start_n in range(0, N, BLOCK_SIZE): 

73 cols_offsets = start_n + tl.arange(0, BLOCK_SIZE) 

74 offset = pid_x * N + cols_offsets 

75 mask = offset < numel and cols_offsets < N 

76 x = tl.load(x_ptr + offset, mask=mask, other=0) 

77 nonzero_count += tl.sum(x) 

78 tl.store(out_ptr + pid_x, nonzero_count) 

79 

80 

81@libentry() 

82@triton.jit 

83def count_nonzero_combin_kernel( 

84 x_ptr, combin_ptr, N, combin_N, numel, BLOCK_SIZE: tl.constexpr 

85): 

86 pid_x = tle.program_id(0) 

87 pid_y = tle.program_id(1) 

88 cols_offsets = pid_y * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

89 offset = pid_x * N + cols_offsets 

90 mask = offset < numel and cols_offsets < N 

91 x = tl.load(x_ptr + offset, mask=mask, other=0) 

92 is_nonzero = (x != 0).to(tl.int64) 

93 nonzero_count = tl.sum(is_nonzero) 

94 tl.store(combin_ptr + pid_x * combin_N + pid_y, nonzero_count) 

95 

96 

97def count_nonzero(x, dim=None): 

98 logger.debug("GEMS_TSINGMICRO COUNT NONZERO") 

99 if dim is not None: 

100 assert dim >= -x.ndim and dim < x.ndim, "Invalid dim" 

101 shape = x.shape 

102 BLOCK_SIZE = 2048 

103 numel = x.numel() 

104 x = dim_compress(x, dim) 

105 x = x.contiguous().flatten() 

106 combin_shape = list(shape) 

107 combin_shape[dim] = triton.cdiv(combin_shape[dim], BLOCK_SIZE) 

108 if combin_shape[dim] != 1: 

109 combin = torch.zeros(combin_shape, dtype=torch.int64, device=x.device) 

110 grid = (triton.cdiv(numel, shape[dim]), combin_shape[dim], 1) 

111 count_nonzero_combin_kernel[grid]( 

112 x, combin, shape[dim], combin_shape[dim], numel, BLOCK_SIZE 

113 ) 

114 x = combin 

115 shape = x.shape 

116 numel = x.numel() 

117 out_shape = list(shape) 

118 del out_shape[dim] 

119 out = torch.zeros(out_shape, dtype=torch.int64, device=x.device) 

120 grid = lambda meta: (triton.cdiv(numel, shape[dim]),) 

121 count_nonzero_combin_kernel_1[grid](x, out, shape[dim], numel) 

122 return out 

123 out_shape = list(shape) 

124 del out_shape[dim] 

125 out = torch.zeros(out_shape, dtype=torch.int64, device=x.device) 

126 grid = lambda meta: ( 

127 min( 

128 torch_device_fn.get_device_properties().multi_processor_count, 

129 triton.cdiv(numel, shape[dim]), 

130 ), 

131 ) 

132 count_nonzero_kernel[grid](x, out, shape[dim], numel) 

133 return out 

134 else: 

135 x = x.contiguous().flatten() 

136 numel = x.numel() 

137 

138 out = torch.zeros(1, dtype=torch.int32, device=x.device) 

139 

140 BLOCK_SIZE = 1024 * 8 

141 grid = lambda meta: (triton.cdiv(numel, meta["BLOCK_SIZE"]),) 

142 

143 count_nonzero_kernel_1[grid](x, out, numel, BLOCK_SIZE=BLOCK_SIZE) 

144 

145 return out[0].to(torch.int64)