Coverage for src/flag_gems/runtime/backend/_ascend/ops/max.py: 0%
107 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-19 02:32 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-19 02:32 +0800
1import logging
2import math
3from collections import namedtuple
5import torch
6import triton
7import triton.language as tl
9from flag_gems import runtime
10from flag_gems.runtime import torch_device_fn
11from flag_gems.utils import libentry
12from flag_gems.utils import triton_lang_extension as tle
13from flag_gems.utils.limits import get_dtype_min
15logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}')
18@libentry()
19@triton.jit
20def max_kernel_1(
21 inp,
22 mid,
23 M,
24 BLOCK_SIZE: tl.constexpr,
25):
26 pid = tle.program_id(0)
27 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
28 inp_ptrs = inp + offset
29 mask = offset < M
30 min_value = get_dtype_min(inp.type.element_ty)
31 inp_val = tl.load(inp_ptrs, mask=mask, other=min_value)
32 max_val = tl.max(inp_val)
33 mid_ptr = mid + pid
34 tl.store(mid_ptr, max_val)
37@libentry()
38@triton.jit
39def max_kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr):
40 offset = tl.arange(0, BLOCK_MID)
41 mid_ptrs = mid + offset
42 mask = offset < mid_size
43 min_value = get_dtype_min(mid.type.element_ty)
44 mid_val = tl.load(mid_ptrs, mask=mask, other=min_value)
45 max_val = tl.max(mid_val)
46 tl.store(out, max_val)
49def heur_block_n(args):
50 return triton.next_power_of_2(args["N"])
53@libentry()
54@triton.autotune(
55 configs=runtime.get_tuned_config("max"),
56 key=[
57 "M",
58 "N",
59 ],
60)
61@triton.jit
62def max_kernel(
63 inp,
64 out_value,
65 out_index,
66 M,
67 N,
68 K,
69 BLOCK_M: tl.constexpr,
70 BLOCK_N: tl.constexpr,
71):
72 # set offset
73 pid_m = tle.program_id(0)
74 pid_k = tle.program_id(1)
75 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
77 dtype = inp.type.element_ty
78 acc_type = tl.float32 if dtype is tl.bfloat16 else dtype
79 min_value = get_dtype_min(dtype)
80 result_value = tl.full([BLOCK_M], value=min_value, dtype=acc_type)
81 result_index = tl.zeros([BLOCK_M], dtype=tl.int64)
82 for i in range(0, N, BLOCK_N):
83 n_offset = i + tl.arange(0, BLOCK_N)
84 offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k
85 # set mask
86 mask = m_offset[:, None] < M and n_offset[None, :] < N
87 inp_ptrs = inp + offset
88 inp_vals = tl.load(inp_ptrs, mask=mask, other=min_value)
89 if dtype is tl.int64:
90 inp_vals = tl.where(mask, inp_vals, min_value)
91 max_value, max_index = tl.max(inp_vals, axis=1, return_indices=True)
92 update_mask = max_value > result_value
93 result_value = tl.where(update_mask, max_value, result_value)
94 result_index = tl.where(update_mask, i + max_index, result_index)
95 mask1 = m_offset < M
96 offset_index = m_offset * K + pid_k
97 out_value_ptrs = out_value + offset_index
98 out_index_ptrs = out_index + offset_index
100 tl.store(out_value_ptrs, result_value, mask=mask1)
101 tl.store(out_index_ptrs, result_index, mask=mask1)
104def max(inp):
105 logger.debug("GEMS_ASCEND MAX")
106 inp = inp.contiguous()
107 M = inp.numel()
108 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M)))
109 mid_size = triton.cdiv(M, block_size)
110 block_mid = triton.next_power_of_2(mid_size)
112 dtype = inp.dtype
113 mid = torch.empty((mid_size,), dtype=dtype, device=inp.device)
114 out = torch.empty([], dtype=dtype, device=inp.device)
116 with torch_device_fn.device(inp.device):
117 max_kernel_1[(mid_size, 1, 1)](inp, mid, M, block_size)
118 max_kernel_2[(1, 1, 1)](mid, out, mid_size, block_mid)
119 return out
122def max_dim(inp, dim=None, keepdim=False):
123 logger.debug("GEMS_ASCEND MAX")
124 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
125 shape = inp.shape
126 dim = dim % inp.ndim
127 N = shape[dim]
128 M = math.prod(shape[:dim])
129 K = inp.numel() // M // N
131 inp = inp.contiguous()
133 shape_list = list(shape)
134 shape_list[dim] = 1
135 out_value = torch.empty(shape_list, dtype=inp.dtype, device=inp.device)
136 out_index = torch.empty(shape_list, dtype=torch.int64, device=inp.device)
138 if not keepdim:
139 out_value = torch.squeeze(out_value, dim)
140 out_index = torch.squeeze(out_index, dim)
142 def grid(meta):
143 axis0 = triton.cdiv(M, meta["BLOCK_M"])
144 axis1 = K
145 while axis0 * axis1 >= 65536:
146 axis0 = axis0 // 2
147 return (
148 axis0,
149 axis1,
150 )
152 with torch_device_fn.device(inp.device):
153 max_kernel[grid](inp, out_value, out_index, M, N, K)
154 Max_out = namedtuple("max", ["values", "indices"])
155 out = Max_out(values=out_value, indices=out_index)
156 return out