fcakyon commited on
Commit
3dd1121
·
verified ·
1 Parent(s): 8e076e3

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +166 -0
utils.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import gradio as gr
4
+ from typing import Any, Dict, Tuple
5
+ from urllib.request import urlopen, Request
6
+ from io import BytesIO
7
+ from PIL import Image
8
+ from functools import lru_cache
9
+
10
+ _MODEL_CACHE: Dict[str, Any] = {}
11
+
12
+ EXAMPLE_ITEMS = [
13
+ (
14
+ "https://assets.clevelandclinic.org/transform/LargeFeatureImage/cd71f4bd-81d4-45d8-a450-74df78e4477a/Apples-184940975-770x533-1_jpg",
15
+ "viddexa/nsfw-mini",
16
+ "Apples (mini)",
17
+ ),
18
+ (
19
+ "https://img.freepik.com/free-photo/breast-screening-is-very-important-every-woman_329181-14953.jpg",
20
+ "viddexa/nsfw-nano",
21
+ "Breast screening (nano)",
22
+ ),
23
+ (
24
+ "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcSbRwt56NYsiHwrT8oS-igzgeEzp7p3Jbe2dw&s",
25
+ "viddexa/nsfw-mini",
26
+ "Thumbnail (mini)",
27
+ ),
28
+ (
29
+ "https://img.freepik.com/premium-photo/portrait-beautiful-young-woman_1048944-5548042.jpg",
30
+ "viddexa/nsfw-nano",
31
+ "Portrait (nano)",
32
+ ),
33
+ ]
34
+
35
+
36
+ @lru_cache(maxsize=32)
37
+ def _download_image_bytes(image_url: str) -> bytes:
38
+ """Download image bytes from URL with caching."""
39
+ req = Request(image_url, headers={"User-Agent": "viddexa-gradio-demo/1.0"})
40
+ with urlopen(req, timeout=20) as resp:
41
+ return resp.read()
42
+
43
+
44
+ def _load_model(model_id: str, token: str | None = None) -> Any:
45
+ """Load a model and cache it."""
46
+ if model_id in _MODEL_CACHE:
47
+ return _MODEL_CACHE[model_id]
48
+ try:
49
+ from moderators.auto_model import AutoModerator
50
+ model = AutoModerator.from_pretrained(model_id, token=token, use_fast=True)
51
+ _MODEL_CACHE[model_id] = model
52
+ return model
53
+ except Exception as e:
54
+ error_msg = f"Failed to load model: {model_id}. Error: {e}"
55
+ if "401" in str(e):
56
+ error_msg += "\n\nThis model may be private. Please ensure you have provided a valid Hugging Face token if required."
57
+ raise gr.Error(error_msg)
58
+
59
+
60
+ def _get_image_input(image_path: str | None, image_url: str | None) -> Image.Image:
61
+ """Get image data from either an uploaded file path or a URL."""
62
+ if image_url:
63
+ try:
64
+ data = _download_image_bytes(image_url)
65
+ img = Image.open(BytesIO(data))
66
+ return img.convert("RGB")
67
+ except Exception as fetch_err:
68
+ raise gr.Error(f"Could not download or open the image from the URL: {fetch_err}")
69
+ elif image_path:
70
+ img = Image.open(image_path)
71
+ return img.convert("RGB")
72
+ else:
73
+ raise gr.Error("Please upload an image or provide an image URL.")
74
+
75
+
76
+ def _format_results(results: list) -> Tuple[str, Dict[str, float], str, Dict]:
77
+ """Format the model output for the Gradio interface."""
78
+ if not results or "classifications" not in results[0]:
79
+ return "<div class='verdict-card'>No classifications found.</div>", {}, "No classifications found.", {}
80
+
81
+ classifications = results[0]["classifications"]
82
+
83
+ label_output: Dict[str, float]
84
+ if isinstance(classifications, dict):
85
+ label_output = {str(k): float(v) for k, v in classifications.items()}
86
+ else:
87
+ try:
88
+ label_output = {str(item['label']): float(item['score']) for item in classifications}
89
+ except Exception:
90
+ label_output = {}
91
+
92
+ scores = {label.lower(): score for label, score in label_output.items()}
93
+ nsfw_score = scores.get("nsfw", 0.0)
94
+
95
+ if nsfw_score > 0.7:
96
+ verdict_text = "HIGH RISK: NSFW"
97
+ verdict_class = "verdict-nsfw"
98
+ elif nsfw_score > 0.2:
99
+ verdict_text = "MEDIUM RISK: SENSITIVE"
100
+ verdict_class = "verdict-sensitive"
101
+ else:
102
+ verdict_text = "LOW RISK: SAFE"
103
+ verdict_class = "verdict-safe"
104
+
105
+ verdict_html = f"<div class='verdict-card {verdict_class}'>{verdict_text}</div>"
106
+
107
+ markdown_output = "### All Scores\n---\n"
108
+ for label, score in sorted(label_output.items(), key=lambda kv: kv[1], reverse=True):
109
+ markdown_output += f"- **{label.capitalize()}**: {score:.4f}\n"
110
+
111
+ return verdict_html, label_output, markdown_output, results[0]
112
+
113
+
114
+ def analyze_image(image_path: str | None, image_url: str | None, model_choice: str,
115
+ token: str | None = None, progress=gr.Progress(track_tqdm=True)):
116
+ """Main inference function for the Gradio interface."""
117
+ progress(0, desc="Initializing Analysis...")
118
+ progress(0.2, desc="Processing Image...")
119
+ input_image = _get_image_input(image_path, image_url)
120
+ progress(0.5, desc=f"Loading Model: {os.path.basename(model_choice)}...")
121
+ model = _load_model(model_choice, token)
122
+ progress(0.8, desc="Running Inference...")
123
+ results = model(input_image)
124
+
125
+ json_results = [
126
+ {"classifications": getattr(r, "classifications", r)}
127
+ for r in results
128
+ ]
129
+ json_results = json.loads(json.dumps(json_results, ensure_ascii=False))
130
+
131
+ progress(1, desc="Complete!")
132
+ return _format_results(json_results)
133
+
134
+
135
+ def analyze_image_with_status(image_path: str | None, image_url: str | None, model_choice: str,
136
+ token: str | None = None, progress=gr.Progress(track_tqdm=True)):
137
+ """Run analysis and return results with user-friendly status string."""
138
+ verdict_html, label_scores, md_scores, json_obj = analyze_image(image_path, image_url, model_choice, token, progress)
139
+ if image_url:
140
+ status = f"Last analysed URL: {image_url}"
141
+ elif image_path:
142
+ status = "Last analysed uploaded image."
143
+ else:
144
+ status = "Last analysed: —"
145
+ return verdict_html, label_scores, md_scores, json_obj, status
146
+
147
+
148
+ def run_example_by_index(evt: gr.SelectData, token: str | None = None):
149
+ """Handle gallery selection: run analysis for the selected example and update inputs."""
150
+ try:
151
+ idx = int(getattr(evt, "index", 0))
152
+ except Exception:
153
+ idx = 0
154
+ idx = max(0, min(idx, len(EXAMPLE_ITEMS) - 1))
155
+ url, model, caption = EXAMPLE_ITEMS[idx]
156
+ verdict_html, label_scores, md_scores, json_obj = analyze_image(None, url, model, token)
157
+ status = f"Last analysed example: {caption}"
158
+ return (
159
+ verdict_html,
160
+ label_scores,
161
+ md_scores,
162
+ json_obj,
163
+ gr.update(value=model),
164
+ gr.update(value=url),
165
+ status,
166
+ )