Coverage for src/flag_gems/runtime/backend/_mthreads/ops/amax.py: 0%
143 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-21 14:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-21 14:31 +0800
1import builtins
2import logging
3import math
5import torch
6import triton
7import triton.language as tl
9from flag_gems.ops.amax import amax as base_amax
10from flag_gems.runtime import torch_device_fn
11from flag_gems.utils import libentry
12from flag_gems.utils import triton_lang_extension as tle
13from flag_gems.utils.limits import get_dtype_min
15logger = logging.getLogger(
16 f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}'
17)
19AMAX_REDUCTION_CONFIGS = [
20 triton.Config({"BLOCK_M": 32, "BLOCK_N": 64}, num_warps=4, num_stages=1),
21 triton.Config({"BLOCK_M": 32, "BLOCK_N": 128}, num_warps=8, num_stages=1),
22 triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=8, num_stages=2),
23 triton.Config({"BLOCK_M": 32, "BLOCK_N": 256}, num_warps=8, num_stages=2),
24 triton.Config({"BLOCK_M": 16, "BLOCK_N": 256}, num_warps=4, num_stages=2),
25 triton.Config({"BLOCK_M": 16, "BLOCK_N": 512}, num_warps=8, num_stages=2),
26 triton.Config({"BLOCK_M": 32, "BLOCK_N": 512}, num_warps=8, num_stages=2),
27 triton.Config({"BLOCK_M": 16, "BLOCK_N": 1024}, num_warps=8, num_stages=2),
28 triton.Config({"BLOCK_M": 32, "BLOCK_N": 1024}, num_warps=8, num_stages=2),
29 triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1),
30]
33def _prune_reduction_configs(configs, nargs, **meta):
34 n = meta.get("N", nargs["N"])
35 if n <= 128:
36 max_block_n = 128
37 elif n <= 2048:
38 max_block_n = 256
39 elif n <= 8192:
40 max_block_n = 512
41 else:
42 max_block_n = 1024
43 return [cfg for cfg in configs if cfg.kwargs["BLOCK_N"] <= max_block_n]
46def _flatten_dim(shape, dim):
47 dim = dim % len(shape)
48 n = shape[dim]
49 inner = math.prod(shape[dim + 1 :]) if dim + 1 < len(shape) else 1
50 outer = math.prod(shape[:dim]) if dim > 0 else 1
51 return dim, n, inner, outer
54@libentry()
55@triton.jit
56def amax_kernel_1(
57 inp,
58 mid,
59 M,
60 BLOCK_SIZE: tl.constexpr,
61):
62 pid = tle.program_id(0)
63 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
64 mask = offset < M
65 min_value = get_dtype_min(inp.type.element_ty)
66 vals = tl.load(inp + offset, mask=mask, other=min_value, cache_modifier=".cg")
67 tl.store(mid + pid, tl.max(vals))
70@libentry()
71@triton.jit
72def amax_kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr):
73 offset = tl.arange(0, BLOCK_MID)
74 mask = offset < mid_size
75 min_value = get_dtype_min(mid.type.element_ty)
76 vals = tl.load(mid + offset, mask=mask, other=min_value)
77 tl.store(out, tl.max(vals))
80@libentry()
81@triton.jit
82def amax_kernel_small(
83 inp,
84 out_value,
85 M,
86 N,
87 STRIDE_OUTER,
88 STRIDE_REDUCE,
89 BLOCK_N: tl.constexpr,
90):
91 row = tle.program_id(0)
92 row_mask = row < M
93 cols = tl.arange(0, BLOCK_N)
94 col_mask = cols < N
96 stride_outer = tl.full((), STRIDE_OUTER, tl.int64)
97 stride_reduce = tl.full((), STRIDE_REDUCE, tl.int64)
98 offsets = row.to(tl.int64) * stride_outer + cols.to(tl.int64) * stride_reduce
100 dtype = inp.type.element_ty
101 acc_type = tl.float32 if (dtype is tl.float16 or dtype is tl.bfloat16) else dtype
102 min_value = get_dtype_min(dtype)
103 vals = tl.load(inp + offsets, mask=row_mask & col_mask, other=min_value).to(
104 acc_type
105 )
106 row_max = tl.max(vals, axis=0)
107 tl.store(out_value + row, row_max, mask=row_mask)
110@libentry()
111@triton.autotune(
112 configs=AMAX_REDUCTION_CONFIGS,
113 key=["M", "N"],
114 warmup=8,
115 rep=40,
116 prune_configs_by={"early_config_prune": _prune_reduction_configs},
117)
118@triton.jit
119def amax_kernel(
120 inp,
121 out_value,
122 M,
123 N,
124 INNER,
125 STRIDE_OUTER,
126 STRIDE_REDUCE,
127 BLOCK_M: tl.constexpr,
128 BLOCK_N: tl.constexpr,
129):
130 pid_m = tle.program_id(0)
131 rows = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
132 rows = rows.to(tl.int64)
133 row_mask = rows < M
135 outer_idx = rows // INNER
136 inner_idx = rows % INNER
137 base_ptr = inp + outer_idx * STRIDE_OUTER + inner_idx
139 dtype = inp.type.element_ty
140 acc_type = tl.float32 if (dtype is tl.float16 or dtype is tl.bfloat16) else dtype
141 min_value = get_dtype_min(dtype)
142 max_values = tl.full([BLOCK_M], dtype=acc_type, value=min_value)
144 for start_n in range(0, N, BLOCK_N):
145 n_offset = start_n + tl.arange(0, BLOCK_N)
146 n_offset = n_offset.to(tl.int64)
147 mask = row_mask[:, None] & (n_offset[None, :] < N)
148 inp_ptrs = base_ptr[:, None] + n_offset[None, :] * STRIDE_REDUCE
149 inp_vals = tl.load(inp_ptrs, mask=mask, other=min_value, cache_modifier=".cg")
150 inp_vals = inp_vals.to(acc_type)
151 local_max = tl.max(inp_vals, axis=1)
152 max_values = tl.maximum(max_values, local_max)
154 out_value_ptrs = out_value + rows
155 tl.store(out_value_ptrs, max_values, mask=row_mask)
158def amax(inp, dim=None, keepdim=False):
159 logger.debug("GEMS_MTHREADS AMAX")
161 if dim is None or (isinstance(dim, (list, tuple)) and len(dim) == 0):
162 # Global reduction
163 if not inp.is_contiguous():
164 inp = inp.contiguous()
165 if inp.numel() == 0:
166 return base_amax(inp, dim=dim, keepdim=keepdim)
168 M = inp.numel()
169 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M)))
170 block_size = builtins.min(block_size * 4, 4096, triton.next_power_of_2(M))
171 mid_size = triton.cdiv(M, block_size)
172 block_mid = triton.next_power_of_2(mid_size)
174 dtype = inp.dtype
175 mid = torch.empty((mid_size,), dtype=dtype, device=inp.device)
177 if not keepdim:
178 out = torch.empty([], dtype=dtype, device=inp.device)
179 else:
180 shape = [1] * inp.dim()
181 out = torch.empty(shape, dtype=dtype, device=inp.device)
183 num_warps_block = builtins.min(8, builtins.max(1, block_size // 128))
184 num_warps_mid = builtins.min(8, builtins.max(1, block_mid // 128))
186 with torch_device_fn.device(inp.device):
187 amax_kernel_1[(mid_size, 1, 1)](
188 inp, mid, M, block_size, num_warps=num_warps_block, num_stages=2
189 )
190 amax_kernel_2[(1, 1, 1)](
191 mid, out, mid_size, block_mid, num_warps=num_warps_mid, num_stages=2
192 )
193 return out
194 else:
195 # Dimension-specific reduction
196 if isinstance(dim, int):
197 dim = [dim]
199 # For multi-dim reduction, use base implementation
200 if len(dim) > 1:
201 return base_amax(inp, dim=dim, keepdim=keepdim)
203 dim_val = dim[0]
204 assert dim_val >= -inp.ndim and dim_val < inp.ndim, "Invalid dim"
205 dim_val = dim_val % inp.ndim
207 if not inp.is_contiguous():
208 return base_amax(inp, dim=dim, keepdim=keepdim)
210 shape = list(inp.shape)
211 dim_val, N, inner, outer = _flatten_dim(shape, dim_val)
212 M = outer * inner
213 stride = inp.stride()
214 stride_reduce = stride[dim_val]
215 stride_outer = stride_reduce * N
217 out_value = torch.empty((M,), dtype=inp.dtype, device=inp.device)
219 if inner == 1 and N <= 128:
220 block_n = builtins.min(triton.next_power_of_2(N), 128)
221 grid = (triton.cdiv(M, 1),)
222 with torch_device_fn.device(inp.device):
223 amax_kernel_small[grid](
224 inp,
225 out_value,
226 M,
227 N,
228 stride_outer,
229 stride_reduce,
230 block_n,
231 num_warps=1,
232 num_stages=1,
233 )
234 else:
235 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
236 with torch_device_fn.device(inp.device):
237 amax_kernel[grid](
238 inp,
239 out_value,
240 M,
241 N,
242 builtins.max(inner, 1),
243 stride_outer,
244 stride_reduce,
245 )
247 out_shape = shape.copy()
248 out_shape[dim_val] = 1
249 out_value = out_value.view(out_shape)
250 if not keepdim:
251 out_value = torch.squeeze(out_value, dim_val)
253 return out_value
256__all__ = ["amax"]