Coverage for src/flag_gems/ops/count_nonzero.py: 62%

95 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-12 02:21 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.utils import dim_compress, libentry 

9from flag_gems.utils import triton_lang_extension as tle 

10 

11logger = logging.getLogger(__name__) 

12 

13 

14@libentry() 

15@triton.jit 

16def count_nonzero_kernel_1(x_ptr, out_ptr, numel, BLOCK_SIZE: tl.constexpr): 

17 pid = tle.program_id(0) 

18 block_start = pid * BLOCK_SIZE 

19 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

20 mask = offsets < numel 

21 x = tl.load(x_ptr + offsets, mask=mask, other=0) 

22 is_nonzero = (x != 0).to(tl.int64) 

23 nonzero_count = tl.sum(is_nonzero, axis=0) 

24 tl.atomic_add(out_ptr, nonzero_count) 

25 

26 

27@libentry() 

28@triton.autotune(configs=runtime.get_tuned_config("count_nonzero"), key=["numel"]) 

29@triton.jit 

30def count_nonzero_kernel(x_ptr, out_ptr, N, numel, BLOCK_SIZE: tl.constexpr): 

31 pid_x = tle.program_id(0) 

32 

33 nonzero_count = tl.full((), value=0, dtype=out_ptr.dtype.element_ty) 

34 for start_n in range(0, N, BLOCK_SIZE): 

35 cols_offsets = start_n + tl.arange(0, BLOCK_SIZE) 

36 offset = pid_x * N + cols_offsets 

37 mask = offset < numel and cols_offsets < N 

38 x = tl.load(x_ptr + offset, mask=mask, other=0) 

39 is_nonzero = (x != 0).to(tl.int64) 

40 nonzero_count += tl.sum(is_nonzero) 

41 

42 tl.store(out_ptr + pid_x, nonzero_count) 

43 

44 

45@libentry() 

46@triton.autotune(configs=runtime.get_tuned_config("count_nonzero"), key=["numel"]) 

47@triton.jit 

48def count_nonzero_combin_kernel_1(x_ptr, out_ptr, N, numel, BLOCK_SIZE: tl.constexpr): 

49 pid_x = tle.program_id(0) 

50 nonzero_count = tl.full((), value=0, dtype=out_ptr.dtype.element_ty) 

51 for start_n in range(0, N, BLOCK_SIZE): 

52 cols_offsets = start_n + tl.arange(0, BLOCK_SIZE) 

53 offset = pid_x * N + cols_offsets 

54 mask = offset < numel and cols_offsets < N 

55 x = tl.load(x_ptr + offset, mask=mask, other=0) 

56 nonzero_count += tl.sum(x) 

57 tl.store(out_ptr + pid_x, nonzero_count) 

58 

59 

60@libentry() 

61@triton.jit 

62def count_nonzero_combin_kernel( 

63 x_ptr, combin_ptr, N, combin_N, numel, BLOCK_SIZE: tl.constexpr 

64): 

65 pid_x = tle.program_id(0) 

66 pid_y = tle.program_id(1) 

67 cols_offsets = pid_y * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

68 offset = pid_x * N + cols_offsets 

69 mask = offset < numel and cols_offsets < N 

70 x = tl.load(x_ptr + offset, mask=mask, other=0) 

71 is_nonzero = (x != 0).to(tl.int64) 

72 nonzero_count = tl.sum(is_nonzero) 

73 tl.store(combin_ptr + pid_x * combin_N + pid_y, nonzero_count) 

74 

75 

76def count_nonzero(x, dim=None): 

77 logger.debug("GEMS COUNT NONZERO") 

78 if dim is not None: 

79 assert dim >= -x.ndim and dim < x.ndim, "Invalid dim" 

80 shape = x.shape 

81 BLOCK_SIZE = 2048 

82 numel = x.numel() 

83 x = dim_compress(x, dim) 

84 x = x.contiguous().flatten() 

85 combin_shape = list(shape) 

86 combin_shape[dim] = triton.cdiv(combin_shape[dim], BLOCK_SIZE) 

87 if combin_shape[dim] != 1: 

88 combin = torch.zeros(combin_shape, dtype=torch.int64, device=x.device) 

89 grid = (triton.cdiv(numel, shape[dim]), combin_shape[dim], 1) 

90 count_nonzero_combin_kernel[grid]( 

91 x, combin, shape[dim], combin_shape[dim], numel, BLOCK_SIZE 

92 ) 

93 x = combin 

94 shape = x.shape 

95 numel = x.numel() 

96 out_shape = list(shape) 

97 del out_shape[dim] 

98 out = torch.zeros(out_shape, dtype=torch.int64, device=x.device) 

99 grid = lambda meta: (triton.cdiv(numel, shape[dim]),) 

100 count_nonzero_combin_kernel_1[grid](x, out, shape[dim], numel) 

101 return out 

102 out_shape = list(shape) 

103 del out_shape[dim] 

104 out = torch.zeros(out_shape, dtype=torch.int64, device=x.device) 

105 grid = lambda meta: (triton.cdiv(numel, shape[dim]),) 

106 count_nonzero_kernel[grid](x, out, shape[dim], numel) 

107 return out 

108 else: 

109 x = x.contiguous().flatten() 

110 numel = x.numel() 

111 

112 out = torch.zeros(1, dtype=torch.int64, device=x.device) 

113 

114 BLOCK_SIZE = 1024 

115 grid = lambda meta: (triton.cdiv(numel, meta["BLOCK_SIZE"]),) 

116 

117 count_nonzero_kernel_1[grid](x, out, numel, BLOCK_SIZE=BLOCK_SIZE) 

118 

119 return out[0]