codic's picture
Update app.py
ea25121 verified
from skimage import img_as_uint
from skimage.filters import gaussian
from skimage.segmentation import clear_border
from scipy.ndimage import gaussian_filter
import gradio as gr
from transformers import DPTFeatureExtractor, DPTForDepthEstimation
import torch
import numpy as np
from PIL import Image
import open3d as o3d
from pathlib import Path
import os
import cv2
from rembg import remove
# Initialize model and feature extractor for depth estimation
feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large")
def process_image(image_path, depth_map_path=None):
image_path = Path(image_path)
image_raw = Image.open(image_path).convert("RGB")
image = image_raw.resize(
(2048, int(2048 * image_raw.size[1] / image_raw.size[0])),
Image.Resampling.LANCZOS,
)
# Remove background using rembg
foreground = remove(image_raw)
foreground = Image.fromarray(np.array(foreground)).convert("RGB")
# Check if user-provided depth map is available
if depth_map_path:
if depth_map_path.endswith('.npy'):
depth_image = np.load(depth_map_path)
# Invert depth only for npy files
depth_image = 1 - depth_image
else:
depth_image_raw = Image.open(depth_map_path).convert("L")
depth_image = depth_image_raw.resize(image.size, Image.Resampling.NEAREST)
depth_image = np.array(depth_image)
# Normalize depth image to [0, 1] range
depth_image = (depth_image - np.min(depth_image)) / (np.max(depth_image) - np.min(depth_image))
depth_image = np.clip(depth_image, 0, 1)
depth_image = img_as_uint(depth_image) # Changed to uint16
else:
# Generate depth map using DPT model
encoding = feature_extractor(foreground, return_tensors="pt")
with torch.no_grad():
outputs = model(**encoding)
predicted_depth = outputs.predicted_depth
prediction = torch.nn.functional.interpolate(
predicted_depth.unsqueeze(1),
size=image.size[::-1],
mode="bicubic",
align_corners=False,
).squeeze()
depth_image = prediction.cpu().numpy()
# Normalize depth image to [0, 1] range before converting
depth_image = (depth_image - np.min(depth_image)) / (np.max(depth_image) - np.min(depth_image))
depth_image = np.clip(depth_image, 0, 1)
depth_image = img_as_uint(depth_image) # Changed to uint16
# Resize and apply lighter Gaussian smoothing for smoother transitions
depth_image = cv2.resize(depth_image, (image.size[0], image.size[1]), interpolation=cv2.INTER_CUBIC)
mask_blurred = cv2.GaussianBlur(depth_image, (3, 3), 0.8) # Lighter blur for transition
# Ensure matching size and single channel for blending
if len(depth_image.shape) == 3:
depth_image = cv2.cvtColor(depth_image, cv2.COLOR_BGR2GRAY)
if len(mask_blurred.shape) == 3:
mask_blurred = cv2.cvtColor(mask_blurred, cv2.COLOR_BGR2GRAY)
mask_blurred_resized = cv2.resize(mask_blurred, (depth_image.shape[1], depth_image.shape[0]), interpolation=cv2.INTER_LINEAR)
# Blend depth_image with mask_blurred_resized for smoother transition
blended_depth_image = cv2.addWeighted(depth_image, 0.9, mask_blurred_resized, 0.1, 0)
# Inpaint background where main subject was removed
background = remove(image_raw, only_mask=True)
background = cv2.cvtColor(np.array(background), cv2.COLOR_GRAY2BGR)
background_inpainted = cv2.inpaint(np.array(image_raw), background[:, :, 0], inpaintRadius=3, flags=cv2.INPAINT_TELEA)
try:
gltf_path = create_3d_obj(np.array(foreground), blended_depth_image, background_inpainted, image_path)
img = Image.fromarray(blended_depth_image)
return [img, gltf_path, gltf_path]
except Exception as e:
print(f"Error with default depth: {str(e)}. Retrying with a shallower depth.")
gltf_path = create_3d_obj(np.array(foreground), blended_depth_image, background_inpainted, image_path, depth=9)
img = Image.fromarray(blended_depth_image)
return [img, gltf_path, gltf_path]
import cv2 # Make sure OpenCV is installed
def create_3d_obj(foreground, depth_image, background, image_path, depth=10):
if depth_image.ndim != 2:
raise ValueError("Depth image should be a 2D array, but got: {}".format(depth_image.shape))
depth_image = depth_image.astype(np.uint16)
# Convert depth image to Open3D compatible format without PIL if needed
depth_image_resized = cv2.resize(depth_image, (foreground.shape[1], foreground.shape[0]), interpolation=cv2.INTER_LINEAR)
# Check for valid shape and dtype
if depth_image_resized.ndim != 2 or depth_image_resized.shape[0] == 0 or depth_image_resized.shape[1] == 0:
raise ValueError(f"Resized depth image is not valid. Shape: {depth_image_resized.shape}")
depth_o3d = o3d.geometry.Image(depth_image_resized)
image_o3d = o3d.geometry.Image(foreground)
rgbd_image = o3d.geometry.RGBDImage.create_from_color_and_depth(
image_o3d, depth_o3d, convert_rgb_to_intensity=False
)
# Set camera intrinsic parameters
w, h = foreground.shape[1], foreground.shape[0]
camera_intrinsic = o3d.camera.PinholeCameraIntrinsic()
camera_intrinsic.set_intrinsics(w, h, 1500, 1500, w / 2, h / 2)
# Create point cloud from RGBD image
pcd = o3d.geometry.PointCloud.create_from_rgbd_image(rgbd_image, camera_intrinsic)
if len(pcd.points) < 100:
raise RuntimeError("Insufficient points in the point cloud for normals estimation.")
# Estimate normals
if not pcd.has_normals():
pcd.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.25, max_nn=50))
if pcd.has_normals():
pcd.orient_normals_towards_camera_location(camera_location=np.array([0.0, 0.0, 1500.0]))
else:
raise RuntimeError("Failed to estimate normals for the point cloud.")
# Increase the depth parameter dynamically based on image size for finer details in complex images
adjusted_depth = depth + int(min(w, h) / 500)
# Create mesh using Poisson reconstruction
mesh_raw, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(
pcd, depth=adjusted_depth, width=0, scale=1.2, linear_fit=True
)
# Simplify mesh
voxel_size = max(mesh_raw.get_max_bound() - mesh_raw.get_min_bound()) / 400
mesh = mesh_raw.simplify_vertex_clustering(
voxel_size=voxel_size,
contraction=o3d.geometry.SimplificationContraction.Average,
)
mesh = mesh.filter_smooth_simple(number_of_iterations=3)
# Crop mesh to point cloud bounding box
bbox = pcd.get_axis_aligned_bounding_box()
mesh_crop = mesh.crop(bbox)
# Save as GLTF
gltf_path = f"./{image_path.stem}.gltf"
o3d.io.write_triangle_mesh(gltf_path, mesh_crop, write_triangle_uvs=True)
return gltf_path
# Gradio Interface
title = "Depth Estimation & 3D Reconstruction Demo"
description = "Upload an image and optionally a depth map (in .npy or image format) to generate a 3D model. If no depth map is provided, the DPT model will generate it."
examples = [["examples/" + img] for img in os.listdir("examples/")]
iface = gr.Interface(
fn=process_image,
inputs=[
gr.Image(type="filepath", label="Input Image"),
gr.File(type="filepath", label="Input Depth Map (optional)"),
],
outputs=[
gr.Image(label="Predicted Depth", type="pil"),
gr.Model3D(label="3D Mesh Reconstruction", clear_color=[1.0, 1.0, 1.0, 1.0]),
gr.File(label="3D gLTF File"),
],
title=title,
description=description,
examples=examples,
allow_flagging="never",
cache_examples=False,
)
iface.launch(debug=True, show_api=True, share=True)