From c4298a2d4949d2253a71690f33e5285311f33a21 Mon Sep 17 00:00:00 2001 From: eclique Date: Fri, 1 Dec 2017 05:35:58 -0500 Subject: [PATCH 1/2] Optimization and refactoring --- grad-cam.py | 211 +++++++++++++++++++++++++++++----------------------- 1 file changed, 118 insertions(+), 93 deletions(-) diff --git a/grad-cam.py b/grad-cam.py index a175db7..bc97549 100644 --- a/grad-cam.py +++ b/grad-cam.py @@ -1,72 +1,60 @@ -from keras.applications.vgg16 import ( - VGG16, preprocess_input, decode_predictions) +from keras.applications.vgg16 import VGG16, preprocess_input, decode_predictions from keras.preprocessing import image -from keras.layers.core import Lambda -from keras.models import Sequential from tensorflow.python.framework import ops +from matplotlib import pyplot as plt import keras.backend as K import tensorflow as tf import numpy as np -import keras import sys import cv2 -def target_category_loss(x, category_index, nb_classes): - return tf.multiply(x, K.one_hot([category_index], nb_classes)) -def target_category_loss_output_shape(input_shape): - return input_shape +def build_model(): + """Function returning keras model instance. + + Model can be + - Trained here + - Loaded with load_model + - Loaded from keras.applications + """ + return VGG16(weights='imagenet') -def normalize(x): - # utility function to normalize a tensor by its L2 norm - return x / (K.sqrt(K.mean(K.square(x))) + 1e-5) - -def load_image(path): - img_path = sys.argv[1] - img = image.load_img(img_path, target_size=(224, 224)) - x = image.img_to_array(img) - x = np.expand_dims(x, axis=0) - x = preprocess_input(x) - return x -def register_gradient(): - if "GuidedBackProp" not in ops._gradient_registry._registry: - @ops.RegisterGradient("GuidedBackProp") +def build_guided_model(): + """Function returning modified model. + + Changes gradient function for all ReLu activations + according to Guided Backpropagation. + """ + if 'GuidedBackProp' not in ops._gradient_registry._registry: + @ops.RegisterGradient('GuidedBackProp') def _GuidedBackProp(op, grad): dtype = op.inputs[0].dtype return grad * tf.cast(grad > 0., dtype) * \ - tf.cast(op.inputs[0] > 0., dtype) - -def compile_saliency_function(model, activation_layer='block5_conv3'): - input_img = model.input - layer_dict = dict([(layer.name, layer) for layer in model.layers[1:]]) - layer_output = layer_dict[activation_layer].output - max_output = K.max(layer_output, axis=3) - saliency = K.gradients(K.sum(max_output), input_img)[0] - return K.function([input_img, K.learning_phase()], [saliency]) + tf.cast(op.inputs[0] > 0., dtype) -def modify_backprop(model, name): g = tf.get_default_graph() - with g.gradient_override_map({'Relu': name}): + with g.gradient_override_map({'Relu': 'GuidedBackProp'}): + new_model = build_model() + return new_model - # get layers that have an activation - layer_dict = [layer for layer in model.layers[1:] - if hasattr(layer, 'activation')] - # replace relu activation - for layer in layer_dict: - if layer.activation == keras.activations.relu: - layer.activation = tf.nn.relu +def load_image(path, preprocess=True): + """Function to load and preprocess image.""" + x = image.load_img(path, target_size=(224, 224)) + if preprocess: + x = image.img_to_array(x) + x = np.expand_dims(x, axis=0) + x = preprocess_input(x) + return x - # re-instanciate a new model - new_model = VGG16(weights='imagenet') - return new_model def deprocess_image(x): - ''' + """ Same normalization as in: https://github.com/fchollet/keras/blob/master/examples/conv_filter_visualization.py - ''' + """ + x = x.copy() if np.ndim(x) > 3: x = np.squeeze(x) # normalize tensor: center on 0., ensure std is 0.1 @@ -85,59 +73,96 @@ def deprocess_image(x): x = np.clip(x, 0, 255).astype('uint8') return x -def grad_cam(input_model, image, category_index, layer_name): - model = Sequential() - model.add(input_model) - nb_classes = 1000 - target_layer = lambda x: target_category_loss(x, category_index, nb_classes) - model.add(Lambda(target_layer, - output_shape = target_category_loss_output_shape)) - - loss = K.sum(model.layers[-1].output) - conv_output = [l for l in model.layers[0].layers if l.name is layer_name][0].output - grads = normalize(K.gradients(loss, conv_output)[0]) - gradient_function = K.function([model.layers[0].input], [conv_output, grads]) - - output, grads_val = gradient_function([image]) - output, grads_val = output[0, :], grads_val[0, :, :, :] - - weights = np.mean(grads_val, axis = (0, 1)) - cam = np.ones(output.shape[0 : 2], dtype = np.float32) +def normalize(x): + """Utility function to normalize a tensor by its L2 norm""" + return (x + 1e-10) / (K.sqrt(K.mean(K.square(x))) + 1e-10) - for i, w in enumerate(weights): - cam += w * output[:, :, i] - cam = cv2.resize(cam, (224, 224)) - cam = np.maximum(cam, 0) - heatmap = cam / np.max(cam) +def guided_backprop(model, img, activation_layer): + """Compute gradients of conv. activation w.r.t. the input image. - #Return to BGR [0..255] from the preprocessed image - image = image[0, :] - image -= np.min(image) - image = np.minimum(image, 255) - - cam = cv2.applyColorMap(np.uint8(255*heatmap), cv2.COLORMAP_JET) - cam = np.float32(cam) + np.float32(image) - cam = 255 * cam / np.max(cam) - return np.uint8(cam), heatmap + If model is modified properly this will result in Guided Backpropagation + method for visualizing input saliency. + See https://arxiv.org/abs/1412.6806 """ + input_img = model.input + layer_output = model.get_layer(activation_layer).output + grads = K.gradients(layer_output, input_img)[0] + gradient_fn = K.function([input_img, K.learning_phase()], [grads]) + grads_val = gradient_fn([img, 0])[0] + return grads_val -preprocessed_input = load_image(sys.argv[1]) -model = VGG16(weights='imagenet') +def grad_cam(input_model, img, category_index, activation_layer): + """GradCAM method for visualizing input saliency.""" + loss = input_model.output[0, category_index] + layer_output = input_model.get_layer(activation_layer).output + grads = normalize(K.gradients(loss, layer_output)[0]) + gradient_fn = K.function([input_model.input, K.learning_phase()], [layer_output, grads]) -predictions = model.predict(preprocessed_input) -top_1 = decode_predictions(predictions)[0][0] -print('Predicted class:') -print('%s (%s) with probability %.2f' % (top_1[1], top_1[0], top_1[2])) + conv_output, grads_val = gradient_fn([img, 0]) + conv_output, grads_val = conv_output[0], grads_val[0] -predicted_class = np.argmax(predictions) -cam, heatmap = grad_cam(model, preprocessed_input, predicted_class, "block5_conv3") -cv2.imwrite("gradcam.jpg", cam) + weights = np.mean(grads_val, axis=(0, 1)) + cam = np.dot(conv_output, weights) -register_gradient() -guided_model = modify_backprop(model, 'GuidedBackProp') -saliency_fn = compile_saliency_function(guided_model) -saliency = saliency_fn([preprocessed_input, 0]) -gradcam = saliency[0] * heatmap[..., np.newaxis] -cv2.imwrite("guided_gradcam.jpg", deprocess_image(gradcam)) + cam = cv2.resize(cam, (224, 224), cv2.INTER_LINEAR) + cam = np.maximum(cam, 0) + cam = cam / cam.max() + return cam + + +def compute_saliency(model, guided_model, layer_name, img_path, cls=-1, visualize=True, save=True): + preprocessed_input = load_image(img_path) + + predictions = model.predict(preprocessed_input) + top_n = 5 + top = decode_predictions(predictions, top=top_n)[0] + classes = np.argsort(predictions[0])[-top_n:][::-1] + print('Model prediction:') + for c, p in zip(classes, top): + print('\t({}) {:20s}\twith probability {:.3f}'.format(c, p[1],p[2])) + if cls == -1: + cls = np.argmax(predictions) + nb_classes = 1000 + class_name = decode_predictions(np.eye(1, nb_classes, cls))[0][0][1] + print("Computing saliency for '{}'".format(class_name)) + + gradcam = grad_cam(model, preprocessed_input, cls, activation_layer=layer_name) + gb = guided_backprop(guided_model, img=preprocessed_input, activation_layer=layer_name) + guided_gradcam = gb * gradcam[..., np.newaxis] + + if save: + jetcam = cv2.applyColorMap(np.uint8(255 * gradcam), cv2.COLORMAP_JET) + jetcam = (np.float32(jetcam) + np.float32(cv2.imread(sys.argv[1]))) / 2 + cv2.imwrite('gradcam.jpg', np.uint8(jetcam)) + cv2.imwrite('guided_backprop.jpg', deprocess_image(gb[0])) + cv2.imwrite('guided_gradcam.jpg', deprocess_image(guided_gradcam[0])) + + if visualize: + plt.figure(figsize=(15, 6)) + plt.subplot(131) + plt.title('GradCAM') + plt.axis('off') + plt.imshow(load_image(img_path, preprocess=False)) + plt.imshow(gradcam, cmap='jet', alpha=0.5) + + plt.subplot(132) + plt.title('Guided Backprop') + plt.axis('off') + plt.imshow(np.flip(deprocess_image(gb[0]), -1)) + + plt.subplot(133) + plt.title('Guided GradCAM') + plt.axis('off') + plt.imshow(np.flip(deprocess_image(guided_gradcam[0]), -1)) + plt.show() + + return gradcam, gb, grad_cam + + +if __name__ == '__main__': + model = build_model() + guided_model = build_guided_model() + gradcam, gb, grad_cam = compute_saliency(model, guided_model, layer_name='block5_conv3', + img_path=sys.argv[1], cls=-1) From 772356ff079700175906c8d38b9c4ffbbc04bc54 Mon Sep 17 00:00:00 2001 From: Vitali Petsiuk Date: Mon, 8 Jan 2018 20:06:30 -0500 Subject: [PATCH 2/2] Update README.md --- README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/README.md b/README.md index 6d81c1f..81330d5 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,9 @@ ## Grad-CAM implementation in Keras ## +--- +#### My more recent version of repository: https://github.com/eclique/keras-gradcam. +--- + Gradient class activation maps are a visualization technique for deep learning networks. See the paper: https://arxiv.org/pdf/1610.02391v1.pdf