Try Live
Add Docs
Rankings
Pricing
Docs
Install
Install
Docs
Pricing
More...
More...
Try Live
Rankings
Enterprise
Create API Key
Add Docs
Burn
https://github.com/tracel-ai/burn
Admin
Burn is a next generation Deep Learning Framework that doesn't compromise on flexibility, efficiency
...
Tokens:
109,542
Snippets:
864
Trust Score:
7.7
Update:
1 month ago
Context
Skills
Chat
Benchmark
77.6
Suggestions
Latest
Show doc for...
Code
Info
Show Results
Context Summary (auto-generated)
Raw
Copy
Link
# Burn - Deep Learning Framework Burn is a next-generation tensor library and deep learning framework written in Rust that prioritizes flexibility, efficiency, and portability. It provides a comprehensive ecosystem for building, training, and deploying neural networks across multiple hardware backends including CUDA, ROCm, Metal, Vulkan, WebGPU, and CPU. The framework leverages Rust's type system to perform optimizations typically only available in static-graph frameworks while maintaining the flexibility of dynamic computation graphs. The framework's modular architecture centers around the `Backend` trait, which enables writing backend-agnostic code that can run on any supported hardware without modification. Burn provides automatic differentiation through a backend decorator pattern, kernel fusion for performance optimization, and a complete training infrastructure with metrics, checkpointing, and a terminal-based dashboard. The library supports ONNX model import, PyTorch/SafeTensors weight loading, and deployment targets ranging from embedded devices (no_std) to WebAssembly for browser inference. ## Tensor Creation and Operations Burn tensors are the fundamental data structure, parameterized by backend type `B`, dimensionality `D`, and optional element kind (Float, Int, Bool). Tensors support all standard operations with automatic broadcasting and are designed with ownership semantics that enable safe memory management and potential in-place optimizations. ```rust use burn::backend::Wgpu; use burn::tensor::{Tensor, TensorData, Int, Bool, Distribution}; type Backend = Wgpu; fn main() { let device = Default::default(); // Create tensors from arrays let tensor_1d = Tensor::<Backend, 1>::from_floats([1.0, 2.0, 3.0, 4.0, 5.0], &device); let tensor_2d = Tensor::<Backend, 2>::from_floats([[1.0, 2.0], [3.0, 4.0]], &device); // Create tensors with specific shapes let zeros = Tensor::<Backend, 2>::zeros([3, 3], &device); let ones = Tensor::<Backend, 2>::ones([3, 3], &device); let random = Tensor::<Backend, 2>::random([32, 32], Distribution::Default, &device); // Integer and boolean tensors let int_tensor = Tensor::<Backend, 1, Int>::from_ints([1, 2, 3, 4], &device); let bool_tensor = Tensor::<Backend, 1, Bool>::from_data([true, false, true], &device); // Tensor operations (note: operations consume tensors, use clone() to reuse) let a = Tensor::<Backend, 1>::from_floats([1.0, 2.0, 3.0, 4.0], &device); let min = a.clone().min(); let max = a.clone().max(); let normalized = (a.clone() - min.clone()).div(max - min); // Matrix operations let mat_a = Tensor::<Backend, 2>::random([4, 3], Distribution::Default, &device); let mat_b = Tensor::<Backend, 2>::random([3, 5], Distribution::Default, &device); let product = mat_a.matmul(mat_b); // [4, 5] // Shape manipulation let tensor = Tensor::<Backend, 2>::random([6, 4], Distribution::Default, &device); let reshaped = tensor.clone().reshape([2, 12]); let transposed = tensor.clone().transpose(); let flattened = tensor.flatten(0, 1); // Flatten dimensions 0 and 1 // Slicing and indexing let data = Tensor::<Backend, 2>::random([10, 10], Distribution::Default, &device); let sliced = data.clone().slice([2..5, 3..8]); // [3, 5] let selected = data.select(0, Tensor::<Backend, 1, Int>::from_ints([0, 2, 4], &device)); println!("Normalized: {}", normalized.to_data()); println!("Product shape: {:?}", product.dims()); } ``` ## Module Definition with Derive Macro The `#[derive(Module)]` macro automatically implements the `Module` trait for custom neural network structs, providing parameter management, serialization, and device transfer capabilities. Each field must also implement the `Module` trait. ```rust use burn::nn::{ conv::{Conv2d, Conv2dConfig}, pool::{AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig}, Dropout, DropoutConfig, Linear, LinearConfig, Relu, }; use burn::prelude::*; #[derive(Module, Debug)] pub struct ConvNet<B: Backend> { conv1: Conv2d<B>, conv2: Conv2d<B>, pool: AdaptiveAvgPool2d, dropout: Dropout, linear1: Linear<B>, linear2: Linear<B>, activation: Relu, } #[derive(Config, Debug)] pub struct ConvNetConfig { num_classes: usize, hidden_size: usize, #[config(default = "0.5")] dropout: f64, } impl ConvNetConfig { /// Initialize the model on the specified device pub fn init<B: Backend>(&self, device: &B::Device) -> ConvNet<B> { ConvNet { conv1: Conv2dConfig::new([1, 8], [3, 3]).init(device), conv2: Conv2dConfig::new([8, 16], [3, 3]).init(device), pool: AdaptiveAvgPool2dConfig::new([8, 8]).init(), activation: Relu::new(), linear1: LinearConfig::new(16 * 8 * 8, self.hidden_size).init(device), linear2: LinearConfig::new(self.hidden_size, self.num_classes).init(device), dropout: DropoutConfig::new(self.dropout).init(), } } } impl<B: Backend> ConvNet<B> { /// Forward pass: Images [batch_size, height, width] -> Output [batch_size, num_classes] pub fn forward(&self, images: Tensor<B, 3>) -> Tensor<B, 2> { let [batch_size, height, width] = images.dims(); let x = images.reshape([batch_size, 1, height, width]); let x = self.conv1.forward(x); let x = self.dropout.forward(x); let x = self.conv2.forward(x); let x = self.dropout.forward(x); let x = self.activation.forward(x); let x = self.pool.forward(x); let x = x.reshape([batch_size, 16 * 8 * 8]); let x = self.linear1.forward(x); let x = self.dropout.forward(x); let x = self.activation.forward(x); self.linear2.forward(x) } } fn main() { use burn::backend::Wgpu; type MyBackend = Wgpu; let device = Default::default(); let config = ConvNetConfig::new(10, 512); let model = config.init::<MyBackend>(&device); // Print model structure println!("{}", model); // Save configuration config.save("model_config.json").expect("Failed to save config"); } ``` ## Autodiff Backend for Training Wrap any backend with `Autodiff<B>` to enable automatic differentiation. The `backward()` method returns gradients that can be retrieved per-tensor, enabling flexible gradient computation without runtime mode switches. ```rust use burn::backend::{Autodiff, Wgpu}; use burn::tensor::{Tensor, Distribution}; fn main() { type Backend = Autodiff<Wgpu>; let device = Default::default(); // Create tensors with gradient tracking let x = Tensor::<Backend, 2>::random([32, 32], Distribution::Default, &device); let y = Tensor::<Backend, 2>::random([32, 32], Distribution::Default, &device) .require_grad(); // Enable gradient tracking // Forward computation let z = x.clone() + y.clone(); let z = z.matmul(x); let loss = z.exp().mean(); // Backward pass - returns gradients container let grads = loss.backward(); // Retrieve specific gradients let y_grad = y.grad(&grads).expect("Gradient should exist"); println!("Gradient shape: {:?}", y_grad.dims()); // For validation/inference without gradients, use inner() let inner_tensor = y.inner(); // Returns Tensor<Wgpu, 2> without autodiff } // Using autodiff with modules use burn::tensor::backend::AutodiffBackend; fn compute_loss<B: AutodiffBackend>( model_output: Tensor<B, 2>, targets: Tensor<B, 1, burn::tensor::Int>, ) -> Tensor<B, 1> { use burn::nn::loss::CrossEntropyLossConfig; let loss = CrossEntropyLossConfig::new() .init(&model_output.device()) .forward(model_output, targets); loss } ``` ## Data Loading with Batcher The `Batcher` trait converts dataset items into batched tensors. Combined with `DataLoaderBuilder`, it provides parallel data loading with shuffling and multi-worker support. ```rust use burn::data::dataloader::batcher::Batcher; use burn::data::dataloader::DataLoaderBuilder; use burn::data::dataset::vision::{MnistDataset, MnistItem}; use burn::prelude::*; #[derive(Clone, Default)] pub struct MnistBatcher; #[derive(Clone, Debug)] pub struct MnistBatch<B: Backend> { pub images: Tensor<B, 3>, pub targets: Tensor<B, 1, Int>, } impl<B: Backend> Batcher<B, MnistItem, MnistBatch<B>> for MnistBatcher { fn batch(&self, items: Vec<MnistItem>, device: &B::Device) -> MnistBatch<B> { // Convert images to tensors with normalization let images: Vec<Tensor<B, 3>> = items .iter() .map(|item| TensorData::from(item.image).convert::<B::FloatElem>()) .map(|data| Tensor::<B, 2>::from_data(data, device)) .map(|tensor| tensor.reshape([1, 28, 28])) .map(|tensor| ((tensor / 255) - 0.1307) / 0.3081) // Normalize .collect(); // Convert labels to tensor let targets: Vec<Tensor<B, 1, Int>> = items .iter() .map(|item| { Tensor::<B, 1, Int>::from_data( [(item.label as i64).elem::<B::IntElem>()], device, ) }) .collect(); MnistBatch { images: Tensor::cat(images, 0), targets: Tensor::cat(targets, 0), } } } fn main() { use burn::backend::NdArray; type Backend = NdArray; let batcher = MnistBatcher::default(); let dataloader = DataLoaderBuilder::new(batcher) .batch_size(64) .shuffle(42) .num_workers(4) .build(MnistDataset::train()); // Iterate over batches for batch in dataloader.iter() { println!("Batch images shape: {:?}", batch.images.dims()); println!("Batch targets shape: {:?}", batch.targets.dims()); break; // Just show first batch } } ``` ## Training with SupervisedTraining and Learner The `SupervisedTraining` struct provides a complete training loop with metrics, checkpointing, and a terminal dashboard. Implement `TrainStep` and `InferenceStep` traits to define how your model processes batches. ```rust use burn::data::dataloader::DataLoaderBuilder; use burn::data::dataset::vision::MnistDataset; use burn::nn::loss::CrossEntropyLossConfig; use burn::optim::AdamConfig; use burn::prelude::*; use burn::record::CompactRecorder; use burn::tensor::backend::AutodiffBackend; use burn::train::{ ClassificationOutput, InferenceStep, Learner, SupervisedTraining, TrainOutput, TrainStep, metric::{AccuracyMetric, LossMetric}, }; // Assume Model, MnistBatcher, MnistBatch are defined as shown above impl<B: Backend> Model<B> { pub fn forward_classification( &self, images: Tensor<B, 3>, targets: Tensor<B, 1, Int>, ) -> ClassificationOutput<B> { let output = self.forward(images); let loss = CrossEntropyLossConfig::new() .init(&output.device()) .forward(output.clone(), targets.clone()); ClassificationOutput::new(loss, output, targets) } } impl<B: AutodiffBackend> TrainStep for Model<B> { type Input = MnistBatch<B>; type Output = ClassificationOutput<B>; fn step(&self, batch: MnistBatch<B>) -> TrainOutput<ClassificationOutput<B>> { let item = self.forward_classification(batch.images, batch.targets); TrainOutput::new(self, item.loss.backward(), item) } } impl<B: Backend> InferenceStep for Model<B> { type Input = MnistBatch<B>; type Output = ClassificationOutput<B>; fn step(&self, batch: MnistBatch<B>) -> ClassificationOutput<B> { self.forward_classification(batch.images, batch.targets) } } #[derive(Config)] pub struct TrainingConfig { pub model: ModelConfig, pub optimizer: AdamConfig, #[config(default = 10)] pub num_epochs: usize, #[config(default = 64)] pub batch_size: usize, #[config(default = 4)] pub num_workers: usize, #[config(default = 42)] pub seed: u64, #[config(default = 1.0e-4)] pub learning_rate: f64, } pub fn train<B: AutodiffBackend>( artifact_dir: &str, config: TrainingConfig, device: B::Device, ) { // Create artifact directory std::fs::create_dir_all(artifact_dir).ok(); config.save(format!("{artifact_dir}/config.json")).unwrap(); B::seed(&device, config.seed); let batcher = MnistBatcher::default(); let dataloader_train = DataLoaderBuilder::new(batcher.clone()) .batch_size(config.batch_size) .shuffle(config.seed) .num_workers(config.num_workers) .build(MnistDataset::train()); let dataloader_test = DataLoaderBuilder::new(batcher) .batch_size(config.batch_size) .num_workers(config.num_workers) .build(MnistDataset::test()); let training = SupervisedTraining::new(artifact_dir, dataloader_train, dataloader_test) .metrics((AccuracyMetric::new(), LossMetric::new())) .with_file_checkpointer(CompactRecorder::new()) .num_epochs(config.num_epochs) .summary(); let model = config.model.init::<B>(&device); let result = training.launch(Learner::new( model, config.optimizer.init(), config.learning_rate, )); // Save final model result.model .save_file(format!("{artifact_dir}/model"), &CompactRecorder::new()) .expect("Failed to save model"); } ``` ## Saving and Loading Models Burn provides multiple formats for model persistence including Burnpack (native), SafeTensors, and PyTorch checkpoint loading. The `burn-store` crate offers advanced features like key remapping and partial loading. ```rust use burn::record::{NamedMpkFileRecorder, FullPrecisionSettings, CompactRecorder}; use burn::module::Module; // Basic save/load with recorders fn save_load_basic<B: Backend>(model: &Model<B>, device: &B::Device) { let model_path = "model_weights"; // Save with full precision MessagePack format let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::new(); model.save_file(model_path, &recorder) .expect("Failed to save model"); // Load model weights let mut loaded_model = ModelConfig::new(10, 512).init::<B>(device); loaded_model = loaded_model .load_file(model_path, &recorder, device) .expect("Failed to load model"); // Compact format (half precision, compressed) let compact_recorder = CompactRecorder::new(); model.save_file("model_compact", &compact_recorder).unwrap(); } // Advanced loading with burn-store use burn_store::{ModuleSnapshot, BurnpackStore, SafetensorsStore, PytorchStore}; fn load_from_pytorch<B: Backend>(model: &mut Model<B>) { // Load from PyTorch checkpoint let mut store = PytorchStore::from_file("pytorch_model.pt"); let result = model.load_from(&mut store).expect("Failed to load"); println!("Loaded {} tensors", result.applied.len()); if !result.missing.is_empty() { println!("Missing: {:?}", result.missing); } } fn load_from_safetensors<B: Backend>(model: &mut Model<B>) { use burn_store::PyTorchToBurnAdapter; // Load SafeTensors with PyTorch format adaptation let mut store = SafetensorsStore::from_file("model.safetensors") .with_from_adapter(PyTorchToBurnAdapter); model.load_from(&mut store).expect("Failed to load"); } fn load_with_key_remapping<B: Backend>(model: &mut Model<B>) { // Remap keys when model structure differs let mut store = PytorchStore::from_file("model.pt") .with_key_remapping(r"^model\.", "") // Remove "model." prefix .with_key_remapping(r"^layer", "encoder.layer") .allow_partial(true); // Continue even if some keys missing model.load_from(&mut store).expect("Failed to load"); } fn save_for_pytorch_compatibility<B: Backend>(model: &Model<B>) { use burn_store::BurnToPyTorchAdapter; let mut store = SafetensorsStore::from_file("for_pytorch.safetensors") .with_to_adapter(BurnToPyTorchAdapter) .skip_enum_variants(true); model.save_into(&mut store).expect("Failed to save"); } ``` ## Inference Load trained models and run inference without autodiff overhead. The `inner()` method extracts the base tensor from autodiff-wrapped tensors when needed. ```rust use burn::prelude::*; use burn::record::CompactRecorder; use burn::data::dataloader::batcher::Batcher; use burn::data::dataset::vision::{MnistDataset, MnistItem}; use burn::data::dataset::Dataset; pub fn infer<B: Backend>( artifact_dir: &str, device: B::Device, item: MnistItem, ) { // Load configuration let config = TrainingConfig::load(format!("{artifact_dir}/config.json")) .expect("Config should exist"); // Load model weights let record = CompactRecorder::new() .load(format!("{artifact_dir}/model").into(), &device) .expect("Model weights should exist"); // Initialize model with loaded weights let model = config.model.init::<B>(&device).load_record(record); // Prepare input let label = item.label; let batcher = MnistBatcher::default(); let batch = batcher.batch(vec![item], &device); // Run inference let output = model.forward(batch.images); let predicted = output.argmax(1).flatten::<1>(0, 1).into_scalar(); println!("Predicted: {}, Expected: {}", predicted, label); } fn main() { use burn::backend::Wgpu; type MyBackend = Wgpu; let device = Default::default(); let test_item = MnistDataset::test().get(42).unwrap(); infer::<MyBackend>("/tmp/mnist_model", device, test_item); } ``` ## Backend Configuration and Selection Burn supports multiple backends that can be swapped without code changes. Use type aliases to configure backends for different deployment scenarios. ```rust use burn::backend::{Autodiff, Wgpu, NdArray}; use burn::tensor::Tensor; // GPU backend with autodiff for training type TrainBackend = Autodiff<Wgpu>; // CPU backend for inference (no autodiff needed) type InferBackend = NdArray; // Backend with specific precision type WgpuF16 = Wgpu<f16, i32>; fn backend_example() { // WGPU (GPU) backend { use burn::backend::wgpu::WgpuDevice; type Backend = Wgpu; let device = WgpuDevice::default(); // Auto-select best GPU // Or specific device: // let device = WgpuDevice::DiscreteGpu(0); // let device = WgpuDevice::IntegratedGpu(0); let tensor = Tensor::<Backend, 2>::zeros([100, 100], &device); } // NdArray (CPU) backend { use burn::backend::ndarray::NdArrayDevice; type Backend = NdArray; let device = NdArrayDevice::Cpu; let tensor = Tensor::<Backend, 2>::zeros([100, 100], &device); } // LibTorch backend (when available) #[cfg(feature = "tch")] { use burn::backend::{LibTorch, libtorch::LibTorchDevice}; type Backend = LibTorch; let device = LibTorchDevice::Cuda(0); // GPU // let device = LibTorchDevice::Cpu; // CPU } // Remote backend for distributed computing #[cfg(feature = "remote")] { use burn::backend::RemoteDevice; // Server side // burn::server::start::<Wgpu>(Default::default(), 3000); // Client side let device = RemoteDevice::new("ws://localhost:3000"); } } ``` ## Built-in Neural Network Layers Burn provides comprehensive neural network modules including convolutions, normalization, pooling, attention, RNNs, and loss functions that mirror PyTorch's API. ```rust use burn::nn::{ // Linear and embedding Linear, LinearConfig, Embedding, EmbeddingConfig, // Convolutions conv::{Conv1d, Conv1dConfig, Conv2d, Conv2dConfig, Conv3d, Conv3dConfig}, conv::{ConvTranspose1d, ConvTranspose2d, ConvTranspose3d}, // Normalization BatchNorm, BatchNormConfig, LayerNorm, LayerNormConfig, GroupNorm, GroupNormConfig, InstanceNorm, InstanceNormConfig, RmsNorm, RmsNormConfig, // Pooling pool::{AvgPool1d, AvgPool2d, MaxPool1d, MaxPool2d}, pool::{AdaptiveAvgPool1d, AdaptiveAvgPool2d}, // Activations Relu, Gelu, Selu, LeakyRelu, Dropout, DropoutConfig, // Transformer attention::MultiHeadAttention, transformer::{TransformerEncoder, TransformerDecoder}, PositionalEncoding, RotaryEncoding, // RNN Lstm, LstmConfig, Gru, GruConfig, // Loss functions loss::{CrossEntropyLoss, CrossEntropyLossConfig}, loss::{MseLoss, HuberLoss, BinaryCrossEntropyLoss}, }; use burn::prelude::*; fn layers_example<B: Backend>(device: &B::Device) { // Linear layer let linear = LinearConfig::new(512, 256) .with_bias(true) .init(device); // Convolution let conv = Conv2dConfig::new([3, 64], [3, 3]) .with_stride([1, 1]) .with_padding(burn::nn::PaddingConfig2d::Same) .init(device); // Layer normalization let norm = LayerNormConfig::new(512).init(device); // Multi-head attention let attention = burn::nn::attention::MultiHeadAttentionConfig::new(512, 8) .with_dropout(0.1) .init(device); // LSTM let lstm = LstmConfig::new(256, 512, true) // (input, hidden, bidirectional) .with_num_layers(2) .init(device); // Loss let ce_loss = CrossEntropyLossConfig::new().init(device); } ``` ## Dataset Transformations Burn provides lazy dataset transformations for sampling, shuffling, partitioning, and mapping data without unnecessary allocations. ```rust use burn::data::dataset::{ Dataset, InMemDataset, transform::{ SamplerDataset, ShuffledDataset, PartialDataset, MapperDataset, ComposedDataset, Mapper, }, }; #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] struct DataItem { features: Vec<f32>, label: usize, } struct NormalizeMapper; impl Mapper<DataItem, DataItem> for NormalizeMapper { fn map(&self, item: &DataItem) -> DataItem { let max = item.features.iter().cloned().fold(f32::MIN, f32::max); DataItem { features: item.features.iter().map(|x| x / max).collect(), label: item.label, } } } fn dataset_transformations() { // Create in-memory dataset let items = vec![ DataItem { features: vec![1.0, 2.0], label: 0 }, DataItem { features: vec![3.0, 4.0], label: 1 }, // ... more items ]; let dataset = InMemDataset::new(items); // Shuffle dataset let shuffled = ShuffledDataset::new(dataset.clone(), 42); // Create train/test split (80/20) let len = dataset.len(); let train = PartialDataset::new(shuffled.clone(), 0, len * 8 / 10); let test = PartialDataset::new(shuffled, len * 8 / 10, len); // Apply transformation let normalized = MapperDataset::new(train, NormalizeMapper); // Sample fixed number of items (useful for checkpointing) let sampled = SamplerDataset::new(normalized, 1000); // Compose multiple datasets let combined = ComposedDataset::new(vec![ Box::new(dataset.clone()), Box::new(dataset), ]); } // Loading from Hugging Face use burn::data::dataset::HuggingfaceDatasetLoader; use burn::data::dataset::SqliteDataset; #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] struct TextItem { text: String, label: usize, } fn load_huggingface() { let dataset: SqliteDataset<TextItem> = HuggingfaceDatasetLoader::new("imdb") .dataset("train") .unwrap(); println!("Dataset size: {}", dataset.len()); } ``` ## Summary Burn provides a complete deep learning ecosystem in Rust, enabling researchers and engineers to build production-ready models with type-safe APIs and cross-platform deployment. The framework's modular design allows mixing and matching backends (CUDA, ROCm, Metal, WebGPU, CPU) without code changes, while the autodiff decorator pattern provides clean separation between training and inference code paths. Key integration patterns include: (1) using the `#[derive(Module)]` and `#[derive(Config)]` macros to define models with serializable configurations, (2) implementing `TrainStep` and `InferenceStep` traits for custom training loops with the `SupervisedTraining` infrastructure, (3) leveraging `DataLoaderBuilder` with custom `Batcher` implementations for efficient parallel data loading, and (4) using `burn-store` for flexible model weight management including PyTorch/SafeTensors interoperability. For deployment, models can target WebAssembly for browser inference, embedded systems via no_std support, or distributed training with the remote backend, all while maintaining identical model code across platforms.