From e576cdbe52937940b740533ef14ad35bf9f14fc4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mikel=20Brostr=C3=B6m?= Date: Thu, 21 Nov 2024 15:54:57 +0100 Subject: [PATCH] update tensorrt backend --- .../appearance/backends/tensorrt_backend.py | 94 ++++++++++--------- 1 file changed, 50 insertions(+), 44 deletions(-) diff --git a/boxmot/appearance/backends/tensorrt_backend.py b/boxmot/appearance/backends/tensorrt_backend.py index d737918a8e..32c840dbdb 100644 --- a/boxmot/appearance/backends/tensorrt_backend.py +++ b/boxmot/appearance/backends/tensorrt_backend.py @@ -1,77 +1,83 @@ import torch import numpy as np from pathlib import Path +from collections import OrderedDict, namedtuple from boxmot.utils import logger as LOGGER - from boxmot.appearance.backends.base_backend import BaseModelBackend - - class TensorRTBackend(BaseModelBackend): - def __init__(self, weights, device, half): super().__init__(weights, device, half) self.nhwc = False self.half = half + self.device = device + self.weights = weights + self.fp16 = False # Will be updated in load_model + self.load_model(self.weights) def load_model(self, w): - LOGGER.info(f"Loading {w} for TensorRT inference...") self.checker.check_packages(("nvidia-tensorrt",)) - import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download + try: + import tensorrt as trt # TensorRT library + except ImportError: + raise ImportError("Please install tensorrt to use this backend.") + + if self.device.type == "cpu": + if torch.cuda.is_available(): + self.device = torch.device("cuda:0") + else: + raise ValueError("CUDA device not available for TensorRT inference.") - if device.type == "cpu": - device = torch.device("cuda:0") Binding = namedtuple("Binding", ("name", "dtype", "shape", "data", "ptr")) logger = trt.Logger(trt.Logger.INFO) + + # Deserialize the engine with open(w, "rb") as f, trt.Runtime(logger) as runtime: self.model_ = runtime.deserialize_cuda_engine(f.read()) + + # Execution context self.context = self.model_.create_execution_context() self.bindings = OrderedDict() - self.fp16 = False # default updated below - # dynamic = False + + # Parse bindings for index in range(self.model_.num_bindings): name = self.model_.get_binding_name(index) dtype = trt.nptype(self.model_.get_binding_dtype(index)) - if self.model_.binding_is_input(index): - if -1 in tuple(self.model_.get_binding_shape(index)): # dynamic - # dynamic = True - self.context.set_binding_shape( - index, tuple(self.model_.get_profile_shape(0, index)[2]) - ) - if dtype == np.float16: - self.fp16 = True + is_input = self.model_.binding_is_input(index) + + # Handle dynamic shapes + if is_input and -1 in self.model_.get_binding_shape(index): + profile_index = 0 + min_shape, opt_shape, max_shape = self.model_.get_profile_shape(profile_index, index) + self.context.set_binding_shape(index, opt_shape) + + if is_input and dtype == np.float16: + self.fp16 = True + shape = tuple(self.context.get_binding_shape(index)) - im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device) - self.bindings[name] = Binding( - name, dtype, shape, im, int(im.data_ptr()) - ) - self.binding_addrs = OrderedDict( - (n, d.ptr) for n, d in self.bindings.items() - ) - # batch_size = self.bindings["images"].shape[ - # 0 - # ] # if dynamic, this is instead max batch size + data = torch.from_numpy(np.empty(shape, dtype=dtype)).to(self.device) + self.bindings[name] = Binding(name, dtype, shape, data, int(data.data_ptr())) + + self.binding_addrs = OrderedDict((n, d.ptr) for n, d in self.bindings.items()) def forward(self, im_batch): - if True and im_batch.shape != self.bindings["images"].shape: - i_in, i_out = ( - self.model_.get_binding_index(x) for x in ("images", "output") - ) - self.context.set_binding_shape( - i_in, im_batch.shape - ) # reshape if dynamic - self.bindings["images"] = self.bindings["images"]._replace( - shape=im_batch.shape - ) - self.bindings["output"].data.resize_( - tuple(self.context.get_binding_shape(i_out)) - ) + # Adjust for dynamic shapes + if im_batch.shape != self.bindings["images"].shape: + i_in = self.model_.get_binding_index("images") + i_out = self.model_.get_binding_index("output") + self.context.set_binding_shape(i_in, im_batch.shape) + self.bindings["images"] = self.bindings["images"]._replace(shape=im_batch.shape) + output_shape = tuple(self.context.get_binding_shape(i_out)) + self.bindings["output"].data.resize_(output_shape) + s = self.bindings["images"].shape - assert ( - im_batch.shape == s - ), f"input size {im_batch.shape} {'>' if self.dynamic else 'not equal to'} max model size {s}" + assert im_batch.shape == s, f"Input size {im_batch.shape} does not match model size {s}" + + # Set input buffer self.binding_addrs["images"] = int(im_batch.data_ptr()) + + # Execute inference self.context.execute_v2(list(self.binding_addrs.values())) features = self.bindings["output"].data return features