Skip to content

Commit

Permalink
Improved base classes.
Browse files Browse the repository at this point in the history
  • Loading branch information
lauri-codes committed Aug 14, 2023
1 parent c5284a5 commit d50cefa
Show file tree
Hide file tree
Showing 9 changed files with 77 additions and 53 deletions.
13 changes: 8 additions & 5 deletions dscribe/ext/coulombmatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,17 @@ CoulombMatrix::CoulombMatrix(
{
}

void CoulombMatrix::create_raw(
py::detail::unchecked_mutable_reference<double, 1> &out_mu,
py::detail::unchecked_reference<double, 2> &positions_u,
py::detail::unchecked_reference<int, 1> &atomic_numbers_u,
CellList &cell_list
void CoulombMatrix::create(
py::array_t<double> out,
py::array_t<double> positions,
py::array_t<int> atomic_numbers,
CellList cell_list
)
{
// Calculate all pairwise distances.
auto out_mu = out.mutable_unchecked<1>();
auto atomic_numbers_u = atomic_numbers.unchecked<1>();
auto positions_u = positions.unchecked<2>();
int n_atoms = atomic_numbers_u.shape(0);
MatrixXd matrix = distancesEigen(positions_u);

Expand Down
10 changes: 5 additions & 5 deletions dscribe/ext/coulombmatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@ class CoulombMatrix: public DescriptorGlobal {
/**
* For creating feature vectors.
*/
void create_raw(
py::detail::unchecked_mutable_reference<double, 1> &out_mu,
py::detail::unchecked_reference<double, 2> &positions_u,
py::detail::unchecked_reference<int, 1> &atomic_numbers_u,
CellList &cell_list
void create(
py::array_t<double> out,
py::array_t<double> positions,
py::array_t<int> atomic_numbers,
CellList cell_list
);

/**
Expand Down
20 changes: 11 additions & 9 deletions dscribe/ext/descriptorglobal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,18 @@ void DescriptorGlobal::create(
positions = system_extended.positions;
atomic_numbers = system_extended.atomic_numbers;
}
this->create(out, positions, atomic_numbers);
}

void DescriptorGlobal::create(
py::array_t<double> out,
py::array_t<double> positions,
py::array_t<int> atomic_numbers
)
{
// Calculate neighbours with a cell list
CellList cell_list(positions, this->cutoff);
auto out_mu = out.mutable_unchecked<1>();
auto positions_u = positions.unchecked<2>();
auto atomic_numbers_u = atomic_numbers.unchecked<1>();
this->create_raw(out_mu, positions_u, atomic_numbers_u, cell_list);
this->create(out, positions, atomic_numbers, cell_list);
}

void DescriptorGlobal::derivatives_numerical(
Expand All @@ -66,7 +71,6 @@ void DescriptorGlobal::derivatives_numerical(
int n_atoms = atomic_numbers.size();
int n_features = this->get_number_of_features();
auto derivatives_mu = derivatives.mutable_unchecked<3>();
auto descriptor_mu = descriptor.mutable_unchecked<1>();
auto indices_u = indices.unchecked<1>();
auto pbc_u = pbc.unchecked<1>();

Expand All @@ -79,15 +83,13 @@ void DescriptorGlobal::derivatives_numerical(
atomic_numbers = system_extension.atomic_numbers;
}
auto positions_mu = positions.mutable_unchecked<2>();
auto positions_u = positions.unchecked<2>();
auto atomic_numbers_u = atomic_numbers.unchecked<1>();

// Pre-calculate cell list for atoms
CellList cell_list_atoms(positions, this->cutoff);

// Calculate the desciptor value if requested
if (return_descriptor) {
this->create_raw(descriptor_mu, positions_u, atomic_numbers_u, cell_list_atoms);
this->create(descriptor, positions, atomic_numbers, cell_list_atoms);
}

// Central finite difference with error O(h^2)
Expand Down Expand Up @@ -134,7 +136,7 @@ void DescriptorGlobal::derivatives_numerical(
auto d_mu = d.mutable_unchecked<1>();

// Calculate descriptor value
this->create_raw(d_mu, positions_u, atomic_numbers_u, cell_list_atoms);
this->create(d, positions, atomic_numbers, cell_list_atoms);

// Add value to final derivative array
double coeff = coefficients[i_stencil];
Expand Down
50 changes: 34 additions & 16 deletions dscribe/ext/descriptorglobal.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,15 @@ using namespace std;
class DescriptorGlobal : public Descriptor {
public:
/**
* Calculates the feature vector.
*
* @param out Numpy output array for the descriptor.
* @param positions Atomic positions as [n_atoms, 3] numpy array.
* @param atomic_numbers Atomic numbers as [n_atoms] numpy array.
* @param cell Simulation cell as [3, 3] numpy array.
* @param pbc Simulation cell periodicity as [3] numpy array.
*/
* @brief Version of 'create' that automatically extends the system
* based on PBC and calculates celllist.
*
* @param out
* @param positions
* @param atomic_numbers
* @param cell
* @param pbc
*/
void create(
py::array_t<double> out,
py::array_t<double> positions,
Expand All @@ -47,15 +48,32 @@ class DescriptorGlobal : public Descriptor {
);

/**
* Called internally. The system should already be extended
* periodically and CellList should be available.
* @brief Version of 'create' that automatically calculates celllist.
*
* @param out
* @param positions
* @param atomic_numbers
*/
virtual void create_raw(
py::detail::unchecked_mutable_reference<double, 1> &out_mu,
py::detail::unchecked_reference<double, 2> &positions_u,
py::detail::unchecked_reference<int, 1> &atomic_numbers_u,
CellList &cell_list
) = 0;
void create(
py::array_t<double> out,
py::array_t<double> positions,
py::array_t<int> atomic_numbers
);

/**
* @brief Pure virtual function for calculating the feature vectors.
*
* @param out
* @param positions
* @param atomic_numbers
* @param cell_list
*/
virtual void create(
py::array_t<double> out,
py::array_t<double> positions,
py::array_t<int> atomic_numbers,
CellList cell_list
) = 0;

/**
* Calculates the numerical derivates with central finite difference.
Expand Down
6 changes: 3 additions & 3 deletions dscribe/ext/descriptorlocal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ void DescriptorLocal::create(
py::array_t<double> cell,
py::array_t<bool> pbc,
py::array_t<double> centers
) const
)
{
// Extend system if periodicity is requested.
auto pbc_u = pbc.unchecked<1>();
Expand All @@ -52,7 +52,7 @@ void DescriptorLocal::create(
py::array_t<double> positions,
py::array_t<int> atomic_numbers,
py::array_t<double> centers
) const
)
{
// Calculate neighbours with a cell list
CellList cell_list(positions, this->cutoff);
Expand Down Expand Up @@ -93,7 +93,7 @@ void DescriptorLocal::derivatives_numerical(
py::array_t<int> indices,
bool attach,
bool return_descriptor
) const
)
{
int n_copies = 1;
int n_atoms = atomic_numbers.size();
Expand Down
8 changes: 4 additions & 4 deletions dscribe/ext/descriptorlocal.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class DescriptorLocal : public Descriptor {
py::array_t<double> cell,
py::array_t<bool> pbc,
py::array_t<double> centers
) const;
);

/**
* @brief Version of 'create' that automatically calculates celllist.
Expand All @@ -62,7 +62,7 @@ class DescriptorLocal : public Descriptor {
py::array_t<double> positions,
py::array_t<int> atomic_numbers,
py::array_t<double> centers
) const;
);

/**
* @brief Pure virtual function for calculating the feature vectors.
Expand All @@ -79,7 +79,7 @@ class DescriptorLocal : public Descriptor {
py::array_t<int> atomic_numbers,
py::array_t<double> centers,
CellList cell_list
) const = 0;
) = 0;

/**
* Calculates the numerical derivates with central finite difference.
Expand Down Expand Up @@ -109,7 +109,7 @@ class DescriptorLocal : public Descriptor {
py::array_t<int> indices,
bool attach,
bool return_descriptor
) const;
);

protected:
DescriptorLocal(bool periodic, string average="", double cutoff=0);
Expand Down
15 changes: 8 additions & 7 deletions dscribe/ext/ext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
#include <pybind11/numpy.h> // Enables easy access to numpy arrays
#include <pybind11/stl.h> // Enables automatic type conversion from C++ containers to python
#include "descriptorlocal.h"
#include "descriptorglobal.h"
#include "celllist.h"
#include "coulombmatrix.h"
#include "soap.h"
Expand All @@ -36,7 +37,7 @@ PYBIND11_MODULE(ext, m) {
// CoulombMatrix
py::class_<CoulombMatrix>(m, "CoulombMatrix")
.def(py::init<unsigned int, string, double, int>())
.def("create", &CoulombMatrix::create)
.def("create", overload_cast_<py::array_t<double>, py::array_t<double>, py::array_t<int>, py::array_t<double>, py::array_t<bool> >()(&DescriptorGlobal::create))
.def("derivatives_numerical", &CoulombMatrix::derivatives_numerical)
.def(py::pickle(
[](const CoulombMatrix &p) {
Expand All @@ -58,16 +59,16 @@ PYBIND11_MODULE(ext, m) {
// SOAP
py::class_<SOAPGTO>(m, "SOAPGTO")
.def(py::init<double, int, int, double, py::dict, string, double, py::array_t<int>, py::array_t<double>, bool, string, py::array_t<double>, py::array_t<double> >())
.def("create", overload_cast_<py::array_t<double>, py::array_t<double>, py::array_t<int>, py::array_t<double> >()(&DescriptorLocal::create, py::const_))
.def("create", overload_cast_<py::array_t<double>, py::array_t<double>, py::array_t<int>, py::array_t<double>, py::array_t<bool>, py::array_t<double> >()(&DescriptorLocal::create, py::const_))
.def("create", overload_cast_<py::array_t<double>, py::array_t<double>, py::array_t<int>, py::array_t<double>, CellList>()(&SOAPGTO::create, py::const_))
.def("create", overload_cast_<py::array_t<double>, py::array_t<double>, py::array_t<int>, py::array_t<double> >()(&DescriptorLocal::create))
.def("create", overload_cast_<py::array_t<double>, py::array_t<double>, py::array_t<int>, py::array_t<double>, py::array_t<bool>, py::array_t<double> >()(&DescriptorLocal::create))
.def("create", overload_cast_<py::array_t<double>, py::array_t<double>, py::array_t<int>, py::array_t<double>, CellList>()(&SOAPGTO::create))
.def("derivatives_numerical", &SOAPGTO::derivatives_numerical)
.def("derivatives_analytical", &SOAPGTO::derivatives_analytical);
py::class_<SOAPPolynomial>(m, "SOAPPolynomial")
.def(py::init<double, int, int, double, py::dict, string, double, py::array_t<int>, py::array_t<double>, bool, string, py::array_t<double>, py::array_t<double> >())
.def("create", overload_cast_<py::array_t<double>, py::array_t<double>, py::array_t<int>, py::array_t<double> >()(&DescriptorLocal::create, py::const_))
.def("create", overload_cast_<py::array_t<double>, py::array_t<double>, py::array_t<int>, py::array_t<double>, py::array_t<bool>, py::array_t<double> >()(&DescriptorLocal::create, py::const_))
.def("create", overload_cast_<py::array_t<double>, py::array_t<double>, py::array_t<int>, py::array_t<double>, CellList>()(&SOAPPolynomial::create, py::const_))
.def("create", overload_cast_<py::array_t<double>, py::array_t<double>, py::array_t<int>, py::array_t<double> >()(&DescriptorLocal::create))
.def("create", overload_cast_<py::array_t<double>, py::array_t<double>, py::array_t<int>, py::array_t<double>, py::array_t<bool>, py::array_t<double> >()(&DescriptorLocal::create))
.def("create", overload_cast_<py::array_t<double>, py::array_t<double>, py::array_t<int>, py::array_t<double>, CellList>()(&SOAPPolynomial::create))
.def("derivatives_numerical", &SOAPPolynomial::derivatives_numerical);

// ACSF
Expand Down
4 changes: 2 additions & 2 deletions dscribe/ext/soap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ void SOAPGTO::create(
py::array_t<int> atomic_numbers,
py::array_t<double> centers,
CellList cell_list
) const
)
{
// Empty mock arrays since we are not calculating the derivatives
py::array_t<double> xd({1, 1, 1, 1, 1});
Expand Down Expand Up @@ -204,7 +204,7 @@ void SOAPPolynomial::create(
py::array_t<int> atomic_numbers,
py::array_t<double> centers,
CellList cell_list
) const
)
{
soapGeneral(
out,
Expand Down
4 changes: 2 additions & 2 deletions dscribe/ext/soap.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class SOAPGTO: public DescriptorLocal {
py::array_t<int> atomic_numbers,
py::array_t<double> centers,
CellList cell_list
) const;
);

int get_number_of_features() const;

Expand Down Expand Up @@ -113,7 +113,7 @@ class SOAPPolynomial: public DescriptorLocal {
py::array_t<int> atomic_numbers,
py::array_t<double> centers,
CellList cell_list
) const;
);

int get_number_of_features() const;

Expand Down

0 comments on commit d50cefa

Please sign in to comment.