Granitagushi commited on
Commit
958c5b7
·
verified ·
1 Parent(s): 5ae6617

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +49 -0
  2. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import pipeline, AutoImageProcessor, AutoModelForImageClassification
3
+
4
+ # Image Processor explizit laden!
5
+ processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")
6
+ model = AutoModelForImageClassification.from_pretrained("Granitagushi/vit-base-fruits-360")
7
+
8
+ vit_classifier = pipeline(
9
+ "image-classification",
10
+ model=model,
11
+ image_processor=processor,
12
+ device=0 # oder -1 für CPU
13
+ )
14
+
15
+ clip_detector = pipeline(
16
+ model="openai/clip-vit-large-patch14",
17
+ task="zero-shot-image-classification"
18
+ )
19
+
20
+ labels_fruits = [
21
+ 'Orange', 'Strawberry Wedge', 'Banana', 'Cherry', 'Apple Red'
22
+ ]
23
+
24
+ def classify_fruit(image):
25
+ vit_results = vit_classifier(image)
26
+ vit_output = {result['label']: result['score'] for result in vit_results}
27
+ clip_results = clip_detector(image, candidate_labels=labels_fruits)
28
+ clip_output = {result['label']: result['score'] for result in clip_results}
29
+ return {"ViT Classification": vit_output, "CLIP Zero-Shot Classification": clip_output}
30
+
31
+ example_images = [
32
+ ["example_images/Apple.jpg"],
33
+ ["example_images/Banana.jpg"],
34
+ ["example_images/Cherry.jpg"],
35
+ ["example_images/orange.jpg"],
36
+ ["example_images/strawberry.jpg"]
37
+ ]
38
+
39
+ iface = gr.Interface(
40
+ fn=classify_fruit,
41
+ inputs=gr.Image(type="filepath"),
42
+ outputs=gr.JSON(),
43
+ title="Fruit Classification Comparison",
44
+ description="Upload an image of a fruit, and compare results from a trained ViT model and a zero-shot CLIP model.",
45
+ examples=example_images
46
+ )
47
+
48
+ if __name__ == "__main__":
49
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ transformers
2
+ datasets
3
+ torch
4
+ gradio
5
+ Pillow
6
+ tqdm
7
+ scikit-learn