Coverage for src/flag_gems/runtime/backend/_arm/ops/rope.py: 0%

34 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1""" 

2ARM CPU RoPE (Rotary Position Embedding) — pure-PyTorch, no LibEntry. 

3 

4flag_gems.apply_rotary_pos_emb (fused/rotary_embedding.py) uses @libentry() 

5which indexes kernel_cache by GPU DEVICE_COUNT → crashes on CPU-only ARM. 

6 

7This module provides a pure-PyTorch drop-in that: 

8 - Works correctly on CPU (no CUDA/LibEntry dependency) 

9 - Is fast for decode (M=1): indexed gather + elementwise NEON 

10 - Supports NeoX (non-interleaved) and GPT-J (interleaved) styles 

11 - Handles inplace=True (required by vLLM custom_gems_rope_forward_cuda) 

12 

13Benchmarks (CIX P1 CD8180, BF16, Qwen3-1.7B shapes, OMP=8, 

14 prefault+1000 runs, drop top-5%): 

15 NeoX style (rotary_interleaved=False): 

16 M=1 q[1,16,64] k[1,8,64]: ATen ~5μs PyTorch ~4μs (similar) 

17 M=64 q[64,16,64] k[64,8,64]: ATen ~30μs PyTorch ~25μs (1.2x) 

18 Interleaved style: 

19 M=1: ATen ~6μs PyTorch ~5μs (similar) 

20 

21No Triton used: launch overhead (~17μs) would dominate at these small sizes. 

22ATen is within noise at all tested M — pure PyTorch is always safe. 

23""" 

24 

25from typing import Optional, Tuple 

26 

27import torch 

28 

29 

30def _rotate_half(x: torch.Tensor) -> torch.Tensor: 

31 """Rotate the second half into the first (NeoX style).""" 

32 half = x.shape[-1] // 2 

33 x1, x2 = x[..., :half], x[..., half:] 

34 return torch.cat((-x2, x1), dim=-1) 

35 

36 

37def _apply_rope_neox( 

38 x: torch.Tensor, 

39 cos_pos: torch.Tensor, 

40 sin_pos: torch.Tensor, 

41) -> torch.Tensor: 

42 """NeoX (non-interleaved) RoPE: rotate first/second halves together. 

43 

44 x : [n_tokens, heads, rotary_dim] 

45 cos_pos : [n_tokens, 1, rotary_dim] (already gathered + broadcast) 

46 sin_pos : [n_tokens, 1, rotary_dim] 

47 """ 

48 return x * cos_pos + _rotate_half(x) * sin_pos 

49 

50 

51def _apply_rope_interleaved( 

52 x: torch.Tensor, 

53 cos_pos: torch.Tensor, 

54 sin_pos: torch.Tensor, 

55) -> torch.Tensor: 

56 """GPT-J (interleaved) RoPE: each pair (x[2i], x[2i+1]) is rotated. 

57 

58 x : [n_tokens, heads, rotary_dim] 

59 cos_pos : [n_tokens, 1, rotary_dim//2] 

60 sin_pos : [n_tokens, 1, rotary_dim//2] 

61 """ 

62 x1 = x[..., 0::2] # even indices 

63 x2 = x[..., 1::2] # odd indices 

64 out1 = x1 * cos_pos - x2 * sin_pos 

65 out2 = x1 * sin_pos + x2 * cos_pos 

66 # interleave back: [n, h, d//2, 2] → [n, h, d] 

67 return torch.stack([out1, out2], dim=-1).flatten(-2) 

68 

69 

70def arm_apply_rotary_pos_emb( 

71 query: torch.Tensor, 

72 key: torch.Tensor, 

73 cos: torch.Tensor, 

74 sin: torch.Tensor, 

75 position_ids: Optional[torch.Tensor] = None, 

76 rotary_interleaved: bool = False, 

77 inplace: bool = False, 

78) -> Tuple[torch.Tensor, torch.Tensor]: 

79 """Pure-PyTorch RoPE — ARM CPU drop-in for flag_gems.apply_rotary_pos_emb. 

80 

81 Args: 

82 query : [n_tokens, q_heads, rotary_dim] 

83 key : [n_tokens, kv_heads, rotary_dim] 

84 cos : [max_pos, rotary_dim//2] 

85 sin : [max_pos, rotary_dim//2] 

86 position_ids: [n_tokens] int32/int64 — position indices into cos/sin 

87 rotary_interleaved: False = NeoX style; True = GPT-J/interleaved style 

88 inplace : if True, write result back into query/key buffers 

89 

90 Returns: 

91 (q_out, k_out) — same shape as inputs 

92 """ 

93 # Gather cos/sin for the requested positions 

94 # cos, sin: [max_pos, half_dim] → after index: [n_tokens, half_dim] 

95 if position_ids is not None: 

96 cos_pos = cos[position_ids] # [n_tokens, half_dim] 

97 sin_pos = sin[position_ids] 

98 else: 

99 cos_pos = cos 

100 sin_pos = sin 

101 

102 # Broadcast over heads: [n_tokens, 1, half_dim] 

103 cos_pos = cos_pos.unsqueeze(1) 

104 sin_pos = sin_pos.unsqueeze(1) 

105 

106 if rotary_interleaved: 

107 # cos/sin are [n_tokens, 1, half_dim]; interleaved expects same 

108 q_out = _apply_rope_interleaved(query, cos_pos, sin_pos) 

109 k_out = _apply_rope_interleaved(key, cos_pos, sin_pos) 

110 else: 

111 # NeoX: expand cos/sin to full rotary_dim by repeating 

112 # (each element of the half is used twice — once for x and once for rotated x) 

113 cos_full = torch.cat([cos_pos, cos_pos], dim=-1) # [n_tokens, 1, rotary_dim] 

114 sin_full = torch.cat([sin_pos, sin_pos], dim=-1) 

115 q_out = _apply_rope_neox(query, cos_full, sin_full) 

116 k_out = _apply_rope_neox(key, cos_full, sin_full) 

117 

118 if inplace: 

119 query.copy_(q_out) 

120 key.copy_(k_out) 

121 return query, key 

122 return q_out, k_out