HAL1993 commited on
Commit
b0aa94a
·
verified ·
1 Parent(s): c4df9a8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +108 -28
app.py CHANGED
@@ -7,22 +7,97 @@ import spaces
7
  from PIL import Image, ImageOps
8
  from typing import Iterable, Dict
9
 
10
- # -------------------------- THEME (unchanged) -------------------------- #
 
 
11
  from gradio.themes import Soft
12
  from gradio.themes.utils import colors, fonts, sizes
13
 
14
- # (theme definition omitted for brevity – keep exactly the same as before)
15
- # ---------------------------------------------------------------------- #
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  steel_blue_theme = SteelBlueTheme()
18
 
19
- # -------------------------- DEVICE & DTYPE --------------------------- #
 
 
20
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
- # Prefer fp16 on consumer GPUs it is ~2× faster than bf16 on most cards.
22
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
23
  print(f"Using device={device}, dtype={dtype}")
24
 
25
- # -------------------------- PIPELINE SETUP --------------------------- #
 
 
26
  from diffusers import FlowMatchEulerDiscreteScheduler
27
  from qwenimage.pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline
28
  from qwenimage.transformer_qwenimage import QwenImageTransformer2DModel
@@ -40,7 +115,7 @@ pipe = QwenImageEditPlusPipeline.from_pretrained(
40
  scheduler=FlowMatchEulerDiscreteScheduler(),
41
  ).to(device)
42
 
43
- # LoRA adapters ---------------------------------------------------------
44
  pipe.load_lora_weights(
45
  "autoweeb/Qwen-Image-Edit-2509-Photo-to-Anime",
46
  weight_name="Qwen-Image-Edit-2509-Photo-to-Anime_000001000.safetensors",
@@ -84,7 +159,7 @@ pipe.load_lora_weights(
84
 
85
  pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
86
 
87
- # Speed‑up helpers -------------------------------------------------------
88
  if hasattr(pipe, "enable_xformers_memory_efficient_attention"):
89
  pipe.enable_xformers_memory_efficient_attention()
90
  if hasattr(pipe, "enable_attention_slicing"):
@@ -92,39 +167,43 @@ if hasattr(pipe, "enable_attention_slicing"):
92
 
93
  MAX_SEED = np.iinfo(np.int32).max
94
 
95
- # -------------------------- UTILITIES --------------------------- #
 
 
96
  def _pad_to_multiple_of(value: int, divisor: int = 8) -> int:
97
  """Round `value` down to the nearest multiple of `divisor`."""
98
  return (value // divisor) * divisor
99
 
100
  def prepare_image(image: Image.Image, max_side: int = 1024) -> tuple[Image.Image, tuple[int, int]]:
101
  """
102
- 1️⃣ Scale the image so that the longest side equals `max_side` (preserving aspect ratio).
103
- 2️⃣ Pad the scaled image on the right / bottom so that both dimensions are a multiple of 8.
104
- 3️⃣ Return the padded image **and** the (pad_w, pad_h) that were added – needed to crop the result later.
105
  """
106
- # ---- 1️⃣ Scale ----------------------------------------------------
107
  w, h = image.size
108
  scale = max_side / max(w, h)
109
  new_w, new_h = int(round(w * scale)), int(round(h * scale))
110
 
111
- # ---- 2️⃣ Pad to 8‑multiple -----------------------------------------
112
  pad_w = _pad_to_multiple_of(new_w) - new_w
113
  pad_h = _pad_to_multiple_of(new_h) - new_h
114
- # Pad on the *right* and *bottom* only – easier to crop later
115
- padded = ImageOps.expand(image.resize((new_w, new_h), Image.LANCZOS), border=(0, 0, pad_w, pad_h), fill=0)
 
116
 
117
  return padded, (pad_w, pad_h)
118
 
119
  def crop_to_original(pil_img: Image.Image, pad: tuple[int, int]) -> Image.Image:
120
- """Remove the padding that `prepare_image` added."""
121
  pad_w, pad_h = pad
122
  if pad_w == 0 and pad_h == 0:
123
  return pil_img
124
  w, h = pil_img.size
125
  return pil_img.crop((0, 0, w - pad_w, h - pad_h))
126
 
127
- # -------------------------- INFERENCE --------------------------- #
 
 
128
  @spaces.GPU(duration=30)
129
  def infer(
130
  input_image,
@@ -139,7 +218,7 @@ def infer(
139
  if input_image is None:
140
  raise gr.Error("Please upload an image to edit.")
141
 
142
- # ---- LoRA selection (dictionary makes it easy to extend) ----------
143
  lora_map: Dict[str, str] = {
144
  "Photo-to-Anime": "anime",
145
  "Multiple-Angles": "multiple-angles",
@@ -154,16 +233,16 @@ def infer(
154
  if adapter_name:
155
  pipe.set_adapters([adapter_name], adapter_weights=[1.0])
156
 
157
- # ---- Seed handling -------------------------------------------------
158
  if randomize_seed:
159
  seed = random.randint(0, MAX_SEED)
160
  generator = torch.Generator(device=device).manual_seed(seed)
161
 
162
- # ---- Image preprocessing (aspect‑ratio preserving) -----------------
163
  original = input_image.convert("RGB")
164
- processed, pad = prepare_image(original, max_side=1024) # 1024 is the model's native resolution
165
 
166
- # ---- Run the pipeline -----------------------------------------------
167
  negative_prompt = (
168
  "worst quality, low quality, bad anatomy, bad hands, text, error, "
169
  "missing fingers, extra digit, fewer digits, cropped, jpeg artifacts, "
@@ -180,7 +259,7 @@ def infer(
180
  true_cfg_scale=guidance_scale,
181
  ).images[0]
182
 
183
- # ---- Remove the padding so the output matches the original aspect ----
184
  result = crop_to_original(result, pad)
185
 
186
  return result, seed
@@ -189,8 +268,8 @@ def infer(
189
  @spaces.GPU(duration=30)
190
  def infer_example(input_image, prompt, lora_adapter):
191
  """
192
- A tiny wrapper used by the Gradio examples – it forces a deterministic
193
- fast run (4 steps, guidance=1.0) and always randomises the seed.
194
  """
195
  pil = input_image.convert("RGB")
196
  result, seed = infer(
@@ -204,8 +283,9 @@ def infer_example(input_image, prompt, lora_adapter):
204
  )
205
  return result, seed
206
 
207
-
208
- # -------------------------- GRADIO UI --------------------------- #
 
209
  css = """
210
  #col-container {margin: 0 auto; max-width: 960px;}
211
  #main-title h1 {font-size: 2.1em !important;}
 
7
  from PIL import Image, ImageOps
8
  from typing import Iterable, Dict
9
 
10
+ # --------------------------------------------------------------
11
+ # 🎨 CUSTOM GRADIO THEME (exactly as you wrote it originally)
12
+ # --------------------------------------------------------------
13
  from gradio.themes import Soft
14
  from gradio.themes.utils import colors, fonts, sizes
15
 
16
+ # ---- colour palette ------------------------------------------------
17
+ colors.steel_blue = colors.Color(
18
+ name="steel_blue",
19
+ c50="#EBF3F8",
20
+ c100="#D3E5F0",
21
+ c200="#A8CCE1",
22
+ c300="#7DB3D2",
23
+ c400="#529AC3",
24
+ c500="#4682B4",
25
+ c600="#3E72A0",
26
+ c700="#36638C",
27
+ c800="#2E5378",
28
+ c900="#264364",
29
+ c950="#1E3450",
30
+ )
31
+
32
+ # ---- theme class ---------------------------------------------------
33
+ class SteelBlueTheme(Soft):
34
+ def __init__(
35
+ self,
36
+ *,
37
+ primary_hue: colors.Color | str = colors.gray,
38
+ secondary_hue: colors.Color | str = colors.steel_blue,
39
+ neutral_hue: colors.Color | str = colors.slate,
40
+ text_size: sizes.Size | str = sizes.text_lg,
41
+ font: fonts.Font | str | Iterable[fonts.Font | str] = (
42
+ fonts.GoogleFont("Outfit"),
43
+ "Arial",
44
+ "sans-serif",
45
+ ),
46
+ font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
47
+ fonts.GoogleFont("IBM Plex Mono"),
48
+ "ui-monospace",
49
+ "monospace",
50
+ ),
51
+ ):
52
+ super().__init__(
53
+ primary_hue=primary_hue,
54
+ secondary_hue=secondary_hue,
55
+ neutral_hue=neutral_hue,
56
+ text_size=text_size,
57
+ font=font,
58
+ font_mono=font_mono,
59
+ )
60
+ super().set(
61
+ background_fill_primary="*primary_50",
62
+ background_fill_primary_dark="*primary_900",
63
+ body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
64
+ body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)",
65
+ button_primary_text_color="white",
66
+ button_primary_text_color_hover="white",
67
+ button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
68
+ button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
69
+ button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_800)",
70
+ button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_500)",
71
+ button_secondary_text_color="black",
72
+ button_secondary_text_color_hover="white",
73
+ button_secondary_background_fill="linear-gradient(90deg, *primary_300, *primary_300)",
74
+ button_secondary_background_fill_hover="linear-gradient(90deg, *primary_400, *primary_400)",
75
+ button_secondary_background_fill_dark="linear-gradient(90deg, *primary_500, *primary_600)",
76
+ button_secondary_background_fill_hover_dark="linear-gradient(90deg, *primary_500, *primary_500)",
77
+ slider_color="*secondary_500",
78
+ slider_color_dark="*secondary_600",
79
+ block_title_text_weight="600",
80
+ block_border_width="3px",
81
+ block_shadow="*shadow_drop_lg",
82
+ button_primary_shadow="*shadow_drop_lg",
83
+ button_large_padding="11px",
84
+ color_accent_soft="*primary_100",
85
+ block_label_background_fill="*primary_200",
86
+ )
87
 
88
  steel_blue_theme = SteelBlueTheme()
89
 
90
+ # --------------------------------------------------------------
91
+ # 🖥️ DEVICE & DTYPE
92
+ # --------------------------------------------------------------
93
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
94
+ # fp16 is the fastest on most consumer GPUs; fall back to fp32 if no CUDA.
95
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
96
  print(f"Using device={device}, dtype={dtype}")
97
 
98
+ # --------------------------------------------------------------
99
+ # 🚀 PIPELINE & LoRA SETUP
100
+ # --------------------------------------------------------------
101
  from diffusers import FlowMatchEulerDiscreteScheduler
102
  from qwenimage.pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline
103
  from qwenimage.transformer_qwenimage import QwenImageTransformer2DModel
 
115
  scheduler=FlowMatchEulerDiscreteScheduler(),
116
  ).to(device)
117
 
118
+ # ----- Load all LoRA adapters ------------------------------------------------
119
  pipe.load_lora_weights(
120
  "autoweeb/Qwen-Image-Edit-2509-Photo-to-Anime",
121
  weight_name="Qwen-Image-Edit-2509-Photo-to-Anime_000001000.safetensors",
 
159
 
160
  pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
161
 
162
+ # ----- Speed‑up helpers --------------------------------------------------------
163
  if hasattr(pipe, "enable_xformers_memory_efficient_attention"):
164
  pipe.enable_xformers_memory_efficient_attention()
165
  if hasattr(pipe, "enable_attention_slicing"):
 
167
 
168
  MAX_SEED = np.iinfo(np.int32).max
169
 
170
+ # --------------------------------------------------------------
171
+ # 🛠️ UTILITIES (aspect‑ratio‑preserving preprocessing)
172
+ # --------------------------------------------------------------
173
  def _pad_to_multiple_of(value: int, divisor: int = 8) -> int:
174
  """Round `value` down to the nearest multiple of `divisor`."""
175
  return (value // divisor) * divisor
176
 
177
  def prepare_image(image: Image.Image, max_side: int = 1024) -> tuple[Image.Image, tuple[int, int]]:
178
  """
179
+ 1️⃣ Scale the image so its longest side = `max_side` (keeps AR).
180
+ 2️⃣ Pad the scaled image on the right/bottom to a multiple of 8.
181
+ 3️⃣ Return the padded image **and** the padding that was added.
182
  """
 
183
  w, h = image.size
184
  scale = max_side / max(w, h)
185
  new_w, new_h = int(round(w * scale)), int(round(h * scale))
186
 
187
+ # Pad to the nearest 8‑multiple (required by the UNet)
188
  pad_w = _pad_to_multiple_of(new_w) - new_w
189
  pad_h = _pad_to_multiple_of(new_h) - new_h
190
+
191
+ resized = image.resize((new_w, new_h), Image.LANCZOS)
192
+ padded = ImageOps.expand(resized, border=(0, 0, pad_w, pad_h), fill=0) # black padding
193
 
194
  return padded, (pad_w, pad_h)
195
 
196
  def crop_to_original(pil_img: Image.Image, pad: tuple[int, int]) -> Image.Image:
197
+ """Remove the padding added by `prepare_image`."""
198
  pad_w, pad_h = pad
199
  if pad_w == 0 and pad_h == 0:
200
  return pil_img
201
  w, h = pil_img.size
202
  return pil_img.crop((0, 0, w - pad_w, h - pad_h))
203
 
204
+ # --------------------------------------------------------------
205
+ # 🤖 INFERENCE
206
+ # --------------------------------------------------------------
207
  @spaces.GPU(duration=30)
208
  def infer(
209
  input_image,
 
218
  if input_image is None:
219
  raise gr.Error("Please upload an image to edit.")
220
 
221
+ # ----- LoRA selection via a dict (easier to extend) -----
222
  lora_map: Dict[str, str] = {
223
  "Photo-to-Anime": "anime",
224
  "Multiple-Angles": "multiple-angles",
 
233
  if adapter_name:
234
  pipe.set_adapters([adapter_name], adapter_weights=[1.0])
235
 
236
+ # ----- Seed handling -----
237
  if randomize_seed:
238
  seed = random.randint(0, MAX_SEED)
239
  generator = torch.Generator(device=device).manual_seed(seed)
240
 
241
+ # ----- Image preprocessing (keeps AR) -----
242
  original = input_image.convert("RGB")
243
+ processed, pad = prepare_image(original, max_side=1024)
244
 
245
+ # ----- Run the pipeline -----
246
  negative_prompt = (
247
  "worst quality, low quality, bad anatomy, bad hands, text, error, "
248
  "missing fingers, extra digit, fewer digits, cropped, jpeg artifacts, "
 
259
  true_cfg_scale=guidance_scale,
260
  ).images[0]
261
 
262
+ # ----- Remove padding so output matches original AR -----
263
  result = crop_to_original(result, pad)
264
 
265
  return result, seed
 
268
  @spaces.GPU(duration=30)
269
  def infer_example(input_image, prompt, lora_adapter):
270
  """
271
+ Wrapper used by the Gradio examples – always runs a fast
272
+ (4‑step, guidance=1.0) inference and randomises the seed.
273
  """
274
  pil = input_image.convert("RGB")
275
  result, seed = infer(
 
283
  )
284
  return result, seed
285
 
286
+ # --------------------------------------------------------------
287
+ # 🎛️ GRADIO UI
288
+ # --------------------------------------------------------------
289
  css = """
290
  #col-container {margin: 0 auto; max-width: 960px;}
291
  #main-title h1 {font-size: 2.1em !important;}