Coverage for src/flag_gems/runtime/backend/_ascend/ops/argmin.py: 0%
101 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-18 02:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-18 02:36 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry
11from flag_gems.utils import triton_lang_extension as tle
12from flag_gems.utils.limits import get_dtype_max
14logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}')
17@libentry()
18@triton.jit
19def argmin_kernel_1(
20 inp,
21 mid_value,
22 mid_index,
23 M,
24 BLOCK_SIZE: tl.constexpr,
25):
26 pid = tle.program_id(0)
27 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
28 inp_ptrs = inp + offset
29 mask = offset < M
31 max_value = get_dtype_max(inp.type.element_ty)
32 inp_val = tl.load(inp_ptrs, mask=mask, other=max_value)
33 min_val, min_index = tl.min(inp_val, axis=0, return_indices=True)
34 min_index = min_index + pid * BLOCK_SIZE
35 mid_value_ptr = mid_value + pid
36 min_index_ptr = mid_index + pid
37 tl.store(mid_value_ptr, min_val)
38 tl.store(min_index_ptr, min_index)
41@libentry()
42@triton.jit
43def argmin_kernel_2(
44 mid_value,
45 mid_index,
46 out,
47 mid_size,
48 BLOCK_MID: tl.constexpr,
49):
50 offset = tl.arange(0, BLOCK_MID)
51 mid_ptrs = mid_value + offset
52 mask = offset < mid_size
53 max_value = get_dtype_max(mid_value.type.element_ty)
54 mid_val = tl.load(mid_ptrs, mask=mask, other=max_value)
55 index_val = tl.argmin(mid_val, axis=0)
56 mid_index_ptrs = mid_index + index_val
57 out_val = tl.load(mid_index_ptrs)
58 tl.store(out, out_val)
61@libentry()
62@triton.heuristics(runtime.get_heuristic_config("argmin"))
63@triton.jit
64def argmin_kernel(
65 inp,
66 out_index,
67 M,
68 N,
69 K,
70 BLOCK_M: tl.constexpr,
71 BLOCK_N: tl.constexpr,
72):
73 # set offset
74 pid_m = tle.program_id(0)
75 # pid_k = tle.program_id(1)
76 for pid_k in range(K):
77 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
79 dtype = inp.type.element_ty
80 acc_type = tl.float32 if dtype is tl.bfloat16 else dtype
81 max_value = get_dtype_max(dtype)
82 min_values = tl.full([BLOCK_M], dtype=acc_type, value=max_value)
83 argmin_values = tl.full([BLOCK_M], dtype=tl.int64, value=0)
84 for start_n in range(0, N, BLOCK_N):
85 n_offset = start_n + tl.arange(0, BLOCK_N)
86 offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k
87 mask = m_offset[:, None] < M and n_offset[None, :] < N
88 inp_ptrs = inp + offset
89 inp_vals = tl.load(inp_ptrs, mask=mask, other=max_value)
90 # tl.bfloat is promoted to tl.float32 by tl.min
91 local_min, local_argmin = tl.min(
92 inp_vals, 1, return_indices=True, return_indices_tie_break_left=True
93 )
94 # if return indices is not supported, call a tl.argmin in addition
95 # local_argmin = tl.argmin(inp_vals, 1)
96 update = local_min < min_values
97 min_values = tl.where(update, local_min, min_values)
98 argmin_values = tl.where(update, start_n + local_argmin, argmin_values)
100 offset_index = m_offset * K + pid_k
101 out_index_ptrs = out_index + offset_index
102 mask1 = m_offset < M
103 tl.store(out_index_ptrs, argmin_values, mask=mask1)
106def argmin(inp, dim=None, keepdim=False, *, dtype=None):
107 logger.debug("GEMS_ASCEND ARGMIN")
108 if dim is None:
109 M = inp.numel()
110 if dtype is None:
111 dtype = inp.dtype
112 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M)))
113 mid_size = triton.cdiv(M, block_size)
114 block_mid = triton.next_power_of_2(mid_size)
116 mid_value = torch.empty((mid_size,), dtype=dtype, device=inp.device)
117 mid_index = torch.empty((mid_size,), dtype=torch.int64, device=inp.device)
118 if keepdim:
119 shape = list(inp.shape)
120 for i in range(0, inp.dim()):
121 shape[i] = 1
122 out = torch.empty(shape, dtype=torch.int64, device=inp.device)
123 else:
124 out = torch.empty([], dtype=torch.int64, device=inp.device)
126 with torch_device_fn.device(inp.device):
127 argmin_kernel_1[(mid_size, 1, 1)](
128 inp,
129 mid_value,
130 mid_index,
131 M,
132 block_size,
133 )
134 argmin_kernel_2[(1, 1, 1)](
135 mid_value,
136 mid_index,
137 out,
138 mid_size,
139 block_mid,
140 )
141 return out
142 else:
143 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
144 shape = inp.shape
145 dim = dim % inp.ndim
146 N = shape[dim]
147 M = math.prod(shape[:dim])
148 K = inp.numel() // M // N
150 inp = inp.contiguous()
152 shape_list = list(shape)
153 shape_list[dim] = 1
154 out_index = torch.empty(shape_list, dtype=torch.int64, device=inp.device)
155 if not keepdim:
156 out_index = torch.squeeze(out_index, dim)
158 grid = lambda meta: (
159 triton.cdiv(M, meta["BLOCK_M"]),
160 # K,
161 )
162 with torch_device_fn.device(inp.device):
163 argmin_kernel[grid](
164 inp,
165 out_index,
166 M,
167 N,
168 K,
169 )
171 return out_index