From d50cefad7cf535e99b7ddee1f49617b9cf9c983b Mon Sep 17 00:00:00 2001 From: Lauri Himanen Date: Mon, 14 Aug 2023 21:52:23 +0300 Subject: [PATCH] Improved base classes. --- dscribe/ext/coulombmatrix.cpp | 13 +++++---- dscribe/ext/coulombmatrix.h | 10 +++---- dscribe/ext/descriptorglobal.cpp | 20 +++++++------ dscribe/ext/descriptorglobal.h | 50 ++++++++++++++++++++++---------- dscribe/ext/descriptorlocal.cpp | 6 ++-- dscribe/ext/descriptorlocal.h | 8 ++--- dscribe/ext/ext.cpp | 15 +++++----- dscribe/ext/soap.cpp | 4 +-- dscribe/ext/soap.h | 4 +-- 9 files changed, 77 insertions(+), 53 deletions(-) diff --git a/dscribe/ext/coulombmatrix.cpp b/dscribe/ext/coulombmatrix.cpp index d47bc453..df716cfb 100644 --- a/dscribe/ext/coulombmatrix.cpp +++ b/dscribe/ext/coulombmatrix.cpp @@ -37,14 +37,17 @@ CoulombMatrix::CoulombMatrix( { } -void CoulombMatrix::create_raw( - py::detail::unchecked_mutable_reference &out_mu, - py::detail::unchecked_reference &positions_u, - py::detail::unchecked_reference &atomic_numbers_u, - CellList &cell_list +void CoulombMatrix::create( + py::array_t out, + py::array_t positions, + py::array_t atomic_numbers, + CellList cell_list ) { // Calculate all pairwise distances. + auto out_mu = out.mutable_unchecked<1>(); + auto atomic_numbers_u = atomic_numbers.unchecked<1>(); + auto positions_u = positions.unchecked<2>(); int n_atoms = atomic_numbers_u.shape(0); MatrixXd matrix = distancesEigen(positions_u); diff --git a/dscribe/ext/coulombmatrix.h b/dscribe/ext/coulombmatrix.h index 7a5890ab..aa1cefc5 100644 --- a/dscribe/ext/coulombmatrix.h +++ b/dscribe/ext/coulombmatrix.h @@ -44,11 +44,11 @@ class CoulombMatrix: public DescriptorGlobal { /** * For creating feature vectors. */ - void create_raw( - py::detail::unchecked_mutable_reference &out_mu, - py::detail::unchecked_reference &positions_u, - py::detail::unchecked_reference &atomic_numbers_u, - CellList &cell_list + void create( + py::array_t out, + py::array_t positions, + py::array_t atomic_numbers, + CellList cell_list ); /** diff --git a/dscribe/ext/descriptorglobal.cpp b/dscribe/ext/descriptorglobal.cpp index 331adcfa..6ee7bc78 100644 --- a/dscribe/ext/descriptorglobal.cpp +++ b/dscribe/ext/descriptorglobal.cpp @@ -42,13 +42,18 @@ void DescriptorGlobal::create( positions = system_extended.positions; atomic_numbers = system_extended.atomic_numbers; } + this->create(out, positions, atomic_numbers); +} +void DescriptorGlobal::create( + py::array_t out, + py::array_t positions, + py::array_t atomic_numbers +) +{ // Calculate neighbours with a cell list CellList cell_list(positions, this->cutoff); - auto out_mu = out.mutable_unchecked<1>(); - auto positions_u = positions.unchecked<2>(); - auto atomic_numbers_u = atomic_numbers.unchecked<1>(); - this->create_raw(out_mu, positions_u, atomic_numbers_u, cell_list); + this->create(out, positions, atomic_numbers, cell_list); } void DescriptorGlobal::derivatives_numerical( @@ -66,7 +71,6 @@ void DescriptorGlobal::derivatives_numerical( int n_atoms = atomic_numbers.size(); int n_features = this->get_number_of_features(); auto derivatives_mu = derivatives.mutable_unchecked<3>(); - auto descriptor_mu = descriptor.mutable_unchecked<1>(); auto indices_u = indices.unchecked<1>(); auto pbc_u = pbc.unchecked<1>(); @@ -79,15 +83,13 @@ void DescriptorGlobal::derivatives_numerical( atomic_numbers = system_extension.atomic_numbers; } auto positions_mu = positions.mutable_unchecked<2>(); - auto positions_u = positions.unchecked<2>(); - auto atomic_numbers_u = atomic_numbers.unchecked<1>(); // Pre-calculate cell list for atoms CellList cell_list_atoms(positions, this->cutoff); // Calculate the desciptor value if requested if (return_descriptor) { - this->create_raw(descriptor_mu, positions_u, atomic_numbers_u, cell_list_atoms); + this->create(descriptor, positions, atomic_numbers, cell_list_atoms); } // Central finite difference with error O(h^2) @@ -134,7 +136,7 @@ void DescriptorGlobal::derivatives_numerical( auto d_mu = d.mutable_unchecked<1>(); // Calculate descriptor value - this->create_raw(d_mu, positions_u, atomic_numbers_u, cell_list_atoms); + this->create(d, positions, atomic_numbers, cell_list_atoms); // Add value to final derivative array double coeff = coefficients[i_stencil]; diff --git a/dscribe/ext/descriptorglobal.h b/dscribe/ext/descriptorglobal.h index 76114cf0..cc7a373f 100644 --- a/dscribe/ext/descriptorglobal.h +++ b/dscribe/ext/descriptorglobal.h @@ -30,14 +30,15 @@ using namespace std; class DescriptorGlobal : public Descriptor { public: /** - * Calculates the feature vector. - * - * @param out Numpy output array for the descriptor. - * @param positions Atomic positions as [n_atoms, 3] numpy array. - * @param atomic_numbers Atomic numbers as [n_atoms] numpy array. - * @param cell Simulation cell as [3, 3] numpy array. - * @param pbc Simulation cell periodicity as [3] numpy array. - */ + * @brief Version of 'create' that automatically extends the system + * based on PBC and calculates celllist. + * + * @param out + * @param positions + * @param atomic_numbers + * @param cell + * @param pbc + */ void create( py::array_t out, py::array_t positions, @@ -47,15 +48,32 @@ class DescriptorGlobal : public Descriptor { ); /** - * Called internally. The system should already be extended - * periodically and CellList should be available. + * @brief Version of 'create' that automatically calculates celllist. + * + * @param out + * @param positions + * @param atomic_numbers */ - virtual void create_raw( - py::detail::unchecked_mutable_reference &out_mu, - py::detail::unchecked_reference &positions_u, - py::detail::unchecked_reference &atomic_numbers_u, - CellList &cell_list - ) = 0; + void create( + py::array_t out, + py::array_t positions, + py::array_t atomic_numbers + ); + + /** + * @brief Pure virtual function for calculating the feature vectors. + * + * @param out + * @param positions + * @param atomic_numbers + * @param cell_list + */ + virtual void create( + py::array_t out, + py::array_t positions, + py::array_t atomic_numbers, + CellList cell_list + ) = 0; /** * Calculates the numerical derivates with central finite difference. diff --git a/dscribe/ext/descriptorlocal.cpp b/dscribe/ext/descriptorlocal.cpp index 608a3253..f289ceaa 100644 --- a/dscribe/ext/descriptorlocal.cpp +++ b/dscribe/ext/descriptorlocal.cpp @@ -34,7 +34,7 @@ void DescriptorLocal::create( py::array_t cell, py::array_t pbc, py::array_t centers -) const +) { // Extend system if periodicity is requested. auto pbc_u = pbc.unchecked<1>(); @@ -52,7 +52,7 @@ void DescriptorLocal::create( py::array_t positions, py::array_t atomic_numbers, py::array_t centers -) const +) { // Calculate neighbours with a cell list CellList cell_list(positions, this->cutoff); @@ -93,7 +93,7 @@ void DescriptorLocal::derivatives_numerical( py::array_t indices, bool attach, bool return_descriptor -) const +) { int n_copies = 1; int n_atoms = atomic_numbers.size(); diff --git a/dscribe/ext/descriptorlocal.h b/dscribe/ext/descriptorlocal.h index cca39721..71d630e5 100644 --- a/dscribe/ext/descriptorlocal.h +++ b/dscribe/ext/descriptorlocal.h @@ -47,7 +47,7 @@ class DescriptorLocal : public Descriptor { py::array_t cell, py::array_t pbc, py::array_t centers - ) const; + ); /** * @brief Version of 'create' that automatically calculates celllist. @@ -62,7 +62,7 @@ class DescriptorLocal : public Descriptor { py::array_t positions, py::array_t atomic_numbers, py::array_t centers - ) const; + ); /** * @brief Pure virtual function for calculating the feature vectors. @@ -79,7 +79,7 @@ class DescriptorLocal : public Descriptor { py::array_t atomic_numbers, py::array_t centers, CellList cell_list - ) const = 0; + ) = 0; /** * Calculates the numerical derivates with central finite difference. @@ -109,7 +109,7 @@ class DescriptorLocal : public Descriptor { py::array_t indices, bool attach, bool return_descriptor - ) const; + ); protected: DescriptorLocal(bool periodic, string average="", double cutoff=0); diff --git a/dscribe/ext/ext.cpp b/dscribe/ext/ext.cpp index 04dc75b7..4e318489 100644 --- a/dscribe/ext/ext.cpp +++ b/dscribe/ext/ext.cpp @@ -17,6 +17,7 @@ limitations under the License. #include // Enables easy access to numpy arrays #include // Enables automatic type conversion from C++ containers to python #include "descriptorlocal.h" +#include "descriptorglobal.h" #include "celllist.h" #include "coulombmatrix.h" #include "soap.h" @@ -36,7 +37,7 @@ PYBIND11_MODULE(ext, m) { // CoulombMatrix py::class_(m, "CoulombMatrix") .def(py::init()) - .def("create", &CoulombMatrix::create) + .def("create", overload_cast_, py::array_t, py::array_t, py::array_t, py::array_t >()(&DescriptorGlobal::create)) .def("derivatives_numerical", &CoulombMatrix::derivatives_numerical) .def(py::pickle( [](const CoulombMatrix &p) { @@ -58,16 +59,16 @@ PYBIND11_MODULE(ext, m) { // SOAP py::class_(m, "SOAPGTO") .def(py::init, py::array_t, bool, string, py::array_t, py::array_t >()) - .def("create", overload_cast_, py::array_t, py::array_t, py::array_t >()(&DescriptorLocal::create, py::const_)) - .def("create", overload_cast_, py::array_t, py::array_t, py::array_t, py::array_t, py::array_t >()(&DescriptorLocal::create, py::const_)) - .def("create", overload_cast_, py::array_t, py::array_t, py::array_t, CellList>()(&SOAPGTO::create, py::const_)) + .def("create", overload_cast_, py::array_t, py::array_t, py::array_t >()(&DescriptorLocal::create)) + .def("create", overload_cast_, py::array_t, py::array_t, py::array_t, py::array_t, py::array_t >()(&DescriptorLocal::create)) + .def("create", overload_cast_, py::array_t, py::array_t, py::array_t, CellList>()(&SOAPGTO::create)) .def("derivatives_numerical", &SOAPGTO::derivatives_numerical) .def("derivatives_analytical", &SOAPGTO::derivatives_analytical); py::class_(m, "SOAPPolynomial") .def(py::init, py::array_t, bool, string, py::array_t, py::array_t >()) - .def("create", overload_cast_, py::array_t, py::array_t, py::array_t >()(&DescriptorLocal::create, py::const_)) - .def("create", overload_cast_, py::array_t, py::array_t, py::array_t, py::array_t, py::array_t >()(&DescriptorLocal::create, py::const_)) - .def("create", overload_cast_, py::array_t, py::array_t, py::array_t, CellList>()(&SOAPPolynomial::create, py::const_)) + .def("create", overload_cast_, py::array_t, py::array_t, py::array_t >()(&DescriptorLocal::create)) + .def("create", overload_cast_, py::array_t, py::array_t, py::array_t, py::array_t, py::array_t >()(&DescriptorLocal::create)) + .def("create", overload_cast_, py::array_t, py::array_t, py::array_t, CellList>()(&SOAPPolynomial::create)) .def("derivatives_numerical", &SOAPPolynomial::derivatives_numerical); // ACSF diff --git a/dscribe/ext/soap.cpp b/dscribe/ext/soap.cpp index 05aac13c..6b7fa1c4 100644 --- a/dscribe/ext/soap.cpp +++ b/dscribe/ext/soap.cpp @@ -55,7 +55,7 @@ void SOAPGTO::create( py::array_t atomic_numbers, py::array_t centers, CellList cell_list -) const +) { // Empty mock arrays since we are not calculating the derivatives py::array_t xd({1, 1, 1, 1, 1}); @@ -204,7 +204,7 @@ void SOAPPolynomial::create( py::array_t atomic_numbers, py::array_t centers, CellList cell_list -) const +) { soapGeneral( out, diff --git a/dscribe/ext/soap.h b/dscribe/ext/soap.h index 8be8460d..d584f760 100644 --- a/dscribe/ext/soap.h +++ b/dscribe/ext/soap.h @@ -51,7 +51,7 @@ class SOAPGTO: public DescriptorLocal { py::array_t atomic_numbers, py::array_t centers, CellList cell_list - ) const; + ); int get_number_of_features() const; @@ -113,7 +113,7 @@ class SOAPPolynomial: public DescriptorLocal { py::array_t atomic_numbers, py::array_t centers, CellList cell_list - ) const; + ); int get_number_of_features() const;