# Value Residual Learning Value Residual Learning is the official research codebase accompanying the paper [Value Residual Learning](https://arxiv.org/abs/2410.17897). It introduces two novel transformer architectures — **Resformer** and **SVformer** — that improve language model training by modifying how value states are computed across attention layers. In Resformer, each layer's value states are blended (50/50) with those from the first layer, acting as a residual connection through the value pathway. In SVformer, all layers share the exact same value projection computed only at layer 0, drastically reducing value-projection parameters. Both architectures are implemented as drop-in replacements for the standard LLaMA and GPT-2 architectures and are trained on 20B tokens sampled from the SlimPajama-627B dataset. The codebase provides a full end-to-end pipeline: raw data download, domain-based reorganization, tokenization/preprocessing, distributed multi-GPU training via DeepSpeed ZeRO-2 with Hugging Face Accelerate, and analysis tools for measuring per-layer attention entropy and hidden-state token similarity. The training harness (`src/train.py`) accepts model architecture, data, and optimizer arguments, and delegates to a custom `UpdatableTrainer` that supports three novel learning-rate schedules (linear-warmup-exponential, linear-warmup-cosine, linear-warmup-linear) with configurable warmup and decay endpoints. --- ## Data Pipeline ### `src_data/reorganize_data.py` — Re-organize SlimPajama by domain Reads the raw SlimPajama-627B dataset (compressed `.jsonl.zst` chunks) and writes one JSONL file per domain per chunk, enabling efficient per-domain sampling later. ```python # reorganize_data.py (run as a script; set Data_Path and Write_Dir at top of file) # Data_Path = "Value-Residual-Learning/data/SlimPajama-627B" # Write_Dir = "Value-Residual-Learning/data/slimpajama_all/{}" # Produces files like: # slimpajama_all/RedPajamaCommonCrawl/train_chunk1.jsonl # slimpajama_all/RedPajamaC4/train_chunk1.jsonl # ... # slimpajama_all/RedPajamaWikipedia/validation_chunk_all.jsonl python src_data/reorganize_data.py ``` --- ### `src_data/selection.py` — Sample N tokens from a domain Randomly samples text from a reorganized domain directory until a target token budget is reached, then writes a single `train.jsonl` output file for that domain. ```bash python src_data/selection.py \ --input_dir data/slimpajama_all \ --output_dir data/slimpajama_20B \ --meta_name RedPajamaCommonCrawl \ --data_size 10000000000 \ # 10 B tokens --tokenizer data/tokenizer/RedPajama-INCITE-Base-7B \ --cache_dir cache/ # Output: data/slimpajama_20B/RedPajamaCommonCrawl/train.jsonl # Run for each domain; use run_selection.sh / scp_valid.sh to batch all domains. ``` --- ### `src_data/preprocess.py` — Tokenize and chunk a domain Tokenizes raw text files for a single domain and writes fixed-length (e.g. 2048-token) chunks to disk in HuggingFace Arrow format, ready for fast streaming during training. ```bash python src_data/preprocess.py \ --dataset_dir data/slimpajama_20B \ --output_dir data/processed_slimpajama_20B \ --domain RedPajamaCommonCrawl \ --max_length 2048 \ --nproc 8 \ --tokenizer data/tokenizer/RedPajama-INCITE-Base-7B \ --cache_dir cache/ # Or batch-process all domains: bash src_data/run_all_preprocess.sh # pass max_length as first argument # Output layout: # processed_slimpajama_20B/ # train/ # RedPajamaCommonCrawl_length2048/ ← Arrow shards # RedPajamaC4_length2048/ # ... # validation/ # RedPajamaCommonCrawl_length2048/ ``` --- ## Data Loading ### `get_preprocessed_mixed_dataset` — Load a weighted multi-domain dataset Returns a single `IterableDataset` that streams preprocessed shards from multiple domains, interleaving them according to specified sampling probabilities. Used directly by the training script. ```python from dataloaders import get_preprocessed_mixed_dataset domain_weight_train = { "RedPajamaCommonCrawl_length2048": 0.50, "RedPajamaC4_length2048": 0.20, "RedPajamaGithub_length2048": 0.10, "RedPajamaStackExchange_length2048": 0.05, "RedPajamaWikipedia_length2048": 0.05, "RedPajamaBook_length2048": 0.05, "RedPajamaArXiv_length2048": 0.05, } train_dataset = get_preprocessed_mixed_dataset( preprocessed_dir="data/processed_slimpajama_20B", dataset_name="slimpajama", domain_weight=domain_weight_train, split="train", max_samples=10_240_000, # total examples cap seed=42, no_interleave=False, # True → sequential domain order shuffle=True, keep_in_memory=False, ) # Iterate for batch in train_dataset.take(3): print(batch.keys()) # dict_keys(['input_ids', 'attention_mask', 'domain_id']) ``` --- ### `get_data_collator` — Build a batching collator Returns a collator function compatible with HuggingFace `Trainer`. Optionally pads sequences and always sets `labels = input_ids` (with pad tokens masked to `-100`). ```python from transformers import AutoTokenizer from dataloaders import get_data_collator tokenizer = AutoTokenizer.from_pretrained("data/tokenizer/RedPajama-INCITE-Base-7B", use_fast=True) collator = get_data_collator( tokenizer, do_padding=False, # sequences are already fixed-length after preprocessing max_length=2048, ) # collator receives a list of dicts with 'input_ids', 'attention_mask' batch = collator([ {"input_ids": [1, 2, 3, 4], "attention_mask": [1, 1, 1, 1]}, {"input_ids": [5, 6, 7, 8], "attention_mask": [1, 1, 1, 1]}, ]) # batch["input_ids"] → torch.LongTensor (2, 4) # batch["labels"] → torch.LongTensor (2, 4), same as input_ids # batch["attention_mask"] → torch.LongTensor (2, 4) ``` --- ## Model Architectures ### `modeling_llama_resformer` — Resformer: value residual blending Extends the standard LLaMA attention so that every layer after the first blends its own value projection with the first layer's value states at a 50/50 ratio (`value = 0.5 * layer_0_values + 0.5 * current_values`). The first-layer values are accumulated in a list (`formal_layer_values`) passed through the decoder stack. ```python # Key forward logic inside LlamaAttention (modeling_llama_resformer.py) value_states = self.v_proj(hidden_states) if self.layer_idx == 0: formal_layer_values.append(value_states) # store layer-0 values else: value_states = 0.5 * formal_layer_values[0] + 0.5 * value_states # residual blend # Instantiate and train via train.py: from modeling.modeling_llama_resformer import LlamaForCausalLM from transformers import LlamaConfig config = LlamaConfig( vocab_size=50277, hidden_size=512, intermediate_size=1792, num_hidden_layers=8, num_attention_heads=8, num_key_value_heads=8, max_position_embeddings=2048, hidden_act="silu", ) model = LlamaForCausalLM(config) print(sum(p.numel() for p in model.parameters()) / 1e6, "M params") # ~82 M ``` --- ### `modeling_llama_svformer` — SVformer: shared value projection A more extreme variant where **only layer 0 has a `v_proj` weight**. All subsequent layers reuse the exact value states produced at layer 0 without any blending or additional projection. This eliminates N−1 value projection matrices from the model. ```python # Key forward logic (modeling_llama_svformer.py) # __init__: v_proj is only created when layer_idx == 0 if self.layer_idx == 0: self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) # forward: if self.layer_idx == 0: value_states = self.v_proj(hidden_states) formal_layer_values.append(value_states) else: value_states = formal_layer_values[0] # direct reuse, no blending # Train SVformer (82 M): bash scripts/run_llama_svformer_82M.sh ``` --- ## Training ### `src/train.py` — Main training entry point Parses three argument groups (`ModelArguments`, `DataTrainingArguments`, `FullTrainingArguments`), dynamically imports the chosen model architecture, loads the tokenizer and dataset, then launches `UpdatableTrainer` for training and/or evaluation. ```bash # Train LLaMA Resformer 82M on 8 GPUs with DeepSpeed ZeRO-2: mkdir -p logs output # Edit CACHE and CODE_DIR in the script, then: bash scripts/run_llama_resformer_82M.sh # Equivalent manual launch: accelerate launch \ --config_file config/accelerate_deepspeed_zero2_config.yml \ --num_processes 8 \ src/train.py \ --do_train \ --bf16 \ --modeling_name modeling_llama_resformer \ --model_type llama \ --config_overrides "vocab_size=50277,hidden_size=512,intermediate_size=1792,num_hidden_layers=8,num_attention_heads=8,num_key_value_heads=8,max_position_embeddings=2048,hidden_act=silu,rope_theta=10000,rms_norm_eps=1e-5,attention_dropout=0.0,tie_word_embeddings=False,eos_token_id=0,bos_token_id=0,use_cache=False" \ --tokenizer_name data/tokenizer/RedPajama-INCITE-Base-7B \ --attn_implementation flash_attention_2 \ --dataset_dir data/processed_slimpajama_20B \ --output_dir output/modeling_llama_resformer_82M \ --domain_weight_train '{"RedPajamaCommonCrawl_length2048":0.5,"RedPajamaC4_length2048":0.2,"RedPajamaGithub_length2048":0.1,"RedPajamaStackExchange_length2048":0.05,"RedPajamaWikipedia_length2048":0.05,"RedPajamaBook_length2048":0.05,"RedPajamaArXiv_length2048":0.05}' \ --domain_weight_eval '{"RedPajamaCommonCrawl_length2048":1.0}' \ --max_steps 10000 \ --per_device_train_batch_size 32 \ --gradient_accumulation_steps 4 \ --lr_scheduler_name linear_warmup_cosine \ --learning_rate 6e-4 \ --lr_end 6e-5 \ --warmup_step 120 \ --num_warmup_stop_steps 10000 \ --optim adamw_torch \ --weight_decay 0.1 \ --max_grad_norm 1.0 \ --seed 42 \ 2>&1 | tee logs/modeling_llama_resformer_82M.log # Supported --modeling_name values: # modeling_llama_baseline # modeling_llama_resformer # modeling_llama_svformer # modeling_llama_NeuTRENO_lambda04 # modeling_llama_NeuTRENO_resformer # modeling_gpt2_baseline ``` --- ### `FullTrainingArguments` — Extended training arguments Extends HuggingFace `TrainingArguments` with three extra fields for custom LR schedulers. ```python from training_args import FullTrainingArguments args = FullTrainingArguments( output_dir="output/my_run", do_train=True, bf16=True, max_steps=10000, per_device_train_batch_size=32, gradient_accumulation_steps=4, lr_scheduler_name="linear_warmup_cosine", # or "linear_warmup_exponential" / "linear_warmup_linear" learning_rate=6e-4, lr_end=6e-5, # final LR after decay (new field) warmup_steps=120, num_warmup_stop_steps=10000, # step at which decay ends (new field) weight_decay=0.1, optim="adamw_torch", adam_beta1=0.9, adam_beta2=0.95, ) ``` --- ## Custom LR Schedulers ### `LinearWarmupCosineLR` — Cosine decay with linear warmup Linearly warms up from `lr_start` (1e-7) to `base_lr` over `num_warmup_steps`, then follows a cosine decay to `lr_end` until `num_warmup_stop_steps`, and holds `lr_end` for the remainder of training. ```python import torch from trainer import LinearWarmupCosineLR optimizer = torch.optim.AdamW([torch.zeros(1, requires_grad=True)], lr=6e-4) scheduler = LinearWarmupCosineLR( optimizer, num_warmup_steps=120, num_warmup_stop_steps=10000, num_training_steps=10000, lr_start=1e-7, lr_end=6e-5, ) for step in range(10000): optimizer.step() scheduler.step() if step in (0, 120, 5000, 9999): print(f"step {step}: lr = {scheduler.get_last_lr()[0]:.2e}") # step 0: lr = 5.00e-07 (during warmup) # step 120: lr = 6.00e-04 (peak) # step 5000: lr = ~3.06e-04 (mid cosine) # step 9999: lr = 6.00e-05 (floor) ``` --- ### `LinearWarmupExponentialLR` — Exponential decay with linear warmup Same warmup phase, then exponential decay to `lr_end` (reaching within 1e-10 of `lr_end` at `num_warmup_stop_steps`), then constant at `lr_end`. ```python from trainer import LinearWarmupExponentialLR optimizer = torch.optim.AdamW([torch.zeros(1, requires_grad=True)], lr=6e-4) scheduler = LinearWarmupExponentialLR( optimizer, num_warmup_steps=120, num_warmup_stop_steps=10000, num_training_steps=10000, lr_start=1e-7, lr_end=6e-5, ) ``` --- ## Analysis Tools ### `analyze/get_entropy.py` — Per-layer attention entropy Loads a LLaMA model and a JSON list of text samples, runs forward passes with `output_attentions=True`, and computes the Shannon entropy of the averaged attention distribution at each layer. High entropy indicates spread-out attention; low entropy indicates focused attention. ```bash # Place sampled_data.json in analyze/ (a JSON list of strings, e.g. ["text1", "text2", ...]) # Set model_name_or_path inside the script, then: cd analyze python get_entropy.py # Output (printed to stdout): list of average per-layer entropy values, e.g. # [2.31, 2.45, 2.67, 2.89, 3.01, 2.95, 2.78, 2.60] (one float per layer) ``` ```python # Core function: from scipy.stats import entropy as scipy_entropy import numpy as np, torch def calculate_entropy(attention_weights): # attention_weights: (batch=1, num_heads, seq, seq) average_attention = torch.mean(attention_weights, dim=1).squeeze(0) # (seq, seq) token_importance = average_attention.sum(dim=0) # (seq,) token_importance = np.array(token_importance.tolist()) token_importance = token_importance / token_importance.sum() return scipy_entropy(token_importance) ``` --- ### `analyze/get_similarity.py` — Per-layer token representation similarity Computes the average cosine similarity between all pairs of token representations (hidden states or key vectors) at each layer. Higher similarity across layers is associated with the representation collapse / token-uniformity problem that value residual learning is designed to mitigate. ```bash cd analyze # Edit model_name_or_path, embedding_type ("hidden" or "key"), head_operation inside script python get_similarity.py # Output: list of average per-layer similarity values, e.g. # [0.12, 0.18, 0.25, 0.31, 0.38, 0.44, 0.50, 0.57] ``` ```python def calculate_similarity(data, head_operation='average'): # data: (1, num_heads, seq, dim) for key; (1, seq, dim) for hidden if head_operation == 'average': data = data.mean(dim=1) # (1, seq, dim) elif head_operation == 'concatenate': data = data.transpose(1, 2).reshape(data.size(0), data.size(2), -1) data = data.squeeze() # (seq, dim) data_norm = torch.nn.functional.normalize(data, p=2, dim=-1) sim_matrix = torch.matmul(data_norm, data_norm.T) # (seq, seq) mask = torch.eye(sim_matrix.size(0), device=data.device).bool() return sim_matrix[~mask].mean().item() ``` --- ### `analyze/plot_relative_loss.py` — Plot relative training loss between runs Reads training log files, extracts step-wise loss, applies a 100-step moving average, and plots the loss difference (method − baseline) over training steps. Positive values mean the method is worse; negative means it is better than baseline. ```python # Configure file_ls_all at the bottom of plot_relative_loss.py: file_ls_all = { "ResformerVsBaseline": [[ ["logs/modeling_llama_82M.log", "logs/modeling_llama_resformer_82M.log"], ["logs/modeling_llama_82M.log", "logs/modeling_llama_svformer_82M.log"], ], [ "Resformer", "SVformer", ]], } # Then run: python analyze/plot_relative_loss.py # Output: figure_v2/ResformerVsBaseline.pdf # X-axis: Training Step (0–10000) # Y-axis: Relative Training Loss (method loss − baseline loss) # Dashed red line at y=0 marks the baseline ``` --- ## Summary Value Residual Learning is primarily used for language model pretraining research. The two main use cases are: (1) training Resformer or SVformer from scratch on large text corpora (20B+ tokens) to validate the value residual hypothesis — that re-using early-layer value representations improves gradient flow and reduces representation collapse — and (2) running ablation studies by swapping the `--modeling_name` flag to compare baseline LLaMA/GPT-2 against Resformer, SVformer, and NeuTRENO variants under identical data and optimizer settings. The analysis tools (`get_entropy.py`, `get_similarity.py`, `plot_relative_loss.py`) complement training by providing interpretability metrics on existing checkpoints. Integration into a broader research workflow typically follows the pipeline: download SlimPajama → reorganize by domain → select a 20B-token subset → tokenize into fixed-length chunks → launch distributed training with Accelerate + DeepSpeed ZeRO-2 → evaluate per-domain perplexity logged at every 1000 steps. The modular design means any step can be replaced independently: alternative tokenizers, datasets, or base architectures (GPT-2 vs. LLaMA) can be swapped by changing a few flags, while the value-residual mechanism itself is confined to a handful of lines in each `LlamaAttention.forward` method, making it straightforward to port to other transformer implementations.