Coverage for src/flag_gems/runtime/backend/_mthreads/ops/argmin.py: 0%
121 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-18 02:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-18 02:36 +0800
1import builtins
2import logging
3import math
5import torch
6import triton
7import triton.language as tl
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry
11from flag_gems.utils import triton_lang_extension as tle
12from flag_gems.utils.limits import get_dtype_max
14logger = logging.getLogger(
15 f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}'
16)
18# Favor wider column tiles for long rows and more rows per block for tall shapes.
19ARGMIN_REDUCTION_CONFIGS = [
20 triton.Config({"BLOCK_M": 32, "BLOCK_N": 64}, num_warps=4, num_stages=1),
21 triton.Config({"BLOCK_M": 32, "BLOCK_N": 128}, num_warps=8, num_stages=1),
22 triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=8, num_stages=2),
23 triton.Config({"BLOCK_M": 32, "BLOCK_N": 256}, num_warps=8, num_stages=2),
24 triton.Config({"BLOCK_M": 16, "BLOCK_N": 256}, num_warps=4, num_stages=2),
25 triton.Config({"BLOCK_M": 16, "BLOCK_N": 512}, num_warps=8, num_stages=2),
26 triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1),
27]
30def _prune_reduction_configs(configs, nargs, **meta):
31 n = meta.get("N", nargs["N"])
32 if n <= 128:
33 max_block_n = 128
34 elif n <= 2048:
35 max_block_n = 256
36 else:
37 max_block_n = 512
38 return [cfg for cfg in configs if cfg.kwargs["BLOCK_N"] <= max_block_n]
41def _flatten_dim(shape, dim):
42 dim = dim % len(shape)
43 n = shape[dim]
44 inner = math.prod(shape[dim + 1 :]) if dim + 1 < len(shape) else 1
45 outer = math.prod(shape[:dim]) if dim > 0 else 1
46 return dim, n, inner, outer
49@libentry()
50@triton.jit
51def argmin_kernel_1(
52 inp,
53 mid_value,
54 mid_index,
55 M,
56 BLOCK_SIZE: tl.constexpr,
57):
58 pid = tle.program_id(0)
59 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
60 mask = offset < M
62 max_value = get_dtype_max(inp.type.element_ty)
63 inp_val = tl.load(inp + offset, mask=mask, other=max_value, cache_modifier=".cg")
64 min_val, min_index = tl.min(
65 inp_val, axis=0, return_indices=True, return_indices_tie_break_left=True
66 )
67 tl.store(mid_value + pid, min_val)
68 tl.store(mid_index + pid, min_index + pid * BLOCK_SIZE)
71@libentry()
72@triton.jit
73def argmin_kernel_2(
74 mid_value,
75 mid_index,
76 out,
77 mid_size,
78 BLOCK_MID: tl.constexpr,
79):
80 offset = tl.arange(0, BLOCK_MID)
81 mask = offset < mid_size
82 max_value = get_dtype_max(mid_value.type.element_ty)
83 mid_val = tl.load(mid_value + offset, mask=mask, other=max_value)
84 _, index_val = tl.min(
85 mid_val,
86 axis=0,
87 return_indices=True,
88 return_indices_tie_break_left=True,
89 )
90 out_val = tl.load(mid_index + index_val)
91 tl.store(out, out_val)
94@libentry()
95@triton.autotune(
96 configs=ARGMIN_REDUCTION_CONFIGS,
97 key=["M", "N"],
98 warmup=8,
99 rep=40,
100 prune_configs_by={"early_config_prune": _prune_reduction_configs},
101)
102@triton.jit
103def argmin_kernel(
104 inp,
105 out_index,
106 M,
107 N,
108 INNER,
109 STRIDE_OUTER,
110 STRIDE_REDUCE,
111 BLOCK_M: tl.constexpr,
112 BLOCK_N: tl.constexpr,
113):
114 pid_m = tle.program_id(0)
115 rows = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
116 rows = rows.to(tl.int64)
117 row_mask = rows < M
119 outer_idx = rows // INNER
120 inner_idx = rows % INNER
121 base_ptr = inp + outer_idx * STRIDE_OUTER + inner_idx
123 dtype = inp.type.element_ty
124 acc_type = tl.float32 if dtype is tl.bfloat16 else dtype
125 max_value = get_dtype_max(dtype)
126 min_values = tl.full([BLOCK_M], dtype=acc_type, value=max_value)
127 argmin_values = tl.full([BLOCK_M], dtype=tl.int32, value=0)
129 for start_n in range(0, N, BLOCK_N):
130 n_offset = start_n + tl.arange(0, BLOCK_N)
131 n_offset = n_offset.to(tl.int64)
132 mask = row_mask[:, None] & (n_offset[None, :] < N)
133 inp_ptrs = base_ptr[:, None] + n_offset[None, :] * STRIDE_REDUCE
134 inp_vals = tl.load(inp_ptrs, mask=mask, other=max_value, cache_modifier=".cg")
135 local_min, local_argmin = tl.min(
136 inp_vals,
137 1,
138 return_indices=True,
139 return_indices_tie_break_left=True,
140 )
141 local_argmin = local_argmin.to(tl.int32)
142 update = local_min < min_values
143 min_values = tl.where(update, local_min, min_values)
144 argmin_values = tl.where(
145 update, (start_n + local_argmin).to(tl.int32), argmin_values
146 )
148 out_index_ptrs = out_index + rows
149 tl.store(out_index_ptrs, argmin_values, mask=row_mask)
152def argmin(inp, dim=None, keepdim=False, *, dtype=None):
153 logger.debug("GEMS_MTHREADS ARGMIN")
154 if not inp.is_contiguous():
155 inp = inp.contiguous()
157 if dim is None:
158 M = inp.numel()
159 if dtype is None:
160 dtype = inp.dtype
161 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M)))
162 block_size = builtins.min(block_size * 4, 4096, triton.next_power_of_2(M))
163 mid_size = triton.cdiv(M, block_size)
164 block_mid = triton.next_power_of_2(mid_size)
166 mid_value = torch.empty((mid_size,), dtype=dtype, device=inp.device)
167 mid_index = torch.empty((mid_size,), dtype=torch.int64, device=inp.device)
168 if keepdim:
169 shape = list(inp.shape)
170 for i in range(0, inp.dim()):
171 shape[i] = 1
172 out = torch.empty(shape, dtype=torch.int64, device=inp.device)
173 else:
174 out = torch.empty([], dtype=torch.int64, device=inp.device)
176 num_warps_block = builtins.min(8, max(1, block_size // 128))
177 num_warps_mid = builtins.min(8, max(1, block_mid // 128))
179 with torch_device_fn.device(inp.device):
180 argmin_kernel_1[(mid_size, 1, 1)](
181 inp,
182 mid_value,
183 mid_index,
184 M,
185 block_size,
186 num_warps=num_warps_block,
187 num_stages=2,
188 )
189 argmin_kernel_2[(1, 1, 1)](
190 mid_value,
191 mid_index,
192 out,
193 mid_size,
194 block_mid,
195 num_warps=num_warps_mid,
196 num_stages=2,
197 )
198 return out
200 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
201 dim = dim % inp.ndim
203 shape = list(inp.shape)
204 dim, N, inner, outer = _flatten_dim(shape, dim)
205 M = outer * inner
206 stride = inp.stride()
207 stride_reduce = stride[dim]
208 stride_outer = stride_reduce * N
210 out_index = torch.empty((M,), dtype=torch.int32, device=inp.device)
212 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
213 with torch_device_fn.device(inp.device):
214 argmin_kernel[grid](
215 inp,
216 out_index,
217 M,
218 N,
219 max(inner, 1),
220 stride_outer,
221 stride_reduce,
222 )
224 out_shape = shape.copy()
225 out_shape[dim] = 1
226 out_index = out_index.view(out_shape).to(torch.int64)
227 if not keepdim:
228 out_index = torch.squeeze(out_index, dim)
229 return out_index
232__all__ = ["argmin"]