Coverage for src/flag_gems/runtime/backend/_mthreads/ops/min.py: 0%
123 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-23 02:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-23 02:03 +0800
1import builtins
2import logging
3import math
4from collections import namedtuple
6import torch
7import triton
8import triton.language as tl
10from flag_gems.ops import min_dim as base_min_dim
11from flag_gems.runtime import torch_device_fn
12from flag_gems.utils import libentry
13from flag_gems.utils import triton_lang_extension as tle
14from flag_gems.utils.limits import get_dtype_max
16logger = logging.getLogger(
17 f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}'
18)
20MinOut = namedtuple("min", ["values", "indices"])
22# Expanded coverage favors smaller column tiles and more warps for tall shapes.
23NAIVE_REDUCTION_CONFIGS = [
24 triton.Config({"BLOCK_M": 16, "BLOCK_N": 32}, num_warps=4, num_stages=1),
25 triton.Config({"BLOCK_M": 16, "BLOCK_N": 64}, num_warps=2, num_stages=1),
26 triton.Config({"BLOCK_M": 32, "BLOCK_N": 128}, num_warps=8, num_stages=1),
27 triton.Config({"BLOCK_M": 32, "BLOCK_N": 128}, num_warps=16, num_stages=1),
28 triton.Config({"BLOCK_M": 16, "BLOCK_N": 256}, num_warps=4, num_stages=1),
29 triton.Config({"BLOCK_M": 32, "BLOCK_N": 256}, num_warps=8, num_stages=1),
30 triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=16, num_stages=2),
31]
34def _prune_reduction_configs(configs, nargs, **meta):
35 n = meta.get("N", None)
36 if n is None:
37 n = nargs["N"]
38 max_block_n = 64 if n <= 128 else 256
39 return [cfg for cfg in configs if cfg.kwargs["BLOCK_N"] <= max_block_n]
42def _flatten_dim(shape, dim):
43 dim = dim % len(shape)
44 n = shape[dim]
45 inner = math.prod(shape[dim + 1 :]) if dim + 1 < len(shape) else 1
46 outer = math.prod(shape[:dim]) if dim > 0 else 1
47 return dim, n, inner, outer
50@libentry()
51@triton.jit
52def min_kernel_1(
53 inp,
54 mid,
55 M,
56 BLOCK_SIZE: tl.constexpr,
57):
58 pid = tle.program_id(0)
59 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
60 inp_ptrs = inp + offset
61 mask = offset < M
62 max_value = get_dtype_max(inp.type.element_ty)
63 inp_val = tl.load(inp_ptrs, mask=mask, other=max_value)
64 min_val = tl.min(inp_val)
65 mid_ptr = mid + pid
66 tl.store(mid_ptr, min_val)
69@libentry()
70@triton.jit
71def min_kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr):
72 offset = tl.arange(0, BLOCK_MID)
73 mid_ptrs = mid + offset
74 mask = offset < mid_size
75 max_value = get_dtype_max(mid.type.element_ty)
76 mid_val = tl.load(mid_ptrs, mask=mask, other=max_value)
77 min_val = tl.min(mid_val)
78 tl.store(out, min_val)
81@libentry()
82@triton.autotune(
83 configs=NAIVE_REDUCTION_CONFIGS,
84 key=["M", "N"],
85 warmup=8,
86 rep=40,
87 prune_configs_by={"early_config_prune": _prune_reduction_configs},
88)
89@triton.jit
90def min_kernel(
91 inp,
92 out_value,
93 out_index,
94 M,
95 N,
96 INNER,
97 STRIDE_OUTER,
98 STRIDE_REDUCE,
99 BLOCK_M: tl.constexpr,
100 BLOCK_N: tl.constexpr,
101):
102 pid_m = tle.program_id(0)
103 rows = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
104 rows = rows.to(tl.int64)
105 row_mask = rows < M
107 outer_idx = rows // INNER
108 inner_idx = rows % INNER
109 base_ptr = inp + outer_idx * STRIDE_OUTER + inner_idx
111 dtype = inp.type.element_ty
112 acc_type = tl.float32 if dtype is tl.bfloat16 else dtype
113 max_value = get_dtype_max(dtype)
114 min_values = tl.full([BLOCK_M], dtype=acc_type, value=max_value)
115 argmin_values = tl.full([BLOCK_M], dtype=tl.int32, value=0)
116 for start_n in range(0, N, BLOCK_N):
117 n_offset = start_n + tl.arange(0, BLOCK_N)
118 n_offset = n_offset.to(tl.int64)
119 mask = row_mask[:, None] & (n_offset[None, :] < N)
120 inp_ptrs = base_ptr[:, None] + n_offset[None, :] * STRIDE_REDUCE
121 inp_vals = tl.load(inp_ptrs, mask=mask, other=max_value, cache_modifier=".cg")
122 local_min, local_argmin = tl.min(inp_vals, 1, return_indices=True)
123 local_argmin = local_argmin.to(tl.int32)
124 update = local_min < min_values
125 min_values = tl.where(update, local_min, min_values)
126 argmin_values = tl.where(
127 update, (start_n + local_argmin).to(tl.int32), argmin_values
128 )
130 out_value_ptrs = out_value + rows
131 out_index_ptrs = out_index + rows
132 tl.store(out_value_ptrs, min_values, mask=row_mask)
133 tl.store(out_index_ptrs, argmin_values, mask=row_mask)
136def min(inp):
137 logger.debug("GEMS_MTHREADS MIN")
138 M = inp.numel()
139 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M)))
140 block_size = builtins.min(block_size * 4, 4096, triton.next_power_of_2(M))
141 mid_size = triton.cdiv(M, block_size)
142 block_mid = triton.next_power_of_2(mid_size)
144 dtype = inp.dtype
145 mid = torch.empty((mid_size,), dtype=dtype, device=inp.device)
146 out = torch.empty([], dtype=dtype, device=inp.device)
148 num_warps_block = builtins.min(8, max(1, block_size // 128))
149 num_warps_mid = builtins.min(8, max(1, block_mid // 128))
151 with torch_device_fn.device(inp.device):
152 min_kernel_1[(mid_size, 1, 1)](
153 inp, mid, M, block_size, num_warps=num_warps_block, num_stages=2
154 )
155 min_kernel_2[(1, 1, 1)](
156 mid, out, mid_size, block_mid, num_warps=num_warps_mid, num_stages=2
157 )
158 return out
161def min_dim(inp, dim=None, keepdim=False):
162 logger.debug("GEMS_MTHREADS MIN DIM")
163 assert dim is not None, "dim must be specified"
164 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
165 dim = dim % inp.ndim
167 if not inp.is_contiguous():
168 # Fall back to the generic implementation (handles arbitrary strides).
169 return base_min_dim(inp, dim=dim, keepdim=keepdim)
171 shape = list(inp.shape)
172 dim, N, inner, outer = _flatten_dim(shape, dim)
173 M = outer * inner
174 stride = inp.stride()
175 stride_reduce = stride[dim]
176 stride_outer = stride_reduce * N
178 out_value = torch.empty((M,), dtype=inp.dtype, device=inp.device)
179 out_index = torch.empty((M,), dtype=torch.int32, device=inp.device)
181 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
182 with torch_device_fn.device(inp.device):
183 min_kernel[grid](
184 inp,
185 out_value,
186 out_index,
187 M,
188 N,
189 max(inner, 1),
190 stride_outer,
191 stride_reduce,
192 )
194 out_shape = shape.copy()
195 out_shape[dim] = 1
196 out_value = out_value.view(out_shape)
197 out_index = out_index.view(out_shape).to(torch.int64)
198 if not keepdim:
199 out_value = torch.squeeze(out_value, dim)
200 out_index = torch.squeeze(out_index, dim)
201 return MinOut(values=out_value, indices=out_index)
204__all__ = ["min", "min_dim"]