Coverage for src/flag_gems/runtime/backend/_ascend/ops/amax.py: 0%
100 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-21 14:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-21 14:31 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import dim_compress, libentry
11from flag_gems.utils import triton_lang_extension as tle
12from flag_gems.utils.limits import get_dtype_min
14logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}')
17@libentry()
18@triton.jit
19def amax_kernel_1(
20 inp,
21 mid,
22 M,
23 BLOCK_SIZE: tl.constexpr,
24):
25 pid = tle.program_id(0)
27 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
28 inp_ptrs = inp + offset
29 mask = offset < M
30 min_value = get_dtype_min(inp.type.element_ty)
31 inp_val = tl.load(inp_ptrs, mask=mask, other=min_value)
32 amax_val = tl.max(inp_val)
33 mid_ptr = mid + pid
34 tl.store(mid_ptr, amax_val)
37@libentry()
38@triton.jit
39def amax_kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr):
40 offset = tl.arange(0, BLOCK_MID)
41 mid_ptrs = mid + offset
42 mask = offset < mid_size
43 min_value = get_dtype_min(mid.type.element_ty)
44 mid_val = tl.load(mid_ptrs, mask=mask, other=min_value)
45 amax_val = tl.max(mid_val)
46 tl.store(out, amax_val)
49@libentry()
50@triton.autotune(configs=runtime.get_tuned_config("amax"), key=["M", "N"])
51@triton.jit
52def amax_kernel(
53 inp,
54 out,
55 M,
56 N,
57 BLOCK_M: tl.constexpr,
58 BLOCK_N: tl.constexpr,
59):
60 dtype = inp.type.element_ty
61 min_value = get_dtype_min(dtype)
63 # Map the program id to the row of inp it should compute.
64 workers = tle.num_programs(0)
65 pid = tle.program_id(0)
67 total_workloads = tl.cdiv(M, BLOCK_M)
68 workloads = tl.cdiv(total_workloads, workers)
70 for w in range(workloads):
71 work_id = pid + w * workers
72 rows = work_id * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
73 ninp = inp + rows * N
74 nout = out + rows
75 row_mask = rows < M
77 acc_type = tl.float32 if dtype is tl.bfloat16 else dtype
78 _all = tl.full([BLOCK_M, BLOCK_N], value=min_value, dtype=acc_type)
79 for off in range(0, N, BLOCK_N):
80 cols = off + tl.arange(0, BLOCK_N)[None, :]
81 col_mask = cols < N
82 mask = row_mask and col_mask
83 a = tl.load(ninp + cols, mask, other=min_value)
84 _all = tl.maximum(_all, a)
85 all = tl.max(_all, axis=1)[:, None]
86 tl.store(nout, all, row_mask)
89def amax(inp, dim=None, keepdim=False):
90 logger.debug("GEMS_ASCEND AMAX")
91 if dim is None or len(dim) == 0:
92 M = inp.numel()
93 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M)))
94 mid_size = triton.cdiv(M, block_size)
95 block_mid = triton.next_power_of_2(mid_size)
96 dtype = inp.dtype
97 mid = torch.empty((mid_size,), dtype=dtype, device=inp.device)
98 if not keepdim:
99 out = torch.empty([], dtype=dtype, device=inp.device)
100 else:
101 shape = list(inp.shape)
102 for i in range(0, inp.dim()):
103 shape[i] = 1
104 out = torch.empty(shape, dtype=dtype, device=inp.device)
105 with torch_device_fn.device(inp.device):
106 amax_kernel_1[(mid_size, 1)](
107 inp,
108 mid,
109 M,
110 block_size,
111 )
112 amax_kernel_2[(1, 1)](
113 mid, out, mid_size, block_mid
114 ) # max block size is 128k, so mid does not requires int64 index
115 return out
116 else:
117 if isinstance(dim, int):
118 dim = [dim]
119 assert ((i >= -inp.ndim and i < inp.ndim) for i in dim), "Invalid dim"
120 dtype = inp.dtype
122 shape = list(inp.shape)
123 dim = [d % inp.ndim for d in dim]
124 inp = dim_compress(inp, dim)
125 N = 1
126 for i in dim:
127 N *= shape[i]
128 shape[i] = 1
129 M = inp.numel() // N
131 out = torch.empty(shape, dtype=dtype, device=inp.device)
133 def grid(meta):
134 axis0 = triton.cdiv(M, meta["BLOCK_M"])
135 axis0 = axis0 if axis0 < 4096 else 4096
136 return (axis0,)
138 with torch_device_fn.device(inp.device):
139 amax_kernel[grid](inp, out, M, N)
140 if not keepdim:
141 out = out.squeeze(dim=dim)
142 return out