Coverage for src/flag_gems/runtime/backend/_mthreads/ops/all.py: 0%
140 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
1import logging
2import math
3from typing import Sequence
5import torch
6import triton
7import triton.language as tl
9from flag_gems import runtime
10from flag_gems.runtime import torch_device_fn
11from flag_gems.utils import dim_compress, libentry, libtuner
12from flag_gems.utils import triton_lang_extension as tle
14logger = logging.getLogger(
15 f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}'
16)
18NAIVE_REDUCTION_CONFIGS = [
19 triton.Config({"BLOCK_M": 4, "BLOCK_N": 1024}, num_warps=4),
20 triton.Config({"BLOCK_M": 8, "BLOCK_N": 1024}, num_warps=4),
21 triton.Config({"BLOCK_M": 16, "BLOCK_N": 1024}, num_warps=8),
22 triton.Config({"BLOCK_M": 32, "BLOCK_N": 512}, num_warps=8),
23 triton.Config({"BLOCK_M": 32, "BLOCK_N": 256}, num_warps=4),
24 triton.Config({"BLOCK_M": 32, "BLOCK_N": 128}, num_warps=4),
25 triton.Config({"BLOCK_M": 32, "BLOCK_N": 64}, num_warps=4),
26 triton.Config({"BLOCK_M": 16, "BLOCK_N": 256}, num_warps=4),
27 triton.Config({"BLOCK_M": 16, "BLOCK_N": 128}, num_warps=4),
28 triton.Config({"BLOCK_M": 64, "BLOCK_N": 256}, num_warps=8),
29 triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=4),
30 triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4),
31 triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=8),
32 triton.Config({"BLOCK_M": 64, "BLOCK_N": 32}, num_warps=4),
33 triton.Config({"BLOCK_M": 128, "BLOCK_N": 256}, num_warps=8),
34 triton.Config({"BLOCK_M": 256, "BLOCK_N": 256}, num_warps=8),
35]
38@triton.jit
39def reduce_all(a, b):
40 return a and b
43@triton.autotune(configs=NAIVE_REDUCTION_CONFIGS, key=["M", "N"])
44@triton.jit
45def all_kernel_dim_strided(
46 inp,
47 out,
48 M,
49 N,
50 INNER,
51 STRIDE_OUTER,
52 STRIDE_REDUCE,
53 BLOCK_M: tl.constexpr,
54 BLOCK_N: tl.constexpr,
55):
56 pid = tl.program_id(0)
57 rows = pid * BLOCK_M + tl.arange(0, BLOCK_M)
58 rows = rows.to(tl.int64)
59 row_mask = rows < M
61 outer_idx = rows // INNER
62 inner_idx = rows % INNER
63 base_ptr = inp + outer_idx * STRIDE_OUTER + inner_idx
65 acc = tl.full([BLOCK_M, BLOCK_N], value=1, dtype=tl.int1)
66 for off in range(0, N, BLOCK_N):
67 cols = off + tl.arange(0, BLOCK_N)
68 cols = cols.to(tl.int64)
69 col_mask = cols < N
70 mask = row_mask[:, None] and col_mask[None, :]
71 vals = tl.load(
72 base_ptr[:, None] + cols[None, :] * STRIDE_REDUCE, mask, other=1.0
73 )
74 acc = acc and (vals != 0)
75 all_val = tl.reduce(acc, axis=1, combine_fn=reduce_all)
76 tl.store(out + rows, all_val, mask=row_mask)
79def _flatten_dim(shape: Sequence[int], dim: int):
80 dim = dim % len(shape)
81 n = shape[dim]
82 inner = math.prod(shape[dim + 1 :]) if dim + 1 < len(shape) else 1
83 outer = math.prod(shape[:dim]) if dim > 0 else 1
84 return dim, n, inner, outer
87def triton_all_dim_strided(
88 inp: torch.Tensor, dim: int, keepdim: bool = False
89) -> torch.Tensor:
90 dim = dim % inp.ndim
91 shape = list(inp.shape)
92 dim, n, inner, outer = _flatten_dim(shape, dim)
93 m = outer * inner
95 stride = inp.stride()
96 stride_reduce = stride[dim]
97 stride_outer = stride_reduce * n
99 out_flat = torch.empty((m,), dtype=torch.bool, device=inp.device)
100 grid = lambda meta: (triton.cdiv(m, meta["BLOCK_M"]),)
101 all_kernel_dim_strided[grid](
102 inp,
103 out_flat,
104 m,
105 n,
106 inner,
107 stride_outer,
108 stride_reduce,
109 )
111 shape[dim] = 1
112 out = out_flat.view(shape)
113 if not keepdim:
114 out = out.squeeze(dim=dim)
115 return out
118@libentry()
119@libtuner(
120 configs=runtime.get_tuned_config("naive_reduction"),
121 key=["M", "N"],
122)
123@triton.jit
124def all_kernel_dim(
125 inp,
126 out,
127 M,
128 N,
129 BLOCK_M: tl.constexpr,
130 BLOCK_N: tl.constexpr,
131):
132 pid = tle.program_id(0)
133 rows = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
134 inp = inp + rows * N
135 out = out + rows
136 row_mask = rows < M
138 _all = tl.full([BLOCK_M, BLOCK_N], value=1, dtype=tl.int1)
139 for off in range(0, N, BLOCK_N):
140 cols = off + tl.arange(0, BLOCK_N)[None, :]
141 col_mask = cols < N
142 mask = row_mask and col_mask
144 a = tl.load(inp + cols, mask, other=1.0)
145 _all = _all and (a != 0)
146 all = tl.reduce(_all, axis=1, combine_fn=reduce_all)
147 tl.store(out, all[:, None], row_mask)
150@libentry()
151@triton.jit
152def all_kernel_1(
153 inp,
154 mid,
155 n_elements,
156 mid_size,
157 BLOCK_SIZE: tl.constexpr,
158):
159 pid = tle.program_id(0)
160 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
161 inp_ptrs = inp + offset
162 mask = offset < n_elements
163 inp_val = tl.load(inp_ptrs, mask=mask, other=1.0)
164 all_val = tl.reduce(inp_val != 0, axis=0, combine_fn=reduce_all)
165 mid_ptr = mid + pid
166 tl.store(mid_ptr, all_val)
169@libentry()
170@triton.jit
171def all_kernel_2(mid, out, MID_SIZE, BLOCK_MID: tl.constexpr):
172 offset = tl.arange(0, BLOCK_MID)
173 mid_ptrs = mid + offset
174 mask = offset < MID_SIZE
175 mid_val = tl.load(mid_ptrs, mask=mask, other=1).to(tl.int1)
176 all_val = tl.reduce(mid_val, axis=0, combine_fn=reduce_all)
177 tl.store(out, all_val)
180def all(inp):
181 logger.debug("GEMS_MTHREADS ALL")
182 n_elements = inp.numel()
183 block_size = triton.next_power_of_2(math.ceil(math.sqrt(n_elements)))
184 block_size = min(block_size * 2, 4096, triton.next_power_of_2(n_elements))
185 mid_size = triton.cdiv(n_elements, block_size)
186 block_mid = triton.next_power_of_2(mid_size)
188 mid = torch.empty((mid_size,), dtype=torch.bool, device=inp.device)
189 out = torch.empty([], dtype=torch.bool, device=inp.device)
191 with torch_device_fn.device(inp.device):
192 all_kernel_1[(mid_size, 1)](inp, mid, n_elements, mid_size, block_size)
193 all_kernel_2[(1, 1)](mid, out, mid_size, block_mid)
195 return out
198def all_dim(inp, dim=None, keepdim=False):
199 logger.debug("GEMS_MTHREADS ALL DIM")
200 if dim is None:
201 out = all(inp)
202 if keepdim:
203 out = torch.reshape(out, [1] * inp.ndim)
204 return out
206 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
207 dim = dim % inp.ndim
209 with torch_device_fn.device(inp.device):
210 return triton_all_dim_strided(inp, dim=dim, keepdim=keepdim)
213def all_dims(inp, dim=None, keepdim=False):
214 logger.debug("GEMS_MTHREADS ALL DIMS")
216 if dim is None or isinstance(dim, int):
217 return all_dim(inp, dim=dim, keepdim=keepdim)
218 assert ((i >= -inp.ndim and i < inp.ndim) for i in dim), "Invalid dim"
220 shape = list(inp.shape)
221 dim = [d % inp.ndim for d in dim]
222 inp = dim_compress(inp, dim)
223 N = 1
224 for i in dim:
225 N *= shape[i]
226 shape[i] = 1
227 M = inp.numel() // N
229 out = torch.empty(shape, dtype=torch.bool, device=inp.device)
231 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
232 with torch_device_fn.device(inp.device):
233 all_kernel_dim[grid](inp, out, M, N)
234 if not keepdim:
235 out = out.squeeze(dim=dim)
236 return out
239__all__ = ["all", "all_dim", "all_dims"]