Coverage for src/flag_gems/runtime/backend/_cambricon/ops/cummin.py: 0%
242 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-16 02:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-16 02:02 +0800
1import logging
2import math
3from typing import List, Tuple, Union
5import torch
6import triton
7import triton.language as tl
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry
11from flag_gems.utils import triton_lang_extension as tle
12from flag_gems.utils.limits import get_dtype_max
14Tensor = torch.Tensor
15logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
18@triton.jit
19def tl_cummin(input, index, axis=0):
20 return tl.associative_scan(
21 (input, index), axis, tle.minimum_with_index_tie_break_right
22 )
25@triton.jit
26def tl_min_tie_break_right(input, index, axis=None, keep_dims=False):
27 return tl.reduce(
28 (input, index),
29 axis,
30 tle.minimum_with_index_tie_break_right,
31 keep_dims=keep_dims,
32 )
35@libentry()
36@triton.jit(do_not_specialize=["n_elements"])
37def add_base_min_kernel(
38 out,
39 out_indices,
40 partial_min,
41 partial_min_indices,
42 n_elements,
43 BLOCK_SIZE: tl.constexpr,
44):
45 pid = tle.program_id(0)
46 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
47 mask = offset < n_elements
49 out_ptrs = out + offset
50 out_indices_ptrs = out_indices + offset
51 out_vals = tl.load(out_ptrs, mask=mask)
52 out_indices = tl.load(out_indices_ptrs, mask=mask)
54 if pid > 0:
55 partial_min_ptrs = partial_min + pid - 1
56 last_part_min_via_min = tl.load(partial_min_ptrs)
57 partial_min_indices_ptrs = partial_min_indices + pid - 1
58 last_part_min_index_via_min = tl.load(partial_min_indices_ptrs)
60 final_vals = tl.minimum(out_vals, last_part_min_via_min)
61 final_indices = tl.where(
62 out_vals <= last_part_min_via_min, out_indices, last_part_min_index_via_min
63 )
64 tl.store(out_ptrs, final_vals.to(out_vals.dtype), mask=mask)
65 tl.store(out_indices_ptrs, final_indices, mask=mask)
68@libentry()
69@triton.jit(do_not_specialize=["n_elements"])
70def scan_part_min_kernel(
71 inp,
72 out,
73 in_indices,
74 out_indices,
75 partial_min,
76 partial_min_indices,
77 n_elements,
78 BLOCK_SIZE: tl.constexpr,
79 NEED_PARTIAL: tl.constexpr,
80 USE_OUT_INDICES: tl.constexpr,
81):
82 pid = tle.program_id(0)
83 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
84 mask = offset < n_elements
86 max_value = get_dtype_max(inp.type.element_ty)
87 inp_ptrs = inp + offset
88 inp_vals = tl.load(inp_ptrs, mask=mask, other=max_value)
89 if (
90 tl.constexpr(inp_vals.dtype.is_int64())
91 or tl.constexpr(inp_vals.dtype.is_uint64())
92 ) or tl.constexpr(inp_vals.dtype.is_fp64()):
93 inp_vals = inp_vals
94 elif tl.constexpr(inp_vals.dtype.is_int()):
95 inp_vals = inp_vals.to(tl.int32)
96 else:
97 inp_vals = inp_vals.to(tl.float32)
98 if tl.constexpr(USE_OUT_INDICES):
99 in_indices_ptrs = out_indices + offset
100 in_indices_vals = tl.load(in_indices_ptrs, mask=mask)
101 else:
102 in_indices_vals = offset
103 result, cummin_indices = tl_cummin(inp_vals, in_indices_vals, axis=0)
105 if tl.constexpr(NEED_PARTIAL):
106 # tl.min do not support min_indices_tie_break_right
107 part_min_via_min, part_min_indices_via_min = tl_min_tie_break_right(
108 inp_vals, in_indices_vals, axis=0
109 )
111 out_ptrs = out + offset
112 tl.store(out_ptrs, result, mask=mask)
114 out_indices_ptrs = out_indices + offset
115 tl.store(out_indices_ptrs, cummin_indices, mask=mask)
117 if tl.constexpr(NEED_PARTIAL):
118 partial_min_ptrs = partial_min + pid
119 tl.store(partial_min_ptrs, part_min_via_min)
121 partial_min_indices_ptrs = partial_min_indices + pid
122 tl.store(partial_min_indices_ptrs, part_min_indices_via_min)
125def scan_then_fan_col(inp, out, out_indices, n_ele, dtype, use_out_indices=False):
126 # TODO(all): tune on target board
127 BLOCK_SIZE = 1024
128 if n_ele <= 1024 * 4:
129 BLOCK_SIZE = triton.next_power_of_2(n_ele)
130 part_num = math.ceil(n_ele / BLOCK_SIZE)
131 need_partial = True if part_num >= 2 else False
132 if need_partial:
133 partial_min = torch.empty(part_num, dtype=dtype, device=inp.device)
134 partial_min_indices = torch.empty(
135 part_num, dtype=torch.int64, device=inp.device
136 )
137 else:
138 partial_min = None
139 partial_min_indices = None
141 grid = (part_num,)
142 with torch_device_fn.device(inp.device):
143 scan_part_min_kernel[grid](
144 inp,
145 out,
146 out_indices,
147 out_indices,
148 partial_min,
149 partial_min_indices,
150 n_ele,
151 BLOCK_SIZE,
152 need_partial,
153 use_out_indices,
154 )
156 if part_num >= 2:
157 scan_then_fan_col(
158 partial_min,
159 partial_min,
160 partial_min_indices,
161 part_num,
162 dtype,
163 use_out_indices=True,
164 )
165 with torch_device_fn.device(inp.device):
166 add_base_min_kernel[grid](
167 out, out_indices, partial_min, partial_min_indices, n_ele, BLOCK_SIZE
168 )
171@libentry()
172@triton.jit(do_not_specialize=["part_num"])
173def scan_part_min_abc_kernel(
174 inp,
175 out,
176 in_indices,
177 out_indices,
178 partial_min,
179 partial_min_indices,
180 B,
181 C,
182 part_num,
183 BLOCK_SIZE: tl.constexpr,
184 NEED_PARTIAL: tl.constexpr,
185 USE_OUT_INDICES: tl.constexpr,
186):
187 pid_a = tle.program_id(0)
188 pid_b = tle.program_id(1)
189 pid_c = tle.program_id(2)
191 a_idx = pid_a
192 b_idx = pid_b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
193 c_idx = pid_c
195 offset = a_idx * B * C + b_idx * C + c_idx
196 base_part_offset = a_idx * part_num * C + c_idx
197 part_offset = base_part_offset + pid_b * C
199 mask = b_idx < B
200 inp_ptrs = inp + offset
201 max_value = get_dtype_max(inp.type.element_ty)
202 inp_vals = tl.load(inp_ptrs, mask=mask, other=max_value)
203 if (
204 tl.constexpr(inp_vals.dtype.is_int64())
205 or tl.constexpr(inp_vals.dtype.is_uint64())
206 ) or tl.constexpr(inp_vals.dtype.is_fp64()):
207 inp_vals = inp_vals
208 elif tl.constexpr(inp_vals.dtype.is_int()):
209 inp_vals = inp_vals.to(tl.int32)
210 else:
211 inp_vals = inp_vals.to(tl.float32)
212 if tl.constexpr(USE_OUT_INDICES):
213 in_indices_ptrs = out_indices + offset
214 in_indices_vals = tl.load(in_indices_ptrs, mask=mask)
215 else:
216 in_indices_vals = b_idx
217 result, cummin_indices = tl_cummin(inp_vals, in_indices_vals, axis=0)
219 if tl.constexpr(NEED_PARTIAL):
220 # tl.min do not support min_indices_tie_break_right
221 part_min_via_min, part_min_indices_via_min = tl_min_tie_break_right(
222 inp_vals, in_indices_vals, axis=0
223 )
225 out_ptrs = out + offset
226 tl.store(out_ptrs, result, mask=mask)
228 out_indices_ptrs = out_indices + offset
229 tl.store(out_indices_ptrs, cummin_indices, mask=mask)
231 if tl.constexpr(NEED_PARTIAL):
232 partial_min_ptrs = partial_min + part_offset
233 tl.store(partial_min_ptrs, part_min_via_min)
235 partial_min_indices_ptrs = partial_min_indices + part_offset
236 tl.store(partial_min_indices_ptrs, part_min_indices_via_min)
239@libentry()
240@triton.jit(do_not_specialize=["part_num"])
241def add_base_min_abc_kernel(
242 out,
243 out_indices,
244 partial_min,
245 partial_min_indices,
246 B,
247 C,
248 part_num,
249 BLOCK_SIZE: tl.constexpr,
250):
251 pid_a = tle.program_id(0)
252 pid_b = tle.program_id(1)
253 pid_c = tle.program_id(2)
255 a_idx = pid_a
256 b_idx = pid_b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
257 c_idx = pid_c
259 base_offset = a_idx * B * C + c_idx
260 offset = base_offset + b_idx * C
261 base_part_offset = a_idx * part_num * C + c_idx
262 last_part_offset = base_part_offset + (pid_b - 1) * C
264 mask = b_idx < B
265 out_ptrs = out + offset
266 out_vals = tl.load(out_ptrs, mask=mask)
267 out_indices_ptrs = out_indices + offset
268 out_indices = tl.load(out_indices_ptrs, mask=mask)
270 if pid_b > 0:
271 partial_min_ptrs = partial_min + last_part_offset
272 last_part_min_via_min = tl.load(partial_min_ptrs)
273 partial_min_index_ptrs = partial_min_indices + last_part_offset
274 last_part_min_index_via_min = tl.load(partial_min_index_ptrs)
276 final_vals = tl.minimum(out_vals, last_part_min_via_min)
277 final_indices = tl.where(
278 out_vals <= last_part_min_via_min, out_indices, last_part_min_index_via_min
279 )
280 tl.store(out_ptrs, final_vals.to(out_vals.dtype), mask=mask)
281 tl.store(out_indices_ptrs, final_indices, mask=mask)
284def scan_then_fan(inp, out, out_indices, A, B, C, dtype, use_out_indices=False):
285 # TODO(all): tune on target board
286 BLOCK_SIZE = 1024
287 if B <= 1024 * 4:
288 BLOCK_SIZE = triton.next_power_of_2(B)
289 part_num = math.ceil(B / BLOCK_SIZE)
290 need_partial = True if part_num >= 2 else False
291 if need_partial:
292 partial_min = torch.empty(A, part_num, C, dtype=dtype, device=inp.device)
293 partial_min_indices = torch.empty(
294 A, part_num, C, dtype=torch.int64, device=inp.device
295 )
296 else:
297 partial_min = None
298 partial_min_indices = None
300 grid = (A, part_num, C)
301 with torch_device_fn.device(inp.device):
302 scan_part_min_abc_kernel[grid](
303 inp,
304 out,
305 out_indices,
306 out_indices,
307 partial_min,
308 partial_min_indices,
309 B,
310 C,
311 part_num,
312 BLOCK_SIZE,
313 need_partial,
314 use_out_indices,
315 )
317 if part_num >= 2:
318 scan_then_fan(
319 partial_min,
320 partial_min,
321 partial_min_indices,
322 A,
323 part_num,
324 C,
325 dtype,
326 use_out_indices=True,
327 )
328 with torch_device_fn.device(inp.device):
329 add_base_min_abc_kernel[grid](
330 out,
331 out_indices,
332 partial_min,
333 partial_min_indices,
334 B,
335 C,
336 part_num,
337 BLOCK_SIZE,
338 )
341@libentry()
342@triton.jit()
343def scan_part_min_abc_loop_kernel(
344 inp,
345 out,
346 out_indices,
347 B,
348 C,
349 loop_num,
350 BLOCK_SIZE: tl.constexpr,
351):
352 pid_a = tle.program_id(0)
353 pid_c = tle.program_id(1)
355 a_idx = pid_a
356 c_idx = pid_c
357 t_idx = tl.arange(0, BLOCK_SIZE)
358 ac_offset = a_idx * B * C + c_idx
360 # init
361 max_value = get_dtype_max(inp.type.element_ty)
362 if tl.constexpr(inp.type.element_ty.is_fp16()) or tl.constexpr(
363 inp.type.element_ty.is_bf16()
364 ):
365 compute_dtype = tl.float32
366 elif tl.constexpr(inp.type.element_ty.is_int8()) or tl.constexpr(
367 inp.type.element_ty.is_int16()
368 ):
369 compute_dtype = tl.int32
370 else:
371 compute_dtype = inp.type.element_ty
373 prev_min_val = tl.full([], max_value, dtype=compute_dtype)
374 prev_min_val_idx = tl.full([], 0, dtype=tl.int64)
375 last_mask = t_idx == (BLOCK_SIZE - 1)
377 for l_idx in tl.range(loop_num):
378 b_idx = l_idx * BLOCK_SIZE + t_idx
379 mask = b_idx < B
380 offset = ac_offset + b_idx * C
382 inp_vals = tl.load(inp + offset, mask=mask, other=max_value)
383 # Only promote if necessary
384 if tl.constexpr(compute_dtype != inp.type.element_ty):
385 vals = inp_vals.to(compute_dtype)
386 else:
387 vals = inp_vals
388 idxs = b_idx
390 # cummin
391 result, cummin_indices = tl_cummin(vals, idxs, axis=0)
393 # broadcast
394 prev_min_val_b = tl.broadcast_to(prev_min_val, (BLOCK_SIZE,))
395 prev_min_val_idx_b = tl.broadcast_to(prev_min_val_idx, (BLOCK_SIZE,))
397 # Handle NaN and tie-breaking logic
398 if tl.constexpr(compute_dtype.is_floating()):
399 # For floats: handle NaN propagation + tie-break right
400 prev_is_nan = prev_min_val != prev_min_val
401 result_is_nan = result != result
402 prev_nan_mask = tl.broadcast_to(prev_is_nan, (BLOCK_SIZE,))
404 use_result = result_is_nan | (~prev_nan_mask & (result <= prev_min_val_b))
405 else:
406 # For integers: simple tie-break right
407 use_result = result <= prev_min_val_b
409 final_vals = tl.where(use_result, result, prev_min_val_b)
410 final_indices = tl.where(use_result, cummin_indices, prev_min_val_idx_b)
412 # update global min val and idx
413 prev_min_val = tl.sum(tl.where(last_mask, final_vals, 0), axis=0)
414 prev_min_val_idx = tl.sum(tl.where(last_mask, final_indices, 0), axis=0)
416 # store result
417 tl.store(out + offset, final_vals.to(out.type.element_ty), mask=mask)
418 tl.store(out_indices + offset, final_indices, mask=mask)
421def scan_then_fan_loop(inp, out, out_indices, A, B, C, dtype):
422 # TODO(all): tune on target board
423 BLOCK_SIZE = 1024
424 if B < 1024 * 4:
425 BLOCK_SIZE = triton.next_power_of_2(B)
426 loop_num = math.ceil(B / BLOCK_SIZE)
428 grid = (A, C)
429 with torch_device_fn.device(inp.device):
430 scan_part_min_abc_loop_kernel[grid](
431 inp,
432 out,
433 out_indices,
434 B,
435 C,
436 loop_num,
437 BLOCK_SIZE,
438 )
441def cummin(
442 input: Tensor,
443 dim: int,
444 *,
445 out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None,
446) -> torch.return_types.cummin:
447 logger.debug("GEMS_CAMBRICON CUMMIN")
448 assert dim >= -input.ndim and dim < input.ndim, "Invalid dim"
449 shape = input.shape
450 dim = dim % input.ndim
451 M = 1
452 N = shape[dim]
453 for i in range(dim):
454 M *= shape[i]
455 input = input.contiguous()
456 K = input.numel() // M // N
458 dtype = input.dtype
459 if dtype is torch.bool:
460 dtype = torch.int64
461 out = torch.empty_like(input, dtype=dtype)
462 out_indices = torch.empty_like(input, dtype=torch.int64)
464 compute_dtype = out.dtype
465 if input.dtype == torch.float16 or input.dtype == torch.bfloat16:
466 compute_dtype = torch.float32
468 if M == 1 and K == 1:
469 scan_then_fan_col(input, out, out_indices, N, compute_dtype)
470 elif M * K <= 16:
471 scan_then_fan(input, out, out_indices, M, N, K, compute_dtype)
472 else:
473 scan_then_fan_loop(input, out, out_indices, M, N, K, compute_dtype)
474 return out, out_indices