From 815c69c30dcf1e83d81a7dd617ecc6f4339636f4 Mon Sep 17 00:00:00 2001 From: Marcus Wieder <31651017+wiederm@users.noreply.github.com> Date: Mon, 4 Sep 2023 23:23:11 +0200 Subject: [PATCH] Initial Base Classes and base functionality for Neural Network Potentials (#5) * delete notebook * first outline of base NNP class * remove useless example * updating tests * change dtype naming * simplify base model * update base implementation * reference schnet impoementation * adding imports, removing alternative implementstions * adding test * still work in progress * adding comments * adding docstrings * adding tests * bugfix * remove print statments * bugfix --- devtools/conda-envs/test_env.yaml | 1 + modelforge/potential/__initi__.py | 2 + modelforge/potential/features.py | 0 modelforge/potential/models.py | 67 ++++++++ modelforge/potential/schnet.py | 223 +++++++++++++++++++++++++ modelforge/potential/utils.py | 235 +++++++++++++++++++++++++++ modelforge/tests/helper_functinos.py | 33 ++++ modelforge/tests/test_models.py | 30 ++++ modelforge/tests/test_schnet.py | 22 +++ modelforge/tests/test_utils.py | 23 +++ modelforge/utils/__init__.py | 4 +- modelforge/utils/prop.py | 33 ++++ notebooks/dataclass.ipynb | 150 ----------------- 13 files changed, 672 insertions(+), 151 deletions(-) delete mode 100644 modelforge/potential/features.py create mode 100644 modelforge/potential/schnet.py create mode 100644 modelforge/potential/utils.py create mode 100644 modelforge/tests/helper_functinos.py create mode 100644 modelforge/tests/test_models.py create mode 100644 modelforge/tests/test_schnet.py create mode 100644 modelforge/tests/test_utils.py create mode 100644 modelforge/utils/prop.py delete mode 100644 notebooks/dataclass.ipynb diff --git a/devtools/conda-envs/test_env.yaml b/devtools/conda-envs/test_env.yaml index 6828d078..8527b1db 100644 --- a/devtools/conda-envs/test_env.yaml +++ b/devtools/conda-envs/test_env.yaml @@ -19,6 +19,7 @@ dependencies: - lightning - tensorboard - torchvision + - ase # Testing - pytest diff --git a/modelforge/potential/__initi__.py b/modelforge/potential/__initi__.py index e69de29b..61327ce3 100644 --- a/modelforge/potential/__initi__.py +++ b/modelforge/potential/__initi__.py @@ -0,0 +1,2 @@ +from .schnet import Schnet +from .utils import Dense, GaussianRBF, cosine_cutoff, shifted_softplus, scatter_add diff --git a/modelforge/potential/features.py b/modelforge/potential/features.py deleted file mode 100644 index e69de29b..00000000 diff --git a/modelforge/potential/models.py b/modelforge/potential/models.py index e69de29b..cace0b71 100644 --- a/modelforge/potential/models.py +++ b/modelforge/potential/models.py @@ -0,0 +1,67 @@ +from typing import Dict, List, Optional + +import torch +import torch.nn as nn + +from modelforge.utils import Inputs, Properties, SpeciesEnergies + + +class BaseNNP(nn.Module): + """ + Abstract base class for neural network potentials. + This class defines the overall structure and ensures that subclasses + implement the `calculate_energies_and_forces` method. + """ + + def __init__(self, dtype: torch.dtype, device: torch.device): + """ + Initialize the NeuralNetworkPotential class. + + Parameters + ---------- + dtype : torch.dtype + Data type for the PyTorch tensors. + device : torch.device + Device ("cpu" or "cuda") on which computations will be performed. + + """ + super().__init__() + self.dtype = dtype + self.device = device + + def forward( + self, + inputs: Inputs, + ) -> SpeciesEnergies: + """ + Forward pass for the neural network potential. + + Parameters + ---------- + inputs : Inputs + An instance of the Inputs data class containing atomic numbers, positions, etc. + + Returns + ------- + SpeciesEnergies + An instance of the SpeciesEnergies data class containing species and calculated energies. + + """ + + E = self.calculate_energies_and_forces(inputs) + return SpeciesEnergies(inputs.Z, E) + + def calculate_energies_and_forces(self, inputs: Optional[Inputs] = None): + """ + Placeholder for the method that should calculate energies and forces. + This method should be implemented in subclasses. + + Raises + ------ + NotImplementedError + If the method is not overridden in the subclass. + + """ + raise NotImplementedError("Subclasses must implement this method.") + + diff --git a/modelforge/potential/schnet.py b/modelforge/potential/schnet.py new file mode 100644 index 00000000..6705bf01 --- /dev/null +++ b/modelforge/potential/schnet.py @@ -0,0 +1,223 @@ +from typing import Tuple +from loguru import logger + +import numpy as np +import torch +import torch.nn as nn +from ase import Atoms +from ase.neighborlist import neighbor_list +from torch import dtype + +from modelforge.utils import Inputs + +from .models import BaseNNP +from .utils import Dense, GaussianRBF, cosine_cutoff, shifted_softplus, scatter_add + + +class Schnet(BaseNNP): + """ + Implementation of the SchNet architecture for quantum mechanical property prediction. + """ + + def __init__( + self, + n_atom_basis: int, # number of features per atom + n_interactions: int, # number of interaction blocks + n_filters: int = 0, # number of filters + dtype: dtype = torch.float32, + device: torch.device = torch.device( + "cuda" if torch.cuda.is_available() else "cpu" + ), + ): + """ + Initialize the SchNet model. + + Parameters + ---------- + n_atom_basis : int + Number of features per atom. + n_interactions : int + Number of interaction blocks. + n_filters : int, optional + Number of filters, defaults to None. + dtype : torch.dtype, optional + Data type for PyTorch tensors, defaults to torch.float32. + device : torch.device, optional + Device ("cpu" or "cuda") on which computations will be performed. + + """ + + super().__init__(dtype, device) + + # initialize atom embeddings + max_z: int = 100 # max nuclear charge (i.e. atomic number) + self.embedding = nn.Embedding(max_z, n_atom_basis, padding_idx=0) + + # initialize radial basis functions and other constants + n_rbf = 20 + self.radial_basis = GaussianRBF(n_rbf=n_rbf, cutoff=5.0) + self.cutoff = 5.0 + self.activation = shifted_softplus + self.n_interactions = n_interactions + self.n_atom_basis = n_atom_basis + + # initialize dense yalers for atom feature transformation + # Dense layers are applied consecutively to the initialized atom embeddings x^{l}_{0} + # to generate x_i^l+1 = W^lx^l_i + b^l + self.intput_to_feature = Dense( + n_atom_basis, n_filters, bias=False, activation=None + ) + self.feature_to_output = nn.Sequential( + Dense(n_filters, n_atom_basis, activation=self.activation), + Dense(n_atom_basis, n_atom_basis, activation=None), + ) + + # Initialize filter network + self.filter_network = nn.Sequential( + Dense(n_rbf, n_filters, activation=self.activation), + Dense(n_filters, n_filters), + ) + + def _setup_ase_system(self, inputs: Inputs) -> Atoms: + """ + Transform inputs to an ASE Atoms object. + + Parameters + ---------- + inputs : Inputs + Input features including atomic numbers and positions. + + Returns + ------- + ase.Atoms + Transformed ASE Atoms object. + + """ + _atomic_numbers = torch.clone(inputs.Z) + atomic_numbers = list(_atomic_numbers.detach().cpu().numpy()) + positions = list(inputs.R.detach().cpu().numpy()) + ase_atoms = Atoms(numbers=atomic_numbers, positions=positions) + return ase_atoms + + def _compute_distances( + self, inputs: Inputs + ) -> Tuple[torch.Tensor, np.ndarray, np.ndarray]: + """ + Compute atomic distances using ASE's neighbor list. + + Parameters + ---------- + inputs : Inputs + Input features including atomic numbers and positions. + + Returns + ------- + torch.Tensor, np.ndarray, np.ndarray + Pairwise distances, index of atom i, and index of atom j. + + """ + + ase_atoms = self._setup_ase_system(inputs) + idx_i, idx_j, _, r_ij = neighbor_list( + "ijSD", ase_atoms, 5.0, self_interaction=False + ) + r_ij = torch.from_numpy(r_ij) + return r_ij, idx_i, idx_j + + def _distance_to_radial_basis(self, r_ij): + """ + Transform distances to radial basis functions. + + Parameters + ---------- + r_ij : torch.Tensor + Pairwise distances between atoms. + + Returns + ------- + torch.Tensor, torch.Tensor + Radial basis functions and cutoff values. + + """ + d_ij = torch.norm(r_ij, dim=1) # calculate pairwise distances + f_ij = self.radial_basis(d_ij) + rcut_ij = cosine_cutoff(d_ij, self.cutoff) + return f_ij, rcut_ij + + def _interaction_block(self, inputs: Inputs, f_ij, idx_i, idx_j, rcut_ij): + """ + Compute the interaction block which updates atom features. + + Parameters + ---------- + inputs : Inputs + Input features including atomic numbers and positions. + f_ij : torch.Tensor + Radial basis functions. + idx_i : np.ndarray + Indices of center atoms. + idx_j : np.ndarray + Indices of neighboring atoms. + rcut_ij : torch.Tensor + Cutoff values for each pair of atoms. + + Returns + ------- + torch.Tensor + Updated atom features. + + """ + + # compute atom and pair features (see Fig1 in 10.1063/1.5019779) + # initializing x^{l}_{0} as x^l)0 = aZ_i + logger.debug("Embedding inputs.Z") + x_emb = self.embedding(inputs.Z) + logger.debug("After embedding: x.shape", x_emb.shape) + idx_i = torch.from_numpy(idx_i).to(self.device, torch.int64) + + # interaction blocks + for _ in range(self.n_interactions): + # atom wise update of features + logger.debug(f"Input to feature: x.shape {x_emb.shape}") + x = self.intput_to_feature(x_emb) + logger.debug("After input_to_feature call: x.shape {x.shape}") + + # Filter generation networks + Wij = self.filter_network(f_ij) + Wij = Wij * rcut_ij[:, None] + Wij = Wij.to(dtype=self.dtype) + + # continuous-filter convolutional layers + x_j = x[idx_j] + x_ij = x_j * Wij + logger.debug("After x_j * Wij: x_ij.shape {x_ij.shape}") + x = scatter_add(x_ij, idx_i, dim_size=x.shape[0]) + logger.debug("After scatter_add: x.shape {x.shape}") + # Update features + x = self.feature_to_output(x) + x_emb = x_emb + x + + return x_emb + + def calculate_energies_and_forces(self, inputs: Inputs) -> torch.Tensor: + """ + Compute energies and forces for given atomic configurations. + + Parameters + ---------- + inputs : Inputs + Input features including atomic numbers and positions. + + Returns + ------- + torch.Tensor + Energies and forces for the given configurations. + + """ + logger.debug("Compute distances ...") + r_ij, idx_i, idx_j = self._compute_distances(inputs) + logger.debug("Convert distances to radial basis ...") + f_ij, rcut_ij = self._distance_to_radial_basis(r_ij) + logger.debug("Compute interaction block ...") + x = self._interaction_block(inputs, f_ij, idx_i, idx_j, rcut_ij) + return x diff --git a/modelforge/potential/utils.py b/modelforge/potential/utils.py new file mode 100644 index 00000000..ea5448a3 --- /dev/null +++ b/modelforge/potential/utils.py @@ -0,0 +1,235 @@ +import torch +from typing import Callable, Union +import torch.nn as nn +import numpy as np +import torch.nn.functional as F + + +def _scatter_add( + x: torch.Tensor, idx_i: torch.Tensor, dim_size: int, dim: int = 0 +) -> torch.Tensor: + shape = list(x.shape) + shape[dim] = dim_size + tmp = torch.zeros(shape, dtype=x.dtype, device=x.device) + y = tmp.index_add(dim, idx_i, x) + return y + + +def scatter_add( + x: torch.Tensor, idx_i: torch.Tensor, dim_size: int, dim: int = 0 +) -> torch.Tensor: + """ + Sum over values with the same indices. + + Args: + x: input values + idx_i: index of center atom i + dim_size: size of the dimension after reduction + dim: the dimension to reduce + + Returns: + reduced input + + """ + return _scatter_add(x, idx_i, dim_size, dim) + + +class Dense(nn.Linear): + r"""Fully connected linear layer with activation function. + + .. math:: + y = activation(x W^T + b) + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + activation: Union[Callable, nn.Module] = None, + dtype: torch.dtype = torch.float32, + device: torch.device = torch.device( + "cuda" if torch.cuda.is_available() else "cpu" + ), + ): + """ + Initialize the Dense layer. + + Parameters + ---------- + in_features : int + Number of input features. + out_features : int + Number of output features. + bias : bool, optional + If False, the layer will not adapt bias. + activation : Callable or nn.Module, optional + Activation function, default is None (Identity). + dtype : torch.dtype, optional + Data type for PyTorch tensors. + device : torch.device, optional + Device ("cpu" or "cuda") on which computations will be performed. + """ + super().__init__(in_features, out_features, bias) + + # Initialize activation function + self.activation = activation if activation is not None else nn.Identity() + + # Initialize weight matrix + self.weight = nn.init.xavier_uniform_(self.weight).to(device, dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass through the layer. + + Parameters + ---------- + x : torch.Tensor + Input tensor. + + Returns + ------- + torch.Tensor + Transformed tensor. + """ + + y = F.linear(x, self.weight, self.bias) + y = self.activation(y) + return y + + +def gaussian_rbf( + inputs: torch.Tensor, offsets: torch.Tensor, widths: torch.Tensor +) -> torch.Tensor: + """ + Gaussian radial basis function (RBF) transformation. + + Parameters + ---------- + inputs : torch.Tensor + Input tensor. + offsets : torch.Tensor + Offsets for Gaussian functions. + widths : torch.Tensor + Widths for Gaussian functions. + + Returns + ------- + torch.Tensor + Transformed tensor. + """ + + coeff = -0.5 / torch.pow(widths, 2) + diff = inputs[..., None] - offsets + y = torch.exp(coeff * torch.pow(diff, 2)) + return y.to(dtype=torch.float32) + + +def cosine_cutoff(input: torch.Tensor, cutoff: torch.Tensor) -> torch.Tensor: + """ + Behler-style cosine cutoff function. + + Parameters + ---------- + inputs : torch.Tensor + Input tensor. + cutoff : torch.Tensor + Cutoff radius. + + Returns + ------- + torch.Tensor + Transformed tensor. + """ + + # Compute values of cutoff function + input_cut = 0.5 * (torch.cos(input * np.pi / cutoff) + 1.0) + # Remove contributions beyond the cutoff radius + input_cut *= input < cutoff + return input_cut + + +def shifted_softplus(x: torch.Tensor) -> torch.Tensor: + """ + Compute shifted soft-plus activation function. + + Parameters + ---------- + x : torch.Tensor + Input tensor. + + Returns + ------- + torch.Tensor + Transformed tensor. + """ + return nn.functional.softplus(x) - np.log(2.0) + + +class GaussianRBF(nn.Module): + """ + Gaussian radial basis functions (RBF). + """ + + def __init__( + self, + n_rbf: int, + cutoff: float, + start: float = 0.0, + trainable: bool = False, + device: torch.device = torch.device( + "cuda" if torch.cuda.is_available() else "cpu" + ), + dtype: torch.dtype = torch.float32, + ): + """ + Initialize Gaussian RBF layer. + + Parameters + ---------- + n_rbf : int + Number of radial basis functions. + cutoff : float + Cutoff distance for RBF. + start : float, optional + Starting distance for RBF, defaults to 0.0. + trainable : bool, optional + If True, widths and offsets are trainable parameters. + device : torch.device, optional + Device ("cpu" or "cuda") on which computations will be performed. + dtype : torch.dtype, optional + Data type for PyTorch tensors. + + """ + super().__init__() + self.n_rbf = n_rbf + + # compute offset and width of Gaussian functions + offset = torch.linspace(start, cutoff, n_rbf, dtype=dtype, device=device) + widths = torch.tensor( + torch.abs(offset[1] - offset[0]) * torch.ones_like(offset), + device=device, + dtype=dtype, + ) + if trainable: + self.widths = nn.Parameter(widths) + self.offsets = nn.Parameter(offset) + else: + self.register_buffer("widths", widths) + self.register_buffer("offsets", offset) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + """ + Forward pass through the layer. + + Parameters + ---------- + inputs : torch.Tensor + Input tensor. + + Returns + ------- + torch.Tensor + Transformed tensor. + """ + return gaussian_rbf(inputs, self.offsets, self.widths) diff --git a/modelforge/tests/helper_functinos.py b/modelforge/tests/helper_functinos.py new file mode 100644 index 00000000..217cccd4 --- /dev/null +++ b/modelforge/tests/helper_functinos.py @@ -0,0 +1,33 @@ +import torch + +from modelforge.dataset.dataset import TorchDataModule +from modelforge.dataset.qm9 import QM9Dataset +from modelforge.utils import Inputs + + +def default_input(): + train_loader = initialize_dataloader() + R, Z, E = train_loader.dataset[0] + padded_values = -Z.eq(-1).sum().item() + Z_ = Z[:padded_values] + R_ = R[:padded_values] + return Inputs(Z_, R_, E) + + +def initialize_dataloader() -> torch.utils.data.DataLoader: + data = QM9Dataset(for_unit_testing=True) + data_module = TorchDataModule(data) + data_module.prepare_data() + data_module.setup("fit") + return data_module.train_dataloader() + + +methane_coordinates = torch.tensor( + [ + [0.0, 0.0, 0.0], + [0.63918859, 0.63918859, 0.63918859], + [-0.63918859, -0.63918859, 0.63918859], + [-0.63918859, 0.63918859, -0.63918859], + [0.63918859, -0.63918859, -0.63918859], + ] +) diff --git a/modelforge/tests/test_models.py b/modelforge/tests/test_models.py new file mode 100644 index 00000000..25e3ed68 --- /dev/null +++ b/modelforge/tests/test_models.py @@ -0,0 +1,30 @@ +from typing import Optional + +import pytest +import torch + +from modelforge.potential.models import BaseNNP +from modelforge.potential.schnet import Schnet +from .helper_functinos import default_input + +MODELS_TO_TEST = [Schnet] + + +def setup_simple_model(model_class) -> Optional[BaseNNP]: + if model_class is Schnet: + return Schnet(n_atom_basis=128, n_interactions=3, n_filters=64) + else: + raise NotImplementedError + + +def test_BaseNNP(): + nnp = BaseNNP(dtype=torch.float32, device="cpu") + assert nnp.dtype == torch.float32 + assert str(nnp.device) == "cpu" + + +@pytest.mark.parametrize("model_class", MODELS_TO_TEST) +def test_forward_pass(model_class): + initialized_model = setup_simple_model(model_class) + inputs = default_input() + output = initialized_model.forward(inputs) diff --git a/modelforge/tests/test_schnet.py b/modelforge/tests/test_schnet.py new file mode 100644 index 00000000..dd064c4e --- /dev/null +++ b/modelforge/tests/test_schnet.py @@ -0,0 +1,22 @@ +import torch +from modelforge.potential.schnet import Schnet +from modelforge.utils import Inputs + + +def test_Schnet_init(): + schnet = Schnet(128, 6, 2) + assert schnet.n_atom_basis == 128 + assert schnet.n_interactions == 6 + + +def test_calculate_energies_and_forces(): + # this test will be adopted as soon as we have a + # trained model. Here we want to test the + # energy and force calculatino on Methane + + schnet = Schnet(128, 6, 64) + Z = torch.tensor([1, 8], dtype=torch.int64) + R = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.96]], dtype=torch.float32) + inputs = Inputs(Z, R, torch.tensor([100])) + result = schnet.calculate_energies_and_forces(inputs) + assert result.shape[1] == 128 # Assuming n_atom_basis is 128 diff --git a/modelforge/tests/test_utils.py b/modelforge/tests/test_utils.py new file mode 100644 index 00000000..cea5d45a --- /dev/null +++ b/modelforge/tests/test_utils.py @@ -0,0 +1,23 @@ +import numpy as np +from modelforge.potential.utils import scatter_add, Dense, GaussianRBF +import torch + +def test_scatter_add(): + x = torch.tensor([1, 4, 3, 2], dtype=torch.float32) + idx_i = torch.tensor([0, 2, 2, 1], dtype=torch.int64) + result = scatter_add(x, idx_i, dim_size=3) + assert torch.equal(result, torch.tensor([1.0, 2.0, 7.0])) + + +def test_Dense(): + layer = Dense(2, 3) + x = torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32) + y = layer(x) + assert y.shape == (2, 3) + + +def test_GaussianRBF(): + layer = GaussianRBF(10, 5.0) + x = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) + y = layer(x) + assert y.shape == (3, 10) diff --git a/modelforge/utils/__init__.py b/modelforge/utils/__init__.py index 52c5105c..d7489235 100644 --- a/modelforge/utils/__init__.py +++ b/modelforge/utils/__init__.py @@ -1 +1,3 @@ -"""modelforge utilities.""" \ No newline at end of file +"""modelforge utilities.""" + +from .prop import Properties, Inputs, SpeciesEnergies diff --git a/modelforge/utils/prop.py b/modelforge/utils/prop.py new file mode 100644 index 00000000..30f4e428 --- /dev/null +++ b/modelforge/utils/prop.py @@ -0,0 +1,33 @@ +from dataclasses import dataclass +import torch +from typing import Optional +from loguru import logger + + +@dataclass +class Properties: + Z: str = "atomic_numbers" + R: str = "positions" + + +@dataclass +class Inputs: + Z: torch.Tensor + R: torch.Tensor + E: torch.Tensor + cell: Optional[torch.Tensor] = (None,) + pbc: Optional[torch.Tensor] = (None,) + dtype: torch.dtype = torch.float32 + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + def __post_init__(self): + assert self.Z.shape[0] == self.R.shape[0] + logger.info(f"Transforming Z and R to {self.dtype}") + self.Z = self.Z.to(self.device, torch.int32) + self.R = self.R.to(self.device, self.dtype) + + +@dataclass +class SpeciesEnergies: + species: torch.Tensor + energies: torch.Tensor diff --git a/notebooks/dataclass.ipynb b/notebooks/dataclass.ipynb deleted file mode 100644 index 3c714943..00000000 --- a/notebooks/dataclass.ipynb +++ /dev/null @@ -1,150 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "import os, sys\n", - "import numpy as np\n", - "import qcportal as ptl\n", - "import pandas as pd\n", - "from collections import defaultdict\n", - "import torch\n", - "from loguru import logger\n" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "class QM9Dataset(torch.utils.data.Dataset):\n", - " \n", - " def __init__(self, dataset_file: str = None, name: str = \"\"):\n", - " self._name = name\n", - " self.dataset_file = dataset_file\n", - " self.load(dataset_file)\n", - " self.molecules = self.qm9.get_molecules()\n", - " self.records = self.qm9.get_records(method=\"b3lyp\")\n", - "\n", - "\n", - " def load(self, dataset_file):\n", - " \"\"\"\n", - " Loads the raw dataset from qcarchive.\n", - "\n", - " If a valid qcarchive generated hdf5 file is not pass to the\n", - " set_raw_dataset_file function, the code will download the\n", - " raw dataset from qcarchive.\n", - " \"\"\"\n", - " qcp_client = ptl.FractalClient()\n", - " qcportal_data = {\"collection\": \"Dataset\", \"dataset\": \"QM9\"}\n", - "\n", - " try:\n", - " self.qm9 = qcp_client.get_collection(\n", - " qcportal_data[\"collection\"], qcportal_data[\"dataset\"]\n", - " )\n", - " except Exception:\n", - " print(\n", - " f\"Dataset {qcportal_data['dataset']} is not available in collection {qcportal_data['collection']}.\"\n", - " )\n", - "\n", - " if dataset_file and os.path.isfile(dataset_file):\n", - " if not dataset_file.endswith(\".hdf5\"):\n", - " raise ValueError(\"Input file must be an .hdf5 file.\")\n", - " logger.debug(f'Loading from {dataset_file}')\n", - " self.qm9.set_view(dataset_file)\n", - " else:\n", - " logger.debug(f'Downloading from qcportal')\n", - "\n", - " # to get QM9 from qcportal, we need to define which collection and QM9\n", - "\n", - " self.qm9.download(dataset_file)\n", - " self.qm9.to_file(path=dataset_file, encoding=\"hdf5\")\n", - " self.qm9.set_view(self.dataset_file)\n", - "\n", - " def __len__(self):\n", - " molecules = self.qm9.get_molecules()\n", - " return molecules.shape[0]\n", - "\n", - " def __getitem__(self, idx):\n", - " with h5py.File(self.hdf5_file, 'r') as f:\n", - " geometry = torch.tensor(f[self.keys[idx]]['geometry'][:])\n", - " energy = torch.tensor(f[self.keys[idx]]['energy'][()])\n", - " \n", - " if self.transform:\n", - " geometry = self.transform(geometry)\n", - " if self.target_transform:\n", - " energy = self.target_transform(energy)\n", - " \n", - " return geometry, energy\n" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[32m2023-08-01 14:37:25.667\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36mload\u001b[0m:\u001b[36m34\u001b[0m - \u001b[34m\u001b[1mDownloading from qcportal\u001b[0m\n", - "160MB [01:26, 1.85MB/s] \n" - ] - }, - { - "ename": "AttributeError", - "evalue": "module 'distutils' has no attribute 'version'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[3], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m data \u001b[39m=\u001b[39m QM9Dataset(\u001b[39m'\u001b[39;49m\u001b[39mtest.hdf5\u001b[39;49m\u001b[39m'\u001b[39;49m)\n", - "Cell \u001b[0;32mIn[2], line 6\u001b[0m, in \u001b[0;36mQM9Dataset.__init__\u001b[0;34m(self, dataset_file, name)\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_name \u001b[39m=\u001b[39m name\n\u001b[1;32m 5\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset_file \u001b[39m=\u001b[39m dataset_file\n\u001b[0;32m----> 6\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mload(dataset_file)\n", - "Cell \u001b[0;32mIn[2], line 39\u001b[0m, in \u001b[0;36mQM9Dataset.load\u001b[0;34m(self, dataset_file)\u001b[0m\n\u001b[1;32m 36\u001b[0m \u001b[39m# to get QM9 from qcportal, we need to define which collection and QM9\u001b[39;00m\n\u001b[1;32m 38\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mqm9\u001b[39m.\u001b[39mdownload(dataset_file)\n\u001b[0;32m---> 39\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mqm9\u001b[39m.\u001b[39;49mto_file(path\u001b[39m=\u001b[39;49mdataset_file, encoding\u001b[39m=\u001b[39;49m\u001b[39m\"\u001b[39;49m\u001b[39mhdf5\u001b[39;49m\u001b[39m\"\u001b[39;49m)\n\u001b[1;32m 40\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mqm9\u001b[39m.\u001b[39mset_view(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset_file)\n", - "File \u001b[0;32m/data/shared/software/python_env/mambaforge/envs/modelforge3.10/lib/python3.10/site-packages/qcportal/collections/dataset.py:240\u001b[0m, in \u001b[0;36mDataset.to_file\u001b[0;34m(self, path, encoding)\u001b[0m\n\u001b[1;32m 237\u001b[0m \u001b[39melif\u001b[39;00m encoding\u001b[39m.\u001b[39mlower() \u001b[39min\u001b[39;00m [\u001b[39m\"\u001b[39m\u001b[39mhdf5\u001b[39m\u001b[39m\"\u001b[39m, \u001b[39m\"\u001b[39m\u001b[39mh5\u001b[39m\u001b[39m\"\u001b[39m]:\n\u001b[1;32m 238\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39m.\u001b[39;00m \u001b[39mimport\u001b[39;00m HDF5View\n\u001b[0;32m--> 240\u001b[0m HDF5View(path)\u001b[39m.\u001b[39;49mwrite(\u001b[39mself\u001b[39;49m)\n\u001b[1;32m 241\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 242\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mNotImplementedError\u001b[39;00m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mUnsupported encoding: \u001b[39m\u001b[39m{\u001b[39;00mencoding\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m)\n", - "File \u001b[0;32m/data/shared/software/python_env/mambaforge/envs/modelforge3.10/lib/python3.10/site-packages/qcportal/collections/dataset_view.py:260\u001b[0m, in \u001b[0;36mHDF5View.write\u001b[0;34m(self, ds)\u001b[0m\n\u001b[1;32m 257\u001b[0m n_records \u001b[39m=\u001b[39m \u001b[39mlen\u001b[39m(ds\u001b[39m.\u001b[39mdata\u001b[39m.\u001b[39mrecords)\n\u001b[1;32m 258\u001b[0m default_shape \u001b[39m=\u001b[39m (n_records,)\n\u001b[0;32m--> 260\u001b[0m \u001b[39mif\u001b[39;00m h5py\u001b[39m.\u001b[39m__version__ \u001b[39m>\u001b[39m\u001b[39m=\u001b[39m distutils\u001b[39m.\u001b[39;49mversion\u001b[39m.\u001b[39mStrictVersion(\u001b[39m\"\u001b[39m\u001b[39m2.10.0\u001b[39m\u001b[39m\"\u001b[39m):\n\u001b[1;32m 261\u001b[0m vlen_double_t \u001b[39m=\u001b[39m h5py\u001b[39m.\u001b[39mvlen_dtype(np\u001b[39m.\u001b[39mdtype(\u001b[39m\"\u001b[39m\u001b[39mfloat64\u001b[39m\u001b[39m\"\u001b[39m))\n\u001b[1;32m 262\u001b[0m utf8_t \u001b[39m=\u001b[39m h5py\u001b[39m.\u001b[39mstring_dtype(encoding\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mutf-8\u001b[39m\u001b[39m\"\u001b[39m)\n", - "\u001b[0;31mAttributeError\u001b[0m: module 'distutils' has no attribute 'version'" - ] - } - ], - "source": [ - "data = QM9Dataset('test.hdf5')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "data.__len__()" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "modelforge3.10", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.12" - }, - "orig_nbformat": 4 - }, - "nbformat": 4, - "nbformat_minor": 2 -}