Coverage for src/flag_gems/ops/cumsum.py: 40%
328 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
1import functools
2import logging
3import math
5import torch
6import triton
7import triton.language as tl
8from torch._prims_common import is_boolean_dtype, is_integer_dtype
10from flag_gems.runtime import device, torch_device_fn
11from flag_gems.utils import get_device_properties, libentry
12from flag_gems.utils import triton_lang_extension as tle
14device = device.name
15logger = logging.getLogger(__name__)
18@functools.lru_cache
19def get_num_sms(idx: int) -> int:
20 return get_device_properties(idx).multi_processor_count
23@tl.constexpr
24def get_scan_accum_type(inp_dtype: tl.dtype) -> tl.dtype:
25 if inp_dtype.is_bf16() or inp_dtype.is_fp16():
26 return tl.float32
27 if inp_dtype.is_int(): # signed or not(including bool)
28 return tl.int64
29 else:
30 return inp_dtype
33@libentry()
34@triton.jit(do_not_specialize=["n_elements", "part_num"])
35def scan_part_sum_kernel(
36 inp,
37 out,
38 partial_sum,
39 n_elements,
40 part_num,
41 BLOCK_SIZE: tl.constexpr,
42):
43 pid = tle.program_id(0)
44 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
45 mask = offset < n_elements
47 inp_ptrs = inp + offset
48 inp_vals = tl.load(inp_ptrs, mask=mask)
49 if (
50 tl.constexpr(inp_vals.dtype.is_int64())
51 or tl.constexpr(inp_vals.dtype.is_uint64())
52 ) or tl.constexpr(inp_vals.dtype.is_fp64()):
53 inp_vals = inp_vals
54 elif tl.constexpr(inp_vals.dtype.is_int()):
55 inp_vals = inp_vals.to(tl.int32)
56 else:
57 inp_vals = inp_vals.to(tl.float32)
58 result = tl.cumsum(inp_vals, axis=0)
60 part_sum_via_sum = tl.sum(inp_vals)
62 out_ptrs = out + offset
63 tl.store(out_ptrs, result, mask=mask)
65 partial_sum_ptrs = partial_sum + pid
66 tl.store(partial_sum_ptrs, part_sum_via_sum)
69@libentry()
70@triton.jit(do_not_specialize=["n_elements", "part_num"])
71def add_base_sum_kernel(
72 out,
73 partial_sum,
74 n_elements,
75 part_num,
76 BLOCK_SIZE: tl.constexpr,
77):
78 pid = tle.program_id(0)
79 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
80 mask = offset < n_elements
82 out_ptrs = out + offset
83 out_vals = tl.load(out_ptrs, mask=mask)
85 if pid > 0:
86 partial_sum_ptrs = partial_sum + pid - 1
87 last_part_sum_via_sum = tl.load(partial_sum_ptrs)
89 final_vals = out_vals + last_part_sum_via_sum
90 tl.store(out_ptrs, final_vals.to(out_vals.dtype), mask=mask)
93@libentry()
94@triton.jit(do_not_specialize=["part_num"])
95def scan_part_sum_abc_kernel(
96 inp,
97 out,
98 partial_sum,
99 B,
100 C,
101 part_num,
102 BLOCK_SIZE: tl.constexpr,
103):
104 pid_a = tle.program_id(0)
105 pid_b = tle.program_id(1)
106 pid_c = tle.program_id(2)
108 a_idx = pid_a
109 b_idx = pid_b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
110 c_idx = pid_c
112 offset = a_idx * B * C + b_idx * C + c_idx
113 base_part_offset = a_idx * part_num * C + c_idx
114 part_offset = base_part_offset + pid_b * C
116 mask = b_idx < B
117 inp_ptrs = inp + offset
118 inp_vals = tl.load(inp_ptrs, mask=mask)
119 if (
120 tl.constexpr(inp_vals.dtype.is_int64())
121 or tl.constexpr(inp_vals.dtype.is_uint64())
122 ) or tl.constexpr(inp_vals.dtype.is_fp64()):
123 inp_vals = inp_vals
124 elif tl.constexpr(inp_vals.dtype.is_int()):
125 inp_vals = inp_vals.to(tl.int32)
126 else:
127 inp_vals = inp_vals.to(tl.float32)
128 result = tl.cumsum(inp_vals, axis=0)
130 part_sum_via_sum = tl.sum(inp_vals)
132 out_ptrs = out + offset
133 tl.store(out_ptrs, result, mask=mask)
135 partial_sum_ptrs = partial_sum + part_offset
136 tl.store(partial_sum_ptrs, part_sum_via_sum)
139@libentry()
140@triton.jit(do_not_specialize=["part_num"])
141def add_base_sum_abc_kernel(
142 out,
143 partial_sum,
144 B,
145 C,
146 part_num,
147 BLOCK_SIZE: tl.constexpr,
148):
149 pid_a = tle.program_id(0)
150 pid_b = tle.program_id(1)
151 pid_c = tle.program_id(2)
153 a_idx = pid_a
154 b_idx = pid_b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
155 c_idx = pid_c
157 base_offset = a_idx * B * C + c_idx
158 offset = base_offset + b_idx * C
159 base_part_offset = a_idx * part_num * C + c_idx
160 last_part_offset = base_part_offset + (pid_b - 1) * C
162 mask = b_idx < B
163 out_ptrs = out + offset
164 out_vals = tl.load(out_ptrs, mask=mask)
166 if pid_b > 0:
167 partial_sum_ptrs = partial_sum + last_part_offset
168 last_part_sum_via_sum = tl.load(partial_sum_ptrs)
170 final_vals = out_vals + last_part_sum_via_sum
171 tl.store(out_ptrs, final_vals.to(out_vals.dtype), mask=mask)
174def scan_then_fan_col(inp, out, n_ele, dtype):
175 # TODO(all): tune on target board
176 BLOCK_SIZE = 1024
177 if n_ele <= 1024 * 4:
178 BLOCK_SIZE = triton.next_power_of_2(n_ele)
179 part_num = math.ceil(n_ele / BLOCK_SIZE)
180 partial_sum = torch.empty(part_num, dtype=dtype, device=inp.device)
182 grid = (part_num,)
183 with torch_device_fn.device(inp.device):
184 scan_part_sum_kernel[grid](inp, out, partial_sum, n_ele, part_num, BLOCK_SIZE)
186 if part_num >= 2:
187 scan_then_fan_col(partial_sum, partial_sum, part_num, dtype)
188 with torch_device_fn.device(inp.device):
189 add_base_sum_kernel[grid](out, partial_sum, n_ele, part_num, BLOCK_SIZE)
192def scan_then_fan(inp, out, A, B, C, dtype):
193 # TODO(all): tune on target board
194 BLOCK_SIZE = 1024
195 if B <= 1024 * 4:
196 BLOCK_SIZE = triton.next_power_of_2(B)
197 part_num = math.ceil(B / BLOCK_SIZE)
198 partial_sum = torch.empty(A, part_num, C, dtype=dtype, device=inp.device)
200 grid = (A, part_num, C)
201 with torch_device_fn.device(inp.device):
202 scan_part_sum_abc_kernel[grid](
203 inp, out, partial_sum, B, C, part_num, BLOCK_SIZE
204 )
206 if part_num >= 2:
207 scan_then_fan(partial_sum, partial_sum, A, part_num, C, dtype)
208 with torch_device_fn.device(inp.device):
209 add_base_sum_abc_kernel[grid](out, partial_sum, B, C, part_num, BLOCK_SIZE)
212def cumsum_wrapper(inp, dim=1, dtype=None, out=None):
213 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
214 shape = inp.shape
215 dim = dim % inp.ndim
216 M = 1
217 N = shape[dim]
218 for i in range(dim):
219 M *= shape[i]
220 inp = inp.contiguous()
221 K = inp.numel() // M // N
223 if dtype is None:
224 dtype = inp.dtype
225 if is_integer_dtype(dtype) or is_boolean_dtype(dtype):
226 dtype = torch.int64
227 if out is None:
228 out = torch.empty_like(inp, dtype=dtype)
230 compute_dtype = out.dtype
231 if inp.dtype == torch.float16 or inp.dtype == torch.bfloat16:
232 compute_dtype = torch.float32
234 if K == 1: # row scan
235 reduce_then_scan_row(inp, out, M, N, compute_dtype)
236 else: # col scan
237 scan_then_fan(inp, out, M, N, K, compute_dtype)
239 return out
242def reduce_then_scan_row(x, out, M, N, compute_dtype):
243 if N <= 16384: # persistent
244 TILE_SIZE = triton.next_power_of_2(N)
245 num_warps = 8 if TILE_SIZE > 2048 else 4
246 reduce_then_scan_root_scan_kernel_row[(M, 1, 1)](
247 x, out, N, TILE_SIZE, num_warps=num_warps
248 )
249 return out
251 TILE_SIZE = min(4096, triton.next_power_of_2(N))
252 num_warps = 8 if TILE_SIZE > 2048 else 4
253 num_tiles = triton.cdiv(N, TILE_SIZE)
254 max_ctas = get_num_sms(x.device.index) * 4
255 num_ctas = min(num_tiles, max_ctas)
256 ROOT_SCAN_TILE_SIZE = triton.next_power_of_2(num_ctas)
257 tiles_per_cta = triton.cdiv(num_tiles, num_ctas)
258 block_sums = torch.empty(
259 (
260 M,
261 num_ctas,
262 ),
263 dtype=compute_dtype,
264 device=x.device,
265 )
266 block_inclusive_prefix = torch.empty(
267 (
268 M,
269 num_ctas,
270 ),
271 dtype=compute_dtype,
272 device=x.device,
273 )
275 # 3-kernel implementation
276 reduce_then_scan_block_sum_kernel_row[(M, num_ctas, 1, 1)](
277 x, block_sums, N, tiles_per_cta, TILE_SIZE, num_warps=num_warps
278 )
279 reduce_then_scan_root_scan_kernel_row[(M, 1, 1)](
280 block_sums,
281 block_inclusive_prefix,
282 num_ctas,
283 ROOT_SCAN_TILE_SIZE,
284 num_warps=num_warps,
285 )
286 reduce_then_scan_block_scan_kernel_row[(M, num_ctas, 1)](
287 x,
288 block_inclusive_prefix,
289 out,
290 N,
291 num_ctas,
292 tiles_per_cta,
293 TILE_SIZE,
294 num_warps=num_warps,
295 )
296 return out
299@triton.jit
300def reduce_then_scan_block_sum_kernel_row(
301 in_ptr,
302 block_sum_ptr,
303 N,
304 tiles_per_cta,
305 TILE_SIZE: tl.constexpr,
306):
307 """The same kernel as the block sum in parallel reduce"""
308 pid_n = tl.program_id(1).to(tl.int64)
309 pid_m = tl.program_id(0).to(tl.int64)
310 num_programs_n = tl.num_programs(1)
311 block_offset = pid_n * (tiles_per_cta * TILE_SIZE)
312 block_end = min(block_offset + tiles_per_cta * TILE_SIZE, N)
314 acc_dtype: tl.constexpr = get_scan_accum_type(in_ptr.type.element_ty)
315 acc = tl.zeros((TILE_SIZE,), dtype=acc_dtype)
316 for start in range(block_offset, block_end, TILE_SIZE):
317 offsets = start + tl.arange(0, TILE_SIZE)
318 x = tl.load(in_ptr + pid_m * N + offsets, mask=offsets < N).to(acc_dtype)
319 acc += x
320 block_sum = tl.sum(acc, 0)
321 tl.store(
322 block_sum_ptr + pid_m * num_programs_n + pid_n, block_sum, cache_modifier=".cg"
323 )
326@triton.jit
327def reduce_then_scan_root_scan_kernel_row(in_ptr, out_ptr, N, TILE_SIZE: tl.constexpr):
328 """Almost The same kernel as the persistent scan kernel"""
329 pid = tl.program_id(0).to(tl.int64)
330 offsets = tl.arange(0, TILE_SIZE)
331 mask = offsets < N
332 acc_dtype: tl.constexpr = get_scan_accum_type(in_ptr.type.element_ty)
333 x = tl.load(in_ptr + pid * N + offsets, mask=mask, other=0).to(acc_dtype)
334 out = tl.cumsum(x, 0)
335 tl.store(out_ptr + pid * N + offsets, out, mask=mask)
338@triton.jit
339def reduce_then_scan_block_scan_kernel_row(
340 in_ptr,
341 previous_sum_ptr,
342 out_ptr,
343 N,
344 num_tiles_n,
345 tiles_per_cta,
346 TILE_SIZE: tl.constexpr,
347):
348 pid_m = tl.program_id(0).to(tl.int64)
349 pid_n = tl.program_id(1).to(tl.int64)
350 block_offset = pid_n * (tiles_per_cta * TILE_SIZE)
351 block_end = min(block_offset + tiles_per_cta * TILE_SIZE, N)
352 acc_dtype: tl.constexpr = get_scan_accum_type(in_ptr.type.element_ty)
354 prefix = tl.load(
355 previous_sum_ptr + pid_m * num_tiles_n + pid_n - 1, mask=pid_n > 0, other=0
356 ).to(acc_dtype)
357 for start in range(block_offset, block_end, TILE_SIZE):
358 offsets = start + tl.arange(0, TILE_SIZE)
359 mask = offsets < N
360 x = tl.load(in_ptr + pid_m * N + offsets, mask=mask).to(acc_dtype)
361 tile_scan = prefix + tl.cumsum(x, 0)
362 prefix += tl.sum(x, 0)
363 tl.store(
364 out_ptr + pid_m * N + offsets, tile_scan, mask=mask, cache_modifier=".cg"
365 )
368def cumsum(inp, dim=1, *, dtype=None):
369 logger.debug("GEMS CUMSUM")
370 return cumsum_wrapper(inp, dim, dtype)
373def cumsum_out(inp, dim=1, *, dtype=None, out):
374 logger.debug("GEMS CUMSUM_OUT")
375 return cumsum_wrapper(inp, dim, dtype, out)
378@libentry()
379@triton.jit(do_not_specialize=["K"])
380def normed_cumsum_kernel(inp, out, K, BLOCK: tl.constexpr):
381 row_start = tle.program_id(0) * K
382 row_off = tl.arange(0, BLOCK)
383 x = tl.load(inp + row_start + row_off, mask=row_off < K, other=0)
384 if x.dtype.is_fp16():
385 x = x.to(tl.float32)
386 y_sum = tl.sum(x, 0)
387 y = tl.cumsum(x, 0)
388 y = y / y_sum
389 tl.store(out + row_start + row_off, y, mask=row_off < K)
392@libentry()
393@triton.jit(
394 do_not_specialize=[
395 "r",
396 "t",
397 "R",
398 "K",
399 "r_stride",
400 "out_r_stride",
401 ]
402)
403def block_cumsum_kernel(
404 inp,
405 out,
406 sums,
407 r,
408 t,
409 R,
410 K,
411 r_stride,
412 k_stride,
413 out_r_stride,
414 out_k_stride,
415 OUTPUT_SUMS: tl.constexpr,
416 NORMALIZE: tl.constexpr,
417 HAS_OUT_LAYOUT: tl.constexpr,
418 TILE: tl.constexpr,
419):
420 # One CTA processes a (r, t*tile) chunk
421 # rows = [ grid.y, grid.y + r )
422 # cols = [ grid.x * t * tile, (grid.x + 1) * t * tile )
423 gridx = tle.program_id(0).to(tl.int64)
424 gridy = tle.program_id(1).to(tl.int64)
425 n_chunks = tle.num_programs(0)
427 for row in range(gridy * r, min((gridy + 1) * r, R)):
428 curr_cumsum = tl.zeros((1,), tl.float32)
429 row_offset = row * r_stride
430 cols = gridx * t * TILE + tl.arange(0, TILE)
431 for ti in range(0, t):
432 cols_offset = cols * k_stride
433 x = tl.load(inp + row_offset + cols_offset, mask=cols < K, other=0)
434 if x.dtype.is_fp16() | x.dtype.is_bf16():
435 x = x.to(tl.float32)
436 tile_sum = tl.sum(x, 0)[None]
437 tile_cumsum = tl.cumsum(x, 0) + curr_cumsum
438 curr_cumsum += tile_sum
439 if HAS_OUT_LAYOUT:
440 cols_offset = cols * out_k_stride
441 row_offset = row * out_r_stride
442 tl.store(out + row_offset + cols_offset, tile_cumsum, mask=cols < K)
443 if OUTPUT_SUMS:
444 tl.store(sums + row * n_chunks + gridx[None], curr_cumsum)
445 cols += TILE
446 if NORMALIZE:
447 cols = gridx * t * TILE + tl.arange(0, TILE)
448 for _ in range(0, t):
449 cols_offset = cols * k_stride
450 if HAS_OUT_LAYOUT:
451 cols_offset = cols * out_k_stride
452 row_offset = row * out_r_stride
453 x = tl.load(out + row_offset + cols_offset, mask=cols < K, other=0)
454 if x.dtype.is_fp16() | x.dtype.is_bf16():
455 x = x.to(tl.float32)
456 x = x / curr_cumsum
457 tl.store(out + row_offset + cols_offset, x, mask=cols < K)
458 cols += TILE
461@libentry()
462@triton.jit(
463 do_not_specialize=[
464 "r",
465 "t",
466 "R",
467 "K",
468 "r_stride",
469 "out_r_stride",
470 ]
471)
472def block_update_kernel(
473 inp,
474 base,
475 rscale_ptr,
476 out,
477 r,
478 t,
479 R,
480 K,
481 r_stride,
482 k_stride,
483 out_r_stride,
484 out_k_stride,
485 rscale_stride,
486 HAS_OUT_LAYOUT: tl.constexpr,
487 TILE: tl.constexpr,
488):
489 # One CTA processes a (r, t*tile) chunk
490 # rows = [ grid.y, grid.y + r )
491 # cols = [ grid.x * t * tile, (grid.x + 1) * t * tile )
492 gridx = tle.program_id(0).to(tl.int64)
493 gridy = tle.program_id(1).to(tl.int64)
494 n_gridx = tle.num_programs(1)
496 base += gridy * n_gridx + gridx
497 rscale_ptr += gridy * rscale_stride
499 for row in range(gridy, min(gridy + r, R)):
500 d = tl.load(base)
501 rscale = tl.load(rscale_ptr)
502 base += gridx
503 rscale_ptr += rscale_stride
504 row_offset = row * r_stride
505 cols = gridx * t * TILE + tl.arange(0, TILE)
506 for _ in range(0, t):
507 cols_offset = cols * k_stride
508 x = tl.load(inp + row_offset + cols_offset, mask=cols < K, other=0)
509 x += d
510 x /= rscale
511 if HAS_OUT_LAYOUT:
512 cols_offset = cols * out_k_stride
513 row_offset = row * out_r_stride
514 tl.store(out + row_offset + cols_offset, x, mask=cols < K)
515 cols += TILE
518GRID_Y_LIMIT = 65535
521def normed_cumsum(inp, dim=-1):
522 logger.debug("GEMS NORMED_CUMSUM")
523 assert inp.dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64)
524 dim = dim % inp.ndim
525 N = inp.numel()
526 K = inp.size(dim)
527 # inp = inp.contiguous()
528 # First and last dims are easier to handle, but transpose the middle dim to the last
529 ranked_dims = sorted(range(inp.ndim), key=lambda i: inp.stride(i), reverse=True)
530 is_mid_dim = dim not in (ranked_dims[0], ranked_dims[-1])
531 if is_mid_dim:
532 inp = inp.transpose(dim, -1).contiguous()
533 dim = -1
534 out = torch.empty_like(inp)
535 with torch_device_fn.device(inp.device.index):
536 # Pass one, scan a (batch, n_tiles * TILE) sized block within each cta
537 num_sms = get_device_properties(device).multi_processor_count
538 TILE = 2048
539 # Each row is split into n_chunks of chunks where each chunk is compised of
540 # n_tiles of tiles. Different chunks are assigned to different ctas.
541 n_rows = N // K
542 n_chunks = min(triton.cdiv(num_sms, n_rows), triton.cdiv(K, TILE))
543 n_tiles = triton.cdiv(triton.cdiv(K, TILE), n_chunks)
544 k_stride = inp.stride(dim)
545 r_stride = inp.size(dim) if k_stride == 1 else 1
546 if n_rows > GRID_Y_LIMIT:
547 batch = triton.cdiv(n_rows, GRID_Y_LIMIT)
548 n_batch = triton.cdiv(n_rows, batch)
549 else:
550 batch = 1
551 n_batch = n_rows
553 grid = (n_chunks, n_batch)
554 if n_chunks == 1:
555 block_cumsum_kernel[grid](
556 inp,
557 out,
558 0,
559 batch,
560 n_tiles,
561 n_rows,
562 K,
563 r_stride,
564 k_stride,
565 r_stride,
566 k_stride,
567 OUTPUT_SUMS=False,
568 NORMALIZE=True,
569 HAS_OUT_LAYOUT=False,
570 TILE=TILE,
571 )
572 return out
574 if inp.dtype != torch.float64:
575 acc_dtype = torch.float32
576 sums = torch.empty((n_rows, n_chunks), dtype=acc_dtype, device=device.name)
577 cumsums = torch.empty_like(sums)
578 block_cumsum_kernel[grid](
579 inp,
580 out,
581 sums,
582 batch,
583 n_tiles,
584 n_rows,
585 K,
586 r_stride,
587 k_stride,
588 r_stride,
589 k_stride,
590 OUTPUT_SUMS=True,
591 NORMALIZE=False,
592 HAS_OUT_LAYOUT=False,
593 TILE=TILE,
594 )
595 # Pass two, scan partial cumsums
596 block_cumsum_kernel[(1, n_batch)](
597 sums,
598 cumsums,
599 0,
600 batch,
601 1,
602 n_rows,
603 n_chunks,
604 n_chunks,
605 1,
606 n_chunks,
607 1,
608 OUTPUT_SUMS=False,
609 NORMALIZE=False,
610 HAS_OUT_LAYOUT=True,
611 TILE=TILE,
612 )
613 # print(sums)
614 rscale = cumsums[..., -1]
615 block_update_kernel[grid](
616 out,
617 cumsums - sums,
618 rscale,
619 out,
620 batch,
621 n_tiles,
622 n_rows,
623 K,
624 r_stride,
625 k_stride,
626 r_stride,
627 k_stride,
628 n_chunks,
629 HAS_OUT_LAYOUT=False,
630 TILE=TILE,
631 )
632 return out