Try Live
Add Docs
Rankings
Pricing
Enterprise
Docs
Install
Install
Docs
Pricing
Enterprise
More...
More...
Try Live
Rankings
Add Docs
Hunter Forms BS
https://github.com/hunterhogan/hunterformsbs
Admin
Hunter Forms BS is a flexible frequency-band splitter for music source separation using PyTorch,
...
Tokens:
1,525
Snippets:
30
Trust Score:
8
Update:
1 day ago
Context
Skills
Chat
Benchmark
88.3
Suggestions
Latest
Show doc for...
Code
Info
Show Results
Context Summary (auto-generated)
Raw
Copy
Link
# hunterFormsBS `hunterFormsBS` is a fully typed, modular PyTorch library for music source separation built around a unified band-split RoFormer architecture. Instead of maintaining separate implementations for BS-RoFormer and Mel-Band RoFormer, the package treats both as different band-layout configurations of one core design centered on `BandSplitRotator`. The model accepts raw waveform audio, converts it to an STFT representation, groups frequency bins into bands, runs a hierarchical time-and-frequency attention core, estimates per-stem complex masks, and reconstructs separated waveforms via inverse STFT — all within a single forward pass that doubles as training loop entry point. The package exposes stable entry points (`BandSplitRotator`, `BSRoformer`, `MelBandRoformer`) alongside reusable typed building blocks (`BandSplit`, `MaskEstimator`, `Transformer`, `lossComputation`) and typed configuration records (`KwargsSTFT`, `KwargsTransformer`, `ComputeLoss`). Positional embeddings are selectable between RoPE and PoPE. An optional XCiT-style linear-attention pre-block, activation checkpointing, skip connections, and stereo support are all configurable. Experimental variants of the attention stack and separator classes are kept in separate modules, leaving the stable API uncluttered. The package also ships drop-in compatibility modules for existing `BSRoformer` and `MelBandRoformer` configuration files and checkpoints from external training frameworks such as Music-Source-Separation-Training. --- ## `BandSplitRotator` — unified waveform separator `BandSplitRotator` is the primary separator class. Passing `sample_rate` and `num_bands` automatically builds the overlapped mel-band front end; omitting them and relying on the default `freqs_per_bands` builds the non-overlapping BS-style front end. A custom `mask_filter_bank` Boolean tensor of shape `(bands, freqs)` bypasses automatic construction entirely. ```python import torch from hunterFormsBS import BandSplitRotator # --- BS-style (non-overlapping) separator, mono, 2 stems (e.g. vocals + drums) --- model_bs = BandSplitRotator( dim=128, depth=6, num_stems=2, stereo=False, final_norm=True, norm_output=False, zero_dc=True, time_transformer_depth=2, freq_transformer_depth=2, heads=8, dim_head=64, flash_attn=True, use_torch_checkpoint=False, ) # Inference: raw_audio shape (batch, time) or (batch, 1, time) for mono mixture = torch.randn(2, 44100) # batch=2, 1 second at 44100 Hz recon = model_bs(mixture) # -> (batch=2, stems=2, channels=1, time=44100) print(recon.shape) # torch.Size([2, 2, 1, 44100]) # Training: pass target of shape (batch, stems, channels, time) target = torch.randn(2, 2, 1, 44100) loss: torch.Tensor = model_bs(mixture, target=target) loss.backward() # Training with loss breakdown total_loss, (l1_loss, stft_loss) = model_bs( mixture, target=target, return_loss_breakdown=True ) print(f"total={total_loss.item():.4f} l1={l1_loss.item():.4f} stft={stft_loss.item():.4f}") # --- Mel-band (overlapping) separator, stereo --- model_mel = BandSplitRotator( dim=128, depth=6, num_stems=1, stereo=True, sample_rate=44100.0, num_bands=60, final_norm=False, norm_output=True, zero_dc=False, ) stereo_mix = torch.randn(1, 2, 88200) # batch=1, stereo, 2 seconds recon_stereo = model_mel(stereo_mix) # -> (1, 1, 2, 88200) print(recon_stereo.shape) # torch.Size([1, 1, 2, 88200]) # --- Custom filter bank --- custom_bank = torch.randint(0, 2, (48, 1025), dtype=torch.bool) custom_bank[:, 0] = True # DC bin must belong to at least one band model_custom = BandSplitRotator( dim=128, depth=4, mask_filter_bank=custom_bank, final_norm=True, norm_output=False, zero_dc=False, ) ``` --- ## `BandSplitRotator.forward` — inference and training pass `forward` accepts a raw-audio mixture and optionally a target tensor. Without a target it returns separated waveform stems; with a target it returns the combined L1 + multi-resolution STFT loss (or an expanded breakdown). ```python import torch from hunterFormsBS import BandSplitRotator model = BandSplitRotator(dim=128, depth=6, num_stems=4, stereo=True, sample_rate=44100.0, num_bands=60, final_norm=False, norm_output=True, zero_dc=False) model.eval() # Inference — no gradients needed with torch.no_grad(): mixture = torch.randn(1, 2, 262144) # batch=1, stereo, ~6 s stems = model(mixture) # (batch, num_stems, channels, time) vocals, drums, bass, other = stems.unbind(dim=1) print(vocals.shape) # torch.Size([1, 2, 262144]) # Selective stem training with active_stem_ids model.train() mixture_batch = torch.randn(4, 2, 131072) # Only supervise stem 0 (vocals) in this batch target_vocals = torch.randn(4, 4, 2, 131072) # full target tensor, only stem 0 used loss = model(mixture_batch, target=target_vocals, active_stem_ids=[0]) loss.backward() # Return loss breakdown for logging total, (wav_loss, spec_loss) = model( mixture_batch, target=target_vocals, active_stem_ids=[0], return_loss_breakdown=True, ) ``` --- ## `BSRoformer` — non-overlapping BS-RoFormer compatibility class `BSRoformer` is a drop-in compatibility wrapper that preserves the non-overlapping band-split front end and constructor defaults from the upstream BS-RoFormer implementation. It has the same `forward` signature as `BandSplitRotator`. ```python import torch from hunterFormsBS.bs_roformer import BSRoformer from hunterFormsBS import DEFAULT_FREQS_PER_BANDS # Drop-in replacement for upstream BSRoFormer configs # Note: mask_estimator_depth defaults to 1 here; upstream used effective depth 1 # after an internal subtraction. Set mask_estimator_depth=1 to match checkpoints. model = BSRoformer( dim=384, depth=12, num_stems=1, stereo=True, heads=12, dim_head=64, mask_estimator_depth=1, # important: upstream effective depth was 1 freqs_per_bands=DEFAULT_FREQS_PER_BANDS, flash_attn=True, use_torch_checkpoint=True, # saves memory during training ) mixture = torch.randn(2, 2, 261120) with torch.no_grad(): separated = model(mixture) # (2, 1, 2, 261120) # Loading an existing checkpoint state = torch.load("checkpoint.pth", weights_only=True) model.load_state_dict(state) ``` --- ## `MelBandRoformer` — overlapping mel-band compatibility class `MelBandRoformer` targets the overlapped mel-band front end from the Mel-Band RoFormer paper. It defaults `num_bands=60` and `sample_rate=44100` when neither is supplied. ```python import torch from hunterFormsBS.mel_band_roformer import MelBandRoformer model = MelBandRoformer( dim=256, depth=8, num_stems=1, stereo=True, num_bands=64, sample_rate=44100.0, heads=8, dim_head=64, mask_estimator_depth=1, ) mixture = torch.randn(2, 2, 131072) with torch.no_grad(): vocals = model(mixture) # (2, 1, 2, 131072) print(vocals.shape) # torch.Size([2, 1, 2, 131072]) ``` --- ## `BandSplit` — front-end band projection `BandSplit` slices a concatenated STFT band representation along the last axis according to a per-band width sequence and projects each slice to a shared feature width. It is the shared front-end projector used by all three separator classes. ```python import torch from hunterFormsBS import BandSplit # Each band may have a different number of frequency bins (× 2 for complex, × channels) dim_inputs = [4, 4, 8, 8, 24, 24, 48, 48, 256, 258] # example per-band widths band_split = BandSplit(dim=128, dim_inputs=dim_inputs) # x: (batch, time_frames, sum(dim_inputs)) x = torch.randn(2, 64, sum(dim_inputs)) features = band_split(x) # (2, 64, num_bands=10, dim=128) print(features.shape) # torch.Size([2, 64, 10, 128]) ``` --- ## `MaskEstimator` — per-stem complex mask estimation head `MaskEstimator` maps a stack of band tokens to a concatenated complex mask representation. One `MaskEstimator` is created per output stem inside `BandSplitRotator`. ```python import torch from hunterFormsBS import MaskEstimator dim_inputs = [4, 4, 8, 8, 24, 24, 48, 48, 256, 258] estimator = MaskEstimator( dim=128, dim_inputs=dim_inputs, depth=1, # paper-default: one hidden layer mlp_expansion_factor=4, ) # x: (batch, time_frames, num_bands, dim) x = torch.randn(2, 64, len(dim_inputs), 128) mask = estimator(x) # (2, 64, sum(dim_inputs)) print(mask.shape) # torch.Size([2, 64, 682]) ``` --- ## `lossComputation` — combined waveform + multi-resolution STFT loss `lossComputation` computes `L_total = L1_waveform + weight × Σ_w L1_STFT(window=w)` for selected stems. It is called automatically inside `forward` when a `target` is supplied but can also be used standalone. ```python import torch from hunterFormsBS import lossComputation, ComputeLoss from Z0Z_tools import halfsineTensor multi_stft_config = ComputeLoss( hop_length=147, loss_weight=1.0, window_sizes=(4096, 2048, 1024, 512, 256), n_fft=2048, normalized=False, window_fn=halfsineTensor, ) batch, num_stems, channels, time = 2, 2, 1, 44100 recon_audio = torch.randn(batch, num_stems, channels, time) target = torch.randn(batch, num_stems, channels, time) # Compute loss for all stems total_loss = lossComputation( recon_audio=recon_audio, target=target, stem_ids=[0, 1], multi_stft=multi_stft_config, ) print(total_loss.item()) # Compute loss with breakdown total_loss, (l1, stft_loss) = lossComputation( recon_audio=recon_audio, target=target, stem_ids=[0, 1], multi_stft=multi_stft_config, return_loss_breakdown=True, ) print(f"L1={l1.item():.4f} STFT={stft_loss.item():.4f} Total={total_loss.item():.4f}") ``` --- ## `Transformer` — hierarchical attention stack block `Transformer` stacks `depth` pairs of attention and feedforward sublayers with residual connections. Setting `linear_attn=True` switches each attention sublayer to the XCiT-style cross-covariance `LinearAttention` branch. Positional encoders `rotary_embed` (RoPE) or `pope_embed` (PoPE) are injected into each `Attention` block. ```python import torch from hunterFormsBS import Transformer from rotary_embedding_torch import RotaryEmbedding rotary = RotaryEmbedding(dim=64) time_transformer = Transformer( dim=128, depth=2, dim_head=64, heads=8, attn_dropout=0.1, ff_dropout=0.1, flash_attn=True, norm_output=True, linear_attn=False, rotary_embed=rotary, ) # x: (batch × num_bands, time_frames, dim) x = torch.randn(16, 64, 128) out = time_transformer(x) # same shape (16, 64, 128) print(out.shape) # torch.Size([16, 64, 128]) # Linear-attention (XCiT) variant — no positional encoder lin_transformer = Transformer( dim=128, depth=1, dim_head=32, heads=8, flash_attn=False, norm_output=False, linear_attn=True, ) out_lin = lin_transformer(x) print(out_lin.shape) # torch.Size([16, 64, 128]) ``` --- ## `Attend` — core scaled dot-product attention `Attend` computes attention from precomputed `q`, `k`, `v` tensors. With `flash=True` and PyTorch ≥ 2.0 it delegates to `torch.nn.functional.scaled_dot_product_attention`; otherwise it uses an explicit softmax path. ```python import torch from hunterFormsBS.attend import Attend attend = Attend(dropout=0.0, flash=True) # requires PyTorch >= 2.0 batch, heads, seq, dim_head = 2, 8, 64, 64 q = torch.randn(batch, heads, seq, dim_head) k = torch.randn(batch, heads, seq, dim_head) v = torch.randn(batch, heads, seq, dim_head) out = attend(q, k, v) # (2, 8, 64, 64) print(out.shape) # torch.Size([2, 8, 64, 64]) ``` --- ## Typed configuration records (`theTypes`) All configuration records are `TypedDict` or `NamedTuple` instances exported from `hunterFormsBS.theTypes` and re-exported from the top-level namespace. ```python from hunterFormsBS import ( ComputeLoss, FlashAttentionConfig, KwargsOfAttention, KwargsSTFT, KwargsTransformer, ) from Z0Z_tools import halfsineTensor # KwargsSTFT — shared STFT keyword arguments stft_kwargs = KwargsSTFT(n_fft=2048, hop_length=512, win_length=1024, normalized=False) # KwargsTransformer — shared transformer configuration transformer_kwargs = KwargsTransformer( dim=128, dim_head=64, heads=8, attn_dropout=0.0, ff_dropout=0.0, flash_attn=True, norm_output=True, ) # ComputeLoss — multi-resolution STFT loss settings loss_config = ComputeLoss( hop_length=147, loss_weight=1.0, window_sizes=(4096, 2048, 1024, 512, 256), n_fft=2048, normalized=False, window_fn=halfsineTensor, ) # FlashAttentionConfig — backend flags for sdp_kernel cuda_flags = FlashAttentionConfig( enable_flash=True, enable_math=False, enable_mem_efficient=False, ) # KwargsOfAttention — attention block configuration attn_kwargs = KwargsOfAttention( dim=128, dim_head=64, heads=8, dropout=0.1, flash=True, ) ``` --- ## `DEFAULT_FREQS_PER_BANDS` — standard BS-RoFormer frequency partition `DEFAULT_FREQS_PER_BANDS` is the canonical non-overlapping frequency-bin count tuple matching the BS-RoFormer paper partition. It can be passed directly to `BandSplitRotator` or `BSRoformer` as `freqs_per_bands`. ```python from hunterFormsBS import DEFAULT_FREQS_PER_BANDS, BandSplitRotator print(len(DEFAULT_FREQS_PER_BANDS)) # 62 bands print(sum(DEFAULT_FREQS_PER_BANDS)) # 1025 frequency bins (n_fft//2 + 1 for n_fft=2048) model = BandSplitRotator( dim=128, depth=6, freqs_per_bands=DEFAULT_FREQS_PER_BANDS, final_norm=True, norm_output=False, zero_dc=True, ) ``` --- ## Summary `hunterFormsBS` covers the full pipeline from raw stereo or mono audio to separated stem waveforms and training loss, making it suitable for both research prototyping and production training loops. The primary entry point `BandSplitRotator` subsumes all band-layout modes — non-overlapping BS-style, overlapping mel-band, or any custom Boolean filter bank — so a single model class suffices for the most common source separation configurations. The shared `lossComputation` helper, typed configuration records, and reusable `BandSplit`/`MaskEstimator`/`Transformer` building blocks make it straightforward to compose novel architectures or plug the library into external frameworks that expect the `BSRoformer` or `MelBandRoformer` namespace. For integration with external training frameworks such as Music-Source-Separation-Training, the compatibility modules `hunterFormsBS.bs_roformer.BSRoformer` and `hunterFormsBS.mel_band_roformer.MelBandRoformer` accept the same constructor kwargs and `forward` signatures as the upstream implementations, allowing checkpoint and configuration file reuse with minimal changes. The `active_stem_ids` parameter in `forward` enables per-batch stem selection without rebuilding the model, `use_torch_checkpoint=True` reduces activation memory at training time, and the MPS fallback path means the same code runs on Apple Silicon without modification.