Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Comparing Models for Different Tasks (Image Classification v.s. Object Detection)? #9

Open
bryanbocao opened this issue Mar 10, 2023 · 5 comments

Comments

@bryanbocao
Copy link

Thanks for sharing the code. I see that the models in the examples in this repo are all for classification. I was wondering if we can compare two models that do different tasks such as one for classification (Reset) and the other for object detection (YOLO)?

Thanks!

@bryanbocao
Copy link
Author

Tried

import os
import torch
from torchvision.models import resnet18 # edit
from torchvision.models.detection import fasterrcnn_resnet50_fpn # edit
# from torchvision.models.detection.SSD import ssd300_vgg16 # edit
from torchvision.datasets import CIFAR10 # edit
# import torchvision.datasets as dataset
# print('\n dir(dataset): ', dir(dataset))
# from torchvision.datasets import CocoDetection # edit
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import numpy as np
import random
from torch_cka import CKA

# print('\n dir(CocoDetection): ', dir(CocoDetection))

if not os.path.exists('../exps'): os.makedirs('../exps')

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(0)
np.random.seed(0)
random.seed(0)

model1_name, model2_name = 'resnet18', 'F-RCNN' # edit
model1 = resnet18(pretrained=True) # edit
model2 = fasterrcnn_resnet50_fpn(pretrained=True) # edit

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])

batch_size = 16 # 64 # 256

dataset = CIFAR10(root='../data/',
                  train=False,
                  download=True,
                  transform=transform)

dataloader = DataLoader(dataset,
                        batch_size=batch_size,
                        shuffle=False,
                        worker_init_fn=seed_worker,
                        generator=g,)

cka = CKA(model1, model2,
        model1_name=model1_name, model2_name=model2_name,
        device='cuda')

cka.compare(dataloader)
cka.plot_results(save_path="../exps/{}_{}.jpg".format(model1_name, model2_name))

but got

python3 resnet18_FRCNN.py
/home/brcao/Apps/anaconda3/envs/effi/lib/python3.8/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
  warnings.warn(
/home/brcao/Apps/anaconda3/envs/effi/lib/python3.8/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
/home/brcao/Apps/anaconda3/envs/effi/lib/python3.8/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=FasterRCNN_ResNet50_FPN_Weights.COCO_V1`. You can also use `weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
Files already downloaded and verified
/home/brcao/Apps/anaconda3/envs/effi/lib/python3.8/site-packages/torch_cka/cka.py:69: UserWarning: Model 2 seems to have a lot of layers. Consider giving a list of layers whose features you are concerned with through the 'model2_layers' parameter. Your CPU/GPU will thank you :)
  warn("Model 2 seems to have a lot of layers. " \
/home/brcao/Apps/anaconda3/envs/effi/lib/python3.8/site-packages/torch_cka/cka.py:145: UserWarning: Dataloader for Model 2 is not given. Using the same dataloader for both models.
  warn("Dataloader for Model 2 is not given. Using the same dataloader for both models.")
| Comparing features |:   0%|                                             | 0/625 [00:10<?, ?it/s]
Traceback (most recent call last):
  File "resnet18_FRCNN.py", line 55, in <module>
    cka.compare(dataloader)
  File "/home/brcao/Apps/anaconda3/envs/effi/lib/python3.8/site-packages/torch_cka/cka.py", line 172, in compare
    Y = feat2.flatten(1)
AttributeError: 'tuple' object has no attribute 'flatten'

Any help would be appreciated @AntixK. Thanks!

@ratom
Copy link

ratom commented Jul 23, 2023

@bryanbocao . Have you solved the error AttributeError: 'tuple' object has no attribute 'flatten'.
I got this error when I tried to compare ResNet and ViT.
Is there any solution, please let me know

@bryanbocao
Copy link
Author

@ratom unfortunately, not yet. I tried other projects later.

@bryanbocao
Copy link
Author

@ratom The layer where I got the error seemed to be right after a backbone (haven't rigorously verified it). So I simply just find those layers with errors and ignore them.

@StarBlue98
Copy link

@ratom The layer where I got the error seemed to be right after a backbone (haven't rigorously verified it). So I simply just find those layers with errors and ignore them.

Hi, I am currently working on comparing different object detection models, but I have not been successful in finding a suitable implementation for this functionality. I was wondering if you have come across any projects or resources that could help me achieve this goal. If you have any experience with this type of comparison or know of any relevant tools, libraries, I would be very grateful for your guidance. Thank you in advance for your time and assistance.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants