Coverage for src/flag_gems/runtime/backend/_ascend/ops/all.py: 0%

114 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-22 16:54 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import dim_compress, libentry 

11from flag_gems.utils import triton_lang_extension as tle 

12 

13logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}') 

14 

15 

16# torch.all: Tests if all elements in input evaluate to True. If the dtype of input 

17# is not BOOL, then test if all elements in input evaluate to non-zero value 

18# In triton function, test if all elements in input evaluate to non-zero value is ok. 

19 

20 

21@triton.jit 

22def reduce_all(a, b): 

23 return a and b 

24 

25 

26@libentry() 

27@triton.autotune(configs=runtime.get_tuned_config("all"), key=["M", "N"]) 

28@triton.jit 

29def all_kernel_dim( 

30 inp, 

31 out, 

32 M, 

33 N, 

34 BLOCK_M: tl.constexpr, 

35 BLOCK_N: tl.constexpr, 

36): 

37 # Map the program id to the row of inp it should compute. 

38 workers = tle.num_programs(0) 

39 pid = tle.program_id(0) 

40 

41 total_workloads = tl.cdiv(M, BLOCK_M) 

42 workloads = tl.cdiv(total_workloads, workers) 

43 

44 for w in range(workloads): 

45 work_id = pid + w * workers 

46 rows = work_id * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

47 ninp = inp + rows * N 

48 nout = out + rows 

49 row_mask = rows < M 

50 

51 _all = tl.full([BLOCK_M, BLOCK_N], value=1, dtype=tl.int1) 

52 for off in range(0, N, BLOCK_N): 

53 cols = off + tl.arange(0, BLOCK_N)[None, :] 

54 col_mask = cols < N 

55 mask = row_mask and col_mask 

56 

57 a = tl.load(ninp + cols, mask, other=1.0) 

58 _all = _all and (a != 0) 

59 all = tl.reduce(_all, axis=1, combine_fn=reduce_all) 

60 tl.store(nout, all[:, None], row_mask) 

61 

62 

63@libentry() 

64@triton.jit 

65def all_kernel_1( 

66 inp, 

67 mid, 

68 n_elements, 

69 mid_size, 

70 BLOCK_SIZE: tl.constexpr, 

71): 

72 pid = tle.program_id(0) 

73 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

74 inp_ptrs = inp + offset 

75 mask = offset < n_elements 

76 inp_val = tl.load(inp_ptrs, mask=mask, other=1.0) 

77 all_val = tl.reduce(inp_val != 0, axis=0, combine_fn=reduce_all) 

78 mid_ptr = mid + pid 

79 tl.store(mid_ptr, all_val) 

80 

81 

82@libentry() 

83@triton.jit 

84def all_kernel_2(mid, out, MID_SIZE, BLOCK_MID: tl.constexpr): 

85 offset = tl.arange(0, BLOCK_MID) 

86 mid_ptrs = mid + offset 

87 mask = offset < MID_SIZE 

88 mid_val = tl.load(mid_ptrs, mask=mask, other=1).to(tl.int1) 

89 all_val = tl.reduce(mid_val, axis=0, combine_fn=reduce_all) 

90 tl.store(out, all_val) 

91 

92 

93def all(inp): 

94 logger.debug("GEMS_ASCEND ALL") 

95 n_elements = inp.numel() 

96 block_size = triton.next_power_of_2(math.ceil(math.sqrt(n_elements))) 

97 mid_size = triton.cdiv(n_elements, block_size) 

98 block_mid = triton.next_power_of_2(mid_size) 

99 

100 mid = torch.empty((mid_size,), dtype=torch.bool, device=inp.device) 

101 out = torch.empty([], dtype=torch.bool, device=inp.device) 

102 

103 with torch_device_fn.device(inp.device): 

104 all_kernel_1[(mid_size, 1)](inp, mid, n_elements, mid_size, block_size) 

105 all_kernel_2[(1, 1)](mid, out, mid_size, block_mid) 

106 

107 return out 

108 

109 

110def all_dim(inp, dim=None, keepdim=False): 

111 logger.debug("GEMS_ASCEND ALL DIM") 

112 shape = list(inp.shape) 

113 if dim is None: 

114 out = all(inp) 

115 if keepdim: 

116 out = torch.reshape(out, [1] * inp.ndim) 

117 else: 

118 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

119 dim = dim % inp.ndim 

120 inp = dim_compress(inp, dim) 

121 N = shape[dim] 

122 shape[dim] = 1 

123 M = inp.numel() // N 

124 

125 out = torch.empty(shape, dtype=torch.bool, device=inp.device) 

126 

127 def grid(meta): 

128 axis0 = triton.cdiv(M, meta["BLOCK_M"]) 

129 axis0 = axis0 if axis0 < 40 else 40 

130 return (axis0,) 

131 

132 with torch_device_fn.device(inp.device): 

133 all_kernel_dim[grid](inp, out, M, N) 

134 if not keepdim: 

135 out = out.squeeze(dim=dim) 

136 return out 

137 

138 

139def all_dims(inp, dim=None, keepdim=False): 

140 logger.debug("GEMS_ASCEND ALL DIMS") 

141 

142 if dim is None or isinstance(dim, int): 

143 return all_dim(inp, dim=dim, keepdim=keepdim) 

144 assert ((i >= -inp.ndim and i < inp.ndim) for i in dim), "Invalid dim" 

145 

146 shape = list(inp.shape) 

147 dim = [d % inp.ndim for d in dim] 

148 inp = dim_compress(inp, dim) 

149 N = 1 

150 for i in dim: 

151 N *= shape[i] 

152 shape[i] = 1 

153 M = inp.numel() // N 

154 

155 out = torch.empty(shape, dtype=torch.bool, device=inp.device) 

156 

157 def grid(meta): 

158 axis0 = triton.cdiv(M, meta["BLOCK_M"]) 

159 axis0 = axis0 if axis0 < 40 else 40 

160 return (axis0,) 

161 

162 with torch_device_fn.device(inp.device): 

163 all_kernel_dim[grid](inp, out, M, N) 

164 if not keepdim: 

165 out = out.squeeze(dim=dim) 

166 return out