Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/amax.py: 0%
91 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7# from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import dim_compress, libentry
10from flag_gems.utils import triton_lang_extension as tle
12from ..utils.block_size_utils import get_block_size_1d
14logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
17@libentry()
18@triton.jit
19def amax_kernel_1(
20 inp,
21 mid,
22 M,
23 BLOCK_SIZE: tl.constexpr,
24):
25 pid = tle.program_id(0)
27 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
28 inp_ptrs = inp + offset
29 mask = offset < M
30 inp_val = tl.load(inp_ptrs, mask=mask, other=-float("inf"))
31 amax_val = tl.max(inp_val)
32 mid_ptr = mid + pid
33 tl.store(mid_ptr, amax_val)
36@libentry()
37@triton.jit
38def amax_kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr):
39 offset = tl.arange(0, BLOCK_MID)
40 mid_ptrs = mid + offset
41 mask = offset < mid_size
42 mid_val = tl.load(mid_ptrs, mask=mask, other=-float("inf"))
43 amax_val = tl.max(mid_val)
44 tl.store(out, amax_val)
47def heur_m_block_size(args):
48 return triton.next_power_of_2(triton.cdiv(args["M"], 12)) # cluster_num
51def heur_n_block_size(args):
52 import builtins
54 return builtins.min(triton.next_power_of_2(args["N"]), 8192)
57@libentry()
58# @triton.autotune(configs=runtime.get_tuned_config("amax"), key=["M", "N"])
59@triton.heuristics(
60 values={
61 "BLOCK_M": heur_m_block_size,
62 "BLOCK_N": heur_n_block_size,
63 },
64)
65@triton.jit
66def amax_kernel(
67 inp,
68 out,
69 M,
70 N,
71 BLOCK_M: tl.constexpr,
72 BLOCK_N: tl.constexpr,
73):
74 # Map the program id to the row of inp it should compute.
75 pid = tle.program_id(0)
76 rows = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
77 inp = inp + rows * N
78 out = out + rows
79 row_mask = rows < M
81 _all = tl.full([BLOCK_M, BLOCK_N], value=-float("inf"), dtype=tl.float32)
82 for off in range(0, N, BLOCK_N):
83 cols = off + tl.arange(0, BLOCK_N)[None, :]
84 col_mask = cols < N
85 mask = row_mask and col_mask
87 a = tl.load(inp + cols, mask, other=-float("inf")).to(tl.float32)
88 a = tl.where(mask, a, -float("inf"))
89 _all = tl.maximum(_all, a)
90 all = tl.max(_all, axis=1)[:, None]
91 tl.store(out, all, row_mask)
94def amax(inp, dim=None, keepdim=False):
95 logger.debug("GEMS AMAX")
96 if dim is None or len(dim) == 0:
97 M = inp.numel()
98 # block_size = triton.next_power_of_2(math.ceil(math.sqrt(M)))
99 block_size = get_block_size_1d(M, inp.element_size())
100 mid_size = triton.cdiv(M, block_size)
101 block_mid = triton.next_power_of_2(mid_size)
102 dtype = inp.dtype
103 mid = torch.empty((mid_size,), dtype=dtype, device=inp.device)
104 if not keepdim:
105 out = torch.empty([], dtype=dtype, device=inp.device)
106 else:
107 shape = list(inp.shape)
108 for i in range(0, inp.dim()):
109 shape[i] = 1
110 out = torch.empty(shape, dtype=dtype, device=inp.device)
111 with torch_device_fn.device(inp.device):
112 amax_kernel_1[(mid_size, 1)](
113 inp, mid, M, block_size, buffer_size_limit=2048
114 )
115 amax_kernel_2[(1, 1)](
116 mid, out, mid_size, block_mid, buffer_size_limit=2048
117 ) # max block size is 128k, so mid does not requires int64 index
118 return out
119 else:
120 if isinstance(dim, int):
121 dim = [dim]
122 assert ((i >= -inp.ndim and i < inp.ndim) for i in dim), "Invalid dim"
123 dtype = inp.dtype
125 shape = list(inp.shape)
126 dim = [d % inp.ndim for d in dim]
127 inp = dim_compress(inp, dim)
128 N = 1
129 for i in dim:
130 N *= shape[i]
131 shape[i] = 1
132 M = inp.numel() // N
134 out = torch.empty(shape, dtype=dtype, device=inp.device)
136 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
137 with torch_device_fn.device(inp.device):
138 amax_kernel[grid](inp, out, M, N, buffer_size_limit=2048)
139 if not keepdim:
140 out = out.squeeze(dim=dim)
141 return out