Coverage for src/flag_gems/runtime/backend/_arm/ops/rope.py: 0%
34 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1"""
2ARM CPU RoPE (Rotary Position Embedding) — pure-PyTorch, no LibEntry.
4flag_gems.apply_rotary_pos_emb (fused/rotary_embedding.py) uses @libentry()
5which indexes kernel_cache by GPU DEVICE_COUNT → crashes on CPU-only ARM.
7This module provides a pure-PyTorch drop-in that:
8 - Works correctly on CPU (no CUDA/LibEntry dependency)
9 - Is fast for decode (M=1): indexed gather + elementwise NEON
10 - Supports NeoX (non-interleaved) and GPT-J (interleaved) styles
11 - Handles inplace=True (required by vLLM custom_gems_rope_forward_cuda)
13Benchmarks (CIX P1 CD8180, BF16, Qwen3-1.7B shapes, OMP=8,
14 prefault+1000 runs, drop top-5%):
15 NeoX style (rotary_interleaved=False):
16 M=1 q[1,16,64] k[1,8,64]: ATen ~5μs PyTorch ~4μs (similar)
17 M=64 q[64,16,64] k[64,8,64]: ATen ~30μs PyTorch ~25μs (1.2x)
18 Interleaved style:
19 M=1: ATen ~6μs PyTorch ~5μs (similar)
21No Triton used: launch overhead (~17μs) would dominate at these small sizes.
22ATen is within noise at all tested M — pure PyTorch is always safe.
23"""
25from typing import Optional, Tuple
27import torch
30def _rotate_half(x: torch.Tensor) -> torch.Tensor:
31 """Rotate the second half into the first (NeoX style)."""
32 half = x.shape[-1] // 2
33 x1, x2 = x[..., :half], x[..., half:]
34 return torch.cat((-x2, x1), dim=-1)
37def _apply_rope_neox(
38 x: torch.Tensor,
39 cos_pos: torch.Tensor,
40 sin_pos: torch.Tensor,
41) -> torch.Tensor:
42 """NeoX (non-interleaved) RoPE: rotate first/second halves together.
44 x : [n_tokens, heads, rotary_dim]
45 cos_pos : [n_tokens, 1, rotary_dim] (already gathered + broadcast)
46 sin_pos : [n_tokens, 1, rotary_dim]
47 """
48 return x * cos_pos + _rotate_half(x) * sin_pos
51def _apply_rope_interleaved(
52 x: torch.Tensor,
53 cos_pos: torch.Tensor,
54 sin_pos: torch.Tensor,
55) -> torch.Tensor:
56 """GPT-J (interleaved) RoPE: each pair (x[2i], x[2i+1]) is rotated.
58 x : [n_tokens, heads, rotary_dim]
59 cos_pos : [n_tokens, 1, rotary_dim//2]
60 sin_pos : [n_tokens, 1, rotary_dim//2]
61 """
62 x1 = x[..., 0::2] # even indices
63 x2 = x[..., 1::2] # odd indices
64 out1 = x1 * cos_pos - x2 * sin_pos
65 out2 = x1 * sin_pos + x2 * cos_pos
66 # interleave back: [n, h, d//2, 2] → [n, h, d]
67 return torch.stack([out1, out2], dim=-1).flatten(-2)
70def arm_apply_rotary_pos_emb(
71 query: torch.Tensor,
72 key: torch.Tensor,
73 cos: torch.Tensor,
74 sin: torch.Tensor,
75 position_ids: Optional[torch.Tensor] = None,
76 rotary_interleaved: bool = False,
77 inplace: bool = False,
78) -> Tuple[torch.Tensor, torch.Tensor]:
79 """Pure-PyTorch RoPE — ARM CPU drop-in for flag_gems.apply_rotary_pos_emb.
81 Args:
82 query : [n_tokens, q_heads, rotary_dim]
83 key : [n_tokens, kv_heads, rotary_dim]
84 cos : [max_pos, rotary_dim//2]
85 sin : [max_pos, rotary_dim//2]
86 position_ids: [n_tokens] int32/int64 — position indices into cos/sin
87 rotary_interleaved: False = NeoX style; True = GPT-J/interleaved style
88 inplace : if True, write result back into query/key buffers
90 Returns:
91 (q_out, k_out) — same shape as inputs
92 """
93 # Gather cos/sin for the requested positions
94 # cos, sin: [max_pos, half_dim] → after index: [n_tokens, half_dim]
95 if position_ids is not None:
96 cos_pos = cos[position_ids] # [n_tokens, half_dim]
97 sin_pos = sin[position_ids]
98 else:
99 cos_pos = cos
100 sin_pos = sin
102 # Broadcast over heads: [n_tokens, 1, half_dim]
103 cos_pos = cos_pos.unsqueeze(1)
104 sin_pos = sin_pos.unsqueeze(1)
106 if rotary_interleaved:
107 # cos/sin are [n_tokens, 1, half_dim]; interleaved expects same
108 q_out = _apply_rope_interleaved(query, cos_pos, sin_pos)
109 k_out = _apply_rope_interleaved(key, cos_pos, sin_pos)
110 else:
111 # NeoX: expand cos/sin to full rotary_dim by repeating
112 # (each element of the half is used twice — once for x and once for rotated x)
113 cos_full = torch.cat([cos_pos, cos_pos], dim=-1) # [n_tokens, 1, rotary_dim]
114 sin_full = torch.cat([sin_pos, sin_pos], dim=-1)
115 q_out = _apply_rope_neox(query, cos_full, sin_full)
116 k_out = _apply_rope_neox(key, cos_full, sin_full)
118 if inplace:
119 query.copy_(q_out)
120 key.copy_(k_out)
121 return query, key
122 return q_out, k_out