Coverage for src/flag_gems/runtime/backend/_ascend/ops/any.py: 0%
106 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import dim_compress, libentry
11from flag_gems.utils import triton_lang_extension as tle
13logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}')
16# torch.any: Tests if any elements in input evaluate to True. If the dtype of input
17# is not BOOL, then test if any elements in input evaluate to non-zero value
18# In triton function, test if any elements in input evaluate to non-zero value is ok.
21@triton.jit
22def reduce_any(a, b):
23 return a or b
26@libentry()
27@triton.autotune(configs=runtime.get_tuned_config("any"), key=["M", "N"])
28@triton.jit
29def any_kernel_dim(
30 inp,
31 out,
32 M,
33 N,
34 BLOCK_M: tl.constexpr,
35 BLOCK_N: tl.constexpr,
36):
37 # Map the program id to the row of inp it should compute.
38 pid = tle.program_id(0)
39 rows = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
40 inp = inp + rows * N
41 out = out + rows
42 row_mask = rows < M
44 _any = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.int1)
45 for off in range(0, N, BLOCK_N):
46 cols = off + tl.arange(0, BLOCK_N)[None, :]
47 col_mask = cols < N
48 mask = row_mask and col_mask
50 a = tl.load(inp + cols, mask, other=0.0)
51 _any = _any or (a != 0)
52 any = tl.reduce(_any, axis=1, combine_fn=reduce_any)
53 tl.store(out, any[:, None], row_mask)
56@libentry()
57@triton.jit
58def any_kernel_1(
59 inp,
60 mid,
61 n_elements,
62 mid_size,
63 BLOCK_SIZE: tl.constexpr,
64):
65 pid = tle.program_id(0)
66 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
67 inp_ptrs = inp + offset
68 mask = offset < n_elements
69 inp_val = tl.load(inp_ptrs, mask=mask, other=0.0)
70 any_val = tl.reduce(inp_val != 0, axis=0, combine_fn=reduce_any)
71 mid_ptr = mid + pid
72 tl.store(mid_ptr, any_val)
75@libentry()
76@triton.jit
77def any_kernel_2(mid, out, MID_SIZE, BLOCK_MID: tl.constexpr):
78 offset = tl.arange(0, BLOCK_MID)
79 mid_ptrs = mid + offset
80 mask = offset < MID_SIZE
81 mid_val = tl.load(mid_ptrs, mask=mask, other=0).to(tl.int1)
82 any_val = tl.reduce(mid_val, axis=0, combine_fn=reduce_any)
83 tl.store(out, any_val)
86def any(inp):
87 logger.debug("GEMS_ASCEND ANY")
88 n_elements = inp.numel()
89 block_size = triton.next_power_of_2(math.ceil(math.sqrt(n_elements)))
90 mid_size = triton.cdiv(n_elements, block_size)
91 block_mid = triton.next_power_of_2(mid_size)
93 mid = torch.empty((mid_size,), dtype=torch.bool, device=inp.device)
94 out = torch.empty([], dtype=torch.bool, device=inp.device)
96 with torch_device_fn.device(inp.device):
97 any_kernel_1[(mid_size, 1)](inp, mid, n_elements, mid_size, block_size)
98 any_kernel_2[(1, 1)](mid, out, mid_size, block_mid)
100 return out
103def any_dim(inp, dim=None, keepdim=False):
104 logger.debug("GEMS_ASCEND ANY DIM")
105 shape = list(inp.shape)
106 if dim is None:
107 out = any(inp)
108 if keepdim:
109 out = torch.reshape(out, [1] * inp.ndim)
110 else:
111 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
112 dim = dim % inp.ndim
113 inp = dim_compress(inp, dim)
114 N = shape[dim]
115 shape[dim] = 1
116 M = inp.numel() // N
117 out = torch.empty(shape, dtype=torch.bool, device=inp.device)
119 def grid_fn(meta):
120 grid = triton.cdiv(M, meta["BLOCK_M"])
121 grid = grid if grid <= 65535 else 65535
122 return (grid,)
124 with torch_device_fn.device(inp.device):
125 any_kernel_dim[grid_fn](inp, out, M, N)
126 if not keepdim:
127 out = out.squeeze(dim=dim)
128 return out
131def any_dims(inp, dim=None, keepdim=False):
132 logger.debug("GEMS_ASCEND ANY DIMS")
134 if dim is None or isinstance(dim, int):
135 return any_dim(inp, dim=dim, keepdim=keepdim)
136 assert ((i >= -inp.ndim and i < inp.ndim) for i in dim), "Invalid dim"
138 shape = list(inp.shape)
139 dim = [d % inp.ndim for d in dim]
140 inp = dim_compress(inp, dim)
141 N = 1
142 for i in dim:
143 N *= shape[i]
144 shape[i] = 1
145 M = inp.numel() // N
147 out = torch.empty(shape, dtype=torch.bool, device=inp.device)
149 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
151 with torch_device_fn.device(inp.device):
152 any_kernel_dim[grid](inp, out, M, N)
153 if not keepdim:
154 out = out.squeeze(dim=dim)
155 return out