diff --git a/pyadjoint/reduced_functional.py b/pyadjoint/reduced_functional.py index cd8bcbe1..993615df 100644 --- a/pyadjoint/reduced_functional.py +++ b/pyadjoint/reduced_functional.py @@ -1,7 +1,7 @@ from .drivers import compute_gradient, compute_hessian from .enlisting import Enlist from .tape import get_working_tape, stop_annotating, no_annotations -from .overloaded_type import OverloadedType +from .overloaded_type import OverloadedType, create_overloaded_object def _get_extract_derivative_components(derivative_components): @@ -123,7 +123,9 @@ def derivative(self, adj_input=1.0, options={}): # Scale adjoint input with stop_annotating(): - adj_value = self.scale * adj_input + # Make sure `adj_input` is an OverloadedType + adj_input = create_overloaded_object(adj_input) + adj_value = adj_input._ad_mul(self.scale) derivatives = compute_gradient(self.functional, controls,