Coverage for src/flag_gems/runtime/backend/_hygon/ops/any.py: 0%
103 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import dim_compress, libentry
11from flag_gems.utils import triton_lang_extension as tle
13logger = logging.getLogger(__name__)
14# torch.any: Tests if any elements in input evaluate to True. If the dtype of input
15# is not BOOL, then test if any elements in input evaluate to non-zero value
16# In triton function, test if any elements in input evaluate to non-zero value is ok.
19@triton.jit
20def reduce_any(a, b):
21 return a or b
24@libentry()
25@triton.autotune(configs=runtime.get_tuned_config("any"), key=["M", "N"])
26@triton.jit
27def any_kernel_dim(
28 inp,
29 out,
30 M,
31 N,
32 BLOCK_M: tl.constexpr,
33 BLOCK_N: tl.constexpr,
34):
35 # Map the program id to the row of inp it should compute.
36 pid = tle.program_id(0)
37 rows = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
38 inp = inp + rows * N
39 out = out + rows
40 row_mask = rows < M
42 _any = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.int1)
43 for off in range(0, N, BLOCK_N):
44 cols = off + tl.arange(0, BLOCK_N)[None, :]
45 col_mask = cols < N
46 mask = row_mask and col_mask
48 a = tl.load(inp + cols, mask, other=0.0)
49 _any = _any or (a != 0)
50 any = tl.reduce(_any, axis=1, combine_fn=reduce_any)
51 tl.store(out, any[:, None], row_mask)
54@libentry()
55@triton.jit
56def any_kernel_1(
57 inp,
58 mid,
59 n_elements,
60 mid_size,
61 BLOCK_SIZE: tl.constexpr,
62):
63 pid = tle.program_id(0)
64 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
65 inp_ptrs = inp + offset
66 mask = offset < n_elements
67 inp_val = tl.load(inp_ptrs, mask=mask, other=0.0)
68 any_val = tl.reduce(inp_val != 0, axis=0, combine_fn=reduce_any)
69 mid_ptr = mid + pid
70 tl.store(mid_ptr, any_val)
73@libentry()
74@triton.jit
75def any_kernel_2(mid, out, MID_SIZE, BLOCK_MID: tl.constexpr):
76 offset = tl.arange(0, BLOCK_MID)
77 mid_ptrs = mid + offset
78 mask = offset < MID_SIZE
79 mid_val = tl.load(mid_ptrs, mask=mask, other=0).to(tl.int1)
80 any_val = tl.reduce(mid_val, axis=0, combine_fn=reduce_any)
81 tl.store(out, any_val)
84def any(inp):
85 logger.debug("GEMS ANY")
86 n_elements = inp.numel()
87 block_size = triton.next_power_of_2(math.ceil(math.sqrt(n_elements)))
88 mid_size = triton.cdiv(n_elements, block_size)
89 block_mid = triton.next_power_of_2(mid_size)
91 mid = torch.empty((mid_size,), dtype=torch.bool, device=inp.device)
92 out = torch.empty([], dtype=torch.bool, device=inp.device)
94 with torch_device_fn.device(inp.device):
95 any_kernel_1[(mid_size, 1)](inp, mid, n_elements, mid_size, block_size)
96 any_kernel_2[(1, 1)](mid, out, mid_size, block_mid)
98 return out
101def any_dim(inp, dim=None, keepdim=False):
102 logger.debug("GEMS ANY DIM")
103 shape = list(inp.shape)
104 if dim is None:
105 out = any(inp)
106 if keepdim:
107 out = torch.reshape(out, [1] * inp.ndim)
108 else:
109 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
110 dim = dim % inp.ndim
111 inp = dim_compress(inp, dim)
112 N = shape[dim]
113 shape[dim] = 1
114 M = inp.numel() // N
116 out = torch.empty(shape, dtype=torch.bool, device=inp.device)
118 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
119 with torch_device_fn.device(inp.device):
120 any_kernel_dim[grid](inp, out, M, N)
121 if not keepdim:
122 out = out.squeeze(dim=dim)
123 return out
126def any_dims(inp, dim=None, keepdim=False):
127 logger.debug("GEMS ANY DIMS")
129 if dim is None or isinstance(dim, int):
130 return any_dim(inp, dim=dim, keepdim=keepdim)
131 assert ((i >= -inp.ndim and i < inp.ndim) for i in dim), "Invalid dim"
133 shape = list(inp.shape)
134 dim = [d % inp.ndim for d in dim]
135 inp = dim_compress(inp, dim)
136 N = 1
137 for i in dim:
138 N *= shape[i]
139 shape[i] = 1
140 M = inp.numel() // N
142 out = torch.empty(shape, dtype=torch.bool, device=inp.device)
144 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
145 with torch_device_fn.device(inp.device):
146 any_kernel_dim[grid](inp, out, M, N)
147 if not keepdim:
148 out = out.squeeze(dim=dim)
149 return out