Coverage for src/flag_gems/modules/rotary_embedding.py: 29%
79 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
1# Copyright (c) 2025 FlagGems. All rights reserved.
2#
3# This module is designed to provide a unified interface for various Rotary Position Embedding (RoPE) implementations.
4# Currently, it includes only the Yarn-style RoPE used by DeepSeek,
5# but support for other variants will be added progressively.
6#
7# The following components are adapted from DeepSeek-R1:
8# - yarn_find_correction_dim
9# - yarn_find_correction_range
10# - yarn_get_mscale
11# - yarn_linear_ramp_mask
12# - _set_cos_sin_cache method in `GemsDeepseekYarnRoPE`
13#
14# Source: https://huggingface.co/deepseek-ai/DeepSeek-R1/blob/main/modeling_deepseek.py
15# License: Apache License 2.0 (https://www.apache.org/licenses/LICENSE-2.0)
17import logging
18import math
19from typing import Optional, Tuple, Union
21import torch
22import torch.nn as nn
24import flag_gems
25from flag_gems.config import use_c_extension
27logger = logging.getLogger(__name__)
29__all__ = [
30 "gems_rope_forward",
31 "GemsDeepseekYarnRoPE",
32 "GemsRope",
33]
36def gems_rope_forward(
37 query: torch.Tensor,
38 key: torch.Tensor,
39 cos: torch.Tensor,
40 sin: torch.Tensor,
41 position_ids: Optional[torch.IntTensor] = None,
42 rotary_interleaved: bool = False,
43 inplace: bool = False,
44) -> Union[torch.Tensor, torch.Tensor]:
45 if use_c_extension:
46 logger.debug("GEMS CUSTOM ROPE FORWARD(C EXTENSION)")
47 if inplace:
48 torch.ops.flag_gems.rotary_embedding_inplace(
49 query, key, cos, sin, position_ids, rotary_interleaved
50 )
51 return query, key
52 else:
53 return torch.ops.flag_gems.rotary_embedding(
54 query, key, cos, sin, position_ids, rotary_interleaved
55 )
56 else:
57 logger.debug("GEMS CUSTOM ROPE FORWARD")
58 # Fallback to pure python implementation
59 return flag_gems.apply_rotary_pos_emb(
60 query, key, cos, sin, position_ids, rotary_interleaved, inplace
61 )
64class GemsRope(nn.Module):
65 """
66 Base class for Rotary Position Embedding (RoPE) modules.
67 This class is intended to be subclassed for specific RoPE implementations.
69 Args:
70 rotary_dim (int): The rotary embedding dimension (typically equal to head_dim).
71 max_position_embeddings (int): Initial maximum position length to precompute cache.
72 base (float): Frequency base used to compute inverse frequencies (default 10000).
73 device (torch.device or None): Device to place initial buffers on.
74 dtype (torch.dtype): Data type for the cos/sin cache buffers.
75 rotary_interleaved (bool): Whether to use interleaved rotary layout (GPT-NeoX-style).
77 Inputs:
78 query (torch.Tensor): Shape (..., q_heads, head_dim)
79 key (torch.Tensor): Shape (..., k_heads, head_dim)
80 position_ids (torch.IntTensor, optional): Shape (..., seq_len), positions for cos/sin lookup.
81 inplace (bool): If True, modifies query and key in place (default False).
83 Returns:
84 Tuple[torch.Tensor, torch.Tensor]: Transformed (query, key) tensors with RoPE applied.
85 """
87 def __init__(
88 self,
89 rotary_dim,
90 max_position_embeddings,
91 base,
92 rotary_interleaved,
93 dtype,
94 device,
95 ):
96 super().__init__()
97 self.rotary_dim = rotary_dim
98 self.max_position_embeddings = max_position_embeddings
99 self.base = base
100 self.rotary_interleaved = rotary_interleaved
101 self.dtype = dtype
102 self.device = device
103 self._set_cos_sin_cache()
105 def _compute_inv_freq(self) -> torch.Tensor:
106 """
107 Compute the inverse frequency tensor: shape [dim/2]
108 """
109 return 1.0 / (
110 self.base
111 ** (
112 torch.arange(
113 0, self.rotary_dim, 2, dtype=torch.float32, device=self.device
114 )
115 / self.rotary_dim
116 )
117 )
119 def _set_cos_sin_cache(self):
120 """
121 Default implementation of rotary embeddings (vanilla RoPE).
122 Can be overridden in subclasses for NTK, YaRN, etc.
123 """
124 inv_freq = self._compute_inv_freq()
125 t = torch.arange(
126 self.max_position_embeddings, device=self.device, dtype=torch.float32
127 )
128 freqs = torch.outer(t, inv_freq) # [max_position_embeddings, rotary_dim // 2]
130 self.register_buffer(
131 "cos_cached", freqs.cos().to(self.dtype), persistent=False
132 ) # [max_position_embeddings, rotary_dim // 2]
133 self.register_buffer(
134 "sin_cached", freqs.sin().to(self.dtype), persistent=False
135 ) # [max_position_embeddings, rotary_dim // 2]
137 def forward(
138 self,
139 query: torch.Tensor,
140 key: torch.Tensor,
141 position_ids: Optional[torch.IntTensor] = None,
142 inplace: bool = False,
143 ) -> Tuple[torch.Tensor, torch.Tensor]:
144 if not hasattr(self, "cos_cached") or not hasattr(self, "sin_cached"):
145 self._set_cos_sin_cache()
147 return gems_rope_forward(
148 query,
149 key,
150 self.cos_cached,
151 self.sin_cached,
152 position_ids,
153 self.rotary_interleaved,
154 inplace,
155 )
158# Inverse dim formula to find dim based on number of rotations
159def yarn_find_correction_dim(
160 num_rotations, dim, base=10000, max_position_embeddings=2048
161):
162 return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
163 2 * math.log(base)
164 )
167# Find dim range bounds based on rotations
168def yarn_find_correction_range(
169 low_rot, high_rot, dim, base=10000, max_position_embeddings=2048
170):
171 low = math.floor(
172 yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
173 )
174 high = math.ceil(
175 yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
176 )
177 return max(low, 0), min(high, dim - 1) # Clamp values just in case
180def yarn_get_mscale(scale=1, mscale=1):
181 if scale <= 1:
182 return 1.0
183 return 0.1 * mscale * math.log(scale) + 1.0
186def yarn_linear_ramp_mask(min, max, dim):
187 if min == max:
188 max += 0.001 # Prevent singularity
190 linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
191 ramp_func = torch.clamp(linear_func, 0, 1)
192 return ramp_func
195class GemsDeepseekYarnRoPE(GemsRope):
196 """
197 Yarn-based Rotary Position Embedding (RoPE) for DeepSeek models.
198 Args:
199 scaling_factor (float): Scaling factor for Yarn extrapolation.
200 original_max_position_embeddings (int): Original pretraining context size.
201 beta_fast (float): Controls rapid frequency decay.
202 beta_slow (float): Controls slow frequency decay.
203 mscale (float): Multiplicative scale factor for selected frequencies.
204 mscale_all_dim (float): Global multiplicative baseline.
205 """
207 def __init__(
208 self,
209 rotary_dim: int,
210 max_position_embeddings: int = 2048,
211 base: float = 10000,
212 rotary_interleaved: bool = False,
213 dtype: torch.dtype = torch.float32,
214 device: Optional[torch.device] = None,
215 scaling_factor: float = 1.0,
216 original_max_position_embeddings: int = 4096,
217 beta_fast: float = 32.0,
218 beta_slow: float = 1.0,
219 mscale: float = 1.0,
220 mscale_all_dim: float = 0.0,
221 ):
222 self.scaling_factor = scaling_factor
223 self.original_max_position_embeddings = original_max_position_embeddings
224 self.beta_fast = beta_fast
225 self.beta_slow = beta_slow
226 self.mscale = mscale
227 self.mscale_all_dim = mscale_all_dim
228 super().__init__(
229 rotary_dim=rotary_dim,
230 max_position_embeddings=max_position_embeddings,
231 base=base,
232 rotary_interleaved=rotary_interleaved,
233 dtype=dtype,
234 device=device,
235 )
237 def _compute_inv_freq(self) -> torch.Tensor:
238 freq_extra = 1.0 / (
239 self.base
240 ** (
241 torch.arange(
242 0, self.rotary_dim, 2, dtype=torch.float32, device=self.device
243 )
244 / self.rotary_dim
245 )
246 )
247 freq_inter = 1.0 / (
248 self.scaling_factor
249 * self.base
250 ** (
251 torch.arange(
252 0, self.rotary_dim, 2, dtype=torch.float32, device=self.device
253 )
254 / self.rotary_dim
255 )
256 )
258 low, high = yarn_find_correction_range(
259 self.beta_fast,
260 self.beta_slow,
261 self.rotary_dim,
262 self.base,
263 self.original_max_position_embeddings,
264 )
266 inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, self.rotary_dim // 2).to(
267 device=self.device, dtype=torch.float32
268 )
269 inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
270 return inv_freq
272 def _set_cos_sin_cache(self):
273 inv_freq = self._compute_inv_freq()
274 # self.register_buffer("inv_freq", inv_freq, persistent=False)
276 t = torch.arange(
277 self.max_position_embeddings, device=self.device, dtype=torch.float32
278 )
279 freqs = torch.outer(t, inv_freq) # [max_position_embeddings, rotary_dim // 2]
281 _mscale = float(
282 yarn_get_mscale(self.scaling_factor, self.mscale)
283 / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim)
284 )
286 self.register_buffer(
287 "cos_cached", (freqs.cos() * _mscale).to(self.dtype), persistent=False
288 )
289 self.register_buffer(
290 "sin_cached", (freqs.sin() * _mscale).to(self.dtype), persistent=False
291 )