Skip to content

jmayank23/ClassActivationMaps-PyTorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

9 Commits
 
 
 
 
 
 

Repository files navigation

Class Activation Maps in PyTorch

cams

This code provides an easy-to-use implementation for generating state-of-the-art Class Activation Maps (CAMs.) While the CAM generation is handled by another library, this code highlights the sequence of steps and provides an easy way to visualize multiple CAMs without having to deal with any intricacies of plotting functions. CAMs are useful visualizations that highlight the regions of an image that contribute most to the prediction of a specific class by a convolutional neural network (CNN).

Installation

Before running the code, make sure to install the grad-cam library by executing the following command:

pip install grad-cam

Usage

  1. Import the necessary modules:
from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
from pytorch_grad_cam.utils.image import show_cam_on_image
import torch
import cv2
import torchvision
import matplotlib.pyplot as plt
  1. Define the CAM methods you want to use:
cam_methods = {
    'GradCAM': GradCAM,
    'GradCAM++': GradCAMPlusPlus,
    'HiResCAM': HiResCAM,
    'ScoreCAM': ScoreCAM,
    'AblationCAM': AblationCAM,
    'XGradCAM': XGradCAM,
    'EigenCAM': EigenCAM,
    'FullGrad': FullGrad
}

Uncomment the desired methods based on your requirements.

  1. Load the input image:
rgb_img = cv2.imread('/path/to/image.jpg') / 255.0

Make sure to specify the correct path to your image.

  1. Prepare the image for input to the CAM model:
resize_to = (224, 224)
transform_norm = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Resize(size=resize_to, antialias=True),
    torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
input_tensor = transform_norm(rgb_img).unsqueeze(0)

Adjust the resize_to dimensions according to your model's input size and preprocessing requirements.

  1. Load your pre-trained model and specify the target layers:
model = torchvision.models.resnet50(pretrained='imagenet').double()
target_layers = [model.layer4[-1]]

Replace your own model here, and modify target_layers based on the layer(s) you want to visualize.

  1. Create CAM objects for each method:
cams = {}
for cam_name, cam_method in cam_methods.items():
    cams[cam_name] = cam_method(model=model, target_layers=target_layers, use_cuda=torch.cuda.is_available())
  1. Specify the targets (optional):
targets = None
# targets = [ClassifierOutputTarget(281)]

If you want to generate CAMs for specific targets, uncomment the second line and modify the target index accordingly.

  1. Generate visualizations for each CAM method:
visualizations = {}
for cam_name, cam in cams.items():
    grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
    grayscale_cam = grayscale_cam[0, :]
    visualization = show_cam_on_image(cv2.resize(rgb_img, resize_to), grayscale_cam, use_rgb=True)
    visualizations[cam_name] = visualization
  1. Plot the original image and the visualizations:
num_visualizations = len(visualizations)
fig, axs = plt.subplots(1,

 num_visualizations + 1, figsize=(4 * (num_visualizations + 1), 4))
axs[0].imshow(cv2.resize(rgb_img, resize_to))
axs[0].set_title('Original Image')

for i, (cam_name, visualization) in enumerate(visualizations.items()):
    axs[i + 1].imshow(visualization)
    axs[i + 1].set_title(cam_name)

plt.tight_layout()
plt.show()

Adjust the plot settings according to your preferences.

Acknowledgments

This code utilizes the pytorch-grad-cam library developed by Jacob Gil. Make sure to consult the library's documentation for more advanced usage and additional features.

License

The code is provided under the MIT License. Feel free to modify and use it according to your needs.

About

Plotting Class Activation Maps in PyTorch

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published