Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Attempt to change all Numpy calls to Torch calls #357

Draft
wants to merge 72 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
59aef83
Attempt to compute all base and grad-cam class operations using torch…
TRex22 Nov 5, 2022
8bf7524
Bump other version :cop:
TRex22 Nov 5, 2022
fa8c8d7
Fix import to use torch over Numpy :cop:
TRex22 Nov 5, 2022
ce52619
Convert max pos 2 to a tensor :cop:
TRex22 Nov 5, 2022
73f720b
Begin to find and migrate more numpy calls to torch calls. Also fix s…
TRex22 Nov 5, 2022
f96e21f
Make use of one strategy for resizing tensors over the cv2 call :cop:
TRex22 Nov 5, 2022
ef3dcf5
Go back to CPU for cv2 resizing :cop:
TRex22 Nov 5, 2022
bf9b9db
Attempt to get the image scaling function working with torch :cop:
TRex22 Nov 5, 2022
df9b035
Use torch tensor only on the list of numpy arrays :cop:
TRex22 Nov 5, 2022
fe6cb92
Use the correct torch function :cop:
TRex22 Nov 5, 2022
f58e88c
Attempt to fix torch resizing :cop:
TRex22 Nov 5, 2022
594bb0c
Use a transpose for the experiment. Investigate a proper resize later…
TRex22 Nov 5, 2022
f4739c2
Remove the resize for now :cop:
TRex22 Nov 5, 2022
995550e
Disable the scaling function from changing dimensions (for now)
TRex22 Nov 5, 2022
02b9451
Create a simple benchmark :cop:
TRex22 Nov 6, 2022
cc557d1
Add in basic GradCAM :cop:
TRex22 Nov 6, 2022
71f51d0
Continue to write a simple GradCAM :cop:
TRex22 Nov 6, 2022
06efbc4
Properly name the variable :cop:
TRex22 Nov 6, 2022
f2578d7
Fix the tensor stack :cop:
TRex22 Nov 6, 2022
a4d2750
Fix the dimensions needed for Resnet :cop:
TRex22 Nov 6, 2022
7058122
Change target layer :cop:
TRex22 Nov 6, 2022
1a77f74
Add in cuda profiling :cop:
TRex22 Nov 6, 2022
f4b759a
Create the large loop :cop:
TRex22 Nov 6, 2022
15ca2be
Refactor code to share some algorithm :cop:
TRex22 Nov 6, 2022
fb1b50d
Fix batching :cop:
TRex22 Nov 6, 2022
62c1709
Add in proper output :cop:
TRex22 Nov 6, 2022
045d200
Add in loading bar :cop:
TRex22 Nov 6, 2022
20d7ebd
Reduce to 100 images :cop:
TRex22 Nov 6, 2022
d9dbc85
Attempt using a bigger batchsize :cop:
TRex22 Nov 6, 2022
2756d71
Bump batch_size :cop:
TRex22 Nov 6, 2022
4200b99
Add workflow test :cop:
TRex22 Nov 8, 2022
1489ea3
Fix tensor issue in 1.4.6 :cop:
TRex22 Nov 8, 2022
8833ee1
Add inner loop :cop:
TRex22 Nov 8, 2022
d492c36
Force cuda device :cop:
TRex22 Nov 8, 2022
5bbdf8f
Fix loop range :cop:
TRex22 Nov 8, 2022
b8a8a46
Make use of the tensor resize transform :cop:
TRex22 Nov 10, 2022
77b19da
Add in a different model to benchmark too :cop:
TRex22 Nov 10, 2022
85f196b
handle the tensor list size :cop:
TRex22 Nov 10, 2022
a56647d
Correct the dimensions in the resize :cop:
TRex22 Nov 10, 2022
922d2d3
pdate using the correct models in the benchmark :cop:
TRex22 Nov 10, 2022
5dbc8bc
Fix output :cop:
TRex22 Nov 10, 2022
20ab49f
Improve benchmarking and make a functions file to store reusable comp…
TRex22 Nov 18, 2022
3eceb84
Make use of shared functions :cop:
TRex22 Nov 18, 2022
901391e
Attempt to fix device memory issues :cop:
TRex22 Nov 18, 2022
2748c5c
Select the last CNN model as the GradCAM taregt layer :scientist:
TRex22 Nov 18, 2022
b77aa5b
Fix spelling miskate :cop:
TRex22 Nov 18, 2022
65f1b1f
Attempt another way to iterate through model params :cop:
TRex22 Nov 18, 2022
9c274be
Handle multiple models :cop:
TRex22 Nov 18, 2022
eaaf0a9
Fix feature bug :cop:
TRex22 Nov 18, 2022
dc5db2e
Cleanup progress :cop:
TRex22 Nov 18, 2022
c75bbef
Add in method benchmark :cop:
TRex22 Nov 18, 2022
915b99f
Patch in cuda device support :cop:
TRex22 Feb 17, 2023
0305eec
Fix cuda device call :cop:
TRex22 Feb 17, 2023
0da12a0
Fix cuda device call :cop:
TRex22 Feb 17, 2023
82f71e7
Work on a single image benchmark :cop:
TRex22 Mar 9, 2023
199815c
Disable cpu benchmarking :cop:
TRex22 Mar 9, 2023
cf020cf
output the resultant image and allow inputting a image :cop:
TRex22 Mar 9, 2023
9513c86
Allow for output saving for a snaity check :cop:
TRex22 Mar 9, 2023
186c14b
Allow for output saving for a snaity check :cop:
TRex22 Mar 9, 2023
b6d2202
Allow for output saving for a snaity check :cop:
TRex22 Mar 9, 2023
84a0689
Allow for output saving for a snaity check :cop:
TRex22 Mar 9, 2023
18f8d8e
Open image :cop:
TRex22 Mar 9, 2023
8eaf1b7
Open image :cop:
TRex22 Mar 9, 2023
b805465
Change to simple model :cop:
TRex22 Mar 9, 2023
7393e77
try with trained weights :cop:
TRex22 Mar 9, 2023
193c9f2
try with trained weights :cop:
TRex22 Mar 9, 2023
a3d327b
Add in cuda device support for gradients and activations :cop:
TRex22 Apr 17, 2023
15de5ed
Fix typo :cop:
TRex22 Apr 17, 2023
9e38599
Attempt to force a difference FakeTensorMode :scientist:
TRex22 Apr 17, 2023
ed64d06
Attempt to patch issue with pytorch 2.0 :cop:
TRex22 Apr 17, 2023
273bd81
Make more meaningful changes extracted from another branch :cop:
TRex22 Apr 26, 2023
9ed4c9b
Remove TODO :cop:
TRex22 Apr 26, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 17 additions & 17 deletions pytorch_grad_cam/base_cam.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def get_cam_weights(self,
target_layers: List[torch.nn.Module],
targets: List[torch.nn.Module],
activations: torch.Tensor,
grads: torch.Tensor) -> np.ndarray:
grads: torch.Tensor) -> torch.Tensor:
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now Im trying to do minimal changes to create a proof of concept I can run.
Ive left this as a draft PR and its still definitely a WIP

I just find the PR user interface is great to observe changes

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Related to: #356

raise Exception("Not Implemented")

def get_cam_image(self,
Expand All @@ -45,7 +45,7 @@ def get_cam_image(self,
targets: List[torch.nn.Module],
activations: torch.Tensor,
grads: torch.Tensor,
eigen_smooth: bool = False) -> np.ndarray:
eigen_smooth: bool = False) -> torch.Tensor:

weights = self.get_cam_weights(input_tensor,
target_layer,
Expand All @@ -62,7 +62,7 @@ def get_cam_image(self,
def forward(self,
input_tensor: torch.Tensor,
targets: List[torch.nn.Module],
eigen_smooth: bool = False) -> np.ndarray:
eigen_smooth: bool = False) -> torch.Tensor:

if self.cuda:
input_tensor = input_tensor.cuda()
Expand All @@ -73,7 +73,7 @@ def forward(self,

outputs = self.activations_and_grads(input_tensor)
if targets is None:
target_categories = np.argmax(outputs.cpu().data.numpy(), axis=-1)
target_categories = torch.argmax(outputs.data, axis=-1)
targets = [ClassifierOutputTarget(
category) for category in target_categories]

Expand Down Expand Up @@ -106,10 +106,10 @@ def compute_cam_per_layer(
self,
input_tensor: torch.Tensor,
targets: List[torch.nn.Module],
eigen_smooth: bool) -> np.ndarray:
activations_list = [a.cpu().data.numpy()
eigen_smooth: bool) -> torch.Tensor:
activations_list = [a.data
for a in self.activations_and_grads.activations]
grads_list = [g.cpu().data.numpy()
grads_list = [g.data
for g in self.activations_and_grads.gradients]
target_size = self.get_target_width_height(input_tensor)

Expand All @@ -130,24 +130,24 @@ def compute_cam_per_layer(
layer_activations,
layer_grads,
eigen_smooth)
cam = np.maximum(cam, 0)
cam = torch.maximum(cam, torch.tensor(0))
scaled = scale_cam_image(cam, target_size)
cam_per_target_layer.append(scaled[:, None, :])

return cam_per_target_layer

def aggregate_multi_layers(
self,
cam_per_target_layer: np.ndarray) -> np.ndarray:
cam_per_target_layer = np.concatenate(cam_per_target_layer, axis=1)
cam_per_target_layer = np.maximum(cam_per_target_layer, 0)
result = np.mean(cam_per_target_layer, axis=1)
cam_per_target_layer: torch.Tensor) -> torch.Tensor:
cam_per_target_layer = torch.concatenate(cam_per_target_layer, axis=1)
cam_per_target_layer = torch.maximum(cam_per_target_layer, torch.tensor(0))
result = torch.mean(cam_per_target_layer, axis=1)
return scale_cam_image(result)

def forward_augmentation_smoothing(self,
input_tensor: torch.Tensor,
targets: List[torch.nn.Module],
eigen_smooth: bool = False) -> np.ndarray:
eigen_smooth: bool = False) -> torch.Tensor:
transforms = tta.Compose(
[
tta.HorizontalFlip(),
Expand All @@ -167,18 +167,18 @@ def forward_augmentation_smoothing(self,
cam = transform.deaugment_mask(cam)

# Back to numpy float32, HxW
cam = cam.numpy()
# cam = cam.numpy()
cam = cam[:, 0, :, :]
cams.append(cam)
cams.append(cam) # TODO: Handle this for torch tensors
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically pre-initialise a tensor. Ive found that to be drastically faster that lists when dealing with cuda / non-cpu devices


cam = np.mean(np.float32(cams), axis=0)
cam = torch.mean(cams.to(torch.float32), axis=0)
return cam

def __call__(self,
input_tensor: torch.Tensor,
targets: List[torch.nn.Module] = None,
aug_smooth: bool = False,
eigen_smooth: bool = False) -> np.ndarray:
eigen_smooth: bool = False) -> torch.Tensor:

# Smooth the CAM result with test time augmentation
if aug_smooth is True:
Expand Down
4 changes: 2 additions & 2 deletions pytorch_grad_cam/grad_cam.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import numpy as np
import torch
from pytorch_grad_cam.base_cam import BaseCAM


Expand All @@ -19,4 +19,4 @@ def get_cam_weights(self,
target_category,
activations,
grads):
return np.mean(grads, axis=(2, 3))
return torch.mean(grads, axis=(2, 3))
18 changes: 14 additions & 4 deletions pytorch_grad_cam/utils/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import cv2
import numpy as np
import torch
import torchvision.transforms.functional as F
from torchvision.transforms import Compose, Normalize, ToTensor
from typing import List, Dict
import math
Expand Down Expand Up @@ -160,12 +161,21 @@ def show_factorization_on_image(img: np.ndarray,
def scale_cam_image(cam, target_size=None):
result = []
for img in cam:
img = img - np.min(img)
img = img / (1e-7 + np.max(img))
img = img - torch.min(img)
img = img / (1e-7 + torch.max(img))
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cv2 resize will need work to be done via a torch tensor. Will investigate once I get the concept working

if target_size is not None:
img = cv2.resize(img, target_size)
# There seem to be many different ways to resize a torch tensor
# with varying results
# TODO: Investigate these
# For now going to convert to cpu numpy and back just to get
# the crude experiment working - and then begin to tune and refine
# Possible way:
# img = F.resize(img, target_size) # TODO: Investigate better resizing techniques - Keeping defaults for now

# Convert to numpy
img = torch.tensor(cv2.resize(img.cpu().numpy(), target_size))
result.append(img)
result = np.float32(result)
result = torch.tensor(np.array(result)).to(torch.float32) # TODO: Optimise this to use pre-initialised torch tensor
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is quite crude. My aim is to get torch working first. Then to go back to these hacks and get them optimal with a torch tensor approach once I know the rest of the changes to torch tensors are working.

Also helps identify the problem areas that need more attention.

None of this work is release ready


return result

Expand Down
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = grad-cam
version = 1.1.0
version = 1.4.7
author = Jacob Gildenblat
author_email = [email protected]
description = Many Class Activation Map methods implemented in Pytorch. Including Grad-CAM, Grad-CAM++, Score-CAM, Ablation-CAM and XGrad-CAM
Expand All @@ -16,4 +16,4 @@ classifiers =

[options]
packages = find:
python_requires = >=3.6
python_requires = >=3.6
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

setuptools.setup(
name='grad-cam',
version='1.4.6',
version='1.4.7',
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just did this so on my compute node I know which version Im working with. In terms of what the real version bump would be Im open to a more major potential bump - as the final version of these changes would be substantial and potentially breaking

author='Jacob Gildenblat',
author_email='[email protected]',
description='Many Class Activation Map methods implemented in Pytorch for classification, segmentation, object detection and more',
Expand Down