Skip to content

Commit

Permalink
fix: only checking pred path if it's a string or path
Browse files Browse the repository at this point in the history
Signed-off-by: Kin Long Kelvin Lee <[email protected]>
  • Loading branch information
laserkelvin committed Sep 25, 2024
1 parent eabfef8 commit 2faa68d
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
4 changes: 3 additions & 1 deletion docs/source/inference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ The ``ParityInferenceTask`` helps perform this task by using the PyTorch Lightni
The default ``Trainer`` settings will create a ``lightning_logs`` directory, followed by an experiment
number. Within it, once your inference run completes, there will be a ``inference_data.json`` that you
can then load in. The data is sorted by the name of the target (e.g. ``energy``, ``bandgap``), under
these keys, ``predictions`` and ``targets``.
these keys, ``predictions`` and ``targets``. Note that ``pred_split`` does not necessarily have to be
a completely different hold out: you can pass your training LMDB path if you wish to double check the
performance of your model after training, or you can use it with unseen samples.

.. note::

Expand Down
2 changes: 1 addition & 1 deletion matsciml/lightning/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def setup(self, stage: str | None = None) -> None:
dset = self._make_dataset(split_path, self.dataset)
splits[key] = dset
# specialty case for 'inference' or prediction runs
if hasattr(self.hparams, "pred_split"):
if isinstance(self.hparams.pred_split, (str, Path)):
pred_split_path = self.hparams.pred_split
if isinstance(pred_split_path, str):
pred_split_path = Path(pred_split_path)
Expand Down

0 comments on commit 2faa68d

Please sign in to comment.