Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/max.py: 0%
127 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
1import logging
2import math
3import os
4from collections import namedtuple
6import torch
7import triton
8import triton.language as tl
10# from flag_gems import runtime
11from flag_gems.runtime import torch_device_fn
12from flag_gems.utils import libentry
13from flag_gems.utils import triton_lang_extension as tle
14from flag_gems.utils.limits import get_dtype_min
16from ..utils.block_size_utils import get_block_size_1d
18logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
21@libentry()
22@triton.jit
23def max_kernel_1(
24 inp,
25 mid,
26 M,
27 BLOCK_SIZE: tl.constexpr,
28):
29 pid = tle.program_id(0)
30 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
31 inp_ptrs = inp + offset
32 mask = offset < M
33 min_value = get_dtype_min(inp.type.element_ty)
34 inp_val = tl.load(inp_ptrs, mask=mask, other=min_value)
35 max_val = tl.max(inp_val)
36 mid_ptr = mid + pid
37 tl.store(mid_ptr, max_val)
40@libentry()
41@triton.jit
42def max_kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr):
43 offset = tl.arange(0, BLOCK_MID)
44 mid_ptrs = mid + offset
45 mask = offset < mid_size
46 min_value = get_dtype_min(mid.type.element_ty)
47 mid_val = tl.load(mid_ptrs, mask=mask, other=min_value)
48 max_val = tl.max(mid_val)
49 tl.store(out, max_val)
52def heur_m_block_size(args):
53 return triton.next_power_of_2(triton.cdiv(args["M"], 12)) # cluster_num
56def heur_n_block_size(args):
57 import builtins
59 return builtins.min(triton.next_power_of_2(args["N"]), 8192)
62# def heur_m_block_size(args):
63# # if triton.next_power_of_2(triton.cdiv(args["M"], cluster_num)) < core_num:
64# # return triton.next_power_of_2(triton.cdiv(args["M"], cluster_num))
65# # else:
66# return (
67# triton.cdiv(triton.cdiv(2048, args["ELEMENT_SIZE"]), args["N"])
68# * 64
69# )
72# def heur_n_block_size(args):
73# return min(args["N"], triton.cdiv(2048, args["ELEMENT_SIZE"]))
76@libentry()
77# @triton.autotune(
78# configs=runtime.get_tuned_config("max"),
79# key=[
80# "M",
81# "N",
82# ],
83# )
84@triton.heuristics(
85 values={
86 "BLOCK_M": heur_m_block_size,
87 "BLOCK_N": heur_n_block_size,
88 },
89)
90@triton.jit
91def max_kernel(
92 inp,
93 out_value,
94 out_index,
95 M: tl.constexpr,
96 N: tl.constexpr,
97 K: tl.constexpr,
98 ELEMENT_SIZE: tl.constexpr,
99 BLOCK_M: tl.constexpr,
100 BLOCK_N: tl.constexpr,
101):
102 # set offset
103 pid_m = tle.program_id(0)
104 pid_k = tle.program_id(1)
105 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
107 dtype = inp.type.element_ty
108 acc_type = tl.float32 if dtype is tl.bfloat16 else dtype
109 min_value = get_dtype_min(dtype)
110 result_value = tl.full([BLOCK_M], value=min_value, dtype=acc_type)
111 result_index = tl.zeros([BLOCK_M], dtype=tl.int64)
112 for i in range(0, N, BLOCK_N):
113 n_offset = i + tl.arange(0, BLOCK_N)
114 offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k
115 # set mask
116 mask = m_offset[:, None] < M and n_offset[None, :] < N
117 inp_ptrs = inp + offset
118 inp_vals = tl.load(inp_ptrs, mask=mask, other=min_value)
119 max_value, max_index = tl.max(inp_vals, axis=1, return_indices=True)
120 update_mask = max_value > result_value
121 result_value = tl.where(update_mask, max_value, result_value)
122 result_index = tl.where(update_mask, i + max_index, result_index)
123 mask1 = m_offset < M
124 offset_index = m_offset * K + pid_k
125 out_value_ptrs = out_value + offset_index
126 out_index_ptrs = out_index + offset_index
128 tl.store(out_value_ptrs, result_value, mask=mask1)
129 tl.store(out_index_ptrs, result_index, mask=mask1)
132def max(inp):
133 logger.debug("GEMS MAX")
134 os.environ["TRITONXPU_IS_SCATTER_SLICE"] = "1"
135 inp = inp.contiguous()
136 M = inp.numel()
137 # block_size = triton.next_power_of_2(math.ceil(math.sqrt(M)))
138 block_size = get_block_size_1d(M, inp.element_size())
139 mid_size = triton.cdiv(M, block_size)
140 block_mid = triton.next_power_of_2(mid_size)
142 dtype = inp.dtype
143 mid = torch.empty((mid_size,), dtype=dtype, device=inp.device)
144 out = torch.empty([], dtype=dtype, device=inp.device)
145 if M == 1:
146 return inp.reshape([])
147 with torch_device_fn.device(inp.device):
148 max_kernel_1[(mid_size, 1, 1)](inp, mid, M, block_size, buffer_size_limit=2048)
149 if mid_size == 1:
150 return mid.reshape([])
152 os.environ["TRITONXPU_OTHER_SIM"] = "1"
153 os.environ["TRITONXPU_STORE_MASK_SIM"] = "1"
155 max_kernel_2[(1, 1, 1)](mid, out, mid_size, block_mid, buffer_size_limit=2048)
157 if "TRITONXPU_OTHER_SIM" in os.environ:
158 del os.environ["TRITONXPU_OTHER_SIM"]
159 if "TRITONXPU_STORE_MASK_SIM" in os.environ:
160 del os.environ["TRITONXPU_STORE_MASK_SIM"]
162 if "TRITONXPU_IS_SCATTER_SLICE" in os.environ:
163 del os.environ["TRITONXPU_IS_SCATTER_SLICE"]
164 return out
167def max_dim(inp, dim=None, keepdim=False):
168 logger.debug("GEMS MAX DIM")
169 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
170 shape = inp.shape
171 dim = dim % inp.ndim
172 N = shape[dim]
173 M = math.prod(shape[:dim])
174 K = inp.numel() // M // N
175 ELEMENT_SIZE = inp.element_size()
177 inp = inp.contiguous()
179 shape_list = list(shape)
180 shape_list[dim] = 1
181 out_value = torch.empty(shape_list, dtype=inp.dtype, device=inp.device)
182 out_index = torch.empty(shape_list, dtype=torch.int64, device=inp.device)
184 if not keepdim:
185 out_value = torch.squeeze(out_value, dim)
186 out_index = torch.squeeze(out_index, dim)
188 grid = lambda meta: (
189 triton.cdiv(M, meta["BLOCK_M"]),
190 K,
191 )
192 os.environ["TRITONXPU_OTHER_SIM"] = "1"
193 os.environ["TRITONXPU_STORE_MASK_SIM"] = "1"
194 isCloseCoreTiling = False
195 if inp.dtype in [torch.int16, torch.int32, torch.int64] and M == 4096 and N == 256:
196 isCloseCoreTiling = True
198 with torch_device_fn.device(inp.device):
199 max_kernel[grid](
200 inp,
201 out_value,
202 out_index,
203 M,
204 N,
205 K,
206 ELEMENT_SIZE,
207 isCloseCoreTiling=isCloseCoreTiling,
208 )
210 if "TRITONXPU_OTHER_SIM" in os.environ:
211 del os.environ["TRITONXPU_OTHER_SIM"]
212 if "TRITONXPU_STORE_MASK_SIM" in os.environ:
213 del os.environ["TRITONXPU_STORE_MASK_SIM"]
214 Max_out = namedtuple("max", ["values", "indices"])
215 out = Max_out(values=out_value, indices=out_index)
216 return out