Coverage for src/flag_gems/ops/argmax.py: 44%
163 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-24 15:40 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-24 15:40 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry
11from flag_gems.utils import triton_lang_extension as tle
12from flag_gems.utils.limits import get_dtype_min
14logger = logging.getLogger(__name__)
17@libentry()
18@triton.jit
19def argmax_kernel_1(
20 inp,
21 mid_value,
22 mid_index,
23 M,
24 BLOCK_SIZE: tl.constexpr,
25):
26 pid = tle.program_id(0)
27 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
28 inp_ptrs = inp + offset
29 mask = offset < M
30 min_value = get_dtype_min(inp.type.element_ty)
31 inp_val = tl.load(inp_ptrs, mask=mask, other=min_value)
32 max_val, max_index = tl.max(inp_val, axis=0, return_indices=True)
33 max_index = max_index + pid * BLOCK_SIZE
34 mid_value_ptr = mid_value + pid
35 max_index_ptr = mid_index + pid
36 tl.store(mid_value_ptr, max_val)
37 tl.store(max_index_ptr, max_index)
40@libentry()
41@triton.jit
42def argmax_kernel_2(mid_value, mid_index, out, mid_size, BLOCK_MID: tl.constexpr):
43 offset = tl.arange(0, BLOCK_MID)
44 mid_ptrs = mid_value + offset
45 mask = offset < mid_size
46 min_value = get_dtype_min(mid_value.type.element_ty)
47 mid_val = tl.load(mid_ptrs, mask=mask, other=min_value)
48 index_val = tl.argmax(mid_val, axis=0)
49 mid_index_ptrs = mid_index + index_val
50 out_val = tl.load(mid_index_ptrs)
51 tl.store(out, out_val)
54@libentry()
55@triton.heuristics(runtime.get_heuristic_config("argmax_non_inner"))
56@triton.jit
57def argmax_kernel_non_inner(
58 inp,
59 out_index,
60 M,
61 N,
62 K,
63 TILE_K: tl.constexpr,
64 TILE_N: tl.constexpr,
65 ONE_TILE_PER_CTA: tl.constexpr,
66):
67 pid_m = tle.program_id(0)
68 pid_k = tle.program_id(1)
69 k_offset = pid_k * TILE_K + tl.arange(0, TILE_K)
71 if tl.constexpr(inp.dtype.element_ty == tl.float16) or tl.constexpr(
72 inp.dtype.element_ty == tl.bfloat16
73 ):
74 cdtype = tl.float32
75 else:
76 cdtype = inp.dtype.element_ty
78 min_value = get_dtype_min(cdtype)
80 if ONE_TILE_PER_CTA:
81 n_offset = tl.arange(0, TILE_N)
82 offset = pid_m * N * K + n_offset[:, None] * K + k_offset
83 mask = k_offset < K and n_offset[:, None] < N
84 inp_ptrs = inp + offset
85 inp_vals = tl.load(inp_ptrs, mask=mask, other=min_value)
86 local_max, local_argmax = tl.max(
87 inp_vals, 0, return_indices=True, return_indices_tie_break_left=True
88 )
89 offset_index = pid_m * K + k_offset
90 out_index_ptrs = out_index + offset_index
91 mask1 = k_offset < K
92 tl.store(out_index_ptrs, local_argmax, mask=mask1)
93 else:
94 max_values = tl.full([TILE_K], dtype=cdtype, value=min_value)
95 argmax_values = tl.full([TILE_K], dtype=tl.int64, value=0)
97 for start_n in range(0, N, TILE_N):
98 n_offset = start_n + tl.arange(0, TILE_N)
99 offset = pid_m * N * K + n_offset[:, None] * K + k_offset
100 mask = k_offset < K and n_offset[:, None] < N
101 inp_ptrs = inp + offset
102 inp_vals = tl.load(inp_ptrs, mask=mask, other=min_value)
103 local_max, local_argmax = tl.max(
104 inp_vals, 0, return_indices=True, return_indices_tie_break_left=True
105 )
106 update = local_max > max_values
107 max_values = tl.where(update, local_max, max_values)
108 argmax_values = tl.where(update, start_n + local_argmax, argmax_values)
109 offset_index = pid_m * K + k_offset
110 out_index_ptrs = out_index + offset_index
111 mask1 = k_offset < K
112 tl.store(out_index_ptrs, argmax_values, mask=mask1)
115@libentry()
116@triton.heuristics(runtime.get_heuristic_config("argmax_inner"))
117@triton.jit
118def argmax_kernel_inner(
119 inp,
120 out_index,
121 M,
122 N,
123 TILE_N: tl.constexpr,
124 ONE_TILE_PER_CTA: tl.constexpr,
125):
126 pid_m = tle.program_id(0)
128 dtype = inp.type.element_ty
129 min_value = get_dtype_min(dtype)
131 if ONE_TILE_PER_CTA:
132 n_offset = tl.arange(0, TILE_N)
133 offset = pid_m * N + n_offset
134 mask = n_offset < N
135 inp_ptrs = inp + offset
136 inp_vals = tl.load(inp_ptrs, mask=mask, other=min_value)
137 local_max, local_argmax = tl.max(
138 inp_vals, 0, return_indices=True, return_indices_tie_break_left=True
139 )
140 out_index_ptrs = out_index + pid_m
141 tl.store(out_index_ptrs, local_argmax)
142 else:
143 max_values = min_value
144 argmax_values = 0
146 loop_time = N // TILE_N
147 remainder = N % TILE_N
148 for start_n in range(0, loop_time):
149 n_offset = start_n * TILE_N + tl.arange(0, TILE_N)
150 offset = pid_m * N + n_offset
151 inp_ptrs = inp + offset
152 inp_vals = tl.load(inp_ptrs)
153 local_max, local_argmax = tl.max(
154 inp_vals, 0, return_indices=True, return_indices_tie_break_left=True
155 )
156 update = local_max > max_values
157 max_values = tl.where(update, local_max, max_values)
158 argmax_values = tl.where(
159 update, start_n * TILE_N + local_argmax, argmax_values
160 )
162 if remainder:
163 n_offset = loop_time * TILE_N + tl.arange(0, TILE_N)
164 offset = pid_m * N + n_offset
165 mask = n_offset < N
166 inp_ptrs = inp + offset
167 inp_vals = tl.load(inp_ptrs, mask=mask, other=min_value)
168 local_max, local_argmax = tl.max(
169 inp_vals, 0, return_indices=True, return_indices_tie_break_left=True
170 )
171 update = local_max > max_values
172 max_values = tl.where(update, local_max, max_values)
173 argmax_values = tl.where(
174 update, loop_time * TILE_N + local_argmax, argmax_values
175 )
177 out_index_ptrs = out_index + pid_m
178 tl.store(out_index_ptrs, argmax_values)
181def argmax(inp, dim=None, keepdim=False, *, dtype=None):
182 logger.debug("GEMS ARGMAX")
183 if dim is None:
184 M = inp.numel()
185 if dtype is None:
186 dtype = inp.dtype
187 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M)))
188 mid_size = triton.cdiv(M, block_size)
189 block_mid = triton.next_power_of_2(mid_size)
191 mid_value = torch.empty((mid_size,), dtype=dtype, device=inp.device)
192 mid_index = torch.empty((mid_size,), dtype=torch.int64, device=inp.device)
193 if keepdim:
194 shape = list(inp.shape)
195 for i in range(0, inp.dim()):
196 shape[i] = 1
197 out = torch.empty(shape, dtype=torch.int64, device=inp.device)
198 else:
199 out = torch.empty([], dtype=torch.int64, device=inp.device)
201 with torch_device_fn.device(inp.device):
202 argmax_kernel_1[(mid_size, 1, 1)](
203 inp,
204 mid_value,
205 mid_index,
206 M,
207 block_size,
208 )
209 argmax_kernel_2[(1, 1, 1)](mid_value, mid_index, out, mid_size, block_mid)
210 return out
211 else:
212 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
213 shape = inp.shape
214 dim = dim % inp.ndim
215 if inp.numel() == 0:
216 out_shape = list(shape)
217 if keepdim:
218 out_shape[dim] = 1
219 else:
220 del out_shape[dim]
221 return torch.zeros(out_shape, dtype=torch.int64, device=inp.device)
222 N = shape[dim]
223 M = math.prod(shape[:dim])
224 K = inp.numel() // M // N
226 inp = inp.contiguous()
228 shape_list = list(shape)
229 shape_list[dim] = 1
230 out_index = torch.empty(shape_list, dtype=torch.int64, device=inp.device)
231 if not keepdim:
232 out_index = torch.squeeze(out_index, dim)
234 with torch_device_fn.device(inp.device):
235 if K > 1:
236 grid = lambda meta: (
237 M,
238 triton.cdiv(K, meta["TILE_K"]),
239 )
240 argmax_kernel_non_inner[grid](
241 inp,
242 out_index,
243 M,
244 N,
245 K,
246 )
247 else:
248 grid = lambda meta: (M, 1, 1)
249 argmax_kernel_inner[grid](
250 inp,
251 out_index,
252 M,
253 N,
254 )
255 return out_index