Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/all.py: 0%

160 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-29 04:01 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.ops.min import min_kernel_1, min_kernel_2 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import dim_compress, libentry 

10from flag_gems.utils import triton_lang_extension as tle 

11 

12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

13# import math 

14 

15 

16# torch.all: Tests if all elements in input evaluate to True. If the dtype of input 

17# is not BOOL, then test if all elements in input evaluate to non-zero value 

18# In triton function, test if all elements in input evaluate to non-zero value is ok. 

19 

20cluster_num = 12 

21core_num = 64 

22thread_num = core_num * cluster_num 

23buf_len_per_core = 2048 

24vector_size = 16 

25 

26 

27def get_block(n: int) -> int: 

28 if n < cluster_num: 

29 res = cluster_num 

30 else: 

31 res = cluster_num * triton.cdiv(n, cluster_num) 

32 return res 

33 

34 

35def heur_m_block_size(args): 

36 return triton.next_power_of_2(min(triton.cdiv(args["M"], cluster_num), core_num)) 

37 

38 

39def heur_n_block_size(args): 

40 return triton.next_power_of_2(min(args["N"], 512)) 

41 

42 

43@triton.jit 

44def reduce_all(a, b): 

45 return a and b 

46 

47 

48# def heur_m_block_size(args): 

49# return triton.next_power_of_2(triton.cdiv(args["M"], 12)) # cluster_num 

50 

51 

52# def heur_n_block_size(args): 

53# import builtins 

54 

55# return builtins.min(triton.next_power_of_2(args["N"]), 8192 * 4) 

56 

57 

58@libentry() 

59# @triton.autotune(configs=runtime.get_tuned_config("all"), key=["M", "N"]) 

60@triton.heuristics( 

61 values={ 

62 "BLOCK_M": heur_m_block_size, 

63 "BLOCK_N": heur_n_block_size, 

64 }, 

65) 

66@triton.jit 

67def all_kernel_dim( 

68 inp, 

69 out, 

70 M, 

71 N, 

72 BLOCK_M: tl.constexpr, 

73 BLOCK_N: tl.constexpr, 

74): 

75 # Map the program id to the row of inp it should compute. 

76 pid = tle.program_id(0) 

77 rows = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

78 inp = inp + rows * N 

79 out = out + rows 

80 row_mask = rows < M 

81 

82 _all = tl.full([BLOCK_M, BLOCK_N], value=1, dtype=tl.int1) 

83 for off in range(0, N, BLOCK_N): 

84 cols = off + tl.arange(0, BLOCK_N)[None, :] 

85 col_mask = cols < N 

86 mask = row_mask and col_mask 

87 

88 a = tl.load(inp + cols, mask, other=1.0) 

89 _all = _all and (a != 0) 

90 all = tl.reduce(_all, axis=1, combine_fn=reduce_all) 

91 tl.store(out, all[:, None], row_mask) 

92 

93 

94@libentry() 

95@triton.heuristics( 

96 values={ 

97 "BLOCK_M": heur_m_block_size, 

98 "BLOCK_N": heur_n_block_size, 

99 }, 

100) 

101@triton.jit 

102def min_kernel_dim( 

103 in_ptr, 

104 out_ptr, 

105 M, 

106 N, 

107 BLOCK_M: tl.constexpr, 

108 BLOCK_N: tl.constexpr, 

109): 

110 xoffset = tl.program_id(0) * BLOCK_M 

111 xindex = xoffset + tl.arange(0, BLOCK_M)[:, None] 

112 xmask = xindex < M 

113 rbase = tl.arange(0, BLOCK_N)[None, :] 

114 _min = tl.full([BLOCK_M, BLOCK_N], float("inf"), tl.float32) 

115 for roffset in range(0, N, BLOCK_N): 

116 rindex = roffset + rbase 

117 rmask = rindex < N 

118 r1 = rindex 

119 inp = tl.load( 

120 in_ptr + (r1 + (N * xindex)), rmask & xmask, other=float("inf") 

121 ).to(tl.float32) 

122 inpb = tl.broadcast_to(inp, [BLOCK_M, BLOCK_N]) 

123 _min = tl.minimum(_min, inpb) 

124 tmp2 = tl.min(_min, axis=1, return_indices=False)[:, None] 

125 tl.store(out_ptr + xindex, tmp2, xmask) 

126 

127 

128@libentry() 

129@triton.jit 

130def all_kernel_1( 

131 inp, 

132 mid, 

133 n_elements, 

134 BLOCK_SIZE: tl.constexpr, 

135): 

136 pid = tle.program_id(0) 

137 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

138 inp_ptrs = inp + offset 

139 mask = offset < n_elements 

140 inp_val = tl.load(inp_ptrs, mask=mask, other=1.0) 

141 all_val = tl.reduce(inp_val != 0, axis=0, combine_fn=reduce_all) 

142 mid_ptr = mid + pid 

143 tl.store(mid_ptr, all_val) 

144 

145 

146@libentry() 

147@triton.jit 

148def all_kernel_2( 

149 mid, 

150 out, 

151 MID_SIZE, 

152 BLOCK_MID: tl.constexpr, 

153): 

154 offset = tl.arange(0, BLOCK_MID) 

155 mid_ptrs = mid + offset 

156 mask = offset < MID_SIZE 

157 mid_val = tl.load(mid_ptrs, mask=mask, other=1).to(tl.int1) 

158 all_val = tl.reduce(mid_val, axis=0, combine_fn=reduce_all) 

159 tl.store(out, all_val) 

160 

161 

162def all(inp): 

163 logger.debug("GEMS ALL") 

164 n_elements = inp.numel() 

165 block_size = min( 

166 triton.cdiv(get_block(n_elements), cluster_num), 

167 triton.cdiv(buf_len_per_core * core_num, 4), 

168 ) 

169 mid_size = triton.cdiv(n_elements, block_size) 

170 block_mid = triton.next_power_of_2(mid_size) 

171 

172 if n_elements >= vector_size * thread_num: 

173 # according to api, op == all, use min to calculate 

174 inpf = inp.to(torch.float) 

175 midf = torch.empty((mid_size,), dtype=torch.float, device=inp.device) 

176 outf = torch.empty([], dtype=torch.float, device=inp.device) 

177 

178 with torch_device_fn.device(inp.device): 

179 min_kernel_1[(mid_size, 1)]( 

180 inpf, midf, n_elements, block_size, buffer_size_limit=2048 

181 ) 

182 if mid_size == 1: 

183 return midf.to(torch.bool).reshape([]) 

184 min_kernel_2[(1, 1)]( 

185 midf, outf, mid_size, block_mid, buffer_size_limit=2048 

186 ) 

187 out = outf.to(torch.bool) 

188 else: 

189 mid = torch.empty((mid_size,), dtype=torch.bool, device=inp.device) 

190 out = torch.empty([], dtype=torch.bool, device=inp.device) 

191 

192 with torch_device_fn.device(inp.device): 

193 all_kernel_1[(mid_size, 1)]( 

194 inp, mid, n_elements, block_size, buffer_size_limit=2048 

195 ) 

196 if mid_size == 1: 

197 return mid.reshape([]) 

198 all_kernel_2[(1, 1)](mid, out, mid_size, block_mid, buffer_size_limit=2048) 

199 

200 return out 

201 

202 

203def all_dim(inp, dim=None, keepdim=False): 

204 logger.debug("GEMS ALL DIM") 

205 shape = list(inp.shape) 

206 if dim is None: 

207 out = all(inp) 

208 if keepdim: 

209 out = torch.reshape(out, [1] * inp.ndim) 

210 else: 

211 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

212 dim = dim % inp.ndim 

213 inp = dim_compress(inp, dim) 

214 N = shape[dim] 

215 shape[dim] = 1 

216 M = inp.numel() // N 

217 

218 if N >= vector_size * vector_size: 

219 # according to api, op == all, use min to calculate 

220 inpf = inp.to(torch.float) 

221 outf = torch.empty(shape, dtype=torch.float, device=inp.device) 

222 

223 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) 

224 with torch_device_fn.device(inp.device): 

225 min_kernel_dim[grid](inpf, outf, M, N, buffer_size_limit=2048) 

226 out = outf.to(torch.bool) 

227 else: 

228 out = torch.empty(shape, dtype=torch.bool, device=inp.device) 

229 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) 

230 with torch_device_fn.device(inp.device): 

231 all_kernel_dim[grid](inp, out, M, N, buffer_size_limit=2048) 

232 

233 if not keepdim: 

234 out = out.squeeze(dim=dim) 

235 return out 

236 

237 

238def all_dims(inp, dim=None, keepdim=False): 

239 logger.debug("GEMS ALL DIMS") 

240 

241 if dim is None or isinstance(dim, int): 

242 return all_dim(inp, dim=dim, keepdim=keepdim) 

243 assert ((i >= -inp.ndim and i < inp.ndim) for i in dim), "Invalid dim" 

244 

245 shape = list(inp.shape) 

246 dim = [d % inp.ndim for d in dim] 

247 inp = dim_compress(inp, dim) 

248 N = 1 

249 for i in dim: 

250 N *= shape[i] 

251 shape[i] = 1 

252 M = inp.numel() // N 

253 

254 if N >= vector_size * core_num: 

255 # according to api, op == all, use min to calculate 

256 inpf = inp.to(torch.float) 

257 outf = torch.empty(shape, dtype=torch.float, device=inp.device) 

258 

259 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) 

260 with torch_device_fn.device(inp.device): 

261 min_kernel_dim[grid](inpf, outf, M, N, buffer_size_limit=2048) 

262 out = outf.to(torch.bool) 

263 else: 

264 out = torch.empty(shape, dtype=torch.bool, device=inp.device) 

265 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) 

266 with torch_device_fn.device(inp.device): 

267 all_kernel_dim[grid](inp, out, M, N, buffer_size_limit=2048) 

268 

269 if not keepdim: 

270 out = out.squeeze(dim=dim) 

271 return out