π§ BrainGemma3D β Brain Report Automation via Inflated Vision Transformers in 3D
BrainGemma3D is a multimodal vision-language model that generates clinically accurate radiology reports directly from native 3D brain MRI volumes. Unlike 2D slice-based approaches, BrainGemma3D processes MRI scans volumetrically, preserving the spatial context critical for accurate neuroradiological interpretation.
π― Key Features
- π¬ Native 3D Processing: Inflated 2D medical vision encoder (MedSigLIP) to 3D for volumetric understanding
- π Clinical Accuracy: 95.1% F1 score on pathology entity recognition (on BraTS dataset)
- π§ Spatial Awareness: 68.9% laterality F1 (correct left/right hemisphere localization)
- π Interpretable: LIME-based 3D attribution maps show which brain regions drive predictions
- π Efficient: Processes full 3D volumes with 32 compressed visual tokens
- π₯ Research-Ready: Pre-trained on 369 brain tumor cases + 99 healthy controls
ποΈ Architecture
BrainGemma3D combines:
3D Vision Encoder: MedSigLIP inflated to 3D via center-frame initialization (Conv2D β Conv3D)
Base model: google/medsiglip-448Token Compressor: 2-layer Perceiver that reduces 3D patches to 32 visual tokens
Vision-Language Projector: 2-layer MLP that projects visual tokens to language model embedding space
Language Model: 4-bit quantized MedGemma-1.5-4B-IT with LoRA adapters
Base model: google/medgemma-1.5-4b-it
π Usage
Requirements
pip install torch torchvision transformers nibabel scikit-image lime
Model Download
from huggingface_hub import snapshot_download
# 1. Download the repository containing our custom architecture from Hugging Face
repo_id = "praiselab-picuslab/BrainGemma3D"
print(f"Downloading repository: {repo_id}...")
local_dir = snapshot_download(repo_id)
print(f"β
Repository downloaded to: {local_dir}")
Quick Start
import os
import torch
import sys
sys.path.append(local_dir)
from medgemma3d_architecture import MedGemma3D, load_nifti_volume, CANONICAL_PROMPT
# Automatically select the optimal hardware accelerator (GPU if available, otherwise CPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Hardware accelerator selected: {device}")
# 2. Instantiate the base architecture (3D-inflated MedSigLIP + MedGemma)
model = MedGemma3D(
vision_model_dir=f"{local_dir}/vision_model",
language_model_dir=f"{local_dir}/language_model",
depth=2,
num_vision_tokens=32,
freeze_vision=True,
freeze_language=True,
device_map={"": 0} if device == "cuda" else None,
)
# 3. Load projector
proj_path = os.path.join(local_dir, "projector_vis_scale.pt")
print(f"Loading custom projector weights from: {proj_path}...")
# Load the checkpoint into memory
ckpt = torch.load(proj_path, map_location=device)
# Inject the weights into the visual projector (which bridges Vision and Language)
model.vision_projector.load_state_dict(ckpt["vision_projector"])
# Load the visual scaling factor, ensuring correct tensor formatting
if ckpt.get("vis_scale") is not None:
if isinstance(ckpt["vis_scale"], torch.Tensor):
model.vis_scale.data = ckpt["vis_scale"].to(device)
else:
model.vis_scale.data.fill_(ckpt["vis_scale"])
# Transition the model to evaluation mode for inference
model.eval()
print("β
BrainGemma3D is fully loaded and ready for inference!")
# 4. Load MRI scan
volume = load_nifti_volume(
"path/to/brain_flair.nii.gz",
target_size=(32, 128, 128)
).to(device)
if volume.ndim == 4:
volume = volume.unsqueeze(0)
# 5. Generate report
with torch.no_grad():
report = model.generate_report(
volume,
prompt=CANONICAL_PROMPT,
max_new_tokens=256,
temperature=0.1,
top_p=0.9,
)
print("\n===== GENERATED REPORT =====\n")
print(report)
Expected Output Example
Generated Report:
The lesion area is in the left parietal and frontal lobes with mixed high-signal
areas. Edema signals are mainly observed around these lesions, indicating significant
edema presence affecting parts of both frontal and temporal regions as well as some
portions within the parietal lobe. Necrosis may be present at low signal intensity
or scattered throughout certain sections of the brain tissue affected by edema.
Ventricular compression effects on adjacent ventricles can occur due to pressure
from surrounding tissues near the ventricular system.
π Training Pipeline
BrainGemma3D is trained in three progressive stages to prevent catastrophic forgetting:
Phase 1: Contrastive Grounding (Image-Text Alignment)
- Goal: Align 3D visual features with textual report embeddings
- Loss: InfoNCE (CLIP-style contrastive learning)
- Trainable: 3D Vision Encoder + Projector
- Frozen: Language Model
- Epochs: 100
Phase 2A: Projector Warmup
- Goal: Train the projector to condition the LM effectively
- Loss: Next-token prediction (Cross-Entropy)
- Trainable: Projector only
- Frozen: Vision Encoder + Language Model
- Epochs: 100
Phase 2B: LoRA Linguistic Specialization
- Goal: Adapt LM to generate structured clinical reports
- Loss: Next-token prediction (Cross-Entropy)
- Trainable: Projector + LoRA adapters (rank=4) on LM attention layers
- Frozen: Vision Encoder + LM base weights
- Epochs: 100
Dataset:
- 369 BraTS 2020 brain tumor MRI cases with radiologist-written reports from TextBraTS 2021
- 99 healthy control scans with synthetic reports from MPI-Leipzig Mind-Brain-Body
- Stratified group-based splits (70% train / 10% val / 20% test) to prevent patient leakage
π Performance
Evaluated on 468 subjects (369 BraTS pathological + 99 healthy controls) with group-based splits.
Quantitative Results (Test Set)
| Model | BLEU-1 | BLEU-4 | ROUGE-L | CIDEr | Lat F1 | Anat F1 | Path F1 |
|---|---|---|---|---|---|---|---|
| Med3DVLM (3D Generalist) | 0.051 | 0.005 | 0.083 | 0.007 | 0.300 | 0.225 | 0.119 |
| MedGemma 1.5 (2D Slice) | 0.245 | 0.024 | 0.189 | 0.029 | 0.526 | 0.461 | 0.413 |
| BrainGemma3D (Ours) | 0.302 | 0.098 | 0.289 | 0.293 | 0.689 | 0.691 | 0.951 |
Clinical Metrics Breakdown
- Laterality F1: 0.689 β Correct hemispheric localization (left/right)
- Anatomy F1: 0.691 β Accurate anatomical structure identification
- Pathology F1: 0.951 β Near-perfect pathological entity recognition
- Healthy Specificity: 1.0 β Zero hallucinations on healthy controls
Key Insight: The +130% gain in Pathology F1 (0.951 vs 0.413 compared to 2D baseline) demonstrates that native 3D processing is essential for diagnostic accuracy in neuroradiology.
π Interpretability
BrainGemma3D includes LIME-based 3D interpretability to visualize which brain regions drive diagnostic predictions.
from braingemma3d_interpretability import run_interpretability
# 6. Run interpretability analysis
weights, wvol = run_interpretability(
model=model,
load_nifti_volume=load_nifti_volume,
CANONICAL_PROMPT=CANONICAL_PROMPT,
mri_path="path/to/brain_flair.nii.gz",
report=report,
output_dir="./interpretability_output",
lime_samples=100, # Number of perturbations (more = better but slower)
n_segments=20, # Number of brain regions to analyze
alpha=0.45, # Overlay transparency
clip_q=0.99, # Heatmap clipping
seed=42,
)
Output:
overlay_slices.pngβ Full 3D heatmap (red=supportive, blue=contradicting)lime_2x3_grid.pngβ 2Γ3 grid with selected slices (original + LIME overlay)lime_top_supervoxels_grid.pngβ Most influential supervoxelslime_weights.jsonβ Supervoxel importance scores
Expected Output Example
Figure 1: LIME attribution maps for a BraTS sample. Red regions show supervoxels that positively contribute to pathology predictions. The model correctly focuses on tumor-affected areas in the left parietal and frontal lobes.
βοΈ Model Details
- Model Type: Multimodal Vision-Language Model (3D MRI β Text)
- Architecture: Inflated ViT + Perceiver Compressor + MLP Projector + Quantized Gemma-1.5
- Input: 3D brain MRI FLAIR volumes (64Γ128Γ128 voxels)
- Output: Free-form radiology reports (up to 256 tokens)
- Parameters: ~454M (450M vision + 4B language, 4-bit quantized)
- Training Compute: 1Γ NVIDIA A100 64GB (β12 GPU-hours total)
- Framework: PyTorch 2.0, Transformers 4.40+
Preprocessing Requirements
- Orientation: RAS (as-closest-canonical)
- Resolution: Resampled to (64, 128, 128) via trilinear interpolation
- Normalization: Percentile clipping (p1, p99) + z-score normalization
- Format: NIfTI (.nii or .nii.gz)
β οΈ Limitations & Intended Use
β Intended Use
- Research: Medical AI research, neuroradiology automation
- Education: Teaching radiology residents about report generation
- Prototyping: Building diagnostic support tools (non-clinical)
β Not Intended For
- Clinical Diagnosis: This model is NOT FDA/CE approved for medical use
- Primary Interpretation: Always verify with board-certified radiologists
- Real-Time Emergency: Not validated for acute stroke or trauma cases
Known Limitations
- Training Bias: Trained primarily on glioblastoma (BraTS dataset) β may underperform on other pathologies
- Language: English only (radiology reports)
- Hallucination Risk: May generate plausible but incorrect anatomical details (always verify)
- Compute Requirements: Requires GPU with β₯16GB VRAM for inference
π₯ Clinical Validation Notes
BrainGemma3D achieved 95.1% pathology F1 on the BraTS, but this does NOT imply clinical readiness. Key considerations:
- Dataset Homogeneity: BraTS contains predominantly glioblastomas β performance on other tumor types (meningiomas, metastases) is unknown
- Report Quality: Ground truth reports are from a single institution β may not generalize to other radiology practices
- No Radiologist Review: Generated reports have not been clinically validated by neuroradiologists
- Regulatory Status: Not cleared by FDA, EMA, or any regulatory body
Recommendation: Use only in research settings with appropriate ethical oversight and informed consent.
π Acknowledgements
This project was developed by:
Mariano Barone Β· Francesco Di Serio Β· Giuseppe Riccio Β· Antonio Romano Β· Vincenzo Moscato
Department of Electrical Engineering and Information Technology
University of Naples Federico II, Italy
Built With
- Google MedGemma β Medical domain language model
- Google MedSigLIP β Medical vision encoder
- Hugging Face Transformers β Model framework
Built with β€οΈ for the MedGemma Impact Challenge π
Advancing Medical AI with Google's Health AI Developer Foundations
Model tree for praiselab-picuslab/BrainGemma3D
Base model
google/medgemma-1.5-4b-it