MR-JEPA: Multimodal Reasoning via Joint-Embedding Predictive Architecture
A world model for multimodal reasoning that refines a latent belief state over K=3 steps using JEPA-style prediction, evidence gating, and dense visual backbones.
Key Idea
Traditional multimodal models produce answers in a single forward pass. MR-JEPA instead models the evolution of a belief state as the system reasons about a question:
z₀ (initial evidence) → z₁ (first refinement) → z₂ (deeper reasoning) → z₃ (answer)
This trajectory is supervised by a JEPA objective: a target encoder (EMA) generates target latent states, and the online predictor learns to predict them. The JEPA loss encourages the model to learn meaningful intermediate reasoning states — not just the final answer.
Architecture
┌──────────────┐ ┌────────────────────┐ ┌──────────────────┐ ┌───────────────┐
│ DINOv3-L/16 │────▶│ Evidence Memory │────▶│ Latent Rollout │────▶│ Disc. Head │
│ (frozen) │ │ (Perceiver Resampl)│ │ z₀→z₁→z₂→z₃ │ │ (MC scoring) │
└──────────────┘ └────────┬───────────┘ │ (shared block) │ └───────────────┘
│ └────────┬─────────┘ ┌───────────────┐
┌──────────────┐ │ │ ├───▶│ Gen. Decoder │
│ Qwen3-Embed │──────────────┘ ┌───────┴────────┐ │ (Qwen3.5-4B) │
│ 0.6B │ │ Target Encoder │ └───────────────┘
│ (frozen) │ │ (EMA copy) │
└──────────────┘ └────────────────┘
│
┌──────────────┐ JEPA Loss:
│ Phase 3 opt: │──────────┘ SmoothL1/Cosine
│ PaddleOCR-VL │ + SIGReg (purist)
│ SAM 3.1 │ / VICReg (hybrid)
└──────────────┘
Component Stack
| Module | Primary Choice | Alternative | Notes |
|---|---|---|---|
| Visual backbone | timm/vit_large_patch16_dinov3.lvd1689m — DINOv3-L/16, 1024-dim, 300M |
DINOv3-B/16 (purist); DINOv2-L/14 (ablation) | Frozen Phase 1; last 6 layers unfrozen Phase 2 |
| Text encoder | Qwen/Qwen3-Embedding-0.6B — 1024-dim, 596M |
Qwen3-Embedding-4B (heavier); EmbeddingGemma-300M (lighter) | Frozen Phase 1; last 4 layers unfrozen Phase 2 |
| Evidence memory | Perceiver Resampler, 64 queries, 4 cross-attn layers | Q-Former as baseline | Modality-typed tokens (visual/text/OCR/layout/chart/SAM) |
| OCR / doc / charts | PaddlePaddle/PaddleOCR-VL-1.5 — 958M |
MinerU2.5 for heavy PDF parsing | Phase 3 only, offline token extraction |
| Segmentation | jetjodh/sam3.1 — SAM 3.1, non-gated mirror |
SAM 2.1-Large (stable) | Phase 3 optional, offline mask extraction |
| Latent rollout | Shared transformer predictor, 6 layers, K=3 | Per-step unshared blocks (ablation) | Weight-tied across steps; sigmoid evidence gates |
| Target encoder | EMA copy (cosine 0.996→1.0) of evidence+rollout | Frozen target (ablation baseline) | From I-JEPA |
| JEPA loss | SmoothL1 + VICReg (hybrid); Cosine + SIGReg (purist) | MSE (ablation) | SIGReg emphasis in purist branch |
| Disc. head | MLP/bilinear scorer | Cross-encoder scorer (ablation) | Attention-pooled z_K × option embeddings |
| Gen. decoder | Qwen/Qwen3.5-4B — 4.7B, multimodal |
HuggingFaceTB/SmolLM3-3B (cheaper); Gemma3-4B |
Phase 3+, cross-attends to z_K + evidence |
| Teacher/baseline | InternVL3.5 / Qwen3-VL | External comparison only | NOT used as internal module |
Training Protocol
Phase 1: Reasoning Core (15–20 epochs)
- Freeze all perception (DINOv3 + Qwen3-Embedding)
- Train evidence memory + latent rollout + discriminative head
- Full JEPA loss + task loss
- LR: 3e-4, effective batch: 64
Phase 2: Perception Fine-tuning (10 epochs)
- Unfreeze last 6 DINOv3 layers + last 4 Qwen3-Embedding layers (1e-5)
- Continue training reasoning core (1e-4)
Phase 3: Enriched Evidence + Generative Decoder (10 epochs)
- Enable PaddleOCR-VL tokens, SAM 3.1 masks, layout/chart tokens
- Attach Qwen3.5-4B generative decoder for open-ended answers
- End-to-end fine-tuning, LR: 5e-5
Target Benchmarks (9)
| Benchmark | Type | Metric | Key Challenge |
|---|---|---|---|
| MMMU | MC (multi-image) | Accuracy | Multi-discipline, up to 7 images |
| MathVista | Mixed MC/Open | Accuracy | Mathematical reasoning |
| ScienceQA | MC | Accuracy | Scientific diagrams, nullable images |
| AI2D | MC | Accuracy | Science diagram comprehension |
| MMBench | MC | CircularEval Acc | General visual understanding |
| MMStar | MC | Accuracy | Vision-dependent questions |
| DocVQA | Open | ANLS | Document text extraction |
| TextVQA | Open | VQA Accuracy | Scene text reading |
| ChartQA | Open | Relaxed Accuracy | Chart data extraction |
Experimental Branches
Hybrid-main (competitive)
- DINOv3-L backbone, SmoothL1 + VICReg, K=3
- Full enriched evidence in Phase 3
- Target: state-of-the-art on all benchmarks
Purist-side (scientific validation)
- DINOv3-B backbone, Cosine + SIGReg, K=5
- No enriched evidence, pure JEPA reasoning
- Target: demonstrate JEPA contributes beyond perception
Ablation Experiments
Each experiment maps 1:1 to a CLI flag in train_mrjepa.py.
| Experiment | CLI flag | Modification | Purpose |
|---|---|---|---|
hybrid_main |
(default) | Full model | Baseline |
no_jepa |
--no_jepa |
Remove L_JEPA, task loss only | Validate JEPA objective |
no_rollout |
--no_rollout |
K=0, use z₀ directly | Validate iterative refinement |
no_gate |
--no_evidence_gate |
Remove evidence gating | Validate adaptive evidence flow |
K1 / K5 / K7 |
--K 1/5/7 |
Vary rollout depth | Find optimal depth |
dinov2_ablation |
--backbone dinov2 |
DINOv2-L/14 backbone | DINOv3 vs DINOv2 |
mse_loss |
--loss_fn mse |
MSE (L2) JEPA loss | Original I-JEPA loss |
cosine_loss |
--loss_fn cosine |
Cosine similarity JEPA loss | Purist-style loss |
no_sigreg |
--no_sigreg |
Disable SIGReg anti-collapse | Test regularization |
vicreg_only |
--no_sigreg --use_vicreg |
VICReg only | Alternative anti-collapse |
purist |
--purist |
DINOv3-B, K=5, Cosine+SIGReg | Isolate JEPA contribution |
Project Structure
MR-JEPA/
├── README.md # This file
├── train_mrjepa.py # Complete training script (CLI, all ablations)
├── test_architecture.py # Architecture validation tests (synthetic data)
│
├── mr_jepa/
│ ├── __init__.py
│ ├── ARCHITECTURE.md # Detailed architecture specification
│ │
│ ├── configs/
│ │ ├── __init__.py
│ │ └── model_config.py # All hyperparameter dataclasses
│ │
│ ├── models/
│ │ ├── __init__.py
│ │ ├── mr_jepa.py # Main model (integrates all components)
│ │ ├── backbones.py # Visual (DINOv3/v2) + Text (Qwen3-Embedding)
│ │ ├── evidence_memory.py # Perceiver Resampler multimodal fusion
│ │ ├── latent_rollout.py # K-step shared predictor + evidence gates
│ │ ├── target_encoder.py # EMA encoder + JEPA/SIGReg/VICReg losses
│ │ └── answer_heads.py # Discriminative (MC) + Generative (open-ended)
│ │
│ ├── data/
│ │ ├── __init__.py
│ │ ├── unified_dataset.py # 9-benchmark unified loader with format quirks
│ │ └── data_utils.py # Collator, dataloader factory, benchmark configs
│ │
│ ├── training/
│ │ ├── __init__.py
│ │ ├── trainer.py # 3-phase training loop
│ │ └── phase_scheduler.py # Phase transitions, LR scheduling
│ │
│ ├── evaluation/
│ │ ├── __init__.py
│ │ └── metrics.py # Accuracy, ANLS, VQA Acc, Relaxed Acc
│ │
│ └── utils/
│ ├── __init__.py
│ ├── visualization.py # Trajectory PCA, gate analysis
│ └── ablation.py # Systematic ablation runner
│
├── results/ # Training results (auto-pushed)
│ ├── hybrid_main.json
│ ├── no_jepa.json
│ ├── no_rollout.json
│ └── ...
│
└── checkpoints/ # Best model checkpoints (auto-pushed)
├── hybrid_main_best.pt
└── ...
Paper Contribution
A world model for multimodal reasoning: We demonstrate that modeling the evolution of a latent belief state via JEPA-style prediction improves performance on static multimodal benchmarks compared to single-pass baselines. The evidence-gated rollout with K=3 steps learns meaningful intermediate reasoning states, validated through ablation studies across 9 benchmarks. The JEPA objective (not human chain-of-thought) supervises a latent trajectory generated by an EMA target encoder, showing that self-supervised dynamics training transfers to discriminative reasoning tasks.
Key References
- I-JEPA (Assran et al., 2023) — arxiv:2301.08243: JEPA architecture, EMA target encoder, L2 prediction loss, narrow predictor
- LeWorldModel (Maes et al., 2025) — arxiv:2603.19312: SIGReg anti-collapse, end-to-end JEPA
- Coconut (Yu et al., 2024) — arxiv:2412.06769: Chain of Continuous Thought, latent reasoning
- DINOv3 (Meta, 2025) — arxiv:2508.10104: Dense SSL with RoPE + Gram anchoring
- SoftCoT++ (Xu et al., 2025) — arxiv:2505.11484: Soft chain-of-thought with contrastive learning
License
Apache-2.0