| | import torch |
| | import numpy as np |
| | import cv2 |
| | from PIL import Image |
| | import logging |
| | import gc |
| | import time |
| | import os |
| | from typing import Optional, Dict, Any, Callable |
| | import warnings |
| | warnings.filterwarnings("ignore") |
| |
|
| | from diffusers import StableDiffusionXLPipeline, StableDiffusionXLInpaintPipeline, DPMSolverMultistepScheduler |
| | import open_clip |
| | from mask_generator import MaskGenerator |
| | from image_blender import ImageBlender |
| |
|
| | try: |
| | import spaces |
| | SPACES_AVAILABLE = True |
| | except ImportError: |
| | SPACES_AVAILABLE = False |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | class BackgroundEngine: |
| | """ |
| | Background generation engine for VividFlow. |
| | |
| | Integrates SDXL pipeline, OpenCLIP analysis, mask generation, |
| | and advanced image blending. |
| | """ |
| |
|
| | def __init__(self, device: str = "auto"): |
| | self.device = self._setup_device(device) |
| | self.base_model_id = "stabilityai/stable-diffusion-xl-base-1.0" |
| | self.clip_model_name = "ViT-B-32" |
| | self.clip_pretrained = "openai" |
| |
|
| | self.pipeline = None |
| | self.inpaint_pipeline = None |
| | self.clip_model = None |
| | self.clip_preprocess = None |
| | self.clip_tokenizer = None |
| | self.is_initialized = False |
| | self.inpaint_initialized = False |
| |
|
| | self.max_image_size = 1024 |
| | self.default_steps = 25 |
| | self.use_fp16 = True |
| |
|
| | self.mask_generator = MaskGenerator(self.max_image_size) |
| | self.image_blender = ImageBlender() |
| |
|
| | logger.info(f"BackgroundEngine initialized on {self.device}") |
| |
|
| | def _setup_device(self, device: str) -> str: |
| | """Setup computation device (ZeroGPU compatible)""" |
| | if os.getenv('SPACE_ID') is not None: |
| | return "cpu" |
| |
|
| | if device == "auto": |
| | if torch.cuda.is_available(): |
| | return "cuda" |
| | elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): |
| | return "mps" |
| | return "cpu" |
| | return device |
| |
|
| | def _memory_cleanup(self): |
| | """Memory cleanup""" |
| | for _ in range(3): |
| | gc.collect() |
| |
|
| | is_spaces = os.getenv('SPACE_ID') is not None |
| | if not is_spaces and torch.cuda.is_available(): |
| | torch.cuda.empty_cache() |
| |
|
| | def load_models(self, progress_callback: Optional[Callable] = None): |
| | """Load SDXL and OpenCLIP models""" |
| | if self.is_initialized: |
| | logger.info("Models already loaded") |
| | return |
| |
|
| | logger.info("Loading background generation models...") |
| |
|
| | try: |
| | self._memory_cleanup() |
| |
|
| | |
| | actual_device = "cuda" if torch.cuda.is_available() else self.device |
| | logger.info(f"Loading models to device: {actual_device}") |
| |
|
| | if progress_callback: |
| | progress_callback("Loading OpenCLIP...", 20) |
| |
|
| | |
| | self.clip_model, _, self.clip_preprocess = open_clip.create_model_and_transforms( |
| | self.clip_model_name, |
| | pretrained=self.clip_pretrained, |
| | device=actual_device |
| | ) |
| | self.clip_tokenizer = open_clip.get_tokenizer(self.clip_model_name) |
| | self.clip_model.eval() |
| |
|
| | logger.info("OpenCLIP loaded") |
| |
|
| | if progress_callback: |
| | progress_callback("Loading SDXL pipeline...", 60) |
| |
|
| | |
| | self.pipeline = StableDiffusionXLPipeline.from_pretrained( |
| | self.base_model_id, |
| | torch_dtype=torch.float16 if self.use_fp16 else torch.float32, |
| | use_safetensors=True, |
| | variant="fp16" if self.use_fp16 else None |
| | ) |
| |
|
| | |
| | self.pipeline.scheduler = DPMSolverMultistepScheduler.from_config( |
| | self.pipeline.scheduler.config |
| | ) |
| |
|
| | self.pipeline = self.pipeline.to(actual_device) |
| |
|
| | if progress_callback: |
| | progress_callback("Applying optimizations...", 90) |
| |
|
| | |
| | try: |
| | self.pipeline.enable_xformers_memory_efficient_attention() |
| | logger.info("xformers enabled") |
| | except Exception: |
| | try: |
| | self.pipeline.enable_attention_slicing() |
| | logger.info("Attention slicing enabled") |
| | except Exception: |
| | pass |
| |
|
| | if hasattr(self.pipeline, 'enable_vae_tiling'): |
| | self.pipeline.enable_vae_tiling() |
| |
|
| | if hasattr(self.pipeline, 'enable_vae_slicing'): |
| | self.pipeline.enable_vae_slicing() |
| |
|
| | self.pipeline.unet.eval() |
| | if hasattr(self.pipeline, 'vae'): |
| | self.pipeline.vae.eval() |
| |
|
| | self.is_initialized = True |
| |
|
| | if progress_callback: |
| | progress_callback("Models loaded!", 100) |
| |
|
| | logger.info("Background models loaded successfully") |
| |
|
| | except Exception as e: |
| | logger.error(f"Model loading failed: {e}") |
| | raise RuntimeError(f"Failed to load models: {str(e)}") |
| |
|
| | def analyze_image_with_clip(self, image: Image.Image) -> str: |
| | """Analyze image using OpenCLIP""" |
| | if not self.clip_model: |
| | return "Unknown" |
| |
|
| | try: |
| | |
| | actual_device = "cuda" if torch.cuda.is_available() else self.device |
| |
|
| | image_input = self.clip_preprocess(image).unsqueeze(0).to(actual_device) |
| |
|
| | categories = [ |
| | "a photo of a person", |
| | "a photo of an animal", |
| | "a photo of an object", |
| | "a photo of nature", |
| | "a photo of a building" |
| | ] |
| |
|
| | text_inputs = self.clip_tokenizer(categories).to(actual_device) |
| |
|
| | with torch.no_grad(): |
| | image_features = self.clip_model.encode_image(image_input) |
| | text_features = self.clip_model.encode_text(text_inputs) |
| |
|
| | image_features /= image_features.norm(dim=-1, keepdim=True) |
| | text_features /= text_features.norm(dim=-1, keepdim=True) |
| |
|
| | similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1) |
| | best_match_idx = similarity.argmax().item() |
| |
|
| | category = categories[best_match_idx].replace("a photo of ", "") |
| | return category |
| |
|
| | except Exception as e: |
| | logger.error(f"CLIP analysis failed: {e}") |
| | return "unknown" |
| |
|
| | def enhance_prompt(self, user_prompt: str, foreground_image: Image.Image) -> str: |
| | """Smart prompt enhancement based on image analysis""" |
| | try: |
| | img_array = np.array(foreground_image.convert('RGB')) |
| |
|
| | |
| | lab = cv2.cvtColor(img_array, cv2.COLOR_RGB2LAB) |
| | avg_b = np.mean(lab[:, :, 2]) |
| | is_warm = avg_b > 128 |
| |
|
| | |
| | gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY) |
| | avg_brightness = np.mean(gray) |
| | is_bright = avg_brightness > 127 |
| |
|
| | |
| | clip_analysis = self.analyze_image_with_clip(foreground_image) |
| | subject_type = clip_analysis |
| |
|
| | |
| | if is_warm and is_bright: |
| | lighting = "warm golden hour lighting, soft natural light" |
| | elif is_warm and not is_bright: |
| | lighting = "warm ambient lighting, cozy atmosphere" |
| | elif not is_warm and is_bright: |
| | lighting = "bright daylight, clear sky lighting" |
| | else: |
| | lighting = "soft diffused light, gentle shadows" |
| |
|
| | |
| | atmosphere_map = { |
| | "person": "professional, elegant composition", |
| | "animal": "natural, harmonious setting", |
| | "object": "clean product photography style", |
| | "nature": "scenic, peaceful atmosphere", |
| | "building": "architectural, balanced composition" |
| | } |
| | atmosphere = atmosphere_map.get(subject_type, "balanced composition") |
| |
|
| | quality_modifiers = "high quality, detailed, sharp focus, photorealistic" |
| |
|
| | |
| | user_prompt_lower = user_prompt.lower() |
| | if "sunset" in user_prompt_lower or "golden" in user_prompt_lower: |
| | lighting = "" |
| | if "dark" in user_prompt_lower or "night" in user_prompt_lower: |
| | lighting = lighting.replace("bright", "").replace("daylight", "") |
| |
|
| | |
| | fragments = [user_prompt] |
| | if lighting: |
| | fragments.append(lighting) |
| | fragments.append(atmosphere) |
| | fragments.append(quality_modifiers) |
| |
|
| | enhanced_prompt = ", ".join(filter(None, fragments)) |
| |
|
| | logger.debug(f"Enhanced: {enhanced_prompt[:80]}...") |
| | return enhanced_prompt |
| |
|
| | except Exception as e: |
| | logger.warning(f"Prompt enhancement failed: {e}") |
| | return f"{user_prompt}, high quality, detailed, photorealistic" |
| |
|
| | def _prepare_image(self, image: Image.Image) -> Image.Image: |
| | """Prepare image for processing""" |
| | if image.mode != 'RGB': |
| | image = image.convert('RGB') |
| |
|
| | width, height = image.size |
| | max_size = self.max_image_size |
| |
|
| | if width > max_size or height > max_size: |
| | ratio = min(max_size/width, max_size/height) |
| | new_width = int(width * ratio) |
| | new_height = int(height * ratio) |
| | image = image.resize((new_width, new_height), Image.LANCZOS) |
| |
|
| | width, height = image.size |
| | new_width = (width // 8) * 8 |
| | new_height = (height // 8) * 8 |
| |
|
| | if new_width != width or new_height != height: |
| | image = image.resize((new_width, new_height), Image.LANCZOS) |
| |
|
| | return image |
| |
|
| | def generate_background( |
| | self, |
| | prompt: str, |
| | width: int, |
| | height: int, |
| | negative_prompt: str = "blurry, low quality, distorted", |
| | num_inference_steps: int = 25, |
| | guidance_scale: float = 7.5 |
| | ) -> Image.Image: |
| | """Generate background using SDXL""" |
| | if not self.is_initialized: |
| | raise RuntimeError("Models not loaded") |
| |
|
| | logger.info(f"Generating background: {prompt[:50]}...") |
| |
|
| | try: |
| | |
| | actual_device = "cuda" if torch.cuda.is_available() else self.device |
| |
|
| | with torch.inference_mode(): |
| | result = self.pipeline( |
| | prompt=prompt, |
| | negative_prompt=negative_prompt, |
| | width=width, |
| | height=height, |
| | num_inference_steps=num_inference_steps, |
| | guidance_scale=guidance_scale, |
| | generator=torch.Generator(device=actual_device).manual_seed(42) |
| | ) |
| |
|
| | generated_image = result.images[0] |
| | logger.info("Background generation completed") |
| | return generated_image |
| |
|
| | except torch.cuda.OutOfMemoryError: |
| | logger.error("GPU memory exhausted") |
| | self._memory_cleanup() |
| | raise RuntimeError("GPU memory insufficient") |
| |
|
| | except Exception as e: |
| | logger.error(f"Generation failed: {e}") |
| | raise RuntimeError(f"Generation failed: {str(e)}") |
| |
|
| | def generate_and_combine( |
| | self, |
| | original_image: Image.Image, |
| | prompt: str, |
| | combination_mode: str = "center", |
| | focus_mode: str = "person", |
| | negative_prompt: str = "blurry, low quality, distorted", |
| | num_inference_steps: int = 25, |
| | guidance_scale: float = 7.5, |
| | progress_callback: Optional[Callable] = None, |
| | enable_prompt_enhancement: bool = True, |
| | feather_radius: int = 0, |
| | enhance_dark_edges: bool = False |
| | ) -> Dict[str, Any]: |
| | """ |
| | Generate background and combine with foreground. |
| | |
| | Args: |
| | feather_radius: Gaussian blur radius for mask edge softening (0-20, default 0) |
| | enhance_dark_edges: Enhance mask edges for dark background images (default False) |
| | |
| | Returns dict with: combined_image, generated_scene, original_image, mask, success |
| | """ |
| | if not self.is_initialized: |
| | raise RuntimeError("Models not loaded") |
| |
|
| | logger.info("Starting background generation and combination...") |
| |
|
| | try: |
| | if progress_callback: |
| | progress_callback("Analyzing image...", 5) |
| |
|
| | |
| | processed_original = self._prepare_image(original_image) |
| | target_width, target_height = processed_original.size |
| |
|
| | if progress_callback: |
| | progress_callback("Enhancing prompt...", 15) |
| |
|
| | |
| | if enable_prompt_enhancement: |
| | enhanced_prompt = self.enhance_prompt(prompt, processed_original) |
| | else: |
| | enhanced_prompt = f"{prompt}, high quality, detailed, photorealistic" |
| |
|
| | enhanced_negative = f"{negative_prompt}, people, characters, cartoons, logos" |
| |
|
| | if progress_callback: |
| | progress_callback("Generating background...", 30) |
| |
|
| | |
| | generated_background = self.generate_background( |
| | prompt=enhanced_prompt, |
| | width=target_width, |
| | height=target_height, |
| | negative_prompt=enhanced_negative, |
| | num_inference_steps=num_inference_steps, |
| | guidance_scale=guidance_scale |
| | ) |
| |
|
| | if progress_callback: |
| | progress_callback("Creating mask...", 80) |
| |
|
| | |
| | logger.info("Generating mask...") |
| | combination_mask = self.mask_generator.create_gradient_based_mask( |
| | processed_original, |
| | combination_mode, |
| | focus_mode, |
| | enhance_dark_edges=enhance_dark_edges |
| | ) |
| |
|
| | if progress_callback: |
| | progress_callback("Blending images...", 90) |
| |
|
| | |
| | logger.info("Blending images...") |
| | combined_image = self.image_blender.simple_blend_images( |
| | processed_original, |
| | generated_background, |
| | combination_mask, |
| | feather_radius=feather_radius |
| | ) |
| |
|
| | |
| | self._memory_cleanup() |
| |
|
| | if progress_callback: |
| | progress_callback("Complete!", 100) |
| |
|
| | logger.info("Background generation completed successfully") |
| |
|
| | |
| | return { |
| | "combined_image": combined_image, |
| | "generated_scene": generated_background, |
| | "original_image": processed_original, |
| | "mask": combination_mask, |
| | "success": True |
| | } |
| |
|
| | except Exception as e: |
| | logger.error(f"Generation failed: {e}") |
| | self._memory_cleanup() |
| | return { |
| | "success": False, |
| | "error": str(e) |
| | } |
| |
|
| | def _load_inpaint_pipeline(self) -> bool: |
| | """Lazy load SDXL inpainting pipeline""" |
| | if self.inpaint_initialized: |
| | return True |
| |
|
| | try: |
| | logger.info("Loading SDXL inpainting pipeline...") |
| | actual_device = "cuda" if torch.cuda.is_available() else self.device |
| |
|
| | self.inpaint_pipeline = StableDiffusionXLInpaintPipeline.from_pretrained( |
| | "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", |
| | torch_dtype=torch.float16 if actual_device == "cuda" else torch.float32, |
| | variant="fp16" if actual_device == "cuda" else None, |
| | use_safetensors=True |
| | ) |
| | self.inpaint_pipeline.to(actual_device) |
| |
|
| | |
| | self.inpaint_pipeline.scheduler = DPMSolverMultistepScheduler.from_config( |
| | self.inpaint_pipeline.scheduler.config |
| | ) |
| |
|
| | |
| | if actual_device == "cuda": |
| | try: |
| | self.inpaint_pipeline.enable_xformers_memory_efficient_attention() |
| | except Exception: |
| | pass |
| |
|
| | self.inpaint_initialized = True |
| | logger.info("✓ SDXL inpainting pipeline loaded") |
| | return True |
| |
|
| | except Exception as e: |
| | logger.error(f"Failed to load inpainting pipeline: {e}") |
| | self.inpaint_initialized = False |
| | return False |
| |
|
| | def inpaint_region( |
| | self, |
| | image: Image.Image, |
| | mask: Image.Image, |
| | prompt: str, |
| | negative_prompt: str = "blurry, low quality, artifacts, seams", |
| | num_inference_steps: int = 20, |
| | guidance_scale: float = 7.5, |
| | strength: float = 0.99 |
| | ) -> Dict[str, Any]: |
| | """ |
| | Inpaint marked regions with background content. |
| | |
| | Args: |
| | image: The combined image with artifacts to fix |
| | mask: Binary mask where white = areas to inpaint |
| | prompt: Background description for inpainting |
| | negative_prompt: What to avoid |
| | num_inference_steps: Denoising steps (20 is usually enough) |
| | guidance_scale: How closely to follow prompt |
| | strength: How much to change masked area (0.99 = almost complete replacement) |
| | |
| | Returns: |
| | Dict with inpainted_image, success, error |
| | """ |
| | try: |
| | |
| | if not self._load_inpaint_pipeline(): |
| | |
| | return self._opencv_inpaint_fallback(image, mask) |
| |
|
| | logger.info("Starting region inpainting...") |
| |
|
| | |
| | image = self._prepare_image(image) |
| | mask = mask.resize(image.size, Image.LANCZOS).convert('L') |
| |
|
| | |
| | mask_array = np.array(mask) |
| | mask_array = (mask_array > 127).astype(np.uint8) * 255 |
| | mask = Image.fromarray(mask_array, mode='L') |
| |
|
| | |
| | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) |
| | mask_dilated = cv2.dilate(mask_array, kernel, iterations=1) |
| | mask = Image.fromarray(mask_dilated, mode='L') |
| |
|
| | actual_device = "cuda" if torch.cuda.is_available() else self.device |
| |
|
| | with torch.inference_mode(): |
| | result = self.inpaint_pipeline( |
| | prompt=prompt, |
| | negative_prompt=negative_prompt, |
| | image=image, |
| | mask_image=mask, |
| | width=image.size[0], |
| | height=image.size[1], |
| | num_inference_steps=num_inference_steps, |
| | guidance_scale=guidance_scale, |
| | strength=strength, |
| | generator=torch.Generator(device=actual_device).manual_seed(42) |
| | ) |
| |
|
| | inpainted = result.images[0] |
| |
|
| | |
| | inpainted = self._blend_inpaint_edges(image, inpainted, mask) |
| |
|
| | self._memory_cleanup() |
| |
|
| | logger.info("✓ Region inpainting completed") |
| | return { |
| | "inpainted_image": inpainted, |
| | "success": True |
| | } |
| |
|
| | except Exception as e: |
| | logger.error(f"Inpainting failed: {e}") |
| | self._memory_cleanup() |
| | return { |
| | "success": False, |
| | "error": str(e) |
| | } |
| |
|
| | def _opencv_inpaint_fallback( |
| | self, |
| | image: Image.Image, |
| | mask: Image.Image |
| | ) -> Dict[str, Any]: |
| | """Fallback to OpenCV inpainting for small areas or when SDXL unavailable""" |
| | try: |
| | logger.info("Using OpenCV inpainting fallback...") |
| |
|
| | img_array = np.array(image.convert('RGB')) |
| | mask_array = np.array(mask.convert('L')) |
| |
|
| | |
| | mask_binary = (mask_array > 127).astype(np.uint8) * 255 |
| |
|
| | |
| | inpainted = cv2.inpaint( |
| | img_array, |
| | mask_binary, |
| | inpaintRadius=5, |
| | flags=cv2.INPAINT_TELEA |
| | ) |
| |
|
| | result = Image.fromarray(inpainted) |
| |
|
| | logger.info("✓ OpenCV inpainting completed") |
| | return { |
| | "inpainted_image": result, |
| | "success": True |
| | } |
| |
|
| | except Exception as e: |
| | logger.error(f"OpenCV inpainting failed: {e}") |
| | return { |
| | "success": False, |
| | "error": str(e) |
| | } |
| |
|
| | def _blend_inpaint_edges( |
| | self, |
| | original: Image.Image, |
| | inpainted: Image.Image, |
| | mask: Image.Image, |
| | feather_pixels: int = 8 |
| | ) -> Image.Image: |
| | """Blend inpainted region edges for seamless transition""" |
| | try: |
| | orig_array = np.array(original).astype(np.float32) |
| | inpaint_array = np.array(inpainted).astype(np.float32) |
| | mask_array = np.array(mask.convert('L')).astype(np.float32) / 255.0 |
| |
|
| | |
| | if feather_pixels > 0: |
| | kernel_size = feather_pixels * 2 + 1 |
| | mask_feathered = cv2.GaussianBlur( |
| | mask_array, |
| | (kernel_size, kernel_size), |
| | feather_pixels / 2 |
| | ) |
| | else: |
| | mask_feathered = mask_array |
| |
|
| | |
| | mask_3d = mask_feathered[:, :, np.newaxis] |
| |
|
| | |
| | blended = inpaint_array * mask_3d + orig_array * (1 - mask_3d) |
| | blended = np.clip(blended, 0, 255).astype(np.uint8) |
| |
|
| | return Image.fromarray(blended) |
| |
|
| | except Exception as e: |
| | logger.warning(f"Edge blending failed: {e}, returning inpainted directly") |
| | return inpainted |
| |
|