Coverage for src/flag_gems/runtime/backend/_cambricon/ops/max.py: 0%
177 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
1import logging
2import math
3from collections import namedtuple
5import torch
6import triton
7import triton.language as tl
9from flag_gems import runtime
10from flag_gems.runtime import torch_device_fn
11from flag_gems.utils import libentry, libtuner
12from flag_gems.utils.limits import get_dtype_min
14from ..utils import TOTAL_CORE_NUM, cfggen_reduce_op, prune_reduce_config
16logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
19@libentry()
20@triton.jit
21def max_kernel_float_once(
22 inp,
23 out,
24 M: tl.constexpr,
25):
26 offset = tl.arange(0, M)
27 inp_val = tl.load(inp + offset)
28 max_val = tl.max(inp_val, 0)
29 tl.store(out, max_val)
32@libentry()
33@libtuner(
34 configs=cfggen_reduce_op(),
35 key=["M"],
36 strategy=["log"],
37 prune_configs_by={"early_config_prune": prune_reduce_config},
38)
39@triton.heuristics(
40 values={
41 "ONE_TILE_PER_CTA": lambda args: args["M"]
42 <= args["BLOCK_SIZE"] * TOTAL_CORE_NUM
43 }
44)
45@triton.jit
46def max_kernel_float(
47 inp, out, M, BLOCK_SIZE: tl.constexpr, ONE_TILE_PER_CTA: tl.constexpr
48):
49 pid = tl.program_id(0)
50 block_start = pid * BLOCK_SIZE
51 res = -float("inf")
53 if ONE_TILE_PER_CTA:
54 offset = block_start + tl.arange(0, BLOCK_SIZE)
55 mask = offset < M
56 inp_val = tl.load(inp + offset, mask=mask, other=-float("inf"))
57 (res,) = tl.max(inp_val, 0, return_indices=True)
58 tl.atomic_max(out, res)
59 else:
60 num_jobs = tl.num_programs(axis=0)
61 step = num_jobs * BLOCK_SIZE
62 _tmp = tl.full([BLOCK_SIZE], value=-float("inf"), dtype=inp.dtype.element_ty)
63 for off in range(block_start, M, step):
64 offset = off + tl.arange(0, BLOCK_SIZE)
65 mask = offset < M
66 inp_val = tl.load(inp + offset, mask=mask, other=-float("inf"))
67 _tmp = tl.where((inp_val > _tmp), inp_val, _tmp)
68 (res,) = tl.max(_tmp, 0, return_indices=True)
69 tl.atomic_max(out, res)
72@libentry()
73@libtuner(
74 configs=cfggen_reduce_op(),
75 key=["M"],
76 strategy=["log"],
77 prune_configs_by={"early_config_prune": prune_reduce_config},
78)
79@triton.heuristics(
80 values={
81 "ONE_TILE_PER_CTA": lambda args: args["M"]
82 <= args["BLOCK_SIZE"] * TOTAL_CORE_NUM
83 }
84)
85@triton.jit
86def max_kernel_int(
87 inp, out, M, FILL_VALUE, BLOCK_SIZE: tl.constexpr, ONE_TILE_PER_CTA: tl.constexpr
88):
89 pid = tl.program_id(0)
90 block_start = pid * BLOCK_SIZE
91 res = FILL_VALUE
92 if ONE_TILE_PER_CTA:
93 offset = block_start + tl.arange(0, BLOCK_SIZE)
94 mask = offset < M
95 inp_val = tl.load(inp + offset, mask=mask, other=FILL_VALUE)
96 res = tl.max(inp_val)
97 else:
98 num_jobs = tl.num_programs(axis=0)
99 step = num_jobs * BLOCK_SIZE
100 block_start = pid * BLOCK_SIZE
101 _tmp = tl.full([BLOCK_SIZE], value=-(2**63), dtype=tl.int64)
102 for off in range(block_start, M, step):
103 offset = off + tl.arange(0, BLOCK_SIZE)
104 mask = offset < M
105 inp_val = tl.load(inp + offset, mask=mask, other=FILL_VALUE)
106 _tmp = tl.where((inp_val > _tmp), inp_val, _tmp)
107 res = tl.max(_tmp)
108 tl.atomic_max(out, res)
111@libentry()
112@libtuner(
113 configs=cfggen_reduce_op(),
114 key=["M"],
115 strategy=["log"],
116 prune_configs_by={"early_config_prune": prune_reduce_config},
117)
118@triton.heuristics(
119 values={
120 "ONE_TILE_PER_CTA": lambda args: args["M"]
121 <= args["BLOCK_SIZE"] * TOTAL_CORE_NUM
122 }
123)
124@triton.jit
125def max_kernel_int64_1(
126 inp, mid, M, BLOCK_SIZE: tl.constexpr, ONE_TILE_PER_CTA: tl.constexpr
127):
128 pid = tl.program_id(0)
129 block_start = pid * BLOCK_SIZE
130 FILL_VALUE = -(2**63)
131 res = FILL_VALUE
132 if ONE_TILE_PER_CTA:
133 offset = block_start + tl.arange(0, BLOCK_SIZE)
134 mask = offset < M
135 inp_val = tl.load(inp + offset, mask=mask, other=FILL_VALUE)
136 res = tl.max(inp_val)
137 else:
138 num_jobs = tl.num_programs(axis=0)
139 step = num_jobs * BLOCK_SIZE
140 block_start = block_start.to(tl.int64)
141 _tmp = tl.full([BLOCK_SIZE], value=FILL_VALUE, dtype=tl.int64)
142 for off in range(block_start, M, step):
143 offset = off + tl.arange(0, BLOCK_SIZE)
144 mask = offset < M
145 inp_val = tl.load(inp + offset, mask=mask, other=FILL_VALUE)
146 _tmp = tl.where((inp_val > _tmp), inp_val, _tmp)
147 res = tl.max(_tmp)
148 tl.store(mid + pid, res)
151@libentry()
152@triton.jit
153def max_kernel_int64_2(mid, out, BLOCK_NUM: tl.constexpr):
154 offset = tl.arange(0, BLOCK_NUM)
155 mid_val = tl.load(mid + offset)
156 out_val = tl.max(mid_val)
157 tl.store(out, out_val)
160def heur_block_n(args):
161 return triton.next_power_of_2(args["N"])
164@libentry()
165@libtuner(
166 configs=runtime.get_tuned_config("max"),
167 key=[
168 "M",
169 "N",
170 ],
171 strategy=["log", "log"],
172)
173@triton.jit
174def max_kernel(
175 inp,
176 out_value,
177 out_index,
178 M,
179 N,
180 K,
181 BLOCK_M: tl.constexpr,
182 BLOCK_N: tl.constexpr,
183 UPCAST: tl.constexpr = False,
184):
185 # set offset
186 if UPCAST:
187 pid_m = tl.program_id(0).to(tl.int64)
188 pid_k = tl.program_id(1).to(tl.int64)
189 else:
190 pid_m = tl.program_id(0)
191 pid_k = tl.program_id(1)
192 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
193 result_value = tl.full([BLOCK_M], value=-float("inf"), dtype=tl.float32)
194 result_index = tl.zeros([BLOCK_M], dtype=tl.int64)
195 min_value = get_dtype_min(inp.type.element_ty)
196 for i in range(0, N, BLOCK_N):
197 n_offset = i + tl.arange(0, BLOCK_N)
198 offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k
199 # set mask
200 mask = m_offset[:, None] < M and n_offset[None, :] < N
201 inp_ptrs = inp + offset
202 inp_vals = tl.load(inp_ptrs, mask=mask, other=min_value)
203 max_value, max_index = tl.max(inp_vals, axis=1, return_indices=True)
204 update_mask = max_value > result_value
205 result_value = tl.where(update_mask, max_value, result_value)
206 result_index = tl.where(update_mask, i + max_index, result_index)
207 mask1 = m_offset < M
208 offset_index = m_offset * K + pid_k
209 out_value_ptrs = out_value + offset_index
210 out_index_ptrs = out_index + offset_index
212 tl.store(out_value_ptrs, result_value, mask=mask1)
213 tl.store(out_index_ptrs, result_index, mask=mask1)
216def max(inp):
217 logger.debug("GEMS_CAMBRICON MAX")
218 inp = inp.contiguous()
219 M = inp.numel()
220 dtype = inp.dtype
221 device = inp.device
222 mid_size = TOTAL_CORE_NUM
223 grid = lambda meta: (min(triton.cdiv(M, meta["BLOCK_SIZE"]), mid_size),)
225 with torch_device_fn.device(inp.device):
226 if torch.is_floating_point(inp):
227 if M <= 65536:
228 out = torch.empty([], dtype=dtype, device=device)
229 max_kernel_float_once[(1, 1, 1)](inp, out, M)
230 else:
231 out = torch.full([], float("-inf"), dtype=torch.float32, device=device)
232 max_kernel_float[grid](inp, out, M)
233 elif dtype == torch.int64:
234 mid = torch.empty([mid_size], dtype=dtype, device=device)
235 out = torch.empty([], dtype=dtype, device=device)
236 # Because atomic op don't support i64, use two kernels.
237 max_kernel_int64_1[(mid_size, 1, 1)](inp, mid, M, enable_soft_i64=True)
238 max_kernel_int64_2[(1, 1, 1)](
239 mid, out, BLOCK_NUM=mid_size, enable_soft_i64=True
240 )
241 else:
242 fill_value = torch.iinfo(dtype).min
243 out = torch.full([], -(2**31), dtype=torch.int32, device=device)
244 max_kernel_int[grid](inp, out, M, fill_value)
245 return out.to(dtype)
248def max_dim(inp, dim=None, keepdim=False):
249 logger.debug("GEMS_CAMBRICON MAX DIM")
250 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
251 shape = inp.shape
252 dim = dim % inp.ndim
253 N = shape[dim]
254 M = math.prod(shape[:dim])
255 K = inp.numel() // M // N
257 inp = inp.contiguous()
259 shape_list = list(shape)
260 shape_list[dim] = 1
261 out_value = torch.empty(shape_list, dtype=inp.dtype, device=inp.device)
262 out_index = torch.empty(shape_list, dtype=torch.int64, device=inp.device)
264 if not keepdim:
265 out_value = torch.squeeze(out_value, dim)
266 out_index = torch.squeeze(out_index, dim)
267 UPCAST = inp.shape[0] * inp.stride(0) >= 1 << 31
269 grid = lambda meta: (
270 triton.cdiv(M, meta["BLOCK_M"]),
271 K,
272 )
273 with torch_device_fn.device(inp.device):
274 max_kernel[grid](inp, out_value, out_index, M, N, K, UPCAST=UPCAST)
275 Max_out = namedtuple("max", ["values", "indices"])
276 out = Max_out(values=out_value, indices=out_index)
277 return out