From b63247856dc6b9ff5fddc3afc6f882fca885dac3 Mon Sep 17 00:00:00 2001 From: Kin Long Kelvin Lee Date: Fri, 27 Sep 2024 08:52:53 -0700 Subject: [PATCH 1/6] fix: mapping the right private variable to predictions property Signed-off-by: Kin Long Kelvin Lee --- matsciml/models/inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/matsciml/models/inference.py b/matsciml/models/inference.py index 2f97cce0..b620badd 100644 --- a/matsciml/models/inference.py +++ b/matsciml/models/inference.py @@ -60,7 +60,7 @@ def targets(self, values: torch.Tensor) -> None: @property def predictions(self) -> torch.Tensor: - return torch.vstack(self._targets) + return torch.vstack(self._predictions) @predictions.setter def predictions(self, values: torch.Tensor) -> None: From b1d3941ce8c8daf87d8c521d4e38fbc3042bd20a Mon Sep 17 00:00:00 2001 From: Kin Long Kelvin Lee Date: Fri, 27 Sep 2024 09:09:45 -0700 Subject: [PATCH 2/6] refactor: allowing inference tasks to be created from specific classes Signed-off-by: Kin Long Kelvin Lee --- matsciml/models/inference.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/matsciml/models/inference.py b/matsciml/models/inference.py index b620badd..e27c7a4e 100644 --- a/matsciml/models/inference.py +++ b/matsciml/models/inference.py @@ -112,8 +112,7 @@ def predict_step( @classmethod def from_pretrained_checkpoint( - cls, - task_ckpt_path: str | Path, + cls, task_ckpt_path: str | Path, ckpt_class_name: str | None = None ) -> BaseInferenceTask: """ Instantiate a ``BaseInferenceTask`` from an existing Lightning checkpoint @@ -124,9 +123,15 @@ def from_pretrained_checkpoint( Parameters ---------- - task_ckpt_path : Union[str, Path] + task_ckpt_path : str | Path Path to an existing task checkpoint file. Typically, this would be a PyTorch Lightning checkpoint. + ckpt_class_name : str, optional + If specified, this will load the task based on its native + ``load_from_checkpoint`` method. This is a good alternative + if this method is unable to resolve parameter naming, etc, + and if your inference task depends on specific methods in + the task. Examples -------- @@ -143,6 +148,13 @@ def from_pretrained_checkpoint( assert ( task_ckpt_path.exists() ), "Encoder checkpoint filepath specified but does not exist." + # if a task name for the checkpoint is given, use that task's + # loading method directly + if ckpt_class_name: + task_cls = registry.get_task_class(ckpt_class_name) + if not task_cls: + raise KeyError(f"Requested {task_cls}, which is not a registered task.") + return cls(task_cls.load_from_checkpoint(str(task_ckpt_path))) ckpt = torch.load(task_ckpt_path) select_kwargs = {} for key in ["encoder_class", "encoder_kwargs"]: From a23db74d1d2196d368b553c00ae8153c6d6a0a9c Mon Sep 17 00:00:00 2001 From: Kin Long Kelvin Lee Date: Fri, 27 Sep 2024 09:10:04 -0700 Subject: [PATCH 3/6] docs: added note in inference docs about disabling inference mode Signed-off-by: Kin Long Kelvin Lee --- docs/source/inference.rst | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/source/inference.rst b/docs/source/inference.rst index c00c8cda..08032fcc 100644 --- a/docs/source/inference.rst +++ b/docs/source/inference.rst @@ -40,6 +40,12 @@ these keys, ``predictions`` and ``targets``. Note that ``pred_split`` does not n a completely different hold out: you can pass your training LMDB path if you wish to double check the performance of your model after training, or you can use it with unseen samples. +Note that by default, `predict` triggers PyTorch's inference mode, which is a specialized case where +absolutely no autograd is enabled. ``ForceRegressionTask`` uses automatic differentiation to evaluate +forces, and so for inference tasks that require gradients, you **must** pass `inference_mode=False` to +``pl.Trainer``. + + .. note:: For developers, this is handled by the ``matsciml.models.inference.ParityData`` class. This is From 05a3c93ab1bdf2eaba87cb85b14678e0b06dd092 Mon Sep 17 00:00:00 2001 From: Kin Long Kelvin Lee Date: Fri, 27 Sep 2024 09:10:23 -0700 Subject: [PATCH 4/6] docs: added missing inference page to index Signed-off-by: Kin Long Kelvin Lee --- docs/source/index.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/index.rst b/docs/source/index.rst index fd62fe86..aa66676a 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -17,6 +17,7 @@ The Open MatSciML Toolkit training callbacks experiment-interface + inference best-practices how-to developers From 402c6dd6ca91d2bb2c5f0b6583bfb1b1703f41fe Mon Sep 17 00:00:00 2001 From: Kin Long Kelvin Lee Date: Fri, 27 Sep 2024 09:12:45 -0700 Subject: [PATCH 5/6] refactor: setting pred split to sole dataset Signed-off-by: Kin Long Kelvin Lee --- matsciml/lightning/data_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/matsciml/lightning/data_utils.py b/matsciml/lightning/data_utils.py index fbe12d90..ec045c84 100644 --- a/matsciml/lightning/data_utils.py +++ b/matsciml/lightning/data_utils.py @@ -259,6 +259,9 @@ def setup(self, stage: str | None = None) -> None: f"Prediction split provided, but not found: {pred_split_path}" ) dset = self._make_dataset(pred_split_path, self.dataset) + # assumes that if we're providing a predict set, we're not going + # to be doing training in the same run + self.dataset = dset splits["pred"] = dset # the last case assumes only the dataset is passed, we will treat it as train if len(splits) == 0: From 855a06dbddf4b8c8e164b387dad00b01a35e4ce7 Mon Sep 17 00:00:00 2001 From: Kin Long Kelvin Lee Date: Fri, 27 Sep 2024 09:13:05 -0700 Subject: [PATCH 6/6] fix: making predict split use the correct collate func Signed-off-by: Kin Long Kelvin Lee --- matsciml/lightning/data_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/matsciml/lightning/data_utils.py b/matsciml/lightning/data_utils.py index ec045c84..0e82c1ff 100644 --- a/matsciml/lightning/data_utils.py +++ b/matsciml/lightning/data_utils.py @@ -291,7 +291,7 @@ def predict_dataloader(self): target, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers, - collate_fn=self.dataset.collate_fn, + collate_fn=target.collate_fn, persistent_workers=self.persistent_workers, )