Skip to content

Commit

Permalink
Yolov6 prediction codes have been updated.
Browse files Browse the repository at this point in the history
  • Loading branch information
kadirnar committed Jan 22, 2023
1 parent 2fafdf6 commit 90a3286
Show file tree
Hide file tree
Showing 8 changed files with 16 additions and 26 deletions.
2 changes: 1 addition & 1 deletion torchyolo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from torchyolo.predict import YoloHub

__version__ = "0.4.1"
__version__ = "1.0.0"
2 changes: 1 addition & 1 deletion torchyolo/configs/default_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ DETECTOR_CONFIG:
# The threshold for the confidence score
CONF_TH: 0.25
# The size of the image
IMAGE_SIZE: 640
IMAGE_SIZE: 1280
# The device to run the detector
DEVICE: cuda:0
# F16 precision
Expand Down
5 changes: 0 additions & 5 deletions torchyolo/modelhub/yolov5.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,3 @@ def predict(
)
if self.save:
video_writer.write(frame)

if self.show:
cv2.imshow("frame", frame)
if cv2.waitKey(1) & 0xFF == ord("q"):
break
18 changes: 13 additions & 5 deletions torchyolo/modelhub/yolov6.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import cv2
import torch
from tqdm import tqdm
from yolov6.core.inferer import Inferer
from yolov6.helpers import check_img_size
Expand Down Expand Up @@ -43,8 +41,7 @@ def load_model(self):
model = YOLOV6(self.model_path, device=self.device, hf_model=self.hf_model)
model.conf = self.conf
model.iou = self.iou
model.torchyolo = True
self.model = model
self.model = model.model

except ImportError:
raise ImportError('Please run "pip install yolov6detect" ' "to install YOLOv6 first for YOLOv6 inference.")
Expand Down Expand Up @@ -83,7 +80,7 @@ def predict(
img = img[None]
# expand for batch dim

pred_results = self.model.model(img)
pred_results = self.model(img)
det = non_max_suppression(pred_results, self.conf, self.iou, classes=None, agnostic=False, max_det=1000)[0]

det[:, :4] = Inferer.rescale(img.shape[2:], det[:, :4], img_src.shape).round()
Expand All @@ -110,3 +107,14 @@ def predict(
)
if self.save:
video_writer.write(frame)
else:
for *xyxy, conf, cls in det:
label = f"{COCO_CLASSES[int(cls)]} {float(conf):.2f}"
frame = video_vis(
bbox=xyxy,
label=label,
frame=img_src,
object_id=int(cls),
)
if self.save:
video_writer.write(frame)
5 changes: 0 additions & 5 deletions torchyolo/modelhub/yolov7.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,3 @@ def predict(
)
if self.save:
video_writer.write(frame)

if self.show:
cv2.imshow("frame", frame)
if cv2.waitKey(1) & 0xFF == ord("q"):
break
5 changes: 0 additions & 5 deletions torchyolo/modelhub/yolov8.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,3 @@ def predict(
video_writer.write(frame)
else:
cv2.imwrite("output.jpg", frame)

if self.show:
cv2.imshow("frame", frame)
if cv2.waitKey(1) & 0xFF == ord("q"):
break
2 changes: 1 addition & 1 deletion torchyolo/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def predict(
model = YoloHub(
config_path="torchyolo/configs/default_config.yaml",
model_type="yolov6",
model_path="yolov6s.pt",
model_path="yolov6l.pt",
)
result = model.predict(
source="../test.mp4", tracker_type="NORFAIR", tracker_config_path="torchyolo/configs/tracker/norfair_track.yaml"
Expand Down
3 changes: 0 additions & 3 deletions torchyolo/tracker/tracker_zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,6 @@ def create_tracker(
except ImportError:
raise ImportError("Please install strongsort: pip install strongsort")

else:
raise ValueError(f"No such tracker: {tracker_type}")


def load_tracker(
config_path: str = None,
Expand Down

0 comments on commit 90a3286

Please sign in to comment.