--- license: mit tags: - world-model - diffusion - edm - pixel-space - snake-game - generative-model library_name: pytorch --- # Snake World Model - Pixel-Space EDM v2 A DIAMOND-style pixel-space EDM (Elucidating the Design Space of Diffusion-Based Generative Models) implementation for world modeling in the Snake game environment. This model predicts future game frames conditioned on previous frames and player actions. See GitHub repo for train and play code: https://github.com/roastedpotato66/snake-world-modeling ## Model Details ### Architecture The model uses a **DIAMOND-style UNet** architecture working directly in pixel space: - **Input**: `(B, 15, 64, 64)` - noisy target frame (3 channels) + 4 context frames (12 channels) - **Output**: `(B, 3, 64, 64)` - denoised next frame prediction - **Base dimensions**: 128 channels, 512 condition dimension - **Resolution**: 64×64 RGB images #### Key Components 1. **UNet Encoder-Decoder** - Encoder: 64×64 → 32×32 → 16×16 → 8×8 (3 downsampling blocks) - Bottleneck: Self-attention at 8×8 resolution for global reasoning - Decoder: 8×8 → 16×16 → 32×32 → 64×64 (3 upsampling blocks) - Skip connections between encoder and decoder 2. **Adaptive Group Normalization** - Conditions normalization on combined action + noise level embeddings - Enables strong action conditioning throughout the network 3. **EDM Preconditioning** - Preconditioned network output: `c_skip * x_noisy + c_out * network(x)` - Stable training with very few denoising steps (only 3 steps needed) 4. **Frame Stacking** - 4 previous frames concatenated channel-wise (12 channels total) - Provides temporal context for prediction ### Training Metrics - **Best Epoch**: 34 - **Best Validation MSE**: 0.000137 - **Training Loss (final)**: 0.000798 - **CFG Difference**: 0.003674 ## Usage The simplest way is to directly download the model.pt and create and move it to the output/ folder. Before playing, you will need to generate some data (1k is enough) for initialization. Follow the instructions in the GitHub repository's README. ### Interactive Play Use the provided play script in the GitHub: ```bash python scripts/play_pixel_edm.py \ --model_path model.pt \ --data_dir data/images \ --cfg_scale 2.0 \ --steps 3 ``` **Controls:** - `WASD` or Arrow Keys - Move snake - `R` - Reset with new random seed from data - `ESC` - Quit ## Training Details ### Dataset - **Format**: 64×64 RGB images from Snake game - **Context**: 4 consecutive frames - **Actions**: One-hot encoded [UP, DOWN, LEFT, RIGHT] - **Special events**: Death and eating events weighted 5× for balanced training ### Training Configuration - **Loss**: Weighted MSE with EDM loss weighting - **Optimizer**: AdamW (lr=1e-4, weight_decay=1e-4) - **Scheduler**: Cosine annealing over 40 epochs - **Mixed precision**: BF16 AMP - **EMA**: Exponential moving average (decay=0.9999) - **Weighted sampling**: 5× weight for death/eating events - **CFG dropout**: 30% for classifier-free guidance - **Batch size**: 512 - **Denoising steps**: 3 (Euler method)