Coverage for src/flag_gems/runtime/backend/_metax/ops/min.py: 0%
100 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-19 02:32 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-19 02:32 +0800
1import logging
2import math
3from collections import namedtuple
5import torch
6import triton
7import triton.language as tl
9from flag_gems import runtime
10from flag_gems.runtime import torch_device_fn
11from flag_gems.utils import libentry, libtuner
12from flag_gems.utils import triton_lang_extension as tle
13from flag_gems.utils.limits import get_dtype_max
15logger = logging.getLogger("flag_gems." + __name__)
18@libentry()
19@triton.jit
20def min_kernel_1(
21 inp,
22 mid,
23 M,
24 BLOCK_SIZE: tl.constexpr,
25):
26 pid = tle.program_id(0)
27 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
28 inp_ptrs = inp + offset
29 mask = offset < M
30 max_value = get_dtype_max(inp.type.element_ty)
31 inp_val = tl.load(inp_ptrs, mask=mask, other=max_value)
32 min_val = tl.min(inp_val)
33 mid_ptr = mid + pid
34 tl.store(mid_ptr, min_val)
37@libentry()
38@triton.jit
39def min_kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr):
40 offset = tl.arange(0, BLOCK_MID)
41 mid_ptrs = mid + offset
42 mask = offset < mid_size
43 max_value = get_dtype_max(mid.type.element_ty)
44 mid_val = tl.load(mid_ptrs, mask=mask, other=max_value)
45 min_val = tl.min(mid_val)
46 tl.store(out, min_val)
49def heur_block_n(args):
50 return triton.next_power_of_2(args["N"])
53@libentry()
54@libtuner(
55 configs=runtime.get_tuned_config("min"),
56 key=[
57 "M",
58 "N",
59 ],
60)
61@triton.heuristics(
62 {
63 "BLOCK_N": heur_block_n,
64 }
65)
66@triton.jit
67def min_kernel(
68 inp,
69 out_value,
70 out_index,
71 M,
72 N,
73 K,
74 BLOCK_M: tl.constexpr,
75 BLOCK_N: tl.constexpr,
76):
77 # set offset
78 pid_m = tle.program_id(0)
79 pid_k = tle.program_id(1)
80 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
82 dtype = inp.type.element_ty
83 # you just cannot create a function that return a tl.dtype in triton lang
84 acc_type = tl.float32 if dtype is tl.bfloat16 else dtype
85 max_value = get_dtype_max(dtype)
86 min_values = tl.full([BLOCK_M], dtype=acc_type, value=max_value)
87 argmin_values = tl.full([BLOCK_M], dtype=tl.int64, value=0)
88 for start_n in range(0, N, BLOCK_N):
89 n_offset = start_n + tl.arange(0, BLOCK_N)
90 offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k
91 mask = m_offset[:, None] < M and n_offset[None, :] < N
92 inp_ptrs = inp + offset
93 inp_vals = tl.load(inp_ptrs, mask=mask, other=max_value)
94 local_min, local_argmin = tl.min(inp_vals, 1, return_indices=True)
95 # if return indices is not supported, call a tl.argmax in addition
96 # local_argmin = tl.argmin(inp_vals, 1)
97 update = local_min < min_values
98 min_values = tl.where(update, local_min, min_values)
99 argmin_values = tl.where(update, start_n + local_argmin, argmin_values)
101 offset_index = m_offset * K + pid_k
102 out_value_ptrs = out_value + offset_index
103 out_index_ptrs = out_index + offset_index
104 mask1 = m_offset < M
105 tl.store(out_value_ptrs, min_values, mask=mask1)
106 tl.store(out_index_ptrs, argmin_values, mask=mask1)
109def min(inp):
110 logger.debug("METAX GEMS MIN")
111 M = inp.numel()
112 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M)))
113 mid_size = triton.cdiv(M, block_size)
114 block_mid = triton.next_power_of_2(mid_size)
116 dtype = inp.dtype
117 mid = torch.empty((mid_size,), dtype=dtype, device=inp.device)
118 out = torch.empty([], dtype=dtype, device=inp.device)
120 with torch_device_fn.device(inp.device):
121 min_kernel_1[(mid_size, 1, 1)](inp, mid, M, block_size)
122 min_kernel_2[(1, 1, 1)](mid, out, mid_size, block_mid)
123 return out
126def min_dim(inp, dim=None, keepdim=False):
127 logger.debug("METAX GEMS MIN DIM")
128 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
129 shape = inp.shape
130 dim = dim % inp.ndim
131 N = shape[dim]
132 M = math.prod(shape[:dim])
133 K = inp.numel() // M // N
135 inp = inp.contiguous()
137 shape_list = list(shape)
138 shape_list[dim] = 1
139 out_value = torch.empty(shape_list, dtype=inp.dtype, device=inp.device)
140 out_index = torch.empty(shape_list, dtype=torch.int64, device=inp.device)
142 if not keepdim:
143 out_value = torch.squeeze(out_value, dim)
144 out_index = torch.squeeze(out_index, dim)
146 grid = lambda meta: (
147 triton.cdiv(M, meta["BLOCK_M"]),
148 K,
149 )
150 with torch_device_fn.device(inp.device):
151 min_kernel[grid](inp, out_value, out_index, M, N, K)
152 Min_out = namedtuple("min", ["values", "indices"])
153 out = Min_out(values=out_value, indices=out_index)
154 return out