MR-JEPA: Multimodal Reasoning via Joint-Embedding Predictive Architecture

A world model for multimodal reasoning that refines a latent belief state over K=3 steps using JEPA-style prediction, evidence gating, and dense visual backbones.

Key Idea

Traditional multimodal models produce answers in a single forward pass. MR-JEPA instead models the evolution of a belief state as the system reasons about a question:

z₀ (initial evidence) → z₁ (first refinement) → z₂ (deeper reasoning) → z₃ (answer)

This trajectory is supervised by a JEPA objective: a target encoder (EMA) generates target latent states, and the online predictor learns to predict them. The JEPA loss encourages the model to learn meaningful intermediate reasoning states — not just the final answer.


Architecture

┌──────────────┐     ┌────────────────────┐     ┌──────────────────┐     ┌───────────────┐
│  DINOv3-L/16 │────▶│   Evidence Memory  │────▶│  Latent Rollout  │────▶│ Disc. Head    │
│   (frozen)   │     │ (Perceiver Resampl)│     │  z₀→z₁→z₂→z₃    │     │ (MC scoring)  │
└──────────────┘     └────────┬───────────┘     │ (shared block)   │     └───────────────┘
                              │                 └────────┬─────────┘     ┌───────────────┐
┌──────────────┐              │                          │          ├───▶│ Gen. Decoder   │
│ Qwen3-Embed  │──────────────┘                  ┌───────┴────────┐     │ (Qwen3.5-4B)  │
│   0.6B       │                                 │ Target Encoder │     └───────────────┘
│   (frozen)   │                                 │  (EMA copy)    │
└──────────────┘                                 └────────────────┘
                                                         │
┌──────────────┐                                  JEPA Loss:
│ Phase 3 opt: │──────────┘                       SmoothL1/Cosine
│ PaddleOCR-VL │                                  + SIGReg (purist)
│ SAM 3.1      │                                  / VICReg (hybrid)
└──────────────┘

Component Stack

Module Primary Choice Alternative Notes
Visual backbone timm/vit_large_patch16_dinov3.lvd1689m — DINOv3-L/16, 1024-dim, 300M DINOv3-B/16 (purist); DINOv2-L/14 (ablation) Frozen Phase 1; last 6 layers unfrozen Phase 2
Text encoder Qwen/Qwen3-Embedding-0.6B — 1024-dim, 596M Qwen3-Embedding-4B (heavier); EmbeddingGemma-300M (lighter) Frozen Phase 1; last 4 layers unfrozen Phase 2
Evidence memory Perceiver Resampler, 64 queries, 4 cross-attn layers Q-Former as baseline Modality-typed tokens (visual/text/OCR/layout/chart/SAM)
OCR / doc / charts PaddlePaddle/PaddleOCR-VL-1.5 — 958M MinerU2.5 for heavy PDF parsing Phase 3 only, offline token extraction
Segmentation jetjodh/sam3.1 — SAM 3.1, non-gated mirror SAM 2.1-Large (stable) Phase 3 optional, offline mask extraction
Latent rollout Shared transformer predictor, 6 layers, K=3 Per-step unshared blocks (ablation) Weight-tied across steps; sigmoid evidence gates
Target encoder EMA copy (cosine 0.996→1.0) of evidence+rollout Frozen target (ablation baseline) From I-JEPA
JEPA loss SmoothL1 + VICReg (hybrid); Cosine + SIGReg (purist) MSE (ablation) SIGReg emphasis in purist branch
Disc. head MLP/bilinear scorer Cross-encoder scorer (ablation) Attention-pooled z_K × option embeddings
Gen. decoder Qwen/Qwen3.5-4B — 4.7B, multimodal HuggingFaceTB/SmolLM3-3B (cheaper); Gemma3-4B Phase 3+, cross-attends to z_K + evidence
Teacher/baseline InternVL3.5 / Qwen3-VL External comparison only NOT used as internal module

Training Protocol

Phase 1: Reasoning Core (15–20 epochs)

  • Freeze all perception (DINOv3 + Qwen3-Embedding)
  • Train evidence memory + latent rollout + discriminative head
  • Full JEPA loss + task loss
  • LR: 3e-4, effective batch: 64

Phase 2: Perception Fine-tuning (10 epochs)

  • Unfreeze last 6 DINOv3 layers + last 4 Qwen3-Embedding layers (1e-5)
  • Continue training reasoning core (1e-4)

Phase 3: Enriched Evidence + Generative Decoder (10 epochs)

  • Enable PaddleOCR-VL tokens, SAM 3.1 masks, layout/chart tokens
  • Attach Qwen3.5-4B generative decoder for open-ended answers
  • End-to-end fine-tuning, LR: 5e-5

Target Benchmarks (9)

Benchmark Type Metric Key Challenge
MMMU MC (multi-image) Accuracy Multi-discipline, up to 7 images
MathVista Mixed MC/Open Accuracy Mathematical reasoning
ScienceQA MC Accuracy Scientific diagrams, nullable images
AI2D MC Accuracy Science diagram comprehension
MMBench MC CircularEval Acc General visual understanding
MMStar MC Accuracy Vision-dependent questions
DocVQA Open ANLS Document text extraction
TextVQA Open VQA Accuracy Scene text reading
ChartQA Open Relaxed Accuracy Chart data extraction

Experimental Branches

Hybrid-main (competitive)

  • DINOv3-L backbone, SmoothL1 + VICReg, K=3
  • Full enriched evidence in Phase 3
  • Target: state-of-the-art on all benchmarks

Purist-side (scientific validation)

  • DINOv3-B backbone, Cosine + SIGReg, K=5
  • No enriched evidence, pure JEPA reasoning
  • Target: demonstrate JEPA contributes beyond perception

Ablation Experiments

Each experiment maps 1:1 to a CLI flag in train_mrjepa.py.

Experiment CLI flag Modification Purpose
hybrid_main (default) Full model Baseline
no_jepa --no_jepa Remove L_JEPA, task loss only Validate JEPA objective
no_rollout --no_rollout K=0, use z₀ directly Validate iterative refinement
no_gate --no_evidence_gate Remove evidence gating Validate adaptive evidence flow
K1 / K5 / K7 --K 1/5/7 Vary rollout depth Find optimal depth
dinov2_ablation --backbone dinov2 DINOv2-L/14 backbone DINOv3 vs DINOv2
mse_loss --loss_fn mse MSE (L2) JEPA loss Original I-JEPA loss
cosine_loss --loss_fn cosine Cosine similarity JEPA loss Purist-style loss
no_sigreg --no_sigreg Disable SIGReg anti-collapse Test regularization
vicreg_only --no_sigreg --use_vicreg VICReg only Alternative anti-collapse
purist --purist DINOv3-B, K=5, Cosine+SIGReg Isolate JEPA contribution

Project Structure

MR-JEPA/
├── README.md                    # This file
├── train_mrjepa.py              # Complete training script (CLI, all ablations)
├── test_architecture.py         # Architecture validation tests (synthetic data)
│
├── mr_jepa/
│   ├── __init__.py
│   ├── ARCHITECTURE.md          # Detailed architecture specification
│   │
│   ├── configs/
│   │   ├── __init__.py
│   │   └── model_config.py      # All hyperparameter dataclasses
│   │
│   ├── models/
│   │   ├── __init__.py
│   │   ├── mr_jepa.py           # Main model (integrates all components)
│   │   ├── backbones.py         # Visual (DINOv3/v2) + Text (Qwen3-Embedding)
│   │   ├── evidence_memory.py   # Perceiver Resampler multimodal fusion
│   │   ├── latent_rollout.py    # K-step shared predictor + evidence gates
│   │   ├── target_encoder.py    # EMA encoder + JEPA/SIGReg/VICReg losses
│   │   └── answer_heads.py      # Discriminative (MC) + Generative (open-ended)
│   │
│   ├── data/
│   │   ├── __init__.py
│   │   ├── unified_dataset.py   # 9-benchmark unified loader with format quirks
│   │   └── data_utils.py        # Collator, dataloader factory, benchmark configs
│   │
│   ├── training/
│   │   ├── __init__.py
│   │   ├── trainer.py           # 3-phase training loop
│   │   └── phase_scheduler.py   # Phase transitions, LR scheduling
│   │
│   ├── evaluation/
│   │   ├── __init__.py
│   │   └── metrics.py           # Accuracy, ANLS, VQA Acc, Relaxed Acc
│   │
│   └── utils/
│       ├── __init__.py
│       ├── visualization.py     # Trajectory PCA, gate analysis
│       └── ablation.py          # Systematic ablation runner
│
├── results/                     # Training results (auto-pushed)
│   ├── hybrid_main.json
│   ├── no_jepa.json
│   ├── no_rollout.json
│   └── ...
│
└── checkpoints/                 # Best model checkpoints (auto-pushed)
    ├── hybrid_main_best.pt
    └── ...

Paper Contribution

A world model for multimodal reasoning: We demonstrate that modeling the evolution of a latent belief state via JEPA-style prediction improves performance on static multimodal benchmarks compared to single-pass baselines. The evidence-gated rollout with K=3 steps learns meaningful intermediate reasoning states, validated through ablation studies across 9 benchmarks. The JEPA objective (not human chain-of-thought) supervises a latent trajectory generated by an EMA target encoder, showing that self-supervised dynamics training transfers to discriminative reasoning tasks.


Key References

  1. I-JEPA (Assran et al., 2023) — arxiv:2301.08243: JEPA architecture, EMA target encoder, L2 prediction loss, narrow predictor
  2. LeWorldModel (Maes et al., 2025) — arxiv:2603.19312: SIGReg anti-collapse, end-to-end JEPA
  3. Coconut (Yu et al., 2024) — arxiv:2412.06769: Chain of Continuous Thought, latent reasoning
  4. DINOv3 (Meta, 2025) — arxiv:2508.10104: Dense SSL with RoPE + Gram anchoring
  5. SoftCoT++ (Xu et al., 2025) — arxiv:2505.11484: Soft chain-of-thought with contrastive learning

License

Apache-2.0

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Papers for JorgeAV/MR-JEPA