Try Live
Add Docs
Rankings
Pricing
Enterprise
Docs
Install
Install
Docs
Pricing
Enterprise
More...
More...
Try Live
Rankings
Add Docs
BS-RoFormer
https://github.com/lucidrains/bs-roformer
Admin
BS-RoFormer implements the Band Split Roformer, a state-of-the-art attention network for music
...
Tokens:
8,915
Snippets:
48
Trust Score:
10
Update:
1 week ago
Context
Skills
Chat
Benchmark
85.6
Suggestions
Latest
Show doc for...
Code
Info
Show Results
Context Summary (auto-generated)
Raw
Copy
Link
# BS-RoFormer BS-RoFormer is a PyTorch implementation of the Band-Split Rotary Transformer (BS-RoFormer), a state-of-the-art deep learning architecture for music source separation developed by ByteDance AI Labs. The model achieves top benchmark performance by combining band-split frequency decomposition with axial attention across both frequency and time dimensions. Rotary positional encoding (RoPE) — applied separately to the time and frequency axes — proved critical to the architecture's quality gains over prior methods. The library also ships a companion model, `MelBandRoformer`, which replaces the fixed band splits with a perceptually-motivated Mel filterbank and adds an optional linear-attention layer per depth block. Both models share the same training and inference API: pass raw mono or stereo audio waveforms as a float tensor, optionally supply a target waveform to compute the combined L1 + multi-resolution STFT loss, and receive either the separated audio or the scalar loss. The library supports multi-stem separation, stereo audio, FlashAttention (via PyTorch 2.0 `scaled_dot_product_attention`), hyper-connections for residual stream mixing, and an alternative Polar-coordinate Positional Embedding (PoPE) in place of RoPE. Installation is a single `pip install BS-RoFormer`; runtime dependencies are PyTorch ≥ 2.0, einops, rotary-embedding-torch, librosa, beartype, hyper-connections, and PoPE-pytorch. --- ## API Reference ### `BSRoformer` — Band-Split RoPE Transformer The primary model class. Converts raw audio to a complex STFT representation, splits it into configurable frequency bands, and applies alternating time-axis and frequency-axis transformer blocks with rotary (or PoPE) positional encodings. A per-band mask estimator network reconstructs each stem via iSTFT. ```python import torch from bs_roformer import BSRoformer # ── Model instantiation ────────────────────────────────────────────────────── model = BSRoformer( dim = 512, # transformer hidden dimension depth = 12, # number of (time-attn + freq-attn) layer pairs time_transformer_depth = 1, # sub-depth of the time transformer freq_transformer_depth = 1, # sub-depth of the frequency transformer stereo = False, # set True for stereo (2-channel) audio num_stems = 1, # number of output stems (e.g. 4 for vocals/drums/bass/other) heads = 8, dim_head = 64, attn_dropout = 0.1, ff_dropout = 0.1, flash_attn = True, # use PyTorch 2.0 scaled_dot_product_attention num_residual_streams = 4, # hyper-connections; set 1 to disable stft_n_fft = 2048, stft_hop_length = 512, stft_win_length = 2048, multi_stft_resolution_loss_weight = 1.0, use_pope = False # set True to replace RoPE with PoPE embeddings ) # ── Training forward pass (returns scalar loss) ─────────────────────────────── # Audio tensors: (batch, time_samples) for mono or (batch, 2, time_samples) for stereo x = torch.randn(2, 352800) # batch=2, ~8 s at 44 100 Hz target = torch.randn(2, 352800) # ground-truth separated stem loss = model(x, target=target) loss.backward() # loss = L1 + weighted multi-res STFT loss # ── Optional: inspect loss components ───────────────────────────────────────── total_loss, (l1_loss, multires_loss) = model(x, target=target, return_loss_breakdown=True) print(f"total={total_loss:.4f} l1={l1_loss:.4f} multires={multires_loss:.4f}") # ── Inference (no target → returns reconstructed audio tensor) ──────────────── model.eval() with torch.no_grad(): out = model(x) # shape: (batch, time_samples) — mono, single stem # stereo single stem → (batch, 2, time_samples) # mono multi-stem → (batch, num_stems, 1, time_samples) [before squeeze] print(out.shape) # torch.Size([2, 352800]) # ── Multi-stem separation example ───────────────────────────────────────────── model_4stem = BSRoformer(dim=512, depth=6, time_transformer_depth=1, freq_transformer_depth=1, num_stems=4) x4 = torch.randn(1, 441000) # ~10 s target4 = torch.randn(1, 4, 441000) # 4 stems loss4 = model_4stem(x4, target=target4) ``` --- ### `MelBandRoformer` — Mel-Band Rotary Transformer Follow-up architecture that replaces fixed frequency bands with a Mel filterbank (via `librosa.filters.mel`), adding overlapping band coverage and an optional linear-attention global context layer at each depth block. Hyper-connections and value residuals are enabled by default. ```python import torch from bs_roformer import MelBandRoformer # ── Model instantiation ────────────────────────────────────────────────────── model = MelBandRoformer( dim = 128, depth = 6, time_transformer_depth = 1, freq_transformer_depth = 1, linear_transformer_depth = 1, # global linear-attention sub-layer per block (0 to disable) num_bands = 60, # number of Mel bands stereo = False, num_stems = 1, sample_rate = 44100, stft_n_fft = 2048, stft_hop_length = 512, stft_win_length = 2048, attn_dropout = 0.1, ff_dropout = 0.1, flash_attn = True, add_value_residual = True, # learned value residual mixing across layers num_residual_streams = 4, match_input_audio_length = False, # if True, pads output to exact input length use_pope = False ) # ── Training ────────────────────────────────────────────────────────────────── x = torch.randn(2, 352800) target = torch.randn(2, 352800) loss = model(x, target=target) loss.backward() # ── Inspect loss breakdown ──────────────────────────────────────────────────── total, (l1, multires) = model(x, target=target, return_loss_breakdown=True) print(f"total={total:.4f} l1={l1:.4f} multires={multires:.4f}") # ── Inference ───────────────────────────────────────────────────────────────── model.eval() with torch.no_grad(): separated = model(x) # (batch, time_samples) for mono / single stem print(separated.shape) # torch.Size([2, 352800]) # ── Stereo training / inference ─────────────────────────────────────────────── stereo_model = MelBandRoformer(dim=64, depth=2, time_transformer_depth=1, freq_transformer_depth=1, stereo=True) x_stereo = torch.randn(1, 2, 352800) # (batch, channels=2, time) target_stereo = torch.randn(1, 2, 352800) loss_stereo = stereo_model(x_stereo, target=target_stereo) stereo_model.eval() with torch.no_grad(): out_stereo = stereo_model(x_stereo) # (1, 2, time_samples) print(out_stereo.shape) # torch.Size([1, 2, 352800]) ``` --- ### `BSRoformer.forward` / `MelBandRoformer.forward` — Unified Inference & Training Entry Point Both models expose an identical `forward` signature. When `target` is omitted the model performs inference; when `target` is provided it computes the training loss. ```python import torch from bs_roformer import BSRoformer model = BSRoformer(dim=512, depth=1, time_transformer_depth=1, freq_transformer_depth=1) model.eval() # ── Inference only ──────────────────────────────────────────────────────────── raw_audio = torch.randn(1, 176400) # (batch, time) — mono, ~4 s at 44 100 Hz with torch.no_grad(): separated = model(raw_audio) # separated: Tensor of shape (batch, time_samples) # ── Training with loss breakdown ────────────────────────────────────────────── model.train() target = torch.randn(1, 176400) # Standard training step total_loss = model(raw_audio, target=target) total_loss.backward() # Detailed breakdown for logging total_loss, (l1_loss, multires_loss) = model( raw_audio, target=target, return_loss_breakdown=True ) # l1_loss — waveform-domain L1 loss # multires_loss — multi-resolution STFT loss summed across 5 window sizes: # (4096, 2048, 1024, 512, 256) ``` --- ### `Attend` — Scaled Dot-Product Attention with Optional FlashAttention Internal attention primitive used by both models. Supports standard eager attention and PyTorch 2.0 `scaled_dot_product_attention` (FlashAttention path). Automatically configures optimal kernel flags for A100 vs. other CUDA GPUs. ```python import torch from bs_roformer.attend import Attend # ── Standard (eager) attention ──────────────────────────────────────────────── attn = Attend(dropout=0.1, flash=False) batch, heads, seq_len, dim_head = 2, 8, 128, 64 q = torch.randn(batch, heads, seq_len, dim_head) k = torch.randn(batch, heads, seq_len, dim_head) v = torch.randn(batch, heads, seq_len, dim_head) out = attn(q, k, v) # (batch, heads, seq_len, dim_head) print(out.shape) # torch.Size([2, 8, 128, 64]) # ── FlashAttention path (requires PyTorch >= 2.0) ───────────────────────────── flash_attn = Attend(dropout=0.0, flash=True) with torch.no_grad(): out_flash = flash_attn(q, k, v) print(out_flash.shape) # torch.Size([2, 8, 128, 64]) # ── Custom scale ────────────────────────────────────────────────────────────── attn_scaled = Attend(scale=0.05) # override default 1/sqrt(dim_head) out_scaled = attn_scaled(q, k, v) ``` --- ### `BandSplit` — Per-Band Feature Projection Projects each frequency band's flattened complex STFT features into the shared transformer dimension using independent RMSNorm + Linear layers. Used internally by both model classes. ```python import torch from bs_roformer.bs_roformer import BandSplit # Each entry in dim_inputs is 2 * freqs_in_band * audio_channels (complex × stereo) band_split = BandSplit(dim=512, dim_inputs=(4, 4, 8, 8, 24, 24, 48, 48, 256, 258)) # Input: (batch, time_frames, sum(dim_inputs)) x = torch.randn(2, 100, sum((4, 4, 8, 8, 24, 24, 48, 48, 256, 258))) out = band_split(x) # (batch, time_frames, num_bands, dim) print(out.shape) # torch.Size([2, 100, 10, 512]) ``` --- ### `MaskEstimator` — Per-Band Mask MLP Produces a complex-valued spectral mask for each frequency band from the transformer output features, using a GLU-activated MLP. Supports multiple stems (one `MaskEstimator` per stem is instantiated inside `BSRoformer` / `MelBandRoformer`). ```python import torch from bs_roformer.bs_roformer import MaskEstimator dim_inputs = (4, 4, 8, 8, 24) # freqs_per_band_with_complex estimator = MaskEstimator(dim=512, dim_inputs=dim_inputs, depth=2) # Input: (batch, time_frames, num_bands, dim) — output of transformer x = torch.randn(2, 100, len(dim_inputs), 512) out = estimator(x) # (batch, time_frames, sum(dim_inputs)) print(out.shape) # torch.Size([2, 100, 48]) ``` --- ### PoPE positional embedding support (`use_pope=True`) Both models optionally replace the default RoPE embeddings with PoPE (Polar Positional Embeddings), a successor method that decouples positional "what" and "where" information. ```python import torch from bs_roformer import BSRoformer, MelBandRoformer # BSRoformer with PoPE model_pope = BSRoformer( dim=512, depth=1, time_transformer_depth=1, freq_transformer_depth=1, use_pope=True # replaces RotaryEmbedding with PoPE on both time & freq axes ) x = torch.randn(1, 1, 35280) out = model_pope(x) print(out.shape) # torch.Size([1, 1, 35280]) — or squeezed to (1, 35280) for mono/single-stem # MelBandRoformer with PoPE mel_pope = MelBandRoformer( dim=512, depth=1, time_transformer_depth=1, freq_transformer_depth=1, use_pope=True ) out_mel = mel_pope(x) print(out_mel.shape) ``` --- ### Running the test suite ```bash pip install BS-RoFormer pytest pytest tests/test_roformer.py -v # test_bs_roformer[True] PASSED # test_bs_roformer[False] PASSED # test_mel_band_roformer[True] PASSED # test_mel_band_roformer[False] PASSED ``` ```python # Equivalent programmatic test import torch import pytest from bs_roformer import BSRoformer, MelBandRoformer @pytest.mark.parametrize('use_pope', [True, False]) def test_bs_roformer(use_pope): model = BSRoformer(dim=512, depth=1, time_transformer_depth=1, freq_transformer_depth=1, use_pope=use_pope) inp = torch.randn(1, 1, 35280) out = model(inp) assert out.shape[-1] <= inp.shape[-1] @pytest.mark.parametrize('use_pope', [True, False]) def test_mel_band_roformer(use_pope): model = MelBandRoformer(dim=512, depth=1, time_transformer_depth=1, freq_transformer_depth=1, use_pope=use_pope) inp = torch.randn(1, 1, 35280) out = model(inp) assert out.shape[-1] <= inp.shape[-1] ``` --- ## Summary BS-RoFormer is best suited for two primary use cases: training a custom music source separator from scratch, and fine-tuning or extending community pre-trained checkpoints (such as those released by Roman Solovyev / ZFTurbo for vocals). The `BSRoformer` class targets tasks where band boundaries can be specified explicitly as a tuple of integer frequency counts (`freqs_per_bands`), making it ideal when prior knowledge about the spectral structure of the target source is available. `MelBandRoformer` is the better default for general-purpose separation because Mel-spaced bands naturally align with perceptually salient audio regions, yielding better results without manual frequency engineering. Both variants support multi-stem output, which makes them directly applicable to full music demixing pipelines (vocals, drums, bass, other) without needing separate per-stem models. For integration into a training pipeline, the typical pattern is: wrap the model in a standard PyTorch optimizer loop, call `model(audio, target=target)` to get the combined loss, and call `loss.backward()`. The built-in multi-resolution STFT loss (summed over five window sizes from 256 to 4096) eliminates the need for any external loss function. For inference in production or post-processing workflows, call `model(audio)` without a target to obtain the time-domain separated waveform directly — no manual STFT/iSTFT handling is required. Pre-trained weights for vocals are available from the community and can be loaded with a standard `model.load_state_dict(torch.load(...))` call, enabling drop-in vocal isolation, remixing, and karaoke-style applications.