Skip to content

Commit

Permalink
add trt infer example.
Browse files Browse the repository at this point in the history
  • Loading branch information
GiovanniFyc authored Nov 22, 2024
1 parent 74e82ee commit 18eee92
Show file tree
Hide file tree
Showing 2 changed files with 171 additions and 0 deletions.
45 changes: 45 additions & 0 deletions tools/inference/cppExample/trt/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
cmake_minimum_required(VERSION 3.5.1)

project(
trtExample
LANGUAGES CXX
VERSION 1.0.0
)

set(CMAKE_CXX_STANDARD 17)

if (MSVC)
add_compile_options(-nologo)
add_definitions(-DNOMINMAX)
else()
add_compile_options(-Wall)
endif()

find_package(OpenCV CONFIG REQUIRED core dnn imgcodecs imgproc)
find_package(CUDA REQUIRED)

set(tensorrt_include_DIR $ENV{Tensorrt_DIR}/include)
set(tensorrt_lib_DIR $ENV{Tensorrt_DIR}/lib)

file(GLOB tensorrt_libs
"${tensorrt_lib_DIR}/*.lib"
)
message("Tensorrt include: ${tensorrt_include_DIR}")
message("Tensorrt lib: ${tensorrt_lib_DIR}")

add_executable(
trtExample
trtExample.cpp
)
target_link_libraries(
${tensorrt_libs}
${CUDA_LIBRARIES}
opencv_core
opencv_dnn
opencv_imgcodecs
)
target_include_directories(
trtExample PUBLIC
${tensorrt_include_DIR}
${CUDA_INCLUDE_DIRS}
)
126 changes: 126 additions & 0 deletions tools/inference/cppExample/trt/trtExample.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
#include <opencv2/opencv.hpp>
#include <opencv2/dnn.hpp>
#include <iostream>
#include <fstream>
#include <NvInfer.h>
#include <string>

class TRTLogger : public nvinfer1::ILogger {
public:
void log(Severity severity, const char* msg) noexcept override {
switch (severity) {
case Severity::kINFO:
std::cout<<msg<<std::endl;
break;
case Severity::kWARNING:
std::cout<<msg<<std::endl;
break;
case Severity::kERROR:
std::cout<<msg<<std::endl;
break;
case Severity::kINTERNAL_ERROR:
std::cout<<msg<<std::endl;
break;
// case Severity::kVERBOSE:
// std::cout<<msg<<std::endl;
// break;
default:
break;
}
}
static TRTLogger& getInstance() {
static TRTLogger trt_logger{};
return trt_logger;
}
};

void main(){
cv::Mat imageMat = cv::imread("your/png/path");
cv::Mat inferMat;
cv::dnn::blobFromImage(imageMat,inferMat,1.0 / 255.0);
nvinfer1::IRuntime *runtime;
nvinfer1::ICudaEngine *engine;
std::ifstream ifs("your/trt/model/path", std::ios::binary | std::ios::in);
if (!ifs.is_open())
return;
ifs.seekg(0, std::ios::end);
auto size = ifs.tellg();
ifs.seekg(0, std::ios::beg);
std::string str(size, '\0');
ifs.read(str.data(), size);
runtime = nvinfer1::createInferRuntime(TRTLogger::getInstance());
engine = runtime->deserializeCudaEngine(str.data(), str.size());
auto num = engine->getNbIOTensors();
for (int32_t i = 0; i < num; i++){
std::cout<<engine->getIOTensorName(i);
auto shape = engine->getTensorShape(engine->getIOTensorName(i));
std::cout<<"Tensor shape is :";
for (auto j = 0; j < shape.nbDims; j++){
std::cout<<shape.d[j]<<" ";
}
std::cout<<std::endl;
std::cout<<"Tensor format description is :"<<engine->getTensorFormatDesc(engine->getIOTensorName(i))<<std::endl;
}
// change the data type for your model.
float* image = nullptr;
int64_t* imageSize = nullptr;
int64_t* label = nullptr;
float* score = nullptr;
float* boxes = nullptr;
// change the shape for your setting.
cudaMalloc((void**)&image, 1 * 3 * 640 * 640 * sizeof(float));
cudaMalloc((void**)&imageSize, 2 * sizeof(int64_t));
cudaMalloc((void**)&label, 1 * 300 * sizeof(int64_t));
cudaMalloc((void**)&boxes, 1 * 4 * 300 * sizeof(float));
cudaMalloc((void**)&score, 1 * 300 * sizeof(float));
cudaMemcpy(image, inferMat.ptr<float>(), inferMat.total() * sizeof(float), cudaMemcpyHostToDevice);
int64_t h_inputData2[2] = {640, 640};
cudaMemcpy(imageSize, h_inputData2, 2 * sizeof(int64_t), cudaMemcpyHostToDevice);

void* buffers[5];
buffers[0] = image;
buffers[1] = imageSize;
buffers[2] = label;
buffers[3] = boxes;
buffers[4] = score;

auto context = engine->createExecutionContext();
context->setInputShape(engine->getIOTensorName(0), nvinfer1::Dims4(1, 3, 640, 640));
context->setInputShape(engine->getIOTensorName(1), nvinfer1::Dims2(1, 2));
context->executeV2(buffers);

float* hostScore = new float[1 * 4];
int64_t* hostLabel = new int64_t[1 * 4];
float* hostBoxes = new float[1 * 4 * 4];

cudaMemcpy(hostScore, score, 1 * 4 * sizeof(float), cudaMemcpyDeviceToHost);
cudaMemcpy(hostLabel, label, 1 * 4 * sizeof(int64_t), cudaMemcpyDeviceToHost);
cudaMemcpy(hostBoxes, boxes, 1 * 4 * 4 *sizeof(float), cudaMemcpyDeviceToHost);

for(int i=0;i<2;i++){
std::cout<<"score:"<<hostScore[i]<<" ";
std::cout<<"label:"<<hostLabel[i]<<" ";
std::cout<<"x1:"<<hostBoxes[4*i]<<" ";
std::cout<<"y1:"<<hostBoxes[4*i+1]<<" ";
std::cout<<"x2:"<<hostBoxes[4*i+2]<<" ";
std::cout<<"y2:"<<hostBoxes[4*i+3]<<std::endl;
auto cx = hostBoxes[4*i];
auto cy = hostBoxes[4*i+1];
auto bx = hostBoxes[4*i+2];
auto by = hostBoxes[4*i+3];
cv::rectangle(imageMat, cv::Rect2f(cx,cy, bx-cx, by-cy), cv::Scalar(0, 255, 0), 1);
}
cv::imwrite("your/save/path",imageMat);
cudaFree(image);
cudaFree(imageSize);
cudaFree(label);
cudaFree(score);
cudaFree(boxes);
delete [] hostScore;
delete [] hostLabel;
delete [] hostBoxes;
delete context;
delete engine;
delete runtime;
std::cout<<"finish"<<std::endl;
}

0 comments on commit 18eee92

Please sign in to comment.