Coverage for src/flag_gems/modules/rotary_embedding.py: 29%

79 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-17 02:35 +0800

1# Copyright (c) 2025 FlagGems. All rights reserved. 

2# 

3# This module is designed to provide a unified interface for various Rotary Position Embedding (RoPE) implementations. 

4# Currently, it includes only the Yarn-style RoPE used by DeepSeek, 

5# but support for other variants will be added progressively. 

6# 

7# The following components are adapted from DeepSeek-R1: 

8# - yarn_find_correction_dim 

9# - yarn_find_correction_range 

10# - yarn_get_mscale 

11# - yarn_linear_ramp_mask 

12# - _set_cos_sin_cache method in `GemsDeepseekYarnRoPE` 

13# 

14# Source: https://huggingface.co/deepseek-ai/DeepSeek-R1/blob/main/modeling_deepseek.py 

15# License: Apache License 2.0 (https://www.apache.org/licenses/LICENSE-2.0) 

16 

17import logging 

18import math 

19from typing import Optional, Tuple, Union 

20 

21import torch 

22import torch.nn as nn 

23 

24import flag_gems 

25from flag_gems.config import use_c_extension 

26 

27logger = logging.getLogger(__name__) 

28 

29__all__ = [ 

30 "gems_rope_forward", 

31 "GemsDeepseekYarnRoPE", 

32 "GemsRope", 

33] 

34 

35 

36def gems_rope_forward( 

37 query: torch.Tensor, 

38 key: torch.Tensor, 

39 cos: torch.Tensor, 

40 sin: torch.Tensor, 

41 position_ids: Optional[torch.IntTensor] = None, 

42 rotary_interleaved: bool = False, 

43 inplace: bool = False, 

44) -> Union[torch.Tensor, torch.Tensor]: 

45 if use_c_extension: 

46 logger.debug("GEMS CUSTOM ROPE FORWARD(C EXTENSION)") 

47 if inplace: 

48 torch.ops.flag_gems.rotary_embedding_inplace( 

49 query, key, cos, sin, position_ids, rotary_interleaved 

50 ) 

51 return query, key 

52 else: 

53 return torch.ops.flag_gems.rotary_embedding( 

54 query, key, cos, sin, position_ids, rotary_interleaved 

55 ) 

56 else: 

57 logger.debug("GEMS CUSTOM ROPE FORWARD") 

58 # Fallback to pure python implementation 

59 return flag_gems.apply_rotary_pos_emb( 

60 query, key, cos, sin, position_ids, rotary_interleaved, inplace 

61 ) 

62 

63 

64class GemsRope(nn.Module): 

65 """ 

66 Base class for Rotary Position Embedding (RoPE) modules. 

67 This class is intended to be subclassed for specific RoPE implementations. 

68 

69 Args: 

70 rotary_dim (int): The rotary embedding dimension (typically equal to head_dim). 

71 max_position_embeddings (int): Initial maximum position length to precompute cache. 

72 base (float): Frequency base used to compute inverse frequencies (default 10000). 

73 device (torch.device or None): Device to place initial buffers on. 

74 dtype (torch.dtype): Data type for the cos/sin cache buffers. 

75 rotary_interleaved (bool): Whether to use interleaved rotary layout (GPT-NeoX-style). 

76 

77 Inputs: 

78 query (torch.Tensor): Shape (..., q_heads, head_dim) 

79 key (torch.Tensor): Shape (..., k_heads, head_dim) 

80 position_ids (torch.IntTensor, optional): Shape (..., seq_len), positions for cos/sin lookup. 

81 inplace (bool): If True, modifies query and key in place (default False). 

82 

83 Returns: 

84 Tuple[torch.Tensor, torch.Tensor]: Transformed (query, key) tensors with RoPE applied. 

85 """ 

86 

87 def __init__( 

88 self, 

89 rotary_dim, 

90 max_position_embeddings, 

91 base, 

92 rotary_interleaved, 

93 dtype, 

94 device, 

95 ): 

96 super().__init__() 

97 self.rotary_dim = rotary_dim 

98 self.max_position_embeddings = max_position_embeddings 

99 self.base = base 

100 self.rotary_interleaved = rotary_interleaved 

101 self.dtype = dtype 

102 self.device = device 

103 self._set_cos_sin_cache() 

104 

105 def _compute_inv_freq(self) -> torch.Tensor: 

106 """ 

107 Compute the inverse frequency tensor: shape [dim/2] 

108 """ 

109 return 1.0 / ( 

110 self.base 

111 ** ( 

112 torch.arange( 

113 0, self.rotary_dim, 2, dtype=torch.float32, device=self.device 

114 ) 

115 / self.rotary_dim 

116 ) 

117 ) 

118 

119 def _set_cos_sin_cache(self): 

120 """ 

121 Default implementation of rotary embeddings (vanilla RoPE). 

122 Can be overridden in subclasses for NTK, YaRN, etc. 

123 """ 

124 inv_freq = self._compute_inv_freq() 

125 t = torch.arange( 

126 self.max_position_embeddings, device=self.device, dtype=torch.float32 

127 ) 

128 freqs = torch.outer(t, inv_freq) # [max_position_embeddings, rotary_dim // 2] 

129 

130 self.register_buffer( 

131 "cos_cached", freqs.cos().to(self.dtype), persistent=False 

132 ) # [max_position_embeddings, rotary_dim // 2] 

133 self.register_buffer( 

134 "sin_cached", freqs.sin().to(self.dtype), persistent=False 

135 ) # [max_position_embeddings, rotary_dim // 2] 

136 

137 def forward( 

138 self, 

139 query: torch.Tensor, 

140 key: torch.Tensor, 

141 position_ids: Optional[torch.IntTensor] = None, 

142 inplace: bool = False, 

143 ) -> Tuple[torch.Tensor, torch.Tensor]: 

144 if not hasattr(self, "cos_cached") or not hasattr(self, "sin_cached"): 

145 self._set_cos_sin_cache() 

146 

147 return gems_rope_forward( 

148 query, 

149 key, 

150 self.cos_cached, 

151 self.sin_cached, 

152 position_ids, 

153 self.rotary_interleaved, 

154 inplace, 

155 ) 

156 

157 

158# Inverse dim formula to find dim based on number of rotations 

159def yarn_find_correction_dim( 

160 num_rotations, dim, base=10000, max_position_embeddings=2048 

161): 

162 return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( 

163 2 * math.log(base) 

164 ) 

165 

166 

167# Find dim range bounds based on rotations 

168def yarn_find_correction_range( 

169 low_rot, high_rot, dim, base=10000, max_position_embeddings=2048 

170): 

171 low = math.floor( 

172 yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) 

173 ) 

174 high = math.ceil( 

175 yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) 

176 ) 

177 return max(low, 0), min(high, dim - 1) # Clamp values just in case 

178 

179 

180def yarn_get_mscale(scale=1, mscale=1): 

181 if scale <= 1: 

182 return 1.0 

183 return 0.1 * mscale * math.log(scale) + 1.0 

184 

185 

186def yarn_linear_ramp_mask(min, max, dim): 

187 if min == max: 

188 max += 0.001 # Prevent singularity 

189 

190 linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) 

191 ramp_func = torch.clamp(linear_func, 0, 1) 

192 return ramp_func 

193 

194 

195class GemsDeepseekYarnRoPE(GemsRope): 

196 """ 

197 Yarn-based Rotary Position Embedding (RoPE) for DeepSeek models. 

198 Args: 

199 scaling_factor (float): Scaling factor for Yarn extrapolation. 

200 original_max_position_embeddings (int): Original pretraining context size. 

201 beta_fast (float): Controls rapid frequency decay. 

202 beta_slow (float): Controls slow frequency decay. 

203 mscale (float): Multiplicative scale factor for selected frequencies. 

204 mscale_all_dim (float): Global multiplicative baseline. 

205 """ 

206 

207 def __init__( 

208 self, 

209 rotary_dim: int, 

210 max_position_embeddings: int = 2048, 

211 base: float = 10000, 

212 rotary_interleaved: bool = False, 

213 dtype: torch.dtype = torch.float32, 

214 device: Optional[torch.device] = None, 

215 scaling_factor: float = 1.0, 

216 original_max_position_embeddings: int = 4096, 

217 beta_fast: float = 32.0, 

218 beta_slow: float = 1.0, 

219 mscale: float = 1.0, 

220 mscale_all_dim: float = 0.0, 

221 ): 

222 self.scaling_factor = scaling_factor 

223 self.original_max_position_embeddings = original_max_position_embeddings 

224 self.beta_fast = beta_fast 

225 self.beta_slow = beta_slow 

226 self.mscale = mscale 

227 self.mscale_all_dim = mscale_all_dim 

228 super().__init__( 

229 rotary_dim=rotary_dim, 

230 max_position_embeddings=max_position_embeddings, 

231 base=base, 

232 rotary_interleaved=rotary_interleaved, 

233 dtype=dtype, 

234 device=device, 

235 ) 

236 

237 def _compute_inv_freq(self) -> torch.Tensor: 

238 freq_extra = 1.0 / ( 

239 self.base 

240 ** ( 

241 torch.arange( 

242 0, self.rotary_dim, 2, dtype=torch.float32, device=self.device 

243 ) 

244 / self.rotary_dim 

245 ) 

246 ) 

247 freq_inter = 1.0 / ( 

248 self.scaling_factor 

249 * self.base 

250 ** ( 

251 torch.arange( 

252 0, self.rotary_dim, 2, dtype=torch.float32, device=self.device 

253 ) 

254 / self.rotary_dim 

255 ) 

256 ) 

257 

258 low, high = yarn_find_correction_range( 

259 self.beta_fast, 

260 self.beta_slow, 

261 self.rotary_dim, 

262 self.base, 

263 self.original_max_position_embeddings, 

264 ) 

265 

266 inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, self.rotary_dim // 2).to( 

267 device=self.device, dtype=torch.float32 

268 ) 

269 inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask 

270 return inv_freq 

271 

272 def _set_cos_sin_cache(self): 

273 inv_freq = self._compute_inv_freq() 

274 # self.register_buffer("inv_freq", inv_freq, persistent=False) 

275 

276 t = torch.arange( 

277 self.max_position_embeddings, device=self.device, dtype=torch.float32 

278 ) 

279 freqs = torch.outer(t, inv_freq) # [max_position_embeddings, rotary_dim // 2] 

280 

281 _mscale = float( 

282 yarn_get_mscale(self.scaling_factor, self.mscale) 

283 / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim) 

284 ) 

285 

286 self.register_buffer( 

287 "cos_cached", (freqs.cos() * _mscale).to(self.dtype), persistent=False 

288 ) 

289 self.register_buffer( 

290 "sin_cached", (freqs.sin() * _mscale).to(self.dtype), persistent=False 

291 )