VibeToken / examples /batch_inference.py
APGASU's picture
scripts
7bef20f verified
#!/usr/bin/env python3
"""Batch inference example for VibeToken.
Demonstrates how to process multiple images efficiently in batches.
Usage:
# Auto mode (recommended)
python examples/batch_inference.py --auto \
--config configs/vibetoken_ll.yaml \
--checkpoint path/to/checkpoint.bin \
--input_dir path/to/images/ \
--output_dir path/to/output/ \
--batch_size 4
# Manual mode
python examples/batch_inference.py \
--config configs/vibetoken_ll.yaml \
--checkpoint path/to/checkpoint.bin \
--input_dir path/to/images/ \
--output_dir path/to/output/ \
--batch_size 4 \
--resolution 512 \
--encoder_patch_size 16,32 \
--decoder_patch_size 16
"""
import argparse
import time
from pathlib import Path
import torch
from PIL import Image
import numpy as np
import sys
sys.path.insert(0, str(Path(__file__).parent.parent))
from vibetoken import VibeTokenTokenizer, auto_preprocess_image, center_crop_to_multiple
def parse_patch_size(value):
"""Parse patch size from string. Supports single int or tuple (e.g., '16' or '16,32')."""
if value is None:
return None
if ',' in value:
parts = value.split(',')
return (int(parts[0]), int(parts[1]))
return int(value)
def load_and_preprocess_image(path: Path, target_size: tuple = None, auto_mode: bool = False) -> tuple:
"""Load and preprocess image.
Args:
path: Path to image
target_size: Optional target size (width, height) for resizing
auto_mode: If True, use auto_preprocess_image for cropping
Returns:
image: numpy array
patch_size: auto-determined patch size (if auto_mode) or None
"""
img = Image.open(path).convert("RGB")
if auto_mode:
# Use centralized auto_preprocess_image
img, patch_size, info = auto_preprocess_image(img, verbose=False)
return np.array(img), patch_size, info
else:
if target_size:
img = img.resize(target_size, Image.LANCZOS)
# Always center crop to ensure dimensions divisible by 32
img = center_crop_to_multiple(img, multiple=32)
return np.array(img), None, None
def main():
parser = argparse.ArgumentParser(description="VibeToken batch inference example")
parser.add_argument("--config", type=str, required=True, help="Path to config YAML")
parser.add_argument("--checkpoint", type=str, required=True, help="Path to model checkpoint")
parser.add_argument("--input_dir", type=str, required=True, help="Directory with input images")
parser.add_argument("--output_dir", type=str, required=True, help="Directory for output images")
parser.add_argument("--batch_size", type=int, default=4, help="Batch size")
parser.add_argument("--device", type=str, default="cuda", help="Device (cuda/cpu)")
# Auto mode
parser.add_argument("--auto", action="store_true",
help="Auto mode: automatically determine optimal settings per image")
# Manual mode options
parser.add_argument("--resolution", type=int, default=512, help="Target resolution (manual mode)")
parser.add_argument("--encoder_patch_size", type=str, default=None,
help="Encoder patch size: single int (e.g., 16) or tuple (e.g., 16,32 for H,W)")
parser.add_argument("--decoder_patch_size", type=str, default=None,
help="Decoder patch size: single int (e.g., 16) or tuple (e.g., 16,32 for H,W)")
args = parser.parse_args()
# Parse patch sizes
encoder_patch_size = parse_patch_size(args.encoder_patch_size)
decoder_patch_size = parse_patch_size(args.decoder_patch_size)
# Check CUDA
if args.device == "cuda" and not torch.cuda.is_available():
print("CUDA not available, falling back to CPU")
args.device = "cpu"
# Create output directory
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Load tokenizer
print(f"Loading tokenizer from {args.config}")
tokenizer = VibeTokenTokenizer.from_config(
config_path=args.config,
checkpoint_path=args.checkpoint,
device=args.device,
)
if args.auto:
print("Running in AUTO MODE - optimal settings determined per image")
else:
print(f"Running in MANUAL MODE - resolution: {args.resolution}")
if encoder_patch_size:
print(f" Encoder patch size: {encoder_patch_size}")
if decoder_patch_size:
print(f" Decoder patch size: {decoder_patch_size}")
# Find all images
input_dir = Path(args.input_dir)
image_extensions = {".jpg", ".jpeg", ".png", ".webp", ".bmp"}
image_paths = [p for p in input_dir.iterdir() if p.suffix.lower() in image_extensions]
print(f"Found {len(image_paths)} images")
if not image_paths:
print("No images found!")
return
# Process in batches
target_size = (args.resolution, args.resolution) if not args.auto else None
total_time = 0
num_processed = 0
if args.auto:
# AUTO MODE: Process images one by one since each may have different sizes
for i, path in enumerate(image_paths):
try:
img_array, patch_size, info = load_and_preprocess_image(path, auto_mode=True)
batch_array = img_array[np.newaxis, ...] # Add batch dim
start_time = time.time()
# Reconstruct with auto-determined patch size
height, width = info['cropped_size'][1], info['cropped_size'][0]
reconstructed = tokenizer.reconstruct(
batch_array,
encode_patch_size=patch_size,
decode_patch_size=patch_size,
target_height=height,
target_width=width,
)
if args.device == "cuda":
torch.cuda.synchronize()
batch_time = time.time() - start_time
total_time += batch_time
num_processed += 1
# Save output
output_images = tokenizer.to_pil(reconstructed)
output_path = output_dir / f"{path.stem}_recon.png"
output_images[0].save(output_path)
print(f"[{i+1}/{len(image_paths)}] {path.name}: "
f"{info['cropped_size'][0]}x{info['cropped_size'][1]}, "
f"patch_size={patch_size}, {batch_time:.2f}s")
except Exception as e:
print(f"Error processing {path}: {e}")
continue
else:
# MANUAL MODE: Batch processing with uniform size
for batch_start in range(0, len(image_paths), args.batch_size):
batch_paths = image_paths[batch_start:batch_start + args.batch_size]
batch_names = [p.stem for p in batch_paths]
# Load batch
batch_images = []
for path in batch_paths:
try:
img_array, _, _ = load_and_preprocess_image(path, target_size, auto_mode=False)
batch_images.append(img_array)
except Exception as e:
print(f"Error loading {path}: {e}")
continue
if not batch_images:
continue
# Stack into batch tensor
batch_array = np.stack(batch_images, axis=0)
# Measure time
start_time = time.time()
# Reconstruct
reconstructed = tokenizer.reconstruct(
batch_array,
encode_patch_size=encoder_patch_size,
decode_patch_size=decoder_patch_size,
target_height=args.resolution,
target_width=args.resolution,
)
# Synchronize if GPU
if args.device == "cuda":
torch.cuda.synchronize()
batch_time = time.time() - start_time
total_time += batch_time
num_processed += len(batch_images)
# Save outputs
output_images = tokenizer.to_pil(reconstructed)
for name, img in zip(batch_names[:len(output_images)], output_images):
output_path = output_dir / f"{name}_recon.png"
img.save(output_path)
print(f"Processed batch {batch_start // args.batch_size + 1}: "
f"{len(batch_images)} images in {batch_time:.2f}s "
f"({len(batch_images) / batch_time:.2f} img/s)")
# Summary
if num_processed > 0:
print(f"\nTotal: {num_processed} images in {total_time:.2f}s")
print(f"Average: {num_processed / total_time:.2f} images/sec")
print(f"Per image: {total_time / num_processed * 1000:.1f}ms")
if __name__ == "__main__":
main()