Try Live
Add Docs
Rankings
Pricing
Enterprise
Docs
Install
Install
Docs
Pricing
Enterprise
More...
More...
Try Live
Rankings
Add Docs
Music Source Separation Training
https://github.com/zfturbo/music-source-separation-training
Admin
A comprehensive training framework for music source separation models supporting multiple
...
Tokens:
15,427
Snippets:
161
Trust Score:
9.3
Update:
1 week ago
Context
Skills
Chat
Benchmark
78
Suggestions
Latest
Show doc for...
Code
Info
Show Results
Context Summary (auto-generated)
Raw
Copy
Link
# Music Source Separation Training Music Source Separation Training is a universal framework for training and running inference on deep learning models that separate audio mixtures into individual stems (vocals, bass, drums, other, etc.). Built on top of PyTorch and originally inspired by the KUIELab code for the SDX23 challenge, it provides a modular and experiment-friendly codebase that supports a large catalog of state-of-the-art architectures. The project is developed and maintained by [MVSep.com](https://mvsep.com). The framework covers the entire machine-learning lifecycle: dataset preparation (seven different dataset types), augmentation (loudness, pitch shift, EQ, MP3 compression, reverb, and more), single-GPU and multi-GPU (DDP) training with gradient accumulation and EMA, validation with multiple metrics (SDR, SI-SDR, L1-freq, AURA-STFT, bleedless, fullness), inference over audio folders, LoRA fine-tuning, and model ensemble. Supported architectures include MDX23C, Demucs4HT, BS-RoFormer, Mel-Band RoFormer, SCNet, BandIt / BandIt v2, Apollo, BSMamba2, Conformer, Swin UperNet, and more. --- ## `get_model_from_config` — Load and instantiate a model from a config file Reads a YAML configuration file for the requested `model_type`, constructs the corresponding PyTorch model, and returns both the model and the loaded config. This is the central entry point used by `train.py`, `valid.py`, and `inference.py`. ```python from utils.settings import get_model_from_config # Load a Band-Split RoFormer model model, config = get_model_from_config( model_type='bs_roformer', config_path='configs/config_musdb18_bs_roformer.yaml' ) print(model) # BSRoformer(dim=192, depth=6, stereo=True, ...) # Load a Mel-Band RoFormer model model, config = get_model_from_config( model_type='mel_band_roformer', config_path='configs/config_musdb18_mel_band_roformer.yaml' ) # Load a Demucs4HT model model, config = get_model_from_config( model_type='htdemucs', config_path='configs/config_musdb18_htdemucs.yaml' ) # Supported model_type values: # 'mdx23c', 'htdemucs', 'segm_models', 'torchseg', # 'mel_band_roformer', 'mel_band_conformer', 'bs_roformer', 'bs_conformer', # 'bs_mamba2', 'swin_upernet', 'bandit', 'bandit_v2', # 'scnet', 'scnet_tran', 'scnet_unofficial', 'scnet_masked', # 'apollo', 'conformer', 'moises_light' ``` --- ## `train.py` / `train_model` — Single-GPU training entry point Trains a model end-to-end for the configured number of epochs. Handles data loading, optimizer and scheduler setup, AMP, gradient clipping, EMA, optional LoRA, layer freezing, and per-epoch validation with automatic checkpoint saving when the monitored metric improves. ```bash # Basic single-GPU training (Mel-Band RoFormer, vocals separation) python train.py \ --model_type mel_band_roformer \ --config_path configs/config_musdb18_mel_band_roformer.yaml \ --results_path results/ \ --data_path datasets/musdb18hq/train \ --valid_path datasets/musdb18hq/test \ --num_workers 4 \ --device_ids 0 # Resume from a checkpoint, monitor multiple metrics python train.py \ --model_type bs_roformer \ --config_path configs/config_musdb18_bs_roformer.yaml \ --start_check_point results/model_bs_roformer_ep_50_sdr_9.1234.ckpt \ --load_optimizer --load_scheduler --load_epoch --load_best_metric \ --results_path results/ \ --data_path datasets/musdb18hq/train \ --valid_path datasets/musdb18hq/test \ --metrics sdr si_sdr l1_freq \ --metric_for_scheduler sdr \ --device_ids 0 1 \ --num_workers 4 # Training with multiple data paths and dataset type 2 (stem folders) python train.py \ --model_type scnet \ --config_path configs/config_musdb18_scnet.yaml \ --results_path results/scnet/ \ --data_path datasets/musdb18hq/train datasets/extra_stems \ --dataset_type 2 \ --valid_path datasets/musdb18hq/test \ --device_ids 0 \ --wandb_key YOUR_WANDB_KEY # Checkpoints are saved to results/ as: # model_bs_roformer_ep_<N>_sdr_<value>.ckpt (best only) # last_model_<type>.ckpt (always saved) ``` --- ## `train_ddp.py` / `train_model_ddp` — Multi-GPU Distributed Data Parallel training Wraps `train_model` with `torch.multiprocessing.spawn` to launch one process per available GPU. Significantly faster than single-GPU training for 2 or more GPUs. ```bash # Launch DDP training across all available GPUs automatically python train_ddp.py \ --model_type mel_band_roformer \ --config_path configs/config_musdb18_mel_band_roformer.yaml \ --results_path results/ \ --data_path datasets/musdb18hq/train \ --valid_path datasets/musdb18hq/test \ --num_workers 4 # Programmatic DDP launch (e.g., from a notebook or script) from train_ddp import train_model_ddp train_model_ddp({ 'model_type': 'bs_roformer', 'config_path': 'configs/config_musdb18_bs_roformer.yaml', 'results_path': 'results/', 'data_path': ['datasets/musdb18hq/train'], 'valid_path': ['datasets/musdb18hq/test'], 'num_workers': 4, 'metrics': ['sdr'], 'metric_for_scheduler': 'sdr', }) ``` --- ## `inference.py` / `proc_folder` — Separate audio files in a folder Loads a trained model and processes every audio file found recursively inside `--input_folder`, writing each separated stem to `--store_dir`. Supports test-time augmentation (TTA), BigShifts averaging, instrumental extraction, spectrogram rendering, and custom filename templates. ```bash # Basic inference: separate vocals from a folder of tracks python inference.py \ --model_type mel_band_roformer \ --config_path configs/config_musdb18_mel_band_roformer.yaml \ --start_check_point weights/model_mel_band_roformer_ep_300_sdr_11.ckpt \ --input_folder input/songs/ \ --store_dir output/separated/ # With TTA, instrumental extraction, and custom output template python inference.py \ --model_type bs_roformer \ --config_path configs/config_musdb18_bs_roformer.yaml \ --start_check_point weights/model_bs_roformer.ckpt \ --input_folder input/songs/ \ --store_dir output/ \ --use_tta \ --extract_instrumental \ --bigshifts 3 \ --filename_template "{file_name}/{model}/{instr}" \ --device_ids 0 # Programmatic usage from inference import proc_folder proc_folder({ 'model_type': 'bs_roformer', 'config_path': 'configs/config_musdb18_bs_roformer.yaml', 'start_check_point': 'weights/model_bs_roformer.ckpt', 'input_folder': 'input/songs/', 'store_dir': 'output/separated/', 'use_tta': True, 'device_ids': [0], 'bigshifts': 1, }) # Output files: output/separated/<song_name>/<instr>.flac (or .wav for float peaks) ``` --- ## `valid.py` / `check_validation` — Evaluate a model against a validation dataset Runs separation on all `mixture.wav` files found under `--valid_path`, computes the requested metrics for each stem, and prints (and optionally saves) averaged results. Supports multi-GPU parallel validation. ```bash # Validate a BS-RoFormer model with SDR metric python valid.py \ --model_type bs_roformer \ --config_path configs/config_musdb18_bs_roformer.yaml \ --start_check_point weights/model_bs_roformer_ep_317_sdr_12.9755.ckpt \ --valid_path datasets/musdb18hq/test \ --device_ids 0 # Multi-metric validation with output saved to file and spectrograms python valid.py \ --model_type mel_band_roformer \ --config_path configs/config_musdb18_mel_band_roformer.yaml \ --start_check_point weights/model_mel_band_roformer.ckpt \ --valid_path datasets/musdb18hq/test \ --store_dir results_valid/ \ --draw_spectro 30 \ --metrics sdr si_sdr l1_freq aura_stft bleedless fullness \ --device_ids 0 1 # Programmatic usage from valid import check_validation metrics_avg, all_metrics = check_validation({ 'model_type': 'mel_band_roformer', 'config_path': 'configs/config_musdb18_mel_band_roformer.yaml', 'start_check_point': 'weights/model.ckpt', 'valid_path': ['datasets/musdb18hq/test'], 'metrics': ['sdr', 'si_sdr'], 'device_ids': [0], }) # metrics_avg = {'sdr': 11.43, 'si_sdr': 10.87} # all_metrics = {'sdr': {'vocals': [12.1, 11.9, ...], 'drums': [10.2, ...]}, ...} ``` --- ## `ensemble.py` / `ensemble_files` — Ensemble multiple separated audio files Combines several separation outputs (same stem, different models) into one improved result using waveform or spectrogram domain averaging strategies, with optional per-file weights. ```bash # Average two vocal outputs (equal weight) python ensemble.py \ --files results_model1/vocals.wav results_model2/vocals.wav \ --type avg_wave \ --output ensemble_vocals.wav # Weighted FFT-domain max ensemble (most aggressive, good for sharper separation) python ensemble.py \ --files vocals_model1.wav vocals_model2.wav vocals_model3.wav \ --weights 2 1 1 \ --type max_fft \ --output best_vocals.wav # Programmatic usage from ensemble import ensemble_files ensemble_files([ '--files', 'vocals_a.wav', 'vocals_b.wav', '--type', 'avg_wave', '--weights', '1', '1', '--output', 'result.wav' ]) # Available --type options: # avg_wave – weighted mean of waveforms (recommended, usually best SDR) # median_wave – median across waveforms # min_wave – minimum absolute value per sample (conservative) # max_wave – maximum absolute value per sample (aggressive) # avg_fft – weighted mean of STFT spectrograms + iSTFT # median_fft – median of spectrograms + iSTFT (needs ≥3 sources) # min_fft – minimum spectrogram magnitude + iSTFT (most conservative) # max_fft – maximum spectrogram magnitude + iSTFT (most aggressive) ``` --- ## Model configuration YAML — Define model architecture and training hyperparameters Every model is controlled by a YAML config file with four top-level sections: `audio`, `model`, `training`, and `inference`. The config is loaded with `get_model_from_config` and passed throughout training, validation, and inference. ```yaml # configs/config_musdb18_bs_roformer.yaml — Band-Split RoFormer example audio: chunk_size: 131584 # Training chunk length in samples dim_f: 1024 # Frequency dimension for STFT hop_length: 512 n_fft: 2048 num_channels: 2 # Stereo sample_rate: 44100 min_mean_abs: 0.001 # Skip silent chunks below this threshold model: dim: 192 depth: 6 stereo: true num_stems: 1 time_transformer_depth: 1 freq_transformer_depth: 1 heads: 8 dim_head: 64 attn_dropout: 0.1 ff_dropout: 0.1 flash_attn: true # Use Flash Attention for memory efficiency stft_n_fft: 2048 stft_hop_length: 512 mask_estimator_depth: 2 multi_stft_resolution_loss_weight: 1.0 training: batch_size: 10 gradient_accumulation_steps: 1 grad_clip: 0 instruments: [vocals, bass, drums, other] lr: 5.0e-05 patience: 2 # LR reduce patience (ReduceLROnPlateau) reduce_factor: 0.95 target_instrument: vocals # Primary separation target num_epochs: 1000 num_steps: 1000 ema_momentum: 0.999 # Exponential moving average decay optimizer: adam use_amp: true # Mixed precision (float16) augmentations: enable: true loudness: true loudness_min: 0.5 loudness_max: 1.5 inference: batch_size: 1 num_overlap: 4 # Overlap factor for chunked inference ``` --- ## Augmentations config — On-the-fly data augmentation Augmentations are defined inline in the YAML config under the `augmentations` key. They apply to all stems (`all` subsection) or to specific instrument stems by name. ```yaml # Full augmentation example (paste into any config) augmentations: enable: true loudness: true loudness_min: 0.5 loudness_max: 1.5 mixup: true mixup_probs: !!python/tuple - 0.2 # Probability of mixing in 1 extra same-type stem - 0.02 # Probability of mixing in 2 extra same-type stems mixup_loudness_min: 0.5 mixup_loudness_max: 1.5 # Simulate downloading mp3 from the internet mp3_compression_on_mixture: 0.01 mp3_compression_on_mixture_bitrate_min: 32 mp3_compression_on_mixture_bitrate_max: 320 mp3_compression_on_mixture_backend: "lameenc" # Random chunk sizes (careful with VRAM) chunk_size_augm: true chunk_size_min: 44100 chunk_size_max: 661500 all: # Applied to every stem channel_shuffle: 0.5 random_polarity: 0.5 mp3_compression: 0.01 mp3_compression_min_bitrate: 32 mp3_compression_max_bitrate: 320 pedalboard_reverb: 0.01 pedalboard_chorus: 0.01 pedalboard_pitch_shift: 0.01 pedalboard_pitch_shift_semitones_min: -7 pedalboard_pitch_shift_semitones_max: 7 vocals: # Applied only to the vocals stem pitch_shift: 0.1 pitch_shift_min_semitones: -5 pitch_shift_max_semitones: 5 seven_band_parametric_eq: 0.25 seven_band_parametric_eq_min_gain_db: -9 seven_band_parametric_eq_max_gain_db: 9 tanh_distortion: 0.1 tanh_distortion_min: 0.1 tanh_distortion_max: 0.7 drums: pitch_shift: 0.33 pitch_shift_min_semitones: -5 pitch_shift_max_semitones: 5 ``` --- ## Dataset types — Supported training dataset layouts Seven dataset formats are supported via `--dataset_type`. The validation dataset always uses Type 1 layout with an additional `mixture.wav` per song folder. ``` # Type 1 (MUSDB, default) — one folder per song, stems as named wav/flac files datasets/musdb18hq/train/ Song 1/ vocals.wav bass.wav drums.wav other.wav Song 2/ vocals.wav ... # Type 2 — one folder per instrument, each containing solo stem files datasets/stems/ vocals/ vocals_001.wav vocals_002.wav ... bass/ bass_001.wav ... # Type 3 — CSV file(s) mapping instrument names to file paths instrum,path vocals,/data/vocals_01.wav drums,/data/drums_01.wav # Type 4 — Same as Type 1 but all stems loaded from the same time offset (aligned) # Type 5 — Same as Type 1 but pre-split into 50%-overlap chunks # Type 6 — Aligned + explicit mixture.wav (for distillation / consistency losses) # Type 7 — Class-balanced aligned dataset (reduces dominant-class bias) # Training invocation using Type 3 (CSV) python train.py \ --model_type mel_band_roformer \ --config_path configs/config_musdb18_mel_band_roformer.yaml \ --results_path results/ \ --data_path datasets/my_stems.csv datasets/extra_stems.csv \ --dataset_type 3 \ --valid_path datasets/musdb18hq/test \ --device_ids 0 ``` --- ## LoRA fine-tuning — Parameter-efficient fine-tuning with Low-Rank Adaptation LoRA injects small trainable low-rank matrices into the model, drastically reducing the number of trainable parameters and checkpoint size. Supported via two backends: `peft` and `loralib`. ```yaml # Add this lora block to your config file to enable LoRA lora: r: 8 # Rank of the low-rank matrices (smaller = fewer params) lora_alpha: 16 # Scaling factor; lora_alpha/r should be > 1 lora_dropout: 0.05 # Regularization dropout on LoRA layers merge_weights: false # Set true to bake LoRA into base weights for deployment fan_in_fan_out: false enable_lora: [true] # Which sub-projections get LoRA (e.g. [true, false, true] for Q/V only) ``` ```bash # Fine-tune BS-RoFormer with LoRA on a custom dataset python train.py \ --model_type bs_roformer \ --config_path configs/config_musdb18_bs_roformer_with_lora.yaml \ --start_check_point weights/model_bs_roformer_ep_17_sdr_9.6568.ckpt \ --results_path results/lora_finetune/ \ --data_path datasets/my_custom_vocals/train \ --valid_path datasets/my_custom_vocals/test \ --device_ids 0 \ --metrics sdr \ --train_lora_loralib # Validate a LoRA-fine-tuned model python valid.py \ --model_type scnet \ --config_path configs/config_musdb18_scnet.yaml \ --start_check_point weights/base_scnet.ckpt \ --lora_checkpoint_loralib weights/lora_scnet.ckpt \ --valid_path datasets/musdb18hq/test \ --metrics sdr si_sdr l1_freq # Inference with a LoRA checkpoint python inference.py \ --model_type scnet \ --config_path configs/config_musdb18_scnet.yaml \ --start_check_point weights/base_scnet.ckpt \ --lora_checkpoint_loralib weights/lora_scnet.ckpt \ --input_folder input/songs/ \ --store_dir output/lora_results/ ``` --- ## `demix` / `bigshifts_wrapper` — Core inference chunk-processing utilities `demix` splits long audio into overlapping chunks and runs model forward passes in batches, reconstructing the full waveform with fade-in/out windowing to minimize boundary artifacts. `bigshifts_wrapper` further improves quality by averaging multiple time-shifted demix passes. ```python import torch import librosa from utils.settings import get_model_from_config from utils.model_utils import demix, bigshifts_wrapper # Load model model, config = get_model_from_config('bs_roformer', 'configs/config_musdb18_bs_roformer.yaml') checkpoint = torch.load('weights/model_bs_roformer.ckpt', map_location='cpu') model.load_state_dict(checkpoint['state_dict']) model.eval() device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') model.to(device) # Load audio (stereo, 44100 Hz) mix, sr = librosa.load('song.wav', sr=44100, mono=False) # shape: (2, N) # Standard demix sources = demix(config, model, mix, device, model_type='bs_roformer', pbar=True) # sources = {'vocals': array(2, N), 'bass': array(2, N), 'drums': array(2, N), 'other': array(2, N)} # BigShifts demix (averages 3 time-shifted passes for better quality) sources = bigshifts_wrapper(config, model, mix, device, model_type='bs_roformer', bigshifts=3, pbar=True) import soundfile as sf sf.write('vocals.flac', sources['vocals'].T, sr, subtype='PCM_16') ``` --- ## `wandb_init` — Weights & Biases experiment tracking integration Initializes a wandb run in online, offline, or disabled mode based on whether a valid API key is supplied. During training, loss and metric values are logged automatically every step and epoch. ```python import argparse from utils.settings import wandb_init, get_model_from_config _, config = get_model_from_config('bs_roformer', 'configs/config_musdb18_bs_roformer.yaml') args = argparse.Namespace( model_type='bs_roformer', device_ids=[0], wandb_key='YOUR_WANDB_API_KEY', # set '' or None to disable wandb_offline=False, ) # Online mode (requires valid API key) wandb_init(args, config, batch_size=10) # Offline mode (no network required, sync later with `wandb sync`) args.wandb_offline = True wandb_init(args, config, batch_size=10) # Disabled (no tracking at all) args.wandb_key = '' args.wandb_offline = False wandb_init(args, config, batch_size=10) # Run name is auto-generated: "bs_roformer_[vocals-bass-drums-other]_2024-06-01" ``` --- ## Summary Music Source Separation Training is primarily used in two scenarios. The first is **research and experimentation**: practitioners train or fine-tune source separation models on custom datasets by selecting from 15+ supported architectures, tuning YAML configs, enabling on-the-fly augmentations, and monitoring results through wandb and the SDR/SI-SDR/AURA metrics suite. The LoRA support makes it practical to adapt large pre-trained models to new instrument types or recording conditions with a fraction of the compute cost. The seven dataset types and CSV-based data loading let researchers mix and combine heterogeneous data sources without reformatting. The second major use case is **production inference and model ensembling**: once trained, models can separate large collections of audio files in batch via `inference.py`, optionally applying TTA and BigShifts to maximize quality. Multiple model outputs for the same stem can then be combined with `ensemble.py` using waveform or FFT-domain strategies, weighted by model quality. The framework integrates naturally into Python scripts or automation pipelines via its programmatic API (`proc_folder`, `check_validation`, `ensemble_files`, `get_model_from_config`), making it straightforward to embed into larger audio processing or music production workflows.