From 4485fe64fa917a1c381e5cdbefc61d80ef5313c6 Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Mon, 8 Jul 2024 13:50:51 +0100 Subject: [PATCH] NumPy 2.0 fix --- numpy_adjoint/array.py | 2 +- tests/pyadjoint/test_numpy.py | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) create mode 100644 tests/pyadjoint/test_numpy.py diff --git a/numpy_adjoint/array.py b/numpy_adjoint/array.py index 7b824498..282101b3 100644 --- a/numpy_adjoint/array.py +++ b/numpy_adjoint/array.py @@ -11,7 +11,7 @@ def __init__(self, *args, **kwargs): @classmethod def _ad_init_object(cls, obj): - return cls(obj.shape, numpy.float_, buffer=obj) + return cls(obj.shape, obj.dtype, buffer=obj) def _ad_create_checkpoint(self): return self.copy() diff --git a/tests/pyadjoint/test_numpy.py b/tests/pyadjoint/test_numpy.py new file mode 100644 index 00000000..dc582cdc --- /dev/null +++ b/tests/pyadjoint/test_numpy.py @@ -0,0 +1,10 @@ +import numpy as np +from pyadjoint import * +from numpy_adjoint import * + + +def test_ndarray_getitem_single(): + a = create_overloaded_object(np.array([-2.0])) + J = ReducedFunctional(a[0], Control(a)) + dJ = J.derivative() + assert dJ == 1.0