-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce a Nominal Checkpoint for On-Device Training #19232
Conversation
…baijumeswani/nominal-checkpoint
…baijumeswani/nominal-checkpoint
Is there a scenario where only the nominal checkpoint is required? |
For the scenario I foresee, I do not see how users would be able to use only the nominal checkpoint. The nominal checkpoint is of use particularly in the case of federated learning where the device can benefit from not loading the model weights on construction. However, the server must have the complete checkpoint information so it can send that to the device as and when requested. So, the nominal checkpoint and the complete checkpoint must work in tandem to achieve the expected benefit. Having said that, there might be some use of a utility that can take in a complete checkpoint and return/save a nominal checkpoint equivalent of that complete checkpoint. This might be outside the scope of this pull request though, as this PR is targeting enabling the functionality. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some suggestions.
Not included in these changes but in the previous code:
Would suggest changing the RuntimeError to ValueError when we're checking for a null value.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
…baijumeswani/nominal-checkpoint
Would it be ok if I addressed this in another PR? |
Yes this can be as a part of a refactoring PR. |
This pull request introduces a significant enhancement to ONNX Runtime's On-Device Training Checkpoint. The nominal checkpoint feature allows users to use a checkpoint file on the device that contains the minimum information needed to instantiate the
Ort::TrainingSession
.Why Nominal Checkpoints? 🤔
For federated learning scenarios, we noticed that users were doing duplicate work. On the device, they instantiated the
TrainingSession
and immediately thereafter, loaded the parameters got from the federated learning server. The construction of theTrainingSession
allocates memory for the weights and their gradients. And loading the model parameters does similar work, causing duplicated work. Loading of the weights is an expensive operation and allocating and copying data multiple times leads to redundancy that should be avoided.Moreover, the nominal checkpoint also addresses the problem of the On-Device Training Application package size. This is the cost the developer has to incur to package the training artifacts within their application. The complete checkpoint size can be considerable given that it contains all the parameter data. This file size of a nominal checkpoint is far smaller in comparison and does not impact the developer's application size significantly.
How is the Nominal Checkpoint different from the Complete Checkpoint? 🤔
Both the nominal and the complete checkpoint have the exact same schema. The only significant difference between the two is that the model parameters contained in the nominal checkpoint have no raw data associated with it. To elaborate:
All other fields in the checkpoint remain the same.
How does one generate a Nominal Checkpoint? 🤔
The
artifacts.generate_artifacts
python utility inonnxblock
has an optional argumentnominal_checkpoint
. The new signature of thegenerate_artifacts
method is:If the
nominal_checkpoint
argument is passed in asTrue
, two checkpoint files will be generated calledcheckpoint
andnominal_checkpoint
. As the name suggests, thenominal_checkpoint
file will be the one that can be used on the device while thecheckpoint
is expected to be used on the federated learning server (since at least one of the two needs to have the complete state information).How should one use the Nominal Checkpoint on the device? 🤔
Illustration in
C++
:Illustration in
Python
: