-
Notifications
You must be signed in to change notification settings - Fork 25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introducing structured model outputs #316
Conversation
…finition While having the argument means we have choice, it does make it more difficult to control the logic
Signed-off-by: Lee, Kin Long Kelvin <[email protected]>
Signed-off-by: Lee, Kin Long Kelvin <[email protected]>
CI passed on relevant things |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good overall! Will be a nice improvement to have. Left one general comment about downstream workflows.
if isinstance(encoder_outputs, Embeddings): | ||
embeddings = encoder_outputs | ||
# for BYO output head cases | ||
elif isinstance(encoder_outputs, ModelOutput): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A comment that may not be super important to consider right at this moment but should be noted - some workflows, such as serving models with OpenKIM's kusp
to then run benchmarks, require that node energies are present. Some models (such as mace) can output these directly and they may be worth hanging onto as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think they're included in ModelOutput
- in the wrapper output they're stashed as node_energies
. That's what you mean right?
This PR is semi-related to #104, allowing flexibility for wrappers and full implementations to potentially have a common output format that is also "backwards" compatible. It depends on
ModelOutput
from #315, so please review and merge that before this PR.AbstractTask
class (which graph models inherit) called__skip_output_heads__
, whichBaseTaskLitModules
will check this to determine if output heads are needed or not. If we do not use them, we do not initialize theOutputHeads
to save on parameters/ensure we don't have unused parameters.MACEWrapper
as an example of how to do this, and we should do so for other wrappers as well (e.g.CHGNet
). I've also updated thesphinx
documentation to detail how this can be used.MACEWrapper
to run the forward pass and an end-to-end withForceRegressionTask
.