diff --git a/benchmarks/benchmark_functions.py b/benchmarks/benchmark_functions.py new file mode 100644 index 000000000..4ebca6111 --- /dev/null +++ b/benchmarks/benchmark_functions.py @@ -0,0 +1,131 @@ +import argparse +import cv2 +import numpy as np +import torch +import time +import tqdm + +from pytorch_grad_cam import GradCAM + +from torch import nn +import torch.nn.functional as F + +import torchvision # You may need to install separately +from torchvision import models + +from torch.profiler import profile, record_function, ProfilerActivity + +# Simple model to test +class SimpleCNN(nn.Module): + def __init__(self): + super(SimpleCNN, self).__init__() + + # Grad-CAM interface + self.target_layer = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1) + self.target_layers = [self.target_layer] + self.layer4 = self.target_layer + + self.cnn_stack = nn.Sequential( + nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1), + nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(inplace=True), + self.target_layer, + nn.ReLU(inplace=True), + nn.MaxPool2d((2, 2)), + nn.Flatten(), + nn.Linear(122880, 10), + nn.Linear(10, 1) + ) + self.features = self.cnn_stack + + def forward(self, x): + logits = self.cnn_stack(x) + logits = F.normalize(logits, dim = 0) + + return logits + +def xavier_uniform_init(layer): + if type(layer) == nn.Linear or type(layer) == nn.Conv2d: + gain = nn.init.calculate_gain('relu') + + if layer.bias is not None: + nn.init.zeros_(layer.bias) + + nn.init.xavier_uniform_(layer.weight, gain=gain) + +def last_cnn_layer(model): + if hasattr(model, 'layer4'): + return model.layer4 + + if hasattr(model, 'conv3'): + return model.conv3 + + for feature in model.features: + if isinstance(feature, nn.Conv2d): + return feature + + return None + +def save_image(image, path): + return torchvision.utils.save_image(image, path) + +# Code to run benchmark +def run_gradcam(model, number_of_inputs, batch_size=1, use_cuda=False, workflow_test=False, progress_bar=True, method=GradCAM, input_image=None): + min_time = 10000000000000 + max_time = 0 + sum_of_times = 0 + + dev = torch.device('cpu') + if use_cuda: + dev = torch.device('cuda:0') + + # TODO: Use real data? + # TODO: Configurable dimensions? + + # Some defaults I use in research code + input_tensor = torch.rand((number_of_inputs, 3, 256, 60)) + targets = None # [ClassifierOutputTarget(None)] + + model.to(dev) + target_layers = [last_cnn_layer(model)] # Last CNN layer of ResNet50 + + cam_function = method(model=model, target_layers=target_layers, cuda_device=dev, use_cuda=use_cuda) + cam_function.batch_size = batch_size + + pbar = tqdm.tqdm(total=number_of_inputs) + + for i in range(0, number_of_inputs, batch_size): + start_time = time.time() + + threshold_plot = torch.rand((number_of_inputs, 3, 256, 60)) + output_image = torch.rand((number_of_inputs, 3, 256, 60)) + + # Actual code to benchmark + if input_image is None: + input_image = input_tensor[i:i+batch_size] + input_image = input_image.to(dev) + + heatmap = cam_function(input_tensor=input_image, targets=targets) + + if workflow_test: + for j in range(heatmap.shape[0]): + # Create a binary map + threshold_plot = torch.where(torch.tensor(heatmap[j]).to(torch.device('cuda:0')) > 0.5, 1, 0).to(dev) + output_image = input_image * threshold_plot + + end_time = time.time() + time_difference = end_time - start_time + + sum_of_times += time_difference + + if time_difference > max_time: + max_time = time_difference + + if time_difference < min_time: + min_time = time_difference + + if progress_bar: + pbar.update(batch_size) + + avg_time = sum_of_times / number_of_inputs + return [min_time, max_time, avg_time, [threshold_plot, output_image]] diff --git a/benchmarks/methods_benchmark.py b/benchmarks/methods_benchmark.py new file mode 100644 index 000000000..5660e3a79 --- /dev/null +++ b/benchmarks/methods_benchmark.py @@ -0,0 +1,64 @@ +import argparse +import cv2 +import numpy as np +import torch +import time +import tqdm + +from pytorch_grad_cam import GradCAM, \ + ScoreCAM, \ + GradCAMPlusPlus, \ + AblationCAM, \ + XGradCAM, \ + EigenCAM, \ + EigenGradCAM, \ + LayerCAM, \ + FullGrad + +from torch import nn +import torch.nn.functional as F + +import torchvision # You may need to install separately +from torchvision import models + +from torch.profiler import profile, record_function, ProfilerActivity + +import benchmark_functions + +number_of_inputs = 1000 + +print(f'Benchmarking GradCAM using {number_of_inputs} images for multiple models...') + +methods_to_benchmark = [ + ['GradCAM', GradCAM], + ['ScoreCAM', ScoreCAM], + ['GradCAMPlusPlus', GradCAMPlusPlus], + ['AblationCAM', AblationCAM], + ['XGradCAM', XGradCAM], + ['EigenCAM', EigenCAM], + ['EigenGradCAM', EigenGradCAM], + ['LayerCAM', LayerCAM], + ['FullGrad', FullGrad] +] + +model = benchmark_functions.SimpleCNN() +# model = models.resnet18() + +model.apply(benchmark_functions.xavier_uniform_init) # Randomise more weights + +for method_name, method in tqdm.tqdm(methods_to_benchmark): + print('==============================================================================\n\n') + print(f'Simple Workflow for method #{method_name}:\n') + + cpu_min_time, cpu_max_time, cpu_avg_time, _output_image = benchmark_functions.run_gradcam(model, number_of_inputs, batch_size=8, use_cuda=False, workflow_test=True, progress_bar=False, method=method) + cuda_min_time, cuda_max_time, cuda_avg_time, _output_image = benchmark_functions.run_gradcam(model, number_of_inputs, batch_size=8, use_cuda=True, workflow_test=True, progress_bar=False, method=method) + + print(f'Cuda Min time: {cuda_min_time}\n') + print(f'Cuda Max time: {cuda_max_time}\n') + print(f'Cuda Avg time: {cuda_avg_time}\n\n') + print(f'CPU Min time: {cpu_min_time}\n') + print(f'CPU Max time: {cpu_max_time}\n') + print(f'CPU Avg time: {cpu_avg_time}\n') + +print('==============================================================================\n\n') +print('Done!') diff --git a/benchmarks/models_benchmark.py b/benchmarks/models_benchmark.py new file mode 100644 index 000000000..9c6fdfa89 --- /dev/null +++ b/benchmarks/models_benchmark.py @@ -0,0 +1,53 @@ +import argparse +import cv2 +import numpy as np +import torch +import time +import tqdm + +from pytorch_grad_cam import GradCAM + +from torch import nn +import torch.nn.functional as F + +import torchvision # You may need to install separately +from torchvision import models + +from torch.profiler import profile, record_function, ProfilerActivity + +import benchmark_functions + +number_of_inputs = 1000 + +print(f'Benchmarking GradCAM using {number_of_inputs} images for multiple models...') + +models_to_benchmark = [ + ["SimpleCNN", benchmark_functions.SimpleCNN()], + ["resnet18", models.resnet18()], + ["resnet34", models.resnet34()], + ["resnet50", models.resnet50()], + ["alexnet", models.alexnet()], + ["vgg16", models.vgg16()], + ["googlenet", models.googlenet()], + ["mobilenet_v2", models.mobilenet_v2()], + ["densenet161", models.densenet161()] +] + +for model_name, model in tqdm.tqdm(models_to_benchmark): + print('==============================================================================\n\n') + print(f'Simple Workflow for model #{model_name}:\n') + + model.apply(benchmark_functions.xavier_uniform_init) # Randomise more weights + cpu_min_time, cpu_max_time, cpu_avg_time, _output_image = benchmark_functions.run_gradcam(model, number_of_inputs, batch_size=8, use_cuda=False, workflow_test=True, progress_bar=False) + cuda_min_time, cuda_max_time, cuda_avg_time, _output_image = benchmark_functions.run_gradcam(model, number_of_inputs, batch_size=8, use_cuda=True, workflow_test=True, progress_bar=False) + + print(f'Cuda Min time: {cuda_min_time}\n') + print(f'Cuda Max time: {cuda_max_time}\n') + print(f'Cuda Avg time: {cuda_avg_time}\n\n') + print(f'CPU Min time: {cpu_min_time}\n') + print(f'CPU Max time: {cpu_max_time}\n') + print(f'CPU Avg time: {cpu_avg_time}\n') + + +print('==============================================================================\n\n') +print('Done!') diff --git a/benchmarks/single_image_benchmark.py b/benchmarks/single_image_benchmark.py new file mode 100644 index 000000000..2d8368442 --- /dev/null +++ b/benchmarks/single_image_benchmark.py @@ -0,0 +1,131 @@ +import argparse +import cv2 +import numpy as np +import torch +import time +import tqdm + +from pytorch_grad_cam import GradCAM, \ + ScoreCAM, \ + GradCAMPlusPlus, \ + AblationCAM, \ + XGradCAM, \ + EigenCAM, \ + EigenGradCAM, \ + LayerCAM, \ + FullGrad + +from torch import nn +import torch.nn.functional as F + +import torchvision # You may need to install separately +from torchvision import models + +from torch.profiler import profile, record_function, ProfilerActivity + +import benchmark_functions + +number_of_inputs = 1 +model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2) + +# Just hard-coding a path for now +image_path = '~/image.jpg' +input_tensor = torchvision.io.read_image(image_path) + +print(f'Benchmarking GradCAM using {number_of_inputs} image for ResNet50...') + +# Run on CPU with profiler (save the profile to print later) +# print('Profile list of images on CPU...') +# with profile(activities=[ProfilerActivity.CPU], profile_memory=True, record_shapes=True) as prof: +# cpu_profile_min_time, cpu_profile_max_time, cpu_profile_avg_time, _output_image = benchmark_functions.run_gradcam(model, number_of_inputs, batch_size=64, use_cuda=False) +# cpu_profile = prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=15) + +# Run on CUDA with profiler (save the profile to print later) +print('Profile list of images on Cuda...') +with profile(activities=[ProfilerActivity.CPU], profile_memory=True, record_shapes=True) as prof: + cuda_profile_min_time, cuda_profile_max_time, cuda_profile_avg_time, _output_image = benchmark_functions.run_gradcam(model, number_of_inputs, batch_size=64, use_cuda=True) +cuda_profile = prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=15) + +# Run on CUDA with extra workflow +print('Profile list of images on Cuda and then run workflow...') +with profile(activities=[ProfilerActivity.CPU], profile_memory=True, record_shapes=True) as prof: + cuda_profile_min_time, cuda_profile_max_time, cuda_profile_avg_time, _output_image = benchmark_functions.run_gradcam(model, number_of_inputs, batch_size=64, use_cuda=True, workflow_test=True) +work_flow_cuda_profile = prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=15) + +# Run on CUDA with extra workflow +print('Profile list of images on Cuda and then run workflow with a simple CNN...') +model = benchmark_functions.SimpleCNN() +model.apply(benchmark_functions.xavier_uniform_init) # Randomise more weights +with profile(activities=[ProfilerActivity.CPU], profile_memory=True, record_shapes=True) as prof: + cuda_profile_min_time, cuda_profile_max_time, cuda_profile_avg_time, _output_image = benchmark_functions.run_gradcam(model, number_of_inputs, batch_size=64, use_cuda=True, workflow_test=True) +simple_work_flow_cuda_profile = prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=15) + +model = models.resnet50() +# Run on CPU x1000 (get min, max, and avg times) +# print('Run list of images on CPU...') +# cpu_min_time, cpu_max_time, cpu_avg_time, _output_image = benchmark_functions.run_gradcam(model, number_of_inputs, batch_size=64, use_cuda=False) + +# Run on CUDA x1000 +print('Run list of images on Cuda...') +cuda_min_time, cuda_max_time, cuda_avg_time, _output_image = benchmark_functions.run_gradcam(model, number_of_inputs, batch_size=64, use_cuda=True) + +# Run Workflow +print('Run list of images on Cuda with a workflow...') +workflow_cuda_min_time, workflow_cuda_max_time, workflow_cuda_avg_time, _output_image = benchmark_functions.run_gradcam(model, number_of_inputs, batch_size=64, use_cuda=True, workflow_test=True) + +print('Run list of images on Cuda with a workflow using simple CNN...') +model = benchmark_functions.SimpleCNN() +model.apply(benchmark_functions.xavier_uniform_init) # Randomise more weights +simple_workflow_cuda_min_time, simple_workflow_cuda_max_time, simple_workflow_cuda_avg_time, output = benchmark_functions.run_gradcam(model, number_of_inputs, batch_size=64, use_cuda=True, workflow_test=True) + +print('Complete!') + +# print('==============================================================================\n\n') +# print('CPU Profile:\n') +# print(cpu_profile) + +print('==============================================================================\n\n') +print('Cuda Profile:\n') +print(cuda_profile) + +print('==============================================================================\n\n') +print('Workflow Cuda Profile:\n') +print(work_flow_cuda_profile) + +print('==============================================================================\n\n') +print('Simple Workflow Cuda Profile:\n') +print(simple_work_flow_cuda_profile) + +# print('==============================================================================\n\n') +# print('CPU Timing (No Profiler):\n') +# print(f'Min time: {cpu_min_time}\n') +# print(f'Max time: {cpu_max_time}\n') +# print(f'Avg time: {cpu_avg_time}\n') + +print('==============================================================================\n\n') +print('Cuda Timing (No Profiler):\n') +print(f'Min time: {cuda_min_time}\n') +print(f'Max time: {cuda_max_time}\n') +print(f'Avg time: {cuda_avg_time}\n') + +print('==============================================================================\n\n') +print('Workflow Cuda Timing (No Profiler):\n') +print(f'Min time: {workflow_cuda_min_time}\n') +print(f'Max time: {workflow_cuda_max_time}\n') +print(f'Avg time: {workflow_cuda_avg_time}\n') + +print('==============================================================================\n\n') +print('Simple Workflow Cuda Timing (No Profiler):\n') +print(f'Min time: {simple_workflow_cuda_min_time}\n') +print(f'Max time: {simple_workflow_cuda_max_time}\n') +print(f'Avg time: {simple_workflow_cuda_avg_time}\n') + +print('==============================================================================\n\n') +print('Output the resultant heat-map') +threshold_plot, output_image = output + +benchmark_functions.save_image(threshold_plot.to("cpu", torch.uint8), '~/threshold.png') +benchmark_functions.save_image(output_image.to("cpu", torch.uint8), '~/output_image.png') + +print('==============================================================================\n\n') +print('Done!') diff --git a/benchmarks/torch_benchmark.py b/benchmarks/torch_benchmark.py new file mode 100644 index 000000000..793f099c5 --- /dev/null +++ b/benchmarks/torch_benchmark.py @@ -0,0 +1,126 @@ +import argparse +import cv2 +import numpy as np +import torch +import time +import tqdm + +from pytorch_grad_cam import GradCAM, \ + ScoreCAM, \ + GradCAMPlusPlus, \ + AblationCAM, \ + XGradCAM, \ + EigenCAM, \ + EigenGradCAM, \ + LayerCAM, \ + FullGrad + +from torch import nn +import torch.nn.functional as F + +import torchvision # You may need to install separately +from torchvision import models + +from torch.profiler import profile, record_function, ProfilerActivity + +import benchmark_functions + +number_of_inputs = 1000 +model = models.resnet50() + +print(f'Benchmarking GradCAM using {number_of_inputs} images for ResNet50...') + +# TODOs: +# Test with numpy v1.4.6 (master) +# Test with torch v1.4.7 (wip) +# Test other CAMs besides GradCAM +# Nice output + +# Run on CPU with profiler (save the profile to print later) +print('Profile list of images on CPU...') +with profile(activities=[ProfilerActivity.CPU], profile_memory=True, record_shapes=True) as prof: + cpu_profile_min_time, cpu_profile_max_time, cpu_profile_avg_time, _output_image = benchmark_functions.run_gradcam(model, number_of_inputs, batch_size=64, use_cuda=False) +cpu_profile = prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=15) + +# Run on CUDA with profiler (save the profile to print later) +print('Profile list of images on Cuda...') +with profile(activities=[ProfilerActivity.CPU], profile_memory=True, record_shapes=True) as prof: + cuda_profile_min_time, cuda_profile_max_time, cuda_profile_avg_time, _output_image = benchmark_functions.run_gradcam(model, number_of_inputs, batch_size=64, use_cuda=True) +cuda_profile = prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=15) + +# Run on CUDA with extra workflow +print('Profile list of images on Cuda and then run workflow...') +with profile(activities=[ProfilerActivity.CPU], profile_memory=True, record_shapes=True) as prof: + cuda_profile_min_time, cuda_profile_max_time, cuda_profile_avg_time, _output_image = benchmark_functions.run_gradcam(model, number_of_inputs, batch_size=64, use_cuda=True, workflow_test=True) +work_flow_cuda_profile = prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=15) + +# Run on CUDA with extra workflow +print('Profile list of images on Cuda and then run workflow with a simple CNN...') +model = benchmark_functions.SimpleCNN() +model.apply(benchmark_functions.xavier_uniform_init) # Randomise more weights +with profile(activities=[ProfilerActivity.CPU], profile_memory=True, record_shapes=True) as prof: + cuda_profile_min_time, cuda_profile_max_time, cuda_profile_avg_time, _output_image = benchmark_functions.run_gradcam(model, number_of_inputs, batch_size=64, use_cuda=True, workflow_test=True) +simple_work_flow_cuda_profile = prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=15) + +model = models.resnet50() +# Run on CPU x1000 (get min, max, and avg times) +print('Run list of images on CPU...') +cpu_min_time, cpu_max_time, cpu_avg_time, _output_image = benchmark_functions.run_gradcam(model, number_of_inputs, batch_size=64, use_cuda=False) + +# Run on CUDA x1000 +print('Run list of images on Cuda...') +cuda_min_time, cuda_max_time, cuda_avg_time, _output_image = benchmark_functions.run_gradcam(model, number_of_inputs, batch_size=64, use_cuda=True) + +# Run Workflow +print('Run list of images on Cuda with a workflow...') +workflow_cuda_min_time, workflow_cuda_max_time, workflow_cuda_avg_time, _output_image = benchmark_functions.run_gradcam(model, number_of_inputs, batch_size=64, use_cuda=True, workflow_test=True) + +print('Run list of images on Cuda with a workflow using simple CNN...') +model = benchmark_functions.SimpleCNN() +model.apply(benchmark_functions.xavier_uniform_init) # Randomise more weights +simple_workflow_cuda_min_time, simple_workflow_cuda_max_time, simple_workflow_cuda_avg_time, _output_image = benchmark_functions.run_gradcam(model, number_of_inputs, batch_size=64, use_cuda=True, workflow_test=True) + +print('Complete!') + +print('==============================================================================\n\n') +print('CPU Profile:\n') +print(cpu_profile) + +print('==============================================================================\n\n') +print('Cuda Profile:\n') +print(cuda_profile) + +print('==============================================================================\n\n') +print('Workflow Cuda Profile:\n') +print(work_flow_cuda_profile) + +print('==============================================================================\n\n') +print('Simple Workflow Cuda Profile:\n') +print(simple_work_flow_cuda_profile) + +print('==============================================================================\n\n') +print('CPU Timing (No Profiler):\n') +print(f'Min time: {cpu_min_time}\n') +print(f'Max time: {cpu_max_time}\n') +print(f'Avg time: {cpu_avg_time}\n') + +print('==============================================================================\n\n') +print('Cuda Timing (No Profiler):\n') +print(f'Min time: {cuda_min_time}\n') +print(f'Max time: {cuda_max_time}\n') +print(f'Avg time: {cuda_avg_time}\n') + +print('==============================================================================\n\n') +print('Workflow Cuda Timing (No Profiler):\n') +print(f'Min time: {workflow_cuda_min_time}\n') +print(f'Max time: {workflow_cuda_max_time}\n') +print(f'Avg time: {workflow_cuda_avg_time}\n') + +print('==============================================================================\n\n') +print('Simple Workflow Cuda Timing (No Profiler):\n') +print(f'Min time: {simple_workflow_cuda_min_time}\n') +print(f'Max time: {simple_workflow_cuda_max_time}\n') +print(f'Avg time: {simple_workflow_cuda_avg_time}\n') + +print('==============================================================================\n\n') +print('Done!') diff --git a/pytorch_grad_cam/activations_and_gradients.py b/pytorch_grad_cam/activations_and_gradients.py index 0c2071e59..957c976e1 100644 --- a/pytorch_grad_cam/activations_and_gradients.py +++ b/pytorch_grad_cam/activations_and_gradients.py @@ -2,12 +2,16 @@ class ActivationsAndGradients: """ Class for extracting activations and registering gradients from targetted intermediate layers """ - def __init__(self, model, target_layers, reshape_transform): + def __init__(self, model, target_layers, reshape_transform, use_cuda: bool = False, cuda_device = None): self.model = model self.gradients = [] self.activations = [] self.reshape_transform = reshape_transform self.handles = [] + + self.use_cuda = use_cuda + self.cuda_device = cuda_device + for target_layer in target_layers: self.handles.append( target_layer.register_forward_hook(self.save_activation)) @@ -21,7 +25,11 @@ def save_activation(self, module, input, output): if self.reshape_transform is not None: activation = self.reshape_transform(activation) - self.activations.append(activation.cpu().detach()) + + if self.use_cuda: + self.activations.append(activation.to(self.cuda_device)) + else: + self.activations.append(activation.cpu().detach()) def save_gradient(self, module, input, output): if not hasattr(output, "requires_grad") or not output.requires_grad: @@ -32,13 +40,18 @@ def save_gradient(self, module, input, output): def _store_grad(grad): if self.reshape_transform is not None: grad = self.reshape_transform(grad) - self.gradients = [grad.cpu().detach()] + self.gradients + + if self.use_cuda: + self.gradients = [grad.to(self.cuda_device)] + self.gradients + else: + self.gradients = [grad.cpu().detach()] + self.gradients output.register_hook(_store_grad) def __call__(self, x): self.gradients = [] self.activations = [] + return self.model(x) def release(self): diff --git a/pytorch_grad_cam/base_cam.py b/pytorch_grad_cam/base_cam.py index 7ee192971..73ee0b47d 100644 --- a/pytorch_grad_cam/base_cam.py +++ b/pytorch_grad_cam/base_cam.py @@ -6,26 +6,34 @@ from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection from pytorch_grad_cam.utils.image import scale_cam_image from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget - +from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode class BaseCAM: def __init__(self, model: torch.nn.Module, target_layers: List[torch.nn.Module], use_cuda: bool = False, + cuda_device = None, reshape_transform: Callable = None, compute_input_gradient: bool = False, uses_gradients: bool = True) -> None: self.model = model.eval() self.target_layers = target_layers + self.cuda = use_cuda - if self.cuda: + self.cuda_device = cuda_device + + if self.cuda_device and self.cuda: + self.model.to(self.cuda_device) + elif self.cuda: self.model = model.cuda() + self.reshape_transform = reshape_transform self.compute_input_gradient = compute_input_gradient self.uses_gradients = uses_gradients + self.activations_and_grads = ActivationsAndGradients( - self.model, target_layers, reshape_transform) + self.model, target_layers, reshape_transform, use_cuda = use_cuda, cuda_device = cuda_device) """ Get a vector of weights for every channel in the target layer. Methods that return weights channels, @@ -36,7 +44,7 @@ def get_cam_weights(self, target_layers: List[torch.nn.Module], targets: List[torch.nn.Module], activations: torch.Tensor, - grads: torch.Tensor) -> np.ndarray: + grads: torch.Tensor) -> torch.Tensor: raise Exception("Not Implemented") def get_cam_image(self, @@ -45,7 +53,7 @@ def get_cam_image(self, targets: List[torch.nn.Module], activations: torch.Tensor, grads: torch.Tensor, - eigen_smooth: bool = False) -> np.ndarray: + eigen_smooth: bool = False) -> torch.Tensor: weights = self.get_cam_weights(input_tensor, target_layer, @@ -62,9 +70,11 @@ def get_cam_image(self, def forward(self, input_tensor: torch.Tensor, targets: List[torch.nn.Module], - eigen_smooth: bool = False) -> np.ndarray: + eigen_smooth: bool = False) -> torch.Tensor: - if self.cuda: + if self.cuda_device and self.cuda: + input_tensor = input_tensor.to(self.cuda_device) + elif self.cuda: input_tensor = input_tensor.cuda() if self.compute_input_gradient: @@ -73,7 +83,7 @@ def forward(self, outputs = self.activations_and_grads(input_tensor) if targets is None: - target_categories = np.argmax(outputs.cpu().data.numpy(), axis=-1) + target_categories = torch.argmax(outputs.data, axis=-1) targets = [ClassifierOutputTarget( category) for category in target_categories] @@ -106,10 +116,10 @@ def compute_cam_per_layer( self, input_tensor: torch.Tensor, targets: List[torch.nn.Module], - eigen_smooth: bool) -> np.ndarray: - activations_list = [a.cpu().data.numpy() + eigen_smooth: bool) -> torch.Tensor: + activations_list = [a.data for a in self.activations_and_grads.activations] - grads_list = [g.cpu().data.numpy() + grads_list = [g.data for g in self.activations_and_grads.gradients] target_size = self.get_target_width_height(input_tensor) @@ -117,8 +127,10 @@ def compute_cam_per_layer( # Loop over the saliency image from every layer for i in range(len(self.target_layers)): target_layer = self.target_layers[i] + layer_activations = None layer_grads = None + if i < len(activations_list): layer_activations = activations_list[i] if i < len(grads_list): @@ -130,7 +142,8 @@ def compute_cam_per_layer( layer_activations, layer_grads, eigen_smooth) - cam = np.maximum(cam, 0) + + with FakeTensorMode(allow_non_fake_inputs=True): cam = torch.maximum(cam.cpu(), torch.tensor(0)) scaled = scale_cam_image(cam, target_size) cam_per_target_layer.append(scaled[:, None, :]) @@ -138,16 +151,16 @@ def compute_cam_per_layer( def aggregate_multi_layers( self, - cam_per_target_layer: np.ndarray) -> np.ndarray: - cam_per_target_layer = np.concatenate(cam_per_target_layer, axis=1) - cam_per_target_layer = np.maximum(cam_per_target_layer, 0) - result = np.mean(cam_per_target_layer, axis=1) + cam_per_target_layer: torch.Tensor) -> torch.Tensor: + cam_per_target_layer = torch.cat(cam_per_target_layer, axis=1) + cam_per_target_layer = torch.maximum(cam_per_target_layer, torch.tensor(0)) + result = torch.mean(cam_per_target_layer, axis=1) return scale_cam_image(result) def forward_augmentation_smoothing(self, input_tensor: torch.Tensor, targets: List[torch.nn.Module], - eigen_smooth: bool = False) -> np.ndarray: + eigen_smooth: bool = False) -> torch.Tensor: transforms = tta.Compose( [ tta.HorizontalFlip(), @@ -167,18 +180,18 @@ def forward_augmentation_smoothing(self, cam = transform.deaugment_mask(cam) # Back to numpy float32, HxW - cam = cam.numpy() + # cam = cam.numpy() cam = cam[:, 0, :, :] - cams.append(cam) + cams.append(cam) # TODO: Handle this for torch tensors - cam = np.mean(np.float32(cams), axis=0) + cam = torch.mean(cams.to(torch.float32), axis=0) return cam def __call__(self, input_tensor: torch.Tensor, targets: List[torch.nn.Module] = None, aug_smooth: bool = False, - eigen_smooth: bool = False) -> np.ndarray: + eigen_smooth: bool = False) -> torch.Tensor: # Smooth the CAM result with test time augmentation if aug_smooth is True: diff --git a/pytorch_grad_cam/fullgrad_cam.py b/pytorch_grad_cam/fullgrad_cam.py index 1a2685eff..d2b438a4c 100644 --- a/pytorch_grad_cam/fullgrad_cam.py +++ b/pytorch_grad_cam/fullgrad_cam.py @@ -9,7 +9,7 @@ class FullGrad(BaseCAM): - def __init__(self, model, target_layers, use_cuda=False, + def __init__(self, model, target_layers, use_cuda=False, cuda_device=None, reshape_transform=None): if len(target_layers) > 0: print( @@ -28,6 +28,7 @@ def layer_with_2D_bias(layer): model, target_layers, use_cuda, + cuda_device, reshape_transform, compute_input_gradient=True) self.bias_data = [self.get_bias_data( diff --git a/pytorch_grad_cam/grad_cam.py b/pytorch_grad_cam/grad_cam.py index 025bf45dd..2a48bdd8b 100644 --- a/pytorch_grad_cam/grad_cam.py +++ b/pytorch_grad_cam/grad_cam.py @@ -1,9 +1,9 @@ -import numpy as np +import torch from pytorch_grad_cam.base_cam import BaseCAM class GradCAM(BaseCAM): - def __init__(self, model, target_layers, use_cuda=False, + def __init__(self, model, target_layers, use_cuda=False, cuda_device=None, reshape_transform=None): super( GradCAM, @@ -11,6 +11,7 @@ def __init__(self, model, target_layers, use_cuda=False, model, target_layers, use_cuda, + cuda_device, reshape_transform) def get_cam_weights(self, @@ -19,4 +20,4 @@ def get_cam_weights(self, target_category, activations, grads): - return np.mean(grads, axis=(2, 3)) + return torch.mean(grads, axis=(2, 3)) diff --git a/pytorch_grad_cam/utils/image.py b/pytorch_grad_cam/utils/image.py index 34d92ba6f..fd9a6f3a3 100644 --- a/pytorch_grad_cam/utils/image.py +++ b/pytorch_grad_cam/utils/image.py @@ -4,7 +4,7 @@ import cv2 import numpy as np import torch -from torchvision.transforms import Compose, Normalize, ToTensor +from torchvision.transforms import Compose, Normalize, ToTensor, Resize from typing import List, Dict import math @@ -158,16 +158,26 @@ def show_factorization_on_image(img: np.ndarray, def scale_cam_image(cam, target_size=None): - result = [] - for img in cam: - img = img - np.min(img) - img = img / (1e-7 + np.max(img)) + # Disabled the target_size scaling for now + # It appears to swap the axes dimensions and needs further work for the + # proof of concept + + if target_size is not None: + result = torch.zeros([cam.shape[0], target_size[1], target_size[0]]) + else: + result = torch.zeros(cam.shape) + + for i in range(cam.shape[0]): + img = cam[i] + img = img - torch.min(img) + img = img / (1e-7 + torch.max(img)) + if target_size is not None: - img = cv2.resize(img, target_size) - result.append(img) - result = np.float32(result) + img = img.resize_(target_size).T - return result + result[i] = img + + return result.to(torch.float32) def scale_accross_batch_and_channels(tensor, target_size): diff --git a/pytorch_grad_cam/utils/svd_on_activations.py b/pytorch_grad_cam/utils/svd_on_activations.py index a406aeea8..91bfab0e1 100644 --- a/pytorch_grad_cam/utils/svd_on_activations.py +++ b/pytorch_grad_cam/utils/svd_on_activations.py @@ -1,9 +1,8 @@ -import numpy as np +import torch def get_2d_projection(activation_batch): - # TBD: use pytorch batch svd implementation - activation_batch[np.isnan(activation_batch)] = 0 + activation_batch[torch.isnan(activation_batch)] = 0 projections = [] for activations in activation_batch: reshaped_activations = (activations).reshape( @@ -12,8 +11,8 @@ def get_2d_projection(activation_batch): # Otherwise the image returned is negative reshaped_activations = reshaped_activations - \ reshaped_activations.mean(axis=0) - U, S, VT = np.linalg.svd(reshaped_activations, full_matrices=True) + U, S, VT = torch.linalg.svd(reshaped_activations, full_matrices=True) projection = reshaped_activations @ VT[0, :] projection = projection.reshape(activations.shape[1:]) projections.append(projection) - return np.float32(projections) + return torch.tensor(projections).to(torch.float32) diff --git a/setup.cfg b/setup.cfg index 203e6a636..dceb4f5bc 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = grad-cam -version = 1.1.0 +version = 1.4.7 author = Jacob Gildenblat author_email = jacob.gildenblat@gmail.com description = Many Class Activation Map methods implemented in Pytorch. Including Grad-CAM, Grad-CAM++, Score-CAM, Ablation-CAM and XGrad-CAM @@ -16,4 +16,4 @@ classifiers = [options] packages = find: -python_requires = >=3.6 \ No newline at end of file +python_requires = >=3.6 diff --git a/setup.py b/setup.py index 1d8ace600..ea87b563d 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ setuptools.setup( name='grad-cam', - version='1.4.6', + version='1.4.7', author='Jacob Gildenblat', author_email='jacob.gildenblat@gmail.com', description='Many Class Activation Map methods implemented in Pytorch for classification, segmentation, object detection and more',