Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/all.py: 0%
160 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-18 02:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-18 02:36 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.ops.min import min_kernel_1, min_kernel_2
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import dim_compress, libentry
10from flag_gems.utils import triton_lang_extension as tle
12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
13# import math
16# torch.all: Tests if all elements in input evaluate to True. If the dtype of input
17# is not BOOL, then test if all elements in input evaluate to non-zero value
18# In triton function, test if all elements in input evaluate to non-zero value is ok.
20cluster_num = 12
21core_num = 64
22thread_num = core_num * cluster_num
23buf_len_per_core = 2048
24vector_size = 16
27def get_block(n: int) -> int:
28 if n < cluster_num:
29 res = cluster_num
30 else:
31 res = cluster_num * triton.cdiv(n, cluster_num)
32 return res
35def heur_m_block_size(args):
36 return triton.next_power_of_2(min(triton.cdiv(args["M"], cluster_num), core_num))
39def heur_n_block_size(args):
40 return triton.next_power_of_2(min(args["N"], 512))
43@triton.jit
44def reduce_all(a, b):
45 return a and b
48# def heur_m_block_size(args):
49# return triton.next_power_of_2(triton.cdiv(args["M"], 12)) # cluster_num
52# def heur_n_block_size(args):
53# import builtins
55# return builtins.min(triton.next_power_of_2(args["N"]), 8192 * 4)
58@libentry()
59# @triton.autotune(configs=runtime.get_tuned_config("all"), key=["M", "N"])
60@triton.heuristics(
61 values={
62 "BLOCK_M": heur_m_block_size,
63 "BLOCK_N": heur_n_block_size,
64 },
65)
66@triton.jit
67def all_kernel_dim(
68 inp,
69 out,
70 M,
71 N,
72 BLOCK_M: tl.constexpr,
73 BLOCK_N: tl.constexpr,
74):
75 # Map the program id to the row of inp it should compute.
76 pid = tle.program_id(0)
77 rows = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
78 inp = inp + rows * N
79 out = out + rows
80 row_mask = rows < M
82 _all = tl.full([BLOCK_M, BLOCK_N], value=1, dtype=tl.int1)
83 for off in range(0, N, BLOCK_N):
84 cols = off + tl.arange(0, BLOCK_N)[None, :]
85 col_mask = cols < N
86 mask = row_mask and col_mask
88 a = tl.load(inp + cols, mask, other=1.0)
89 _all = _all and (a != 0)
90 all = tl.reduce(_all, axis=1, combine_fn=reduce_all)
91 tl.store(out, all[:, None], row_mask)
94@libentry()
95@triton.heuristics(
96 values={
97 "BLOCK_M": heur_m_block_size,
98 "BLOCK_N": heur_n_block_size,
99 },
100)
101@triton.jit
102def min_kernel_dim(
103 in_ptr,
104 out_ptr,
105 M,
106 N,
107 BLOCK_M: tl.constexpr,
108 BLOCK_N: tl.constexpr,
109):
110 xoffset = tl.program_id(0) * BLOCK_M
111 xindex = xoffset + tl.arange(0, BLOCK_M)[:, None]
112 xmask = xindex < M
113 rbase = tl.arange(0, BLOCK_N)[None, :]
114 _min = tl.full([BLOCK_M, BLOCK_N], float("inf"), tl.float32)
115 for roffset in range(0, N, BLOCK_N):
116 rindex = roffset + rbase
117 rmask = rindex < N
118 r1 = rindex
119 inp = tl.load(
120 in_ptr + (r1 + (N * xindex)), rmask & xmask, other=float("inf")
121 ).to(tl.float32)
122 inpb = tl.broadcast_to(inp, [BLOCK_M, BLOCK_N])
123 _min = tl.minimum(_min, inpb)
124 tmp2 = tl.min(_min, axis=1, return_indices=False)[:, None]
125 tl.store(out_ptr + xindex, tmp2, xmask)
128@libentry()
129@triton.jit
130def all_kernel_1(
131 inp,
132 mid,
133 n_elements,
134 BLOCK_SIZE: tl.constexpr,
135):
136 pid = tle.program_id(0)
137 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
138 inp_ptrs = inp + offset
139 mask = offset < n_elements
140 inp_val = tl.load(inp_ptrs, mask=mask, other=1.0)
141 all_val = tl.reduce(inp_val != 0, axis=0, combine_fn=reduce_all)
142 mid_ptr = mid + pid
143 tl.store(mid_ptr, all_val)
146@libentry()
147@triton.jit
148def all_kernel_2(
149 mid,
150 out,
151 MID_SIZE,
152 BLOCK_MID: tl.constexpr,
153):
154 offset = tl.arange(0, BLOCK_MID)
155 mid_ptrs = mid + offset
156 mask = offset < MID_SIZE
157 mid_val = tl.load(mid_ptrs, mask=mask, other=1).to(tl.int1)
158 all_val = tl.reduce(mid_val, axis=0, combine_fn=reduce_all)
159 tl.store(out, all_val)
162def all(inp):
163 logger.debug("GEMS ALL")
164 n_elements = inp.numel()
165 block_size = min(
166 triton.cdiv(get_block(n_elements), cluster_num),
167 triton.cdiv(buf_len_per_core * core_num, 4),
168 )
169 mid_size = triton.cdiv(n_elements, block_size)
170 block_mid = triton.next_power_of_2(mid_size)
172 if n_elements >= vector_size * thread_num:
173 # according to api, op == all, use min to calculate
174 inpf = inp.to(torch.float)
175 midf = torch.empty((mid_size,), dtype=torch.float, device=inp.device)
176 outf = torch.empty([], dtype=torch.float, device=inp.device)
178 with torch_device_fn.device(inp.device):
179 min_kernel_1[(mid_size, 1)](
180 inpf, midf, n_elements, block_size, buffer_size_limit=2048
181 )
182 if mid_size == 1:
183 return midf.to(torch.bool).reshape([])
184 min_kernel_2[(1, 1)](
185 midf, outf, mid_size, block_mid, buffer_size_limit=2048
186 )
187 out = outf.to(torch.bool)
188 else:
189 mid = torch.empty((mid_size,), dtype=torch.bool, device=inp.device)
190 out = torch.empty([], dtype=torch.bool, device=inp.device)
192 with torch_device_fn.device(inp.device):
193 all_kernel_1[(mid_size, 1)](
194 inp, mid, n_elements, block_size, buffer_size_limit=2048
195 )
196 if mid_size == 1:
197 return mid.reshape([])
198 all_kernel_2[(1, 1)](mid, out, mid_size, block_mid, buffer_size_limit=2048)
200 return out
203def all_dim(inp, dim=None, keepdim=False):
204 logger.debug("GEMS ALL DIM")
205 shape = list(inp.shape)
206 if dim is None:
207 out = all(inp)
208 if keepdim:
209 out = torch.reshape(out, [1] * inp.ndim)
210 else:
211 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
212 dim = dim % inp.ndim
213 inp = dim_compress(inp, dim)
214 N = shape[dim]
215 shape[dim] = 1
216 M = inp.numel() // N
218 if N >= vector_size * vector_size:
219 # according to api, op == all, use min to calculate
220 inpf = inp.to(torch.float)
221 outf = torch.empty(shape, dtype=torch.float, device=inp.device)
223 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
224 with torch_device_fn.device(inp.device):
225 min_kernel_dim[grid](inpf, outf, M, N, buffer_size_limit=2048)
226 out = outf.to(torch.bool)
227 else:
228 out = torch.empty(shape, dtype=torch.bool, device=inp.device)
229 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
230 with torch_device_fn.device(inp.device):
231 all_kernel_dim[grid](inp, out, M, N, buffer_size_limit=2048)
233 if not keepdim:
234 out = out.squeeze(dim=dim)
235 return out
238def all_dims(inp, dim=None, keepdim=False):
239 logger.debug("GEMS ALL DIMS")
241 if dim is None or isinstance(dim, int):
242 return all_dim(inp, dim=dim, keepdim=keepdim)
243 assert ((i >= -inp.ndim and i < inp.ndim) for i in dim), "Invalid dim"
245 shape = list(inp.shape)
246 dim = [d % inp.ndim for d in dim]
247 inp = dim_compress(inp, dim)
248 N = 1
249 for i in dim:
250 N *= shape[i]
251 shape[i] = 1
252 M = inp.numel() // N
254 if N >= vector_size * core_num:
255 # according to api, op == all, use min to calculate
256 inpf = inp.to(torch.float)
257 outf = torch.empty(shape, dtype=torch.float, device=inp.device)
259 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
260 with torch_device_fn.device(inp.device):
261 min_kernel_dim[grid](inpf, outf, M, N, buffer_size_limit=2048)
262 out = outf.to(torch.bool)
263 else:
264 out = torch.empty(shape, dtype=torch.bool, device=inp.device)
265 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
266 with torch_device_fn.device(inp.device):
267 all_kernel_dim[grid](inp, out, M, N, buffer_size_limit=2048)
269 if not keepdim:
270 out = out.squeeze(dim=dim)
271 return out