Coverage for src/flag_gems/runtime/backend/_cambricon/ops/amax.py: 0%
133 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import dim_compress, libentry
10from flag_gems.utils.shape_utils import can_use_int32_index
12from ..utils import TOTAL_CORE_NUM, cfggen_reduce_op
14logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
17@libentry()
18@triton.jit
19def amax_kernel_once(
20 inp,
21 out,
22 M: tl.constexpr,
23):
24 offset = tl.arange(0, M)
25 inp_val = tl.load(inp + offset)
26 amax_val = tl.max(inp_val, 0)
27 tl.store(out, amax_val)
30@libentry()
31@triton.autotune(configs=cfggen_reduce_op(), key=["M"])
32@triton.jit
33def amax_kernel_1(
34 inp,
35 out,
36 M,
37 BLOCK_SIZE: tl.constexpr,
38 INT64_INDEX: tl.constexpr = False,
39):
40 pid = tl.program_id(0)
41 if INT64_INDEX:
42 pid = pid.to(tl.int64)
43 num_jobs = tl.num_programs(axis=0)
44 block_start = pid * BLOCK_SIZE
45 step = num_jobs * BLOCK_SIZE
46 _tmp = -float("inf")
47 for off in range(block_start, M, step):
48 offset = off + tl.arange(0, BLOCK_SIZE)
49 mask = offset < M
50 inp_val = tl.load(inp + offset, mask=mask, other=-float("inf"))
51 (amax_val,) = tl.max(inp_val, 0, return_indices=True)
52 if amax_val > _tmp:
53 _tmp = amax_val.to(tl.float32)
54 tl.atomic_max(out, _tmp)
57@libentry()
58@triton.autotune(configs=runtime.get_tuned_config("amax_opt"), key=["N"])
59@triton.jit
60def amax_kernel_opt(
61 inp,
62 out,
63 M: tl.constexpr,
64 N: tl.constexpr,
65 TILE_NUM_N: tl.constexpr,
66 INT64_INDEX: tl.constexpr = False,
67):
68 # Map the program id to the row of inp it should compute.
69 pid_m = tl.program_id(0)
70 pid_n = tl.program_id(1)
71 if INT64_INDEX:
72 pid_m = pid_m.to(tl.int64)
73 pid_n = pid_n.to(tl.int64)
75 num_jobs = tl.num_programs(0)
76 rows_per_job = (M + num_jobs - 1) // num_jobs
77 row_begin = pid_m * rows_per_job
78 row_end = min(row_begin + rows_per_job, M)
80 BLOCK_N: tl.constexpr = (N + TILE_NUM_N - 1) // TILE_NUM_N
82 for row_idx in range(row_begin, row_end):
83 offset_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
84 inp_ptrs = inp + row_idx * N + offset_n
85 mask = offset_n < N
86 inps = tl.load(inp_ptrs, mask, other=-float("inf"))
87 (max_val,) = tl.max(inps, 0, return_indices=True)
88 new_out = out + row_idx
89 tl.atomic_max(new_out, max_val)
92@libentry()
93@triton.autotune(configs=runtime.get_tuned_config("amax"), key=["M", "N"])
94@triton.jit
95def amax_kernel(
96 inp,
97 out,
98 M,
99 N,
100 BLOCK_M: tl.constexpr,
101 BLOCK_N: tl.constexpr,
102 INT64_INDEX: tl.constexpr = False,
103):
104 # Map the program id to the row of inp it should compute.
105 pid = tl.program_id(0)
106 if INT64_INDEX:
107 pid = pid.to(tl.int64)
109 num_jobs = tl.num_programs(axis=0)
110 start_m = pid * BLOCK_M
111 step = num_jobs * BLOCK_M
112 for off_m in range(start_m, M, step):
113 rows = off_m + tl.arange(0, BLOCK_M)[:, None]
114 new_inp = inp + rows * N
115 new_out = out + rows
116 row_mask = rows < M
118 _all = tl.full([BLOCK_M, BLOCK_N], value=-float("inf"), dtype=tl.float32)
119 for off in range(0, N, BLOCK_N):
120 cols = off + tl.arange(0, BLOCK_N)[None, :]
121 col_mask = cols < N
122 mask = row_mask and col_mask
124 a = tl.load(new_inp + cols, mask, other=-float("inf"))
125 _all = tl.maximum(a, _all)
127 all = tl.max(_all, axis=1)[:, None]
128 tl.store(new_out, all, row_mask)
131def amax(inp, dim=None, keepdim=False):
132 logger.debug("GEMS_CAMBRICON AMAX")
133 if dim is None or len(dim) == 0:
134 M = inp.numel()
135 dtype = inp.dtype
136 use_int64_index = not can_use_int32_index(inp)
138 if M <= 65536:
139 if not keepdim:
140 out = torch.empty([], dtype=dtype, device=inp.device)
141 else:
142 shape = list(inp.shape)
143 for i in range(0, inp.dim()):
144 shape[i] = 1
145 out = torch.empty(shape, dtype=dtype, device=inp.device)
146 with torch.cuda.device(inp.device):
147 amax_kernel_once[(1, 1, 1)](inp, out, M)
148 return out
149 else:
150 outdtype = torch.float32
151 if not keepdim:
152 out = torch.full(
153 [], torch.finfo(outdtype).min, dtype=outdtype, device=inp.device
154 )
155 else:
156 shape = list(inp.shape)
157 for i in range(0, inp.dim()):
158 shape[i] = 1
159 out = torch.full(
160 shape, torch.finfo(outdtype).min, dtype=outdtype, device=inp.device
161 )
162 grid = lambda meta: (
163 min(triton.cdiv(M, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),
164 )
165 with torch_device_fn.device(inp.device):
166 amax_kernel_1[grid](inp, out, M, INT64_INDEX=use_int64_index)
167 return out.to(dtype)
168 else:
169 if isinstance(dim, int):
170 dim = [dim]
171 assert ((i >= -inp.ndim and i < inp.ndim) for i in dim), "Invalid dim"
172 dtype = inp.dtype
174 shape = list(inp.shape)
175 dim = [d % inp.ndim for d in dim]
176 inp = dim_compress(inp, dim)
177 use_int64_index = not can_use_int32_index(inp)
178 N = 1
179 for i in dim:
180 N *= shape[i]
181 shape[i] = 1
182 M = inp.numel() // N
184 with torch_device_fn.device(inp.device):
185 if N > 1048576:
186 out = torch.empty(shape, dtype=dtype, device=inp.device)
187 grid = lambda meta: (
188 min(triton.cdiv(M, meta["BLOCK_M"]), TOTAL_CORE_NUM),
189 )
190 amax_kernel[grid](inp, out, M, N, INT64_INDEX=use_int64_index)
191 else:
192 out = torch.full(
193 shape,
194 torch.finfo(torch.float32).min,
195 dtype=torch.float32,
196 device=inp.device,
197 )
198 grid = lambda meta: (
199 min(triton.cdiv(TOTAL_CORE_NUM, meta["TILE_NUM_N"]), M),
200 meta["TILE_NUM_N"],
201 )
202 amax_kernel_opt[grid](inp, out, M, N, INT64_INDEX=use_int64_index)
203 if not keepdim:
204 out = out.squeeze(dim=dim)
205 return out.to(dtype)