Coverage for src/flag_gems/runtime/backend/_metax/ops/amax.py: 0%
92 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import dim_compress, libentry
11from flag_gems.utils import triton_lang_extension as tle
12from flag_gems.utils.limits import get_dtype_min
14logger = logging.getLogger("flag_gems." + __name__)
17@libentry()
18@triton.jit
19def amax_kernel_1(
20 inp,
21 mid,
22 M,
23 BLOCK_SIZE: tl.constexpr,
24):
25 pid = tle.program_id(0)
27 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
28 inp_ptrs = inp + offset
29 mask = offset < M
30 min_value = get_dtype_min(inp.type.element_ty)
31 inp_val = tl.load(inp_ptrs, mask=mask, other=min_value)
32 amax_val = tl.max(inp_val)
33 mid_ptr = mid + pid
34 tl.store(mid_ptr, amax_val)
37@libentry()
38@triton.jit
39def amax_kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr):
40 offset = tl.arange(0, BLOCK_MID)
41 mid_ptrs = mid + offset
42 mask = offset < mid_size
43 min_value = get_dtype_min(mid.type.element_ty)
44 mid_val = tl.load(mid_ptrs, mask=mask, other=min_value)
45 amax_val = tl.max(mid_val)
46 tl.store(out, amax_val)
49@libentry()
50@triton.heuristics(runtime.get_heuristic_config("amax"))
51@triton.jit
52def amax_kernel(
53 inp,
54 out,
55 M,
56 N,
57 BLOCK_M: tl.constexpr,
58 BLOCK_N: tl.constexpr,
59):
60 dtype = inp.type.element_ty
61 min_value = get_dtype_min(dtype)
63 # Map the program id to the row of inp it should compute.
64 pid = tle.program_id(0)
65 rows = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
66 inp = inp + rows * N
67 out = out + rows
68 row_mask = rows < M
70 acc_type = tl.float32 if dtype is tl.bfloat16 else dtype
71 _all = tl.full([BLOCK_M, BLOCK_N], value=min_value, dtype=acc_type)
72 for off in range(0, N, BLOCK_N):
73 cols = off + tl.arange(0, BLOCK_N)[None, :]
74 col_mask = cols < N
75 mask = row_mask and col_mask
76 a = tl.load(inp + cols, mask, other=min_value)
77 _all = tl.maximum(_all, a)
78 all = tl.max(_all, axis=1)[:, None]
79 tl.store(out, all, row_mask)
82def amax(inp, dim=None, keepdim=False):
83 logger.debug("METAX GEMS AMAX")
84 if dim is None or len(dim) == 0:
85 M = inp.numel()
86 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M)))
87 mid_size = triton.cdiv(M, block_size)
88 block_mid = triton.next_power_of_2(mid_size)
89 dtype = inp.dtype
90 mid = torch.empty((mid_size,), dtype=dtype, device=inp.device)
91 if not keepdim:
92 out = torch.empty([], dtype=dtype, device=inp.device)
93 else:
94 shape = list(inp.shape)
95 for i in range(0, inp.dim()):
96 shape[i] = 1
97 out = torch.empty(shape, dtype=dtype, device=inp.device)
98 with torch_device_fn.device(inp.device):
99 amax_kernel_1[(mid_size, 1)](
100 inp,
101 mid,
102 M,
103 block_size,
104 )
105 amax_kernel_2[(1, 1)](
106 mid, out, mid_size, block_mid
107 ) # max block size is 128k, so mid does not requires int64 index
108 return out
109 else:
110 if isinstance(dim, int):
111 dim = [dim]
112 assert ((i >= -inp.ndim and i < inp.ndim) for i in dim), "Invalid dim"
113 dtype = inp.dtype
115 shape = list(inp.shape)
116 dim = [d % inp.ndim for d in dim]
117 inp = dim_compress(inp, dim)
118 N = 1
119 for i in dim:
120 N *= shape[i]
121 shape[i] = 1
122 M = inp.numel() // N
124 out = torch.empty(shape, dtype=dtype, device=inp.device)
126 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
127 with torch_device_fn.device(inp.device):
128 amax_kernel[grid](inp, out, M, N)
129 if not keepdim:
130 out = out.squeeze(dim=dim)
131 return out