Coverage for src/flag_gems/runtime/backend/_ascend/ops/count_nonzero.py: 0%

103 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-24 15:40 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.utils import dim_compress, libentry 

9from flag_gems.utils import triton_lang_extension as tle 

10 

11logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}') 

12 

13 

14@libentry() 

15@triton.jit 

16def count_nonzero_kernel_1(x_ptr, out_ptr, numel, BLOCK_SIZE: tl.constexpr): 

17 pid = tle.program_id(0) 

18 block_start = pid * BLOCK_SIZE 

19 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

20 mask = offsets < numel 

21 x = tl.load(x_ptr + offsets, mask=mask, other=0) 

22 is_nonzero = (x != 0).to(tl.int32) 

23 nonzero_count = tl.sum(is_nonzero, axis=0) 

24 tl.atomic_add(out_ptr, nonzero_count) 

25 

26 

27@libentry() 

28@triton.autotune(configs=runtime.get_tuned_config("count_nonzero"), key=["numel"]) 

29@triton.jit 

30def count_nonzero_kernel(x_ptr, out_ptr, N, numel, BLOCK_SIZE: tl.constexpr): 

31 n_workers = tle.num_programs(0) 

32 pid = tle.program_id(0) 

33 

34 n_tasks = tl.cdiv(numel, N) 

35 tasks_per_worker = tl.cdiv(n_tasks, n_workers) 

36 

37 for task_index in range(tasks_per_worker): 

38 task_id = pid + task_index * n_workers 

39 nonzero_count = tl.full((), value=0, dtype=out_ptr.dtype.element_ty) 

40 for start_n in range(0, N, BLOCK_SIZE): 

41 cols_offsets = start_n + tl.arange(0, BLOCK_SIZE) 

42 offset = task_id * N + cols_offsets 

43 mask = offset < numel and cols_offsets < N 

44 x = tl.load(x_ptr + offset, mask=mask, other=0) 

45 is_nonzero = (x != 0).to(tl.int64) 

46 nonzero_count += tl.sum(is_nonzero) 

47 

48 tl.store(out_ptr + task_id, nonzero_count) 

49 

50 

51@libentry() 

52@triton.autotune(configs=runtime.get_tuned_config("count_nonzero"), key=["numel"]) 

53@triton.jit 

54def count_nonzero_combin_kernel_1(x_ptr, out_ptr, N, numel, BLOCK_SIZE: tl.constexpr): 

55 pid_x = tle.program_id(0) 

56 nonzero_count = tl.full((), value=0, dtype=out_ptr.dtype.element_ty) 

57 for start_n in range(0, N, BLOCK_SIZE): 

58 cols_offsets = start_n + tl.arange(0, BLOCK_SIZE) 

59 offset = pid_x * N + cols_offsets 

60 mask = offset < numel and cols_offsets < N 

61 x = tl.load(x_ptr + offset, mask=mask, other=0) 

62 nonzero_count += tl.sum(x) 

63 tl.store(out_ptr + pid_x, nonzero_count) 

64 

65 

66@libentry() 

67@triton.jit 

68def count_nonzero_combin_kernel( 

69 x_ptr, combin_ptr, N, combin_N, numel, BLOCK_SIZE: tl.constexpr 

70): 

71 pid_x = tle.program_id(0) 

72 pid_y = tle.program_id(1) 

73 cols_offsets = pid_y * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

74 offset = pid_x * N + cols_offsets 

75 mask = offset < numel and cols_offsets < N 

76 x = tl.load(x_ptr + offset, mask=mask, other=0) 

77 is_nonzero = (x != 0).to(tl.int64) 

78 nonzero_count = tl.sum(is_nonzero) 

79 tl.store(combin_ptr + pid_x * combin_N + pid_y, nonzero_count) 

80 

81 

82def count_nonzero(x, dim=None): 

83 logger.debug("GEMS_ASCEND COUNT NONZERO") 

84 if dim is not None: 

85 assert dim >= -x.ndim and dim < x.ndim, "Invalid dim" 

86 shape = x.shape 

87 BLOCK_SIZE = 8192 

88 numel = x.numel() 

89 x = dim_compress(x, dim) 

90 x = x.contiguous().flatten() 

91 combin_shape = list(shape) 

92 combin_shape[dim] = triton.cdiv(combin_shape[dim], BLOCK_SIZE) 

93 if combin_shape[dim] != 1: 

94 combin = torch.zeros(combin_shape, dtype=torch.int64, device=x.device) 

95 grid = (triton.cdiv(numel, shape[dim]), combin_shape[dim], 1) 

96 count_nonzero_combin_kernel[grid]( 

97 x, combin, shape[dim], combin_shape[dim], numel, BLOCK_SIZE 

98 ) 

99 x = combin 

100 shape = x.shape 

101 numel = x.numel() 

102 out_shape = list(shape) 

103 del out_shape[dim] 

104 out = torch.zeros(out_shape, dtype=torch.int64, device=x.device) 

105 grid = lambda meta: (triton.cdiv(numel, shape[dim]),) 

106 count_nonzero_combin_kernel_1[grid](x, out, shape[dim], numel) 

107 return out 

108 out_shape = list(shape) 

109 del out_shape[dim] 

110 out = torch.zeros(out_shape, dtype=torch.int64, device=x.device) 

111 

112 def grid(meta): 

113 axis0 = triton.cdiv(numel, shape[dim]) 

114 axis0 = axis0 if axis0 < 240 else 240 

115 return (axis0,) 

116 

117 count_nonzero_kernel[grid](x, out, shape[dim], numel) 

118 return out 

119 else: 

120 x = x.contiguous().flatten() 

121 numel = x.numel() 

122 

123 out = torch.zeros(1, dtype=torch.int32, device=x.device) 

124 

125 BLOCK_SIZE = 8192 

126 grid = lambda meta: (triton.cdiv(numel, meta["BLOCK_SIZE"]),) 

127 

128 count_nonzero_kernel_1[grid](x, out, numel, BLOCK_SIZE=BLOCK_SIZE) 

129 

130 return out[0].to(torch.int64)