-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp.py
63 lines (44 loc) · 2.02 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import streamlit as st
from PIL import Image
import numpy as np
import cv2
@st.cache_resource()
def load_model():
prototxt_path = 'model/colorization_deploy_v2.prototxt'
model_path = 'model/colorization_release_v2.caffemodel'
kernel_path = 'model/pts_in_hull.npy'
net = cv2.dnn.readNetFromCaffe(prototxt_path, model_path)
points = np.load(kernel_path)
points = points.transpose().reshape(2, 313, 1, 1)
net.getLayer(net.getLayerId('class8_ab')).blobs = [points.astype(np.float32)]
net.getLayer(net.getLayerId('conv8_313_rh')).blobs = [np.full([1, 313], 2.606, dtype='float32')]
return net
def colorize_image(image, net):
scaled = image.astype("float32") / 255.0
lab = cv2.cvtColor(scaled, cv2.COLOR_RGB2LAB)
resized = cv2.resize(lab, (224, 224))
L = cv2.split(resized)[0]
L -= 50
net.setInput(cv2.dnn.blobFromImage(L))
ab = net.forward()[0, :, :, :].transpose((1, 2, 0))
ab = cv2.resize(ab, (image.shape[1], image.shape[0]))
L = cv2.split(lab)[0]
colorized = np.concatenate((L[:, :, np.newaxis], ab), axis=2)
colorized = cv2.cvtColor(colorized, cv2.COLOR_LAB2RGB)
colorized = np.clip(colorized, 0, 1)
return (colorized * 255).astype("uint8")
def main():
st.set_page_config(page_title='Colorize It!', page_icon="🖌️", layout="centered")
st.title("Black and White Image Colorizer")
uploaded_file = st.file_uploader("Choose a black and white image...", type=["jpg", "jpeg", "png"], )
if uploaded_file:
image = Image.open(uploaded_file)
st.image(image, caption='Uploaded Image', use_column_width=True)
if st.button('Colorize Image'):
net = load_model()
# Convert PIL Image to numpy array
image_array = np.array(image.convert('RGB'))
colorized_image = colorize_image(image_array, net)
st.image(colorized_image, caption='Colorized Image', use_column_width=True)
if __name__ == "__main__":
main()