# Torch Einops Kit `torch_einops_kit` is a typed PyTorch utility library providing tensor-shaping, masking, padding, device-routing, and checkpoint helpers for use in deep learning model repositories. It is a strict-typed, extensively tested superset of `lucidrains/torch-einops-utils`, designed to centralize small tensor utility functions that appear repeatedly across lucidrains-style model repositories (dreamer4, metacontroller, pi-zero-pytorch, sdft-pytorch, locoformer, and others). The package ships a `py.typed` marker, uses strict Pyright type checking, and is annotated with overloads so type checkers preserve useful return-type information. The library is organized into a root package plus four focused submodules: `torch_einops_kit.einops` for einops pack/unpack with paired inverses, `torch_einops_kit.device` for automatic device routing, `torch_einops_kit.save_load` for checkpoint save/load decorators, and `torch_einops_kit.scaleValues` for normalization layers. Most tensor helpers are importable directly from `torch_einops_kit`; submodule imports are used for the specialized utilities. All functions follow a `None`-safe convention: helpers that take sequences of tensors filter `None` values before operating and return `None` for empty effective input. --- ## Installation ```bash pip install torch_einops_kit # or uv add torch_einops_kit ``` --- ## Optional-value and structure helpers ### `exists(v)` — None guard with TypeGuard narrowing Returns `True` when `v is not None`. Falsy values such as `0`, `False`, and `[]` still count as existing. The return type is `TypeGuard[TVar]`, so static analyzers narrow the type after the check. ```python import torch from torch_einops_kit import exists, lens_to_mask def build_mask(lens: torch.Tensor, max_len: int | None = None) -> torch.Tensor: if not exists(max_len): max_len = int(lens.amax().item()) return lens_to_mask(lens, max_len) lens = torch.tensor([4, 3, 1]) mask = build_mask(lens) assert mask.shape == (3, 4) mask_wide = build_mask(lens, max_len=6) assert mask_wide.shape == (3, 6) ``` --- ### `default(v, d)` — fallback for optional values Returns `v` when `v` is not `None`, otherwise returns the fallback `d`. ```python from torch_einops_kit import default def attention_scale(scale: float | None, dim: int) -> float: return default(scale, dim ** -0.5) assert attention_scale(0.5, 64) == 0.5 assert attention_scale(None, 64) == 64 ** -0.5 ``` --- ### `compact(arr)` — filter None from an iterable Removes every `None` value from `arr` and returns the remaining elements as a typed `list`. ```python from torch_einops_kit import compact import torch tensors = [torch.randn(3), None, torch.randn(3), None] result = compact(tensors) assert len(result) == 2 # result is list[Tensor]; None values are removed ``` --- ### `maybe(fn)` — conditional function application Wraps `fn` so the wrapped callable returns `None` when its first argument is `None`. When `fn` itself is `None`, returns `identity`. ```python import torch from torch_einops_kit import maybe, lens_to_mask # episode_lens may be None at the start of training episode_lens: torch.Tensor | None = torch.tensor([4, 3, 1]) seq_len = 5 # mask is None when episode_lens is None, otherwise lens_to_mask is applied mask = maybe(lens_to_mask)(episode_lens, seq_len) assert mask is not None assert mask.shape == (3, 5) # When episode_lens is None, mask is None mask_none = maybe(lens_to_mask)(None, seq_len) assert mask_none is None # Pass None as fn to get identity behavior result = maybe(None)(42) assert result == 42 ``` --- ### `once(fn)` — execute a callable at most once Wraps `fn` so it runs only on the first call. Subsequent calls return `None`. ```python from torch_einops_kit import once print_once = once(print) for _ in range(5): print_once("A100 GPU detected, using flash attention") # Prints exactly once; subsequent calls silently return None ``` --- ### `map_values(fn, v)` — recursive leaf transformation Applies `fn` to every leaf value in a nested `list`, `tuple`, or `dict` structure, preserving container shape. ```python from torch_einops_kit import map_values nested = {"a": [1, 2, {"b": 3}], "c": (4, 5)} result = map_values(lambda x: x * 2, nested) # result == {"a": [2, 4, {"b": 6}], "c": (8, 10)} ``` --- ### `safe(fn)` — decorator for None-tolerant tensor sequences Wraps a function that accepts `Sequence[Tensor]` to tolerate `None` values in the sequence. Returns `None` when no non-`None` tensors remain. ```python import torch from collections.abc import Sequence from torch import Tensor from torch_einops_kit import safe @safe def my_sum(tensors: Sequence[Tensor]) -> Tensor | None: return torch.stack(tensors).sum(dim=0) t1, t2 = torch.randn(3), torch.randn(3) result = my_sum([t1, None, t2]) # None filtered before call assert result is not None result_empty = my_sum([None, None]) # returns None assert result_empty is None ``` --- ## Shape and slicing helpers ### `shape_with_replace(t, replace_dict)` — derive a shape with substituted dimensions Returns a `torch.Size` based on `t.shape` with selected dimension sizes replaced. Does not modify `t`. Keys in `replace_dict` must be non-negative integers less than `t.ndim`. ```python import torch from torch_einops_kit import shape_with_replace generated_video = torch.randn(2, 3, 10, 64, 64) # (batch, channels, time, H, W) real_len, gen_len = 12, 10 # Build a zero-padding shape matching the video except along the time axis pad_shape = shape_with_replace(generated_video, {2: real_len - gen_len}) padding = generated_video.new_zeros(pad_shape) extended = torch.cat((generated_video, padding), dim=2) assert extended.shape == (2, 3, 12, 64, 64) # Compute a future-frame noise shape latents = torch.randn(2, 3, 8, 32) pred_shape = shape_with_replace(latents, {2: 4}) future_noise = torch.randn(pred_shape, device=latents.device) assert future_noise.shape == (2, 3, 4, 32) ``` --- ### `slice_at_dim(t, slc, dim)` — apply a slice to one dimension Applies `slc` to a single dimension and preserves every other dimension. Negative `dim` values are normalized before indexing. ```python import torch from torch_einops_kit import slice_at_dim t = torch.randn(3, 4, 5) # Slice the last dimension (default dim=-1) res = slice_at_dim(t, slice(1, 3)) assert res.shape == (3, 4, 2) # Keep the first two elements of dimension 1 res = slice_at_dim(t, slice(None, 2), dim=1) assert res.shape == (3, 2, 5) # Windowed attention: shift-and-concatenate along a sequence axis left = slice_at_dim(t, slice(None, -1), dim=1) right = slice_at_dim(t, slice(1, None), dim=1) assert left.shape == (3, 3, 5) assert right.shape == (3, 3, 5) ``` --- ### `slice_left_at_dim(t, length, dim)` / `slice_right_at_dim(t, length, dim)` — prefix and suffix slicing `slice_left_at_dim` retains the first `length` elements; `slice_right_at_dim` retains the last `length` elements. ```python import torch from torch_einops_kit import slice_left_at_dim, slice_right_at_dim t = torch.randn(3, 4, 5) prefix = slice_left_at_dim(t, 2, dim=1) assert prefix.shape == (3, 2, 5) suffix = slice_right_at_dim(t, 2, dim=1) assert suffix.shape == (3, 2, 5) # Typical use: trim precomputed positional frequencies to query length freqs = torch.randn(1, 512, 64) # (1, max_seq, head_dim) q_len = 128 trimmed_freqs = slice_right_at_dim(freqs, q_len, dim=-2) assert trimmed_freqs.shape == (1, 128, 64) ``` --- ## Rank-alignment and singleton-dimension helpers ### `pad_ndim(t, (left, right))` / `pad_left_ndim(t, ndims)` / `pad_right_ndim(t, ndims)` — insert singleton dimensions Insert singleton dimensions on either side of a tensor shape to increase its rank for broadcasting. No data is copied. ```python import torch from torch_einops_kit import pad_ndim, pad_left_ndim, pad_right_ndim t = torch.randn(4, 8) # shape: (4, 8) both = pad_ndim(t, (1, 2)) # (1, 4, 8, 1, 1) left = pad_left_ndim(t, 2) # (1, 1, 4, 8) right = pad_right_ndim(t, 2) # (4, 8, 1, 1) assert both.shape == (1, 4, 8, 1, 1) assert left.shape == (1, 1, 4, 8) assert right.shape == (4, 8, 1, 1) ``` --- ### `pad_left_ndim_to(t, ndims)` / `pad_right_ndim_to(t, ndims)` — pad to a target rank Prepend or append singleton dimensions until a tensor has at least `ndims` dimensions. Returns `t` unchanged when `t.ndim >= ndims`. ```python import torch from torch_einops_kit import pad_right_ndim_to # Broadcast a 1-D time scalar against a 5-D video tensor (b, c, t, h, w) video = torch.randn(2, 3, 10, 64, 64) time = torch.tensor([0.5, 0.3]) # shape: (2,) aligned_time = pad_right_ndim_to(time, video.ndim) # (2, 1, 1, 1, 1) noised = video * aligned_time # broadcasts correctly assert noised.shape == (2, 3, 10, 64, 64) ``` --- ### `align_dims_left(tensors, ndim)` — align multiple tensors to the same rank Pads trailing singleton dimensions across several tensors to a common rank, enabling broadcasting in element-wise operations. When `ndim` is `None`, the target is the maximum rank among the inputs. ```python import torch from torch_einops_kit import align_dims_left # PPO surrogate loss: align ratio (b, n) with advantages (b, n, d) ratio = torch.randn(4, 32) advantages = torch.randn(4, 32, 128) ratio_aligned, advantages_aligned = align_dims_left((ratio, advantages)) assert ratio_aligned.shape == (4, 32, 1) assert advantages_aligned.shape == (4, 32, 128) surr1 = ratio_aligned * advantages_aligned # broadcasts over d assert surr1.shape == (4, 32, 128) ``` --- ## Mask helpers ### `lens_to_mask(lens, max_len)` — convert lengths to boolean masks Converts a tensor of integer lengths into a boolean mask. Position `i` in the last axis is `True` when `i < lens[...]`. ```python import torch from torch_einops_kit import lens_to_mask lens = torch.tensor([4, 3, 1]) mask = lens_to_mask(lens) assert mask.shape == (3, 4) assert (mask.sum(dim=-1) == lens).all() # Supply an explicit max_len wider than the longest sequence wide_mask = lens_to_mask(lens, max_len=6) assert wide_mask.shape == (3, 6) # Use with pad_sequence to mask padded positions in a batch x, y, z = torch.randn(2, 4, 5), torch.randn(2, 3, 5), torch.randn(2, 1, 5) from torch_einops_kit import pad_sequence packed, seq_lens = pad_sequence([x, y, z], dim=1, return_lens=True) seq_mask = lens_to_mask(seq_lens) assert seq_mask.shape == (3, 4) ``` --- ### `and_masks(masks)` / `or_masks(masks)` / `reduce_masks(masks, op)` — combine boolean masks Reduce a sequence of boolean mask tensors to a single mask. `None` values are treated as absent and filtered out. Return `None` when no non-`None` masks remain. ```python import torch from torch_einops_kit import and_masks, or_masks, reduce_masks mask1 = torch.tensor([True, True, False]) mask2 = torch.tensor([True, False, False]) # AND: True only where all masks are True result_and = and_masks([mask1, mask2]) assert result_and.tolist() == [True, False, False] # OR: True where at least one mask is True result_or = or_masks([mask1, mask2]) assert result_or.tolist() == [True, True, False] # None values are silently skipped result_skip = and_masks([None, mask1, None]) assert result_skip.tolist() == [True, True, False] # All-None input returns None result_none = and_masks([None, None]) assert result_none is None # Custom operator: exclude positions where both are active (XOR-like) result_custom = reduce_masks([mask1, mask2], torch.logical_xor) assert result_custom.tolist() == [False, True, False] ``` --- ## Concatenation and stacking helpers ### `safe_cat(tensors, dim)` — concatenate while skipping None values Concatenates a mixed sequence of `Tensor | None` values along an existing dimension. Returns `None` when no non-`None` tensors remain. ```python import torch from torch_einops_kit import safe_cat t1 = torch.randn(2, 3) t2 = torch.randn(2, 3) # Basic concatenation along dim 0 result = safe_cat([t1, t2]) assert result.shape == (4, 3) # None values are skipped; single surviving tensor is returned unchanged result = safe_cat([t1, None]) assert result.shape == (2, 3) # All-None or empty input returns None assert safe_cat([]) is None assert safe_cat([None]) is None # Iterative accumulation pattern accumulator: torch.Tensor | None = None for step_output in [t1, None, t2]: accumulator = safe_cat([accumulator, step_output], dim=1) assert accumulator.shape == (2, 6) ``` --- ### `safe_stack(tensors, dim)` — stack while skipping None values Stacks a mixed sequence of `Tensor | None` values along a new dimension. Returns `None` when no non-`None` tensors remain. ```python import torch from torch_einops_kit import safe_stack t1 = torch.randn(2, 3) t2 = torch.randn(2, 3) result = safe_stack([t1, t2]) assert result.shape == (2, 2, 3) # A single tensor is still stacked (gains a new dimension) result = safe_stack([t1, None]) assert result.shape == (1, 2, 3) assert safe_stack([]) is None assert safe_stack([None]) is None ``` --- ### `broadcast_cat(tensors, dim)` — broadcast then concatenate Broadcasts tensor groups to a compatible shape before concatenating along `dim`. ```python import torch from torch_einops_kit import broadcast_cat a = torch.randn(4, 1, 8) b = torch.randn(1, 6, 8) # a broadcasts to (4, 6, 8), b broadcasts to (4, 6, 8), then cat along last dim result = broadcast_cat([a, b], dim=-1) assert result.shape == (4, 6, 16) ``` --- ## Padding helpers ### `pad_at_dim(t, (left, right), dim, value)` — pad or trim one dimension Pads `t` along one dimension. Positive values add elements; negative values trim them. ```python import torch from torch_einops_kit import pad_at_dim t = torch.randn(3, 6, 1) # Add one element at the right of dimension 1 padded = pad_at_dim(t, (0, 1), dim=1) assert padded.shape == (3, 7, 1) # Insert one leading zero to shift action tokens (dreamer4 pattern) action_tokens = torch.randn(2, 8, 32) shifted = pad_at_dim(action_tokens[:, :-1], (1, 0), value=0.0, dim=1) assert shifted.shape == (2, 8, 32) # Create a one-step-delayed copy (locoformer pattern) past_action = pad_at_dim(action_tokens, (1, -1), dim=-2) assert past_action.shape == (2, 8, 32) ``` --- ### `pad_left_at_dim(t, pad)` / `pad_right_at_dim(t, pad)` — one-sided padding Prepend or append a fixed number of fill elements along one dimension. ```python import torch from torch_einops_kit import pad_left_at_dim, pad_right_at_dim t = torch.randn(2, 5) left_padded = pad_left_at_dim(t, 2, dim=1) # (2, 7) right_padded = pad_right_at_dim(t, 2, dim=1) # (2, 7) assert left_padded.shape == (2, 7) assert right_padded.shape == (2, 7) # Prepend a start token (pi-zero-pytorch pattern) discrete_ids = torch.randint(0, 100, (4, 10)) with_start = pad_left_at_dim(discrete_ids + 1, 1) assert with_start.shape == (4, 11) ``` --- ### `pad_left_at_dim_to(t, length)` / `pad_right_at_dim_to(t, length)` — conditional padding to target length Pad to a minimum target length; return `t` unchanged when it is already long enough. ```python import torch from torch_einops_kit import pad_left_at_dim_to, pad_right_at_dim_to t = torch.randn(3, 6, 1) # Pad to length 7 along dim 1 result = pad_left_at_dim_to(t, 7, dim=1) assert result.shape == (3, 7, 1) # Already at length 6; tensor returned unchanged result_same = pad_left_at_dim_to(t, 6, dim=1) assert result_same is t # Align variable-length action sequences to max_time before batching (dreamer4) sequences = [torch.randn(2, length, 32) for length in [4, 7, 5]] max_time = 8 aligned = [pad_right_at_dim_to(s, max_time, dim=1) for s in sequences] assert all(s.shape == (2, 8, 32) for s in aligned) ``` --- ### `pad_sequence(tensors, ...)` — pad a sequence to a shared length and optionally stack Pads every tensor in `tensors` to the maximum length along `dim`. Overloaded return type depends on `return_stacked` and `return_lens`. Returns `None` for empty input. ```python import torch from torch_einops_kit import pad_sequence, lens_to_mask # Variable-length feature tensors x = torch.randn(2, 4, 5) y = torch.randn(2, 3, 5) z = torch.randn(2, 1, 5) # Stack and recover original lengths packed, lens = pad_sequence([x, y, z], dim=1, return_lens=True) assert packed.shape == (3, 2, 4, 5) # (num_tensors, batch, max_len, features) assert lens.tolist() == [4, 3, 1] # Build a mask for the padded positions seq_mask = lens_to_mask(lens) assert seq_mask.shape == (3, 4) # Left-pad and retrieve padding widths (sdft-pytorch pattern) prompt_ids = [torch.randint(0, 1000, (t,)) for t in [6, 4, 8]] padded_ids, start_pos = pad_sequence(prompt_ids, return_lens=True, left=True, pad_lens=True) assert padded_ids.shape == (3, 8) assert start_pos.tolist() == [2, 4, 0] # padding widths for each sequence # Return a list of individually padded tensors instead of stacking padded_list = pad_sequence([x, y, z], dim=1, return_stacked=False) assert isinstance(padded_list, list) assert all(t.shape[1] == 4 for t in padded_list) ``` --- ### `pad_sequence_and_cat(tensors, dim_cat, dim, value, left)` — pad and concatenate Pads tensors to a shared length then concatenates along `dim_cat`. Returns `None` for empty input. ```python import torch from torch_einops_kit import pad_sequence_and_cat # Images with different heights, same number of channels and width images = [torch.randn(3, 16, 17), torch.randn(3, 15, 18), torch.randn(3, 17, 16)] # First pad to a shared height, then to a shared width, then cat along batch padded = pad_sequence_and_cat(images, dim=-2, dim_cat=0) assert padded.shape == (9, 17, 17) ``` --- ## Normalization and masked reduction helpers (`torch_einops_kit.scaleValues`) ### `l2norm(t)` — normalize to unit length Normalizes each vector in `t` to unit length along the last dimension. ```python import torch from torch_einops_kit.scaleValues import l2norm q = torch.randn(4, 8, 64) # (batch, heads, dim) k = torch.randn(4, 8, 64) # Normalize query and key before computing dot-product attention scores q, k = map(l2norm, (q, k)) assert torch.allclose(q.norm(dim=-1), torch.ones(4, 8), atol=1e-5) scores = torch.einsum("bhd, bhd -> bh", q, k) ``` --- ### `masked_mean(t, mask, dim, eps)` — compute mean over selected positions Computes the mean of `t` over positions where `mask` is `True`. Falls back to `t.mean()` when `mask` is `None`. ```python import torch from torch_einops_kit.scaleValues import masked_mean t = torch.tensor([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]]) mask = torch.tensor([[True, True, False, False], [True, False, True, False]]) # Reduce over all dimensions (global masked mean) result = masked_mean(t, mask=mask) # (1+2+5+7) / 4 = 3.75 assert torch.allclose(result, torch.tensor(3.75)) # Reduce along dim=1 per row per_row = masked_mean(t, mask=mask, dim=1) assert per_row.shape == (2,) # row 0: mean([1, 2]) = 1.5 # row 1: mean([5, 7]) = 6.0 assert torch.allclose(per_row, torch.tensor([1.5, 6.0])) # Unmasked mean (mask=None) full = masked_mean(t) assert torch.allclose(full, t.mean()) ``` --- ### `RMSNorm(dim)` — learned root-mean-square normalization layer `torch.nn.Module` that normalizes the last feature axis to unit length, multiplies by `√dim`, and applies a learned `gamma` parameter. Use as a pre-normalization layer in transformer-style modules. ```python import torch from torch import nn from torch_einops_kit.scaleValues import RMSNorm dim = 128 norm = RMSNorm(dim) x = torch.randn(4, 32, dim) # (batch, seq, features) normalized = norm(x) assert normalized.shape == (4, 32, 128) # Use inside a transformer block class TransformerBlock(nn.Module): def __init__(self, dim: int) -> None: super().__init__() self.norm1 = RMSNorm(dim) self.attn = nn.MultiheadAttention(dim, num_heads=8, batch_first=True) self.norm2 = RMSNorm(dim) self.ff = nn.Sequential(nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: normed = self.norm1(x) attn_out, _ = self.attn(normed, normed, normed) x = x + attn_out x = x + self.ff(self.norm2(x)) return x ``` --- ## `einops` submodule (`torch_einops_kit.einops`) ### `pack_with_inverse(t, pattern)` — pack and retain an inverse function Packs a tensor (or list of tensors) using an einops pattern and returns a paired inverse function for shape restoration. ```python import torch from torch_einops_kit.einops import pack_with_inverse # Pack a single tensor: flatten spatial dims t = torch.randn(3, 12, 2, 2) packed, inverse = pack_with_inverse(t, "b * d") assert packed.shape == (3, 24, 2) # Recover original shape restored = inverse(packed) assert restored.shape == (3, 12, 2, 2) # Pack a list of tensors and unpack with an overriding pattern t = torch.randn(3, 12, 2) u = torch.randn(3, 4, 2) packed, inverse = pack_with_inverse([t, u], "b * d") assert packed.shape == (3, 28, 2) # Apply an operation on the packed tensor, then unpack reduced = packed.sum(dim=-1) # (3, 28) t_out, u_out = inverse(reduced, "b *") assert t_out.shape == (3, 12) assert u_out.shape == (3, 4) ``` --- ### `pack_one(t, pattern)` / `unpack_one(t, ps, pattern)` — single-tensor pack/unpack pair Pack a single tensor and store shape metadata; recover the original shape from that metadata. ```python import torch from torch_einops_kit.einops import pack_one, unpack_one t = torch.randn(4, 8, 3, 3) packed, ps = pack_one(t, "b * d") assert packed.shape == (4, 24, 3) restored = unpack_one(packed, ps, "b * d") assert restored.shape == (4, 8, 3, 3) ``` --- ## `device` submodule (`torch_einops_kit.device`) ### `module_device(m)` — infer a module's device Returns the `torch.device` of the first parameter or registered buffer in `m`. Returns `None` for stateless modules. ```python import torch from torch import nn from torch_einops_kit.device import module_device linear = nn.Linear(3, 5) assert module_device(linear) == torch.device("cpu") # GPU module if torch.cuda.is_available(): gpu_linear = nn.Linear(3, 5).cuda() assert module_device(gpu_linear) == torch.device("cuda", 0) # Stateless module returns None identity = nn.Identity() assert module_device(identity) is None ``` --- ### `move_inputs_to_device(device)` — decorator for explicit device routing Creates a decorator that moves all tensor arguments to a fixed target device before each call. ```python import torch from torch_einops_kit.device import move_inputs_to_device target = torch.device("meta") @move_inputs_to_device(target) def compute( x: torch.Tensor, nested: tuple[torch.Tensor, str], *, kw: torch.Tensor, ) -> tuple[torch.device, torch.device, torch.device]: return x.device, nested[0].device, kw.device cpu_t = torch.tensor([1.0, 2.0]) result = compute(cpu_t, (cpu_t, "tag"), kw=cpu_t) # All tensors moved to "meta" device; non-tensor values pass through unchanged assert all(d == torch.device("meta") for d in result) ``` --- ### `move_inputs_to_module_device(fn)` — auto-route tensors to the module's device Decorator for methods of `torch.nn.Module` subclasses. Infers the target device from `self` and moves all tensor arguments to it before each call. ```python import torch from torch import nn, Tensor from torch_einops_kit.device import move_inputs_to_module_device class EchoModule(nn.Module): def __init__(self) -> None: super().__init__() self.scale = nn.Parameter(torch.tensor([2.0], device=torch.device("meta"))) @move_inputs_to_module_device def forward(self, x: Tensor) -> Tensor: return x module = EchoModule() cpu_tensor = torch.tensor([1.0, 2.0]) result = module.forward(cpu_tensor) assert result.device == torch.device("meta") # Standalone function attached to a module class @move_inputs_to_module_device def policy_loss(model: nn.Module, states: Tensor, actions: Tensor) -> Tensor: return (states + actions).mean() ``` --- ## `save_load` submodule (`torch_einops_kit.save_load`) ### `save_load(...)` — checkpoint decorator for `nn.Module` subclasses Class decorator that records constructor arguments and adds `save`, `load`, and `init_and_load` methods to a `torch.nn.Module` subclass. ```python from pathlib import Path import torch from torch import nn, Tensor from torch_einops_kit.save_load import save_load @save_load() class SimpleNet(nn.Module): def __init__(self, dim: int, hidden_dim: int) -> None: super().__init__() self.dim = dim self.hidden_dim = hidden_dim self.net = nn.Linear(dim, hidden_dim) def forward(self, x: Tensor) -> Tensor: return self.net(x) path = Path("model.pt") model = SimpleNet(10, 20) model.save(str(path)) # Reconstruct from checkpoint without knowing constructor args restored = SimpleNet.init_and_load(str(path)) assert restored.dim == 10 assert restored.hidden_dim == 20 # Custom method names and version tagging @save_load( save_method_name="store", load_method_name="restore", config_instance_var_name="stored_config", init_and_load_classmethod_name="create_and_restore", version="1.0.0", ) class VersionedNet(nn.Module): def __init__(self, width: int) -> None: super().__init__() self.weight = nn.Parameter(torch.randn(width)) net = VersionedNet(13) net.store("versioned_model.pt") restored_net = VersionedNet.create_and_restore("versioned_model.pt") assert restored_net.weight.shape == (13,) ``` --- ### `dehydrate_config` / `rehydrate_config` — nested module checkpoint serialization `dehydrate_config` replaces nested decorated module instances with checkpoint-safe reconstruction records. `rehydrate_config` reconstructs those modules from the stored records. ```python from torch import nn import torch from torch_einops_kit.save_load import save_load, dehydrate_config, rehydrate_config @save_load() class InnerNet(nn.Module): def __init__(self, dim: int) -> None: super().__init__() self.proj = nn.Linear(dim, dim) @save_load() class OuterNet(nn.Module): def __init__(self, inner: InnerNet, scale: float) -> None: super().__init__() self.inner = inner self.scale = nn.Parameter(torch.tensor([scale])) def forward(self, x): return self.inner.proj(x) * self.scale inner = InnerNet(64) outer = OuterNet(inner, scale=2.0) outer.save("outer_model.pt") # init_and_load rehydrates the nested InnerNet automatically restored_outer = OuterNet.init_and_load("outer_model.pt") assert isinstance(restored_outer.inner, InnerNet) ``` --- ## PyTree helpers ### `tree_map_tensor(fn, tree)` — apply a function to all tensor leaves Applies `fn` to every `torch.Tensor` leaf in a nested PyTree, leaving non-tensor leaves unchanged. ```python import torch from torch_einops_kit import tree_map_tensor tree = (1, torch.tensor(2.0), {"a": torch.tensor(3.0), "b": "hello"}) result = tree_map_tensor(lambda t: t * 2, tree) # result == (1, tensor(4.0), {"a": tensor(6.0), "b": "hello"}) assert result[0] == 1 assert result[1].item() == 4.0 assert result[2]["b"] == "hello" # Detach all tensors in a state container (dreamer4 / fast-weight-attention pattern) memory = {"hidden": torch.randn(4, 32), "cell": torch.randn(4, 32), "step": 5} detached_memory = tree_map_tensor(lambda t: t.detach(), memory) assert not detached_memory["hidden"].requires_grad assert detached_memory["step"] == 5 # non-tensor unchanged ``` --- ### `tree_flatten_with_inverse(tree)` — flatten and reconstruct a PyTree Flattens a nested PyTree into a list of leaves and returns a paired inverse function to reconstruct the original structure. ```python import torch from torch_einops_kit import tree_flatten_with_inverse tree = (1, (torch.tensor(2.0), 3), {"x": 4}) leaves, inverse = tree_flatten_with_inverse(tree) # leaves == [1, tensor(2.0), 3, 4] (in left-to-right order) # Modify one leaf and reconstruct leaves[0] = leaves[0] + 10 restored = inverse(leaves) # restored == (11, (tensor(2.0), 3), {"x": 4}) assert restored[0] == 11 assert restored[1][1] == 3 ``` --- ## Summary `torch_einops_kit` is designed as a shared, typed tensor utility layer for PyTorch researchers who work in the style of the lucidrains ecosystem. The most common integration pattern is to import the root-level helpers (`align_dims_left`, `lens_to_mask`, `pad_sequence`, `safe_cat`, `pack_with_inverse`, `l2norm`) at the top of a model file and call them directly in `forward` methods, replacing inline ad-hoc implementations. The `save_load` decorator is the entry point for checkpoint management: apply `@save_load()` to any `nn.Module` subclass once, and `save`, `load`, and `init_and_load` are injected automatically, with support for nested decorated modules through `dehydrate_config`/`rehydrate_config`. The `device` submodule eliminates boilerplate `.to(device)` calls in multi-GPU or device-agnostic code by centralizing device routing in a single decorator. The `scaleValues` submodule (`RMSNorm`, `l2norm`, `masked_mean`) provides the normalization primitives needed in transformer pre-norm patterns. The `einops` submodule (`pack_with_inverse`, `pack_one`, `unpack_one`) wraps the einops pack/unpack API with paired inverses, simplifying code that needs to merge and later split tensor dimensions. Together, the package's `None`-safe conventions (`safe_cat`, `safe_stack`, `and_masks`, `or_masks`, `pad_sequence`) reduce error-prone None-handling boilerplate throughout training loops that process variable-length or optional intermediate tensors.