Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

General predict method for tasks #232

Merged
merged 6 commits into from
Jun 3, 2024

Conversation

laserkelvin
Copy link
Collaborator

This PR adds a base predict method to tasks, which is intended to be used for inference workflows. This method acts as a wrapper for the task forward call, but includes the additional step of applying the inverse normalization to the outputs, so that they are rescaled appropriately.

A summary of changes:

  1. The predict definition in BaseTaskModule, which is inherited by all tasks except MultiTaskModule
  2. Refactored the MultiTaskModule.ase_calculate to be predict, so that the interface is consistent. The workflow is more or less unchanged from before, but relies on subtask.predict.
  3. For ForceRegressionTask, we override the predict logic slightly by rescaling both forces and node energies using the energy normalization factor if dedicated normalizers are not provided by those values separately.
  4. Updated the ase calculator code to rely on the predict interface, regardless of whether it is multitask or not.

@laserkelvin laserkelvin added enhancement New feature or request inference Issues related to model inference and testing labels Jun 3, 2024
@laserkelvin laserkelvin requested a review from melo-gonzo June 3, 2024 18:18
Copy link
Collaborator

@melo-gonzo melo-gonzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good!

@laserkelvin laserkelvin merged commit af6c5c1 into IntelLabs:main Jun 3, 2024
3 of 4 checks passed
@laserkelvin laserkelvin deleted the predict-task-method branch June 3, 2024 19:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request inference Issues related to model inference and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants