diff --git a/examples/post_training_quantization/torch/ssd300_vgg16/main.py b/examples/post_training_quantization/torch/ssd300_vgg16/main.py index 567e12d44f3..674e8d60291 100644 --- a/examples/post_training_quantization/torch/ssd300_vgg16/main.py +++ b/examples/post_training_quantization/torch/ssd300_vgg16/main.py @@ -274,13 +274,19 @@ def get_image(): def main_detr(): - from transformers import DetrImageProcessor, DetrForObjectDetection + from transformers import DetrImageProcessor, DetrForObjectDetection # noqa + from transformers import AutoImageProcessor, AutoModelForObjectDetection, ConditionalDetrForObjectDetection # noqa + from transformers import OwlViTProcessor, OwlViTForObjectDetection # noqa import torch device = torch.device("cpu") # you can specify the revision tag if you don't want the timm dependency - processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50", revision="no_timm") - model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", revision="no_timm") + # processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50", revision="no_timm") + # model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", revision="no_timm") + processor = AutoImageProcessor.from_pretrained("microsoft/conditional-detr-resnet-50") + model = ConditionalDetrForObjectDetection.from_pretrained("microsoft/conditional-detr-resnet-50") + # processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32") + # model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32") model.eval() dataset_path = download_dataset()