Coverage for src/flag_gems/runtime/backend/_tsingmicro/ops/argmax.py: 0%
132 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-17 02:35 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-17 02:35 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry, libtuner
11from flag_gems.utils.shape_utils import can_use_int32_index
13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
16def cfggen_reduce_op():
17 return runtime.get_tuned_config("argmax_kernel_1")
20@libentry()
21@triton.jit
22def argmax_kernel_once(
23 inp,
24 out,
25 M: tl.constexpr,
26):
27 offset = tl.arange(0, M)
28 inp_val = tl.load(inp + offset)
29 index_val = tl.argmax(inp_val, axis=0)
30 tl.store(out, index_val.to(tl.int64))
33@libentry()
34@libtuner(
35 configs=cfggen_reduce_op(),
36 key=["M"],
37 strategy=["log"],
38)
39@triton.jit
40def argmax_kernel_1(
41 inp,
42 mid_value,
43 mid_index,
44 real_size,
45 M,
46 BLOCK_SIZE: tl.constexpr,
47 INT64_INDEX: tl.constexpr = False,
48):
49 pid = tl.program_id(0)
50 if INT64_INDEX:
51 pid = pid.to(tl.int64)
52 num_jobs = tl.num_programs(axis=0)
54 size_per_job = (M + num_jobs - 1) // num_jobs
55 start_idx = pid * size_per_job
56 end_idx = min(start_idx + size_per_job, M)
58 max_tmp = -float("inf")
59 index_tmp = 0
60 if INT64_INDEX:
61 index_tmp = index_tmp.to(tl.int64)
62 for off in range(start_idx, end_idx, BLOCK_SIZE):
63 offset = off + tl.arange(0, BLOCK_SIZE)
64 mask = offset < end_idx
65 inp_val = tl.load(inp + offset, mask=mask, other=-float("inf"))
66 max_val, max_index = tl.max(inp_val, axis=0, return_indices=True)
67 if max_val > max_tmp:
68 max_tmp = max_val.to(tl.float32)
69 index_tmp = max_index + off
70 mid_value_ptr = mid_value + pid
71 max_index_ptr = mid_index + pid
72 tl.store(mid_value_ptr, max_tmp)
73 tl.store(max_index_ptr, index_tmp)
74 tl.store(real_size, num_jobs)
77@libentry()
78@triton.jit
79def argmax_kernel_2(mid_value, mid_index, out, real_size, mid_size: tl.constexpr):
80 size = tl.load(real_size)
81 offset = tl.arange(0, mid_size)
82 mid_ptrs = mid_value + offset
83 mid_val = tl.load(mid_ptrs, mask=offset < size, other=-float("inf"))
84 index_val = tl.argmax(mid_val, axis=0)
85 mid_index_ptrs = mid_index + index_val
86 out_val = tl.load(mid_index_ptrs)
87 tl.store(out, out_val)
90@libentry()
91@libtuner(
92 configs=runtime.get_tuned_config("argmax"),
93 key=[
94 "M",
95 "N",
96 ],
97 strategy=["log", "log"],
98)
99@triton.jit
100def argmax_kernel(
101 inp,
102 out_index,
103 M,
104 N,
105 K,
106 BLOCK_M: tl.constexpr,
107 BLOCK_N: tl.constexpr,
108 INT64_INDEX: tl.constexpr = False,
109):
110 # set offset
111 pid_m = tl.program_id(0)
112 pid_k = tl.program_id(1)
113 if INT64_INDEX:
114 pid_m = pid_m.to(tl.int64)
115 pid_k = pid_k.to(tl.int64)
116 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
118 max_values = tl.full([BLOCK_M], dtype=tl.float32, value=float("-inf"))
119 argmax_values = tl.full([BLOCK_M], dtype=tl.int64, value=0)
120 for start_n in range(0, N, BLOCK_N):
121 n_offset = start_n + tl.arange(0, BLOCK_N)
122 offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k
123 mask = m_offset[:, None] < M and n_offset[None, :] < N
124 inp_ptrs = inp + offset
125 inp_vals = tl.load(inp_ptrs, mask=mask, other=-float("inf"))
126 local_max, local_argmax = tl.max(
127 inp_vals, 1, return_indices=True, return_indices_tie_break_left=True
128 )
129 # if return indices is not supported, call a tl.argmax in addition
130 # local_argmax = tl.argmax(inp_vals, 1)
131 update = local_max > max_values
132 max_values = tl.where(update, local_max, max_values)
133 argmax_values = tl.where(update, start_n + local_argmax, argmax_values)
135 offset_index = m_offset * K + pid_k
136 out_index_ptrs = out_index + offset_index
137 mask1 = m_offset < M
138 tl.store(out_index_ptrs, argmax_values, mask=mask1)
141def argmax(inp, dim=None, keepdim=False, *, dtype=None):
142 logger.debug("GEMS_TSINGMICRO ARGMAX")
143 if dim is None:
144 M = inp.numel()
145 if dtype is None:
146 dtype = inp.dtype
148 use_int64_index = not can_use_int32_index(inp)
150 if keepdim:
151 shape = list(inp.shape)
152 for i in range(0, inp.dim()):
153 shape[i] = 1
154 out = torch.empty(shape, dtype=torch.int64, device=inp.device)
155 else:
156 out = torch.empty([], dtype=torch.int64, device=inp.device)
158 if M <= 65530:
159 with torch_device_fn.device(inp.device):
160 argmax_kernel_once[(1, 1, 1)](inp, out, M)
161 else:
162 grid = lambda meta: (
163 min(
164 triton.cdiv(M, meta["BLOCK_SIZE"]),
165 torch_device_fn.get_device_properties().multi_processor_count,
166 ),
167 )
168 mid_size = torch_device_fn.get_device_properties().multi_processor_count
169 real_size = torch.empty([], dtype=torch.int32, device=inp.device)
170 mid_value = torch.empty((mid_size,), dtype=torch.float32, device=inp.device)
171 mid_index = torch.empty((mid_size,), dtype=torch.int64, device=inp.device)
172 with torch_device_fn.device(inp.device):
173 argmax_kernel_1[grid](
174 inp,
175 mid_value,
176 mid_index,
177 real_size,
178 M,
179 INT64_INDEX=use_int64_index,
180 )
181 argmax_kernel_2[(1, 1, 1)](
182 mid_value, mid_index, out, real_size, mid_size
183 )
184 return out
185 else:
186 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
187 shape = inp.shape
188 dim = dim % inp.ndim
189 if inp.numel() == 0:
190 out_shape = list(shape)
191 if keepdim:
192 out_shape[dim] = 1
193 else:
194 del out_shape[dim]
195 return torch.zeros(out_shape, dtype=torch.int64, device=inp.device)
196 N = shape[dim]
197 M = math.prod(shape[:dim])
198 K = inp.numel() // M // N
200 inp = inp.contiguous()
201 use_int64_index = not can_use_int32_index(inp)
203 shape_list = list(shape)
204 shape_list[dim] = 1
205 out_index = torch.empty(shape_list, dtype=torch.int64, device=inp.device)
206 if not keepdim:
207 out_index = torch.squeeze(out_index, dim)
209 grid = lambda meta: (
210 triton.cdiv(M, meta["BLOCK_M"]),
211 K,
212 )
213 with torch_device_fn.device(inp.device):
214 argmax_kernel[grid](inp, out_index, M, N, K, INT64_INDEX=use_int64_index)
216 return out_index