-
Notifications
You must be signed in to change notification settings - Fork 601
/
Copy pathDeployYOLOmodel.py
118 lines (98 loc) · 4.09 KB
/
DeployYOLOmodel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import torch
import numpy as np
import cv2
import pafy
from time import time
class ObjectDetection:
"""
Class implements Yolo5 model to make inferences on a youtube video using OpenCV.
"""
def __init__(self, url, out_file):
"""
Initializes the class with youtube url and output file.
:param url: Has to be as youtube URL,on which prediction is made.
:param out_file: A valid output file name.
"""
self._URL = url
self.model = self.load_model()
self.classes = self.model.names
self.out_file = out_file
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("\n\nDevice Used:",self.device)
def get_video_from_url(self):
"""
Creates a new video streaming object to extract video frame by frame to make prediction on.
:return: opencv2 video capture object, with lowest quality frame available for video.
"""
play = pafy.new(self._URL).streams[-1]
assert play is not None
return cv2.VideoCapture(play.url)
def load_model(self):
"""
Loads Yolo5 model from pytorch hub.
:return: Trained Pytorch model.
"""
model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)
return model
def score_frame(self, frame):
"""
Takes a single frame as input, and scores the frame using yolo5 model.
:param frame: input frame in numpy/list/tuple format.
:return: Labels and Coordinates of objects detected by model in the frame.
"""
self.model.to(self.device)
frame = [frame]
results = self.model(frame)
labels, cord = results.xyxyn[0][:, -1], results.xyxyn[0][:, :-1]
return labels, cord
def class_to_label(self, x):
"""
For a given label value, return corresponding string label.
:param x: numeric label
:return: corresponding string label
"""
return self.classes[int(x)]
def plot_boxes(self, results, frame):
"""
Takes a frame and its results as input, and plots the bounding boxes and label on to the frame.
:param results: contains labels and coordinates predicted by model on the given frame.
:param frame: Frame which has been scored.
:return: Frame with bounding boxes and labels ploted on it.
"""
labels, cord = results
n = len(labels)
x_shape, y_shape = frame.shape[1], frame.shape[0]
for i in range(n):
row = cord[i]
if row[4] >= 0.2:
x1, y1, x2, y2 = int(row[0]*x_shape), int(row[1]*y_shape), int(row[2]*x_shape), int(row[3]*y_shape)
bgr = (0, 255, 0)
cv2.rectangle(frame, (x1, y1), (x2, y2), bgr, 2)
cv2.putText(frame, self.class_to_label(labels[i]), (x1, y1), cv2.FONT_HERSHEY_SIMPLEX, 0.9, bgr, 2)
return frame
def __call__(self):
"""
This function is called when class is executed, it runs the loop to read the video frame by frame,
and write the output into a new file.
:return: void
"""
player = self.get_video_from_url()
assert player.isOpened()
x_shape = int(player.get(cv2.CAP_PROP_FRAME_WIDTH))
y_shape = int(player.get(cv2.CAP_PROP_FRAME_HEIGHT))
four_cc = cv2.VideoWriter_fourcc(*"MJPG")
out = cv2.VideoWriter(self.out_file, four_cc, 20, (x_shape, y_shape))
while True:
start_time = time()
ret, frame = player.read()
if not ret:
break
results = self.score_frame(frame)
frame = self.plot_boxes(results, frame)
end_time = time()
fps = 1/np.round(end_time - start_time, 3)
print(f"Frames Per Second : {fps}")
out.write(frame)
# Create a new object and execute.
detection = ObjectDetection("https://www.youtube.com/watch?v=EXUQnLyc3yE", "video2.avi")
detection()