Coverage for src/flag_gems/runtime/backend/_ascend/ops/any.py: 0%

106 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-26 15:32 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import dim_compress, libentry 

11from flag_gems.utils import triton_lang_extension as tle 

12 

13logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}') 

14 

15 

16# torch.any: Tests if any elements in input evaluate to True. If the dtype of input 

17# is not BOOL, then test if any elements in input evaluate to non-zero value 

18# In triton function, test if any elements in input evaluate to non-zero value is ok. 

19 

20 

21@triton.jit 

22def reduce_any(a, b): 

23 return a or b 

24 

25 

26@libentry() 

27@triton.autotune(configs=runtime.get_tuned_config("any"), key=["M", "N"]) 

28@triton.jit 

29def any_kernel_dim( 

30 inp, 

31 out, 

32 M, 

33 N, 

34 BLOCK_M: tl.constexpr, 

35 BLOCK_N: tl.constexpr, 

36): 

37 # Map the program id to the row of inp it should compute. 

38 pid = tle.program_id(0) 

39 rows = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

40 inp = inp + rows * N 

41 out = out + rows 

42 row_mask = rows < M 

43 

44 _any = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.int1) 

45 for off in range(0, N, BLOCK_N): 

46 cols = off + tl.arange(0, BLOCK_N)[None, :] 

47 col_mask = cols < N 

48 mask = row_mask and col_mask 

49 

50 a = tl.load(inp + cols, mask, other=0.0) 

51 _any = _any or (a != 0) 

52 any = tl.reduce(_any, axis=1, combine_fn=reduce_any) 

53 tl.store(out, any[:, None], row_mask) 

54 

55 

56@libentry() 

57@triton.jit 

58def any_kernel_1( 

59 inp, 

60 mid, 

61 n_elements, 

62 mid_size, 

63 BLOCK_SIZE: tl.constexpr, 

64): 

65 pid = tle.program_id(0) 

66 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

67 inp_ptrs = inp + offset 

68 mask = offset < n_elements 

69 inp_val = tl.load(inp_ptrs, mask=mask, other=0.0) 

70 any_val = tl.reduce(inp_val != 0, axis=0, combine_fn=reduce_any) 

71 mid_ptr = mid + pid 

72 tl.store(mid_ptr, any_val) 

73 

74 

75@libentry() 

76@triton.jit 

77def any_kernel_2(mid, out, MID_SIZE, BLOCK_MID: tl.constexpr): 

78 offset = tl.arange(0, BLOCK_MID) 

79 mid_ptrs = mid + offset 

80 mask = offset < MID_SIZE 

81 mid_val = tl.load(mid_ptrs, mask=mask, other=0).to(tl.int1) 

82 any_val = tl.reduce(mid_val, axis=0, combine_fn=reduce_any) 

83 tl.store(out, any_val) 

84 

85 

86def any(inp): 

87 logger.debug("GEMS_ASCEND ANY") 

88 n_elements = inp.numel() 

89 block_size = triton.next_power_of_2(math.ceil(math.sqrt(n_elements))) 

90 mid_size = triton.cdiv(n_elements, block_size) 

91 block_mid = triton.next_power_of_2(mid_size) 

92 

93 mid = torch.empty((mid_size,), dtype=torch.bool, device=inp.device) 

94 out = torch.empty([], dtype=torch.bool, device=inp.device) 

95 

96 with torch_device_fn.device(inp.device): 

97 any_kernel_1[(mid_size, 1)](inp, mid, n_elements, mid_size, block_size) 

98 any_kernel_2[(1, 1)](mid, out, mid_size, block_mid) 

99 

100 return out 

101 

102 

103def any_dim(inp, dim=None, keepdim=False): 

104 logger.debug("GEMS_ASCEND ANY DIM") 

105 shape = list(inp.shape) 

106 if dim is None: 

107 out = any(inp) 

108 if keepdim: 

109 out = torch.reshape(out, [1] * inp.ndim) 

110 else: 

111 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

112 dim = dim % inp.ndim 

113 inp = dim_compress(inp, dim) 

114 N = shape[dim] 

115 shape[dim] = 1 

116 M = inp.numel() // N 

117 out = torch.empty(shape, dtype=torch.bool, device=inp.device) 

118 

119 def grid_fn(meta): 

120 grid = triton.cdiv(M, meta["BLOCK_M"]) 

121 grid = grid if grid <= 65535 else 65535 

122 return (grid,) 

123 

124 with torch_device_fn.device(inp.device): 

125 any_kernel_dim[grid_fn](inp, out, M, N) 

126 if not keepdim: 

127 out = out.squeeze(dim=dim) 

128 return out 

129 

130 

131def any_dims(inp, dim=None, keepdim=False): 

132 logger.debug("GEMS_ASCEND ANY DIMS") 

133 

134 if dim is None or isinstance(dim, int): 

135 return any_dim(inp, dim=dim, keepdim=keepdim) 

136 assert ((i >= -inp.ndim and i < inp.ndim) for i in dim), "Invalid dim" 

137 

138 shape = list(inp.shape) 

139 dim = [d % inp.ndim for d in dim] 

140 inp = dim_compress(inp, dim) 

141 N = 1 

142 for i in dim: 

143 N *= shape[i] 

144 shape[i] = 1 

145 M = inp.numel() // N 

146 

147 out = torch.empty(shape, dtype=torch.bool, device=inp.device) 

148 

149 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) 

150 

151 with torch_device_fn.device(inp.device): 

152 any_kernel_dim[grid](inp, out, M, N) 

153 if not keepdim: 

154 out = out.squeeze(dim=dim) 

155 return out