Coverage for src/flag_gems/runtime/backend/_mthreads/ops/any.py: 0%
135 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
1import logging
2import math
3from typing import Sequence
5import torch
6import triton
7import triton.language as tl
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry
11from flag_gems.utils import triton_lang_extension as tle
13logger = logging.getLogger(
14 f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}'
15)
18def _flatten_dim(shape: Sequence[int], dim: int):
19 dim = dim % len(shape)
20 n = shape[dim]
21 inner = math.prod(shape[dim + 1 :]) if dim + 1 < len(shape) else 1
22 outer = math.prod(shape[:dim]) if dim > 0 else 1
23 return dim, n, inner, outer
26# Favor smaller tiles to keep occupancy high on MUSA; wide tiles trigger register
27# pressure and hurt latency for large reductions.
28def _select_reduction_config(m_rows: int, n_cols: int):
29 block_n = min(256, max(64, 1 << int(math.ceil(math.log2(n_cols)))))
30 max_block_m = 1 << int(math.floor(math.log2(max(1, m_rows))))
31 block_m = min(32, max_block_m)
32 num_warps = 8 if block_n >= 256 else 4
33 return block_m, block_n, num_warps
36@libentry()
37@triton.jit
38def any_kernel_dim(
39 inp,
40 out,
41 M,
42 N,
43 BLOCK_M: tl.constexpr,
44 BLOCK_N: tl.constexpr,
45):
46 pid = tle.program_id(0)
47 rows = (pid * BLOCK_M + tl.arange(0, BLOCK_M)).to(tl.int64)
48 row_mask = rows < M
49 row_offsets = rows * N
51 acc = tl.zeros((BLOCK_M,), dtype=tl.int1)
52 for off in range(0, N, BLOCK_N):
53 cols = off + tl.arange(0, BLOCK_N)
54 col_mask = cols < N
55 active = acc == 0
56 mask = row_mask[:, None] & col_mask[None, :] & active[:, None]
57 vals = tl.load(inp + row_offsets[:, None] + cols[None, :], mask=mask, other=0.0)
58 block_any = tl.max(vals != 0, axis=1).to(tl.int1)
59 acc = acc | block_any
60 tl.store(out + rows, acc, mask=row_mask)
63@triton.jit
64def any_kernel_dim_strided(
65 inp,
66 out,
67 M,
68 N,
69 INNER,
70 STRIDE_OUTER,
71 STRIDE_REDUCE,
72 BLOCK_M: tl.constexpr,
73 BLOCK_N: tl.constexpr,
74):
75 pid = tle.program_id(0)
76 rows = pid * BLOCK_M + tl.arange(0, BLOCK_M)
77 rows = rows.to(tl.int64)
78 row_mask = rows < M
80 outer_idx = rows // INNER
81 inner_idx = rows % INNER
82 base_ptr = inp + outer_idx * STRIDE_OUTER + inner_idx
84 acc = tl.zeros((BLOCK_M,), dtype=tl.int1)
85 for off in range(0, N, BLOCK_N):
86 cols = off + tl.arange(0, BLOCK_N)
87 cols = cols.to(tl.int64)
88 col_mask = cols < N
89 active = acc == 0
90 mask = row_mask[:, None] & col_mask[None, :] & active[:, None]
91 vals = tl.load(
92 base_ptr[:, None] + cols[None, :] * STRIDE_REDUCE, mask=mask, other=0.0
93 )
94 block_any = tl.max(vals != 0, axis=1).to(tl.int1)
95 acc = acc | block_any
96 tl.store(out + rows, acc, mask=row_mask)
99@libentry()
100@triton.jit
101def any_kernel_1(
102 inp,
103 mid,
104 n_elements,
105 mid_size,
106 BLOCK_SIZE: tl.constexpr,
107):
108 pid = tle.program_id(0)
109 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
110 inp_ptrs = inp + offset
111 mask = offset < n_elements
112 inp_val = tl.load(inp_ptrs, mask=mask, other=0.0)
113 any_val = tl.max(inp_val != 0, axis=0)
114 mid_ptr = mid + pid
115 tl.store(mid_ptr, any_val)
118@libentry()
119@triton.jit
120def any_kernel_2(mid, out, MID_SIZE, BLOCK_MID: tl.constexpr):
121 offset = tl.arange(0, BLOCK_MID)
122 mid_ptrs = mid + offset
123 mask = offset < MID_SIZE
124 mid_val = tl.load(mid_ptrs, mask=mask, other=0).to(tl.int1)
125 any_val = tl.max(mid_val, axis=0)
126 tl.store(out, any_val)
129def any(inp):
130 logger.debug("GEMS_MTHREADS ANY")
131 n_elements = inp.numel()
132 block_size = triton.next_power_of_2(math.ceil(math.sqrt(n_elements)))
133 block_size = min(block_size * 2, 4096, triton.next_power_of_2(n_elements))
134 mid_size = triton.cdiv(n_elements, block_size)
135 block_mid = triton.next_power_of_2(mid_size)
137 mid = torch.empty((mid_size,), dtype=torch.bool, device=inp.device)
138 out = torch.empty([], dtype=torch.bool, device=inp.device)
140 num_warps_block = min(8, max(1, block_size // 128))
141 num_warps_mid = min(8, max(1, block_mid // 128))
143 with torch_device_fn.device(inp.device):
144 any_kernel_1[(mid_size, 1)](
145 inp,
146 mid,
147 n_elements,
148 mid_size,
149 block_size,
150 num_warps=num_warps_block,
151 num_stages=2,
152 )
153 any_kernel_2[(1, 1)](
154 mid,
155 out,
156 mid_size,
157 block_mid,
158 num_warps=num_warps_mid,
159 num_stages=2,
160 )
162 return out
165def triton_any_dim_strided(
166 inp: torch.Tensor, dim: int, keepdim: bool = False
167) -> torch.Tensor:
168 dim = dim % inp.ndim
169 shape = list(inp.shape)
170 dim, n, inner, outer = _flatten_dim(shape, dim)
171 m = outer * inner
173 stride = inp.stride()
174 stride_reduce = stride[dim]
175 stride_outer = stride_reduce * n
177 out_flat = torch.empty((m,), dtype=torch.bool, device=inp.device)
178 block_m, block_n, num_warps = _select_reduction_config(m, n)
179 grid = (triton.cdiv(m, block_m),)
180 with torch_device_fn.device(inp.device):
181 any_kernel_dim_strided[grid](
182 inp,
183 out_flat,
184 m,
185 n,
186 inner,
187 stride_outer,
188 stride_reduce,
189 BLOCK_M=block_m,
190 BLOCK_N=block_n,
191 num_warps=num_warps,
192 num_stages=2,
193 )
195 shape[dim] = 1
196 out = out_flat.view(shape)
197 if not keepdim:
198 out = out.squeeze(dim=dim)
199 return out
202def any_dim(inp, dim=None, keepdim=False):
203 logger.debug("GEMS_MTHREADS ANY DIM")
204 # shape = list(inp.shape)
205 if dim is None:
206 out = any(inp)
207 if keepdim:
208 out = torch.reshape(out, [1] * inp.ndim)
209 return out
210 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
211 return triton_any_dim_strided(inp, dim, keepdim=keepdim)
214def any_dims(inp, dim=None, keepdim=False):
215 logger.debug("GEMS_MTHREADS ANY DIMS")
217 if dim is None or isinstance(dim, int):
218 return any_dim(inp, dim=dim, keepdim=keepdim)
220 dims = [d % inp.ndim for d in dim]
221 dims = sorted(set(dims))
222 out = inp
223 for d in dims:
224 out = triton_any_dim_strided(out, d, keepdim=True)
225 if not keepdim:
226 for d in reversed(dims):
227 out = out.squeeze(dim=d)
228 return out
231__all__ = ["any", "any_dim", "any_dims"]