Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce a Nominal Checkpoint for On-Device Training #19232

Merged
merged 8 commits into from
Jan 31, 2024

Conversation

baijumeswani
Copy link
Contributor

@baijumeswani baijumeswani commented Jan 23, 2024

This pull request introduces a significant enhancement to ONNX Runtime's On-Device Training Checkpoint. The nominal checkpoint feature allows users to use a checkpoint file on the device that contains the minimum information needed to instantiate the Ort::TrainingSession.

Why Nominal Checkpoints? 🤔

For federated learning scenarios, we noticed that users were doing duplicate work. On the device, they instantiated the TrainingSession and immediately thereafter, loaded the parameters got from the federated learning server. The construction of the TrainingSession allocates memory for the weights and their gradients. And loading the model parameters does similar work, causing duplicated work. Loading of the weights is an expensive operation and allocating and copying data multiple times leads to redundancy that should be avoided.
Moreover, the nominal checkpoint also addresses the problem of the On-Device Training Application package size. This is the cost the developer has to incur to package the training artifacts within their application. The complete checkpoint size can be considerable given that it contains all the parameter data. This file size of a nominal checkpoint is far smaller in comparison and does not impact the developer's application size significantly.

How is the Nominal Checkpoint different from the Complete Checkpoint? 🤔

Both the nominal and the complete checkpoint have the exact same schema. The only significant difference between the two is that the model parameters contained in the nominal checkpoint have no raw data associated with it. To elaborate:

Checkpoint State
 |-- Module Checkpoint State
 |    |-- Parameters Requiring Gradients
 |    |    |-- Parameter Name
 |    |    |-- Parameter Raw Data <Empty for Nominal Checkpoints>
 |    |-- Parameters Not Requiring Gradients 
 |    |    |-- Parameter Name
 |    |    |-- Parameter Raw Data <Empty for Nominal Checkpoints>
 |    |-- Is Nominal Checkpoint Flag <True for Nominal Checkpoints and False for Complete Checkpoints>

All other fields in the checkpoint remain the same.

How does one generate a Nominal Checkpoint? 🤔

The artifacts.generate_artifacts python utility in onnxblock has an optional argument nominal_checkpoint. The new signature of the generate_artifacts method is:

def generate_artifacts(
    model: onnx.ModelProto,
    requires_grad: Optional[List[str]] = None,
    frozen_params: Optional[List[str]] = None,
    loss: Optional[Union[LossType, onnxblock.Block]] = None,
    optimizer: Optional[OptimType] = None,
    artifact_directory: Optional[Union[str, bytes, os.PathLike]] = None,
    prefix: str = "",
    ort_format: bool = False,
    custom_op_library: Optional[Union[str, bytes, os.PathLike]] = None,
    additional_output_names: Optional[List[str]] = None,
    nominal_checkpoint: bool = False,
) -> None:

If the nominal_checkpoint argument is passed in as True, two checkpoint files will be generated called checkpoint and nominal_checkpoint. As the name suggests, the nominal_checkpoint file will be the one that can be used on the device while the checkpoint is expected to be used on the federated learning server (since at least one of the two needs to have the complete state information).

How should one use the Nominal Checkpoint on the device? 🤔

Illustration in C++:

// Load the nominal checkpoint state
Ort::CheckpointState nominal_state = Ort::CheckpointState::LoadCheckpoint(<nominal_checkpoint_path>);
// Instantiate the training session with this nominal state
Ort::TrainingSession training_session = Ort::TrainingSession(env, session_options, nominal_state,
                                                             training_model_uri, eval_model_uri,
                                                             optimizer_model_uri);
...

// Before calling any of the methods on the training session that need access to the parameters,
// load the parameters obtained from the federated learning server
// Make sure that the params buffer has all the parameters (not just the trainable parameters).
Ort::Value params_buffer = UsersFLUtility::GetParamsBufferFromFLServer(..., false /* trainable_only */);
// Calling `FromBuffer` does two things:
// 1. Allocate and copy weights from the params buffer to the checkpoint state, thereby making the
//checkpoint complete
// 2. Signal the `Optimizer` that the checkpoint state is now complete and that it can construct
// the optimizer state based on the available model parameters.
training_session.FromBuffer(params_buffer);
...

// At this point, the `nominal_state` is no longer nominal as all the parameters have been loaded onto the state.
// Call the `training_session` methods as you would normally
auto loss = training_session.TrainStep(...);
training_session.OptimizerStep();
training_session.LazyResetGrad();
...

Illustration in Python:

# Load the nominal checkpoint state
nominal_state = CheckpointState.load_checkpoint(<nominal_checkpoint_path>)
# Instantiate the training session Module and Optimizer with this nominal state
model = Module(training_model_uri, nominal_state)
optim = Optimizer(optimizer_model_file_path, model )

# Before calling any of the methods on the model that need access to the parameters,
# load the parameters obtained from the federated learning server
# Make sure that the params buffer has all the parameters (not just the trainable parameters).
parameter_buffer = users_fl_utility.get_params_buffer_from_fl_server(..., trainable_only=False)
# Calling `copy_buffer_to_parameters` does allocates and copies weights from
# the params buffer to the checkpoint state, thereby making the checkpoint complete
model.copy_buffer_to_parameters(parameter_buffer, trainable_only=False)

...

# At this point, the `nominal_state` is no longer nominal as all the parameters have been loaded onto the state.
# Call the `model` methods as you would normally
model.train()
loss = model(...)
# Calling optim.step() will construct the optimizer state based on the available model parameters
# if not already done
optim.step()
model.lazy_reset_grad()
...

@baijumeswani baijumeswani added the training issues related to ONNX Runtime training; typically submitted using template label Jan 23, 2024
@AdamLouly
Copy link
Contributor

Is there a scenario where only the nominal checkpoint is required?
If so, can we export the nominal checkpoint independently, without the necessity of exporting both the nominal and the complete checkpoint?

@baijumeswani
Copy link
Contributor Author

baijumeswani commented Jan 25, 2024

Is there a scenario where only the nominal checkpoint is required? If so, can we export the nominal checkpoint independently, without the necessity of exporting both the nominal and the complete checkpoint?

For the scenario I foresee, I do not see how users would be able to use only the nominal checkpoint. The nominal checkpoint is of use particularly in the case of federated learning where the device can benefit from not loading the model weights on construction. However, the server must have the complete checkpoint information so it can send that to the device as and when requested. So, the nominal checkpoint and the complete checkpoint must work in tandem to achieve the expected benefit.

Having said that, there might be some use of a utility that can take in a complete checkpoint and return/save a nominal checkpoint equivalent of that complete checkpoint. This might be outside the scope of this pull request though, as this PR is targeting enabling the functionality.

Copy link
Contributor

@AdamLouly AdamLouly left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some suggestions.
Not included in these changes but in the previous code:
Would suggest changing the RuntimeError to ValueError when we're checking for a null value.

orttraining/orttraining/python/training/artifacts.py Outdated Show resolved Hide resolved
orttraining/orttraining/python/training/artifacts.py Outdated Show resolved Hide resolved
askhade
askhade previously approved these changes Jan 30, 2024
Copy link
Contributor

@askhade askhade left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@baijumeswani
Copy link
Contributor Author

Some suggestions.
Not included in these changes but in the previous code:
Would suggest changing the RuntimeError to ValueError when we're checking for a null value.

Would it be ok if I addressed this in another PR?

@AdamLouly
Copy link
Contributor

Some suggestions.
Not included in these changes but in the previous code:
Would suggest changing the RuntimeError to ValueError when we're checking for a null value.

Would it be ok if I addressed this in another PR?

Yes this can be as a part of a refactoring PR.

@baijumeswani baijumeswani merged commit 3262e8d into main Jan 31, 2024
98 checks passed
@baijumeswani baijumeswani deleted the baijumeswani/nominal-checkpoint branch January 31, 2024 06:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
training issues related to ONNX Runtime training; typically submitted using template
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants