Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix penultimate layer search condition #85

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open

Conversation

srwi
Copy link

@srwi srwi commented Apr 2, 2022

Hi there!

Currently the last Conv layer is being automatically used for GradCAM if not specified differently. We noticed that for some models like MobileNetV3 this results in the wrong layer being used (as already mentioned in #61).

Specifically to MobileNetV3 this has some problems:

  1. The logits layer is implemented using a Conv2D layer causing this layer with shape (None, 1, 1, 1024) to be selected as the penultimate layer. Obviously this will result in incorrect/useless GradCAM images.
  2. Selecting the previous Conv layer manually however does not include some important activations which occasionally causes inverted GradCAM images (see GradCam Issues with MobileNetV3 #61).

We propose to use a different search condition which searches for the last layer with four dimensions and a width and height of more than 1. This will help both problems mentioned above:

  1. The selected layer will have dimensions greater than 1 by 1
  2. Important activations will still be included since the penultimate layer doesn't necessarily have to be a Conv layer anymore resulting in non-inverted GradCAM images.

A similar implementation is being used in sicara/tf-explain. Their implementation would however still be affected by the first problem.

I added some quick tests which could be extended in the future.

Let me know what you think!

@keisen
Copy link
Owner

keisen commented Apr 4, 2022

Hi @stnkl ,Thank you so much for your great PR!
I'm going to approve and merge this PR because, I believe basically, your idea makes a lot of sense!

However, before that, please improve some points.
First, tf-keras-vis supports N-dim image inputs, so GradCAM should support Conv1D, Conv2D, Conv3D or more dimensions.
And the target feature (cam) shape doesn't have to be square.
So, I think, the condition code will be improved, for example, like below:

def _penultimate_layer_condition(layer):
    return len(layer.output_shape) > 2 and any(d > 1 for d in layer.output.shape[1:-1])

In addition, is it possible to also keep the condition of Conv (i.e., lambda l: isinstance(l, Conv))?

If NOT, for example, in VGG16, the CAM will be generated from the feature-map that is output of max-pooling layer (block5_pool whose output shape is (None, 7, 7, 512)), despite expecting output of before that (block5_conv3 whose output shape is (None, 14, 14, 512)). So I assume that this will be a compatibility issue with lower versions.

Thanks!

@srwi
Copy link
Author

srwi commented Apr 5, 2022

Thanks for having a look at the PR! You brought up some very valid points.

First of all I totally agree with your updated search condition. This would make it compatible with Conv1D, Conv2D and Conv3D. I don't see any issue there.

Secondly, I see the problem with VGG16 and you are correct that it should select the Conv layer here. However if we do keep this Conv layer condition this would again break the behaviour for MobileNetV3 as the subsequent layers with matching dimensions are actually necessary to get correct GradCAM images.

Thus, I could imagine one further method of selecting the target layer:

  1. Find the last Conv layer l that matches _penultimate_layer_condition
  2. Identify the last subsequent layer with same output shape as l and use it as the target layer.

This would however make the target layer search a little more implicit. Again, let me know what you think and I would be happy to implement!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants