Coverage for src/flag_gems/runtime/backend/_ascend/fla/cumsum.py: 0%
55 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
1# SPDX-License-Identifier: Apache-2.0
2# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
4#
5# This file contains code copied from the flash-linear-attention project.
6# The original source code was licensed under the MIT license and included
7# the following copyright notice:
8# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
9# ruff: noqa: E501
10# mypy: ignore-errors
11import torch
12import triton
13import triton.language as tl
15from .utils import prepare_chunk_indices
18@triton.heuristics(
19 {
20 "HAS_SCALE": lambda args: args["scale"] is not None,
21 "IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
22 }
23)
24@triton.jit(do_not_specialize=["T"])
25def chunk_local_cumsum_scalar_kernel(
26 s,
27 o,
28 scale,
29 cu_seqlens,
30 chunk_indices,
31 T,
32 B: tl.constexpr,
33 H: tl.constexpr,
34 BLOCK_T: tl.constexpr,
35 REVERSE: tl.constexpr,
36 HAS_SCALE: tl.constexpr,
37 IS_VARLEN: tl.constexpr,
38 HEAD_FIRST: tl.constexpr,
39 CHUNK_SIZE: tl.constexpr = 64,
40):
41 i_block, i_b = tl.program_id(0), tl.program_id(1)
42 N_CHUNKS: tl.constexpr = BLOCK_T // CHUNK_SIZE
44 if IS_VARLEN:
45 i_s, i_block = (
46 tl.load(chunk_indices + i_block * 2).to(tl.int32),
47 tl.load(chunk_indices + i_block * 2 + 1).to(tl.int32),
48 )
49 bos, eos = (
50 tl.load(cu_seqlens + i_s).to(tl.int32),
51 tl.load(cu_seqlens + i_s + 1).to(tl.int32),
52 )
53 T = eos - bos
54 else:
55 bos, eos = i_b * T, i_b * T + T
57 if HEAD_FIRST:
58 ptr_s = tl.make_block_ptr(
59 s + bos * H, (H, T), (T, 1), (0, i_block * BLOCK_T), (H, BLOCK_T), (1, 0)
60 )
61 ptr_o = tl.make_block_ptr(
62 o + bos * H, (H, T), (T, 1), (0, i_block * BLOCK_T), (H, BLOCK_T), (1, 0)
63 )
64 b_s = tl.load(ptr_s, boundary_check=(0,)).to(tl.float32)
65 b_s = tl.reshape(b_s, (H, N_CHUNKS, CHUNK_SIZE))
66 b_s = tl.trans(b_s, (2, 0, 1))
67 b_o = tl.cumsum(b_s, axis=0, reverse=REVERSE)
68 if HAS_SCALE:
69 b_o *= scale
70 b_o = tl.trans(b_o, (2, 0, 1))
71 b_o = tl.reshape(b_o, (H, BLOCK_T))
72 else:
73 ptr_s = tl.make_block_ptr(
74 s + bos * H, (T, H), (H, 1), (i_block * BLOCK_T, 0), (BLOCK_T, H), (1, 0)
75 )
76 ptr_o = tl.make_block_ptr(
77 o + bos * H, (T, H), (H, 1), (i_block * BLOCK_T, 0), (BLOCK_T, H), (1, 0)
78 )
79 b_s = tl.load(ptr_s, boundary_check=(0,)).to(tl.float32)
80 b_s = tl.reshape(b_s, (N_CHUNKS, CHUNK_SIZE, H))
81 b_s = tl.trans(b_s, (1, 0, 2))
82 b_o = tl.cumsum(b_s, axis=0, reverse=REVERSE)
83 if HAS_SCALE:
84 b_o *= scale
85 b_o = tl.trans(b_o, (1, 0, 2))
86 b_o = tl.reshape(b_o, (BLOCK_T, H))
88 tl.store(ptr_o, b_o.to(s.dtype.element_ty), boundary_check=(0,))
89 return
92def chunk_local_cumsum_scalar(
93 g,
94 chunk_size,
95 reverse: bool = False,
96 scale: float = None,
97 cu_seqlens: torch.Tensor | None = None,
98 head_first: bool = False,
99 output_dtype: torch.Tensor | None = torch.float,
100):
101 if head_first:
102 B, H, T = g.shape
103 else:
104 B, T, H = g.shape
105 assert chunk_size == 2 ** (
106 chunk_size.bit_length() - 1
107 ), "chunk_size must be a power of 2"
108 OPTIM_BLOCK_SIZE = triton.next_power_of_2((2**18) // (H * chunk_size))
109 block_indices = (
110 prepare_chunk_indices(cu_seqlens, chunk_size=OPTIM_BLOCK_SIZE)
111 if cu_seqlens is not None
112 else None
113 )
114 num_blocks = (
115 len(block_indices)
116 if cu_seqlens is not None
117 else triton.cdiv(T, OPTIM_BLOCK_SIZE)
118 )
119 g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)
120 grid = (num_blocks, B)
121 chunk_local_cumsum_scalar_kernel[grid](
122 s=g_org,
123 o=g,
124 scale=scale,
125 cu_seqlens=cu_seqlens,
126 chunk_indices=block_indices,
127 T=T,
128 B=B,
129 H=H,
130 BLOCK_T=OPTIM_BLOCK_SIZE,
131 CHUNK_SIZE=chunk_size,
132 HEAD_FIRST=head_first,
133 REVERSE=reverse,
134 num_warps=8,
135 num_stages=3,
136 )
137 return g
140def chunk_local_cumsum(
141 g: torch.Tensor,
142 chunk_size: int,
143 reverse: bool = False,
144 scale: float = None,
145 cu_seqlens: torch.Tensor | None = None,
146 head_first: bool = False,
147 output_dtype: torch.dtype | None = torch.float,
148 **kwargs,
149) -> torch.Tensor:
150 if cu_seqlens is not None:
151 assert (
152 g.shape[0] == 1
153 ), "Only batch size 1 is supported when cu_seqlens are provided"
154 if len(g.shape) == 3:
155 return chunk_local_cumsum_scalar(
156 g=g,
157 chunk_size=chunk_size,
158 reverse=reverse,
159 scale=scale,
160 cu_seqlens=cu_seqlens,
161 head_first=head_first,
162 output_dtype=output_dtype,
163 )
164 else:
165 raise ValueError(
166 f"Unsupported input shape {g.shape}, "
167 f"which should be (B, T, H, D) if `head_first=False` "
168 f"or (B, H, T, D) otherwise"
169 )