Coverage for src/flag_gems/experimental_ops/maximum.py: 0%
103 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-17 02:35 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-17 02:35 +0800
1import torch
2import triton
3import triton.language as tl
5MAX_DIMS = 8
6BLOCK_SIZE = 1024
9@triton.jit
10def maximum_kernel(
11 a_ptr,
12 b_ptr,
13 out_ptr,
14 n_elements,
15 s0,
16 s1,
17 s2,
18 s3,
19 s4,
20 s5,
21 s6,
22 s7, # shape dims
23 sa0,
24 sa1,
25 sa2,
26 sa3,
27 sa4,
28 sa5,
29 sa6,
30 sa7, # a strides
31 sb0,
32 sb1,
33 sb2,
34 sb3,
35 sb4,
36 sb5,
37 sb6,
38 sb7, # b strides
39 so0,
40 so1,
41 so2,
42 so3,
43 so4,
44 so5,
45 so6,
46 so7, # out strides
47 BLOCK_SIZE: tl.constexpr,
48):
49 pid = tl.program_id(axis=0)
50 block_start = pid * BLOCK_SIZE
51 offsets = block_start + tl.arange(0, BLOCK_SIZE)
52 mask = offsets < n_elements
54 # Use int64 for address calculations
55 li = offsets.to(tl.int64)
57 # Compute multi-dimensional indices from linear index (row-major: last dim fastest)
58 i7 = li % s7
59 li = li // s7
60 i6 = li % s6
61 li = li // s6
62 i5 = li % s5
63 li = li // s5
64 i4 = li % s4
65 li = li // s4
66 i3 = li % s3
67 li = li // s3
68 i2 = li % s2
69 li = li // s2
70 i1 = li % s1
71 li = li // s1
72 i0 = li % s0
73 li = li // s0
75 # Compute element offsets for each tensor using strides (in elements)
76 off_a = (
77 i0 * sa0
78 + i1 * sa1
79 + i2 * sa2
80 + i3 * sa3
81 + i4 * sa4
82 + i5 * sa5
83 + i6 * sa6
84 + i7 * sa7
85 )
86 off_b = (
87 i0 * sb0
88 + i1 * sb1
89 + i2 * sb2
90 + i3 * sb3
91 + i4 * sb4
92 + i5 * sb5
93 + i6 * sb6
94 + i7 * sb7
95 )
96 off_o = (
97 i0 * so0
98 + i1 * so1
99 + i2 * so2
100 + i3 * so3
101 + i4 * so4
102 + i5 * so5
103 + i6 * so6
104 + i7 * so7
105 )
107 a_vals = tl.load(a_ptr + off_a, mask=mask, other=0)
108 b_vals = tl.load(b_ptr + off_b, mask=mask, other=0)
109 out_vals = tl.maximum(a_vals, b_vals)
110 tl.store(out_ptr + off_o, out_vals, mask=mask)
113def _as_tensor_on_device(x, device, dtype=None):
114 if torch.is_tensor(x):
115 return (
116 x.to(device=device, dtype=dtype)
117 if (dtype is not None and x.dtype != dtype) or (x.device != device)
118 else x
119 )
120 return torch.tensor(x, device=device, dtype=dtype)
123def _broadcast_to_common(a, b):
124 a_b, b_b = torch.broadcast_tensors(a, b)
125 return a_b, b_b
128def _pad_shape_strides(shape, strides):
129 # Ensure shape dims are at least 1 to avoid div by zero
130 shape_list = list(shape)
131 strides_list = list(strides)
132 nd = len(shape_list)
133 assert nd <= MAX_DIMS
134 shape_list = shape_list + [1] * (MAX_DIMS - nd)
135 strides_list = strides_list + [0] * (MAX_DIMS - nd)
136 # Triton expects integers
137 shape_list = [int(s) for s in shape_list]
138 strides_list = [int(s) for s in strides_list]
139 return shape_list, strides_list
142def _launch_maximum_kernel(a, b, out):
143 # Assumes a and b are broadcastable and already cast to out.dtype and on same device
144 a_b, b_b = _broadcast_to_common(a, b)
145 # Make inputs contiguous to avoid negative/irregular strides complications
146 # Broadcasting uses 0-stride for broadcasted dims; keeping 0-stride is fine
147 # but handle potential negative/non-standard strides by materializing.
148 if any(s < 0 for s in a_b.stride()):
149 a_b = a_b.contiguous()
150 if any(s < 0 for s in b_b.stride()):
151 b_b = b_b.contiguous()
153 out_shape = a_b.shape # == b_b.shape
154 n_elements = int(a_b.numel())
155 if n_elements == 0:
156 return
158 # Prepare shape and strides for kernel
159 shp, sa = _pad_shape_strides(out_shape, a_b.stride())
160 _, sb = _pad_shape_strides(out_shape, b_b.stride())
161 _, so = _pad_shape_strides(out_shape, out.stride())
163 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
164 maximum_kernel[grid](
165 a_b,
166 b_b,
167 out,
168 n_elements,
169 shp[0],
170 shp[1],
171 shp[2],
172 shp[3],
173 shp[4],
174 shp[5],
175 shp[6],
176 shp[7],
177 sa[0],
178 sa[1],
179 sa[2],
180 sa[3],
181 sa[4],
182 sa[5],
183 sa[6],
184 sa[7],
185 sb[0],
186 sb[1],
187 sb[2],
188 sb[3],
189 sb[4],
190 sb[5],
191 sb[6],
192 sb[7],
193 so[0],
194 so[1],
195 so[2],
196 so[3],
197 so[4],
198 so[5],
199 so[6],
200 so[7],
201 BLOCK_SIZE=BLOCK_SIZE,
202 )
205def maximum(a, b):
206 # Determine device
207 dev = None
208 if torch.is_tensor(a):
209 dev = a.device
210 if torch.is_tensor(b):
211 dev = b.device if dev is None else dev
212 if dev is None or dev.type != "cuda":
213 raise ValueError("maximum expects at least one CUDA tensor as input")
215 # Determine result dtype per PyTorch promotion rules
216 res_dtype = torch.result_type(a, b)
217 a_t = _as_tensor_on_device(a, dev, dtype=res_dtype)
218 b_t = _as_tensor_on_device(b, dev, dtype=res_dtype)
220 # Broadcast to determine output shape
221 a_b, b_b = _broadcast_to_common(a_t, b_t)
222 out = torch.empty(a_b.shape, device=dev, dtype=res_dtype)
224 # If out has negative strides or is non-contiguous, compute into a contiguous buffer then copy
225 if not out.is_contiguous() or any(s < 0 for s in out.stride()):
226 out_buf = torch.empty_like(out, memory_format=torch.contiguous_format)
227 _launch_maximum_kernel(a_t, b_t, out_buf)
228 out.copy_(out_buf)
229 else:
230 _launch_maximum_kernel(a_t, b_t, out)
232 return out
235def maximum_out(a, b, out):
236 if not torch.is_tensor(out):
237 raise TypeError("out must be a torch.Tensor")
238 if out.device.type != "cuda":
239 raise ValueError("out tensor must be on CUDA device")
241 dev = out.device
243 # Cast inputs to out dtype (following typical .out behavior)
244 a_t = _as_tensor_on_device(a, dev, dtype=out.dtype)
245 b_t = _as_tensor_on_device(b, dev, dtype=out.dtype)
247 # Validate/broadcast shape against out
248 a_b, b_b = _broadcast_to_common(a_t, b_t)
249 if tuple(a_b.shape) != tuple(out.shape):
250 raise ValueError(
251 f"out shape {tuple(out.shape)} is not broadcast-compatible with inputs shape {tuple(a_b.shape)}"
252 )
254 # If out has negative strides or is non-contiguous, compute into a contiguous buffer then copy
255 if not out.is_contiguous() or any(s < 0 for s in out.stride()):
256 out_buf = torch.empty_like(out, memory_format=torch.contiguous_format)
257 _launch_maximum_kernel(a_t, b_t, out_buf)
258 out.copy_(out_buf)
259 else:
260 _launch_maximum_kernel(a_t, b_t, out)
262 return out