Autoregressive Video Model
A simple implementation to understand how autoregressive models work for video generation.
Overview
This codebase demonstrates the core concepts of autoregressive video modeling:
- Autoregressive Generation: Predicting future frames based on past frames
- Patch-based Processing: Dividing frames into patches for efficient processing
- Transformer Architecture: Using self-attention for temporal modeling
- Causal Masking: Ensuring the model only attends to previous frames
Key Components
1. Model Architecture (model.py
)
- VideoAutoregressiveModel: Main model class
- Patchify/Unpatchify: Converts frames to/from patch sequences
- Positional Encoding: Adds temporal position information
- Causal Transformer: Ensures autoregressive property
2. Dataset (dataset.py
)
- SyntheticVideoDataset: Generates simple moving shapes
- Motion types: linear, circular, bounce
- Helps understand temporal patterns without complex real data
3. Training (train.py
)
- Standard PyTorch training loop
- MSE loss for frame prediction
- Gradient clipping for stability
- Checkpoint saving and tensorboard logging
4. Inference (inference.py
)
- VideoGenerator: Handles video generation
- Supports generation from noise or continuation
- Visualization utilities
Quick Start
# Run demo (trains for 10 epochs on synthetic data)
python example.py --mode demo
# Full training
python example.py --mode train
# Generate videos from checkpoint
python example.py --mode generate --checkpoint ./checkpoints/best_model.pt
Understanding the Code
How Autoregressive Video Generation Works
- Input Processing:
- Video frames are divided into patches (e.g., 8x8 pixels)
- Each patch becomes a token in the sequence
- Frames are flattened into a long sequence of patches
- Autoregressive Prediction:
- Given frames 1 to t-1, predict frame t
- Causal mask ensures no “looking ahead”
- Model learns temporal patterns in video
- Generation Process:
- Start with a few initial frames
- Predict next frame patch by patch
- Add predicted frame to context
- Repeat for desired length
Key Concepts Illustrated
- Patchification: See
patchify()
in model.py - Causal Masking: See transformer mask in
forward()
- Sequential Generation: See
generate()
method - Temporal Embeddings: Frame position embeddings
Experiments to Try
- Architecture Changes:
- Vary model size (d_model, num_layers)
- Change patch size (smaller = more detail)
- Adjust sequence length (max_frames)
- Training Variations:
- Different motion patterns in dataset
- Temperature sampling during generation
- Various learning rates and schedules
- Advanced Extensions:
- Add conditional generation (e.g., text prompts)
- Implement different positional encodings
- Try different attention mechanisms
Requirements
torch
numpy
matplotlib
tqdm
tensorboard
Notes
- This is a simplified implementation for learning purposes
- Real video models often use more sophisticated architectures
- Consider GPU memory when increasing resolution or sequence length