This git is a user friendly implementation of GradCam to better understand the behaviour of any CNN-based computer vision network. Paper at: https://arxiv.org/pdf/1610.02391
The input needed is the pre-trained network in question, and at least one image to be used for inference. The last layer is visualized since it contains highest level of information. The output is a superposition of the image and a heatmap indicating where the CNN layer is looking.
There are two main ways to use this git, either apply GradCam on one a single layer on multiple images, or apply GradCam on a single image for multiple layers.
Command for a single layer on multiple images
python main.py --model_path vgg16 --conv2d_backcount 1 --images_path data --labels_path imagenet1000_labels.txt --n_images 3 --show
Command for multiple layers on a single image
python main.py --model_path vgg16 --conv2d_backcount 1 3 4 --images_path data --labels_path imagenet1000_labels.txt --show
-
model_path: path to torch saved model (using torch.save())
-
conv2d_backcount: positive integer or a list of positive integers; CNN layer to visualize, counting from behind. By default equal to 1.
-
images_path: path pointing to the root of the images folder (not to the directory containing the images, but one level before).
-
save_dir: directory to save images on which gradcam was applied.
-
n_images: number of images used for inference.
-
imageNet_labels: add --imageNet_labels if labels used are from ImageNet.
-
show: add --show in command to show plots of GradCam. By default, plots will not be shown.
-
labels_path: path to .txt file containing a dictionary of your labels in the following format:
{0: 'cat',
1: 'dog',
2: 'person'}
--labels_path not needed if --imageNet_labels flag is used.
All images are saved in --save_dir directory
Example of figures for a single layer with multiple images:
We always show the top three predictions and the worst prediction.
Example of figures for a single layer with multiple images:
We always show the top three predictions and the worst prediction.
About preprocessing
We used the default preprocessing from ImageNet, and it is defined in utils.py If another preprocessing is needed, the function should be replaced inside utils.py