Coverage for src/flag_gems/runtime/backend/_cambricon/ops/argmax.py: 0%
133 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-26 15:32 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-26 15:32 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry, libtuner
11from flag_gems.utils.shape_utils import can_use_int32_index
13from ..utils import TOTAL_CORE_NUM
15logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
18def cfggen_reduce_op():
19 return runtime.get_tuned_config("argmax_kernel_1")
22@libentry()
23@triton.jit
24def argmax_kernel_once(
25 inp,
26 out,
27 M: tl.constexpr,
28):
29 offset = tl.arange(0, M)
30 inp_val = tl.load(inp + offset)
31 index_val = tl.argmax(inp_val, axis=0)
32 tl.store(out, index_val.to(tl.int64))
35@libentry()
36@libtuner(
37 configs=cfggen_reduce_op(),
38 key=["M"],
39 strategy=["log"],
40)
41@triton.jit
42def argmax_kernel_1(
43 inp,
44 mid_value,
45 mid_index,
46 real_size,
47 M,
48 BLOCK_SIZE: tl.constexpr,
49 INT64_INDEX: tl.constexpr = False,
50):
51 pid = tl.program_id(0)
52 if INT64_INDEX:
53 pid = pid.to(tl.int64)
54 num_jobs = tl.num_programs(axis=0)
56 size_per_job = (M + num_jobs - 1) // num_jobs
57 start_idx = pid * size_per_job
58 end_idx = min(start_idx + size_per_job, M)
60 max_tmp = -float("inf")
61 index_tmp = 0
62 if INT64_INDEX:
63 index_tmp = index_tmp.to(tl.int64)
64 for off in range(start_idx, end_idx, BLOCK_SIZE):
65 offset = off + tl.arange(0, BLOCK_SIZE)
66 mask = offset < end_idx
67 inp_val = tl.load(inp + offset, mask=mask, other=-float("inf"))
68 max_val, max_index = tl.max(inp_val, axis=0, return_indices=True)
69 if max_val > max_tmp:
70 max_tmp = max_val.to(tl.float32)
71 index_tmp = max_index + off
72 mid_value_ptr = mid_value + pid
73 max_index_ptr = mid_index + pid
74 tl.store(mid_value_ptr, max_tmp)
75 tl.store(max_index_ptr, index_tmp)
76 tl.store(real_size, num_jobs)
79@libentry()
80@triton.jit
81def argmax_kernel_2(mid_value, mid_index, out, real_size, mid_size: tl.constexpr):
82 size = tl.load(real_size)
83 offset = tl.arange(0, mid_size)
84 mid_ptrs = mid_value + offset
85 mid_val = tl.load(mid_ptrs, mask=offset < size, other=-float("inf"))
86 index_val = tl.argmax(mid_val, axis=0)
87 mid_index_ptrs = mid_index + index_val
88 out_val = tl.load(mid_index_ptrs)
89 tl.store(out, out_val)
92@libentry()
93@libtuner(
94 configs=runtime.get_tuned_config("argmax"),
95 key=[
96 "M",
97 "N",
98 ],
99 strategy=["log", "log"],
100)
101@triton.jit
102def argmax_kernel(
103 inp,
104 out_index,
105 M,
106 N,
107 K,
108 BLOCK_M: tl.constexpr,
109 BLOCK_N: tl.constexpr,
110 INT64_INDEX: tl.constexpr = False,
111):
112 # set offset
113 pid_m = tl.program_id(0)
114 pid_k = tl.program_id(1)
115 if INT64_INDEX:
116 pid_m = pid_m.to(tl.int64)
117 pid_k = pid_k.to(tl.int64)
118 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
120 max_values = tl.full([BLOCK_M], dtype=tl.float32, value=float("-inf"))
121 argmax_values = tl.full([BLOCK_M], dtype=tl.int64, value=0)
122 for start_n in range(0, N, BLOCK_N):
123 n_offset = start_n + tl.arange(0, BLOCK_N)
124 offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k
125 mask = m_offset[:, None] < M and n_offset[None, :] < N
126 inp_ptrs = inp + offset
127 inp_vals = tl.load(inp_ptrs, mask=mask, other=-float("inf"))
128 local_max, local_argmax = tl.max(
129 inp_vals, 1, return_indices=True, return_indices_tie_break_left=True
130 )
131 # if return indices is not supported, call a tl.argmax in addition
132 # local_argmax = tl.argmax(inp_vals, 1)
133 update = local_max > max_values
134 max_values = tl.where(update, local_max, max_values)
135 argmax_values = tl.where(update, start_n + local_argmax, argmax_values)
137 offset_index = m_offset * K + pid_k
138 out_index_ptrs = out_index + offset_index
139 mask1 = m_offset < M
140 tl.store(out_index_ptrs, argmax_values, mask=mask1)
143def argmax(inp, dim=None, keepdim=False, *, dtype=None):
144 logger.debug("GEMS_CAMBRICON ARGMAX")
145 if dim is None:
146 M = inp.numel()
147 if dtype is None:
148 dtype = inp.dtype
150 use_int64_index = not can_use_int32_index(inp)
152 if keepdim:
153 shape = list(inp.shape)
154 for i in range(0, inp.dim()):
155 shape[i] = 1
156 out = torch.empty(shape, dtype=torch.int64, device=inp.device)
157 else:
158 out = torch.empty([], dtype=torch.int64, device=inp.device)
160 if M <= 65530:
161 with torch_device_fn.device(inp.device):
162 argmax_kernel_once[(1, 1, 1)](inp, out, M)
163 else:
164 grid = lambda meta: (
165 min(triton.cdiv(M, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),
166 )
167 mid_size = TOTAL_CORE_NUM
168 real_size = torch.empty([], dtype=torch.int32, device=inp.device)
169 mid_value = torch.empty((mid_size,), dtype=torch.float32, device=inp.device)
170 mid_index = torch.empty((mid_size,), dtype=torch.int64, device=inp.device)
171 with torch_device_fn.device(inp.device):
172 argmax_kernel_1[grid](
173 inp,
174 mid_value,
175 mid_index,
176 real_size,
177 M,
178 INT64_INDEX=use_int64_index,
179 )
180 argmax_kernel_2[(1, 1, 1)](
181 mid_value, mid_index, out, real_size, mid_size
182 )
183 return out
184 else:
185 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
186 shape = inp.shape
187 dim = dim % inp.ndim
188 if inp.numel() == 0:
189 out_shape = list(shape)
190 if keepdim:
191 out_shape[dim] = 1
192 else:
193 del out_shape[dim]
194 return torch.zeros(out_shape, dtype=torch.int64, device=inp.device)
195 N = shape[dim]
196 M = math.prod(shape[:dim])
197 K = inp.numel() // M // N
199 inp = inp.contiguous()
200 use_int64_index = not can_use_int32_index(inp)
202 shape_list = list(shape)
203 shape_list[dim] = 1
204 out_index = torch.empty(shape_list, dtype=torch.int64, device=inp.device)
205 if not keepdim:
206 out_index = torch.squeeze(out_index, dim)
208 grid = lambda meta: (
209 triton.cdiv(M, meta["BLOCK_M"]),
210 K,
211 )
212 with torch_device_fn.device(inp.device):
213 argmax_kernel[grid](inp, out_index, M, N, K, INT64_INDEX=use_int64_index)
215 return out_index