🧠 BrainGemma3D β€” Brain Report Automation via Inflated Vision Transformers in 3D

BrainGemma3D is a multimodal vision-language model that generates clinically accurate radiology reports directly from native 3D brain MRI volumes. Unlike 2D slice-based approaches, BrainGemma3D processes MRI scans volumetrically, preserving the spatial context critical for accurate neuroradiological interpretation.

GitHub Repository Kaggle Notebook
MedGemma Challenge

🎯 Key Features

  • πŸ”¬ Native 3D Processing: Inflated 2D medical vision encoder (MedSigLIP) to 3D for volumetric understanding
  • πŸ“ Clinical Accuracy: 95.1% F1 score on pathology entity recognition (on BraTS dataset)
  • 🧭 Spatial Awareness: 68.9% laterality F1 (correct left/right hemisphere localization)
  • πŸ” Interpretable: LIME-based 3D attribution maps show which brain regions drive predictions
  • πŸš€ Efficient: Processes full 3D volumes with 32 compressed visual tokens
  • πŸ₯ Research-Ready: Pre-trained on 369 brain tumor cases + 99 healthy controls

πŸ—οΈ Architecture

BrainGemma3D combines:

  1. 3D Vision Encoder: MedSigLIP inflated to 3D via center-frame initialization (Conv2D β†’ Conv3D)
    Base model: google/medsiglip-448

  2. Token Compressor: 2-layer Perceiver that reduces 3D patches to 32 visual tokens

  3. Vision-Language Projector: 2-layer MLP that projects visual tokens to language model embedding space

  4. Language Model: 4-bit quantized MedGemma-1.5-4B-IT with LoRA adapters
    Base model: google/medgemma-1.5-4b-it


πŸš€ Usage

Requirements

pip install torch torchvision transformers nibabel scikit-image lime

Model Download

from huggingface_hub import snapshot_download

# 1. Download the repository containing our custom architecture from Hugging Face
repo_id = "praiselab-picuslab/BrainGemma3D"
print(f"Downloading repository: {repo_id}...")
local_dir = snapshot_download(repo_id)
print(f"βœ… Repository downloaded to: {local_dir}")

Quick Start

import os
import torch
import sys
sys.path.append(local_dir)

from medgemma3d_architecture import MedGemma3D, load_nifti_volume, CANONICAL_PROMPT

# Automatically select the optimal hardware accelerator (GPU if available, otherwise CPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Hardware accelerator selected: {device}")

# 2. Instantiate the base architecture (3D-inflated MedSigLIP + MedGemma)
model = MedGemma3D(
    vision_model_dir=f"{local_dir}/vision_model",
    language_model_dir=f"{local_dir}/language_model",
    depth=2,
    num_vision_tokens=32,
    freeze_vision=True,
    freeze_language=True,
    device_map={"": 0} if device == "cuda" else None,
)

# 3. Load projector
proj_path = os.path.join(local_dir, "projector_vis_scale.pt")
print(f"Loading custom projector weights from: {proj_path}...")

# Load the checkpoint into memory
ckpt = torch.load(proj_path, map_location=device)

# Inject the weights into the visual projector (which bridges Vision and Language)
model.vision_projector.load_state_dict(ckpt["vision_projector"])

# Load the visual scaling factor, ensuring correct tensor formatting
if ckpt.get("vis_scale") is not None:
    if isinstance(ckpt["vis_scale"], torch.Tensor):
        model.vis_scale.data = ckpt["vis_scale"].to(device)
    else:
        model.vis_scale.data.fill_(ckpt["vis_scale"])

# Transition the model to evaluation mode for inference
model.eval()
print("βœ… BrainGemma3D is fully loaded and ready for inference!")

# 4. Load MRI scan
volume = load_nifti_volume(
    "path/to/brain_flair.nii.gz",
    target_size=(32, 128, 128)
).to(device)

if volume.ndim == 4:
    volume = volume.unsqueeze(0)

# 5. Generate report
with torch.no_grad():
    report = model.generate_report(
        volume,
        prompt=CANONICAL_PROMPT,
        max_new_tokens=256,
        temperature=0.1,
        top_p=0.9,
    )

print("\n===== GENERATED REPORT =====\n")
print(report)

Expected Output Example

Generated Report:
The lesion area is in the left parietal and frontal lobes with mixed high-signal 
areas. Edema signals are mainly observed around these lesions, indicating significant 
edema presence affecting parts of both frontal and temporal regions as well as some 
portions within the parietal lobe. Necrosis may be present at low signal intensity 
or scattered throughout certain sections of the brain tissue affected by edema. 
Ventricular compression effects on adjacent ventricles can occur due to pressure 
from surrounding tissues near the ventricular system.

πŸŽ“ Training Pipeline

BrainGemma3D is trained in three progressive stages to prevent catastrophic forgetting:

Phase 1: Contrastive Grounding (Image-Text Alignment)

  • Goal: Align 3D visual features with textual report embeddings
  • Loss: InfoNCE (CLIP-style contrastive learning)
  • Trainable: 3D Vision Encoder + Projector
  • Frozen: Language Model
  • Epochs: 100

Phase 2A: Projector Warmup

  • Goal: Train the projector to condition the LM effectively
  • Loss: Next-token prediction (Cross-Entropy)
  • Trainable: Projector only
  • Frozen: Vision Encoder + Language Model
  • Epochs: 100

Phase 2B: LoRA Linguistic Specialization

  • Goal: Adapt LM to generate structured clinical reports
  • Loss: Next-token prediction (Cross-Entropy)
  • Trainable: Projector + LoRA adapters (rank=4) on LM attention layers
  • Frozen: Vision Encoder + LM base weights
  • Epochs: 100

Dataset:

  • 369 BraTS 2020 brain tumor MRI cases with radiologist-written reports from TextBraTS 2021
  • 99 healthy control scans with synthetic reports from MPI-Leipzig Mind-Brain-Body
  • Stratified group-based splits (70% train / 10% val / 20% test) to prevent patient leakage

πŸ“Š Performance

Evaluated on 468 subjects (369 BraTS pathological + 99 healthy controls) with group-based splits.

Quantitative Results (Test Set)

Model BLEU-1 BLEU-4 ROUGE-L CIDEr Lat F1 Anat F1 Path F1
Med3DVLM (3D Generalist) 0.051 0.005 0.083 0.007 0.300 0.225 0.119
MedGemma 1.5 (2D Slice) 0.245 0.024 0.189 0.029 0.526 0.461 0.413
BrainGemma3D (Ours) 0.302 0.098 0.289 0.293 0.689 0.691 0.951

Clinical Metrics Breakdown

  • Laterality F1: 0.689 β€” Correct hemispheric localization (left/right)
  • Anatomy F1: 0.691 β€” Accurate anatomical structure identification
  • Pathology F1: 0.951 β€” Near-perfect pathological entity recognition
  • Healthy Specificity: 1.0 β€” Zero hallucinations on healthy controls

Key Insight: The +130% gain in Pathology F1 (0.951 vs 0.413 compared to 2D baseline) demonstrates that native 3D processing is essential for diagnostic accuracy in neuroradiology.


πŸ” Interpretability

BrainGemma3D includes LIME-based 3D interpretability to visualize which brain regions drive diagnostic predictions.

from braingemma3d_interpretability import run_interpretability


# 6. Run interpretability analysis
weights, wvol = run_interpretability(
    model=model,
    load_nifti_volume=load_nifti_volume,
    CANONICAL_PROMPT=CANONICAL_PROMPT,
    mri_path="path/to/brain_flair.nii.gz",
    report=report,
    output_dir="./interpretability_output",
    lime_samples=100,      # Number of perturbations (more = better but slower)
    n_segments=20,         # Number of brain regions to analyze
    alpha=0.45,            # Overlay transparency
    clip_q=0.99,           # Heatmap clipping
    seed=42,
)

Output:

  • overlay_slices.png β€” Full 3D heatmap (red=supportive, blue=contradicting)
  • lime_2x3_grid.png β€” 2Γ—3 grid with selected slices (original + LIME overlay)
  • lime_top_supervoxels_grid.png β€” Most influential supervoxels
  • lime_weights.json β€” Supervoxel importance scores

Expected Output Example

LIME Interpretability

Figure 1: LIME attribution maps for a BraTS sample. Red regions show supervoxels that positively contribute to pathology predictions. The model correctly focuses on tumor-affected areas in the left parietal and frontal lobes.


βš™οΈ Model Details

  • Model Type: Multimodal Vision-Language Model (3D MRI β†’ Text)
  • Architecture: Inflated ViT + Perceiver Compressor + MLP Projector + Quantized Gemma-1.5
  • Input: 3D brain MRI FLAIR volumes (64Γ—128Γ—128 voxels)
  • Output: Free-form radiology reports (up to 256 tokens)
  • Parameters: ~454M (450M vision + 4B language, 4-bit quantized)
  • Training Compute: 1Γ— NVIDIA A100 64GB (β‰ˆ12 GPU-hours total)
  • Framework: PyTorch 2.0, Transformers 4.40+

Preprocessing Requirements

  • Orientation: RAS (as-closest-canonical)
  • Resolution: Resampled to (64, 128, 128) via trilinear interpolation
  • Normalization: Percentile clipping (p1, p99) + z-score normalization
  • Format: NIfTI (.nii or .nii.gz)

⚠️ Limitations & Intended Use

βœ… Intended Use

  • Research: Medical AI research, neuroradiology automation
  • Education: Teaching radiology residents about report generation
  • Prototyping: Building diagnostic support tools (non-clinical)

❌ Not Intended For

  • Clinical Diagnosis: This model is NOT FDA/CE approved for medical use
  • Primary Interpretation: Always verify with board-certified radiologists
  • Real-Time Emergency: Not validated for acute stroke or trauma cases

Known Limitations

  • Training Bias: Trained primarily on glioblastoma (BraTS dataset) β€” may underperform on other pathologies
  • Language: English only (radiology reports)
  • Hallucination Risk: May generate plausible but incorrect anatomical details (always verify)
  • Compute Requirements: Requires GPU with β‰₯16GB VRAM for inference

πŸ₯ Clinical Validation Notes

BrainGemma3D achieved 95.1% pathology F1 on the BraTS, but this does NOT imply clinical readiness. Key considerations:

  1. Dataset Homogeneity: BraTS contains predominantly glioblastomas β€” performance on other tumor types (meningiomas, metastases) is unknown
  2. Report Quality: Ground truth reports are from a single institution β€” may not generalize to other radiology practices
  3. No Radiologist Review: Generated reports have not been clinically validated by neuroradiologists
  4. Regulatory Status: Not cleared by FDA, EMA, or any regulatory body

Recommendation: Use only in research settings with appropriate ethical oversight and informed consent.


πŸ™ Acknowledgements

This project was developed by:

Mariano Barone Β· Francesco Di Serio Β· Giuseppe Riccio Β· Antonio Romano Β· Vincenzo Moscato

Department of Electrical Engineering and Information Technology
University of Naples Federico II, Italy

Built With


Built with ❀️ for the MedGemma Impact Challenge πŸ†

Advancing Medical AI with Google's Health AI Developer Foundations

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for praiselab-picuslab/BrainGemma3D

Finetuned
(41)
this model

Space using praiselab-picuslab/BrainGemma3D 1