Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/any.py: 0%

160 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-21 14:31 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.ops.max import max_kernel_1, max_kernel_2 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import dim_compress, libentry 

10from flag_gems.utils import triton_lang_extension as tle 

11 

12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

13# torch.any: Tests if any elements in input evaluate to True. If the dtype of input 

14# is not BOOL, then test if any elements in input evaluate to non-zero value 

15# In triton function, test if any elements in input evaluate to non-zero value is ok. 

16 

17cluster_num = 12 

18core_num = 64 

19thread_num = core_num * cluster_num 

20buf_len_per_core = 2048 

21vector_size = 16 

22 

23 

24def get_block(n: int) -> int: 

25 if n < cluster_num: 

26 res = cluster_num 

27 else: 

28 res = cluster_num * triton.cdiv(n, cluster_num) 

29 return res 

30 

31 

32def heur_m_block_size(args): 

33 return triton.next_power_of_2(min(triton.cdiv(args["M"], cluster_num), core_num)) 

34 

35 

36def heur_n_block_size(args): 

37 return triton.next_power_of_2(min(args["N"], triton.cdiv(buf_len_per_core, 4))) 

38 

39 

40@triton.jit 

41def reduce_any(a, b): 

42 return a or b 

43 

44 

45@libentry() 

46# @triton.autotune(configs=runtime.get_tuned_config("any"), key=["M", "N"]) 

47@triton.heuristics( 

48 values={ 

49 "BLOCK_M": heur_m_block_size, 

50 "BLOCK_N": heur_n_block_size, 

51 }, 

52) 

53@triton.jit 

54def any_kernel_dim( 

55 inp, 

56 out, 

57 M, 

58 N, 

59 BLOCK_M: tl.constexpr, 

60 BLOCK_N: tl.constexpr, 

61): 

62 # Map the program id to the row of inp it should compute. 

63 pid = tle.program_id(0) 

64 rows = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

65 inp = inp + rows * N 

66 out = out + rows 

67 row_mask = rows < M 

68 

69 _any = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.int1) 

70 for off in range(0, N, BLOCK_N): 

71 cols = off + tl.arange(0, BLOCK_N)[None, :] 

72 col_mask = cols < N 

73 mask = row_mask and col_mask 

74 

75 a = tl.load(inp + cols, mask, other=0.0) 

76 _any = _any or (a != 0) 

77 any = tl.reduce(_any, axis=1, combine_fn=reduce_any) 

78 tl.store(out, any[:, None], row_mask) 

79 

80 

81@libentry() 

82@triton.heuristics( 

83 values={ 

84 "BLOCK_M": heur_m_block_size, 

85 "BLOCK_N": heur_n_block_size, 

86 }, 

87) 

88@triton.jit 

89def max_kernel_dim( 

90 in_ptr, 

91 out_ptr, 

92 M, 

93 N, 

94 BLOCK_M: tl.constexpr, 

95 BLOCK_N: tl.constexpr, 

96): 

97 xoffset = tl.program_id(0) * BLOCK_M 

98 xindex = xoffset + tl.arange(0, BLOCK_M)[:, None] 

99 xmask = xindex < M 

100 rbase = tl.arange(0, BLOCK_N)[None, :] 

101 _max = tl.full([BLOCK_M, BLOCK_N], float("-inf"), tl.float32) 

102 for roffset in range(0, N, BLOCK_N): 

103 rindex = roffset + rbase 

104 rmask = rindex < N 

105 r1 = rindex 

106 inp = tl.load( 

107 in_ptr + (r1 + (N * xindex)), rmask & xmask, other=float("-inf") 

108 ).to(tl.float32) 

109 inpb = tl.broadcast_to(inp, [BLOCK_M, BLOCK_N]) 

110 _max = tl.maximum(_max, inpb) 

111 tmp2 = tl.max(_max, axis=1, return_indices=False)[:, None] 

112 tl.store(out_ptr + xindex, tmp2, xmask) 

113 

114 

115@libentry() 

116@triton.jit 

117def any_kernel_1( 

118 inp, 

119 mid, 

120 n_elements, 

121 BLOCK_SIZE: tl.constexpr, 

122): 

123 pid = tle.program_id(0) 

124 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

125 inp_ptrs = inp + offset 

126 mask = offset < n_elements 

127 inp_val = tl.load(inp_ptrs, mask=mask, other=0.0) 

128 any_val = tl.reduce(inp_val != 0, axis=0, combine_fn=reduce_any) 

129 mid_ptr = mid + pid 

130 tl.store(mid_ptr, any_val) 

131 

132 

133@libentry() 

134@triton.jit 

135def any_kernel_2(mid, out, MID_SIZE, BLOCK_MID: tl.constexpr): 

136 offset = tl.arange(0, BLOCK_MID) 

137 mid_ptrs = mid + offset 

138 mask = offset < MID_SIZE 

139 mid_val = tl.load(mid_ptrs, mask=mask, other=0).to(tl.int1) 

140 any_val = tl.reduce(mid_val, axis=0, combine_fn=reduce_any) 

141 tl.store(out, any_val) 

142 

143 

144def any(inp): 

145 logger.debug("GEMS ANY") 

146 n_elements = inp.numel() 

147 block_size = max( 

148 triton.cdiv(get_block(n_elements), cluster_num), 

149 triton.cdiv(buf_len_per_core * core_num, 4), 

150 ) 

151 

152 mid_size = triton.cdiv(n_elements, block_size) 

153 block_mid = triton.next_power_of_2(mid_size) 

154 

155 if n_elements >= vector_size * thread_num: 

156 inp_uint8 = inp.view(torch.uint8) 

157 

158 mid = torch.empty((mid_size,), dtype=torch.uint8, device=inp.device) 

159 out = torch.empty([], dtype=torch.uint8, device=inp.device) 

160 

161 with torch_device_fn.device(inp.device): 

162 max_kernel_1[(mid_size, 1)]( 

163 inp_uint8, mid, n_elements, block_size, buffer_size_limit=2048 

164 ) 

165 if mid_size == 1: 

166 return mid.view(torch.bool).reshape([]) 

167 

168 max_kernel_2[(1, 1)](mid, out, mid_size, block_mid, buffer_size_limit=2048) 

169 out = out.view(torch.bool) 

170 else: 

171 mid = torch.empty((mid_size,), dtype=torch.bool, device=inp.device) 

172 out = torch.empty([], dtype=torch.bool, device=inp.device) 

173 

174 with torch_device_fn.device(inp.device): 

175 any_kernel_1[(mid_size, 1)]( 

176 inp, mid, n_elements, block_size, buffer_size_limit=2048 

177 ) 

178 if mid_size == 1: 

179 return mid.reshape([]) 

180 any_kernel_2[(1, 1)](mid, out, mid_size, block_mid, buffer_size_limit=2048) 

181 

182 return out 

183 

184 

185def any_dim(inp, dim=None, keepdim=False): 

186 logger.debug("GEMS ANY DIM") 

187 shape = list(inp.shape) 

188 if dim is None: 

189 out = any(inp) 

190 if keepdim: 

191 out = torch.reshape(out, [1] * inp.ndim) 

192 else: 

193 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

194 dim = dim % inp.ndim 

195 inp = dim_compress(inp, dim) 

196 N = shape[dim] 

197 shape[dim] = 1 

198 M = inp.numel() // N 

199 

200 if N >= vector_size * vector_size: 

201 # according to api, op == any, use max to calculate 

202 inpf = inp.to(torch.float) 

203 outf = torch.empty(shape, dtype=torch.float, device=inp.device) 

204 

205 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) 

206 with torch_device_fn.device(inp.device): 

207 max_kernel_dim[grid](inpf, outf, M, N, buffer_size_limit=2048) 

208 out = outf.to(torch.bool) 

209 else: 

210 out = torch.empty(shape, dtype=torch.bool, device=inp.device) 

211 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) 

212 with torch_device_fn.device(inp.device): 

213 any_kernel_dim[grid](inp, out, M, N, buffer_size_limit=2048) 

214 

215 if not keepdim: 

216 out = out.squeeze(dim=dim) 

217 return out 

218 

219 

220def any_dims(inp, dim=None, keepdim=False): 

221 logger.debug("GEMS ANY DIMS") 

222 

223 if dim is None or isinstance(dim, int): 

224 return any_dim(inp, dim=dim, keepdim=keepdim) 

225 assert ((i >= -inp.ndim and i < inp.ndim) for i in dim), "Invalid dim" 

226 

227 shape = list(inp.shape) 

228 dim = [d % inp.ndim for d in dim] 

229 inp = dim_compress(inp, dim) 

230 N = 1 

231 for i in dim: 

232 N *= shape[i] 

233 shape[i] = 1 

234 M = inp.numel() // N 

235 

236 if N >= vector_size * core_num: 

237 # according to api, op == any, use max to calculate 

238 inpf = inp.to(torch.float) 

239 outf = torch.empty(shape, dtype=torch.float, device=inp.device) 

240 

241 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) 

242 with torch_device_fn.device(inp.device): 

243 max_kernel_dim[grid](inpf, outf, M, N, buffer_size_limit=2048) 

244 out = outf.to(torch.bool) 

245 else: 

246 out = torch.empty(shape, dtype=torch.bool, device=inp.device) 

247 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) 

248 with torch_device_fn.device(inp.device): 

249 any_kernel_dim[grid](inp, out, M, N, buffer_size_limit=2048) 

250 

251 if not keepdim: 

252 out = out.squeeze(dim=dim) 

253 return out