Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introducing structured model outputs #316

Merged
merged 35 commits into from
Nov 12, 2024

Conversation

laserkelvin
Copy link
Collaborator

@laserkelvin laserkelvin commented Nov 12, 2024

This PR is semi-related to #104, allowing flexibility for wrappers and full implementations to potentially have a common output format that is also "backwards" compatible. It depends on ModelOutput from #315, so please review and merge that before this PR.

  • Added a private attribute of the base AbstractTask class (which graph models inherit) called __skip_output_heads__, which BaseTaskLitModules will check this to determine if output heads are needed or not. If we do not use them, we do not initialize the OutputHeads to save on parameters/ensure we don't have unused parameters.
  • I have refactored the MACEWrapper as an example of how to do this, and we should do so for other wrappers as well (e.g. CHGNet). I've also updated the sphinx documentation to detail how this can be used.
  • I've ensured that the regular workflow (with PyG EGNN) isn't broken, and included unit tests for MACEWrapper to run the forward pass and an end-to-end with ForceRegressionTask.

…finition

While having the argument means we have choice, it does make it more difficult to control the logic
@laserkelvin laserkelvin added enhancement New feature or request code maintenance Issue/PR for refactors, code clean up, etc. labels Nov 12, 2024
@laserkelvin
Copy link
Collaborator Author

CI passed on relevant things

Copy link
Collaborator

@melo-gonzo melo-gonzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall! Will be a nice improvement to have. Left one general comment about downstream workflows.

if isinstance(encoder_outputs, Embeddings):
embeddings = encoder_outputs
# for BYO output head cases
elif isinstance(encoder_outputs, ModelOutput):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A comment that may not be super important to consider right at this moment but should be noted - some workflows, such as serving models with OpenKIM's kusp to then run benchmarks, require that node energies are present. Some models (such as mace) can output these directly and they may be worth hanging onto as well.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think they're included in ModelOutput - in the wrapper output they're stashed as node_energies. That's what you mean right?

@laserkelvin laserkelvin merged commit b522337 into IntelLabs:main Nov 12, 2024
2 of 3 checks passed
@laserkelvin laserkelvin deleted the byo-outputs branch November 12, 2024 23:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
code maintenance Issue/PR for refactors, code clean up, etc. enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants