Coverage for src/flag_gems/ops/argmin.py: 45%
161 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry
11from flag_gems.utils import triton_lang_extension as tle
12from flag_gems.utils.limits import get_dtype_max
14logger = logging.getLogger(__name__)
17@libentry()
18@triton.jit
19def argmin_kernel_1(
20 inp,
21 mid_value,
22 mid_index,
23 M,
24 BLOCK_SIZE: tl.constexpr,
25):
26 pid = tle.program_id(0)
27 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
28 inp_ptrs = inp + offset
29 mask = offset < M
31 max_value = get_dtype_max(inp.type.element_ty)
32 inp_val = tl.load(inp_ptrs, mask=mask, other=max_value)
33 min_val, min_index = tl.min(inp_val, axis=0, return_indices=True)
34 min_index = min_index + pid * BLOCK_SIZE
35 mid_value_ptr = mid_value + pid
36 min_index_ptr = mid_index + pid
37 tl.store(mid_value_ptr, min_val)
38 tl.store(min_index_ptr, min_index)
41@libentry()
42@triton.jit
43def argmin_kernel_2(
44 mid_value,
45 mid_index,
46 out,
47 mid_size,
48 BLOCK_MID: tl.constexpr,
49):
50 offset = tl.arange(0, BLOCK_MID)
51 mid_ptrs = mid_value + offset
52 mask = offset < mid_size
53 max_value = get_dtype_max(mid_value.type.element_ty)
54 mid_val = tl.load(mid_ptrs, mask=mask, other=max_value)
55 index_val = tl.argmin(mid_val, axis=0)
56 mid_index_ptrs = mid_index + index_val
57 out_val = tl.load(mid_index_ptrs)
58 tl.store(out, out_val)
61def heur_block_n(args):
62 return min(4096, triton.next_power_of_2(args["N"]))
65@libentry()
66@triton.heuristics(runtime.get_heuristic_config("argmin"))
67@triton.jit
68def argmin_kernel_opt_k1(
69 inp,
70 out_index,
71 M,
72 N,
73 BLOCK_M: tl.constexpr,
74 BLOCK_N: tl.constexpr,
75):
76 pid_m = tle.program_id(0)
77 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
79 dtype = inp.type.element_ty
80 acc_type = tl.float32 if dtype is tl.bfloat16 else dtype
81 max_val = get_dtype_max(dtype)
83 min_vals = tl.full([BLOCK_M], dtype=acc_type, value=max_val)
84 argmin_vals = tl.full([BLOCK_M], dtype=tl.int64, value=0)
85 for start_n in range(0, N, BLOCK_N):
86 n_offset = start_n + tl.arange(0, BLOCK_N)
87 offset = m_offset[:, None] * N + n_offset[None, :]
88 inp_vals = tl.load(inp + offset, mask=True)
90 local_min, local_argmin = tl.min(
91 inp_vals, 1, return_indices=True, return_indices_tie_break_left=True
92 )
93 update = local_min < min_vals
94 min_vals = tl.where(update, local_min, min_vals)
95 argmin_vals = tl.where(update, start_n + local_argmin, argmin_vals)
97 out_ptr = out_index + m_offset
98 tl.store(out_ptr, argmin_vals, mask=True)
101@libentry()
102@triton.autotune(
103 configs=runtime.get_tuned_config("argmin_split_k"), key=["M", "N", "K"]
104)
105@triton.jit
106def argmin_split_K_kernel_merged(
107 inp,
108 out_index,
109 M: tl.constexpr,
110 N: tl.constexpr,
111 K: tl.constexpr,
112 dtype: tl.constexpr,
113 BLOCK_M: tl.constexpr,
114 BLOCK_N: tl.constexpr,
115 BLOCK_K: tl.constexpr,
116):
117 pid_m = tle.program_id(0)
118 pid_k = tle.program_id(1)
120 m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] # (BLOCK_M, 1)
121 k = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)[None, :] # (1, BLOCK_K)
123 m_mask = m < M
124 k_mask = k < K
125 mk_mask = m_mask & k_mask
127 compute_dtype = tl.float32 if dtype == tl.bfloat16 else dtype
128 max_val = get_dtype_max(compute_dtype)
130 global_min = tl.full((BLOCK_M, BLOCK_K), max_val, dtype=compute_dtype)
131 global_argmin = tl.full((BLOCK_M, BLOCK_K), 0, dtype=tl.int64)
133 for start_n in range(0, N, BLOCK_N):
134 n = start_n + tl.arange(0, BLOCK_N)
135 n_mask = n < N
137 offset = m * N * K + n[:, None, None] * K + k[None, :, :]
139 inp_vals = tl.load(
140 inp + offset,
141 mask=(m_mask & n_mask[:, None, None] & k_mask[None, :, :]),
142 other=max_val,
143 )
144 inp_vals = inp_vals.to(compute_dtype)
146 local_min, local_argmin = tl.min(
147 inp_vals, 0, return_indices=True, return_indices_tie_break_left=True
148 )
149 local_argmin += start_n
151 mask = local_min < global_min
152 global_min = tl.where(mask, local_min, global_min)
153 global_argmin = tl.where(mask, local_argmin, global_argmin)
155 out_offset = m * K + k # (BLOCK_M, BLOCK_K)
156 tl.store(out_index + out_offset, global_argmin, mask=mk_mask)
159@libentry()
160@triton.heuristics(runtime.get_heuristic_config("argmin"))
161@triton.jit
162def argmin_kernel(
163 inp,
164 out_index,
165 M,
166 N,
167 K,
168 BLOCK_M: tl.constexpr,
169 BLOCK_N: tl.constexpr,
170):
171 # set offset
172 pid_m = tle.program_id(0)
173 pid_k = tle.program_id(1)
174 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
176 dtype = inp.type.element_ty
177 acc_type = tl.float32 if dtype is tl.bfloat16 else dtype
178 max_value = get_dtype_max(dtype)
179 min_values = tl.full([BLOCK_M], dtype=acc_type, value=max_value)
180 argmin_values = tl.full([BLOCK_M], dtype=tl.int64, value=0)
181 for start_n in range(0, N, BLOCK_N):
182 n_offset = start_n + tl.arange(0, BLOCK_N)
183 offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k
184 mask = m_offset[:, None] < M and n_offset[None, :] < N
185 inp_ptrs = inp + offset
186 inp_vals = tl.load(inp_ptrs, mask=mask, other=max_value)
187 # tl.bfloat is promoted to tl.float32 by tl.min
188 local_min, local_argmin = tl.min(
189 inp_vals, 1, return_indices=True, return_indices_tie_break_left=True
190 )
191 # if return indices is not supported, call a tl.argmin in addition
192 # local_argmin = tl.argmin(inp_vals, 1)
193 update = local_min < min_values
194 min_values = tl.where(update, local_min, min_values)
195 argmin_values = tl.where(update, start_n + local_argmin, argmin_values)
197 offset_index = m_offset * K + pid_k
198 out_index_ptrs = out_index + offset_index
199 mask1 = m_offset < M
200 tl.store(out_index_ptrs, argmin_values, mask=mask1)
203def argmin(inp, dim=None, keepdim=False, *, dtype=None):
204 logger.debug("GEMS ARGMIN")
205 if dim is None:
206 M = inp.numel()
207 if dtype is None:
208 dtype = inp.dtype
209 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M)))
210 mid_size = triton.cdiv(M, block_size)
211 block_mid = triton.next_power_of_2(mid_size)
213 mid_value = torch.empty((mid_size,), dtype=dtype, device=inp.device)
214 mid_index = torch.empty((mid_size,), dtype=torch.int64, device=inp.device)
215 if keepdim:
216 shape = list(inp.shape)
217 for i in range(0, inp.dim()):
218 shape[i] = 1
219 out = torch.empty(shape, dtype=torch.int64, device=inp.device)
220 else:
221 out = torch.empty([], dtype=torch.int64, device=inp.device)
223 with torch_device_fn.device(inp.device):
224 argmin_kernel_1[(mid_size, 1, 1)](
225 inp,
226 mid_value,
227 mid_index,
228 M,
229 block_size,
230 )
231 argmin_kernel_2[(1, 1, 1)](
232 mid_value,
233 mid_index,
234 out,
235 mid_size,
236 block_mid,
237 )
238 return out
239 else:
240 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
241 shape = inp.shape
242 dim = dim % inp.ndim
243 N = shape[dim]
244 M = math.prod(shape[:dim])
245 K = inp.numel() // M // N
246 inp = inp.contiguous()
248 shape_list = list(shape)
249 shape_list[dim] = 1
250 out_index = torch.empty(shape_list, dtype=torch.int64, device=inp.device)
251 if not keepdim:
252 out_index = torch.squeeze(out_index, dim)
254 grid = lambda meta: (
255 triton.cdiv(M, meta["BLOCK_M"]),
256 K,
257 )
258 if K == 1 and inp.dtype != torch.int32 and inp.dtype != torch.int16:
259 with torch_device_fn.device(inp.device):
260 argmin_kernel_opt_k1[grid](
261 inp,
262 out_index,
263 M,
264 N,
265 )
267 else:
268 torch2triton_dtype = {
269 torch.float16: tl.float16,
270 torch.bfloat16: tl.bfloat16,
271 torch.float32: tl.float32,
272 }
273 # general support for other (N, K)
274 if (
275 (N % 64 == 0 or N == 512)
276 and (K % 32 == 0)
277 and M % 8 == 0
278 and inp.dtype != torch.int32
279 and inp.dtype != torch.int16
280 ):
281 triton_dtype = torch2triton_dtype[inp.dtype]
282 # use default paramerter to calcualte grid
283 grid_for_split_K = (triton.cdiv(M, 8), triton.cdiv(K, 32))
284 with torch_device_fn.device(inp.device):
285 argmin_split_K_kernel_merged[grid_for_split_K](
286 inp,
287 out_index,
288 M,
289 N,
290 K,
291 dtype=triton_dtype,
292 )
293 else:
294 with torch_device_fn.device(inp.device):
295 argmin_kernel[grid](
296 inp,
297 out_index,
298 M,
299 N,
300 K,
301 )
303 return out_index