| | |
| | """Batch inference example for VibeToken. |
| | |
| | Demonstrates how to process multiple images efficiently in batches. |
| | |
| | Usage: |
| | # Auto mode (recommended) |
| | python examples/batch_inference.py --auto \ |
| | --config configs/vibetoken_ll.yaml \ |
| | --checkpoint path/to/checkpoint.bin \ |
| | --input_dir path/to/images/ \ |
| | --output_dir path/to/output/ \ |
| | --batch_size 4 |
| | |
| | # Manual mode |
| | python examples/batch_inference.py \ |
| | --config configs/vibetoken_ll.yaml \ |
| | --checkpoint path/to/checkpoint.bin \ |
| | --input_dir path/to/images/ \ |
| | --output_dir path/to/output/ \ |
| | --batch_size 4 \ |
| | --resolution 512 \ |
| | --encoder_patch_size 16,32 \ |
| | --decoder_patch_size 16 |
| | """ |
| |
|
| | import argparse |
| | import time |
| | from pathlib import Path |
| |
|
| | import torch |
| | from PIL import Image |
| | import numpy as np |
| |
|
| | import sys |
| | sys.path.insert(0, str(Path(__file__).parent.parent)) |
| |
|
| | from vibetoken import VibeTokenTokenizer, auto_preprocess_image, center_crop_to_multiple |
| |
|
| |
|
| | def parse_patch_size(value): |
| | """Parse patch size from string. Supports single int or tuple (e.g., '16' or '16,32').""" |
| | if value is None: |
| | return None |
| | if ',' in value: |
| | parts = value.split(',') |
| | return (int(parts[0]), int(parts[1])) |
| | return int(value) |
| |
|
| |
|
| | def load_and_preprocess_image(path: Path, target_size: tuple = None, auto_mode: bool = False) -> tuple: |
| | """Load and preprocess image. |
| | |
| | Args: |
| | path: Path to image |
| | target_size: Optional target size (width, height) for resizing |
| | auto_mode: If True, use auto_preprocess_image for cropping |
| | |
| | Returns: |
| | image: numpy array |
| | patch_size: auto-determined patch size (if auto_mode) or None |
| | """ |
| | img = Image.open(path).convert("RGB") |
| | |
| | if auto_mode: |
| | |
| | img, patch_size, info = auto_preprocess_image(img, verbose=False) |
| | return np.array(img), patch_size, info |
| | else: |
| | if target_size: |
| | img = img.resize(target_size, Image.LANCZOS) |
| | |
| | img = center_crop_to_multiple(img, multiple=32) |
| | return np.array(img), None, None |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser(description="VibeToken batch inference example") |
| | parser.add_argument("--config", type=str, required=True, help="Path to config YAML") |
| | parser.add_argument("--checkpoint", type=str, required=True, help="Path to model checkpoint") |
| | parser.add_argument("--input_dir", type=str, required=True, help="Directory with input images") |
| | parser.add_argument("--output_dir", type=str, required=True, help="Directory for output images") |
| | parser.add_argument("--batch_size", type=int, default=4, help="Batch size") |
| | parser.add_argument("--device", type=str, default="cuda", help="Device (cuda/cpu)") |
| | |
| | |
| | parser.add_argument("--auto", action="store_true", |
| | help="Auto mode: automatically determine optimal settings per image") |
| | |
| | |
| | parser.add_argument("--resolution", type=int, default=512, help="Target resolution (manual mode)") |
| | parser.add_argument("--encoder_patch_size", type=str, default=None, |
| | help="Encoder patch size: single int (e.g., 16) or tuple (e.g., 16,32 for H,W)") |
| | parser.add_argument("--decoder_patch_size", type=str, default=None, |
| | help="Decoder patch size: single int (e.g., 16) or tuple (e.g., 16,32 for H,W)") |
| | args = parser.parse_args() |
| | |
| | |
| | encoder_patch_size = parse_patch_size(args.encoder_patch_size) |
| | decoder_patch_size = parse_patch_size(args.decoder_patch_size) |
| |
|
| | |
| | if args.device == "cuda" and not torch.cuda.is_available(): |
| | print("CUDA not available, falling back to CPU") |
| | args.device = "cpu" |
| |
|
| | |
| | output_dir = Path(args.output_dir) |
| | output_dir.mkdir(parents=True, exist_ok=True) |
| |
|
| | |
| | print(f"Loading tokenizer from {args.config}") |
| | tokenizer = VibeTokenTokenizer.from_config( |
| | config_path=args.config, |
| | checkpoint_path=args.checkpoint, |
| | device=args.device, |
| | ) |
| | |
| | if args.auto: |
| | print("Running in AUTO MODE - optimal settings determined per image") |
| | else: |
| | print(f"Running in MANUAL MODE - resolution: {args.resolution}") |
| | if encoder_patch_size: |
| | print(f" Encoder patch size: {encoder_patch_size}") |
| | if decoder_patch_size: |
| | print(f" Decoder patch size: {decoder_patch_size}") |
| |
|
| | |
| | input_dir = Path(args.input_dir) |
| | image_extensions = {".jpg", ".jpeg", ".png", ".webp", ".bmp"} |
| | image_paths = [p for p in input_dir.iterdir() if p.suffix.lower() in image_extensions] |
| | print(f"Found {len(image_paths)} images") |
| |
|
| | if not image_paths: |
| | print("No images found!") |
| | return |
| |
|
| | |
| | target_size = (args.resolution, args.resolution) if not args.auto else None |
| | total_time = 0 |
| | num_processed = 0 |
| |
|
| | if args.auto: |
| | |
| | for i, path in enumerate(image_paths): |
| | try: |
| | img_array, patch_size, info = load_and_preprocess_image(path, auto_mode=True) |
| | batch_array = img_array[np.newaxis, ...] |
| | |
| | start_time = time.time() |
| | |
| | |
| | height, width = info['cropped_size'][1], info['cropped_size'][0] |
| | reconstructed = tokenizer.reconstruct( |
| | batch_array, |
| | encode_patch_size=patch_size, |
| | decode_patch_size=patch_size, |
| | target_height=height, |
| | target_width=width, |
| | ) |
| | |
| | if args.device == "cuda": |
| | torch.cuda.synchronize() |
| | |
| | batch_time = time.time() - start_time |
| | total_time += batch_time |
| | num_processed += 1 |
| | |
| | |
| | output_images = tokenizer.to_pil(reconstructed) |
| | output_path = output_dir / f"{path.stem}_recon.png" |
| | output_images[0].save(output_path) |
| | |
| | print(f"[{i+1}/{len(image_paths)}] {path.name}: " |
| | f"{info['cropped_size'][0]}x{info['cropped_size'][1]}, " |
| | f"patch_size={patch_size}, {batch_time:.2f}s") |
| | |
| | except Exception as e: |
| | print(f"Error processing {path}: {e}") |
| | continue |
| | else: |
| | |
| | for batch_start in range(0, len(image_paths), args.batch_size): |
| | batch_paths = image_paths[batch_start:batch_start + args.batch_size] |
| | batch_names = [p.stem for p in batch_paths] |
| | |
| | |
| | batch_images = [] |
| | for path in batch_paths: |
| | try: |
| | img_array, _, _ = load_and_preprocess_image(path, target_size, auto_mode=False) |
| | batch_images.append(img_array) |
| | except Exception as e: |
| | print(f"Error loading {path}: {e}") |
| | continue |
| | |
| | if not batch_images: |
| | continue |
| |
|
| | |
| | batch_array = np.stack(batch_images, axis=0) |
| | |
| | |
| | start_time = time.time() |
| | |
| | |
| | reconstructed = tokenizer.reconstruct( |
| | batch_array, |
| | encode_patch_size=encoder_patch_size, |
| | decode_patch_size=decoder_patch_size, |
| | target_height=args.resolution, |
| | target_width=args.resolution, |
| | ) |
| | |
| | |
| | if args.device == "cuda": |
| | torch.cuda.synchronize() |
| | |
| | batch_time = time.time() - start_time |
| | total_time += batch_time |
| | num_processed += len(batch_images) |
| | |
| | |
| | output_images = tokenizer.to_pil(reconstructed) |
| | for name, img in zip(batch_names[:len(output_images)], output_images): |
| | output_path = output_dir / f"{name}_recon.png" |
| | img.save(output_path) |
| | |
| | print(f"Processed batch {batch_start // args.batch_size + 1}: " |
| | f"{len(batch_images)} images in {batch_time:.2f}s " |
| | f"({len(batch_images) / batch_time:.2f} img/s)") |
| |
|
| | |
| | if num_processed > 0: |
| | print(f"\nTotal: {num_processed} images in {total_time:.2f}s") |
| | print(f"Average: {num_processed / total_time:.2f} images/sec") |
| | print(f"Per image: {total_time / num_processed * 1000:.1f}ms") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|