# PoPE PyTorch PoPE-pytorch is an efficient PyTorch implementation of Polar Coordinate Positional Embeddings (PoPE), a novel approach to positional encoding for transformer architectures. Based on the research paper "Decoupling the 'What' and 'Where' With Polar Coordinate Positional Embeddings" by Gopalakrishnan et al., this library provides a clean interface for integrating polar positional embeddings into attention mechanisms. PoPE works by representing positions using polar coordinates, which decouples spatial information from content representation in attention computations. The library offers several key components: the core `PoPE` class for sequence-based positional embeddings, `AxialPoPE` for multi-dimensional data like images and videos, fused attention similarity computation via Triton kernels for improved performance, and flash attention integration with PoPE support. These components can be seamlessly integrated into existing transformer architectures with minimal code changes, providing both training and inference capabilities with support for key-value caching. ## Installation Install PoPE-pytorch using pip to add polar coordinate positional embeddings to your transformer models. ```bash pip install PoPE-pytorch ``` ## PoPE - Polar Positional Embedding The `PoPE` class generates polar coordinate positional embeddings for sequences. It takes a feature dimension and number of attention heads, then produces frequency-based positional information that can be applied to queries and keys in attention. ```python import torch from PoPE_pytorch import PoPE # Initialize PoPE with feature dimension and number of heads pope = PoPE(dim=64, heads=8) # Generate positional embeddings for a sequence of length 1024 # Returns a PolarEmbedReturn namedtuple with (freqs, bias) pos_emb = pope(1024) # Create query and key tensors (batch, heads, seq_len, dim) q = torch.randn(2, 8, 1024, 64) k = torch.randn(2, 8, 1024, 64) # Apply PoPE rotations to queries and keys for training # Output shape doubles the rotation dimension: (2, 8, 1024, 128) rotated_q, rotated_k = pope.apply_pope_to_qk(pos_emb, q, k) print(f"Rotated Q shape: {rotated_q.shape}") # torch.Size([2, 8, 1024, 128]) print(f"Rotated K shape: {rotated_k.shape}") # torch.Size([2, 8, 1024, 128]) # For inference with KV-caching, apply to only the last query position # while keeping the full key sequence rotated_q_inf, rotated_k_inf = pope.apply_pope_to_qk(pos_emb, q[..., -1:, :], k) print(f"Inference Q shape: {rotated_q_inf.shape}") # torch.Size([2, 8, 1, 128]) print(f"Inference K shape: {rotated_k_inf.shape}") # torch.Size([2, 8, 1024, 128]) ``` ## PoPE with Partial Rotation PoPE supports partial rotation where only a subset of dimensions receive positional information. This is useful when you want to preserve some dimensions for content-only representation. ```python import torch from PoPE_pytorch import PoPE # Initialize PoPE with only 32 dimensions for rotation (out of 64 total) pope = PoPE(dim=32, heads=8) pos_emb = pope(512) # Query and key have 64 dimensions, but only first 32 are rotated q = torch.randn(1, 8, 512, 64) k = torch.randn(1, 8, 512, 64) rotated_q, rotated_k = pope.apply_pope_to_qk(pos_emb, q, k) # Output: 32 * 2 (rotated) + 32 (unrotated) = 96 dimensions print(f"Partial rotation output shape: {rotated_q.shape}") # torch.Size([1, 8, 512, 96]) ``` ## AxialPoPE - Multi-dimensional Positional Embeddings `AxialPoPE` extends PoPE to support multi-dimensional data like images and videos. It splits the feature dimension across axial dimensions and generates appropriate positional frequencies for each axis. ```python import torch from PoPE_pytorch import AxialPoPE # AxialPoPE for images: split 64 dims into 32 (height) + 32 (width) pope_image = AxialPoPE( dim=64, heads=8, axial_dims=(32, 32) ) # Generate positional embeddings for a 16x16 image # Automatically creates a grid of positions pos_emb = pope_image((16, 16)) # pos_emb.freqs shape: (256, 64) - 256 positions, 64 frequency dims print(f"Image positional embedding shape: {pos_emb.freqs.shape}") # Query and key for flattened image patches (16*16 = 256 tokens) q = torch.randn(1, 8, 256, 64) k = torch.randn(1, 8, 256, 64) rotated_q, rotated_k = AxialPoPE.apply_pope_to_qk(pos_emb, q, k) print(f"Image attention Q shape: {rotated_q.shape}") # torch.Size([1, 8, 256, 128]) # AxialPoPE for video: split 96 dims into 32 (time) + 32 (height) + 32 (width) pope_video = AxialPoPE( dim=96, heads=8, axial_dims=(32, 32, 32) ) # Generate embeddings for 8 frames of 16x16 video pos_emb_video = pope_video((8, 16, 16)) # 8 * 16 * 16 = 2048 positions, 96 frequency dims print(f"Video positional embedding shape: {pos_emb_video.freqs.shape}") # torch.Size([2048, 96]) # Video queries and keys q_video = torch.randn(1, 8, 2048, 96) k_video = torch.randn(1, 8, 2048, 96) rotated_q_video, rotated_k_video = AxialPoPE.apply_pope_to_qk(pos_emb_video, q_video, k_video) print(f"Video attention output shape: {rotated_q_video.shape}") # torch.Size([1, 8, 2048, 192]) ``` ## compute_attn_similarity - Fused Attention Similarity The `compute_attn_similarity` function computes attention scores with PoPE in a memory-efficient manner. When Triton is available and running on CUDA, it uses a fused kernel that avoids expanding dimensions, reducing memory usage. ```python import torch from PoPE_pytorch import PoPE, compute_attn_similarity # Initialize PoPE on CUDA for fused kernel support pope = PoPE(dim=64, heads=8).cuda() pos_emb = pope(1024) # Queries and keys on CUDA q = torch.randn(2, 8, 1024, 64).cuda() k = torch.randn(2, 8, 1024, 64).cuda() # Compute fused attention similarity # Returns (batch, heads, q_len, k_len) similarity scores sim = compute_attn_similarity(q, k, pos_emb) print(f"Similarity matrix shape: {sim.shape}") # torch.Size([2, 8, 1024, 1024]) # Apply softmax to get attention weights attn = sim.softmax(dim=-1) # Use attention weights with values v = torch.randn(2, 8, 1024, 64).cuda() out = torch.einsum('b h i j, b h j d -> b h i d', attn, v) print(f"Attention output shape: {out.shape}") # torch.Size([2, 8, 1024, 64]) # For inference with single query position q_single = torch.randn(2, 8, 1, 64).cuda() sim_single = compute_attn_similarity(q_single, k, pos_emb) print(f"Single query similarity shape: {sim_single.shape}") # torch.Size([2, 8, 1, 1024]) ``` ## flash_attn_with_pope - Flash Attention with PoPE The `flash_attn_with_pope` function provides a complete attention implementation with PoPE support, including optional causal masking, key padding masks, and automatic selection between fused Triton kernels and standard PyTorch SDPA. ```python import torch from PoPE_pytorch import PoPE, flash_attn_with_pope # Initialize PoPE pope = PoPE(dim=32, heads=8).cuda() # Query, key, value tensors q = torch.randn(2, 8, 1024, 64).cuda() k = torch.randn(2, 8, 1024, 64).cuda() v = torch.randn(2, 8, 1024, 64).cuda() pos_emb = pope(1024) # Basic flash attention with PoPE out = flash_attn_with_pope(q, k, v, pos_emb=pos_emb) print(f"Basic attention output: {out.shape}") # torch.Size([2, 8, 1024, 64]) # Causal attention for autoregressive models out_causal = flash_attn_with_pope(q, k, v, pos_emb=pos_emb, causal=True) print(f"Causal attention output: {out_causal.shape}") # torch.Size([2, 8, 1024, 64]) # With key padding mask (True = valid, False = masked) mask = torch.ones((2, 1024)).bool().cuda() mask[:, 512:] = False # Mask out second half of sequence out_masked = flash_attn_with_pope( q, k, v, pos_emb=pos_emb, causal=True, mask=mask ) print(f"Masked attention output: {out_masked.shape}") # torch.Size([2, 8, 1024, 64]) # Control fused vs non-fused execution out_manual = flash_attn_with_pope( q, k, v, pos_emb=pos_emb, fused=False # Force non-fused path using PyTorch SDPA ) # Custom softmax scale out_scaled = flash_attn_with_pope( q, k, v, pos_emb=pos_emb, softmax_scale=0.1 ) # Tensor layout: (batch, seq, heads, dim) format q_nhd = torch.randn(2, 1024, 8, 64).cuda() k_nhd = torch.randn(2, 1024, 8, 64).cuda() v_nhd = torch.randn(2, 1024, 8, 64).cuda() out_nhd = flash_attn_with_pope( q_nhd, k_nhd, v_nhd, pos_emb=pos_emb, head_dimension_at_first=False ) print(f"NHD layout output: {out_nhd.shape}") # torch.Size([2, 1024, 8, 64]) ``` ## apply_pope_to_qk - Manual PoPE Application The standalone `apply_pope_to_qk` function provides fine-grained control over PoPE application, supporting custom magnitude functions and complex number output for specialized use cases. ```python import torch import torch.nn.functional as F from PoPE_pytorch import PoPE, apply_pope_to_qk pope = PoPE(dim=64, heads=8) pos_emb = pope(512) q = torch.randn(1, 8, 512, 64) k = torch.randn(1, 8, 512, 64) # Default application with softplus magnitude rotated_q, rotated_k = apply_pope_to_qk(pos_emb, q, k) # Custom magnitude function (e.g., ReLU) rotated_q_relu, rotated_k_relu = apply_pope_to_qk( pos_emb, q, k, to_magnitude=F.relu ) # Custom magnitude with exponential rotated_q_exp, rotated_k_exp = apply_pope_to_qk( pos_emb, q, k, to_magnitude=torch.exp ) # Return complex numbers instead of concatenated real/imaginary rotated_q_complex, rotated_k_complex = apply_pope_to_qk( pos_emb, q, k, return_complex=True ) print(f"Complex Q dtype: {rotated_q_complex.dtype}") # torch.complex64 print(f"Complex Q shape: {rotated_q_complex.shape}") # torch.Size([1, 8, 512, 64]) ``` ## Building a Transformer with PoPE Complete example integrating PoPE into a causal transformer architecture for language modeling, demonstrating training and inference with KV-caching. ```python import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from PoPE_pytorch import PoPE from PoPE_pytorch.attention import flash_attn_with_pope class RMSNorm(nn.Module): def __init__(self, dim): super().__init__() self.scale = dim ** 0.5 self.gamma = nn.Parameter(torch.ones(dim)) def forward(self, x): return F.normalize(x, dim=-1) * self.scale * self.gamma class CausalAttention(nn.Module): def __init__(self, dim, heads=8): super().__init__() self.heads = heads self.scale = (dim // heads) ** -0.5 self.to_qkv = nn.Linear(dim, dim * 3, bias=False) self.to_out = nn.Linear(dim, dim, bias=False) def forward(self, x, pos_emb, cache=None): qkv = self.to_qkv(x).chunk(3, dim=-1) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv) # Handle KV-cache for inference if cache is not None: ck, cv = cache k = torch.cat([ck, k], dim=-2) v = torch.cat([cv, v], dim=-2) new_cache = (k, v) out = flash_attn_with_pope( q, k, v, pos_emb=pos_emb, causal=True, softmax_scale=self.scale, head_dimension_at_first=True ) out = rearrange(out, 'b h n d -> b n (h d)') return self.to_out(out), new_cache class TransformerBlock(nn.Module): def __init__(self, dim, heads=8, ff_mult=4): super().__init__() self.norm1 = RMSNorm(dim) self.attn = CausalAttention(dim, heads) self.norm2 = RMSNorm(dim) self.ff = nn.Sequential( nn.Linear(dim, dim * ff_mult), nn.GELU(), nn.Linear(dim * ff_mult, dim) ) def forward(self, x, pos_emb, cache=None): attn_out, new_cache = self.attn(self.norm1(x), pos_emb, cache) x = x + attn_out x = x + self.ff(self.norm2(x)) return x, new_cache class PoPETransformer(nn.Module): def __init__(self, vocab_size, dim=512, depth=6, heads=8): super().__init__() self.token_emb = nn.Embedding(vocab_size, dim) self.pope = PoPE(dim // heads, heads=heads) self.layers = nn.ModuleList([ TransformerBlock(dim, heads) for _ in range(depth) ]) self.norm = RMSNorm(dim) self.to_logits = nn.Linear(dim, vocab_size, bias=False) def forward(self, x, cache=None, return_cache=False): seq_len = x.shape[1] x = self.token_emb(x) # Compute full KV sequence length for positional embeddings kv_len = seq_len if cache is None else (cache[0][0].shape[-2] + seq_len) pos_emb = self.pope(kv_len) new_caches = [] for i, layer in enumerate(self.layers): layer_cache = cache[i] if cache is not None else None x, new_cache = layer(x, pos_emb, layer_cache) new_caches.append(new_cache) logits = self.to_logits(self.norm(x)) if return_cache: return logits, new_caches return logits @torch.no_grad() def generate(self, prompt, max_tokens=100, temperature=1.0): self.eval() tokens = prompt.clone() cache = None for _ in range(max_tokens): # Use cache: only process last token after first pass curr_input = tokens if cache is None else tokens[:, -1:] logits, cache = self.forward(curr_input, cache=cache, return_cache=True) # Sample next token logits = logits[:, -1] / temperature probs = F.softmax(logits, dim=-1) next_token = torch.multinomial(probs, 1) tokens = torch.cat([tokens, next_token], dim=-1) return tokens # Example usage device = 'cuda' if torch.cuda.is_available() else 'cpu' model = PoPETransformer(vocab_size=256, dim=256, depth=4, heads=4).to(device) # Training step input_ids = torch.randint(0, 256, (2, 128)).to(device) logits = model(input_ids[:, :-1]) loss = F.cross_entropy( logits.reshape(-1, 256), input_ids[:, 1:].reshape(-1) ) print(f"Training loss: {loss.item():.4f}") # Generation with KV-cache prompt = torch.randint(0, 256, (1, 10)).to(device) generated = model.generate(prompt, max_tokens=50) print(f"Generated sequence shape: {generated.shape}") # torch.Size([1, 60]) ``` ## Summary PoPE-pytorch provides a robust foundation for incorporating polar coordinate positional embeddings into transformer architectures. The library's main use cases include enhancing sequence transformers with improved positional representations, building vision transformers with AxialPoPE for 2D spatial awareness, and creating video models that leverage 3D axial positional information. The fused Triton kernels enable efficient training and inference on CUDA devices, while the fallback to standard PyTorch operations ensures compatibility across different hardware configurations. Integration with existing transformer codebases is straightforward: replace traditional positional embeddings with PoPE, pass the generated positional information to the attention layers, and use `apply_pope_to_qk` or the complete `flash_attn_with_pope` function depending on the level of control needed. The library supports both full and partial rotation of feature dimensions, enabling flexible architectural choices between fully position-aware and partially content-preserved representations. KV-caching for efficient autoregressive inference is naturally supported through the design of the positional embedding functions.