Coverage for src/flag_gems/runtime/backend/_ascend/ops/all.py: 0%
114 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import dim_compress, libentry
11from flag_gems.utils import triton_lang_extension as tle
13logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}')
16# torch.all: Tests if all elements in input evaluate to True. If the dtype of input
17# is not BOOL, then test if all elements in input evaluate to non-zero value
18# In triton function, test if all elements in input evaluate to non-zero value is ok.
21@triton.jit
22def reduce_all(a, b):
23 return a and b
26@libentry()
27@triton.autotune(configs=runtime.get_tuned_config("all"), key=["M", "N"])
28@triton.jit
29def all_kernel_dim(
30 inp,
31 out,
32 M,
33 N,
34 BLOCK_M: tl.constexpr,
35 BLOCK_N: tl.constexpr,
36):
37 # Map the program id to the row of inp it should compute.
38 workers = tle.num_programs(0)
39 pid = tle.program_id(0)
41 total_workloads = tl.cdiv(M, BLOCK_M)
42 workloads = tl.cdiv(total_workloads, workers)
44 for w in range(workloads):
45 work_id = pid + w * workers
46 rows = work_id * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
47 ninp = inp + rows * N
48 nout = out + rows
49 row_mask = rows < M
51 _all = tl.full([BLOCK_M, BLOCK_N], value=1, dtype=tl.int1)
52 for off in range(0, N, BLOCK_N):
53 cols = off + tl.arange(0, BLOCK_N)[None, :]
54 col_mask = cols < N
55 mask = row_mask and col_mask
57 a = tl.load(ninp + cols, mask, other=1.0)
58 _all = _all and (a != 0)
59 all = tl.reduce(_all, axis=1, combine_fn=reduce_all)
60 tl.store(nout, all[:, None], row_mask)
63@libentry()
64@triton.jit
65def all_kernel_1(
66 inp,
67 mid,
68 n_elements,
69 mid_size,
70 BLOCK_SIZE: tl.constexpr,
71):
72 pid = tle.program_id(0)
73 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
74 inp_ptrs = inp + offset
75 mask = offset < n_elements
76 inp_val = tl.load(inp_ptrs, mask=mask, other=1.0)
77 all_val = tl.reduce(inp_val != 0, axis=0, combine_fn=reduce_all)
78 mid_ptr = mid + pid
79 tl.store(mid_ptr, all_val)
82@libentry()
83@triton.jit
84def all_kernel_2(mid, out, MID_SIZE, BLOCK_MID: tl.constexpr):
85 offset = tl.arange(0, BLOCK_MID)
86 mid_ptrs = mid + offset
87 mask = offset < MID_SIZE
88 mid_val = tl.load(mid_ptrs, mask=mask, other=1).to(tl.int1)
89 all_val = tl.reduce(mid_val, axis=0, combine_fn=reduce_all)
90 tl.store(out, all_val)
93def all(inp):
94 logger.debug("GEMS_ASCEND ALL")
95 n_elements = inp.numel()
96 block_size = triton.next_power_of_2(math.ceil(math.sqrt(n_elements)))
97 mid_size = triton.cdiv(n_elements, block_size)
98 block_mid = triton.next_power_of_2(mid_size)
100 mid = torch.empty((mid_size,), dtype=torch.bool, device=inp.device)
101 out = torch.empty([], dtype=torch.bool, device=inp.device)
103 with torch_device_fn.device(inp.device):
104 all_kernel_1[(mid_size, 1)](inp, mid, n_elements, mid_size, block_size)
105 all_kernel_2[(1, 1)](mid, out, mid_size, block_mid)
107 return out
110def all_dim(inp, dim=None, keepdim=False):
111 logger.debug("GEMS_ASCEND ALL DIM")
112 shape = list(inp.shape)
113 if dim is None:
114 out = all(inp)
115 if keepdim:
116 out = torch.reshape(out, [1] * inp.ndim)
117 else:
118 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
119 dim = dim % inp.ndim
120 inp = dim_compress(inp, dim)
121 N = shape[dim]
122 shape[dim] = 1
123 M = inp.numel() // N
125 out = torch.empty(shape, dtype=torch.bool, device=inp.device)
127 def grid(meta):
128 axis0 = triton.cdiv(M, meta["BLOCK_M"])
129 axis0 = axis0 if axis0 < 40 else 40
130 return (axis0,)
132 with torch_device_fn.device(inp.device):
133 all_kernel_dim[grid](inp, out, M, N)
134 if not keepdim:
135 out = out.squeeze(dim=dim)
136 return out
139def all_dims(inp, dim=None, keepdim=False):
140 logger.debug("GEMS_ASCEND ALL DIMS")
142 if dim is None or isinstance(dim, int):
143 return all_dim(inp, dim=dim, keepdim=keepdim)
144 assert ((i >= -inp.ndim and i < inp.ndim) for i in dim), "Invalid dim"
146 shape = list(inp.shape)
147 dim = [d % inp.ndim for d in dim]
148 inp = dim_compress(inp, dim)
149 N = 1
150 for i in dim:
151 N *= shape[i]
152 shape[i] = 1
153 M = inp.numel() // N
155 out = torch.empty(shape, dtype=torch.bool, device=inp.device)
157 def grid(meta):
158 axis0 = triton.cdiv(M, meta["BLOCK_M"])
159 axis0 = axis0 if axis0 < 40 else 40
160 return (axis0,)
162 with torch_device_fn.device(inp.device):
163 all_kernel_dim[grid](inp, out, M, N)
164 if not keepdim:
165 out = out.squeeze(dim=dim)
166 return out