| | from fastapi import FastAPI, HTTPException, Response |
| | from pydantic import BaseModel |
| | from contextlib import asynccontextmanager |
| | import numpy as np |
| | from PIL import Image |
| | import io |
| | import uuid |
| | from typing import List, Union |
| |
|
| | import axengine |
| | import torch |
| |
|
| | from transformers import CLIPTokenizer, PreTrainedTokenizer |
| | import time |
| | import argparse |
| |
|
| | import os |
| | import traceback |
| | from diffusers import DPMSolverMultistepScheduler |
| | |
| | DEBUG_MODE = True |
| | LOG_TIMESTAMP = True |
| |
|
| | def debug_log(msg): |
| | if DEBUG_MODE: |
| | timestamp = f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] " if LOG_TIMESTAMP else "" |
| | print(f"{timestamp}[DEBUG] {msg}") |
| | |
| | |
| | MODEL_PATHS = { |
| | "tokenizer": "./models/tokenizer", |
| | "text_encoder": "./models/text_encoder/sd15_text_encoder_sim.axmodel", |
| | "unet": "./models/unet.axmodel", |
| | "vae": "./models/vae_decoder.axmodel", |
| | "time_embeddings": "./models/time_input_dpmpp_20steps.npy" |
| | } |
| |
|
| | class DiffusionModels: |
| | def __init__(self): |
| | self.models_loaded = False |
| | self.tokenizer = None |
| | self.text_encoder = None |
| | self.unet = None |
| | self.vae = None |
| | self.time_embeddings = None |
| |
|
| | def load_models(self): |
| | """预加载所有模型到内存""" |
| | try: |
| | |
| | self.tokenizer = CLIPTokenizer.from_pretrained(MODEL_PATHS["tokenizer"]) |
| | self.text_encoder = axengine.InferenceSession(MODEL_PATHS["text_encoder"]) |
| | self.unet = axengine.InferenceSession(MODEL_PATHS["unet"]) |
| | self.vae = axengine.InferenceSession(MODEL_PATHS["vae"]) |
| | |
| | |
| | full_time_embeddings = np.load(MODEL_PATHS["time_embeddings"]) |
| | |
| | self.time_embeddings = full_time_embeddings[::2] |
| | debug_log(f"时间嵌入已从20步采样为10步,形状: {self.time_embeddings.shape}") |
| | |
| | self.models_loaded = True |
| | print("所有模型已成功加载到内存") |
| | except Exception as e: |
| | print(f"模型加载失败: {str(e)}") |
| | raise |
| |
|
| | diffusion_models = DiffusionModels() |
| |
|
| | @asynccontextmanager |
| | async def lifespan(app: FastAPI): |
| | |
| | diffusion_models.load_models() |
| | yield |
| | |
| | |
| |
|
| | app = FastAPI(lifespan=lifespan) |
| |
|
| | class GenerationRequest(BaseModel): |
| | positive_prompt: str |
| | negative_prompt: str = "" |
| | |
| | |
| | |
| | seed: int = None |
| |
|
| | @app.post("/generate") |
| | async def generate_image(request: GenerationRequest): |
| | try: |
| | |
| | if len(request.positive_prompt) > 1000: |
| | raise ValueError("提示词过长") |
| | |
| | |
| | image = generate_diffusion_image( |
| | positive_prompt=request.positive_prompt, |
| | negative_prompt=request.negative_prompt, |
| | num_steps=10, |
| | guidance_scale=5.4, |
| | seed=request.seed |
| | ) |
| | |
| | |
| | img_byte_arr = io.BytesIO() |
| | image.save(img_byte_arr, format='PNG') |
| | |
| | return Response(content=img_byte_arr.getvalue(), media_type="image/png") |
| | |
| | except Exception as e: |
| | error_id = str(uuid.uuid4()) |
| | print(f"Error [{error_id}]: {str(e)}") |
| | raise HTTPException( |
| | status_code=500, |
| | detail=f"生成失败,错误ID:{error_id}" |
| | ) |
| | |
| | |
| | |
| | def get_embeds(prompt, negative_prompt): |
| | """获取正负提示词的嵌入(带形状验证)""" |
| | try: |
| | debug_log(f"开始处理提示词: {prompt}") |
| | start_time = time.time() |
| | |
| | |
| | def process_prompt(prompt_text): |
| | inputs = diffusion_models.tokenizer( |
| | prompt_text, |
| | padding="max_length", |
| | max_length=77, |
| | truncation=True, |
| | return_tensors="pt" |
| | ) |
| | debug_log(f"Tokenizer输出形状: {inputs.input_ids.shape}") |
| | |
| | outputs = diffusion_models.text_encoder.run(None, {"input_ids": inputs.input_ids.numpy().astype(np.int32)})[0] |
| | debug_log(f"文本编码器输出形状: {outputs.shape} | dtype: {outputs.dtype}") |
| | return outputs |
| | |
| | neg_start = time.time() |
| | neg_embeds = process_prompt(negative_prompt) |
| | pos_embeds = process_prompt(prompt) |
| | debug_log(f"文本编码完成 | 耗时: {(time.time()-start_time):.2f}s") |
| | |
| | |
| | if neg_embeds.shape != (1, 77, 768) or pos_embeds.shape != (1, 77, 768): |
| | raise ValueError(f"嵌入形状异常: 负面{neg_embeds.shape}, 正面{pos_embeds.shape}") |
| | |
| | return neg_embeds, pos_embeds |
| | except Exception as e: |
| | print(f"获取嵌入失败: {str(e)}") |
| | traceback.print_exc() |
| | exit(1) |
| |
|
| |
|
| | def generate_diffusion_image( |
| | positive_prompt: str, |
| | negative_prompt: str, |
| | num_steps: int = 10, |
| | guidance_scale: float = 5.4, |
| | seed: int = None |
| | ) -> Image.Image: |
| | """ |
| | 生成扩散图像的优化版本(固定10步推理,CFG=5.4) |
| | |
| | 参数: |
| | positive_prompt (str): 正向提示词 |
| | negative_prompt (str): 负向提示词 |
| | num_steps (int): 推理步数 (固定为10) |
| | guidance_scale (float): 分类器自由引导系数 (固定为5.4) |
| | seed (int): 随机种子 (可选) |
| | |
| | 返回: |
| | PIL.Image.Image: 生成的图像 |
| | |
| | 异常: |
| | ValueError: 输入参数无效时抛出 |
| | RuntimeError: 推理过程中出现错误时抛出 |
| | """ |
| | try: |
| | |
| | if not positive_prompt: |
| | raise ValueError("正向提示词不能为空") |
| | |
| | |
| | num_steps = 10 |
| | guidance_scale = 5.4 |
| | |
| | debug_log(f"开始生成流程 (固定参数: 10步, CFG=5.4)...") |
| | start_time = time.time() |
| |
|
| | |
| | |
| | |
| | seed = seed if seed is not None else int(time.time() * 1000) % 0xFFFFFFFF |
| | torch.manual_seed(seed) |
| | np.random.seed(seed) |
| | debug_log(f"初始随机种子: {seed}") |
| |
|
| | |
| | |
| | |
| | embed_start = time.time() |
| | neg_emb, pos_emb = get_embeds( |
| | positive_prompt, |
| | negative_prompt, |
| | ) |
| | debug_log(f"文本编码完成 | 耗时: {time.time()-embed_start:.2f}s") |
| |
|
| | |
| | |
| | |
| | scheduler = DPMSolverMultistepScheduler( |
| | num_train_timesteps=1000, |
| | beta_start=0.00085, |
| | beta_end=0.012, |
| | beta_schedule="scaled_linear", |
| | algorithm_type="dpmsolver++", |
| | use_karras_sigmas=True |
| | ) |
| | scheduler.set_timesteps(num_steps) |
| | |
| | latents_shape = (1, 4, 60, 40) |
| | latent = torch.randn(latents_shape, generator=torch.Generator().manual_seed(seed)) |
| | latent = latent * scheduler.init_noise_sigma |
| | latent = latent.numpy().astype(np.float32) |
| | debug_log(f"潜在变量初始化 | 形状: {latent.shape} sigma:{scheduler.init_noise_sigma:.3f}") |
| |
|
| | |
| | |
| | |
| | if len(diffusion_models.time_embeddings) != num_steps: |
| | raise ValueError(f"时间嵌入步数不匹配: 需要{num_steps}步 当前{len(diffusion_models.time_embeddings)}步") |
| | time_steps = diffusion_models.time_embeddings |
| | debug_log(f"使用预处理的10步时间嵌入,形状: {time_steps.shape}") |
| |
|
| | |
| | |
| | |
| | debug_log("开始10步采样循环...") |
| | for step_idx, timestep in enumerate(scheduler.timesteps.numpy().astype(np.int64)): |
| | step_start = time.time() |
| | |
| | |
| | time_emb = np.expand_dims(time_steps[step_idx], axis=0) |
| |
|
| | |
| | |
| | |
| | |
| | noise_pred_neg = diffusion_models.unet.run(None, { |
| | "sample": latent, |
| | "/down_blocks.0/resnets.0/act_1/Mul_output_0": time_emb, |
| | "encoder_hidden_states": neg_emb |
| | })[0] |
| | |
| | |
| | noise_pred_pos = diffusion_models.unet.run(None, { |
| | "sample": latent, |
| | "/down_blocks.0/resnets.0/act_1/Mul_output_0": time_emb, |
| | "encoder_hidden_states": pos_emb |
| | })[0] |
| |
|
| | |
| | noise_pred = noise_pred_neg + 5.4 * (noise_pred_pos - noise_pred_neg) |
| |
|
| | |
| | latent_tensor = torch.from_numpy(latent) |
| | noise_pred_tensor = torch.from_numpy(noise_pred) |
| | |
| | |
| | scheduler_start = time.time() |
| | latent_tensor = scheduler.step( |
| | model_output=noise_pred_tensor, |
| | timestep=timestep, |
| | sample=latent_tensor |
| | ).prev_sample |
| | debug_log(f"调度器更新完成 | 耗时: {(time.time()-scheduler_start):.2f}s") |
| | |
| | |
| | latent = latent_tensor.numpy().astype(np.float32) |
| | debug_log(f"更新后潜在变量范围: [{latent.min():.3f}, {latent.max():.3f}]") |
| |
|
| | debug_log(f"步骤 {step_idx+1}/{num_steps} | 耗时: {time.time()-step_start:.2f}s") |
| |
|
| | |
| | |
| | |
| | debug_log("开始VAE解码...") |
| | vae_start = time.time() |
| | latent = latent / 0.18215 |
| | image = diffusion_models.vae.run(None, {"latent": latent})[0] |
| | |
| | |
| | image = np.transpose(image.squeeze(), (1, 2, 0)) |
| | image = np.clip((image / 2 + 0.5) * 255, 0, 255).astype(np.uint8) |
| | pil_image = Image.fromarray(image[..., :3]) |
| | pil_image.save("./api.png") |
| | debug_log(f"VAE解码完成 | 耗时: {time.time()-vae_start:.2f}s") |
| | debug_log(f"总耗时: {time.time()-start_time:.2f}s (10步优化版)") |
| | return pil_image |
| |
|
| | except Exception as e: |
| | error_msg = f"生成失败: {str(e)}" |
| | debug_log(error_msg) |
| | traceback.print_exc() |
| | raise RuntimeError(error_msg) |