Coverage for src/flag_gems/runtime/backend/_mthreads/ops/max.py: 0%
150 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-24 15:40 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-24 15:40 +0800
1import builtins
2import logging
3import math
4from collections import namedtuple
6import torch
7import triton
8import triton.language as tl
10from flag_gems.ops.max import max as base_max
11from flag_gems.ops.max import max_dim as base_max_dim
12from flag_gems.runtime import torch_device_fn
13from flag_gems.utils import libentry
14from flag_gems.utils import triton_lang_extension as tle
15from flag_gems.utils.limits import get_dtype_min
17logger = logging.getLogger(
18 f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}'
19)
21MaxOut = namedtuple("max", ["values", "indices"])
23MAX_REDUCTION_CONFIGS = [
24 triton.Config({"BLOCK_M": 32, "BLOCK_N": 64}, num_warps=4, num_stages=1),
25 triton.Config({"BLOCK_M": 32, "BLOCK_N": 128}, num_warps=8, num_stages=1),
26 triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=8, num_stages=2),
27 triton.Config({"BLOCK_M": 32, "BLOCK_N": 256}, num_warps=8, num_stages=2),
28 triton.Config({"BLOCK_M": 16, "BLOCK_N": 256}, num_warps=4, num_stages=2),
29 triton.Config({"BLOCK_M": 16, "BLOCK_N": 512}, num_warps=8, num_stages=2),
30 triton.Config({"BLOCK_M": 32, "BLOCK_N": 512}, num_warps=8, num_stages=2),
31 triton.Config({"BLOCK_M": 16, "BLOCK_N": 1024}, num_warps=8, num_stages=2),
32 triton.Config({"BLOCK_M": 32, "BLOCK_N": 1024}, num_warps=8, num_stages=2),
33 triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1),
34]
37def _prune_reduction_configs(configs, nargs, **meta):
38 n = meta.get("N", nargs["N"])
39 if n <= 128:
40 max_block_n = 128
41 elif n <= 2048:
42 max_block_n = 256
43 elif n <= 8192:
44 max_block_n = 512
45 else:
46 max_block_n = 1024
47 return [cfg for cfg in configs if cfg.kwargs["BLOCK_N"] <= max_block_n]
50def _flatten_dim(shape, dim):
51 dim = dim % len(shape)
52 n = shape[dim]
53 inner = math.prod(shape[dim + 1 :]) if dim + 1 < len(shape) else 1
54 outer = math.prod(shape[:dim]) if dim > 0 else 1
55 return dim, n, inner, outer
58@libentry()
59@triton.jit
60def max_kernel_1(
61 inp,
62 mid,
63 M,
64 BLOCK_SIZE: tl.constexpr,
65):
66 pid = tle.program_id(0)
67 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
68 mask = offset < M
69 min_value = get_dtype_min(inp.type.element_ty)
70 vals = tl.load(inp + offset, mask=mask, other=min_value, cache_modifier=".cg")
71 tl.store(mid + pid, tl.max(vals))
74@libentry()
75@triton.jit
76def max_kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr):
77 offset = tl.arange(0, BLOCK_MID)
78 mask = offset < mid_size
79 min_value = get_dtype_min(mid.type.element_ty)
80 vals = tl.load(mid + offset, mask=mask, other=min_value)
81 tl.store(out, tl.max(vals))
84@libentry()
85@triton.jit
86def max_kernel_small(
87 inp,
88 out_value,
89 out_index,
90 M,
91 N,
92 STRIDE_OUTER,
93 STRIDE_REDUCE,
94 BLOCK_N: tl.constexpr,
95):
96 row = tle.program_id(0)
97 row_mask = row < M
98 cols = tl.arange(0, BLOCK_N)
99 col_mask = cols < N
101 stride_outer = tl.full((), STRIDE_OUTER, tl.int64)
102 stride_reduce = tl.full((), STRIDE_REDUCE, tl.int64)
103 offsets = row.to(tl.int64) * stride_outer + cols.to(tl.int64) * stride_reduce
105 dtype = inp.type.element_ty
106 acc_type = tl.float32 if (dtype is tl.float16 or dtype is tl.bfloat16) else dtype
107 min_value = get_dtype_min(dtype)
108 vals = tl.load(inp + offsets, mask=row_mask & col_mask, other=min_value).to(
109 acc_type
110 )
111 row_max, row_argmax = tl.max(
112 vals,
113 axis=0,
114 return_indices=True,
115 return_indices_tie_break_left=True,
116 )
117 tl.store(out_value + row, row_max, mask=row_mask)
118 tl.store(out_index + row, row_argmax.to(tl.int32), mask=row_mask)
121@libentry()
122@triton.autotune(
123 configs=MAX_REDUCTION_CONFIGS,
124 key=["M", "N"],
125 warmup=8,
126 rep=40,
127 prune_configs_by={"early_config_prune": _prune_reduction_configs},
128)
129@triton.jit
130def max_kernel(
131 inp,
132 out_value,
133 out_index,
134 M,
135 N,
136 INNER,
137 STRIDE_OUTER,
138 STRIDE_REDUCE,
139 BLOCK_M: tl.constexpr,
140 BLOCK_N: tl.constexpr,
141):
142 pid_m = tle.program_id(0)
143 rows = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
144 rows = rows.to(tl.int64)
145 row_mask = rows < M
147 outer_idx = rows // INNER
148 inner_idx = rows % INNER
149 base_ptr = inp + outer_idx * STRIDE_OUTER + inner_idx
151 dtype = inp.type.element_ty
152 acc_type = tl.float32 if (dtype is tl.float16 or dtype is tl.bfloat16) else dtype
153 min_value = get_dtype_min(dtype)
154 max_values = tl.full([BLOCK_M], dtype=acc_type, value=min_value)
155 argmax_values = tl.full([BLOCK_M], dtype=tl.int32, value=0)
157 for start_n in range(0, N, BLOCK_N):
158 n_offset = start_n + tl.arange(0, BLOCK_N)
159 n_offset = n_offset.to(tl.int64)
160 mask = row_mask[:, None] & (n_offset[None, :] < N)
161 inp_ptrs = base_ptr[:, None] + n_offset[None, :] * STRIDE_REDUCE
162 inp_vals = tl.load(inp_ptrs, mask=mask, other=min_value, cache_modifier=".cg")
163 inp_vals = inp_vals.to(acc_type)
164 local_max, local_argmax = tl.max(
165 inp_vals,
166 axis=1,
167 return_indices=True,
168 return_indices_tie_break_left=True,
169 )
170 local_argmax = local_argmax.to(tl.int32)
171 update = local_max > max_values
172 max_values = tl.where(update, local_max, max_values)
173 argmax_values = tl.where(
174 update, (start_n + local_argmax).to(tl.int32), argmax_values
175 )
177 out_value_ptrs = out_value + rows
178 out_index_ptrs = out_index + rows
179 tl.store(out_value_ptrs, max_values, mask=row_mask)
180 tl.store(out_index_ptrs, argmax_values, mask=row_mask)
183def max(inp):
184 logger.debug("GEMS_MTHREADS MAX")
185 if not inp.is_contiguous():
186 inp = inp.contiguous()
187 if inp.numel() == 0:
188 return base_max(inp)
190 M = inp.numel()
191 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M)))
192 block_size = builtins.min(block_size * 4, 4096, triton.next_power_of_2(M))
193 mid_size = triton.cdiv(M, block_size)
194 block_mid = triton.next_power_of_2(mid_size)
196 dtype = inp.dtype
197 mid = torch.empty((mid_size,), dtype=dtype, device=inp.device)
198 out = torch.empty([], dtype=dtype, device=inp.device)
200 num_warps_block = builtins.min(8, builtins.max(1, block_size // 128))
201 num_warps_mid = builtins.min(8, builtins.max(1, block_mid // 128))
203 with torch_device_fn.device(inp.device):
204 max_kernel_1[(mid_size, 1, 1)](
205 inp, mid, M, block_size, num_warps=num_warps_block, num_stages=2
206 )
207 max_kernel_2[(1, 1, 1)](
208 mid, out, mid_size, block_mid, num_warps=num_warps_mid, num_stages=2
209 )
210 return out
213def max_dim(inp, dim=None, keepdim=False):
214 logger.debug("GEMS_MTHREADS MAX DIM")
215 assert dim is not None, "dim must be specified"
216 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
217 dim = dim % inp.ndim
219 if not inp.is_contiguous():
220 return base_max_dim(inp, dim=dim, keepdim=keepdim)
222 shape = list(inp.shape)
223 dim, N, inner, outer = _flatten_dim(shape, dim)
224 M = outer * inner
225 stride = inp.stride()
226 stride_reduce = stride[dim]
227 stride_outer = stride_reduce * N
229 out_value = torch.empty((M,), dtype=inp.dtype, device=inp.device)
230 out_index = torch.empty((M,), dtype=torch.int32, device=inp.device)
232 if inner == 1 and N <= 128:
233 block_n = builtins.min(triton.next_power_of_2(N), 128)
234 grid = (triton.cdiv(M, 1),)
235 with torch_device_fn.device(inp.device):
236 max_kernel_small[grid](
237 inp,
238 out_value,
239 out_index,
240 M,
241 N,
242 stride_outer,
243 stride_reduce,
244 block_n,
245 num_warps=1,
246 num_stages=1,
247 )
248 else:
249 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
250 with torch_device_fn.device(inp.device):
251 max_kernel[grid](
252 inp,
253 out_value,
254 out_index,
255 M,
256 N,
257 builtins.max(inner, 1),
258 stride_outer,
259 stride_reduce,
260 )
262 out_shape = shape.copy()
263 out_shape[dim] = 1
264 out_value = out_value.view(out_shape)
265 out_index = out_index.view(out_shape).to(torch.int64)
266 if not keepdim:
267 out_value = torch.squeeze(out_value, dim)
268 out_index = torch.squeeze(out_index, dim)
270 return MaxOut(values=out_value, indices=out_index)
273__all__ = ["max", "max_dim"]