Coverage for src/flag_gems/runtime/backend/_ascend/fla/wy_fast.py: 0%

60 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-26 15:32 +0800

1# SPDX-License-Identifier: Apache-2.0 

2# SPDX-FileCopyrightText: Copyright contributors to the vLLM project 

3# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang 

4# 

5# This file contains code copied from the flash-linear-attention project. 

6# The original source code was licensed under the MIT license and included 

7# the following copyright notice: 

8# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang 

9 

10# ruff: noqa: E501 

11# mypy: ignore-errors 

12import torch 

13import triton 

14import triton.language as tl 

15 

16from .utils import prepare_chunk_indices 

17 

18 

19@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) 

20@triton.jit(do_not_specialize=["T"]) 

21def recompute_w_u_fwd_kernel( 

22 k, 

23 v, 

24 beta, 

25 w, 

26 u, 

27 A, 

28 g, 

29 cu_seqlens, 

30 chunk_indices, 

31 T, 

32 H: tl.constexpr, 

33 Hg: tl.constexpr, 

34 K: tl.constexpr, 

35 V: tl.constexpr, 

36 BT: tl.constexpr, 

37 BK: tl.constexpr, 

38 BV: tl.constexpr, 

39 IS_VARLEN: tl.constexpr, 

40): 

41 T_max = T 

42 i_t_o = tl.program_id(0) 

43 

44 for i_bh in range(H): 

45 i_b, i_h = i_bh // H, i_bh % H 

46 if IS_VARLEN: 

47 i_n, i_t = ( 

48 tl.load(chunk_indices + i_t_o * 2).to(tl.int32), 

49 tl.load(chunk_indices + i_t_o * 2 + 1).to(tl.int32), 

50 ) 

51 bos, eos = ( 

52 tl.load(cu_seqlens + i_n).to(tl.int32), 

53 tl.load(cu_seqlens + i_n + 1).to(tl.int32), 

54 ) 

55 T = eos - bos 

56 else: 

57 bos, eos = i_b * T, i_b * T + T 

58 

59 offs_t = tl.arange(0, BT) 

60 global_offs_t = i_t * BT + offs_t 

61 mask_t = global_offs_t < T 

62 

63 offs_t_2d = global_offs_t[:, None] 

64 offs_bt = tl.arange(0, BT)[None, :] 

65 ptr_A = A + (bos * H + i_h) * BT + offs_t_2d * (H * BT) + offs_bt * 1 

66 mask_A = mask_t[:, None] 

67 b_A = tl.load(ptr_A, mask=mask_A, other=0.0).to(tl.float32) 

68 

69 ptr_g = g + bos + i_h * T_max + global_offs_t 

70 b_g = tl.exp(tl.load(ptr_g, mask=mask_t, other=0.0)).to(tl.float32) 

71 

72 ptr_beta = beta + bos + i_h * T_max + global_offs_t 

73 b_beta = tl.load(ptr_beta, mask=mask_t, other=0.0).to(tl.float32) 

74 

75 for i_v in range(tl.cdiv(V, BV)): 

76 offs_v = i_v * BV + tl.arange(0, BV)[None, :] 

77 mask_v = (mask_t[:, None]) & (offs_v < V) 

78 

79 ptr_v = v + (bos * H + i_h) * V + offs_t_2d * (H * V) + offs_v * 1 

80 b_v = tl.load(ptr_v, mask=mask_v, other=0.0).to(tl.float32) 

81 

82 b_vb = b_v * b_beta[:, None] 

83 b_u = tl.dot(b_A, b_vb, allow_tf32=False) 

84 

85 ptr_u = u + (bos * H + i_h) * V + offs_t_2d * (H * V) + offs_v * 1 

86 tl.store(ptr_u, b_u.to(ptr_u.dtype.element_ty), mask=mask_v) 

87 

88 for i_k in range(tl.cdiv(K, BK)): 

89 offs_k = i_k * BK + tl.arange(0, BK)[None, :] 

90 mask_k = (mask_t[:, None]) & (offs_k < K) 

91 ptr_k = ( 

92 k 

93 + (bos * Hg + i_h // (H // Hg)) * K 

94 + offs_t_2d * (Hg * K) 

95 + offs_k * 1 

96 ) 

97 b_k = tl.load(ptr_k, mask=mask_k, other=0.0).to(tl.float32) 

98 

99 b_kb = b_k * b_beta[:, None] * b_g[:, None] 

100 b_w = tl.dot(b_A, b_kb) 

101 

102 ptr_w = w + (bos * H + i_h) * K + offs_t_2d * (H * K) + offs_k * 1 

103 tl.store(ptr_w, b_w.to(ptr_w.dtype.element_ty), mask=mask_k) 

104 

105 

106def recompute_w_u_fwd( 

107 k: torch.Tensor, 

108 v: torch.Tensor, 

109 beta: torch.Tensor, 

110 g_cumsum: torch.Tensor, 

111 A: torch.Tensor, 

112 cu_seqlens: torch.LongTensor | None = None, 

113) -> tuple[torch.Tensor, torch.Tensor]: 

114 B, T, Hg, K, V = *k.shape, v.shape[-1] 

115 H = v.shape[-2] 

116 BT = A.shape[-1] 

117 

118 chunk_indices = ( 

119 prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None 

120 ) 

121 NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) 

122 

123 BK = 64 

124 BV = 64 

125 

126 u = torch.empty_like(v) 

127 w = k.new_empty(B, T, H, K) 

128 beta = beta.transpose(1, 2).contiguous() 

129 g_cumsum = g_cumsum.transpose(1, 2).contiguous() 

130 recompute_w_u_fwd_kernel[(NT, B)]( 

131 k=k, 

132 v=v, 

133 beta=beta, 

134 w=w, 

135 u=u, 

136 A=A, 

137 g=g_cumsum, 

138 cu_seqlens=cu_seqlens, 

139 chunk_indices=chunk_indices, 

140 T=T, 

141 H=H, 

142 Hg=Hg, 

143 K=K, 

144 V=V, 

145 BT=BT, 

146 BK=BK, 

147 BV=BV, 

148 num_warps=4, 

149 num_stages=3, 

150 ) 

151 return w, u