Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/min.py: 0%
114 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-18 02:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-18 02:36 +0800
1import logging
2import math
3from collections import namedtuple
5import torch
6import triton
7import triton.language as tl
9# from flag_gems import runtime
10from flag_gems.runtime import torch_device_fn
11from flag_gems.utils import libentry
12from flag_gems.utils import triton_lang_extension as tle
13from flag_gems.utils.limits import get_dtype_max
15from ..utils.block_size_utils import get_block_size_1d
17logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
20@libentry()
21@triton.jit
22def min_kernel_1(
23 inp,
24 mid,
25 M,
26 BLOCK_SIZE: tl.constexpr,
27):
28 pid = tle.program_id(0)
29 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
30 inp_ptrs = inp + offset
31 mask = offset < M
32 max_value = get_dtype_max(inp.type.element_ty)
33 inp_val = tl.load(inp_ptrs, mask=mask, other=max_value)
34 min_val = tl.min(inp_val)
35 mid_ptr = mid + pid
36 tl.store(mid_ptr, min_val)
39@libentry()
40@triton.jit
41def min_kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr):
42 offset = tl.arange(0, BLOCK_MID)
43 mid_ptrs = mid + offset
44 mask = offset < mid_size
45 max_value = get_dtype_max(mid.type.element_ty)
46 mid_val = tl.load(mid_ptrs, mask=mask, other=max_value)
47 min_val = tl.min(mid_val)
48 tl.store(out, min_val)
51def heur_m_block_size(args):
52 return triton.next_power_of_2(triton.cdiv(args["M"], 12)) # cluster_num
55def heur_n_block_size(args):
56 import builtins
58 return builtins.min(triton.next_power_of_2(args["N"]), 8192)
61@libentry()
62# @triton.autotune(
63# configs=runtime.get_tuned_config("min"),
64# key=[
65# "M",
66# "N",
67# ],
68# )
69@triton.heuristics(
70 values={
71 "BLOCK_M": heur_m_block_size,
72 "BLOCK_N": heur_n_block_size,
73 },
74)
75@triton.jit
76def min_kernel(
77 inp,
78 out_value,
79 out_index,
80 M: tl.constexpr,
81 N: tl.constexpr,
82 K: tl.constexpr,
83 BLOCK_M: tl.constexpr,
84 BLOCK_N: tl.constexpr,
85):
86 # set offset
87 pid_m = tle.program_id(0)
88 pid_k = tle.program_id(1)
89 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
91 dtype = inp.type.element_ty
92 # you just cannot create a function that return a tl.dtype in triton lang
93 acc_type = tl.float32 if dtype is tl.bfloat16 else dtype
94 max_value = get_dtype_max(dtype)
95 min_values = tl.full([BLOCK_M], dtype=acc_type, value=max_value)
96 argmin_values = tl.full([BLOCK_M], dtype=tl.int64, value=0)
97 for start_n in range(0, N, BLOCK_N):
98 n_offset = start_n + tl.arange(0, BLOCK_N)
99 offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k
100 mask = m_offset[:, None] < M and n_offset[None, :] < N
101 inp_ptrs = inp + offset
102 inp_vals = tl.load(inp_ptrs, mask=mask, other=max_value)
103 local_min, local_argmin = tl.min(inp_vals, 1, return_indices=True)
104 # if return indices is not supported, call a tl.argmax in addition
105 # local_argmin = tl.argmin(inp_vals, 1)
106 update = local_min < min_values
107 min_values = tl.where(update, local_min, min_values)
108 argmin_values = tl.where(update, start_n + local_argmin, argmin_values)
110 offset_index = m_offset * K + pid_k
111 out_value_ptrs = out_value + offset_index
112 out_index_ptrs = out_index + offset_index
113 mask1 = m_offset < M
114 tl.store(out_value_ptrs, min_values, mask=mask1)
115 tl.store(out_index_ptrs, argmin_values, mask=mask1)
118def min(inp):
119 logger.debug("GEMS MIN")
120 M = inp.numel()
121 # block_size = triton.next_power_of_2(math.ceil(math.sqrt(M)))
122 block_size = get_block_size_1d(M, inp.element_size())
123 mid_size = triton.cdiv(M, block_size)
124 block_mid = triton.next_power_of_2(mid_size)
126 dtype = inp.dtype
127 mid = torch.empty((mid_size,), dtype=dtype, device=inp.device)
128 out = torch.empty([], dtype=dtype, device=inp.device)
130 with torch_device_fn.device(inp.device):
131 min_kernel_1[(mid_size, 1, 1)](inp, mid, M, block_size, buffer_size_limit=2048)
132 if mid_size == 1:
133 return mid.reshape([])
135 import os
137 os.environ["TRITONXPU_OTHER_SIM"] = "1"
138 os.environ["TRITONXPU_STORE_MASK_SIM"] = "1"
140 min_kernel_2[(1, 1, 1)](mid, out, mid_size, block_mid, buffer_size_limit=2048)
142 if "TRITONXPU_OTHER_SIM" in os.environ:
143 del os.environ["TRITONXPU_OTHER_SIM"]
144 if "TRITONXPU_STORE_MASK_SIM" in os.environ:
145 del os.environ["TRITONXPU_STORE_MASK_SIM"]
146 return out
149def min_dim(inp, dim=None, keepdim=False):
150 logger.debug("GEMS MIN DIM")
151 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
152 shape = inp.shape
153 dim = dim % inp.ndim
154 N = shape[dim]
155 M = math.prod(shape[:dim])
156 K = inp.numel() // M // N
158 inp = inp.contiguous()
160 shape_list = list(shape)
161 shape_list[dim] = 1
162 out_value = torch.empty(shape_list, dtype=inp.dtype, device=inp.device)
163 out_index = torch.empty(shape_list, dtype=torch.int64, device=inp.device)
165 if not keepdim:
166 out_value = torch.squeeze(out_value, dim)
167 out_index = torch.squeeze(out_index, dim)
169 grid = lambda meta: (
170 triton.cdiv(M, meta["BLOCK_M"]),
171 K,
172 )
173 isCloseCoreTiling = False
174 if inp.dtype in [torch.int16, torch.int32, torch.int64] and M == 4096 and N == 256:
175 isCloseCoreTiling = True
176 with torch_device_fn.device(inp.device):
177 min_kernel[grid](
178 inp, out_value, out_index, M, N, K, isCloseCoreTiling=isCloseCoreTiling
179 )
180 Min_out = namedtuple("min", ["values", "indices"])
181 out = Min_out(values=out_value, indices=out_index)
182 return out