Coverage for src/flag_gems/runtime/backend/_cambricon/ops/min.py: 0%
170 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-17 02:35 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-17 02:35 +0800
1import logging
2import math
3from collections import namedtuple
5import torch
6import triton
7import triton.language as tl
9from flag_gems import runtime
10from flag_gems.runtime import torch_device_fn
11from flag_gems.utils import libentry
12from flag_gems.utils.limits import get_dtype_max
14from ..utils import TOTAL_CORE_NUM, cfggen_reduce_op, prune_reduce_config
16logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
19@libentry()
20@triton.jit
21def min_kernel_float_once(
22 inp,
23 out,
24 M: tl.constexpr,
25):
26 offset = tl.arange(0, M)
27 inp_val = tl.load(inp + offset)
28 min_val = tl.min(inp_val, 0)
29 tl.store(out, min_val)
32@libentry()
33@triton.autotune(
34 configs=cfggen_reduce_op(),
35 key=["M"],
36 prune_configs_by={"early_config_prune": prune_reduce_config},
37)
38@triton.heuristics(
39 values={
40 "ONE_TILE_PER_CTA": lambda args: args["M"]
41 <= args["BLOCK_SIZE"] * TOTAL_CORE_NUM
42 }
43)
44@triton.jit
45def min_kernel_float(
46 inp, out, M, BLOCK_SIZE: tl.constexpr, ONE_TILE_PER_CTA: tl.constexpr
47):
48 pid = tl.program_id(0)
49 block_start = pid * BLOCK_SIZE
50 res = float("inf")
51 if ONE_TILE_PER_CTA:
52 offset = block_start + tl.arange(0, BLOCK_SIZE)
53 mask = offset < M
54 inp_val = tl.load(inp + offset, mask=mask, other=float("inf"))
55 (res,) = tl.min(inp_val, 0, return_indices=True)
56 else:
57 num_jobs = tl.num_programs(axis=0)
58 step = num_jobs * BLOCK_SIZE
59 _tmp = tl.full([BLOCK_SIZE], value=float("inf"), dtype=inp.dtype.element_ty)
60 for off in range(block_start, M, step):
61 offset = off + tl.arange(0, BLOCK_SIZE)
62 mask = offset < M
63 inp_val = tl.load(inp + offset, mask=mask, other=float("inf"))
64 _tmp = tl.where((inp_val < _tmp), inp_val, _tmp)
65 (res,) = tl.min(_tmp, 0, return_indices=True)
66 tl.atomic_min(out, res)
69@libentry()
70@triton.autotune(
71 configs=cfggen_reduce_op(),
72 key=["M"],
73 prune_configs_by={"early_config_prune": prune_reduce_config},
74)
75@triton.heuristics(
76 values={
77 "ONE_TILE_PER_CTA": lambda args: args["M"]
78 <= args["BLOCK_SIZE"] * TOTAL_CORE_NUM
79 }
80)
81@triton.jit
82def min_kernel_int(
83 inp, out, FILL_VALUE, M, BLOCK_SIZE: tl.constexpr, ONE_TILE_PER_CTA: tl.constexpr
84):
85 pid = tl.program_id(0)
86 block_start = pid * BLOCK_SIZE
87 res = FILL_VALUE
88 if ONE_TILE_PER_CTA:
89 offset = block_start + tl.arange(0, BLOCK_SIZE)
90 mask = offset < M
91 inp_val = tl.load(inp + offset, mask=mask, other=FILL_VALUE)
92 res = tl.min(inp_val)
93 else:
94 num_jobs = tl.num_programs(axis=0)
95 step = num_jobs * BLOCK_SIZE
96 block_start = block_start.to(tl.int64)
97 _tmp = tl.full([BLOCK_SIZE], value=2**31 - 1, dtype=tl.int32)
98 for off in range(block_start, M, step):
99 offset = off + tl.arange(0, BLOCK_SIZE)
100 mask = offset < M
101 inp_val = tl.load(inp + offset, mask=mask, other=FILL_VALUE)
102 _tmp = tl.where((inp_val < _tmp), inp_val, _tmp)
103 res = tl.min(_tmp)
104 tl.atomic_min(out, res)
107@libentry()
108@triton.autotune(
109 configs=cfggen_reduce_op(),
110 key=["M"],
111 prune_configs_by={"early_config_prune": prune_reduce_config},
112)
113@triton.heuristics(
114 values={
115 "ONE_TILE_PER_CTA": lambda args: args["M"]
116 <= args["BLOCK_SIZE"] * TOTAL_CORE_NUM
117 }
118)
119@triton.jit
120def min_kernel_int64_1(
121 inp, mid, M, BLOCK_SIZE: tl.constexpr, ONE_TILE_PER_CTA: tl.constexpr
122):
123 pid = tl.program_id(0)
124 block_start = pid * BLOCK_SIZE
125 # FILL_VALUE is the maximum value of a 64-bit integer, used as the initial value for calculations.
126 FILL_VALUE = 2**63 - 1
127 res = FILL_VALUE
128 if ONE_TILE_PER_CTA:
129 offset = block_start + tl.arange(0, BLOCK_SIZE)
130 mask = offset < M
131 inp_val = tl.load(inp + offset, mask=mask, other=FILL_VALUE)
132 res = tl.min(inp_val)
133 else:
134 num_jobs = tl.num_programs(axis=0)
135 step = num_jobs * BLOCK_SIZE
136 block_start = block_start.to(tl.int64)
137 _tmp = tl.full([BLOCK_SIZE], value=FILL_VALUE, dtype=tl.int64)
138 for off in range(block_start, M, step):
139 offset = off + tl.arange(0, BLOCK_SIZE)
140 mask = offset < M
141 inp_val = tl.load(inp + offset, mask=mask, other=FILL_VALUE)
142 _tmp = tl.where((inp_val < _tmp), inp_val, _tmp)
143 res = tl.min(_tmp)
144 tl.store(mid + pid, res)
147@libentry()
148@triton.jit
149def min_kernel_int64_2(mid, out, BLOCK_NUM: tl.constexpr):
150 offset = tl.arange(0, BLOCK_NUM)
151 mid_val = tl.load(mid + offset)
152 out_val = tl.min(mid_val)
153 tl.store(out, out_val)
156def heur_block_n(args):
157 return triton.next_power_of_2(args["N"])
160@libentry()
161@triton.autotune(
162 configs=runtime.get_tuned_config("min"),
163 key=[
164 "M",
165 "N",
166 ],
167)
168@triton.jit
169def min_kernel(
170 inp,
171 out_value,
172 out_index,
173 M,
174 N,
175 K,
176 BLOCK_M: tl.constexpr,
177 BLOCK_N: tl.constexpr,
178):
179 # set offset
180 pid_m = tl.program_id(0)
181 pid_k = tl.program_id(1)
182 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
184 min_values = tl.full([BLOCK_M], dtype=tl.float32, value=float("inf"))
185 argmin_values = tl.full([BLOCK_M], dtype=tl.int64, value=0)
186 max_value = get_dtype_max(inp.type.element_ty)
187 for start_n in range(0, N, BLOCK_N):
188 n_offset = start_n + tl.arange(0, BLOCK_N)
189 offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k
190 mask = m_offset[:, None] < M and n_offset[None, :] < N
191 inp_ptrs = inp + offset
192 inp_vals = tl.load(inp_ptrs, mask=mask, other=max_value)
193 local_min, local_argmin = tl.min(inp_vals, 1, return_indices=True)
194 # if return indices is not supported, call a tl.argmax in addition
195 # local_argmin = tl.argmin(inp_vals, 1)
196 update = local_min < min_values
197 min_values = tl.where(update, local_min, min_values)
198 argmin_values = tl.where(update, start_n + local_argmin, argmin_values)
200 offset_index = m_offset * K + pid_k
201 out_value_ptrs = out_value + offset_index
202 out_index_ptrs = out_index + offset_index
203 mask1 = m_offset < M
204 tl.store(out_value_ptrs, min_values, mask=mask1)
205 tl.store(out_index_ptrs, argmin_values, mask=mask1)
208def min(inp):
209 logger.debug("GEMS_CAMBRICON MIN")
210 M = inp.numel()
211 mid_size = TOTAL_CORE_NUM
212 dtype = inp.dtype
213 device = inp.device
215 with torch_device_fn.device(device):
216 if torch.is_floating_point(inp):
217 if M <= 65536:
218 out = torch.empty([], dtype=dtype, device=device)
219 min_kernel_float_once[(1, 1, 1)](inp, out, M)
220 else:
221 out = torch.full([], float("inf"), dtype=torch.float32, device=device)
222 min_kernel_float[(mid_size, 1, 1)](inp, out, M)
223 elif dtype == torch.int64:
224 mid = torch.empty([mid_size], dtype=dtype, device=device)
225 out = torch.empty([], dtype=dtype, device=device)
226 # Because atomic op don't support i64, use two kernels.
227 min_kernel_int64_1[(mid_size, 1, 1)](inp, mid, M, enable_soft_i64=True)
228 min_kernel_int64_2[(1, 1, 1)](
229 mid, out, BLOCK_NUM=mid_size, enable_soft_i64=True
230 )
231 else:
232 fill_value = torch.iinfo(dtype).max
233 out = torch.full([], 2**31 - 1, dtype=torch.int32, device=device)
234 min_kernel_int[(mid_size, 1, 1)](inp, out, fill_value, M)
235 return out.to(dtype)
238def min_dim(inp, dim=None, keepdim=False):
239 logger.debug("GEMS_CAMBRICON MIN DIM")
240 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
241 shape = inp.shape
242 dim = dim % inp.ndim
243 N = shape[dim]
244 M = math.prod(shape[:dim])
245 K = inp.numel() // M // N
247 inp = inp.contiguous()
249 shape_list = list(shape)
250 shape_list[dim] = 1
251 out_value = torch.empty(shape_list, dtype=inp.dtype, device=inp.device)
252 out_index = torch.empty(shape_list, dtype=torch.int64, device=inp.device)
254 if not keepdim:
255 out_value = torch.squeeze(out_value, dim)
256 out_index = torch.squeeze(out_index, dim)
258 grid = lambda meta: (
259 triton.cdiv(M, meta["BLOCK_M"]),
260 K,
261 )
262 with torch_device_fn.device(inp.device):
263 min_kernel[grid](inp, out_value, out_index, M, N, K)
264 Min_out = namedtuple("min", ["values", "indices"])
265 out = Min_out(values=out_value, indices=out_index)
266 return out