Coverage for src/flag_gems/runtime/backend/_ascend/ops/argmax.py: 0%
148 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-17 02:35 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-17 02:35 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry
11from flag_gems.utils import triton_lang_extension as tle
12from flag_gems.utils.limits import get_dtype_min
14logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}')
17@libentry()
18@triton.jit
19def argmax_kernel_1(
20 inp,
21 mid_value,
22 mid_index,
23 M,
24 BLOCK_SIZE: tl.constexpr,
25):
26 pid = tle.program_id(0)
27 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
28 inp_ptrs = inp + offset
29 mask = offset < M
30 min_value = get_dtype_min(inp.type.element_ty)
31 inp_val = tl.load(inp_ptrs, mask=mask, other=min_value)
32 max_val, max_index = tl.max(inp_val, axis=0, return_indices=True)
33 max_index = max_index + pid * BLOCK_SIZE
34 mid_value_ptr = mid_value + pid
35 max_index_ptr = mid_index + pid
36 tl.store(mid_value_ptr, max_val)
37 tl.store(max_index_ptr, max_index)
40@libentry()
41@triton.jit
42def argmax_kernel_2(mid_value, mid_index, out, mid_size, BLOCK_MID: tl.constexpr):
43 offset = tl.arange(0, BLOCK_MID)
44 mid_ptrs = mid_value + offset
45 mask = offset < mid_size
46 min_value = get_dtype_min(mid_value.type.element_ty)
47 mid_val = tl.load(mid_ptrs, mask=mask, other=min_value)
48 index_val = tl.argmax(mid_val, axis=0)
49 mid_index_ptrs = mid_index + index_val
50 out_val = tl.load(mid_index_ptrs)
51 tl.store(out, out_val)
54@libentry()
55@triton.heuristics(runtime.get_heuristic_config("argmax_non_inner"))
56@triton.jit
57def argmax_kernel_non_inner(
58 inp,
59 out_index,
60 M,
61 N,
62 K,
63 TILE_K: tl.constexpr,
64 TILE_N: tl.constexpr,
65 ONE_TILE_PER_CTA: tl.constexpr,
66):
67 pid_m = tle.program_id(0)
68 pid_k = tle.program_id(1)
69 k_offset = pid_k * TILE_K + tl.arange(0, TILE_K)
71 if tl.constexpr(inp.dtype.element_ty == tl.float16) or tl.constexpr(
72 inp.dtype.element_ty == tl.bfloat16
73 ):
74 cdtype = tl.float32
75 else:
76 cdtype = inp.dtype.element_ty
78 min_value = get_dtype_min(cdtype)
80 if ONE_TILE_PER_CTA:
81 n_offset = tl.arange(0, TILE_N)
82 offset = pid_m * N * K + n_offset[:, None] * K + k_offset
83 mask = k_offset < K and n_offset[:, None] < N
84 inp_ptrs = inp + offset
85 inp_vals = tl.load(inp_ptrs, mask=mask, other=min_value)
86 local_max, local_argmax = tl.max(
87 inp_vals, 0, return_indices=True, return_indices_tie_break_left=True
88 )
89 offset_index = pid_m * K + k_offset
90 out_index_ptrs = out_index + offset_index
91 mask1 = k_offset < K
92 tl.store(out_index_ptrs, local_argmax, mask=mask1)
93 else:
94 max_values = tl.full([TILE_K], dtype=cdtype, value=min_value)
95 argmax_values = tl.full([TILE_K], dtype=tl.int64, value=0)
97 for start_n in range(0, N, TILE_N):
98 n_offset = start_n + tl.arange(0, TILE_N)
99 offset = pid_m * N * K + n_offset[:, None] * K + k_offset
100 mask = k_offset < K and n_offset[:, None] < N
101 inp_ptrs = inp + offset
102 inp_vals = tl.load(inp_ptrs, mask=mask, other=min_value)
103 local_max, local_argmax = tl.max(
104 inp_vals, 0, return_indices=True, return_indices_tie_break_left=True
105 )
106 update = local_max > max_values
107 max_values = tl.where(update, local_max, max_values)
108 argmax_values = tl.where(update, start_n + local_argmax, argmax_values)
109 offset_index = pid_m * K + k_offset
110 out_index_ptrs = out_index + offset_index
111 mask1 = k_offset < K
112 tl.store(out_index_ptrs, argmax_values, mask=mask1)
115@libentry()
116@triton.heuristics(runtime.get_heuristic_config("argmax"))
117@triton.jit
118def argmax_kernel(
119 inp,
120 out_index,
121 M,
122 N,
123 K,
124 BLOCK_M: tl.constexpr,
125 BLOCK_N: tl.constexpr,
126):
127 # set offset
128 pid_m = tle.program_id(0)
129 # pid_k = tle.program_id(1)
130 for pid_k in range(K):
131 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
133 dtype = inp.type.element_ty
134 acc_type = tl.float32 if dtype is tl.bfloat16 else dtype
135 min_value = get_dtype_min(dtype)
136 max_values = tl.full([BLOCK_M], dtype=acc_type, value=min_value)
137 argmax_values = tl.full([BLOCK_M], dtype=tl.int64, value=0)
138 for start_n in range(0, N, BLOCK_N):
139 n_offset = start_n + tl.arange(0, BLOCK_N)
140 offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k
141 mask = m_offset[:, None] < M and n_offset[None, :] < N
142 inp_ptrs = inp + offset
143 inp_vals = tl.load(inp_ptrs, mask=mask, other=min_value)
144 local_max, local_argmax = tl.max(
145 inp_vals, 1, return_indices=True, return_indices_tie_break_left=True
146 )
147 # if return indices is not supported, call a tl.argmax in addition
148 # local_argmax = tl.argmax(inp_vals, 1)
149 update = local_max > max_values
150 max_values = tl.where(update, local_max, max_values)
151 argmax_values = tl.where(update, start_n + local_argmax, argmax_values)
153 offset_index = m_offset * K + pid_k
154 out_index_ptrs = out_index + offset_index
155 mask1 = m_offset < M
156 tl.store(out_index_ptrs, argmax_values, mask=mask1)
159def argmax(inp, dim=None, keepdim=False, *, dtype=None):
160 logger.debug("GEMS_ASCEND ARGMAX")
161 if dim is None:
162 M = inp.numel()
163 if dtype is None:
164 dtype = inp.dtype
165 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M)))
166 mid_size = triton.cdiv(M, block_size)
167 block_mid = triton.next_power_of_2(mid_size)
169 mid_value = torch.empty((mid_size,), dtype=dtype, device=inp.device)
170 mid_index = torch.empty((mid_size,), dtype=torch.int64, device=inp.device)
171 if keepdim:
172 shape = list(inp.shape)
173 for i in range(0, inp.dim()):
174 shape[i] = 1
175 out = torch.empty(shape, dtype=torch.int64, device=inp.device)
176 else:
177 out = torch.empty([], dtype=torch.int64, device=inp.device)
179 with torch_device_fn.device(inp.device):
180 argmax_kernel_1[(mid_size, 1, 1)](
181 inp,
182 mid_value,
183 mid_index,
184 M,
185 block_size,
186 )
187 argmax_kernel_2[(1, 1, 1)](mid_value, mid_index, out, mid_size, block_mid)
188 return out
189 else:
190 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
191 shape = inp.shape
192 dim = dim % inp.ndim
193 if inp.numel() == 0:
194 out_shape = list(shape)
195 if keepdim:
196 out_shape[dim] = 1
197 else:
198 del out_shape[dim]
199 return torch.zeros(out_shape, dtype=torch.int64, device=inp.device)
200 N = shape[dim]
201 M = math.prod(shape[:dim])
202 K = inp.numel() // M // N
204 inp = inp.contiguous()
206 shape_list = list(shape)
207 shape_list[dim] = 1
208 out_index = torch.empty(shape_list, dtype=torch.int64, device=inp.device)
209 if not keepdim:
210 out_index = torch.squeeze(out_index, dim)
212 with torch_device_fn.device(inp.device):
213 if K > 1:
214 grid = lambda meta: (
215 M,
216 triton.cdiv(K, meta["TILE_K"]),
217 )
218 argmax_kernel_non_inner[grid](
219 inp,
220 out_index,
221 M,
222 N,
223 K,
224 )
225 else:
226 grid = lambda meta: (
227 triton.cdiv(M, meta["BLOCK_M"]),
228 # K,
229 )
230 argmax_kernel[grid](
231 inp,
232 out_index,
233 M,
234 N,
235 K,
236 )
237 return out_index