-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp.py
87 lines (67 loc) · 2.93 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import gradio as gr
from detector_model import LogoDetector
from description_model import FlorenceModel
from verification_model import LLaVAModel, PaligemmaModel
import os
model = LogoDetector()
# desc_model = FlorenceModel()
verify_model = None
def process_image(hf_key, image, thresh, nms_thresh):
if hf_key:
os.environ["HF_TOKEN"] = hf_key
else:
return "Please set a valid Huggingface API Key"
global verify_model
if not verify_model:
verify_model = PaligemmaModel()
annotated_image, potential_logos = model.process_image(image, thresh, nms_thresh)
logos = [logo for logo in potential_logos if verify_model.verify_image(logo)]
descriptions = [verify_model.run_example(logo, "What is the full brand name of this logo?") for logo in logos]
# logos = verify_model.verify_images(potential_logos)
# descriptions = desc_model.describe_images(logos)
return annotated_image, list(zip(logos, descriptions))
DESCRIPTION = "# Automated Logo Detection - Phase 1"
css = """
#output {
height: 500px;
overflow: auto;
border: 1px solid #ccc;
}
"""
with gr.Blocks(css=css) as demo:
gr.Markdown(DESCRIPTION)
with gr.Tab(label="Logo Detection"):
hf_key = gr.Textbox(label="Huggingface_API_KEY", placeholder="Huggingface API KEY", type="password")
with gr.Row():
with gr.Column():
input_img = gr.Image(label="Input Frame", type="pil")
submit_btn = gr.Button(value="Detect Logos")
input_threshold = gr.Slider(
label="Threshold",
info="Larger value will detect fewer logos and vice versa.",
minimum=0.01,
maximum=0.2,
value=0.1,
step=0.01)
input_nms_threshold = gr.Slider(
label="NMS_Threshold",
info="Larger value will detect fewer overlapping logos and vice versa.",
minimum=0.1,
maximum=0.9,
value=0.3,
step=0.1)
with gr.Column():
# output_text = gr.Textbox(label="Output Text")
output_img = gr.Image(label="Output Image")
output_logos = gr.Gallery(columns=1, label="Cropped Logos", preview=True, show_label=True)
with gr.Accordion("Instructions"):
gr.Markdown(
"""
- Upload an image containing logos.
- Select a threshold using the slider.
- Click "Detect Logos" to detect logos in the given image.
- The output image with bounding boxes drawn for each logo will be shown in the top right.
- The cropped logos will be shown in the gallery below the output image.
""")
submit_btn.click(process_image, [hf_key, input_img, input_threshold, input_nms_threshold], [output_img, output_logos])
demo.launch(debug=False)