Coverage for src/flag_gems/ops/cummax.py: 37%
242 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-13 10:08 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-13 10:08 +0800
1import logging
2import math
3from typing import List, Tuple, Union
5import torch
6import triton
7import triton.language as tl
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry
11from flag_gems.utils import triton_lang_extension as tle
12from flag_gems.utils.limits import get_dtype_min
14Tensor = torch.Tensor
16logger = logging.getLogger(__name__)
19@triton.jit
20def tl_cummax(input, index, axis=0):
21 return tl.associative_scan(
22 (input, index), axis, tle.maximum_with_index_tie_break_right
23 )
26@triton.jit
27def tl_max_tie_break_right(input, index, axis=None, keep_dims=False):
28 return tl.reduce(
29 (input, index),
30 axis,
31 tle.maximum_with_index_tie_break_right,
32 keep_dims=keep_dims,
33 )
36@libentry()
37@triton.jit(do_not_specialize=["n_elements"])
38def add_base_max_kernel(
39 out,
40 out_indices,
41 partial_max,
42 partial_max_indices,
43 n_elements,
44 BLOCK_SIZE: tl.constexpr,
45):
46 pid = tle.program_id(0)
47 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
48 mask = offset < n_elements
50 out_ptrs = out + offset
51 out_indices_ptrs = out_indices + offset
52 out_vals = tl.load(out_ptrs, mask=mask)
53 out_indices = tl.load(out_indices_ptrs, mask=mask)
55 if pid > 0:
56 partial_max_ptrs = partial_max + pid - 1
57 last_part_max_via_max = tl.load(partial_max_ptrs)
58 partial_max_indices_ptrs = partial_max_indices + pid - 1
59 last_part_max_index_via_max = tl.load(partial_max_indices_ptrs)
61 final_vals = tl.maximum(out_vals, last_part_max_via_max)
62 final_indices = tl.where(
63 out_vals >= last_part_max_via_max, out_indices, last_part_max_index_via_max
64 )
65 tl.store(out_ptrs, final_vals.to(out_vals.dtype), mask=mask)
66 tl.store(out_indices_ptrs, final_indices, mask=mask)
69@libentry()
70@triton.jit(do_not_specialize=["n_elements"])
71def scan_part_max_kernel(
72 inp,
73 out,
74 in_indices,
75 out_indices,
76 partial_max,
77 partial_max_indices,
78 n_elements,
79 BLOCK_SIZE: tl.constexpr,
80 NEED_PARTIAL: tl.constexpr,
81 USE_OUT_INDICES: tl.constexpr,
82):
83 pid = tle.program_id(0)
84 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
85 mask = offset < n_elements
87 min_value = get_dtype_min(inp.type.element_ty)
88 inp_ptrs = inp + offset
89 inp_vals = tl.load(inp_ptrs, mask=mask, other=min_value)
90 if (
91 tl.constexpr(inp_vals.dtype.is_int64())
92 or tl.constexpr(inp_vals.dtype.is_uint64())
93 ) or tl.constexpr(inp_vals.dtype.is_fp64()):
94 inp_vals = inp_vals
95 elif tl.constexpr(inp_vals.dtype.is_int()):
96 inp_vals = inp_vals.to(tl.int32)
97 else:
98 inp_vals = inp_vals.to(tl.float32)
99 if tl.constexpr(USE_OUT_INDICES):
100 in_indices_ptrs = out_indices + offset
101 in_indices_vals = tl.load(in_indices_ptrs, mask=mask)
102 else:
103 in_indices_vals = offset
104 result, cummax_indices = tl_cummax(inp_vals, in_indices_vals, axis=0)
106 if tl.constexpr(NEED_PARTIAL):
107 # tl.max do not support max_indices_tie_break_right
108 part_max_via_max, part_max_indices_via_max = tl_max_tie_break_right(
109 inp_vals, in_indices_vals, axis=0
110 )
112 out_ptrs = out + offset
113 tl.store(out_ptrs, result, mask=mask)
115 out_indices_ptrs = out_indices + offset
116 tl.store(out_indices_ptrs, cummax_indices, mask=mask)
118 if tl.constexpr(NEED_PARTIAL):
119 partial_max_ptrs = partial_max + pid
120 tl.store(partial_max_ptrs, part_max_via_max)
122 partial_max_indices_ptrs = partial_max_indices + pid
123 tl.store(partial_max_indices_ptrs, part_max_indices_via_max)
126def scan_then_fan_col(inp, out, out_indices, n_ele, dtype, use_out_indices=False):
127 # TODO(all): tune on target board
128 BLOCK_SIZE = 1024
129 if n_ele <= 1024 * 4:
130 BLOCK_SIZE = triton.next_power_of_2(n_ele)
131 part_num = math.ceil(n_ele / BLOCK_SIZE)
132 need_partial = True if part_num >= 2 else False
133 if need_partial:
134 partial_max = torch.empty(part_num, dtype=dtype, device=inp.device)
135 partial_max_indices = torch.empty(
136 part_num, dtype=torch.int64, device=inp.device
137 )
138 else:
139 partial_max = None
140 partial_max_indices = None
142 grid = (part_num,)
143 with torch_device_fn.device(inp.device):
144 scan_part_max_kernel[grid](
145 inp,
146 out,
147 out_indices,
148 out_indices,
149 partial_max,
150 partial_max_indices,
151 n_ele,
152 BLOCK_SIZE,
153 need_partial,
154 use_out_indices,
155 )
157 if part_num >= 2:
158 scan_then_fan_col(
159 partial_max,
160 partial_max,
161 partial_max_indices,
162 part_num,
163 dtype,
164 use_out_indices=True,
165 )
166 with torch_device_fn.device(inp.device):
167 add_base_max_kernel[grid](
168 out, out_indices, partial_max, partial_max_indices, n_ele, BLOCK_SIZE
169 )
172@libentry()
173@triton.jit(do_not_specialize=["part_num"])
174def scan_part_max_abc_kernel(
175 inp,
176 out,
177 in_indices,
178 out_indices,
179 partial_max,
180 partial_max_indices,
181 B,
182 C,
183 part_num,
184 BLOCK_SIZE: tl.constexpr,
185 NEED_PARTIAL: tl.constexpr,
186 USE_OUT_INDICES: tl.constexpr,
187):
188 pid_a = tle.program_id(0)
189 pid_b = tle.program_id(1)
190 pid_c = tle.program_id(2)
192 a_idx = pid_a
193 b_idx = pid_b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
194 c_idx = pid_c
196 offset = a_idx * B * C + b_idx * C + c_idx
197 base_part_offset = a_idx * part_num * C + c_idx
198 part_offset = base_part_offset + pid_b * C
200 mask = b_idx < B
201 inp_ptrs = inp + offset
202 min_value = get_dtype_min(inp.type.element_ty)
203 inp_vals = tl.load(inp_ptrs, mask=mask, other=min_value)
204 if (
205 tl.constexpr(inp_vals.dtype.is_int64())
206 or tl.constexpr(inp_vals.dtype.is_uint64())
207 ) or tl.constexpr(inp_vals.dtype.is_fp64()):
208 inp_vals = inp_vals
209 elif tl.constexpr(inp_vals.dtype.is_int()):
210 inp_vals = inp_vals.to(tl.int32)
211 else:
212 inp_vals = inp_vals.to(tl.float32)
213 if tl.constexpr(USE_OUT_INDICES):
214 in_indices_ptrs = out_indices + offset
215 in_indices_vals = tl.load(in_indices_ptrs, mask=mask)
216 else:
217 in_indices_vals = b_idx
218 result, cummax_indices = tl_cummax(inp_vals, in_indices_vals, axis=0)
220 if tl.constexpr(NEED_PARTIAL):
221 # tl.max do not support max_indices_tie_break_right
222 part_max_via_max, part_max_indices_via_max = tl_max_tie_break_right(
223 inp_vals, in_indices_vals, axis=0
224 )
226 out_ptrs = out + offset
227 tl.store(out_ptrs, result, mask=mask)
229 out_indices_ptrs = out_indices + offset
230 tl.store(out_indices_ptrs, cummax_indices, mask=mask)
232 if tl.constexpr(NEED_PARTIAL):
233 partial_max_ptrs = partial_max + part_offset
234 tl.store(partial_max_ptrs, part_max_via_max)
236 partial_max_indices_ptrs = partial_max_indices + part_offset
237 tl.store(partial_max_indices_ptrs, part_max_indices_via_max)
240@libentry()
241@triton.jit(do_not_specialize=["part_num"])
242def add_base_max_abc_kernel(
243 out,
244 out_indices,
245 partial_max,
246 partial_max_indices,
247 B,
248 C,
249 part_num,
250 BLOCK_SIZE: tl.constexpr,
251):
252 pid_a = tle.program_id(0)
253 pid_b = tle.program_id(1)
254 pid_c = tle.program_id(2)
256 a_idx = pid_a
257 b_idx = pid_b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
258 c_idx = pid_c
260 base_offset = a_idx * B * C + c_idx
261 offset = base_offset + b_idx * C
262 base_part_offset = a_idx * part_num * C + c_idx
263 last_part_offset = base_part_offset + (pid_b - 1) * C
265 mask = b_idx < B
266 out_ptrs = out + offset
267 out_vals = tl.load(out_ptrs, mask=mask)
268 out_indices_ptrs = out_indices + offset
269 out_indices = tl.load(out_indices_ptrs, mask=mask)
271 if pid_b > 0:
272 partial_max_ptrs = partial_max + last_part_offset
273 last_part_max_via_max = tl.load(partial_max_ptrs)
274 partial_max_index_ptrs = partial_max_indices + last_part_offset
275 last_part_max_index_via_max = tl.load(partial_max_index_ptrs)
277 final_vals = tl.maximum(out_vals, last_part_max_via_max)
278 final_indices = tl.where(
279 out_vals >= last_part_max_via_max, out_indices, last_part_max_index_via_max
280 )
281 tl.store(out_ptrs, final_vals.to(out_vals.dtype), mask=mask)
282 tl.store(out_indices_ptrs, final_indices, mask=mask)
285def scan_then_fan(inp, out, out_indices, A, B, C, dtype, use_out_indices=False):
286 # TODO(all): tune on target board
287 BLOCK_SIZE = 1024
288 if B <= 1024 * 4:
289 BLOCK_SIZE = triton.next_power_of_2(B)
290 part_num = math.ceil(B / BLOCK_SIZE)
291 need_partial = True if part_num >= 2 else False
292 if need_partial:
293 partial_max = torch.empty(A, part_num, C, dtype=dtype, device=inp.device)
294 partial_max_indices = torch.empty(
295 A, part_num, C, dtype=torch.int64, device=inp.device
296 )
297 else:
298 partial_max = None
299 partial_max_indices = None
301 grid = (A, part_num, C)
302 with torch_device_fn.device(inp.device):
303 scan_part_max_abc_kernel[grid](
304 inp,
305 out,
306 out_indices,
307 out_indices,
308 partial_max,
309 partial_max_indices,
310 B,
311 C,
312 part_num,
313 BLOCK_SIZE,
314 need_partial,
315 use_out_indices,
316 )
318 if part_num >= 2:
319 scan_then_fan(
320 partial_max,
321 partial_max,
322 partial_max_indices,
323 A,
324 part_num,
325 C,
326 dtype,
327 use_out_indices=True,
328 )
329 with torch_device_fn.device(inp.device):
330 add_base_max_abc_kernel[grid](
331 out,
332 out_indices,
333 partial_max,
334 partial_max_indices,
335 B,
336 C,
337 part_num,
338 BLOCK_SIZE,
339 )
342@libentry()
343@triton.jit()
344def scan_part_max_abc_loop_kernel(
345 inp,
346 out,
347 out_indices,
348 B,
349 C,
350 loop_num,
351 BLOCK_SIZE: tl.constexpr,
352):
353 pid_a = tle.program_id(0)
354 pid_c = tle.program_id(1)
356 a_idx = pid_a
357 c_idx = pid_c
358 t_idx = tl.arange(0, BLOCK_SIZE)
359 ac_offset = a_idx * B * C + c_idx
361 # init, promote low precision types
362 min_value = get_dtype_min(inp.type.element_ty)
363 if tl.constexpr(inp.type.element_ty.is_fp16()) or tl.constexpr(
364 inp.type.element_ty.is_bf16()
365 ):
366 compute_dtype = tl.float32
367 elif tl.constexpr(inp.type.element_ty.is_int8()) or tl.constexpr(
368 inp.type.element_ty.is_int16()
369 ):
370 compute_dtype = tl.int32
371 else:
372 compute_dtype = inp.type.element_ty
374 prev_max_val = tl.full([], min_value, dtype=compute_dtype)
375 prev_max_val_idx = tl.full([], 0, dtype=tl.int64)
376 last_mask = t_idx == (BLOCK_SIZE - 1)
378 for l_idx in tl.range(loop_num):
379 b_idx = l_idx * BLOCK_SIZE + t_idx
380 mask = b_idx < B
381 offset = ac_offset + b_idx * C
383 inp_vals = tl.load(inp + offset, mask=mask, other=min_value)
384 # Only promote if necessary
385 if tl.constexpr(compute_dtype != inp.type.element_ty):
386 vals = inp_vals.to(compute_dtype)
387 else:
388 vals = inp_vals
389 idxs = b_idx
391 # cummax
392 result, cummax_indices = tl_cummax(vals, idxs, axis=0)
394 # broadcast
395 prev_max_val_b = tl.broadcast_to(prev_max_val, (BLOCK_SIZE,))
396 prev_max_val_idx_b = tl.broadcast_to(prev_max_val_idx, (BLOCK_SIZE,))
398 # Handle NaN and tie-breaking logic
399 if tl.constexpr(compute_dtype.is_floating()):
400 # For floats: handle NaN propagation + tie-break right
401 prev_is_nan = prev_max_val != prev_max_val
402 result_is_nan = result != result
403 prev_nan_mask = tl.broadcast_to(prev_is_nan, (BLOCK_SIZE,))
405 use_result = result_is_nan | (~prev_nan_mask & (result >= prev_max_val_b))
406 else:
407 # For integers: simple tie-break right
408 use_result = result >= prev_max_val_b
410 final_vals = tl.where(use_result, result, prev_max_val_b)
411 final_indices = tl.where(use_result, cummax_indices, prev_max_val_idx_b)
413 # update global max val and idx
414 prev_max_val = tl.sum(tl.where(last_mask, final_vals, 0), axis=0)
415 prev_max_val_idx = tl.sum(tl.where(last_mask, final_indices, 0), axis=0)
417 # store result
418 tl.store(out + offset, final_vals.to(out.type.element_ty), mask=mask)
419 tl.store(out_indices + offset, final_indices, mask=mask)
422def scan_then_fan_loop(inp, out, out_indices, A, B, C, dtype):
423 # TODO(all): tune on target board
424 BLOCK_SIZE = 1024
425 if B < 1024 * 4:
426 BLOCK_SIZE = triton.next_power_of_2(B)
427 loop_num = math.ceil(B / BLOCK_SIZE)
429 grid = (A, C)
430 with torch_device_fn.device(inp.device):
431 scan_part_max_abc_loop_kernel[grid](
432 inp,
433 out,
434 out_indices,
435 B,
436 C,
437 loop_num,
438 BLOCK_SIZE,
439 )
442def cummax(
443 input: Tensor,
444 dim: int,
445 *,
446 out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None,
447) -> torch.return_types.cummax:
448 logger.debug("GEMS cummax")
449 assert dim >= -input.ndim and dim < input.ndim, "Invalid dim"
450 shape = input.shape
451 dim = dim % input.ndim
452 M = 1
453 N = shape[dim]
454 for i in range(dim):
455 M *= shape[i]
456 input = input.contiguous()
457 K = input.numel() // M // N
459 dtype = input.dtype
460 if dtype is torch.bool:
461 dtype = torch.int64
462 out = torch.empty_like(input, dtype=dtype)
463 out_indices = torch.empty_like(input, dtype=torch.int64)
465 compute_dtype = out.dtype
466 if input.dtype == torch.float16 or input.dtype == torch.bfloat16:
467 compute_dtype = torch.float32
469 if M == 1 and K == 1:
470 scan_then_fan_col(input, out, out_indices, N, compute_dtype)
471 elif M * K <= 16:
472 scan_then_fan(input, out, out_indices, M, N, K, compute_dtype)
473 else:
474 scan_then_fan_loop(input, out, out_indices, M, N, K, compute_dtype)
475 return out, out_indices