Coverage for src/flag_gems/runtime/backend/_ascend/ops/cumsum.py: 0%
321 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-26 15:32 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-26 15:32 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
7from torch._prims_common import is_boolean_dtype, is_integer_dtype
9from flag_gems.runtime import device, torch_device_fn
10from flag_gems.utils import libentry
11from flag_gems.utils import triton_lang_extension as tle
13from ..utils import CORE_NUM
15device = device.name
16logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}')
19@tl.constexpr
20def get_scan_accum_type(inp_dtype: tl.dtype) -> tl.dtype:
21 if inp_dtype.is_bf16() or inp_dtype.is_fp16():
22 return tl.float32
23 if inp_dtype.is_int(): # signed or not(including bool)
24 return tl.int64
25 else:
26 return inp_dtype
29@libentry()
30@triton.jit(do_not_specialize=["n_elements", "part_num"])
31def scan_part_sum_kernel(
32 inp,
33 out,
34 partial_sum,
35 n_elements,
36 part_num,
37 BLOCK_SIZE: tl.constexpr,
38):
39 pid = tle.program_id(0)
40 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
41 mask = offset < n_elements
43 inp_ptrs = inp + offset
44 inp_vals = tl.load(inp_ptrs, mask=mask)
45 if (
46 tl.constexpr(inp_vals.dtype.is_int64())
47 or tl.constexpr(inp_vals.dtype.is_uint64())
48 ) or tl.constexpr(inp_vals.dtype.is_fp64()):
49 inp_vals = inp_vals
50 elif tl.constexpr(inp_vals.dtype.is_int()):
51 inp_vals = inp_vals.to(tl.int32)
52 else:
53 inp_vals = inp_vals.to(tl.float32)
54 result = tl.cumsum(inp_vals, axis=0)
56 part_sum_via_sum = tl.sum(inp_vals)
58 out_ptrs = out + offset
59 tl.store(out_ptrs, result, mask=mask)
61 partial_sum_ptrs = partial_sum + pid
62 tl.store(partial_sum_ptrs, part_sum_via_sum)
65@libentry()
66@triton.jit(do_not_specialize=["n_elements", "part_num"])
67def add_base_sum_kernel(
68 out,
69 partial_sum,
70 n_elements,
71 part_num,
72 BLOCK_SIZE: tl.constexpr,
73):
74 pid = tle.program_id(0)
75 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
76 mask = offset < n_elements
78 out_ptrs = out + offset
79 out_vals = tl.load(out_ptrs, mask=mask)
81 if pid > 0:
82 partial_sum_ptrs = partial_sum + pid - 1
83 last_part_sum_via_sum = tl.load(partial_sum_ptrs)
85 final_vals = out_vals + last_part_sum_via_sum
86 tl.store(out_ptrs, final_vals.to(out_vals.dtype), mask=mask)
89@libentry()
90@triton.jit(do_not_specialize=["part_num"])
91def scan_part_sum_abc_kernel(
92 inp,
93 out,
94 partial_sum,
95 B,
96 C,
97 part_num,
98 BLOCK_SIZE: tl.constexpr,
99):
100 pid_a = tle.program_id(0)
101 pid_b = tle.program_id(1)
102 pid_c = tle.program_id(2)
104 a_idx = pid_a
105 b_idx = pid_b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
106 c_idx = pid_c
108 offset = a_idx * B * C + b_idx * C + c_idx
109 base_part_offset = a_idx * part_num * C + c_idx
110 part_offset = base_part_offset + pid_b * C
112 mask = b_idx < B
113 inp_ptrs = inp + offset
114 inp_vals = tl.load(inp_ptrs, mask=mask)
115 if (
116 tl.constexpr(inp_vals.dtype.is_int64())
117 or tl.constexpr(inp_vals.dtype.is_uint64())
118 ) or tl.constexpr(inp_vals.dtype.is_fp64()):
119 inp_vals = inp_vals
120 elif tl.constexpr(inp_vals.dtype.is_int()):
121 inp_vals = inp_vals.to(tl.int32)
122 else:
123 inp_vals = inp_vals.to(tl.float32)
124 result = tl.cumsum(inp_vals, axis=0)
126 part_sum_via_sum = tl.sum(inp_vals)
128 out_ptrs = out + offset
129 tl.store(out_ptrs, result, mask=mask)
131 partial_sum_ptrs = partial_sum + part_offset
132 tl.store(partial_sum_ptrs, part_sum_via_sum)
135@libentry()
136@triton.jit(do_not_specialize=["part_num"])
137def add_base_sum_abc_kernel(
138 out,
139 partial_sum,
140 B,
141 C,
142 part_num,
143 BLOCK_SIZE: tl.constexpr,
144):
145 pid_a = tle.program_id(0)
146 pid_b = tle.program_id(1)
147 pid_c = tle.program_id(2)
149 a_idx = pid_a
150 b_idx = pid_b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
151 c_idx = pid_c
153 base_offset = a_idx * B * C + c_idx
154 offset = base_offset + b_idx * C
155 base_part_offset = a_idx * part_num * C + c_idx
156 last_part_offset = base_part_offset + (pid_b - 1) * C
158 mask = b_idx < B
159 out_ptrs = out + offset
160 out_vals = tl.load(out_ptrs, mask=mask)
162 if pid_b > 0:
163 partial_sum_ptrs = partial_sum + last_part_offset
164 last_part_sum_via_sum = tl.load(partial_sum_ptrs)
166 final_vals = out_vals + last_part_sum_via_sum
167 tl.store(out_ptrs, final_vals.to(out_vals.dtype), mask=mask)
170def scan_then_fan_col(inp, out, n_ele, dtype):
171 # TODO(all): tune on target board
172 BLOCK_SIZE = 1024
173 if n_ele <= 1024 * 4:
174 BLOCK_SIZE = triton.next_power_of_2(n_ele)
175 part_num = math.ceil(n_ele / BLOCK_SIZE)
176 partial_sum = torch.empty(part_num, dtype=dtype, device=inp.device)
178 grid = (part_num,)
179 with torch_device_fn.device(inp.device):
180 scan_part_sum_kernel[grid](inp, out, partial_sum, n_ele, part_num, BLOCK_SIZE)
182 if part_num >= 2:
183 scan_then_fan_col(partial_sum, partial_sum, part_num, dtype)
184 with torch_device_fn.device(inp.device):
185 add_base_sum_kernel[grid](out, partial_sum, n_ele, part_num, BLOCK_SIZE)
188def scan_then_fan(inp, out, A, B, C, dtype):
189 # TODO(all): tune on target board
190 BLOCK_SIZE = 1024
191 if B <= 1024 * 4:
192 BLOCK_SIZE = triton.next_power_of_2(B)
193 part_num = math.ceil(B / BLOCK_SIZE)
194 partial_sum = torch.empty(A, part_num, C, dtype=dtype, device=inp.device)
196 grid = (A, part_num, C)
197 with torch_device_fn.device(inp.device):
198 scan_part_sum_abc_kernel[grid](
199 inp, out, partial_sum, B, C, part_num, BLOCK_SIZE
200 )
202 if part_num >= 2:
203 scan_then_fan(partial_sum, partial_sum, A, part_num, C, dtype)
204 with torch_device_fn.device(inp.device):
205 add_base_sum_abc_kernel[grid](out, partial_sum, B, C, part_num, BLOCK_SIZE)
208def cumsum_wrapper(inp, dim=1, dtype=None, out=None):
209 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
210 shape = inp.shape
211 dim = dim % inp.ndim
212 M = 1
213 N = shape[dim]
214 for i in range(dim):
215 M *= shape[i]
216 inp = inp.contiguous()
217 K = inp.numel() // M // N
219 if dtype is None:
220 dtype = inp.dtype
221 if is_integer_dtype(dtype) or is_boolean_dtype(dtype):
222 dtype = torch.int64
223 if out is None:
224 out = torch.empty_like(inp, dtype=dtype)
226 compute_dtype = out.dtype
227 if inp.dtype == torch.float16 or inp.dtype == torch.bfloat16:
228 compute_dtype = torch.float32
230 if K == 1: # row scan
231 reduce_then_scan_row(inp, out, M, N, compute_dtype)
232 else: # col scan
233 scan_then_fan(inp, out, M, N, K, compute_dtype)
235 return out
238def reduce_then_scan_row(x, out, M, N, compute_dtype):
239 if N <= 16384: # persistent
240 TILE_SIZE = triton.next_power_of_2(N)
241 reduce_then_scan_root_scan_kernel_row[(M, 1, 1)](
242 x,
243 out,
244 N,
245 TILE_SIZE,
246 )
247 return out
249 TILE_SIZE = min(4096, triton.next_power_of_2(N))
250 num_tiles = triton.cdiv(N, TILE_SIZE)
251 num_ctas = num_tiles
252 ROOT_SCAN_TILE_SIZE = triton.next_power_of_2(num_ctas)
253 tiles_per_cta = triton.cdiv(num_tiles, num_ctas)
254 block_sums = torch.empty(
255 (
256 M,
257 num_ctas,
258 ),
259 dtype=compute_dtype,
260 device=x.device,
261 )
262 block_inclusive_prefix = torch.empty(
263 (
264 M,
265 num_ctas,
266 ),
267 dtype=compute_dtype,
268 device=x.device,
269 )
271 # 3-kernel implementation
272 reduce_then_scan_block_sum_kernel_row[(M, num_ctas, 1, 1)](
273 x,
274 block_sums,
275 N,
276 tiles_per_cta,
277 TILE_SIZE,
278 )
279 reduce_then_scan_root_scan_kernel_row[(M, 1, 1)](
280 block_sums,
281 block_inclusive_prefix,
282 num_ctas,
283 ROOT_SCAN_TILE_SIZE,
284 )
285 reduce_then_scan_block_scan_kernel_row[(M, num_ctas, 1)](
286 x,
287 block_inclusive_prefix,
288 out,
289 N,
290 num_ctas,
291 tiles_per_cta,
292 TILE_SIZE,
293 )
294 return out
297@triton.jit
298def reduce_then_scan_block_sum_kernel_row(
299 in_ptr,
300 block_sum_ptr,
301 N,
302 tiles_per_cta,
303 TILE_SIZE: tl.constexpr,
304):
305 """The same kernel as the block sum in parallel reduce"""
306 pid_n = tl.program_id(1).to(tl.int64)
307 pid_m = tl.program_id(0).to(tl.int64)
308 num_programs_n = tl.num_programs(1)
309 block_offset = pid_n * (tiles_per_cta * TILE_SIZE)
310 block_end = min(block_offset + tiles_per_cta * TILE_SIZE, N)
312 acc_dtype: tl.constexpr = get_scan_accum_type(in_ptr.type.element_ty)
313 acc = tl.zeros((TILE_SIZE,), dtype=acc_dtype)
314 for start in range(block_offset, block_end, TILE_SIZE):
315 offsets = start + tl.arange(0, TILE_SIZE)
316 x = tl.load(in_ptr + pid_m * N + offsets, mask=offsets < N).to(acc_dtype)
317 acc += x
318 block_sum = tl.sum(acc, 0)
319 tl.store(
320 block_sum_ptr + pid_m * num_programs_n + pid_n, block_sum, cache_modifier=".cg"
321 )
324@triton.jit
325def reduce_then_scan_root_scan_kernel_row(in_ptr, out_ptr, N, TILE_SIZE: tl.constexpr):
326 """Almost The same kernel as the persistent scan kernel"""
327 pid = tl.program_id(0).to(tl.int64)
328 offsets = tl.arange(0, TILE_SIZE)
329 mask = offsets < N
330 acc_dtype: tl.constexpr = get_scan_accum_type(in_ptr.type.element_ty)
331 x = tl.load(in_ptr + pid * N + offsets, mask=mask, other=0).to(acc_dtype)
332 out = tl.cumsum(x, 0)
333 tl.store(out_ptr + pid * N + offsets, out, mask=mask)
336@triton.jit
337def reduce_then_scan_block_scan_kernel_row(
338 in_ptr,
339 previous_sum_ptr,
340 out_ptr,
341 N,
342 num_tiles_n,
343 tiles_per_cta,
344 TILE_SIZE: tl.constexpr,
345):
346 pid_m = tl.program_id(0).to(tl.int64)
347 pid_n = tl.program_id(1).to(tl.int64)
348 block_offset = pid_n * (tiles_per_cta * TILE_SIZE)
349 block_end = min(block_offset + tiles_per_cta * TILE_SIZE, N)
350 acc_dtype: tl.constexpr = get_scan_accum_type(in_ptr.type.element_ty)
352 prefix = tl.load(
353 previous_sum_ptr + pid_m * num_tiles_n + pid_n - 1, mask=pid_n > 0, other=0
354 ).to(acc_dtype)
355 for start in range(block_offset, block_end, TILE_SIZE):
356 offsets = start + tl.arange(0, TILE_SIZE)
357 mask = offsets < N
358 x = tl.load(in_ptr + pid_m * N + offsets, mask=mask).to(acc_dtype)
359 tile_scan = prefix + tl.cumsum(x, 0)
360 prefix += tl.sum(x, 0)
361 tl.store(
362 out_ptr + pid_m * N + offsets, tile_scan, mask=mask, cache_modifier=".cg"
363 )
366def cumsum(inp, dim=1, *, dtype=None):
367 logger.debug("GEMS_ASCEND CUMSUM")
368 return cumsum_wrapper(inp, dim, dtype)
371def cumsum_out(inp, dim=1, *, dtype=None, out):
372 logger.debug("GEMS_ASCEND CUMSUM_OUT")
373 return cumsum_wrapper(inp, dim, dtype, out)
376@libentry()
377@triton.jit(do_not_specialize=["K"])
378def normed_cumsum_kernel(inp, out, K, BLOCK: tl.constexpr):
379 row_start = tle.program_id(0) * K
380 row_off = tl.arange(0, BLOCK)
381 x = tl.load(inp + row_start + row_off, mask=row_off < K, other=0)
382 if x.dtype.is_fp16():
383 x = x.to(tl.float32)
384 y_sum = tl.sum(x, 0)
385 y = tl.cumsum(x, 0)
386 y = y / y_sum
387 tl.store(out + row_start + row_off, y, mask=row_off < K)
390@libentry()
391@triton.jit(
392 do_not_specialize=[
393 "r",
394 "t",
395 "R",
396 "K",
397 "r_stride",
398 "out_r_stride",
399 ]
400)
401def block_cumsum_kernel(
402 inp,
403 out,
404 sums,
405 r,
406 t,
407 R,
408 K,
409 r_stride,
410 k_stride,
411 out_r_stride,
412 out_k_stride,
413 OUTPUT_SUMS: tl.constexpr,
414 NORMALIZE: tl.constexpr,
415 HAS_OUT_LAYOUT: tl.constexpr,
416 TILE: tl.constexpr,
417):
418 # One CTA processes a (r, t*tile) chunk
419 # rows = [ grid.y, grid.y + r )
420 # cols = [ grid.x * t * tile, (grid.x + 1) * t * tile )
421 gridx = tle.program_id(0).to(tl.int64)
422 gridy = tle.program_id(1).to(tl.int64)
423 n_chunks = tle.num_programs(0)
425 for row in range(gridy * r, min((gridy + 1) * r, R)):
426 curr_cumsum = tl.zeros((1,), tl.float32)
427 row_offset = row * r_stride
428 cols = gridx * t * TILE + tl.arange(0, TILE)
429 for ti in range(0, t):
430 cols_offset = cols * k_stride
431 x = tl.load(inp + row_offset + cols_offset, mask=cols < K, other=0)
432 if x.dtype.is_fp16() | x.dtype.is_bf16():
433 x = x.to(tl.float32)
434 tile_sum = tl.sum(x, 0)[None]
435 tile_cumsum = tl.cumsum(x, 0) + curr_cumsum
436 curr_cumsum += tile_sum
437 if HAS_OUT_LAYOUT:
438 cols_offset = cols * out_k_stride
439 row_offset = row * out_r_stride
440 tl.store(out + row_offset + cols_offset, tile_cumsum, mask=cols < K)
441 if OUTPUT_SUMS:
442 tl.store(sums + row * n_chunks + gridx[None], curr_cumsum)
443 cols += TILE
444 if NORMALIZE:
445 cols = gridx * t * TILE + tl.arange(0, TILE)
446 for _ in range(0, t):
447 cols_offset = cols * k_stride
448 if HAS_OUT_LAYOUT:
449 cols_offset = cols * out_k_stride
450 row_offset = row * out_r_stride
451 x = tl.load(out + row_offset + cols_offset, mask=cols < K, other=0)
452 if x.dtype.is_fp16() | x.dtype.is_bf16():
453 x = x.to(tl.float32)
454 x = x / curr_cumsum
455 tl.store(out + row_offset + cols_offset, x, mask=cols < K)
456 cols += TILE
459@libentry()
460@triton.jit(
461 do_not_specialize=[
462 "r",
463 "t",
464 "R",
465 "K",
466 "r_stride",
467 "out_r_stride",
468 ]
469)
470def block_update_kernel(
471 inp,
472 base,
473 rscale_ptr,
474 out,
475 r,
476 t,
477 R,
478 K,
479 r_stride,
480 k_stride,
481 out_r_stride,
482 out_k_stride,
483 rscale_stride,
484 HAS_OUT_LAYOUT: tl.constexpr,
485 TILE: tl.constexpr,
486):
487 # One CTA processes a (r, t*tile) chunk
488 # rows = [ grid.y, grid.y + r )
489 # cols = [ grid.x * t * tile, (grid.x + 1) * t * tile )
490 gridx = tle.program_id(0).to(tl.int64)
491 gridy = tle.program_id(1).to(tl.int64)
492 n_gridx = tle.num_programs(1)
494 base += gridy * n_gridx + gridx
495 rscale_ptr += gridy * rscale_stride
497 for row in range(gridy, min(gridy + r, R)):
498 d = tl.load(base)
499 rscale = tl.load(rscale_ptr)
500 base += gridx
501 rscale_ptr += rscale_stride
502 row_offset = row * r_stride
503 cols = gridx * t * TILE + tl.arange(0, TILE)
504 for _ in range(0, t):
505 cols_offset = cols * k_stride
506 x = tl.load(inp + row_offset + cols_offset, mask=cols < K, other=0)
507 x += d
508 x /= rscale
509 if HAS_OUT_LAYOUT:
510 cols_offset = cols * out_k_stride
511 row_offset = row * out_r_stride
512 tl.store(out + row_offset + cols_offset, x, mask=cols < K)
513 cols += TILE
516GRID_Y_LIMIT = 65535
519def normed_cumsum(inp, dim=-1):
520 logger.debug("GEMS_ASCEND NORMED_CUMSUM")
521 assert inp.dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64)
522 dim = dim % inp.ndim
523 N = inp.numel()
524 K = inp.size(dim)
525 # inp = inp.contiguous()
526 # First and last dims are easier to handle, but transpose the middle dim to the last
527 ranked_dims = sorted(range(inp.ndim), key=lambda i: inp.stride(i), reverse=True)
528 is_mid_dim = dim not in (ranked_dims[0], ranked_dims[-1])
529 if is_mid_dim:
530 inp = inp.transpose(dim, -1).contiguous()
531 dim = -1
532 out = torch.empty_like(inp)
533 with torch_device_fn.device(inp.device.index):
534 # Pass one, scan a (batch, n_tiles * TILE) sized block within each cta
535 TILE = 2048
536 # Each row is split into n_chunks of chunks where each chunk is compised of
537 # n_tiles of tiles. Different chunks are assigned to different ctas.
538 n_rows = N // K
539 n_chunks = min(triton.cdiv(CORE_NUM, n_rows), triton.cdiv(K, TILE))
540 n_tiles = triton.cdiv(triton.cdiv(K, TILE), n_chunks)
541 k_stride = inp.stride(dim)
542 r_stride = inp.size(dim) if k_stride == 1 else 1
543 if n_rows > GRID_Y_LIMIT:
544 batch = triton.cdiv(n_rows, GRID_Y_LIMIT)
545 n_batch = triton.cdiv(n_rows, batch)
546 else:
547 batch = 1
548 n_batch = n_rows
550 grid = (n_chunks, n_batch)
551 if n_chunks == 1:
552 block_cumsum_kernel[grid](
553 inp,
554 out,
555 0,
556 batch,
557 n_tiles,
558 n_rows,
559 K,
560 r_stride,
561 k_stride,
562 r_stride,
563 k_stride,
564 OUTPUT_SUMS=False,
565 NORMALIZE=True,
566 HAS_OUT_LAYOUT=False,
567 TILE=TILE,
568 )
569 return out
571 if inp.dtype != torch.float64:
572 acc_dtype = torch.float32
573 sums = torch.empty((n_rows, n_chunks), dtype=acc_dtype, device=device.name)
574 cumsums = torch.empty_like(sums)
575 block_cumsum_kernel[grid](
576 inp,
577 out,
578 sums,
579 batch,
580 n_tiles,
581 n_rows,
582 K,
583 r_stride,
584 k_stride,
585 r_stride,
586 k_stride,
587 OUTPUT_SUMS=True,
588 NORMALIZE=False,
589 HAS_OUT_LAYOUT=False,
590 TILE=TILE,
591 )
592 # Pass two, scan partial cumsums
593 block_cumsum_kernel[(1, n_batch)](
594 sums,
595 cumsums,
596 0,
597 batch,
598 1,
599 n_rows,
600 n_chunks,
601 n_chunks,
602 1,
603 n_chunks,
604 1,
605 OUTPUT_SUMS=False,
606 NORMALIZE=False,
607 HAS_OUT_LAYOUT=True,
608 TILE=TILE,
609 )
610 rscale = cumsums[..., -1]
611 block_update_kernel[grid](
612 out,
613 cumsums - sums,
614 rscale,
615 out,
616 batch,
617 n_tiles,
618 n_rows,
619 K,
620 r_stride,
621 k_stride,
622 r_stride,
623 k_stride,
624 n_chunks,
625 HAS_OUT_LAYOUT=False,
626 TILE=TILE,
627 )
628 return out