| | """
|
| | Wrapper class to call the stablediffusion.cpp shared library for GGUF support
|
| | """
|
| |
|
| | import ctypes
|
| | import platform
|
| | from ctypes import (
|
| | POINTER,
|
| | c_bool,
|
| | c_char_p,
|
| | c_float,
|
| | c_int,
|
| | c_int64,
|
| | c_void_p,
|
| | )
|
| | from dataclasses import dataclass
|
| | from os import path
|
| | from typing import List, Any
|
| |
|
| | import numpy as np
|
| | from PIL import Image
|
| |
|
| | from backend.gguf.sdcpp_types import (
|
| | RngType,
|
| | SampleMethod,
|
| | Schedule,
|
| | SDCPPLogLevel,
|
| | SDImage,
|
| | SdType,
|
| | )
|
| |
|
| |
|
| | @dataclass
|
| | class ModelConfig:
|
| | model_path: str = ""
|
| | clip_l_path: str = ""
|
| | t5xxl_path: str = ""
|
| | diffusion_model_path: str = ""
|
| | vae_path: str = ""
|
| | taesd_path: str = ""
|
| | control_net_path: str = ""
|
| | lora_model_dir: str = ""
|
| | embed_dir: str = ""
|
| | stacked_id_embed_dir: str = ""
|
| | vae_decode_only: bool = True
|
| | vae_tiling: bool = False
|
| | free_params_immediately: bool = False
|
| | n_threads: int = 4
|
| | wtype: SdType = SdType.SD_TYPE_Q4_0
|
| | rng_type: RngType = RngType.CUDA_RNG
|
| | schedule: Schedule = Schedule.DEFAULT
|
| | keep_clip_on_cpu: bool = False
|
| | keep_control_net_cpu: bool = False
|
| | keep_vae_on_cpu: bool = False
|
| |
|
| |
|
| | @dataclass
|
| | class Txt2ImgConfig:
|
| | prompt: str = "a man wearing sun glasses, highly detailed"
|
| | negative_prompt: str = ""
|
| | clip_skip: int = -1
|
| | cfg_scale: float = 2.0
|
| | guidance: float = 3.5
|
| | width: int = 512
|
| | height: int = 512
|
| | sample_method: SampleMethod = SampleMethod.EULER_A
|
| | sample_steps: int = 1
|
| | seed: int = -1
|
| | batch_count: int = 2
|
| | control_cond: Image = None
|
| | control_strength: float = 0.90
|
| | style_strength: float = 0.5
|
| | normalize_input: bool = False
|
| | input_id_images_path: bytes = b""
|
| |
|
| |
|
| | class GGUFDiffusion:
|
| | """GGUF Diffusion
|
| | To support GGUF diffusion model based on stablediffusion.cpp
|
| | https://github.com/ggerganov/ggml/blob/master/docs/gguf.md
|
| | Implmented based on stablediffusion.h
|
| | """
|
| |
|
| | def __init__(
|
| | self,
|
| | libpath: str,
|
| | config: ModelConfig,
|
| | logging_enabled: bool = False,
|
| | ):
|
| | sdcpp_shared_lib_path = self._get_sdcpp_shared_lib_path(libpath)
|
| | try:
|
| | self.libsdcpp = ctypes.CDLL(sdcpp_shared_lib_path)
|
| | except OSError as e:
|
| | print(f"Failed to load library {sdcpp_shared_lib_path}")
|
| | raise ValueError(f"Error: {e}")
|
| |
|
| | if not config.clip_l_path or not path.exists(config.clip_l_path):
|
| | raise ValueError(
|
| | "CLIP model file not found,please check readme.md for GGUF model usage"
|
| | )
|
| |
|
| | if not config.t5xxl_path or not path.exists(config.t5xxl_path):
|
| | raise ValueError(
|
| | "T5XXL model file not found,please check readme.md for GGUF model usage"
|
| | )
|
| |
|
| | if not config.diffusion_model_path or not path.exists(
|
| | config.diffusion_model_path
|
| | ):
|
| | raise ValueError(
|
| | "Diffusion model file not found,please check readme.md for GGUF model usage"
|
| | )
|
| |
|
| | if not config.vae_path or not path.exists(config.vae_path):
|
| | raise ValueError(
|
| | "VAE model file not found,please check readme.md for GGUF model usage"
|
| | )
|
| |
|
| | self.model_config = config
|
| |
|
| | self.libsdcpp.new_sd_ctx.argtypes = [
|
| | c_char_p,
|
| | c_char_p,
|
| | c_char_p,
|
| | c_char_p,
|
| | c_char_p,
|
| | c_char_p,
|
| | c_char_p,
|
| | c_char_p,
|
| | c_char_p,
|
| | c_char_p,
|
| | c_bool,
|
| | c_bool,
|
| | c_bool,
|
| | c_int,
|
| | SdType,
|
| | RngType,
|
| | Schedule,
|
| | c_bool,
|
| | c_bool,
|
| | c_bool,
|
| | ]
|
| |
|
| | self.libsdcpp.new_sd_ctx.restype = POINTER(c_void_p)
|
| |
|
| | self.sd_ctx = self.libsdcpp.new_sd_ctx(
|
| | self._str_to_bytes(self.model_config.model_path),
|
| | self._str_to_bytes(self.model_config.clip_l_path),
|
| | self._str_to_bytes(self.model_config.t5xxl_path),
|
| | self._str_to_bytes(self.model_config.diffusion_model_path),
|
| | self._str_to_bytes(self.model_config.vae_path),
|
| | self._str_to_bytes(self.model_config.taesd_path),
|
| | self._str_to_bytes(self.model_config.control_net_path),
|
| | self._str_to_bytes(self.model_config.lora_model_dir),
|
| | self._str_to_bytes(self.model_config.embed_dir),
|
| | self._str_to_bytes(self.model_config.stacked_id_embed_dir),
|
| | self.model_config.vae_decode_only,
|
| | self.model_config.vae_tiling,
|
| | self.model_config.free_params_immediately,
|
| | self.model_config.n_threads,
|
| | self.model_config.wtype,
|
| | self.model_config.rng_type,
|
| | self.model_config.schedule,
|
| | self.model_config.keep_clip_on_cpu,
|
| | self.model_config.keep_control_net_cpu,
|
| | self.model_config.keep_vae_on_cpu,
|
| | )
|
| |
|
| | if logging_enabled:
|
| | self._set_logcallback()
|
| |
|
| | def _set_logcallback(self):
|
| | print("Setting logging callback")
|
| |
|
| | SdLogCallbackType = ctypes.CFUNCTYPE(
|
| | None,
|
| | SDCPPLogLevel,
|
| | ctypes.c_char_p,
|
| | ctypes.c_void_p,
|
| | )
|
| |
|
| | self.libsdcpp.sd_set_log_callback.argtypes = [
|
| | SdLogCallbackType,
|
| | ctypes.c_void_p,
|
| | ]
|
| | self.libsdcpp.sd_set_log_callback.restype = None
|
| |
|
| | self.c_log_callback = SdLogCallbackType(
|
| | self.log_callback
|
| | )
|
| | self.libsdcpp.sd_set_log_callback(self.c_log_callback, None)
|
| |
|
| | def _get_sdcpp_shared_lib_path(
|
| | self,
|
| | root_path: str,
|
| | ) -> str:
|
| | system_name = platform.system()
|
| | print(f"GGUF Diffusion on {system_name}")
|
| | lib_name = "stable-diffusion.dll"
|
| | sdcpp_lib_path = ""
|
| |
|
| | if system_name == "Windows":
|
| | sdcpp_lib_path = path.join(root_path, lib_name)
|
| | elif system_name == "Linux":
|
| | lib_name = "libstable-diffusion.so"
|
| | sdcpp_lib_path = path.join(root_path, lib_name)
|
| | elif system_name == "Darwin":
|
| | lib_name = "libstable-diffusion.dylib"
|
| | sdcpp_lib_path = path.join(root_path, lib_name)
|
| | else:
|
| | print("Unknown platform.")
|
| |
|
| | return sdcpp_lib_path
|
| |
|
| | @staticmethod
|
| | def log_callback(
|
| | level,
|
| | text,
|
| | data,
|
| | ):
|
| | print(f"{text.decode('utf-8')}", end="")
|
| |
|
| | def _str_to_bytes(self, in_str: str, encoding: str = "utf-8") -> bytes:
|
| | if in_str:
|
| | return in_str.encode(encoding)
|
| | else:
|
| | return b""
|
| |
|
| | def generate_text2mg(self, txt2img_cfg: Txt2ImgConfig) -> List[Any]:
|
| | self.libsdcpp.txt2img.restype = POINTER(SDImage)
|
| | self.libsdcpp.txt2img.argtypes = [
|
| | c_void_p,
|
| | c_char_p,
|
| | c_char_p,
|
| | c_int,
|
| | c_float,
|
| | c_float,
|
| | c_int,
|
| | c_int,
|
| | SampleMethod,
|
| | c_int,
|
| | c_int64,
|
| | c_int,
|
| | POINTER(SDImage),
|
| | c_float,
|
| | c_float,
|
| | c_bool,
|
| | c_char_p,
|
| | ]
|
| |
|
| | image_buffer = self.libsdcpp.txt2img(
|
| | self.sd_ctx,
|
| | self._str_to_bytes(txt2img_cfg.prompt),
|
| | self._str_to_bytes(txt2img_cfg.negative_prompt),
|
| | txt2img_cfg.clip_skip,
|
| | txt2img_cfg.cfg_scale,
|
| | txt2img_cfg.guidance,
|
| | txt2img_cfg.width,
|
| | txt2img_cfg.height,
|
| | txt2img_cfg.sample_method,
|
| | txt2img_cfg.sample_steps,
|
| | txt2img_cfg.seed,
|
| | txt2img_cfg.batch_count,
|
| | txt2img_cfg.control_cond,
|
| | txt2img_cfg.control_strength,
|
| | txt2img_cfg.style_strength,
|
| | txt2img_cfg.normalize_input,
|
| | txt2img_cfg.input_id_images_path,
|
| | )
|
| |
|
| | images = self._get_sd_images_from_buffer(
|
| | image_buffer,
|
| | txt2img_cfg.batch_count,
|
| | )
|
| |
|
| | return images
|
| |
|
| | def _get_sd_images_from_buffer(
|
| | self,
|
| | image_buffer: Any,
|
| | batch_count: int,
|
| | ) -> List[Any]:
|
| | images = []
|
| | if image_buffer:
|
| | for i in range(batch_count):
|
| | image = image_buffer[i]
|
| | print(
|
| | f"Generated image: {image.width}x{image.height} with {image.channel} channels"
|
| | )
|
| |
|
| | width = image.width
|
| | height = image.height
|
| | channels = image.channel
|
| | pixel_data = np.ctypeslib.as_array(
|
| | image.data, shape=(height, width, channels)
|
| | )
|
| |
|
| | if channels == 1:
|
| | pil_image = Image.fromarray(pixel_data.squeeze(), mode="L")
|
| | elif channels == 3:
|
| | pil_image = Image.fromarray(pixel_data, mode="RGB")
|
| | elif channels == 4:
|
| | pil_image = Image.fromarray(pixel_data, mode="RGBA")
|
| | else:
|
| | raise ValueError(f"Unsupported number of channels: {channels}")
|
| |
|
| | images.append(pil_image)
|
| | return images
|
| |
|
| | def terminate(self):
|
| | if self.libsdcpp:
|
| | if self.sd_ctx:
|
| | self.libsdcpp.free_sd_ctx.argtypes = [c_void_p]
|
| | self.libsdcpp.free_sd_ctx.restype = None
|
| | self.libsdcpp.free_sd_ctx(self.sd_ctx)
|
| | del self.sd_ctx
|
| | self.sd_ctx = None
|
| | del self.libsdcpp
|
| | self.libsdcpp = None
|
| |
|