Spaces:
Runtime error
Runtime error
| # -*- coding: utf-8 -*- | |
| """Stylegan-nada-ailanta.ipynb | |
| Automatically generated by Colab. | |
| Original file is located at | |
| https://colab.research.google.com/drive/1ysq4Y2sv7WTE0sW-n5W_HSgE28vaUDNE | |
| # Проект "CLIP-Guided Domain Adaptation of Image Generators" | |
| Данный проект представляет собой имплементацию подхода StyleGAN-NADA, предложенного в статье [StyleGAN-NADA: CLIP-Guided Domain Adaptation of Image Generators](https://arxiv.org/pdf/2108.00946). | |
| Представленный ниже функционал предназначен для визуализации реализованного проекта и включает в себя: | |
| - Сдвиг генератора по текстовому промпту | |
| - Генерация примеров | |
| - Генерация примеров из готовых пресетов | |
| - Веб-демо | |
| - Стилизация изображения из файла | |
| ## 1. Установка | |
| """ | |
| # @title | |
| # Импорт нужных библиотек | |
| import os | |
| import sys | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torchvision import transforms | |
| from torchvision.utils import save_image | |
| from PIL import Image | |
| import numpy as np | |
| import gradio as gr | |
| import subprocess | |
| import gdown | |
| # Настройка устройства | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| if not os.path.exists("stylegan2-pytorch"): | |
| subprocess.run(["git", "clone", "https://github.com/rosinality/stylegan2-pytorch.git"]) | |
| os.chdir("stylegan2-pytorch") | |
| gdown.download('https://drive.google.com/uc?id=1EM87UquaoQmk17Q8d5kYIAHqu0dkYqdT') | |
| gdown.download('https://drive.google.com/uc?id=1N0MZSqPRJpLfP4mFQCS14ikrVSe8vQlL') | |
| sys.path.append("/home/user/app/stylegan2-pytorch") | |
| from model import Generator | |
| # Параметры генератора | |
| latent_dim = 512 | |
| f_generator = Generator(size=1024, style_dim=latent_dim, n_mlp=8).to(device) | |
| state_dict = torch.load('stylegan2-ffhq-config-f.pt', map_location=device) | |
| f_generator.load_state_dict(state_dict['g_ema']) | |
| f_generator.eval() | |
| # Загрузка пресетов | |
| os.makedirs("/content/presets", exist_ok=True) | |
| gdown.download('https://drive.google.com/uc?id=1trcBvlz7jeBRLNeCyNVCXE4esW25GPaZ', '/content/presets/sketch.pth') | |
| gdown.download('https://drive.google.com/uc?id=1N4C-aTwxeOamZX2GeEElppsMv-ALKojL', '/content/presets/modigliani.pth') | |
| gdown.download('https://drive.google.com/uc?id=1VZHEalFyEFGWIaHei98f9XPyHHvMBp6J', '/content/presets/werewolf.pth') | |
| # Загрузка генератора из файла | |
| def load_model(file_path, latent_dim=512, size=1024): | |
| state_dicts = torch.load(file_path, map_location=device) | |
| # Инициализация | |
| trained_generator = Generator(size=size, style_dim=latent_dim, n_mlp=8).to(device) | |
| # Загрузка весов | |
| trained_generator.load_state_dict(state_dicts) | |
| trained_generator.eval() | |
| return trained_generator | |
| model_paths = { | |
| "Photo -> Pencil Sketch": "/content/presets/sketch.pth", | |
| "Photo -> Modigliani Painting": "/content/presets/modigliani.pth", | |
| "Human -> Werewolf": "/content/presets/werewolf.pth" | |
| } | |
| # Функция обработки | |
| def generate(model_name): | |
| model_path = model_paths[model_name] | |
| g_generator = load_model(model_path) | |
| images = [] | |
| with torch.no_grad(): | |
| w_optimized = f_generator.style(torch.randn(2, latent_dim).to(device)) | |
| w_plus = w_optimized.unsqueeze(1).repeat(1, f_generator.n_latent, 1).clone() | |
| frozen_images = f_generator(w_plus.unsqueeze(0), input_is_latent=True)[0] | |
| frozen_images = (frozen_images.clamp(-1, 1) + 1) / 2.0 # Нормализация к [0, 1] | |
| frozen_images = frozen_images.permute(0, 2, 3, 1).cpu().numpy() | |
| images.extend(frozen_images) | |
| trained_images = g_generator(w_plus.unsqueeze(0), input_is_latent=True)[0] | |
| trained_images = (trained_images.clamp(-1, 1) + 1) / 2.0 # Нормализация к [0, 1] | |
| trained_images = trained_images.permute(0, 2, 3, 1).cpu().numpy() | |
| images.extend(trained_images) | |
| return images | |
| # Интерфейс | |
| iface = gr.Interface( | |
| fn=generate, | |
| inputs=gr.Dropdown(choices=list(model_paths.keys()), label="Выберите пресет"), | |
| outputs=gr.Gallery(label="Результаты генерации", columns=2), | |
| title="Выбор модели", | |
| description="Выберите преобразование из списка." | |
| ) | |
| iface.launch(debug=True) | |