Try Live
Add Docs
Rankings
Pricing
Enterprise
Docs
Install
Install
Docs
Pricing
Enterprise
More...
More...
Try Live
Rankings
Add Docs
Torch Einops Utils
https://github.com/lucidrains/torch-einops-utils
Admin
Torch Einops Utils is a Python utility library providing helper functions for PyTorch tensor
...
Tokens:
8,908
Snippets:
44
Trust Score:
9.9
Update:
1 month ago
Context
Skills
Chat
Benchmark
56.1
Suggestions
Latest
Show doc for...
Code
Info
Show Results
Context Summary (auto-generated)
Raw
Copy
Link
# Torch Einops Utils Torch Einops Utils is a Python utility library that provides helper functions for PyTorch tensor manipulation, designed to accelerate machine learning and AI development workflows. Built by Phil Wang (lucidrains), it extends PyTorch and einops with convenient utilities for masking, padding, slicing, dimension handling, and tensor operations commonly needed in deep learning projects. The library focuses on practical operations that are frequently required but tedious to implement: converting sequence lengths to masks, padding tensors at specific dimensions, safely stacking/concatenating tensors with potential None values, and managing tensor dimensions for broadcasting. It also includes a powerful `save_load` decorator for PyTorch modules that automatically handles model serialization with configuration preservation. ## Installation ```bash pip install torch-einops-utils ``` ## API Reference ### maybe - Null-Safe Function Wrapper Wraps a function to handle None inputs gracefully. If the input is None, returns None instead of calling the function. Useful for chaining operations on potentially missing tensors. ```python import torch from torch_einops_utils import maybe # Create a function that adds 1 to a tensor add_one = lambda t: t + 1 # Wrap it with maybe to handle None safely safe_add_one = maybe(add_one) # Works normally with tensors tensor = torch.tensor([1, 2, 3]) result = safe_add_one(tensor) print(result) # tensor([2, 3, 4]) # Returns None when input is None (no error) result = safe_add_one(None) print(result) # None # maybe(None) returns identity function identity_fn = maybe(None) result = identity_fn(tensor) print(result) # tensor([1, 2, 3]) ``` ### masked_mean - Compute Mean with Masking Support Computes the mean of a tensor with optional masking. Supports broadcasting masks to match tensor dimensions and handles edge cases where mask is all False. ```python import torch from torch_einops_utils import masked_mean # Simple mean without mask t = torch.tensor([1., 2., 3., 4.]) result = masked_mean(t) print(result) # tensor(2.5) # Mean with a boolean mask mask = torch.tensor([True, False, True, False]) result = masked_mean(t, mask=mask) print(result) # tensor(2.0) - only values at True positions # Mean along specific dimension with mask t = torch.tensor([[1., 2.], [3., 4.]]) mask = torch.tensor([[True, False], [True, True]]) result = masked_mean(t, mask=mask, dim=0) print(result) # tensor([2.0, 4.0]) result = masked_mean(t, mask=mask, dim=1) print(result) # tensor([1.0, 3.5]) # Mask with fewer dimensions is auto-expanded t = torch.randn(2, 3, 4) mask = torch.tensor([True, False]) # Shape (2,) result = masked_mean(t, mask=mask, dim=(1, 2)) print(result.shape) # torch.Size([2]) ``` ### slice_at_dim - Slice Tensor at Any Dimension Slices a tensor at a specified dimension using Python slice objects. Provides cleaner syntax than manual indexing for variable dimensions. ```python import torch from torch_einops_utils import slice_at_dim, slice_left_at_dim, slice_right_at_dim t = torch.randn(3, 4, 5) # Slice at last dimension (default dim=-1) result = slice_at_dim(t, slice(1, 3)) print(result.shape) # torch.Size([3, 4, 2]) # Slice at specific dimension result = slice_at_dim(t, slice(None, 2), dim=1) print(result.shape) # torch.Size([3, 2, 5]) # Slice from start (left) at dimension result = slice_left_at_dim(t, length=2, dim=1) print(result.shape) # torch.Size([3, 2, 5]) # Equivalent to t[:, :2, :] # Slice from end (right) at dimension result = slice_right_at_dim(t, length=2, dim=1) print(result.shape) # torch.Size([3, 2, 5]) # Equivalent to t[:, -2:, :] ``` ### shape_with_replace - Modify Tensor Shape Returns a modified shape by replacing values at specified indices. Useful for creating new shapes based on existing tensors. ```python import torch from torch_einops_utils import shape_with_replace t = torch.randn(3, 4, 5) print(t.shape) # torch.Size([3, 4, 5]) # Replace dimension 1 with value 2 new_shape = shape_with_replace(t, {1: 2}) print(new_shape) # torch.Size([3, 2, 5]) # Multiple replacements new_shape = shape_with_replace(t, {0: 1, 2: 10}) print(new_shape) # torch.Size([1, 4, 10]) ``` ### pad_ndim - Add Dimensions to Tensors Adds singleton dimensions to the left and/or right of a tensor's shape. Essential for broadcasting operations. ```python import torch from torch_einops_utils import ( pad_ndim, pad_left_ndim, pad_right_ndim, pad_left_ndim_to, pad_right_ndim_to, align_dims_left ) t = torch.randn(3) # Add 1 dim left and 2 dims right result = pad_ndim(t, (1, 2)) print(result.shape) # torch.Size([1, 3, 1, 1]) # Add dimensions only to right result = pad_right_ndim(t, 2) print(result.shape) # torch.Size([3, 1, 1]) # Add dimensions only to left result = pad_left_ndim(t, 2) print(result.shape) # torch.Size([1, 1, 3]) # Pad to specific number of dimensions t = torch.randn(3) result = pad_right_ndim_to(t, 3) print(result.shape) # torch.Size([3, 1, 1]) result = pad_left_ndim_to(t, 3) print(result.shape) # torch.Size([1, 1, 3]) # Align multiple tensors to same number of dimensions (left-aligned) t = torch.randn(3) u = torch.randn(3, 5, 2) v = torch.randn(3, 5) t, u, v = align_dims_left((t, u, v)) print(t.shape) # torch.Size([3, 1, 1]) print(u.shape) # torch.Size([3, 5, 2]) print(v.shape) # torch.Size([3, 5, 1]) ``` ### pad_at_dim - Pad Tensor at Specific Dimension Pads a tensor with a specified value at any dimension. More flexible than F.pad for arbitrary dimensions. ```python import torch from torch_einops_utils import ( pad_at_dim, pad_left_at_dim, pad_right_at_dim, pad_left_at_dim_to, pad_right_at_dim_to ) t = torch.randn(3, 6, 1) # Pad with (left_pad, right_pad) at dimension 1 result = pad_at_dim(t, (0, 1), dim=1) print(result.shape) # torch.Size([3, 7, 1]) # Pad only on the right at dimension 1 result = pad_right_at_dim(t, 2, dim=1) print(result.shape) # torch.Size([3, 8, 1]) # Pad only on the left at dimension 1 result = pad_left_at_dim(t, 2, dim=1) print(result.shape) # torch.Size([3, 8, 1]) # Pad to reach a target length result = pad_right_at_dim_to(t, length=10, dim=1) print(result.shape) # torch.Size([3, 10, 1]) result = pad_left_at_dim_to(t, length=10, dim=1) print(result.shape) # torch.Size([3, 10, 1]) # Custom padding value result = pad_at_dim(t, (1, 1), dim=1, value=-1.0) print(result[:, 0, :]) # All -1.0 ``` ### lens_to_mask - Convert Lengths to Boolean Mask Converts a tensor of sequence lengths to a boolean mask. Essential for handling variable-length sequences in batch processing. ```python import torch from torch_einops_utils import lens_to_mask # Sequence lengths lens = torch.tensor([3, 2, 4]) # Create mask where True indicates valid positions mask = lens_to_mask(lens) print(mask) # tensor([[ True, True, True, False], # [ True, True, False, False], # [ True, True, True, True]]) # Specify maximum length explicitly mask = lens_to_mask(lens, max_len=6) print(mask.shape) # torch.Size([3, 6]) # Works with batched lengths lens = torch.tensor([[2, 3], [1, 4]]) mask = lens_to_mask(lens) print(mask.shape) # torch.Size([2, 2, 4]) ``` ### and_masks / or_masks - Combine Boolean Masks Combines multiple boolean masks using logical AND or OR operations. Safely handles None values in the mask list. ```python import torch from torch_einops_utils import and_masks, or_masks mask1 = torch.tensor([True, True, False, False]) mask2 = torch.tensor([True, False, True, False]) # AND combination result = and_masks([mask1, mask2]) print(result) # tensor([True, False, False, False]) # OR combination result = or_masks([mask1, mask2]) print(result) # tensor([True, True, True, False]) # None values are ignored result = and_masks([mask1, None, mask2]) print(result) # tensor([True, False, False, False]) # Returns None if all masks are None result = and_masks([None, None]) print(result) # None ``` ### safe_stack / safe_cat - Null-Safe Tensor Operations Stack or concatenate tensors while safely handling None values. Returns None for empty lists, original tensor for single element. ```python import torch from torch_einops_utils import safe_stack, safe_cat t1 = torch.randn(2, 3) t2 = torch.randn(2, 3) # Safe stack - handles None values result = safe_stack([t1, t2]) print(result.shape) # torch.Size([2, 2, 3]) result = safe_stack([t1, None, t2]) print(result.shape) # torch.Size([2, 2, 3]) - None ignored result = safe_stack([t1]) print(result.shape) # torch.Size([1, 2, 3]) result = safe_stack([]) print(result) # None result = safe_stack([None, None]) print(result) # None # Safe cat - similar behavior result = safe_cat([t1, t2]) print(result.shape) # torch.Size([4, 3]) result = safe_cat([t1, None]) print(result == t1) # All True - returns single tensor as-is result = safe_cat([]) print(result) # None ``` ### pad_sequence - Batch Variable-Length Sequences Pads a list of tensors with different lengths to the same size and optionally stacks them. More flexible than torch.nn.utils.rnn.pad_sequence. ```python import torch from torch_einops_utils import pad_sequence, pad_sequence_and_cat # Variable length sequences x = torch.randn(2, 4, 5) # batch=2, seq=4, features=5 y = torch.randn(2, 3, 5) # batch=2, seq=3, features=5 z = torch.randn(2, 1, 5) # batch=2, seq=1, features=5 # Pad and stack along sequence dimension packed, lens = pad_sequence([x, y, z], dim=1, return_lens=True) print(packed.shape) # torch.Size([3, 2, 4, 5]) - stacked with batch dim first print(lens) # tensor([4, 3, 1]) # Left padding (useful for causal/autoregressive models) packed = pad_sequence([x, y, z], dim=1, left=True) print(packed.shape) # torch.Size([3, 2, 4, 5]) # Return padding lengths instead of sequence lengths packed, pad_lens = pad_sequence([x, y, z], dim=1, return_lens=True, pad_lens=True) print(pad_lens) # tensor([0, 1, 3]) # Without stacking (returns list) padded = pad_sequence([x, y, z], dim=1, return_stacked=False) print(len(padded)) # 3 print([t.shape for t in padded]) # All torch.Size([2, 4, 5]) # Padding uneven images images = [ torch.randn(3, 16, 17), # C, H, W torch.randn(3, 15, 18), torch.randn(3, 17, 16) ] # Pad height dimension, then concatenate padded_height = pad_sequence(images, dim=-2, return_stacked=False) stacked = pad_sequence_and_cat(padded_height, dim_cat=0) print(stacked.shape) # torch.Size([9, 17, 18]) ``` ### tree_flatten_with_inverse - Flatten PyTree with Reconstruction Flattens a nested Python structure (PyTree) and returns an inverse function to reconstruct the original structure. ```python import torch from torch_einops_utils import tree_flatten_with_inverse # Nested structure tree = (1, (2, 3), 4) # Flatten with inverse flattened, inverse = tree_flatten_with_inverse(tree) print(flattened) # [1, 2, 3, 4] # Modify and reconstruct modified = [flattened[0] + 10] + flattened[1:] reconstructed = inverse(modified) print(reconstructed) # (11, (2, 3), 4) # Works with complex nested structures nested = {'a': [1, 2], 'b': (3, {'c': 4})} flattened, inverse = tree_flatten_with_inverse(nested) reconstructed = inverse(flattened) print(reconstructed == nested) # True ``` ### tree_map_tensor - Apply Function to Tensors in PyTree Applies a function only to tensors within a nested Python structure, leaving other values unchanged. ```python import torch from torch_einops_utils import tree_map_tensor # Mixed structure with tensors and non-tensors tree = (1, torch.tensor([2, 3]), 'hello', torch.tensor([4.0])) # Apply function only to tensors result = tree_map_tensor(lambda t: t * 2, tree) print(result[0]) # 1 (unchanged - not a tensor) print(result[1]) # tensor([4, 6]) print(result[2]) # 'hello' (unchanged) print(result[3]) # tensor([8.0]) ``` ### pack_with_inverse - Einops Pack with Undo Packs tensors using einops and returns an inverse function to unpack them. Simplifies reshape operations that need to be reversed. ```python import torch from torch_einops_utils import pack_with_inverse # Single tensor packing t = torch.randn(3, 12, 2, 2) packed, inverse = pack_with_inverse(t, 'b * d') print(packed.shape) # torch.Size([3, 24, 2]) - middle dims merged # Unpack back to original shape unpacked = inverse(packed) print(unpacked.shape) # torch.Size([3, 12, 2, 2]) # Pack multiple tensors t = torch.randn(3, 12, 2, 2) u = torch.randn(3, 4, 2) packed, inverse = pack_with_inverse([t, u], 'b * d') print(packed.shape) # torch.Size([3, 28, 2]) # Unpack with different pattern packed_sum = packed.sum(dim=-1) t_out, u_out = inverse(packed_sum, 'b *') print(t_out.shape) # torch.Size([3, 12, 2]) print(u_out.shape) # torch.Size([3, 4]) ``` ### save_load - Model Serialization Decorator A decorator that adds `save`, `load`, and `init_and_load` methods to PyTorch modules. Automatically preserves constructor arguments for full model reconstruction. ```python import torch from torch import nn from torch_einops_utils.save_load import save_load @save_load(version='1.0.0') class MyModel(nn.Module): def __init__(self, dim, hidden_dim, dropout=0.1): super().__init__() self.dim = dim self.hidden_dim = hidden_dim self.net = nn.Sequential( nn.Linear(dim, hidden_dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden_dim, dim) ) def forward(self, x): return self.net(x) # Create and save model model = MyModel(dim=64, hidden_dim=256, dropout=0.2) model.save('model.pt') # Load into existing model model2 = MyModel(dim=64, hidden_dim=256, dropout=0.2) model2.load('model.pt') # Initialize AND load from checkpoint (reconstructs with saved config) model3 = MyModel.init_and_load('model.pt') print(model3.dim) # 64 print(model3.hidden_dim) # 256 # Works with nested decorated modules @save_load() class Encoder(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim self.proj = nn.Linear(dim, dim) @save_load() class FullModel(nn.Module): def __init__(self, encoder, output_dim): super().__init__() self.encoder = encoder self.output = nn.Linear(encoder.dim, output_dim) # Nested models are automatically serialized encoder = Encoder(dim=32) model = FullModel(encoder=encoder, output_dim=10) model.save('full_model.pt') # Reconstruct entire hierarchy loaded = FullModel.init_and_load('full_model.pt') print(loaded.encoder.dim) # 32 ``` ### module_device - Get Module Device Returns the device of a PyTorch module by checking its first parameter or buffer. ```python import torch from torch import nn from torch_einops_utils.device import module_device # Module with parameters model = nn.Linear(3, 3) device = module_device(model) print(device) # device(type='cpu') # Module without parameters (e.g., Identity) model = nn.Identity() device = module_device(model) print(device) # None ``` ### move_inputs_to_device - Device Transfer Decorator Decorators that automatically move all tensor inputs to a specified device or the module's device before function execution. ```python import torch from torch import nn from torch_einops_utils.device import move_inputs_to_device, move_inputs_to_module_device # Move inputs to specific device @move_inputs_to_device(torch.device('cpu')) def process(x, y): return x + y # Inputs are automatically moved to CPU x = torch.randn(3) # Could be on any device y = torch.randn(3) result = process(x, y) # Move inputs to module's device (for nn.Module methods) class MyModule(nn.Module): def __init__(self, dim): super().__init__() self.linear = nn.Linear(dim, dim) @move_inputs_to_module_device def forward(self, x, mask=None): # x and mask are automatically moved to same device as self.linear out = self.linear(x) if mask is not None: out = out * mask return out model = MyModule(64) # Inputs will be moved to model's device automatically x = torch.randn(2, 64) output = model(x) ``` ## Summary Torch Einops Utils provides essential building blocks for PyTorch tensor manipulation that are commonly needed in deep learning projects. The library excels at handling variable-length sequences, masked operations, dimension management for broadcasting, and null-safe tensor operations. These utilities eliminate boilerplate code when working with attention mechanisms, sequence models, and batch processing of irregular data. The `save_load` decorator offers a powerful serialization solution that goes beyond standard PyTorch checkpointing by preserving constructor arguments, enabling complete model reconstruction from a single file. Combined with the device management utilities, the library provides a comprehensive toolkit for building robust, maintainable ML/AI code. Integration is straightforward - simply import the needed functions and use them alongside standard PyTorch and einops operations.