# Rotary Embedding Torch Rotary Embedding Torch is a standalone PyTorch library for adding rotary positional embeddings to transformer architectures. Originally introduced in the RoFormer paper, rotary embeddings have proven to be highly effective relative positional encodings that improve transformer performance. This library makes it easy and efficient to rotate information into any axis of a tensor, whether using fixed positional or learned embeddings. The library provides state-of-the-art results for positional embedding with minimal computational overhead. It supports standard rotary embeddings, length-extrapolatable embeddings via XPos, axial rotary embeddings for video transformers, position interpolation for extending context length, and NTK-aware rescaling for longer sequences without fine-tuning. ## RotaryEmbedding Class The main class for creating rotary position embeddings that can be applied to queries and keys in transformer attention layers. Supports multiple frequency generation modes including language modeling, pixel-based for images/video, and constant frequencies. ```python import torch from rotary_embedding_torch import RotaryEmbedding # Basic initialization for language modeling rotary_emb = RotaryEmbedding( dim=32, # Dimension for rotary embeddings (typically head_dim // 2) freqs_for='lang', # 'lang' for text, 'pixel' for images/video, 'constant' theta=10000, # Base theta for frequency computation max_freq=10, # Maximum frequency (used with freqs_for='pixel') num_freqs=1, # Number of frequencies (used with freqs_for='constant') learned_freq=False, # Whether frequencies are learnable parameters use_xpos=False, # Enable XPos for length extrapolation xpos_scale_base=512, # Scale base for XPos interpolate_factor=1., # Factor for position interpolation (>1 extends context) theta_rescale_factor=1., # NTK-aware rescaling factor seq_before_head_dim=False, # Set True if tensor shape is (batch, seq, heads, dim) cache_if_possible=True, # Cache computed frequencies for efficiency cache_max_seq_len=8192 # Maximum sequence length to cache ) # Create sample queries and keys (batch, heads, seq_len, head_dim) q = torch.randn(2, 8, 512, 64) k = torch.randn(2, 8, 512, 64) # Apply rotary embeddings to queries and keys q_rotated = rotary_emb.rotate_queries_or_keys(q) k_rotated = rotary_emb.rotate_queries_or_keys(k) print(f"Original shape: {q.shape}") # torch.Size([2, 8, 512, 64]) print(f"Rotated shape: {q_rotated.shape}") # torch.Size([2, 8, 512, 64]) ``` ## rotate_queries_or_keys Method Applies rotary embeddings to a single tensor (either queries or keys). This is the primary method for standard rotary embedding usage without XPos length extrapolation. ```python import torch from rotary_embedding_torch import RotaryEmbedding rotary_emb = RotaryEmbedding(dim=32) # Standard attention tensors: (batch, heads, seq_len, head_dim) q = torch.randn(1, 8, 1024, 64) k = torch.randn(1, 8, 1024, 64) # Apply rotary embeddings with default sequence dimension (-2) q = rotary_emb.rotate_queries_or_keys(q) k = rotary_emb.rotate_queries_or_keys(k) # With custom sequence dimension for different tensor layouts # Shape: (batch, seq_len, heads, head_dim) -> use seq_dim=-3 x = torch.randn(1, 1024, 8, 64) x_rotated = rotary_emb.rotate_queries_or_keys(x, seq_dim=-3) # With position offset (useful for incremental decoding) q_offset = rotary_emb.rotate_queries_or_keys(q, offset=100) ``` ## rotate_queries_with_cached_keys Method Handles key-value caching during inference by automatically computing the correct query offset based on the cached key sequence length. Essential for efficient autoregressive generation. ```python import torch from rotary_embedding_torch import RotaryEmbedding rotary_emb = RotaryEmbedding(dim=32) # During inference with KV cache: # Single new query token q = torch.randn(1, 8, 1, 64) # Keys with cached history concatenated k = torch.randn(1, 8, 1024, 64) # Automatically handles offset calculation (k_len - q_len) q_rotated, k_rotated = rotary_emb.rotate_queries_with_cached_keys(q, k) # Equivalent manual approach: # q_manual = rotary_emb.rotate_queries_or_keys(q, offset=k.shape[-2] - q.shape[-2]) # k_manual = rotary_emb.rotate_queries_or_keys(k) print(f"Query shape: {q_rotated.shape}") # torch.Size([1, 8, 1, 64]) print(f"Key shape: {k_rotated.shape}") # torch.Size([1, 8, 1024, 64]) ``` ## rotate_queries_and_keys Method (XPos) Applies XPos rotary embeddings to both queries and keys simultaneously with decay-based scaling for length extrapolation. This method requires `use_xpos=True` during initialization and is designed for autoregressive transformers. ```python import torch from rotary_embedding_torch import RotaryEmbedding # Initialize with XPos for length extrapolation rotary_emb = RotaryEmbedding( dim=32, use_xpos=True, # Enable XPos xpos_scale_base=512 # Controls decay rate ) # Queries and keys must be rotated together with XPos q = torch.randn(1, 8, 1024, 64) k = torch.randn(1, 8, 1024, 64) # XPos applies complementary scaling to q and k q_rotated, k_rotated = rotary_emb.rotate_queries_and_keys(q, k) # The scaling ensures better extrapolation to longer sequences # than what was seen during training print(f"Q rotated shape: {q_rotated.shape}") # torch.Size([1, 8, 1024, 64]) print(f"K rotated shape: {k_rotated.shape}") # torch.Size([1, 8, 1024, 64]) ``` ## get_axial_freqs Method Generates n-dimensional axial rotary frequencies for video transformers or any multi-dimensional positional encoding. Each dimension gets its own set of frequencies which are then concatenated. ```python import torch from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb # Initialize for pixel/spatial data pos_emb = RotaryEmbedding( dim=16, # Dimension per axis freqs_for='pixel', # Use pixel frequencies max_freq=256 # Maximum frequency for pixel mode ) # Video tensor: (batch, frames, height, width, feature_dim) q = torch.randn(1, 8, 64, 32, 64) k = torch.randn(1, 8, 64, 32, 64) # Get axial frequencies for 3D: frames=8, height=64, width=32 # Output shape: (8, 64, 32, 16*3=48) - dim*num_axes frequencies freqs = pos_emb.get_axial_freqs(8, 64, 32) # Apply frequencies - automatically handles partial rotation q = apply_rotary_emb(freqs, q) k = apply_rotary_emb(freqs, k) # With position offsets for each axis freqs_offset = pos_emb.get_axial_freqs(8, 64, 32, offsets=(0, 10, 5)) print(f"Axial freqs shape: {freqs.shape}") # torch.Size([8, 64, 32, 48]) ``` ## apply_rotary_emb Function Low-level function that applies precomputed rotary frequencies to a tensor. Useful when you need direct control over the rotation process or when working with custom frequency computations. ```python import torch from rotary_embedding_torch import apply_rotary_emb, RotaryEmbedding # Generate frequencies manually or from RotaryEmbedding rotary_emb = RotaryEmbedding(dim=32) # Get raw frequencies for a sequence seq_positions = torch.arange(1024, dtype=torch.float32) freqs = rotary_emb(seq_positions, seq_len=1024) # Apply to tensor with full control t = torch.randn(1, 8, 1024, 64) # Basic application t_rotated = apply_rotary_emb(freqs, t) # With custom parameters t_rotated = apply_rotary_emb( freqs, t, start_index=0, # Start rotating from this index in feature dim scale=1.0, # Scaling factor (used with XPos) seq_dim=-2, # Sequence dimension in tensor freqs_seq_dim=None # Sequence dimension in frequencies (auto-detected) ) print(f"Frequencies shape: {freqs.shape}") # torch.Size([1024, 64]) print(f"Output shape: {t_rotated.shape}") # torch.Size([1, 8, 1024, 64]) ``` ## apply_learned_rotations Function Applies learned rotation angles to a tensor, useful for custom positional encoding schemes where rotation angles are learned parameters rather than fixed sinusoidal frequencies. ```python import torch from rotary_embedding_torch import apply_learned_rotations # Learned rotation angles (e.g., from a neural network) # Shape: (batch, seq_len, num_rotations) rotations = torch.randn(1, 1024, 16) # Input tensor to rotate t = torch.randn(1, 1024, 64) # Apply learned rotations t_rotated = apply_learned_rotations(rotations, t, start_index=0) # With frequency ranges for multi-scale rotations freq_ranges = torch.tensor([1., 2., 4., 8.]) # Different frequency scales t_rotated = apply_learned_rotations( rotations[:, :, :4], # 4 rotation groups t, start_index=0, freq_ranges=freq_ranges ) print(f"Output shape: {t_rotated.shape}") # torch.Size([1, 1024, 64]) ``` ## broadcat Function Broadcasts tensors to compatible shapes and concatenates them along a specified dimension. Utility function used internally and available for custom implementations. ```python import torch from rotary_embedding_torch import broadcat # Tensors of different shapes that can be broadcast together t1 = torch.randn(1, 1, 32) t2 = torch.randn(1, 8, 16) t3 = torch.randn(4, 1, 24) # Broadcast all tensors to compatible shape and concatenate result = broadcat([t1, t2, t3], dim=-1) print(f"Result shape: {result.shape}") # torch.Size([4, 8, 72]) # Useful for combining frequencies from different sources freq1 = torch.randn(1024, 16) freq2 = torch.randn(1024, 16) combined_freqs = broadcat([freq1, freq2], dim=-1) print(f"Combined freqs: {combined_freqs.shape}") # torch.Size([1024, 32]) ``` ## Position Interpolation Configuration Extends context length for pretrained models by interpolating sequence positions during fine-tuning. Based on Meta's research showing this performs better than extending positions directly. ```python import torch from rotary_embedding_torch import RotaryEmbedding # Original model trained on max 2048 tokens # To extend to 4096, use interpolate_factor=2.0 rotary_emb = RotaryEmbedding( dim=32, interpolate_factor=2.0 # Extends 2048 -> 4096 context ) # For 8x extension (2048 -> 16384) rotary_emb_8x = RotaryEmbedding( dim=32, interpolate_factor=8.0 ) # Use normally - interpolation is handled internally q = torch.randn(1, 8, 4096, 64) # Longer sequence k = torch.randn(1, 8, 4096, 64) q = rotary_emb.rotate_queries_or_keys(q) k = rotary_emb.rotate_queries_or_keys(k) print(f"Extended context shape: {q.shape}") # torch.Size([1, 8, 4096, 64]) ``` ## NTK-Aware Theta Rescaling Rescales the base theta parameter using NTK-aware scaling to allow models to handle longer sequences without fine-tuning. Community-discovered technique that reduces perplexity degradation at extended lengths. ```python import torch from rotary_embedding_torch import RotaryEmbedding # Standard initialization standard_emb = RotaryEmbedding(dim=32, theta=10000) # NTK-aware rescaling for longer context # Higher rescale factor = better long-range performance ntk_emb = RotaryEmbedding( dim=32, theta=10000, theta_rescale_factor=2.0 # Rescale for ~2x context length ) # For even longer contexts ntk_emb_large = RotaryEmbedding( dim=32, theta=10000, theta_rescale_factor=4.0 # Rescale for ~4x context length ) # Apply to longer sequences without fine-tuning q = torch.randn(1, 8, 8192, 64) k = torch.randn(1, 8, 8192, 64) q_rotated = ntk_emb_large.rotate_queries_or_keys(q) k_rotated = ntk_emb_large.rotate_queries_or_keys(k) print(f"Long context shape: {q_rotated.shape}") # torch.Size([1, 8, 8192, 64]) ``` ## Summary Rotary Embedding Torch is designed for seamless integration into transformer architectures, providing drop-in positional encoding that works with any attention mechanism. The primary use case is adding `rotate_queries_or_keys` calls after splitting heads but before computing attention scores. For autoregressive models requiring length extrapolation, XPos via `rotate_queries_and_keys` provides decay-based scaling. Video and image transformers benefit from `get_axial_freqs` for multi-dimensional positional encoding. Integration typically involves instantiating `RotaryEmbedding` once in the model constructor and passing it to all attention layers. The library handles caching automatically for efficiency, supports both standard and sequence-first tensor layouts via `seq_before_head_dim`, and provides multiple strategies for context length extension including position interpolation and NTK-aware rescaling. For inference with KV caching, `rotate_queries_with_cached_keys` automatically handles offset computation for correct positional alignment.