diff --git a/tools/benchmark/trt_benchmark.py b/tools/benchmark/trt_benchmark.py index a650ac0..328d283 100644 --- a/tools/benchmark/trt_benchmark.py +++ b/tools/benchmark/trt_benchmark.py @@ -19,7 +19,7 @@ def parse_args(): parser = argparse.ArgumentParser(description='Argument Parser Example') - parser.add_argument('--COCO_dir', + parser.add_argument('--infer_dir', type=str, default='/data/COCO2017/val2017', help="Directory for images to perform inference on.") @@ -180,7 +180,7 @@ def main(): 'orig_target_sizes': torch.tensor([640, 640]).to(im.device), } - engine_files = glob.glob(os.path.join(FLAGS.models_dir, "*.engine")) + engine_files = glob.glob(os.path.join(FLAGS.engine_dir, "*.engine")) results = [] for engine_file in engine_files: