Coverage for src/flag_gems/runtime/backend/_ascend/fla/chunk_o.py: 0%

66 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-28 12:23 +0800

1# SPDX-License-Identifier: Apache-2.0 

2# SPDX-FileCopyrightText: Copyright contributors to the vLLM project 

3# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang 

4# 

5# This file contains code copied from the flash-linear-attention project. 

6# The original source code was licensed under the MIT license and included 

7# the following copyright notice: 

8# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang 

9 

10# ruff: noqa: E501 

11# mypy: ignore-errors 

12import torch 

13import triton 

14import triton.language as tl 

15 

16from .utils import prepare_chunk_offsets, safe_exp 

17 

18 

19@triton.heuristics( 

20 { 

21 "USE_G": lambda args: args["g"] is not None, 

22 "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, 

23 } 

24) 

25@triton.jit(do_not_specialize=["T"]) 

26def chunk_fwd_kernel_o( 

27 q, 

28 k, 

29 v, 

30 h, 

31 g, 

32 o, 

33 cu_seqlens, 

34 chunk_offsets, 

35 scale, 

36 T, 

37 H: tl.constexpr, 

38 Hg: tl.constexpr, 

39 K: tl.constexpr, 

40 V: tl.constexpr, 

41 BT: tl.constexpr, 

42 BK: tl.constexpr, 

43 BV: tl.constexpr, 

44 USE_G: tl.constexpr, 

45 IS_VARLEN: tl.constexpr, 

46): 

47 i_v, i_nh = tl.program_id(0), tl.program_id(1) 

48 i_n, i_h = i_nh // H, i_nh % H 

49 T_max = T 

50 

51 if IS_VARLEN: 

52 bos, eos = ( 

53 tl.load(cu_seqlens + i_n).to(tl.int32), 

54 tl.load(cu_seqlens + i_n + 1).to(tl.int32), 

55 ) 

56 T = eos - bos 

57 NT = tl.cdiv(T, BT) 

58 boh = tl.load(chunk_offsets + i_n).to(tl.int64) 

59 else: 

60 bos, eos = i_n * T, i_n * T + T 

61 NT = tl.cdiv(T, BT) 

62 boh = i_n * NT 

63 

64 # offset calculation 

65 q += (bos * Hg + i_h // (H // Hg)) * K 

66 k += (bos * Hg + i_h // (H // Hg)) * K 

67 v += (bos * H + i_h) * V 

68 o += (bos * H + i_h) * V 

69 

70 for i_t in range(NT): 

71 i_tg = boh + i_t 

72 h_base = h + (i_tg * H + i_h).to(tl.int64) * K * V 

73 b_o = tl.zeros([BT, BV], dtype=tl.float32) 

74 b_A = tl.zeros([BT, BT], dtype=tl.float32) 

75 

76 for i_k in range(tl.cdiv(K, BK)): 

77 p_q = tl.make_block_ptr( 

78 q, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0) 

79 ) 

80 p_k = tl.make_block_ptr( 

81 k, (K, T), (1, Hg * K), (i_k * BK, i_t * BT), (BK, BT), (0, 1) 

82 ) 

83 p_h = tl.make_block_ptr( 

84 h_base, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0) 

85 ) 

86 # [BT, BK] 

87 b_q = tl.load(p_q, boundary_check=(0, 1)) 

88 # [BK, BT] 

89 b_k = tl.load(p_k, boundary_check=(0, 1)) 

90 # [BK, BV] 

91 b_h = tl.load(p_h, boundary_check=(0, 1)) 

92 

93 # [BT, BK] @ [BK, BV] -> [BT, BV] 

94 b_o += tl.dot(b_q, b_h) 

95 # [BT, BK] @ [BK, BT] -> [BT, BT] 

96 b_A += tl.dot(b_q, b_k) 

97 

98 if USE_G: 

99 offs_t = i_t * BT + tl.arange(0, BT) 

100 mask_t = offs_t < T 

101 g_ptr = g + bos + i_h * T_max 

102 b_g = tl.load(g_ptr + offs_t, mask=mask_t, other=0.0) 

103 

104 b_o = b_o * tl.exp(b_g)[:, None] 

105 b_A = b_A * safe_exp(b_g[:, None] - b_g[None, :]) 

106 

107 o_i = tl.arange(0, BT).to(tl.float32) 

108 m_A = o_i[:, None] >= o_i[None, :] 

109 b_A = tl.where(m_A, b_A, 0) 

110 

111 p_v = tl.make_block_ptr( 

112 v, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0) 

113 ) 

114 p_o = tl.make_block_ptr( 

115 o, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0) 

116 ) 

117 

118 b_v = tl.load(p_v, boundary_check=(0, 1)) 

119 # to fix mma -> mma layout conversion 

120 # already solved by fla v3.2 or higher 

121 b_o = b_o * scale + tl.dot(b_A.to(b_v.dtype), b_v) * scale 

122 tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) 

123 

124 

125def chunk_fwd_o( 

126 q: torch.Tensor, 

127 k: torch.Tensor, 

128 v: torch.Tensor, 

129 h: torch.Tensor, 

130 g: torch.Tensor | None = None, 

131 scale: float | None = None, 

132 cu_seqlens: torch.LongTensor | None = None, 

133 chunk_size: int = 64, 

134) -> torch.Tensor: 

135 B, T, Hg, K, V = *q.shape, v.shape[-1] 

136 H = v.shape[-2] 

137 BT = chunk_size 

138 

139 if scale is None: 

140 scale = k.shape[-1] ** -0.5 

141 

142 o = torch.empty_like(v) 

143 if cu_seqlens is None: 

144 N, chunk_offsets = B, None 

145 else: 

146 N, chunk_offsets = ( 

147 len(cu_seqlens) - 1, 

148 prepare_chunk_offsets(cu_seqlens, BT), 

149 ) 

150 

151 def grid(meta): 

152 return (triton.cdiv(V, meta["BV"]), N * H) 

153 

154 g = g.transpose(1, 2).contiguous() 

155 chunk_fwd_kernel_o[grid]( 

156 q=q, 

157 k=k, 

158 v=v, 

159 h=h, 

160 g=g, 

161 o=o, 

162 cu_seqlens=cu_seqlens, 

163 chunk_offsets=chunk_offsets, 

164 scale=scale, 

165 T=T, 

166 H=H, 

167 Hg=Hg, 

168 K=K, 

169 V=V, 

170 BT=BT, 

171 BK=128, 

172 BV=128, 

173 num_warps=4, 

174 num_stages=2, 

175 ) 

176 return o