Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/argmax.py: 0%
137 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-23 02:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-23 02:03 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry
11from flag_gems.utils import triton_lang_extension as tle
12from flag_gems.utils.limits import get_dtype_min
14torch_dtype_to_tl_dtype_and_min_value = {
15 torch.int16: (tl.int16, torch.iinfo(torch.int16).min),
16 torch.int32: (tl.int32, torch.iinfo(torch.int32).min),
17 torch.float16: (tl.float16, torch.finfo(torch.float16).min),
18 torch.float32: (tl.float32, torch.finfo(torch.float32).min),
19 torch.bfloat16: (tl.float32, torch.finfo(torch.float32).min),
20}
21logger = logging.getLogger(__name__)
24@libentry()
25@triton.jit
26def argmax_kernel_1(
27 inp,
28 mid_value,
29 mid_index,
30 M,
31 BLOCK_SIZE: tl.constexpr,
32):
33 pid = tle.program_id(0)
34 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
35 inp_ptrs = inp + offset
36 mask = offset < M
37 min_value = get_dtype_min(inp.type.element_ty)
38 inp_val = tl.load(inp_ptrs, mask=mask, other=min_value)
39 max_val, max_index = tl.max(inp_val, axis=0, return_indices=True)
40 max_index = max_index + pid * BLOCK_SIZE
41 mid_value_ptr = mid_value + pid
42 max_index_ptr = mid_index + pid
43 tl.store(mid_value_ptr, max_val)
44 tl.store(max_index_ptr, max_index)
47@libentry()
48@triton.jit
49def argmax_kernel_2(mid_value, mid_index, out, mid_size, BLOCK_MID: tl.constexpr):
50 offset = tl.arange(0, BLOCK_MID)
51 mid_ptrs = mid_value + offset
52 mask = offset < mid_size
53 min_value = get_dtype_min(mid_value.type.element_ty)
54 mid_val = tl.load(mid_ptrs, mask=mask, other=min_value)
55 index_val = tl.argmax(mid_val, axis=0)
56 mid_index_ptrs = mid_index + index_val
57 out_val = tl.load(mid_index_ptrs)
58 tl.store(out, out_val)
61def heur_m_block_size(args):
62 return triton.next_power_of_2(triton.cdiv(args["M"], 12)) # cluster_num
65def heur_n_block_size(args):
66 import builtins
68 return builtins.min(triton.next_power_of_2(args["N"]), 8192)
71@libentry()
72# @triton.heuristics(runtime.get_heuristic_config("argmax"))
73@triton.heuristics(
74 values={
75 "BLOCK_M": heur_m_block_size,
76 "BLOCK_N": heur_n_block_size,
77 },
78)
79@triton.jit
80def argmax_kernel(
81 inp,
82 out_index,
83 M: tl.constexpr,
84 N: tl.constexpr,
85 K: tl.constexpr,
86 BLOCK_M: tl.constexpr,
87 BLOCK_N: tl.constexpr,
88):
89 # set offset
90 pid_m = tle.program_id(0)
91 pid_k = tle.program_id(1)
92 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
94 dtype = inp.type.element_ty
95 acc_type = tl.float32 if dtype is tl.bfloat16 else dtype
96 min_value = get_dtype_min(dtype)
97 max_values = tl.full([BLOCK_M], dtype=acc_type, value=min_value)
98 argmax_values = tl.full([BLOCK_M], dtype=tl.int64, value=0)
99 for start_n in range(0, N, BLOCK_N):
100 n_offset = start_n + tl.arange(0, BLOCK_N)
101 offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k
102 mask = m_offset[:, None] < M and n_offset[None, :] < N
103 inp_ptrs = inp + offset
104 inp_vals = tl.load(inp_ptrs, mask=mask, other=min_value)
105 local_max, local_argmax = tl.max(
106 inp_vals, 1, return_indices=True, return_indices_tie_break_left=True
107 )
108 # if return indices is not supported, call a tl.argmax in addition
109 # local_argmax = tl.argmax(inp_vals, 1)
110 update = local_max > max_values
111 max_values = tl.where(update, local_max, max_values)
112 argmax_values = tl.where(update, start_n + local_argmax, argmax_values)
114 offset_index = m_offset * K + pid_k
115 out_index_ptrs = out_index + offset_index
116 mask1 = m_offset < M
117 tl.store(out_index_ptrs, argmax_values, mask=mask1)
120@libentry()
121@triton.heuristics(runtime.get_heuristic_config("argmax"))
122@triton.jit
123def argmax_kernel_small_n(
124 inp,
125 out_index,
126 M,
127 N,
128 K,
129 tl_dtype: tl.constexpr,
130 dtype_min_value: tl.constexpr,
131 BLOCK_M: tl.constexpr,
132 BLOCK_N: tl.constexpr,
133):
134 # set offset
135 pid_m = tle.program_id(0)
136 pid_k = tle.program_id(1)
137 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
139 if tl_dtype is tl.int16:
140 tl_dtype = tl.int32
141 n_offset = tl.arange(0, BLOCK_N)
142 offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k
143 offset_index = m_offset * K + pid_k
144 # set mask
145 mask1 = m_offset < M
146 mask = m_offset[:, None] < M and n_offset[None, :] < N
147 inp_ptrs = inp + offset
148 inp_vals = tl.load(inp_ptrs, mask=mask, other=dtype_min_value)
149 _, result_index = tl.max(inp_vals, axis=1, return_indices=True)
151 out_index_ptrs = out_index + offset_index
153 tl.store(out_index_ptrs, result_index, mask=mask1)
156def argmax(inp, dim=None, keepdim=False, *, dtype=None):
157 logger.debug("GEMS ARGMAX")
158 if dim is None:
159 M = inp.numel()
160 if dtype is None:
161 dtype = inp.dtype
162 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M)))
163 mid_size = triton.cdiv(M, block_size)
164 block_mid = triton.next_power_of_2(mid_size)
166 mid_value = torch.empty((mid_size,), dtype=dtype, device=inp.device)
167 mid_index = torch.empty((mid_size,), dtype=torch.int64, device=inp.device)
168 if keepdim:
169 shape = list(inp.shape)
170 for i in range(0, inp.dim()):
171 shape[i] = 1
172 out = torch.empty(shape, dtype=torch.int64, device=inp.device)
173 else:
174 out = torch.empty([], dtype=torch.int64, device=inp.device)
176 with torch_device_fn.device(inp.device):
177 argmax_kernel_1[(mid_size, 1, 1)](
178 inp,
179 mid_value,
180 mid_index,
181 M,
182 block_size,
183 )
184 argmax_kernel_2[(1, 1, 1)](mid_value, mid_index, out, mid_size, block_mid)
185 return out
186 else:
187 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
188 shape = inp.shape
189 dim = dim % inp.ndim
190 if inp.numel() == 0:
191 out_shape = list(shape)
192 if keepdim:
193 out_shape[dim] = 1
194 else:
195 del out_shape[dim]
196 return torch.zeros(out_shape, dtype=torch.int64, device=inp.device)
197 N = shape[dim]
198 M = math.prod(shape[:dim])
199 K = inp.numel() // M // N
201 inp = inp.contiguous()
203 shape_list = list(shape)
204 shape_list[dim] = 1
205 out_index = torch.empty(shape_list, dtype=torch.int64, device=inp.device)
206 if not keepdim:
207 out_index = torch.squeeze(out_index, dim)
208 grid = lambda meta: (
209 triton.cdiv(M, meta["BLOCK_M"]),
210 K,
211 )
213 if N == 1:
214 tl_dtype, dtype_min_value = torch_dtype_to_tl_dtype_and_min_value[inp.dtype]
215 with torch_device_fn.device(inp.device):
216 argmax_kernel_small_n[grid](
217 inp,
218 out_index,
219 M,
220 N,
221 K,
222 tl_dtype,
223 dtype_min_value,
224 )
225 return out_index
227 with torch_device_fn.device(inp.device):
228 argmax_kernel[grid](
229 inp,
230 out_index,
231 M,
232 N,
233 K,
234 is_use_mask_zero=True,
235 )
237 return out_index