Coverage for src/flag_gems/ops/amax.py: 46%
92 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import dim_compress, libentry, libtuner
11from flag_gems.utils import triton_lang_extension as tle
12from flag_gems.utils.limits import get_dtype_min
14logger = logging.getLogger(__name__)
17@libentry()
18@triton.jit
19def amax_kernel_1(
20 inp,
21 mid,
22 M,
23 BLOCK_SIZE: tl.constexpr,
24):
25 pid = tle.program_id(0)
27 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
28 inp_ptrs = inp + offset
29 mask = offset < M
30 min_value = get_dtype_min(inp.type.element_ty)
31 inp_val = tl.load(inp_ptrs, mask=mask, other=min_value)
32 amax_val = tl.max(inp_val)
33 mid_ptr = mid + pid
34 tl.store(mid_ptr, amax_val)
37@libentry()
38@triton.jit
39def amax_kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr):
40 offset = tl.arange(0, BLOCK_MID)
41 mid_ptrs = mid + offset
42 mask = offset < mid_size
43 min_value = get_dtype_min(mid.type.element_ty)
44 mid_val = tl.load(mid_ptrs, mask=mask, other=min_value)
45 amax_val = tl.max(mid_val)
46 tl.store(out, amax_val)
49@libentry()
50@libtuner(
51 configs=runtime.get_tuned_config("naive_reduction"),
52 key=["M", "N"],
53)
54@triton.jit
55def amax_kernel(
56 inp,
57 out,
58 M,
59 N,
60 BLOCK_M: tl.constexpr,
61 BLOCK_N: tl.constexpr,
62):
63 dtype = inp.type.element_ty
64 min_value = get_dtype_min(dtype)
66 # Map the program id to the row of inp it should compute.
67 pid = tle.program_id(0)
68 rows = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
69 inp = inp + rows * N
70 out = out + rows
71 row_mask = rows < M
73 acc_type = tl.float32 if dtype is tl.bfloat16 else dtype
74 _all = tl.full([BLOCK_M, BLOCK_N], value=min_value, dtype=acc_type)
75 for off in range(0, N, BLOCK_N):
76 cols = off + tl.arange(0, BLOCK_N)[None, :]
77 col_mask = cols < N
78 mask = row_mask and col_mask
79 a = tl.load(inp + cols, mask, other=min_value)
80 _all = tl.maximum(_all, a)
81 all = tl.max(_all, axis=1)[:, None]
82 tl.store(out, all, row_mask)
85def amax(inp, dim=None, keepdim=False):
86 logger.debug("GEMS AMAX")
87 if dim is None or len(dim) == 0:
88 M = inp.numel()
89 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M)))
90 mid_size = triton.cdiv(M, block_size)
91 block_mid = triton.next_power_of_2(mid_size)
92 dtype = inp.dtype
93 mid = torch.empty((mid_size,), dtype=dtype, device=inp.device)
94 if not keepdim:
95 out = torch.empty([], dtype=dtype, device=inp.device)
96 else:
97 shape = list(inp.shape)
98 for i in range(0, inp.dim()):
99 shape[i] = 1
100 out = torch.empty(shape, dtype=dtype, device=inp.device)
101 with torch_device_fn.device(inp.device):
102 amax_kernel_1[(mid_size, 1)](
103 inp,
104 mid,
105 M,
106 block_size,
107 )
108 amax_kernel_2[(1, 1)](
109 mid, out, mid_size, block_mid
110 ) # max block size is 128k, so mid does not requires int64 index
111 return out
112 else:
113 if isinstance(dim, int):
114 dim = [dim]
115 assert ((i >= -inp.ndim and i < inp.ndim) for i in dim), "Invalid dim"
116 dtype = inp.dtype
118 shape = list(inp.shape)
119 dim = [d % inp.ndim for d in dim]
120 inp = dim_compress(inp, dim)
121 N = 1
122 for i in dim:
123 N *= shape[i]
124 shape[i] = 1
125 M = inp.numel() // N
127 out = torch.empty(shape, dtype=dtype, device=inp.device)
129 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
130 with torch_device_fn.device(inp.device):
131 amax_kernel[grid](inp, out, M, N)
132 if not keepdim:
133 out = out.squeeze(dim=dim)
134 return out