### Install LiveTalk and Dependencies Source: https://context7.com/gair-nlp/livetalk/llms.txt This bash script outlines the steps to install LiveTalk, including cloning the repository, applying patches, creating a conda environment, installing dependencies, and downloading model checkpoints. It ensures all necessary components and pre-trained models are set up for the project. ```bash # Clone the repository git clone https://github.com/GAIR-NLP/livetalk.git cd livetalk # Clone OmniAvatar dependency git clone https://github.com/Omni-Avatar/OmniAvatar # Apply patches bash scripts/add_patch.sh # Create conda environment conda create -n livetalk python=3.10 -y conda activate livetalk pip install -r requirements.txt conda install -c conda-forge ffmpeg pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.8cxx11abiFALSE-cp310-cp310-linux_x86_64.whl python setup.py develop # Download model checkpoints huggingface-cli download Wan-AI/Wan2.1-T2V-1.3B --local-dir-use-symlinks False --local-dir pretrained_checkpoints/Wan2.1-T2V-1.3B huggingface-cli download GAIR/LiveTalk-1.3B-V0.1 --local-dir-use-symlinks False --local-dir pretrained_checkpoints/LiveTalk-1.3B-V0.1 huggingface-cli download facebook/wav2vec2-base-960h --local-dir-use-symlinks False --local-dir pretrained_checkpoints/wav2vec2 ``` -------------------------------- ### Set Up Conda Environment and Install Dependencies Source: https://github.com/gair-nlp/livetalk/blob/main/README.md This sequence of commands sets up a dedicated conda environment for LiveTalk, installs Python 3.10, and then installs all required Python packages listed in the requirements.txt file. It also installs ffmpeg and a specific version of flash-attention. ```bash conda create -n livetalk python=3.10 -y conda activate livetalk pip install -r requirements.txt conda install -c conda-forge ffmpeg pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.8cxx11abiFALSE-cp310-cp310-linux_x86_64.whl python setup.py develop ``` -------------------------------- ### Batch Inference Script Execution Source: https://context7.com/gair-nlp/livetalk/llms.txt Provides command-line examples for running batch inference using the `inference.py` script. It covers basic text-to-video generation, image-to-video (i2v) inference, and distributed inference across multiple GPUs. The expected outputs, including video naming and resolution, are also detailed. ```bash # Basic inference python inference.py \ --config_path configs/causal_inference.yaml \ --checkpoint_path pretrained_checkpoints/LiveTalk-1.3B-V0.1/model.safetensors \ --data_path prompts.txt \ --output_folder outputs/ \ --num_output_frames 21 \ --seed 42 \ --num_samples 1 # Image-to-video inference python inference.py \ --config_path configs/causal_inference.yaml \ --checkpoint_path pretrained_checkpoints/LiveTalk-1.3B-V0.1/model.safetensors \ --data_path image_prompt_pairs.json \ --output_folder outputs/ \ --num_output_frames 21 \ --i2v \ --use_ema \ --save_with_index # Distributed inference (multi-GPU) torchrun --nproc_per_node=4 inference.py \ --config_path configs/causal_inference.yaml \ --checkpoint_path pretrained_checkpoints/LiveTalk-1.3B-V0.1/model.safetensors \ --data_path prompts.txt \ --output_folder outputs/ \ --num_output_frames 21 # Expected output: # - Videos saved to outputs/ folder # - Naming: {index}-{seed}_{model}.mp4 or {prompt[:100]}-{seed}.mp4 # - Resolution: 480x832 or 720x1280 depending on max_hw # - Frame rate: 16 FPS ``` -------------------------------- ### Video Generation Pipeline Source: https://context7.com/gair-nlp/livetalk/llms.txt Demonstrates the Python code for configuring pipeline parameters and generating video using a pre-trained model. It includes setting diffusion steps, sampling solvers, preparing input noise and text prompts, and performing inference. The output includes the generated video and latents, with an example of saving the video using torchvision. ```python pipeline.sampling_steps = 50 # Number of diffusion steps pipeline.sample_solver = 'unipc' # Options: 'unipc', 'dpm++' batch_size = 1 num_frames = 21 # Must be multiple of num_frame_per_block (default: 3) noise = torch.randn([batch_size, num_frames, 16, 60, 104], device=device, dtype=torch.bfloat16) text_prompts = [ "A person speaking naturally with expressive facial movements and synchronized lip movements" ] # Optional: Provide initial latent for image-to-video # initial_latent = pipeline.vae.encode_to_latent(reference_image) # Shape: [B, num_input_frames, 16, 60, 104] video, latents = pipeline.inference( noise=noise, text_prompts=text_prompts, initial_latent=None, # Set to initial_latent for I2V return_latents=True, start_frame_index=0 ) # Output # video: [batch_size, num_frames, 3, height, width], range [0, 1] # latents: [batch_size, num_frames, 16, 60, 104] from torchvision.io import write_video import einops output = einops.rearrange(video, 'b t c h w -> b t h w c').cpu() write_video("output.mp4", (output[0] * 255).byte(), fps=16) ``` -------------------------------- ### Causal Inference Pipeline for Few-Step Video Generation Source: https://context7.com/gair-nlp/livetalk/llms.txt This Python script demonstrates how to use the CausalInferencePipeline for real-time video generation with a distilled 4-step diffusion model. It includes configuration loading, pipeline initialization, checkpoint loading, device/dtype setup, generation parameter configuration, noise tensor preparation, and video saving. The pipeline supports multimodal conditioning via text, image, and audio inputs. ```python import torch from omegaconf import OmegaConf from scripts.inference_example import CausalInferencePipeline import numpy as np # Added import for numpy import imageio # Added import for imageio # Load configuration args = OmegaConf.load("configs/causal_inference.yaml") device = torch.device("cuda:0") # Initialize pipeline pipeline = CausalInferencePipeline.from_pretrained( args=args, device=device ) # Load checkpoint (optional) checkpoint = torch.load("pretrained_checkpoints/LiveTalk-1.3B-V0.1/model.safetensors", map_location="cpu") pipeline.generator.load_state_dict(checkpoint['generator']) # Set to evaluation mode pipeline = pipeline.to(dtype=torch.bfloat16) pipeline.generator.to(device=device) pipeline.text_encoder.to(device=device) pipeline.vae.to(device=device) # Configure generation parameters text_prompt = "A realistic video of a person speaking directly to the camera. The individual maintains steady eye contact with clear, expressive facial features." image_path = "examples/inference/example1.jpg" audio_path = "examples/inference/example1.wav" video_duration = 5 # seconds # Calculate number of frames: num_frames = (duration * fps + 4) // 4 num_frames = (video_duration * args.fps + 4) // 4 # Prepare noise tensor noise = torch.randn([1, num_frames, 16, 64, 64], device=device, dtype=torch.bfloat16) # Generate video with multimodal conditioning video = pipeline( noise=noise, text_prompts=text_prompt, image_path=image_path, audio_path=audio_path, initial_latent=None, return_latents=False ) # Save output video_np = (video.squeeze(0).permute(0, 2, 3, 1).cpu().float().numpy() * 255).astype(np.uint8) imageio.mimsave( "output_video.mp4", video_np, fps=args.fps, codec="libx264", macro_block_size=None, ffmpeg_params=["-crf", "18", "-preset", "veryfast", "-pix_fmt", "yuv420p"] ) print(f"Generated video shape: {video.shape}") # Expected: [1, num_frames, 3, H, W] ``` -------------------------------- ### YAML Configuration for Causal Inference Source: https://context7.com/gair-nlp/livetalk/llms.txt Defines the structure for YAML configuration files used in LiveTalk, specifying model paths, input data locations, generation parameters, and causal inference settings. This allows for flexible and reproducible experiment setups. ```yaml # configs/causal_inference.yaml # Model paths dtype: "bf16" text_encoder_path: pretrained_checkpoints/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth dit_path: pretrained_checkpoints/LiveTalk-1.3B-V0.1/model.safetensors vae_path: pretrained_checkpoints/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth wav2vec_path: pretrained_checkpoints/wav2vec2 # Input data paths image_path: examples/inference/example1.jpg audio_path: examples/inference/example1.wav prompt: "A realistic video of a person speaking directly to the camera." output_path: "output_video.mp4" video_duration: 5 # Duration in seconds (should be 3n+2, e.g., 5, 8, 11, 14, 17, 20) # Generation parameters max_hw: 720 # 720: 480p; 1280: 720p image_sizes_720: [[512,512]] fps: 16 sample_rate: 16000 num_steps: 4 local_attn_size: 15 # Causal inference parameters denoisin g_step_list: [1000, 750, 500, 250] warp_denoising_step: true num_transformer_blocks: 30 frame_seq_length: 1024 num_frame_per_block: 3 independent_first_frame: False ``` -------------------------------- ### Audio Embedding Extraction with Wav2Vec2 Source: https://context7.com/gair-nlp/livetalk/llms.txt Demonstrates the Python code for loading and processing audio files using the Wav2Vec2 model for feature extraction. This snippet includes initializing the feature extractor and audio encoder, loading an audio file with librosa, trimming it to a specific duration, and extracting input values for further processing. ```python import torch import librosa import numpy as np from transformers import Wav2Vec2FeatureExtractor from OmniAvatar.models.wav2vec import Wav2VecModel # Initialize audio encoder device = torch.device("cuda:0") dtype = torch.bfloat16 wav_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( "pretrained_checkpoints/wav2vec2" ) audio_encoder = Wav2VecModel.from_pretrained( "pretrained_checkpoints/wav2vec2", local_files_only=True ).to(device=device, dtype=dtype) audio_encoder.feature_extractor._freeze_parameters() audio_encoder.eval() # Load and process audio audio_path = "examples/inference/example1.wav" audio, sr = librosa.load(audio_path, sr=16000) # Trim audio to match video duration video_duration = 5.0 # seconds max_samples = int(video_duration * sr) if len(audio) > max_samples: audio = audio[:max_samples] # Extract features input_values = np.squeeze( wav_feature_extractor(audio, sampling_rate=16000).input_values ) input_values = torch.from_numpy(input_values).float().to(device=device, dtype=dtype) input_values = input_values.unsqueeze(0) # Calculate audio length for video frames num_frames = (video_duration * 16 + 4) // 4 # fps=16 audio_len = num_frames * 4 - 3 ``` -------------------------------- ### PyTorch Cross-Attention Cache Initialization Source: https://context7.com/gair-nlp/livetalk/llms.txt This Python code initializes the cross-attention cache for text conditioning in transformer models. It creates tensors for keys and values, along with an initialization flag for each transformer block, supporting efficient conditioning during generation. ```python import torch def initialize_crossattn_cache(num_transformer_blocks, batch_size, dtype, device): """Initialize cross-attention cache for text conditioning.""" crossattn_cache = [] for _ in range(num_transformer_blocks): crossattn_cache.append({ "k": torch.zeros([batch_size, 512, 12*128], dtype=dtype, device=device), "v": torch.zeros([batch_size, 512, 12*128], dtype=dtype, device=device), "is_init": False }) return crossattn_cache # Usage example device = torch.device("cuda:0") dtype = torch.bfloat16 num_transformer_blocks = 30 batch_size = 1 crossattn_cache = initialize_crossattn_cache( num_transformer_blocks=num_transformer_blocks, batch_size=batch_size, dtype=dtype, device=device ) ``` -------------------------------- ### Clone LiveTalk and OmniAvatar Repositories Source: https://github.com/gair-nlp/livetalk/blob/main/README.md This snippet demonstrates how to clone the LiveTalk repository and its dependency, OmniAvatar, using git. These commands are essential for setting up the project locally. ```bash git clone https://github.com/GAIR-NLP/livetalk.git cd livetalk git clone https://github.com/Omni-Avatar/OmniAvatar ``` -------------------------------- ### Download Model Checkpoints using Hugging Face CLI Source: https://github.com/gair-nlp/livetalk/blob/main/README.md These commands use the Hugging Face CLI to download pre-trained model checkpoints required for LiveTalk. This includes the Wan2.1 base model, the LiveTalk model itself, and a Wav2Vec2 model for audio processing. The checkpoints are saved into specific directories within the project. ```bash # Download Wan2.1 base model huggingface-cli download Wan-AI/Wan2.1-T2V-1.3B --local-dir-use-symlinks False --local-dir pretrained_checkpoints/Wan2.1-T2V-1.3B # Download LiveTalk model checkpoint huggingface-cli download GAIR/LiveTalk-1.3B-V0.1 --local-dir-use-symlinks False --local-dir pretrained_checkpoints/LiveTalk-1.3B-V0.1 # Download Wav2Vec2 model for audio processing huggingface-cli download facebook/wav2vec2-base-960h --local-dir-use-symlinks False --local-dir pretrained_checkpoints/wav2vec2 ``` -------------------------------- ### CausVid Training with Distribution Matching Distillation Source: https://context7.com/gair-nlp/livetalk/llms.txt Trains the causal few-step generator using distribution matching distillation from a teacher model. This involves loading configuration, initializing the CausVid model, setting training parameters, and configuring an optimizer. Dependencies include PyTorch, OmegaConf, and custom dataset/model modules. ```python import torch from omegaconf import OmegaConf from model.causvid import CausVid from utils.dataset import TextDataset from torch.utils.data import DataLoader # Load configuration config = OmegaConf.load("configs/dmd.yaml") device = torch.device("cuda:0") # Initialize CausVid model model = CausVid(args=config, device=device) model.train() # Configure training parameters config.num_train_timestep = 1000 config.guidance_scale = 6.0 config.real_guidance_scale = 6.0 config.fake_guidance_scale = 0.0 config.timestep_shift = 5.0 config.teacher_forcing = False config.gradient_checkpointing = True # Setup optimizer optimizer = torch.optim.AdamW( model.generator.parameters(), lr=2e-5, betas=(0.9, 0.999), weight_decay=0.01 ) ``` -------------------------------- ### Causal Diffusion Inference Pipeline for Multi-Step Video Generation Source: https://context7.com/gair-nlp/livetalk/llms.txt This Python script initializes the CausalDiffusionInferencePipeline for high-quality video generation using multi-step sampling (e.g., 50 steps with UniPC or DPM++ solvers). It loads configuration, sets up the CUDA device, and instantiates the pipeline. The pipeline is designed for more computationally intensive, higher-fidelity video synthesis compared to the few-step approach. ```python import torch from omegaconf import OmegaConf from pipeline import CausalDiffusionInferencePipeline # Load configuration config = OmegaConf.load("configs/causal_inference.yaml") device = torch.device("cuda:0") # Initialize pipeline pipeline = CausalDiffusionInferencePipeline( args=config, device=device ) ``` -------------------------------- ### PyTorch KV Cache Initialization for Causal Generation Source: https://context7.com/gair-nlp/livetalk/llms.txt This Python code initializes the KV cache for causal attention, essential for block-wise autoregressive generation. It calculates cache sizes based on local attention window and sequence length, initializing tensors for keys, values, and attention indices for each transformer block. ```python import torch def initialize_kv_cache(num_transformer_blocks, batch_size, local_attn_size, frame_seq_length, dtype, device): """Initialize KV cache for causal attention with local attention window.""" kv_cache = [] # Calculate KV cache size based on local attention window if local_attn_size != -1: kv_cache_size = local_attn_size * frame_seq_length else: kv_cache_size = 32760 # Default max cache size # Initialize cache for each transformer block for _ in range(num_transformer_blocks): kv_cache.append({ "k": torch.zeros([batch_size, kv_cache_size, 12*128], dtype=dtype, device=device), "v": torch.zeros([batch_size, kv_cache_size, 12*128], dtype=dtype, device=device), "global_end_index": torch.tensor([0], dtype=torch.long, device=device), "local_end_index": torch.tensor([0], dtype=torch.long, device=device) }) return kv_cache # Usage example device = torch.device("cuda:0") dtype = torch.bfloat16 num_transformer_blocks = 30 batch_size = 1 local_attn_size = 15 frame_seq_length = 1024 kv_cache = initialize_kv_cache( num_transformer_blocks=num_transformer_blocks, batch_size=batch_size, local_attn_size=local_attn_size, frame_seq_length=frame_seq_length, dtype=dtype, device=device ) ``` -------------------------------- ### PyTorch Training Loop for Generator and Critic Models Source: https://context7.com/gair-nlp/livetalk/llms.txt This Python code snippet demonstrates a typical training loop for a generative model, including calculating generator and critic losses, performing backpropagation, and saving model checkpoints. It utilizes PyTorch for deep learning operations. ```python for epoch in range(num_epochs): for batch_idx, batch in enumerate(dataloader): # Get clean latents from real data clean_latent = batch['latents'].to(device, dtype=torch.bfloat16) # Encode text prompts conditional_dict = model.text_encoder(batch['prompts']) unconditional_dict = model.text_encoder(["" ] * len(batch['prompts'])) # Compute generator loss (DMD loss) generator_loss, log_dict = model.generator_loss( image_or_video_shape=clean_latent.shape, conditional_dict=conditional_dict, unconditional_dict=unconditional_dict, clean_latent=clean_latent, initial_latent=None ) # Backpropagation optimizer.zero_grad() generator_loss.backward() torch.nn.utils.clip_grad_norm_(model.generator.parameters(), 1.0) optimizer.step() # Train critic (fake score) critic_loss, critic_log_dict = model.critic_loss( image_or_video_shape=clean_latent.shape, conditional_dict=conditional_dict, unconditional_dict=unconditional_dict, clean_latent=clean_latent ) print(f"Epoch {epoch}, Batch {batch_idx}: " f"Generator Loss: {generator_loss.item():.4f}, " f"Critic Loss: {critic_loss.item():.4f}") # Save checkpoint torch.save({ 'generator': model.generator.state_dict(), 'generator_ema': model.generator_ema.state_dict() if hasattr(model, 'generator_ema') else None, 'config': config }, "checkpoint.pth") ``` -------------------------------- ### Generate Audio Embeddings with Torch Source: https://context7.com/gair-nlp/livetalk/llms.txt Generates audio embeddings using a PyTorch audio encoder. It processes input values and concatenates hidden states to create a comprehensive audio representation, which is then reshaped for video conditioning. Dependencies include PyTorch. ```python with torch.no_grad(): hidden_states = audio_encoder( input_values, seq_len=audio_len, output_hidden_states=True ) audio_embeddings = hidden_states.last_hidden_state for mid_hidden_states in hidden_states.hidden_states: audio_embeddings = torch.cat((audio_embeddings, mid_hidden_states), -1) # Reshape for video conditioning audio_emb = audio_embeddings.permute(0, 2, 1)[:, :, :, None, None] audio_emb = torch.cat([audio_emb[:, :, :1].repeat(1, 1, 3, 1, 1), audio_emb], 2) print(f"Audio embedding shape: {audio_emb.shape}") # Expected: [1, 768, audio_len+3, 1, 1] ``` -------------------------------- ### VAE Encoding and Decoding with WanVAEWrapper Source: https://context7.com/gair-nlp/livetalk/llms.txt Encodes images and videos into a latent space and decodes latents back to pixel space using the Wan VAE. It handles image loading, transformation, normalization, and uses a pre-trained VAE model. Dependencies include PyTorch, PIL, and torchvision. ```python import torch from PIL import Image import torchvision.transforms as TT from utils.wan_wrapper import WanVAEWrapper # Initialize VAE device = torch.device("cuda:0") dtype = torch.bfloat16 vae = WanVAEWrapper(vae_path="pretrained_checkpoints/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth") vae.to(device=device, dtype=dtype) vae.eval() # Encode image to latent image_path = "examples/inference/example1.jpg" image = Image.open(image_path).convert("RGB") transform = TT.Compose([ TT.Resize((480, 832)), TT.ToTensor(), ]) image_tensor = transform(image).unsqueeze(0).to(device) # [1, 3, 480, 832] # Normalize to [-1, 1] image_tensor = image_tensor * 2.0 - 1.0 image_tensor = image_tensor.unsqueeze(2) # [1, 3, 1, 480, 832] # Encode to latent space latent = vae.encode_to_latent(image_tensor.to(dtype=dtype)) print(f"Latent shape: {latent.shape}") # [1, 1, 16, 60, 104] # Decode back to pixels reconstructed = vae.decode_to_pixel(latent) reconstructed = (reconstructed * 0.5 + 0.5).clamp(0, 1) # Denormalize to [0, 1] print(f"Reconstructed shape: {reconstructed.shape}") # [1, 1, 3, 480, 832] # For video: latent shape [B, num_frames, 16, 60, 104] # video_latent = torch.randn([1, 21, 16, 60, 104], device=device, dtype=dtype) # video_pixels = vae.decode_to_pixel(video_latent) # [1, 21, 3, 480, 832] ``` -------------------------------- ### Run LiveTalk Inference Script Source: https://github.com/gair-nlp/livetalk/blob/main/README.md This bash script executes the inference process for LiveTalk. It requires specific input formats (image, audio, text prompt) and outputs a synchronized video. The script assumes the necessary model checkpoints have been downloaded. ```bash bash ./scripts/inference.sh ``` -------------------------------- ### Apply Patches to OmniAvatar Source: https://github.com/gair-nlp/livetalk/blob/main/README.md This bash script applies necessary patches to the OmniAvatar repository to ensure compatibility with LiveTalk. This step is crucial after cloning both repositories. ```bash bash scripts/add_patch.sh ``` -------------------------------- ### Text Encoding with UMT5 using WanTextEncoder Source: https://context7.com/gair-nlp/livetalk/llms.txt Generates text embeddings for conditional video generation using the UMT5-XXL encoder. It initializes the text encoder with pre-trained weights and tokenizers, then encodes text prompts into conditional and unconditional embeddings suitable for classifier-free guidance. Dependencies include PyTorch and a custom WanTextEncoder wrapper. ```python import torch from utils.wan_wrapper import WanTextEncoder # Initialize text encoder device = torch.device("cuda:0") dtype = torch.bfloat16 text_encoder = WanTextEncoder( text_encoder_path="pretrained_checkpoints/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth", tokenizer_path="pretrained_checkpoints/Wan2.1-T2V-1.3B/google/umt5-xxl/" ) text_encoder.to(device=device, dtype=dtype) text_encoder.eval() # Encode text prompts text_prompts = [ "A realistic video of a person speaking directly to the camera with natural expressions", "A person talking with animated facial features and synchronized lip movements" ] # Generate conditional embeddings with torch.no_grad(): conditional_dict = text_encoder(text_prompts) print(f"Prompt embeddings shape: {conditional_dict['prompt_embeds'].shape}") # Expected: [batch_size, seq_length, embedding_dim] # Generate unconditional embeddings for classifier-free guidance negative_prompt = "" unconditional_dict = text_encoder([negative_prompt] * len(text_prompts)) # Use in classifier-free guidance guidance_scale = 7.5 # combined_output = unconditional_output + guidance_scale * (conditional_output - unconditional_output) ``` -------------------------------- ### PyTorch KV and Cross-Attention Cache Reset Function Source: https://context7.com/gair-nlp/livetalk/llms.txt This Python function resets the KV and cross-attention caches for new inference or generation tasks. It sets the initialization flag for cross-attention caches to False and resets the global and local end indices for KV caches to zero. ```python import torch def reset_caches(kv_cache, crossattn_cache, device): """Reset KV and cross-attention caches for new inference.""" # Reset cross-attention cache initialization flags for block_cache in crossattn_cache: block_cache["is_init"] = False # Reset KV cache indices for block_cache in kv_cache: block_cache["global_end_index"] = torch.tensor([0], dtype=torch.long, device=device) block_cache["local_end_index"] = torch.tensor([0], dtype=torch.long, device=device) # Usage example device = torch.device("cuda:0") num_transformer_blocks = 30 batch_size = 1 local_attn_size = 15 frame_seq_length = 1024 dtype = torch.bfloat16 # Assuming kv_cache and crossattn_cache are already initialized # kv_cache = initialize_kv_cache(...) # crossattn_cache = initialize_crossattn_cache(...) # Reset for new generation # reset_caches(kv_cache, crossattn_cache, device) ``` === COMPLETE CONTENT === This response contains all available snippets from this library. No additional content exists. Do not make further requests.