Visual World Model — bin-pick-pack (proprio + visual latents)
MLP-based world model (SystemDynamicsEnsemble) trained on the bin-pick-pack-coffee-capsules manipulation dataset. Predicts next proprioceptive state (17D) and next visual latent (32D) given a history of states, actions, and visual latents.
Architecture
- Type: MLP ensemble (2 heads)
- State dim: 17 (7 joint positions + 10 EEF pose as xyz/rot6d/gripper)
- Action dim: 17 (same decomposition)
- Visual dim: 32 (LAM-encoded visual latents from front camera)
- History horizon: 2
- Forecast horizon: 1
- Checkpoint size: 1.9 MB
Training
- Dataset: villekuosmanen/bin_pick_pack_coffee_capsules — 47865 frames, 200 episodes
- Visual latents: Precomputed from fine-tuned LAM encoder (pravsels/lam-binpack-finetune)
- Split: 35387 train / 12078 val sequences (val_ratio=0.25, seed=0)
- Epochs: 50
- Batch size: 64
- Learning rate: 3e-4 (cosine schedule, min_lr=3e-5)
- Final train loss: 0.09442
- Final val loss: 0.09916
- Visual loss: 0.07207
- W&B: pravsels/binpack-world-model/runs/2pq0n2mx
Files
| File | Description |
|---|---|
best.pt |
Best checkpoint (epoch 50) |
config.yaml |
Training configuration (Isambard) |
Checkpoint format
checkpoint = torch.load("best.pt", map_location="cpu")
# checkpoint["model_state_dict"] -> SystemDynamicsEnsemble.load_state_dict()
# checkpoint["epoch"], checkpoint["train_loss"], checkpoint["val_loss"], etc.
Integrity
sha256: 050dfaffa2c98ff112d6a0d2eba738328bac8b3934863bfecff59d62bd2d2410 best.pt
Verified by running sha256sum twice on the source file.
Usage
from rsl_rl.offline.offline_world_model_trainer import build_system_dynamics_model
import torch
model = build_system_dynamics_model(
state_dim=17, action_dim=17, visual_dim=32,
ensemble_size=2, history_horizon=2, device="cpu",
)
ckpt = torch.load("best.pt", map_location="cpu")
model.load_state_dict(ckpt["model_state_dict"])
model.eval()
# Single-step prediction
# state_hist: (1, 2, 17), action_hist: (1, 2, 17), visual_hist: (1, 2, 32)
state_pred, _, _, _, _, _, visual_pred = model(state_hist, action_hist, visual_hist)
- Downloads last month
- 4