Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes to parity inference task #296

Merged
merged 6 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ The Open MatSciML Toolkit
training
callbacks
experiment-interface
inference
best-practices
how-to
developers
Expand Down
6 changes: 6 additions & 0 deletions docs/source/inference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ these keys, ``predictions`` and ``targets``. Note that ``pred_split`` does not n
a completely different hold out: you can pass your training LMDB path if you wish to double check the
performance of your model after training, or you can use it with unseen samples.

Note that by default, `predict` triggers PyTorch's inference mode, which is a specialized case where
absolutely no autograd is enabled. ``ForceRegressionTask`` uses automatic differentiation to evaluate
forces, and so for inference tasks that require gradients, you **must** pass `inference_mode=False` to
``pl.Trainer``.


.. note::

For developers, this is handled by the ``matsciml.models.inference.ParityData`` class. This is
Expand Down
5 changes: 4 additions & 1 deletion matsciml/lightning/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,9 @@ def setup(self, stage: str | None = None) -> None:
f"Prediction split provided, but not found: {pred_split_path}"
)
dset = self._make_dataset(pred_split_path, self.dataset)
# assumes that if we're providing a predict set, we're not going
# to be doing training in the same run
self.dataset = dset
splits["pred"] = dset
# the last case assumes only the dataset is passed, we will treat it as train
if len(splits) == 0:
Expand Down Expand Up @@ -288,7 +291,7 @@ def predict_dataloader(self):
target,
batch_size=self.hparams.batch_size,
num_workers=self.hparams.num_workers,
collate_fn=self.dataset.collate_fn,
collate_fn=target.collate_fn,
persistent_workers=self.persistent_workers,
)

Expand Down
20 changes: 16 additions & 4 deletions matsciml/models/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def targets(self, values: torch.Tensor) -> None:

@property
def predictions(self) -> torch.Tensor:
return torch.vstack(self._targets)
return torch.vstack(self._predictions)

@predictions.setter
def predictions(self, values: torch.Tensor) -> None:
Expand Down Expand Up @@ -112,8 +112,7 @@ def predict_step(

@classmethod
def from_pretrained_checkpoint(
cls,
task_ckpt_path: str | Path,
cls, task_ckpt_path: str | Path, ckpt_class_name: str | None = None
) -> BaseInferenceTask:
"""
Instantiate a ``BaseInferenceTask`` from an existing Lightning checkpoint
Expand All @@ -124,9 +123,15 @@ def from_pretrained_checkpoint(

Parameters
----------
task_ckpt_path : Union[str, Path]
task_ckpt_path : str | Path
Path to an existing task checkpoint file. Typically, this
would be a PyTorch Lightning checkpoint.
ckpt_class_name : str, optional
If specified, this will load the task based on its native
``load_from_checkpoint`` method. This is a good alternative
if this method is unable to resolve parameter naming, etc,
and if your inference task depends on specific methods in
the task.

Examples
--------
Expand All @@ -143,6 +148,13 @@ def from_pretrained_checkpoint(
assert (
task_ckpt_path.exists()
), "Encoder checkpoint filepath specified but does not exist."
# if a task name for the checkpoint is given, use that task's
# loading method directly
if ckpt_class_name:
task_cls = registry.get_task_class(ckpt_class_name)
if not task_cls:
raise KeyError(f"Requested {task_cls}, which is not a registered task.")
return cls(task_cls.load_from_checkpoint(str(task_ckpt_path)))
ckpt = torch.load(task_ckpt_path)
select_kwargs = {}
for key in ["encoder_class", "encoder_kwargs"]:
Expand Down
Loading