diff --git a/docs/source/inference.rst b/docs/source/inference.rst new file mode 100644 index 00000000..c00c8cda --- /dev/null +++ b/docs/source/inference.rst @@ -0,0 +1,62 @@ +Inference +========= + +"Inference" can be a bit of an overloaded term, and this page is broken down into different possible +downstream use cases for trained models. + +Parity plots and model evaluations +---------------------------------- + +The simplest/most straightforward thing to check the performance of a model is to look beyond reduced metrics; i.e. anything that +has been averaged over batches, epochs, etc. Parity plots help verify linear relationships between predictions and ground truths +by simply iterating over the evaluation subset of data, averaging. + +The ``ParityInferenceTask`` helps perform this task by using the PyTorch Lightning ``predict`` pipelines. With a pre-trained +``matsciml`` task checkpoint, you simply need to run the following: + +.. codeblock:: python + + import pytorch_lightning as pl + + from matsciml.models.inference import ParityInferenceTask + from matsciml.lightning import MatSciMLDataModule + + # configure data module the way that you need to + dm = MatSciMLDataModule( + dataset="NameofDataset", + pred_split="/path/to/lmdb/split", + batch_size=64 # this is just to amoritize model calls + ) + task = ParityInferenceTask.from_pretrained_checkpoint("/path/to/checkpoint") + + trainer = pl.Trainer() # optionally, configure logger/limit_predict_batches + trainer.predict(task, datamodule=dm) + + +The default ``Trainer`` settings will create a ``lightning_logs`` directory, followed by an experiment +number. Within it, once your inference run completes, there will be a ``inference_data.json`` that you +can then load in. The data is sorted by the name of the target (e.g. ``energy``, ``bandgap``), under +these keys, ``predictions`` and ``targets``. Note that ``pred_split`` does not necessarily have to be +a completely different hold out: you can pass your training LMDB path if you wish to double check the +performance of your model after training, or you can use it with unseen samples. + +.. note:: + + For developers, this is handled by the ``matsciml.models.inference.ParityData`` class. This is + mainly to standardize the output and provide a means to serialize the data as JSON. + + + +.. autoclass:: matsciml.models.inference.ParityInferenceTask + :members: + + + +Performing molecular dynamics simulations +----------------------------------------- + +Currently, the main method of interfacing with dynamical simulations is through the ``ase`` package. +Documentation for this is ongoing, but examples can be found under ``examples/interfaces``. + +.. autoclass:: matsciml.interfaces.ase.MatSciMLCalculator + :members: diff --git a/matsciml/lightning/data_utils.py b/matsciml/lightning/data_utils.py index 90fc3ff4..fbe12d90 100644 --- a/matsciml/lightning/data_utils.py +++ b/matsciml/lightning/data_utils.py @@ -104,6 +104,7 @@ def __init__( num_workers: int = 0, val_split: str | Path | float | None = 0.0, test_split: str | Path | float | None = 0.0, + pred_split: str | Path | None = None, seed: int | None = None, dset_kwargs: dict[str, Any] | None = None, persistent_workers: bool | None = None, @@ -111,7 +112,7 @@ def __init__( super().__init__() # make sure we have something to work with assert any( - [i for i in [dataset, train_path, val_split, test_split]], + [i for i in [dataset, train_path, val_split, test_split, pred_split]], ), "No splits provided to datamodule." # if floats are passed to splits, make sure dataset is provided for inference if any([isinstance(i, float) for i in [val_split, test_split]]): @@ -122,7 +123,7 @@ def __init__( assert any( [ isinstance(p, (str, Path)) - for p in [train_path, val_split, test_split] + for p in [train_path, val_split, test_split, pred_split] ], ), "Dataset type passed, but no paths to construct with." self.dataset = dataset @@ -248,6 +249,17 @@ def setup(self, stage: str | None = None) -> None: if isinstance(split_path, (str, Path)): dset = self._make_dataset(split_path, self.dataset) splits[key] = dset + # specialty case for 'inference' or prediction runs + if isinstance(self.hparams.pred_split, (str, Path)): + pred_split_path = self.hparams.pred_split + if isinstance(pred_split_path, str): + pred_split_path = Path(pred_split_path) + if not pred_split_path.exists(): + raise FileNotFoundError( + f"Prediction split provided, but not found: {pred_split_path}" + ) + dset = self._make_dataset(pred_split_path, self.dataset) + splits["pred"] = dset # the last case assumes only the dataset is passed, we will treat it as train if len(splits) == 0: splits["train"] = self.dataset @@ -268,8 +280,12 @@ def predict_dataloader(self): """ Predict behavior just assumes the whole dataset is used for inference. """ + if "pred" in self.splits: + target = self.splits["pred"] + else: + target = self.dataset return DataLoader( - self.dataset, + target, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers, collate_fn=self.dataset.collate_fn, diff --git a/matsciml/models/inference.py b/matsciml/models/inference.py index 38556502..2f97cce0 100644 --- a/matsciml/models/inference.py +++ b/matsciml/models/inference.py @@ -1,8 +1,10 @@ from __future__ import annotations +import json from abc import ABC, abstractmethod from pathlib import Path -from typing import Any, Union +from typing import Any +from logging import getLogger import pytorch_lightning as pl import torch @@ -10,6 +12,83 @@ from matsciml.common.registry import registry from matsciml.common.types import BatchDict, DataDict +from matsciml.models.base import BaseTaskModule, MultiTaskLitModule + + +class ParityData: + def __init__(self, name: str) -> None: + """ + Class to help accumulate inference results. + + This class should be created per target, and uses property + setters to accumulate target and prediction tensors, + and at the final step, aggregate them all into a single + tensor and with the `to_json` method, produce serializable + data. + + Parameters + ---------- + name : str + Name of the target property being tracked. + """ + super().__init__() + self.name = name + self.logger = getLogger(f"matsciml.inference.{name}-parity") + + @property + def ndim(self) -> int: + if not hasattr(self, "_targets"): + raise RuntimeError("No data set to accumulator yet.") + sample = self._targets[0] + if isinstance(sample, torch.Tensor): + return sample.ndim + else: + return 0 + + @property + def targets(self) -> torch.Tensor: + return torch.vstack(self._targets) + + @targets.setter + def targets(self, values: torch.Tensor) -> None: + if not hasattr(self, "_targets"): + self._targets = [] + if isinstance(values, torch.Tensor): + # remove errenous "empty" dimensions + values.squeeze_() + self._targets.append(values) + + @property + def predictions(self) -> torch.Tensor: + return torch.vstack(self._targets) + + @predictions.setter + def predictions(self, values: torch.Tensor) -> None: + if not hasattr(self, "_predictions"): + self._predictions = [] + if isinstance(values, torch.Tensor): + values.squeeze_() + self._predictions.append(values) + + def to_json(self) -> dict[str, list]: + return_dict = {} + targets = self.targets.cpu() + predictions = self.predictions.cpu() + # do some preliminary checks to the data + if targets.ndim != predictions.ndim: + self.logger.warning( + "Target/prediction dimensionality mismatch\n" + f" Target: {targets.ndim}, predictions: {predictions.ndim}" + ) + if targets.shape != predictions.shape: + self.logger.warning( + "Target/prediction shape mismatch\n" + f" Target: {targets.shape}, predictions: {predictions.shape}." + ) + return_dict["predictions"] = predictions.tolist() + return_dict["targets"] = targets.tolist() + return_dict["name"] = self.name + return return_dict class BaseInferenceTask(ABC, pl.LightningModule): @@ -17,14 +96,19 @@ def __init__(self, pretrained_model: nn.Module, *args, **kwargs): super().__init__() self.model = pretrained_model + def training_step(self, *args, **kwargs) -> None: + """Overrides Lightning method to prevent task being used for training.""" + raise NotImplementedError( + f"{self.__class__.__name__} is not intended for training." + ) + @abstractmethod def predict_step( self, batch: BatchDict, batch_idx: int, dataloader_idx: int = 0, - ) -> Any: - ... + ) -> Any: ... @classmethod def from_pretrained_checkpoint( @@ -58,7 +142,7 @@ def from_pretrained_checkpoint( task_ckpt_path = Path(task_ckpt_path) assert ( task_ckpt_path.exists() - ), f"Encoder checkpoint filepath specified but does not exist." + ), "Encoder checkpoint filepath specified but does not exist." ckpt = torch.load(task_ckpt_path) select_kwargs = {} for key in ["encoder_class", "encoder_kwargs"]: @@ -117,3 +201,95 @@ def predict_step( for key in ["targets", "symmetry"]: return_dict[key] = batch.get(key) return return_dict + + +@registry.register_task("ParityInferenceTask") +class ParityInferenceTask(BaseInferenceTask): + def __init__(self, pretrained_model: BaseTaskModule): + """ + Use a pretrained model to produce pair-plot data, i.e. predicted vs. + ground truth. + + Example usage + ------------- + The intended usage is to load a pretrained model, define a data module + that points to some data to perform predictions with, then call Lightning + Trainer's ``predict`` method. + + >>> task = ParityInferenceTask.from_pretrained_checkpoint(...) + >>> dm = MatSciMLDataModule("DatasetName", pred_path=...) + >>> trainer = pl.Trainer() + >>> trainer.predict(task, datamodule=dm) + + Parameters + ---------- + pretrained_model : BaseTaskModule + An instance of a subclass of ``BaseTaskModule``, e.g. a + ``ForceRegressionTask`` object. + + Raises + ------ + NotImplementedError + Currently, multitask modules are not yet supported. + """ + if isinstance(pretrained_model, MultiTaskLitModule): + raise NotImplementedError( + "ParityInferenceTask currently only supports single task modules." + ) + assert hasattr(pretrained_model, "predict") and callable( + pretrained_model.predict + ), "Model passed does not have a `predict` method; is it a `matsciml` task?" + super().__init__(pretrained_model) + self.common_keys = set() + self.accumulators = {} + + def forward(self, batch: BatchDict) -> dict[str, float | torch.Tensor]: + """ + Forward call for the inference task. This wraps the underlying + ``matsciml`` task module's ``predict`` function to ensure that + normalization is 'reversed', i.e. predictions are reported in + the original unit space. + + Parameters + ---------- + batch : BatchDict + Batch of samples to process. + + Returns + ------- + dict[str, float | torch.Tensor] + Prediction output, which should correspond to a key/tensor + mapping of output head/task name, and the associated outputs. + """ + preds = self.model.predict(batch) + return preds + + def on_predict_start(self) -> None: + """Verify that logging is enabled, as it is needed.""" + if not self.trainer.log_dir: + raise RuntimeError( + "ParityInferenceTask requires logging to be enabled; no `log_dir` detected in Trainer." + ) + + def predict_step( + self, batch: BatchDict, batch_idx: int, dataloader_idx: int = 0 + ) -> None: + predictions = self(batch) + pred_keys = set(list(predictions.keys())) + batch_keys = set(list(batch["targets"].keys())) + self.common_keys = pred_keys.intersection(batch_keys) + # loop over keys that are mutually available in predictions and data + for key in self.common_keys: + if key not in self.accumulators: + self.accumulators[key] = ParityData(key) + acc = self.accumulators[key] + acc.targets = batch["targets"][key].detach() + acc.predictions = predictions[key].detach() + + def on_predict_epoch_end(self) -> None: + """At the end of the dataset, write results to ``/inference_data.json``.""" + log_dir = Path(self.trainer.log_dir) + output_file = log_dir.joinpath("inference_data.json") + with open(output_file, "w+") as write_file: + data = {key: acc.to_json() for key, acc in self.accumulators.items()} + json.dump(data, write_file, indent=2) diff --git a/matsciml/models/tests/test_parity_inference.py b/matsciml/models/tests/test_parity_inference.py new file mode 100644 index 00000000..b80cde68 --- /dev/null +++ b/matsciml/models/tests/test_parity_inference.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +import shutil +import json +from pathlib import Path + +import pytest +import pytorch_lightning as pl + +from matsciml.models.inference import ParityInferenceTask +from matsciml.models.base import ScalarRegressionTask +from matsciml.models.pyg import EGNN +from matsciml.lightning import MatSciMLDataModule +from matsciml.datasets.transforms import ( + PeriodicPropertiesTransform, + PointCloudToGraphTransform, +) + + +@pytest.mark.slow +@pytest.mark.parametrize( + "dset_params", + [ + ( + "MaterialsProjectDataset", + [ + "efermi", + ], + ), + ( + "LiPSDataset", + [ + "energy", + ], + ), + ("OQMDDataset", ["stability", "band_gap"]), + ], +) +def test_parity_inference_workflow(dset_params): + dataset_name, keys = dset_params + dm = MatSciMLDataModule.from_devset( + dataset_name, + dset_kwargs={ + "transforms": [ + PeriodicPropertiesTransform(6.0, True), + PointCloudToGraphTransform("pyg"), + ] + }, + batch_size=8, + ) + task = ScalarRegressionTask( + encoder_class=EGNN, + encoder_kwargs={"hidden_dim": 16, "output_dim": 16, "num_conv": 2}, + output_kwargs={"hidden_dim": 16}, + task_keys=keys, + ) + # train the model briefly to initialize output heads + trainer = pl.Trainer(max_epochs=1, limit_train_batches=10, limit_val_batches=0) + trainer.fit(task, dm) + # now do the inference part + wrapper = ParityInferenceTask(task) + trainer.predict(wrapper, datamodule=dm) + assert trainer.log_dir is not None + log_dir = Path(trainer.log_dir) + assert log_dir.exists() + # open the result and make sure it's not empty + with open(log_dir.joinpath("inference_data.json"), "r") as read_file: + data = json.load(read_file) + assert len(data) != 0 + assert sorted(list(data.keys())) == sorted(task.task_keys) + # make sure there are actually predictions and targets available + for subdict in data.values(): + assert len(subdict["predictions"]) == len(subdict["targets"]) + shutil.rmtree("lightning_logs", ignore_errors=True)