Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/any.py: 0%
160 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.ops.max import max_kernel_1, max_kernel_2
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import dim_compress, libentry
10from flag_gems.utils import triton_lang_extension as tle
12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
13# torch.any: Tests if any elements in input evaluate to True. If the dtype of input
14# is not BOOL, then test if any elements in input evaluate to non-zero value
15# In triton function, test if any elements in input evaluate to non-zero value is ok.
17cluster_num = 12
18core_num = 64
19thread_num = core_num * cluster_num
20buf_len_per_core = 2048
21vector_size = 16
24def get_block(n: int) -> int:
25 if n < cluster_num:
26 res = cluster_num
27 else:
28 res = cluster_num * triton.cdiv(n, cluster_num)
29 return res
32def heur_m_block_size(args):
33 return triton.next_power_of_2(min(triton.cdiv(args["M"], cluster_num), core_num))
36def heur_n_block_size(args):
37 return triton.next_power_of_2(min(args["N"], triton.cdiv(buf_len_per_core, 4)))
40@triton.jit
41def reduce_any(a, b):
42 return a or b
45@libentry()
46# @triton.autotune(configs=runtime.get_tuned_config("any"), key=["M", "N"])
47@triton.heuristics(
48 values={
49 "BLOCK_M": heur_m_block_size,
50 "BLOCK_N": heur_n_block_size,
51 },
52)
53@triton.jit
54def any_kernel_dim(
55 inp,
56 out,
57 M,
58 N,
59 BLOCK_M: tl.constexpr,
60 BLOCK_N: tl.constexpr,
61):
62 # Map the program id to the row of inp it should compute.
63 pid = tle.program_id(0)
64 rows = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
65 inp = inp + rows * N
66 out = out + rows
67 row_mask = rows < M
69 _any = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.int1)
70 for off in range(0, N, BLOCK_N):
71 cols = off + tl.arange(0, BLOCK_N)[None, :]
72 col_mask = cols < N
73 mask = row_mask and col_mask
75 a = tl.load(inp + cols, mask, other=0.0)
76 _any = _any or (a != 0)
77 any = tl.reduce(_any, axis=1, combine_fn=reduce_any)
78 tl.store(out, any[:, None], row_mask)
81@libentry()
82@triton.heuristics(
83 values={
84 "BLOCK_M": heur_m_block_size,
85 "BLOCK_N": heur_n_block_size,
86 },
87)
88@triton.jit
89def max_kernel_dim(
90 in_ptr,
91 out_ptr,
92 M,
93 N,
94 BLOCK_M: tl.constexpr,
95 BLOCK_N: tl.constexpr,
96):
97 xoffset = tl.program_id(0) * BLOCK_M
98 xindex = xoffset + tl.arange(0, BLOCK_M)[:, None]
99 xmask = xindex < M
100 rbase = tl.arange(0, BLOCK_N)[None, :]
101 _max = tl.full([BLOCK_M, BLOCK_N], float("-inf"), tl.float32)
102 for roffset in range(0, N, BLOCK_N):
103 rindex = roffset + rbase
104 rmask = rindex < N
105 r1 = rindex
106 inp = tl.load(
107 in_ptr + (r1 + (N * xindex)), rmask & xmask, other=float("-inf")
108 ).to(tl.float32)
109 inpb = tl.broadcast_to(inp, [BLOCK_M, BLOCK_N])
110 _max = tl.maximum(_max, inpb)
111 tmp2 = tl.max(_max, axis=1, return_indices=False)[:, None]
112 tl.store(out_ptr + xindex, tmp2, xmask)
115@libentry()
116@triton.jit
117def any_kernel_1(
118 inp,
119 mid,
120 n_elements,
121 BLOCK_SIZE: tl.constexpr,
122):
123 pid = tle.program_id(0)
124 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
125 inp_ptrs = inp + offset
126 mask = offset < n_elements
127 inp_val = tl.load(inp_ptrs, mask=mask, other=0.0)
128 any_val = tl.reduce(inp_val != 0, axis=0, combine_fn=reduce_any)
129 mid_ptr = mid + pid
130 tl.store(mid_ptr, any_val)
133@libentry()
134@triton.jit
135def any_kernel_2(mid, out, MID_SIZE, BLOCK_MID: tl.constexpr):
136 offset = tl.arange(0, BLOCK_MID)
137 mid_ptrs = mid + offset
138 mask = offset < MID_SIZE
139 mid_val = tl.load(mid_ptrs, mask=mask, other=0).to(tl.int1)
140 any_val = tl.reduce(mid_val, axis=0, combine_fn=reduce_any)
141 tl.store(out, any_val)
144def any(inp):
145 logger.debug("GEMS ANY")
146 n_elements = inp.numel()
147 block_size = max(
148 triton.cdiv(get_block(n_elements), cluster_num),
149 triton.cdiv(buf_len_per_core * core_num, 4),
150 )
152 mid_size = triton.cdiv(n_elements, block_size)
153 block_mid = triton.next_power_of_2(mid_size)
155 if n_elements >= vector_size * thread_num:
156 inp_uint8 = inp.view(torch.uint8)
158 mid = torch.empty((mid_size,), dtype=torch.uint8, device=inp.device)
159 out = torch.empty([], dtype=torch.uint8, device=inp.device)
161 with torch_device_fn.device(inp.device):
162 max_kernel_1[(mid_size, 1)](
163 inp_uint8, mid, n_elements, block_size, buffer_size_limit=2048
164 )
165 if mid_size == 1:
166 return mid.view(torch.bool).reshape([])
168 max_kernel_2[(1, 1)](mid, out, mid_size, block_mid, buffer_size_limit=2048)
169 out = out.view(torch.bool)
170 else:
171 mid = torch.empty((mid_size,), dtype=torch.bool, device=inp.device)
172 out = torch.empty([], dtype=torch.bool, device=inp.device)
174 with torch_device_fn.device(inp.device):
175 any_kernel_1[(mid_size, 1)](
176 inp, mid, n_elements, block_size, buffer_size_limit=2048
177 )
178 if mid_size == 1:
179 return mid.reshape([])
180 any_kernel_2[(1, 1)](mid, out, mid_size, block_mid, buffer_size_limit=2048)
182 return out
185def any_dim(inp, dim=None, keepdim=False):
186 logger.debug("GEMS ANY DIM")
187 shape = list(inp.shape)
188 if dim is None:
189 out = any(inp)
190 if keepdim:
191 out = torch.reshape(out, [1] * inp.ndim)
192 else:
193 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
194 dim = dim % inp.ndim
195 inp = dim_compress(inp, dim)
196 N = shape[dim]
197 shape[dim] = 1
198 M = inp.numel() // N
200 if N >= vector_size * vector_size:
201 # according to api, op == any, use max to calculate
202 inpf = inp.to(torch.float)
203 outf = torch.empty(shape, dtype=torch.float, device=inp.device)
205 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
206 with torch_device_fn.device(inp.device):
207 max_kernel_dim[grid](inpf, outf, M, N, buffer_size_limit=2048)
208 out = outf.to(torch.bool)
209 else:
210 out = torch.empty(shape, dtype=torch.bool, device=inp.device)
211 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
212 with torch_device_fn.device(inp.device):
213 any_kernel_dim[grid](inp, out, M, N, buffer_size_limit=2048)
215 if not keepdim:
216 out = out.squeeze(dim=dim)
217 return out
220def any_dims(inp, dim=None, keepdim=False):
221 logger.debug("GEMS ANY DIMS")
223 if dim is None or isinstance(dim, int):
224 return any_dim(inp, dim=dim, keepdim=keepdim)
225 assert ((i >= -inp.ndim and i < inp.ndim) for i in dim), "Invalid dim"
227 shape = list(inp.shape)
228 dim = [d % inp.ndim for d in dim]
229 inp = dim_compress(inp, dim)
230 N = 1
231 for i in dim:
232 N *= shape[i]
233 shape[i] = 1
234 M = inp.numel() // N
236 if N >= vector_size * core_num:
237 # according to api, op == any, use max to calculate
238 inpf = inp.to(torch.float)
239 outf = torch.empty(shape, dtype=torch.float, device=inp.device)
241 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
242 with torch_device_fn.device(inp.device):
243 max_kernel_dim[grid](inpf, outf, M, N, buffer_size_limit=2048)
244 out = outf.to(torch.bool)
245 else:
246 out = torch.empty(shape, dtype=torch.bool, device=inp.device)
247 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
248 with torch_device_fn.device(inp.device):
249 any_kernel_dim[grid](inp, out, M, N, buffer_size_limit=2048)
251 if not keepdim:
252 out = out.squeeze(dim=dim)
253 return out