diff --git a/alibi/explainers/integrated_gradients.py b/alibi/explainers/integrated_gradients.py index d2726a8ca..96953da7f 100644 --- a/alibi/explainers/integrated_gradients.py +++ b/alibi/explainers/integrated_gradients.py @@ -400,6 +400,11 @@ def _gradients_input(model: Union[tf.keras.models.Model], grads = tape.gradient(preds, x) + # If certain inputs don't impact the target, the gradient is None, but we need to return a tensor + if isinstance(x, list): + for idx, grad in enumerate(grads): + if grad is None: + grads[idx] = tf.convert_to_tensor(np.zeros(x[idx].shape), dtype=x[idx].dtype) return grads @@ -497,7 +502,11 @@ def wrapper(*args, **kwargs): grads = tape.gradient(preds, layer.inp) else: grads = tape.gradient(preds, layer.result) - + # If certain inputs don't impact the target, the gradient is None, but we need to return a tensor + if isinstance(x, list): + for idx, grad in enumerate(grads): + if grad is None: + grads[idx] = tf.convert_to_tensor(np.zeros(x[idx].shape), dtype=x[idx].dtype) delattr(layer, 'inp') delattr(layer, 'result') layer.call = orig_call