Coverage for src/flag_gems/experimental_ops/amin.py: 0%
135 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-18 02:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-18 02:36 +0800
1from functools import reduce
2from operator import mul
4import torch
5import triton
6import triton.language as tl
9@triton.jit
10def amin_reduce_last_kernel(
11 x_ptr,
12 out_ptr,
13 M, # number of rows (outer size)
14 K, # reduction length (last-axis size)
15 stride_xm,
16 stride_xk,
17 init, # identity value for min (same dtype as x)
18 BLOCK_SIZE: tl.constexpr,
19):
20 pid = tl.program_id(0)
21 mask_m = pid < M
22 acc = init
23 k = 0
24 while k < K:
25 offs = k + tl.arange(0, BLOCK_SIZE)
26 mask = mask_m & (offs < K)
27 vals = tl.load(
28 x_ptr + pid * stride_xm + offs * stride_xk, mask=mask, other=init
29 )
30 block_min = tl.min(vals, axis=0)
31 acc = tl.minimum(acc, block_min)
32 k += BLOCK_SIZE
33 tl.store(out_ptr + pid, acc, mask=mask_m)
36def _prod(seq):
37 return int(reduce(mul, seq, 1))
40def _parse_dims(dim, ndim):
41 if dim is None:
42 return list(range(ndim))
43 if isinstance(dim, (list, tuple)):
44 dims = [int(d) for d in dim]
45 else:
46 dims = [int(dim)]
47 # normalize negatives and remove duplicates preserving order
48 seen = set()
49 norm = []
50 for d in dims:
51 dd = d if d >= 0 else d + ndim
52 if dd < 0 or dd >= ndim:
53 raise IndexError("Dimension out of range in amin")
54 if dd not in seen:
55 norm.append(dd)
56 seen.add(dd)
57 return norm
60def _amin_impl(
61 x: torch.Tensor, dim=None, keepdim: bool = False, out: torch.Tensor = None
62):
63 if not x.is_cuda:
64 raise RuntimeError("Triton amin kernel requires CUDA tensors")
65 ndim = x.ndim
66 reduce_dims = _parse_dims(dim, ndim)
67 if len(reduce_dims) == 0:
68 # No reduction dims specified, return input (or copy into out)
69 if out is None:
70 return x.clone()
71 if out.numel() != x.numel():
72 raise RuntimeError(
73 "out tensor has incorrect number of elements for amin with empty dims"
74 )
75 out.copy_(x)
76 return out
78 # Determine output shape
79 input_sizes = list(x.size())
80 keep_sizes = input_sizes.copy()
81 for d in reduce_dims:
82 keep_sizes[d] = 1
83 non_reduce_dims = [i for i in range(ndim) if i not in reduce_dims]
84 non_reduce_sizes = [input_sizes[i] for i in non_reduce_dims]
86 final_shape = keep_sizes if keepdim else non_reduce_sizes
88 # Prepare permutation: move non-reduced dims first, reduced dims last
89 perm = non_reduce_dims + reduce_dims
90 x_perm = x.permute(perm)
91 x_perm = x_perm.contiguous()
93 # Flatten into [M, K]
94 M = _prod(non_reduce_sizes) if len(non_reduce_sizes) > 0 else 1
95 K = _prod([input_sizes[i] for i in reduce_dims]) if len(reduce_dims) > 0 else 1
97 if K == 0:
98 raise RuntimeError(
99 "amin reduction has an empty dimension (no identity for min)"
100 )
102 x_2d = x_perm.view(M, K)
104 # Identity/initial value for min based on dtype
105 dt = x.dtype
106 if dt.is_floating_point:
107 init_val = float("inf")
108 elif dt == torch.bool:
109 init_val = True
110 else:
111 # integer types
112 info = torch.iinfo(dt)
113 init_val = int(info.max)
115 # Prepare output row vector of length M
116 if out is None:
117 out_row = torch.empty((M,), dtype=x.dtype, device=x.device)
118 out_target = None
119 else:
120 # Ensure out shape matches final_shape
121 expected_numel = _prod(final_shape) if len(final_shape) > 0 else 1
122 if out.numel() != expected_numel:
123 raise RuntimeError("out tensor has incorrect number of elements")
124 # We will write into a contiguous view; if out isn't contiguous, use a temp and then reshape/copy back
125 if out.is_contiguous():
126 out_row = out.view(M)
127 out_target = out
128 else:
129 out_row = torch.empty((M,), dtype=out.dtype, device=out.device)
130 out_target = out
132 # Strides for x_2d (contiguous row-major)
133 stride_xm = x_2d.stride(0)
134 stride_xk = x_2d.stride(1)
136 # Launch kernel
137 grid = lambda meta: (M,)
138 BLOCK_SIZE = 1024
139 amin_reduce_last_kernel[grid](
140 x_2d,
141 out_row,
142 M,
143 K,
144 stride_xm,
145 stride_xk,
146 init_val,
147 BLOCK_SIZE=BLOCK_SIZE,
148 )
150 # Reshape to target final shape
151 if len(final_shape) == 0:
152 result = out_row.view(())
153 else:
154 result = out_row.view(final_shape)
156 if out_target is not None:
157 # If original 'out' was non-contiguous, copy result into it respecting shape
158 if not out_target.is_contiguous():
159 # Copy into the provided 'out' tensor
160 out_target.copy_(result)
161 return out_target
162 return out_target
163 return result
166def amin(*args, **kwargs):
167 # Parse args to match aten.amin
168 if len(args) == 0:
169 raise RuntimeError("amin requires at least one tensor argument")
170 x = args[0]
171 dim = kwargs.get("dim", None)
172 keepdim = kwargs.get("keepdim", False)
174 # Positional handling: amin(x, dim), amin(x, dim, keepdim)
175 if len(args) >= 2:
176 if isinstance(args[1], (int, list, tuple)):
177 dim = args[1]
178 elif isinstance(args[1], bool):
179 keepdim = args[1]
180 if len(args) >= 3:
181 if isinstance(args[2], bool):
182 keepdim = args[2]
184 return _amin_impl(x, dim=dim, keepdim=keepdim, out=None)
187def amin_out(*args, **kwargs):
188 # Expected signature: amin_out(x, dim, keepdim, out) or with out as kwarg
189 if len(args) == 0:
190 raise RuntimeError("amin_out requires at least one tensor argument")
191 x = args[0]
193 # Extract out
194 out = kwargs.get("out", None)
195 dim = kwargs.get("dim", None)
196 keepdim = kwargs.get("keepdim", False)
198 # Positional arguments
199 # Try to detect out as last positional if provided
200 if len(args) >= 2:
201 if isinstance(args[1], (int, list, tuple)):
202 dim = args[1]
203 elif isinstance(args[1], bool):
204 keepdim = args[1]
205 elif isinstance(args[1], torch.Tensor):
206 out = args[1]
207 if len(args) >= 3:
208 if isinstance(args[2], bool):
209 keepdim = args[2]
210 elif isinstance(args[2], torch.Tensor) and out is None:
211 out = args[2]
212 if len(args) >= 4 and out is None and isinstance(args[3], torch.Tensor):
213 out = args[3]
215 if out is None:
216 raise RuntimeError("amin_out requires an 'out' tensor argument")
218 return _amin_impl(x, dim=dim, keepdim=keepdim, out=out)