Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multitask ASE interface #224

Merged
merged 19 commits into from
May 28, 2024

Conversation

laserkelvin
Copy link
Collaborator

This PR closes #217 by providing an extended interface for mapping multitask models to ase workflows.

The main changes include:

  • Introduces the matsciml.interfaces.ase.multitask.AbstractStrategy class, whose children allows us to implement ways to aggregate the outputs of multitasks.
  • Implements the sole concrete class (for now), AverageTasks, which does a mean aggregation across the different dataset heads to compute energies/forces. Essentially, if you had ForceRegressionTask trained on multiple datasets, we outputs from all.
  • Implements a custom ase_calculate method for MultiTaskModule, which effectively streamlines inference using multitask modules in scenarios like as an ASE calculator. It's named specifically for ase, but in the future we can use it as a dedicated inference method.

The AbstractStrategy should be flexible enough to do other potentially more intelligent aggregations, e.g. a weighted average, MoE approaches, etc.

Signed-off-by: Lee, Kin Long Kelvin <[email protected]>
This isn't actually originally part of the PR scope, but apparently this typo never triggered as
an issue until now!
This substantially simplifies the workflow, albeit adds yet another method to
multitask modules. The new method simply passes input data into the encoder,
and maps it to every single subtask regardless, instead of requiring that
the batch shares the dataset keys.
@laserkelvin laserkelvin added enhancement New feature or request inference Issues related to model inference and testing labels May 24, 2024
@laserkelvin laserkelvin requested a review from melo-gonzo May 24, 2024 15:28
@laserkelvin
Copy link
Collaborator Author

Ergh, issue with one of the tests...

This adjusts the logic, albeit maybe inconsistent with the rest of multitask, where
we check the incoming batch for dataset names at the top level to determine if it's
a multidata batch, instead of relying on the model expectations.

This fixes the ase calculate behavior, which would have been mismatched since the
module is inherently multidata but the incoming batch is not.
@@ -2107,7 +2107,7 @@ def __init__(
if index != 0:
task.encoder = self.encoder
# nest the task based on its category
task_map[dset_name][task.__task__] = task
task_map[dset_name][task.__class__.__name__] = task
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this added for the single data multi-task case?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not 100% sure what you mean - I don't think this part of the code matters if it's single data or not?

Copy link
Collaborator

@melo-gonzo melo-gonzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good, I'll get some practical testing done once this is merged. Can you add an example similar to ase_from_pretrained.py? Otherwise good to merge!

@laserkelvin
Copy link
Collaborator Author

@melo-gonzo just wanted you to take a look at the example before I merge, and if you wanted to clarify what you meant in your question?

@melo-gonzo
Copy link
Collaborator

@laserkelvin thanks for the example! I am just curious why you made the change in that line. Feel free to merge ✅

@laserkelvin laserkelvin merged commit 92f1600 into IntelLabs:main May 28, 2024
3 of 4 checks passed
@laserkelvin laserkelvin deleted the multitask-ase-interface branch May 28, 2024 15:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request inference Issues related to model inference and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Feature request]: Reconciling multi task models with ase Calculator interface.
2 participants