Visual World Model — bin-pick-pack (proprio + visual latents)

MLP-based world model (SystemDynamicsEnsemble) trained on the bin-pick-pack-coffee-capsules manipulation dataset. Predicts next proprioceptive state (17D) and next visual latent (32D) given a history of states, actions, and visual latents.

Architecture

  • Type: MLP ensemble (2 heads)
  • State dim: 17 (7 joint positions + 10 EEF pose as xyz/rot6d/gripper)
  • Action dim: 17 (same decomposition)
  • Visual dim: 32 (LAM-encoded visual latents from front camera)
  • History horizon: 2
  • Forecast horizon: 1
  • Checkpoint size: 1.9 MB

Training

Files

File Description
best.pt Best checkpoint (epoch 50)
config.yaml Training configuration (Isambard)

Checkpoint format

checkpoint = torch.load("best.pt", map_location="cpu")
# checkpoint["model_state_dict"] -> SystemDynamicsEnsemble.load_state_dict()
# checkpoint["epoch"], checkpoint["train_loss"], checkpoint["val_loss"], etc.

Integrity

sha256: 050dfaffa2c98ff112d6a0d2eba738328bac8b3934863bfecff59d62bd2d2410  best.pt

Verified by running sha256sum twice on the source file.

Usage

from rsl_rl.offline.offline_world_model_trainer import build_system_dynamics_model
import torch

model = build_system_dynamics_model(
    state_dim=17, action_dim=17, visual_dim=32,
    ensemble_size=2, history_horizon=2, device="cpu",
)
ckpt = torch.load("best.pt", map_location="cpu")
model.load_state_dict(ckpt["model_state_dict"])
model.eval()

# Single-step prediction
# state_hist: (1, 2, 17), action_hist: (1, 2, 17), visual_hist: (1, 2, 32)
state_pred, _, _, _, _, _, visual_pred = model(state_hist, action_hist, visual_hist)
Downloads last month
4
Video Preview
loading

Dataset used to train pravsels/visual-wm-binpack