Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/cummin.py: 0%
243 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-17 02:35 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-17 02:35 +0800
1import logging
2import math
3from collections import namedtuple
5import torch
6import triton
7import triton.language as tl
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry
11from flag_gems.utils import triton_lang_extension as tle
12from flag_gems.utils.limits import get_dtype_max
14# from typing import List, Tuple, Union
17CumminResult = namedtuple("CumminResult", ["values", "indices"])
19Tensor = torch.Tensor
21logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
24@triton.jit
25def tl_cummin(input, index, axis=0):
26 return tl.associative_scan(
27 (input, index), axis, tle.minimum_with_index_tie_break_right
28 )
31@triton.jit
32def tl_min_tie_break_right(input, index, axis=None, keep_dims=False):
33 return tl.reduce(
34 (input, index),
35 axis,
36 tle.minimum_with_index_tie_break_right,
37 keep_dims=keep_dims,
38 )
41@libentry()
42@triton.jit(do_not_specialize=["n_elements"])
43def add_base_min_kernel(
44 out,
45 out_indices,
46 partial_min,
47 partial_min_indices,
48 n_elements,
49 BLOCK_SIZE: tl.constexpr,
50):
51 pid = tle.program_id(0)
52 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
53 mask = offset < n_elements
55 out_ptrs = out + offset
56 out_indices_ptrs = out_indices + offset
57 out_vals = tl.load(out_ptrs, mask=mask)
58 out_indices = tl.load(out_indices_ptrs, mask=mask)
60 if pid > 0:
61 partial_min_ptrs = partial_min + pid - 1
62 last_part_min_via_min = tl.load(partial_min_ptrs)
63 partial_min_indices_ptrs = partial_min_indices + pid - 1
64 last_part_min_index_via_min = tl.load(partial_min_indices_ptrs)
66 final_vals = tl.minimum(out_vals, last_part_min_via_min)
67 final_indices = tl.where(
68 out_vals <= last_part_min_via_min, out_indices, last_part_min_index_via_min
69 )
70 tl.store(out_ptrs, final_vals.to(out_vals.dtype), mask=mask)
71 tl.store(out_indices_ptrs, final_indices, mask=mask)
74@libentry()
75@triton.jit(do_not_specialize=["n_elements"])
76def scan_part_min_kernel(
77 inp,
78 out,
79 in_indices,
80 out_indices,
81 partial_min,
82 partial_min_indices,
83 n_elements,
84 BLOCK_SIZE: tl.constexpr,
85 NEED_PARTIAL: tl.constexpr,
86 USE_OUT_INDICES: tl.constexpr,
87):
88 pid = tle.program_id(0)
89 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
90 mask = offset < n_elements
92 max_value = get_dtype_max(inp.type.element_ty)
93 inp_ptrs = inp + offset
94 inp_vals = tl.load(inp_ptrs, mask=mask, other=max_value)
95 if (
96 tl.constexpr(inp_vals.dtype.is_int64())
97 or tl.constexpr(inp_vals.dtype.is_uint64())
98 ) or tl.constexpr(inp_vals.dtype.is_fp64()):
99 inp_vals = inp_vals
100 elif tl.constexpr(inp_vals.dtype.is_int()):
101 inp_vals = inp_vals.to(tl.int32)
102 else:
103 inp_vals = inp_vals.to(tl.float32)
104 if tl.constexpr(USE_OUT_INDICES):
105 in_indices_ptrs = out_indices + offset
106 in_indices_vals = tl.load(in_indices_ptrs, mask=mask)
107 else:
108 in_indices_vals = offset
109 result, cummin_indices = tl_cummin(inp_vals, in_indices_vals, axis=0)
111 if tl.constexpr(NEED_PARTIAL):
112 # tl.min do not support min_indices_tie_break_right
113 part_min_via_min, part_min_indices_via_min = tl_min_tie_break_right(
114 inp_vals, in_indices_vals, axis=0
115 )
117 out_ptrs = out + offset
118 tl.store(out_ptrs, result, mask=mask)
120 out_indices_ptrs = out_indices + offset
121 tl.store(out_indices_ptrs, cummin_indices, mask=mask)
123 if tl.constexpr(NEED_PARTIAL):
124 partial_min_ptrs = partial_min + pid
125 tl.store(partial_min_ptrs, part_min_via_min)
127 partial_min_indices_ptrs = partial_min_indices + pid
128 tl.store(partial_min_indices_ptrs, part_min_indices_via_min)
131def scan_then_fan_col(inp, out, out_indices, n_ele, dtype, use_out_indices=False):
132 # TODO(all): tune on target board
133 BLOCK_SIZE = 1024
134 if n_ele <= 1024 * 4:
135 BLOCK_SIZE = triton.next_power_of_2(n_ele)
136 part_num = math.ceil(n_ele / BLOCK_SIZE)
137 need_partial = True if part_num >= 2 else False
138 if need_partial:
139 partial_min = torch.empty(part_num, dtype=dtype, device=inp.device)
140 partial_min_indices = torch.empty(
141 part_num, dtype=torch.int64, device=inp.device
142 )
143 else:
144 partial_min = None
145 partial_min_indices = None
147 grid = (part_num,)
148 with torch_device_fn.device(inp.device):
149 scan_part_min_kernel[grid](
150 inp,
151 out,
152 out_indices,
153 out_indices,
154 partial_min,
155 partial_min_indices,
156 n_ele,
157 BLOCK_SIZE,
158 need_partial,
159 use_out_indices,
160 )
162 if part_num >= 2:
163 scan_then_fan_col(
164 partial_min,
165 partial_min,
166 partial_min_indices,
167 part_num,
168 dtype,
169 use_out_indices=True,
170 )
171 with torch_device_fn.device(inp.device):
172 add_base_min_kernel[grid](
173 out, out_indices, partial_min, partial_min_indices, n_ele, BLOCK_SIZE
174 )
177@libentry()
178@triton.jit(do_not_specialize=["part_num"])
179def scan_part_min_abc_kernel(
180 inp,
181 out,
182 in_indices,
183 out_indices,
184 partial_min,
185 partial_min_indices,
186 B,
187 C,
188 part_num,
189 BLOCK_SIZE: tl.constexpr,
190 NEED_PARTIAL: tl.constexpr,
191 USE_OUT_INDICES: tl.constexpr,
192):
193 pid_a = tle.program_id(0)
194 pid_b = tle.program_id(1)
195 pid_c = tle.program_id(2)
197 a_idx = pid_a
198 b_idx = pid_b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
199 c_idx = pid_c
201 offset = a_idx * B * C + b_idx * C + c_idx
202 base_part_offset = a_idx * part_num * C + c_idx
203 part_offset = base_part_offset + pid_b * C
205 mask = b_idx < B
206 inp_ptrs = inp + offset
207 max_value = get_dtype_max(inp.type.element_ty)
208 inp_vals = tl.load(inp_ptrs, mask=mask, other=max_value)
209 if (
210 tl.constexpr(inp_vals.dtype.is_int64())
211 or tl.constexpr(inp_vals.dtype.is_uint64())
212 ) or tl.constexpr(inp_vals.dtype.is_fp64()):
213 inp_vals = inp_vals
214 elif tl.constexpr(inp_vals.dtype.is_int()):
215 inp_vals = inp_vals.to(tl.int32)
216 else:
217 inp_vals = inp_vals.to(tl.float32)
218 if tl.constexpr(USE_OUT_INDICES):
219 in_indices_ptrs = out_indices + offset
220 in_indices_vals = tl.load(in_indices_ptrs, mask=mask)
221 else:
222 in_indices_vals = b_idx
223 result, cummin_indices = tl_cummin(inp_vals, in_indices_vals, axis=0)
225 if tl.constexpr(NEED_PARTIAL):
226 # tl.min do not support min_indices_tie_break_right
227 part_min_via_min, part_min_indices_via_min = tl_min_tie_break_right(
228 inp_vals, in_indices_vals, axis=0
229 )
231 out_ptrs = out + offset
232 tl.store(out_ptrs, result, mask=mask)
234 out_indices_ptrs = out_indices + offset
235 tl.store(out_indices_ptrs, cummin_indices, mask=mask)
237 if tl.constexpr(NEED_PARTIAL):
238 partial_min_ptrs = partial_min + part_offset
239 tl.store(partial_min_ptrs, part_min_via_min)
241 partial_min_indices_ptrs = partial_min_indices + part_offset
242 tl.store(partial_min_indices_ptrs, part_min_indices_via_min)
245@libentry()
246@triton.jit(do_not_specialize=["part_num"])
247def add_base_min_abc_kernel(
248 out,
249 out_indices,
250 partial_min,
251 partial_min_indices,
252 B,
253 C,
254 part_num,
255 BLOCK_SIZE: tl.constexpr,
256):
257 pid_a = tle.program_id(0)
258 pid_b = tle.program_id(1)
259 pid_c = tle.program_id(2)
261 a_idx = pid_a
262 b_idx = pid_b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
263 c_idx = pid_c
265 base_offset = a_idx * B * C + c_idx
266 offset = base_offset + b_idx * C
267 base_part_offset = a_idx * part_num * C + c_idx
268 last_part_offset = base_part_offset + (pid_b - 1) * C
270 mask = b_idx < B
271 out_ptrs = out + offset
272 out_vals = tl.load(out_ptrs, mask=mask)
273 out_indices_ptrs = out_indices + offset
274 out_indices = tl.load(out_indices_ptrs, mask=mask)
276 if pid_b > 0:
277 partial_min_ptrs = partial_min + last_part_offset
278 last_part_min_via_min = tl.load(partial_min_ptrs)
279 partial_min_index_ptrs = partial_min_indices + last_part_offset
280 last_part_min_index_via_min = tl.load(partial_min_index_ptrs)
282 final_vals = tl.minimum(out_vals, last_part_min_via_min)
283 final_indices = tl.where(
284 out_vals <= last_part_min_via_min, out_indices, last_part_min_index_via_min
285 )
286 tl.store(out_ptrs, final_vals.to(out_vals.dtype), mask=mask)
287 tl.store(out_indices_ptrs, final_indices, mask=mask)
290def scan_then_fan(inp, out, out_indices, A, B, C, dtype, use_out_indices=False):
291 # TODO(all): tune on target board
292 BLOCK_SIZE = 1024
293 if B <= 1024 * 4:
294 BLOCK_SIZE = triton.next_power_of_2(B)
295 part_num = math.ceil(B / BLOCK_SIZE)
296 need_partial = True if part_num >= 2 else False
297 if need_partial:
298 partial_min = torch.empty(A, part_num, C, dtype=dtype, device=inp.device)
299 partial_min_indices = torch.empty(
300 A, part_num, C, dtype=torch.int64, device=inp.device
301 )
302 else:
303 partial_min = None
304 partial_min_indices = None
306 grid = (A, part_num, C)
307 with torch_device_fn.device(inp.device):
308 scan_part_min_abc_kernel[grid](
309 inp,
310 out,
311 out_indices,
312 out_indices,
313 partial_min,
314 partial_min_indices,
315 B,
316 C,
317 part_num,
318 BLOCK_SIZE,
319 need_partial,
320 use_out_indices,
321 )
323 if part_num >= 2:
324 scan_then_fan(
325 partial_min,
326 partial_min,
327 partial_min_indices,
328 A,
329 part_num,
330 C,
331 dtype,
332 use_out_indices=True,
333 )
334 with torch_device_fn.device(inp.device):
335 add_base_min_abc_kernel[grid](
336 out,
337 out_indices,
338 partial_min,
339 partial_min_indices,
340 B,
341 C,
342 part_num,
343 BLOCK_SIZE,
344 )
347@libentry()
348@triton.jit()
349def scan_part_min_abc_loop_kernel(
350 inp,
351 out,
352 out_indices,
353 B,
354 C,
355 loop_num,
356 BLOCK_SIZE: tl.constexpr,
357):
358 pid_a = tle.program_id(0)
359 pid_c = tle.program_id(1)
361 a_idx = pid_a
362 c_idx = pid_c
363 t_idx = tl.arange(0, BLOCK_SIZE)
364 ac_offset = a_idx * B * C + c_idx
366 # init
367 max_value = get_dtype_max(inp.type.element_ty)
368 if tl.constexpr(inp.type.element_ty.is_fp16()) or tl.constexpr(
369 inp.type.element_ty.is_bf16()
370 ):
371 compute_dtype = tl.float32
372 elif tl.constexpr(inp.type.element_ty.is_int8()) or tl.constexpr(
373 inp.type.element_ty.is_int16()
374 ):
375 compute_dtype = tl.int32
376 else:
377 compute_dtype = inp.type.element_ty
379 prev_min_val = tl.full([], max_value, dtype=compute_dtype)
380 prev_min_val_idx = tl.full([], 0, dtype=tl.int64)
381 last_mask = t_idx == (BLOCK_SIZE - 1)
383 for l_idx in tl.range(loop_num):
384 b_idx = l_idx * BLOCK_SIZE + t_idx
385 mask = b_idx < B
386 offset = ac_offset + b_idx * C
388 inp_vals = tl.load(inp + offset, mask=mask, other=max_value)
389 # Only promote if necessary
390 if tl.constexpr(compute_dtype != inp.type.element_ty):
391 vals = inp_vals.to(compute_dtype)
392 else:
393 vals = inp_vals
394 idxs = b_idx
396 # cummin
397 result, cummin_indices = tl_cummin(vals, idxs, axis=0)
399 # broadcast
400 prev_min_val_b = tl.broadcast_to(prev_min_val, (BLOCK_SIZE,))
401 prev_min_val_idx_b = tl.broadcast_to(prev_min_val_idx, (BLOCK_SIZE,))
403 # Handle NaN and tie-breaking logic
404 if tl.constexpr(compute_dtype.is_floating()):
405 # For floats: handle NaN propagation + tie-break right
406 prev_is_nan = prev_min_val != prev_min_val
407 result_is_nan = result != result
408 prev_nan_mask = tl.broadcast_to(prev_is_nan, (BLOCK_SIZE,))
410 use_result = result_is_nan | (~prev_nan_mask & (result <= prev_min_val_b))
411 else:
412 # For integers: simple tie-break right
413 use_result = result <= prev_min_val_b
415 final_vals = tl.where(use_result, result, prev_min_val_b)
416 final_indices = tl.where(use_result, cummin_indices, prev_min_val_idx_b)
418 # update global min val and idx
419 prev_min_val = tl.sum(tl.where(last_mask, final_vals, 0), axis=0)
420 prev_min_val_idx = tl.sum(tl.where(last_mask, final_indices, 0), axis=0)
422 # store result
423 tl.store(out + offset, final_vals.to(out.type.element_ty), mask=mask)
424 tl.store(out_indices + offset, final_indices, mask=mask)
427def scan_then_fan_loop(inp, out, out_indices, A, B, C, dtype):
428 # TODO(all): tune on target board
429 BLOCK_SIZE = 1024
430 if B < 1024 * 4:
431 BLOCK_SIZE = triton.next_power_of_2(B)
432 loop_num = math.ceil(B / BLOCK_SIZE)
434 grid = (A, C)
435 with torch_device_fn.device(inp.device):
436 scan_part_min_abc_loop_kernel[grid](
437 inp,
438 out,
439 out_indices,
440 B,
441 C,
442 loop_num,
443 BLOCK_SIZE,
444 is_use_mask_zero=True,
445 )
448# def cummin(
449# input: Tensor,
450# dim: int,
451# *,
452# out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None,
453# ) -> torch.return_types.cummin:
454def cummin(input, dim=1):
455 logger.debug("GEMS cummin")
456 assert dim >= -input.ndim and dim < input.ndim, "Invalid dim"
457 shape = input.shape
458 dim = dim % input.ndim
459 M = 1
460 N = shape[dim]
461 for i in range(dim):
462 M *= shape[i]
463 input = input.contiguous()
464 K = input.numel() // M // N
466 dtype = input.dtype
467 if dtype is torch.bool:
468 dtype = torch.int64
469 out = torch.empty_like(input, dtype=dtype)
470 out_indices = torch.empty_like(input, dtype=torch.int64)
472 compute_dtype = out.dtype
473 if input.dtype == torch.float16 or input.dtype == torch.bfloat16:
474 compute_dtype = torch.float32
476 if M == 1 and K == 1:
477 scan_then_fan_col(input, out, out_indices, N, compute_dtype)
478 elif M * K <= 16:
479 scan_then_fan(input, out, out_indices, M, N, K, compute_dtype)
480 else:
481 scan_then_fan_loop(input, out, out_indices, M, N, K, compute_dtype)
482 # return out, out_indices
483 return CumminResult(out, out_indices)