Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix penultimate layer search condition #85

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions tests/tf_keras_vis/utils/model_modifiers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

from tf_keras_vis.activation_maximization import ActivationMaximization
from tf_keras_vis.saliency import Saliency
from tf_keras_vis.utils.model_modifiers import (ExtractIntermediateLayer, GuidedBackpropagation,
from tf_keras_vis.utils.model_modifiers import (ExtractIntermediateLayer,
ExtractIntermediateLayerForGradcam,
GuidedBackpropagation,
ReplaceToLinear)
from tf_keras_vis.utils.scores import CategoricalScore
from tf_keras_vis.utils.test import (NO_ERROR, assert_raises, dummy_sample, mock_conv_model,
Expand Down Expand Up @@ -47,7 +49,13 @@ def test__call__(self, model, layer, expected_error):


class TestExtractIntermediateLayerForGradcam():
pass
@pytest.mark.parametrize("model", [mock_conv_model(), mock_multiple_outputs_model()])
@pytest.mark.parametrize("layer", [None, -1, "conv_1"])
@pytest.mark.usefixtures("mixed_precision")
def test__call__(self, model, layer):
modified_model = ExtractIntermediateLayerForGradcam(layer)(model)
assert modified_model != model
assert modified_model.outputs[-1].shape.as_list() == [None, 6, 6, 6]


class TestExtractGuidedBackpropagation():
Expand Down
25 changes: 14 additions & 11 deletions tf_keras_vis/utils/model_modifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,6 @@
from typing import Union

import tensorflow as tf
from packaging.version import parse as version

if version(tf.version.VERSION) < version("2.6.0rc0"):
from tensorflow.python.keras.layers.convolutional import Conv
else:
from keras.layers.convolutional import Conv

from . import find_layer

Expand Down Expand Up @@ -107,11 +101,20 @@ def __call__(self, model) -> None:


class ExtractIntermediateLayerForGradcam(ModelModifier):
def __init__(self, penultimate_layer=None, seek_conv_layer=True, include_model_outputs=True):
def __init__(self,
penultimate_layer=None,
seek_penultimate_layer=True,
include_model_outputs=True):
self.penultimate_layer = penultimate_layer
self.seek_conv_layer = seek_conv_layer
self.seek_penultimate_layer = seek_penultimate_layer
self.include_model_outputs = include_model_outputs

@staticmethod
def _penultimate_layer_condition(layer):
return len(layer.output_shape) == 4 and \
layer.output_shape[1] > 1 and \
layer.output_shape[2] > 1

def __call__(self, model):
_layer = self.penultimate_layer
if not isinstance(_layer, tf.keras.layers.Layer):
Expand All @@ -123,10 +126,10 @@ def __call__(self, model):
_layer = find_layer(model, lambda l: l.name == _layer)
else:
raise ValueError(f"Invalid argument. `penultimate_layer`={self.penultimate_layer}")
if _layer is not None and self.seek_conv_layer:
_layer = find_layer(model, lambda l: isinstance(l, Conv), offset=_layer)
if _layer is not None and self.seek_penultimate_layer:
_layer = find_layer(model, self._penultimate_layer_condition, offset=_layer)
if _layer is None:
raise ValueError("Unable to determine penultimate `Conv` layer. "
raise ValueError("Unable to determine penultimate layer. "
f"`penultimate_layer`={self.penultimate_layer}")
penultimate_output = _layer.output
if len(penultimate_output.shape) < 3:
Expand Down