Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/argmin.py: 0%
96 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry
11from flag_gems.utils import triton_lang_extension as tle
13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
14torch_dtype_to_tl_dtype_and_max_value = {
15 torch.int16: (tl.int16, torch.iinfo(torch.int16).max),
16 torch.int32: (tl.int32, torch.iinfo(torch.int32).max),
17 torch.float16: (tl.float16, torch.finfo(torch.float16).max),
18 torch.float32: (tl.float32, torch.finfo(torch.float32).max),
19 torch.bfloat16: (tl.float32, torch.finfo(torch.float32).max),
20}
23@libentry()
24@triton.jit
25def argmin_kernel_1(
26 inp,
27 mid_value,
28 mid_index,
29 M,
30 BLOCK_SIZE: tl.constexpr,
31 dtype_max_value: tl.constexpr,
32):
33 pid = tle.program_id(0)
34 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
35 inp_ptrs = inp + offset
36 mask = offset < M
37 inp_val = tl.load(inp_ptrs, mask=mask, other=dtype_max_value)
38 min_val, min_index = tl.min(inp_val, axis=0, return_indices=True)
39 min_index = min_index + pid * BLOCK_SIZE
40 mid_value_ptr = mid_value + pid
41 min_index_ptr = mid_index + pid
42 tl.store(mid_value_ptr, min_val)
43 tl.store(min_index_ptr, min_index)
46@libentry()
47@triton.jit
48def argmin_kernel_2(
49 mid_value,
50 mid_index,
51 out,
52 mid_size,
53 BLOCK_MID: tl.constexpr,
54 dtype_max_value: tl.constexpr,
55):
56 offset = tl.arange(0, BLOCK_MID)
57 mid_ptrs = mid_value + offset
58 mask = offset < mid_size
59 mid_val = tl.load(mid_ptrs, mask=mask, other=dtype_max_value)
60 index_val = tl.argmin(mid_val, axis=0)
61 mid_index_ptrs = mid_index + index_val
62 out_val = tl.load(mid_index_ptrs)
63 tl.store(out, out_val)
66def heur_block_n(args):
67 return min(4096, triton.next_power_of_2(args["N"]))
70@libentry()
71@triton.heuristics(runtime.get_heuristic_config("argmin"))
72@triton.jit
73def argmin_kernel(
74 inp,
75 out_index,
76 M,
77 N,
78 K,
79 tl_dtype: tl.constexpr,
80 dtype_max_value: tl.constexpr,
81 BLOCK_M: tl.constexpr,
82 BLOCK_N: tl.constexpr,
83):
84 # set offset
85 pid_m = tle.program_id(0)
86 pid_k = tle.program_id(1)
87 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
89 # min_values = tl.full([BLOCK_M], dtype=tl.float32, value=float("inf"))
90 if tl_dtype is tl.int16:
91 tl_dtype = tl.int32
92 n_offset = tl.arange(0, BLOCK_N)
93 offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k
94 offset_index = m_offset * K + pid_k
95 # set mask
96 mask1 = m_offset < M
97 mask = m_offset[:, None] < M and n_offset[None, :] < N
98 inp_ptrs = inp + offset
99 inp_vals = tl.load(inp_ptrs, mask=mask, other=-float("inf"))
100 # inp_vals = tl.where(mask, inp_vals, -float("inf"))
101 _, result_index = tl.min(inp_vals, axis=1, return_indices=True)
103 out_index_ptrs = out_index + offset_index
105 tl.store(out_index_ptrs, result_index, mask=mask1)
108def argmin(inp, dim=None, keepdim=False, *, dtype=None):
109 logger.debug("GEMS argmin")
110 if dim is None:
111 M = inp.numel()
112 if dtype is None:
113 dtype = inp.dtype
114 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M)))
115 mid_size = triton.cdiv(M, block_size)
116 block_mid = triton.next_power_of_2(mid_size)
118 mid_value = torch.empty((mid_size,), dtype=dtype, device=inp.device)
119 mid_index = torch.empty((mid_size,), dtype=torch.int64, device=inp.device)
120 if keepdim:
121 shape = list(inp.shape)
122 for i in range(0, inp.dim()):
123 shape[i] = 1
124 out = torch.empty(shape, dtype=torch.int64, device=inp.device)
125 else:
126 out = torch.empty([], dtype=torch.int64, device=inp.device)
128 tl_dtype, dtype_max_value = torch_dtype_to_tl_dtype_and_max_value[inp.dtype]
129 with torch_device_fn.device(inp.device):
130 argmin_kernel_1[(mid_size, 1, 1)](
131 inp,
132 mid_value,
133 mid_index,
134 M,
135 block_size,
136 dtype_max_value,
137 )
138 argmin_kernel_2[(1, 1, 1)](
139 mid_value,
140 mid_index,
141 out,
142 mid_size,
143 block_mid,
144 dtype_max_value,
145 )
146 return out
147 else:
148 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
149 shape = inp.shape
150 dim = dim % inp.ndim
151 N = shape[dim]
152 M = math.prod(shape[:dim])
153 K = inp.numel() // M // N
155 inp = inp.contiguous()
157 shape_list = list(shape)
158 shape_list[dim] = 1
159 out_index = torch.empty(shape_list, dtype=torch.int64, device=inp.device)
160 if not keepdim:
161 out_index = torch.squeeze(out_index, dim)
163 tl_dtype, dtype_max_value = torch_dtype_to_tl_dtype_and_max_value[inp.dtype]
165 grid = lambda meta: (
166 triton.cdiv(M, meta["BLOCK_M"]),
167 K,
168 )
169 with torch_device_fn.device(inp.device):
170 argmin_kernel[grid](
171 inp,
172 out_index,
173 M,
174 N,
175 K,
176 tl_dtype,
177 dtype_max_value,
178 )
180 return out_index