Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

General predict method for tasks #232

Merged
merged 6 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions matsciml/interfaces/ase/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,13 +248,12 @@ def calculate(
# get into format ready for matsciml model
data_dict = self._format_pipeline(atoms)
# run the data structure through the model
output = self.task_module.predict(data_dict)
if isinstance(self.task_module, MultiTaskLitModule):
output = self.task_module.ase_calculate(data_dict)
# use a more complicated parser for multitasks
results = self.multitask_strategy(output, self.task_module)
self.results = results
else:
output = self.task_module(data_dict)
# add outputs to self.results as expected by ase
if "energy" in output:
self.results["energy"] = output["energy"].detach().item()
Expand Down
85 changes: 73 additions & 12 deletions matsciml/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1015,6 +1015,38 @@ def _make_normalizers(self) -> dict[str, Normalizer]:
normalizers[key] = Normalizer(mean=mean, std=std, device=self.device)
return normalizers

def predict(self, batch: BatchDict) -> dict[str, torch.Tensor]:
"""
Implements what is effectively the 'inference' logic of the task,
where run the forward pass on a batch of samples, and if normalizers
were used for training, we also apply the inverse operation to get
values in the right scale.

Not to be confused with `predict_step`, which is used by Lightning as
part of the prediction workflow. Since there is no one-size-fits-all
inference workflow we can define, this provides a convenient function
for users to call as a replacement.

Parameters
----------
batch : BatchDict
Batch of samples to pass to the model.

Returns
-------
dict[str, torch.Tensor]
Output dictionary as provided by the forward pass, but if
normalizers are available for a given task, we apply the
inverse norm on the value.
"""
outputs = self(batch)
if self.uses_normalizers:
for key in self.task_keys:
if key in self.normalizers:
# apply the inverse transform if provided
outputs[key] = self.normalizers[key].denorm(outputs[key])
return outputs

@classmethod
def from_pretrained_encoder(cls, task_ckpt_path: str | Path, **kwargs):
"""
Expand Down Expand Up @@ -1706,6 +1738,36 @@ def energy_and_force(
outputs["node_energies"] = node_energies
return outputs

def predict(self, batch: BatchDict) -> dict[str, torch.Tensor]:
"""
Similar to the base method, but we make two minor modifications to
the denormalization logic as we want to potentially apply the same
energy normalization rescaling to the forces and node-level energies.

Parameters
----------
batch : BatchDict
Batch of samples to evaluate on.

Returns
-------
dict[str, torch.Tensor]
Output dictionary as provided by the forward call. For this task in
particular, we may also apply the energy rescaling to forces and
node energies if separate keys for them are not provided.
"""
output = super().predict(batch)
# for forces, in the event that a dedicated normalizer wasn't provided
# but we have an energy normalizer, we apply the same factors to the force
if self.uses_normalizers:
if "force" not in self.normalizers and "energy" in self.normalizers:
output["force"] = self.normalizers["energy"].denorm(output["force"])
if "node_energies" not in self.normalizers and "energy" in self.normalizers:
output["node_energies"] = self.normalizers["energy"].denorm(
output["node_energies"]
)
return output

def _get_targets(
self,
batch: dict[str, torch.Tensor | dgl.DGLGraph | dict[str, torch.Tensor]],
Expand Down Expand Up @@ -2471,20 +2533,18 @@ def forward(
results[task_type] = subtask(batch)
return results

def ase_calculate(self, batch: BatchDict) -> dict[str, dict[str, torch.Tensor]]:
def predict(self, batch: BatchDict) -> dict[str, dict[str, torch.Tensor]]:
"""
Currently "specialized" function that runs a set of data through
every single output head, ignoring the nominal dataset/subtask
unique mapping.
Similar logic to the `BaseTaskModule.predict` method, but implemented
for the multitask setting.

This is designed for ASE usage primarily, but ostensibly could be
used as _the_ inference call for a multitask module. Basically,
when the input data doesn't come from the same "datasets" used
for initialization/training, and we want to provide a "mixture of
experts" response.
The workflow is a linear combination of the two: we run the joint
embedder once, and then subsequently rely on the `predict` method
for each subtask to get outputs at their expected scales.

TODO: this could potentially be used as a template to redesign
the forward call to substantially simplify the multitask mapping.
This method also behaves a little differently from the other multitask
operations, as it runs a set of data through every single output head,
ignoring the nominal dataset/subtask unique mapping.

Parameters
----------
Expand All @@ -2511,7 +2571,8 @@ def ase_calculate(self, batch: BatchDict) -> dict[str, dict[str, torch.Tensor]]:
# now loop through every dataset/output head pair
for dset_name, subtask_name in self.dataset_task_pairs:
subtask = self.task_map[dset_name][subtask_name]
output = subtask(batch)
# use the predict method to get rescaled outputs
output = subtask.predict(batch)
# now add it to the rest of the results
if dset_name not in results:
results[dset_name] = {}
Expand Down
Loading