-
Notifications
You must be signed in to change notification settings - Fork 25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Multitask ASE interface #224
Multitask ASE interface #224
Conversation
Signed-off-by: Lee, Kin Long Kelvin <[email protected]>
Signed-off-by: Lee, Kin Long Kelvin <[email protected]>
This isn't actually originally part of the PR scope, but apparently this typo never triggered as an issue until now!
This substantially simplifies the workflow, albeit adds yet another method to multitask modules. The new method simply passes input data into the encoder, and maps it to every single subtask regardless, instead of requiring that the batch shares the dataset keys.
Ergh, issue with one of the tests... |
This adjusts the logic, albeit maybe inconsistent with the rest of multitask, where we check the incoming batch for dataset names at the top level to determine if it's a multidata batch, instead of relying on the model expectations. This fixes the ase calculate behavior, which would have been mismatched since the module is inherently multidata but the incoming batch is not.
@@ -2107,7 +2107,7 @@ def __init__( | |||
if index != 0: | |||
task.encoder = self.encoder | |||
# nest the task based on its category | |||
task_map[dset_name][task.__task__] = task | |||
task_map[dset_name][task.__class__.__name__] = task |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was this added for the single data multi-task case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not 100% sure what you mean - I don't think this part of the code matters if it's single data or not?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good, I'll get some practical testing done once this is merged. Can you add an example similar to ase_from_pretrained.py
? Otherwise good to merge!
Signed-off-by: Lee, Kin Long Kelvin <[email protected]>
@melo-gonzo just wanted you to take a look at the example before I merge, and if you wanted to clarify what you meant in your question? |
@laserkelvin thanks for the example! I am just curious why you made the change in that line. Feel free to merge ✅ |
This PR closes #217 by providing an extended interface for mapping multitask models to
ase
workflows.The main changes include:
matsciml.interfaces.ase.multitask.AbstractStrategy
class, whose children allows us to implement ways to aggregate the outputs of multitasks.AverageTasks
, which does a mean aggregation across the different dataset heads to compute energies/forces. Essentially, if you hadForceRegressionTask
trained on multiple datasets, we outputs from all.ase_calculate
method forMultiTaskModule
, which effectively streamlines inference using multitask modules in scenarios like as an ASE calculator. It's named specifically forase
, but in the future we can use it as a dedicated inference method.The
AbstractStrategy
should be flexible enough to do other potentially more intelligent aggregations, e.g. a weighted average, MoE approaches, etc.