Coverage for src/flag_gems/ops/all.py: 68%
103 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-26 15:32 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-26 15:32 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import dim_compress, libentry, libtuner
11from flag_gems.utils import triton_lang_extension as tle
13logger = logging.getLogger(__name__)
15# torch.all: Tests if all elements in input evaluate to True. If the dtype of input
16# is not BOOL, then test if all elements in input evaluate to non-zero value
17# In triton function, test if all elements in input evaluate to non-zero value is ok.
20@triton.jit
21def reduce_all(a, b):
22 return a and b
25@libentry()
26@libtuner(
27 configs=runtime.get_tuned_config("naive_reduction"),
28 key=["M", "N"],
29)
30@triton.jit
31def all_kernel_dim(
32 inp,
33 out,
34 M,
35 N,
36 BLOCK_M: tl.constexpr,
37 BLOCK_N: tl.constexpr,
38):
39 # Map the program id to the row of inp it should compute.
40 pid = tle.program_id(0)
41 rows = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
42 inp = inp + rows * N
43 out = out + rows
44 row_mask = rows < M
46 _all = tl.full([BLOCK_M, BLOCK_N], value=1, dtype=tl.int1)
47 for off in range(0, N, BLOCK_N):
48 cols = off + tl.arange(0, BLOCK_N)[None, :]
49 col_mask = cols < N
50 mask = row_mask and col_mask
52 a = tl.load(inp + cols, mask, other=1.0)
53 _all = _all and (a != 0)
54 all = tl.reduce(_all, axis=1, combine_fn=reduce_all)
55 tl.store(out, all[:, None], row_mask)
58@libentry()
59@triton.jit
60def all_kernel_1(
61 inp,
62 mid,
63 n_elements,
64 mid_size,
65 BLOCK_SIZE: tl.constexpr,
66):
67 pid = tle.program_id(0)
68 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
69 inp_ptrs = inp + offset
70 mask = offset < n_elements
71 inp_val = tl.load(inp_ptrs, mask=mask, other=1.0)
72 all_val = tl.reduce(inp_val != 0, axis=0, combine_fn=reduce_all)
73 mid_ptr = mid + pid
74 tl.store(mid_ptr, all_val)
77@libentry()
78@triton.jit
79def all_kernel_2(mid, out, MID_SIZE, BLOCK_MID: tl.constexpr):
80 offset = tl.arange(0, BLOCK_MID)
81 mid_ptrs = mid + offset
82 mask = offset < MID_SIZE
83 mid_val = tl.load(mid_ptrs, mask=mask, other=1).to(tl.int1)
84 all_val = tl.reduce(mid_val, axis=0, combine_fn=reduce_all)
85 tl.store(out, all_val)
88def all(inp):
89 logger.debug("GEMS ALL")
90 n_elements = inp.numel()
91 block_size = triton.next_power_of_2(math.ceil(math.sqrt(n_elements)))
92 mid_size = triton.cdiv(n_elements, block_size)
93 block_mid = triton.next_power_of_2(mid_size)
95 mid = torch.empty((mid_size,), dtype=torch.bool, device=inp.device)
96 out = torch.empty([], dtype=torch.bool, device=inp.device)
98 with torch_device_fn.device(inp.device):
99 all_kernel_1[(mid_size, 1)](inp, mid, n_elements, mid_size, block_size)
100 all_kernel_2[(1, 1)](mid, out, mid_size, block_mid)
102 return out
105def all_dim(inp, dim=None, keepdim=False):
106 logger.debug("GEMS ALL DIM")
107 shape = list(inp.shape)
108 if dim is None:
109 out = all(inp)
110 if keepdim:
111 out = torch.reshape(out, [1] * inp.ndim)
112 else:
113 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
114 dim = dim % inp.ndim
115 inp = dim_compress(inp, dim)
116 N = shape[dim]
117 shape[dim] = 1
118 M = inp.numel() // N
120 out = torch.empty(shape, dtype=torch.bool, device=inp.device)
122 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
123 with torch_device_fn.device(inp.device):
124 all_kernel_dim[grid](inp, out, M, N)
125 if not keepdim:
126 out = out.squeeze(dim=dim)
127 return out
130def all_dims(inp, dim=None, keepdim=False):
131 logger.debug("GEMS ALL DIMS")
133 if dim is None or isinstance(dim, int):
134 return all_dim(inp, dim=dim, keepdim=keepdim)
135 assert ((i >= -inp.ndim and i < inp.ndim) for i in dim), "Invalid dim"
137 shape = list(inp.shape)
138 dim = [d % inp.ndim for d in dim]
139 inp = dim_compress(inp, dim)
140 N = 1
141 for i in dim:
142 N *= shape[i]
143 shape[i] = 1
144 M = inp.numel() // N
146 out = torch.empty(shape, dtype=torch.bool, device=inp.device)
148 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
149 with torch_device_fn.device(inp.device):
150 all_kernel_dim[grid](inp, out, M, N)
151 if not keepdim:
152 out = out.squeeze(dim=dim)
153 return out