Coverage for src/flag_gems/runtime/backend/_cambricon/ops/any.py: 0%
96 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-19 02:32 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-19 02:32 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import dim_compress, libentry
11from ..utils import TOTAL_CORE_NUM, cfggen_reduce_op2
13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
14# torch.any: Tests if any elements in input evaluate to True. If the dtype of input
15# is not BOOL, then test if any elements in input evaluate to non-zero value
16# In triton function, test if any elements in input evaluate to non-zero value is ok.
19@triton.jit
20def reduce_any(a, b):
21 return a or b
24@libentry()
25@triton.autotune(configs=runtime.get_tuned_config("any"), key=["M", "N"])
26@triton.jit
27def any_kernel_dim(
28 inp,
29 out,
30 M,
31 N,
32 BLOCK_M: tl.constexpr,
33 BLOCK_N: tl.constexpr,
34):
35 # Map the program id to the row of inp it should compute.
36 pid = tl.program_id(0)
37 rows = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
38 inp = inp + rows * N
39 out = out + rows
40 row_mask = rows < M
42 _any = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.int1)
43 for off in range(0, N, BLOCK_N):
44 cols = off + tl.arange(0, BLOCK_N)[None, :]
45 col_mask = cols < N
46 mask = row_mask and col_mask
48 a = tl.load(inp + cols, mask, other=0.0)
49 _any = _any or (a != 0)
50 any = tl.reduce(_any, axis=1, combine_fn=reduce_any)
51 tl.store(out, any[:, None], row_mask)
54@libentry()
55@triton.autotune(configs=cfggen_reduce_op2(), key=["M"])
56@triton.jit
57def any_kernel_1(
58 inp,
59 out,
60 M,
61 BLOCK_SIZE: tl.constexpr,
62 ITER_NUM: tl.constexpr,
63):
64 pid = tl.program_id(0)
65 num_jobs = tl.num_programs(axis=0)
66 block_start = pid * BLOCK_SIZE
67 step = num_jobs * BLOCK_SIZE
68 _tmp = tl.zeros([BLOCK_SIZE], dtype=tl.int1)
69 block_start = block_start.to(tl.int64)
70 for off in range(block_start, M, step):
71 offset = off + tl.arange(0, BLOCK_SIZE)
72 mask = offset < M
73 inp_val = tl.load(inp + offset, mask=mask, other=0.0)
74 _tmp = _tmp or (inp_val != 0)
76 # Reset to original reduce programming mode after optimizing the tl.reduce.
77 for x in tl.static_range(1, int(ITER_NUM), 1):
78 _tmp[: BLOCK_SIZE // (2**x)] = (
79 _tmp[: BLOCK_SIZE // (2**x)]
80 or _tmp[BLOCK_SIZE // (2**x) : (BLOCK_SIZE // (2**x)) * 2]
81 )
83 tl.atomic_or(out, _tmp[0].to(tl.int32))
86def any(inp):
87 logger.debug("GEMS_CAMBRICON ANY")
88 M = inp.numel()
89 grid = lambda meta: (min(triton.cdiv(M, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),)
91 out = torch.zeros([], dtype=torch.int32, device=inp.device)
93 with torch_device_fn.device(inp.device):
94 any_kernel_1[grid](inp, out, M)
96 return out.to(torch.bool)
99def any_dim(inp, dim=None, keepdim=False):
100 logger.debug("GEMS_CAMBRICON ANY DIM")
101 shape = list(inp.shape)
102 if dim is None:
103 out = any(inp)
104 if keepdim:
105 out = torch.reshape(out, [1] * inp.ndim)
106 else:
107 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
108 dim = dim % inp.ndim
109 inp = dim_compress(inp, dim)
110 N = shape[dim]
111 shape[dim] = 1
112 M = inp.numel() // N
114 out = torch.empty(shape, dtype=torch.bool, device=inp.device)
116 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
117 with torch_device_fn.device(inp.device):
118 any_kernel_dim[grid](inp, out, M, N)
119 if not keepdim:
120 out = out.squeeze(dim=dim)
121 return out
124def any_dims(inp, dim=None, keepdim=False):
125 logger.debug("GEMS_CAMBRICON ANY DIMS")
127 if dim is None or isinstance(dim, int):
128 return any_dim(inp, dim=dim, keepdim=keepdim)
129 assert ((i >= -inp.ndim and i < inp.ndim) for i in dim), "Invalid dim"
131 shape = list(inp.shape)
132 dim = [d % inp.ndim for d in dim]
133 inp = dim_compress(inp, dim)
134 N = 1
135 for i in dim:
136 N *= shape[i]
137 shape[i] = 1
138 M = inp.numel() // N
140 out = torch.empty(shape, dtype=torch.bool, device=inp.device)
142 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
143 with torch_device_fn.device(inp.device):
144 any_kernel_dim[grid](inp, out, M, N)
145 if not keepdim:
146 out = out.squeeze(dim=dim)
147 return out