Skip to content

Commit

Permalink
Introduce a Nominal Checkpoint for On-Device Training (#19232)
Browse files Browse the repository at this point in the history
  • Loading branch information
baijumeswani authored Jan 31, 2024
1 parent 4562c91 commit 3262e8d
Show file tree
Hide file tree
Showing 30 changed files with 973 additions and 311 deletions.
4 changes: 3 additions & 1 deletion onnxruntime/core/flatbuffers/checkpoint_version.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ namespace onnxruntime {
// The format includes support for the ModuleState (stores the module parameters), OptimizerGroups
// (stores the optimizer states), and PropertyBag
// (stores custom user properties with support for int64, float and strings).
constexpr const int kCheckpointVersion = 1;
// Version 2: Introduces the On-Device Training nominal checkpoint state.
// Changes include the addition of the is_nominal_state field in the checkpoint's ModuleState.
constexpr const int kCheckpointVersion = 2;

/**
* @brief Check if the given checkpoint version is supported in this build
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,17 @@ def FrozenParamsIsNone(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
return o == 0

def ModuleStateStart(builder): builder.StartObject(2)
# ModuleState
def IsNominalState(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
if o != 0:
return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos))
return False

def ModuleStateStart(builder): builder.StartObject(3)
def ModuleStateAddRequiresGradParams(builder, requiresGradParams): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(requiresGradParams), 0)
def ModuleStateStartRequiresGradParamsVector(builder, numElems): return builder.StartVector(4, numElems, 4)
def ModuleStateAddFrozenParams(builder, frozenParams): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(frozenParams), 0)
def ModuleStateStartFrozenParamsVector(builder, numElems): return builder.StartVector(4, numElems, 4)
def ModuleStateAddIsNominalState(builder, isNominalState): builder.PrependBoolSlot(2, isNominalState, 0)
def ModuleStateEnd(builder): return builder.EndObject()
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ namespace onnxruntime.fbs;
table ModuleState {
requires_grad_params:[Tensor];
frozen_params:[Tensor];
// Nominal state just means that the Tensors in the ModuleState
// are empty. i.e. The tensors are treated as named entities
// without any meaningful data.
is_nominal_state:bool;
}

table ParameterOptimizerState {
Expand Down
20 changes: 16 additions & 4 deletions onnxruntime/core/flatbuffers/schema/ort_training_checkpoint.fbs.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,18 @@ struct ModuleState FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
typedef ModuleStateBuilder Builder;
enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
VT_REQUIRES_GRAD_PARAMS = 4,
VT_FROZEN_PARAMS = 6
VT_FROZEN_PARAMS = 6,
VT_IS_NOMINAL_STATE = 8
};
const flatbuffers::Vector<flatbuffers::Offset<onnxruntime::fbs::Tensor>> *requires_grad_params() const {
return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<onnxruntime::fbs::Tensor>> *>(VT_REQUIRES_GRAD_PARAMS);
}
const flatbuffers::Vector<flatbuffers::Offset<onnxruntime::fbs::Tensor>> *frozen_params() const {
return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<onnxruntime::fbs::Tensor>> *>(VT_FROZEN_PARAMS);
}
bool is_nominal_state() const {
return GetField<uint8_t>(VT_IS_NOMINAL_STATE, 0) != 0;
}
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyOffset(verifier, VT_REQUIRES_GRAD_PARAMS) &&
Expand All @@ -55,6 +59,7 @@ struct ModuleState FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
VerifyOffset(verifier, VT_FROZEN_PARAMS) &&
verifier.VerifyVector(frozen_params()) &&
verifier.VerifyVectorOfTables(frozen_params()) &&
VerifyField<uint8_t>(verifier, VT_IS_NOMINAL_STATE) &&
verifier.EndTable();
}
};
Expand All @@ -69,6 +74,9 @@ struct ModuleStateBuilder {
void add_frozen_params(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<onnxruntime::fbs::Tensor>>> frozen_params) {
fbb_.AddOffset(ModuleState::VT_FROZEN_PARAMS, frozen_params);
}
void add_is_nominal_state(bool is_nominal_state) {
fbb_.AddElement<uint8_t>(ModuleState::VT_IS_NOMINAL_STATE, static_cast<uint8_t>(is_nominal_state), 0);
}
explicit ModuleStateBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
Expand All @@ -84,23 +92,27 @@ struct ModuleStateBuilder {
inline flatbuffers::Offset<ModuleState> CreateModuleState(
flatbuffers::FlatBufferBuilder &_fbb,
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<onnxruntime::fbs::Tensor>>> requires_grad_params = 0,
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<onnxruntime::fbs::Tensor>>> frozen_params = 0) {
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<onnxruntime::fbs::Tensor>>> frozen_params = 0,
bool is_nominal_state = false) {
ModuleStateBuilder builder_(_fbb);
builder_.add_frozen_params(frozen_params);
builder_.add_requires_grad_params(requires_grad_params);
builder_.add_is_nominal_state(is_nominal_state);
return builder_.Finish();
}

inline flatbuffers::Offset<ModuleState> CreateModuleStateDirect(
flatbuffers::FlatBufferBuilder &_fbb,
const std::vector<flatbuffers::Offset<onnxruntime::fbs::Tensor>> *requires_grad_params = nullptr,
const std::vector<flatbuffers::Offset<onnxruntime::fbs::Tensor>> *frozen_params = nullptr) {
const std::vector<flatbuffers::Offset<onnxruntime::fbs::Tensor>> *frozen_params = nullptr,
bool is_nominal_state = false) {
auto requires_grad_params__ = requires_grad_params ? _fbb.CreateVector<flatbuffers::Offset<onnxruntime::fbs::Tensor>>(*requires_grad_params) : 0;
auto frozen_params__ = frozen_params ? _fbb.CreateVector<flatbuffers::Offset<onnxruntime::fbs::Tensor>>(*frozen_params) : 0;
return onnxruntime::fbs::CreateModuleState(
_fbb,
requires_grad_params__,
frozen_params__);
frozen_params__,
is_nominal_state);
}

struct ParameterOptimizerState FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
Expand Down
8 changes: 8 additions & 0 deletions onnxruntime/core/graph/graph_flatbuffers_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,14 @@ Status LoadOrtTensorOrtFormat(const fbs::Tensor& fbs_tensor, const AllocatorPtr
ort_tensor = onnxruntime::Tensor(
tensor_dtype, TensorShape(tensor_dims->data(), tensor_dims->size()), allocator);

if (fbs_tensor.raw_data()->size() == 0U) {
// Empty tensor. Nothing to unpack.
// This check is necessary because an empty ort tensor will return a size of 1.
// As a result, the following call to UnpackTensor will fail since the src and
// dst sizes do not match (0 and 1 elements).
return Status::OK();
}

// The tensor proto is used as a dummy here. The actual data is stored in the raw_data field of the flatbuffer.
// The data is copied from the raw_data field to the ort_tensor.
ONNX_NAMESPACE::TensorProto unused_tensor_proto;
Expand Down
Binary file modified onnxruntime/test/testdata/training_api/checkpoint.ckpt
Binary file not shown.
Binary file modified onnxruntime/test/testdata/training_api/custom_ops/checkpoint
Binary file not shown.
Binary file not shown.
Binary file modified onnxruntime/test/testdata/training_api/ort_format/checkpoint
Binary file not shown.
14 changes: 12 additions & 2 deletions orttraining/orttraining/python/orttraining_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -802,6 +802,9 @@ void addObjectMethodsForTraining(py::module& m) {
.def("copy_parameter_from",
[](onnxruntime::training::api::CheckpointState* state,
const std::string& parameter_name, OrtValue& value) -> void {
if (state->module_checkpoint_state.is_nominal_state) {
ORT_THROW("Cannot copy parameter to a nominal state. Please load all the parameter states first");
}
auto it = state->module_checkpoint_state.named_parameters.find(parameter_name);
if (it == state->module_checkpoint_state.named_parameters.end()) {
ORT_THROW("Parameter with name ", parameter_name, " does not exist.");
Expand All @@ -811,6 +814,9 @@ void addObjectMethodsForTraining(py::module& m) {
})
.def("get_parameter",
[](onnxruntime::training::api::CheckpointState* state, const std::string& parameter_name) {
if (state->module_checkpoint_state.is_nominal_state) {
ORT_THROW("Cannot get parameter from a nominal state. Please load the parameter states first");
}
auto it = state->module_checkpoint_state.named_parameters.find(parameter_name);
if (it == state->module_checkpoint_state.named_parameters.end()) {
ORT_THROW("Parameter with name ", parameter_name, " does not exist.");
Expand Down Expand Up @@ -851,6 +857,9 @@ void addObjectMethodsForTraining(py::module& m) {
return std::make_unique<PyOptimizer>(optimizer_model_uri, state, providers, session_options);
}))
.def("optimizer_step", [](PyOptimizer* optimizer) -> void {
// In case the optimizer was constructed using a nominal checkpoint,
// the optimizer state construction is delayed until the first call to Optimizer::Step().
// It is expected that the model parameter state is available at this point.
ORT_THROW_IF_ERROR(optimizer->optimizer_->Step());
})
.def("set_learning_rate", [](PyOptimizer* optimizer, float lr) -> void {
Expand Down Expand Up @@ -893,7 +902,7 @@ void addObjectMethodsForTraining(py::module& m) {
"save_checkpoint",
[](const std::vector<py::bytes>& trainable_tensor_protos_pybytes,
const std::vector<py::bytes>& non_trainable_tensor_protos_pybytes,
const std::string& checkpoint_path) {
const std::string& checkpoint_path, const bool nominal_checkpoint) {
std::vector<TensorProto> trainable_tensor_protos(trainable_tensor_protos_pybytes.size());
std::vector<TensorProto> non_trainable_tensor_protos(non_trainable_tensor_protos_pybytes.size());

Expand All @@ -914,7 +923,8 @@ void addObjectMethodsForTraining(py::module& m) {

ORT_THROW_IF_ERROR(onnxruntime::training::api::SaveCheckpoint(trainable_tensor_protos,
non_trainable_tensor_protos,
ToPathString(checkpoint_path)));
ToPathString(checkpoint_path),
nominal_checkpoint));
});

m.def("save_checkpoint",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,8 @@ def __init__(self, state: C.CheckpointState):
def load_checkpoint(cls, checkpoint_uri: str | os.PathLike) -> CheckpointState:
"""Loads the checkpoint state from the checkpoint file
The checkpoint file can either be the complete checkpoint or the nominal checkpoint.
Args:
checkpoint_uri: The path to the checkpoint file.
Expand Down
3 changes: 3 additions & 0 deletions orttraining/orttraining/python/training/api/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,9 @@ def get_parameters_size(self, trainable_only: bool = True) -> int:
def copy_buffer_to_parameters(self, buffer: OrtValue, trainable_only: bool = True) -> None:
"""Copies the OrtValue buffer to the training session parameters.
In case the module was loaded from a nominal checkpoint, invoking this function is required
to load the updated parameters onto the checkpoint to complete it.
Args:
buffer: The OrtValue buffer to copy to the training session parameters.
"""
Expand Down
66 changes: 37 additions & 29 deletions orttraining/orttraining/python/training/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,11 @@ def generate_artifacts(
loss: Optional[Union[LossType, onnxblock.Block]] = None,
optimizer: Optional[OptimType] = None,
artifact_directory: Optional[Union[str, bytes, os.PathLike]] = None,
**extra_options,
prefix: str = "",
ort_format: bool = False,
custom_op_library: Optional[Union[str, bytes, os.PathLike]] = None,
additional_output_names: Optional[List[str]] = None,
nominal_checkpoint: bool = False,
) -> None:
"""Generates artifacts required for training with ORT training api.
Expand All @@ -63,11 +67,16 @@ def generate_artifacts(
optimizer: The optimizer enum to be used for training. If None, no optimizer model is generated.
artifact_directory: The directory to save the generated artifacts.
If None, the current working directory is used.
prefix (str): The prefix to be used for the generated artifacts. If not specified, no prefix is used.
ort_format (bool): Whether to save the generated artifacts in ORT format or not. Default is False.
custom_op_library (str | os.PathLike): The path to the custom op library.
If not specified, no custom op library is used.
additional_output_names (List[str]): List of additional output names to be added to the training/eval model.
prefix: The prefix to be used for the generated artifacts. If not specified, no prefix is used.
ort_format: Whether to save the generated artifacts in ORT format or not. Default is False.
custom_op_library: The path to the custom op library.
If not specified, no custom op library is used.
additional_output_names: List of additional output names to be added to the training/eval model in addition
to the loss output. Default is None.
nominal_checkpoint: Whether to generate the nominal checkpoint in addition to the complete checkpoint.
Default is False. Nominal checkpoint is a checkpoint that contains nominal information about the model
parameters. It can be used on the device to reduce overhead while constructing the training model
as well as to reduce the size of the checkpoint packaged with the on-device application.
Raises:
RuntimeError: If the loss provided is neither one of the supported losses nor an instance of `onnxblock.Block`
Expand Down Expand Up @@ -107,19 +116,19 @@ def __init__(self, _loss):
self._loss = _loss

def build(self, *inputs_to_loss):
if "additional_output_names" in extra_options:
if additional_output_names:
# If additional output names is not a list, raise an error
if not isinstance(extra_options["additional_output_names"], list):
if not isinstance(additional_output_names, list):
raise RuntimeError(
f"Unknown type provided for additional output names {type(extra_options['additional_output_names'])}. "
f"Unknown type provided for additional output names {type(additional_output_names)}. "
"Expected additional output names to be a list of strings."
)

loss_output = self._loss(*inputs_to_loss)
if isinstance(loss_output, tuple):
return (*loss_output, *tuple(extra_options["additional_output_names"]))
return (*loss_output, *tuple(additional_output_names))
else:
return (loss_output, *tuple(extra_options["additional_output_names"]))
return (loss_output, *tuple(additional_output_names))

return self._loss(*inputs_to_loss)

Expand All @@ -143,58 +152,57 @@ def build(self, *inputs_to_loss):
eval_model = None
model_params = None

custom_op_library = extra_options.get("custom_op_library", None)
custom_op_library_path = None
if custom_op_library is not None:
logging.info("Custom op library provided: %s", custom_op_library)
custom_op_library = pathlib.Path(custom_op_library)
custom_op_library_path = pathlib.Path(custom_op_library)

with onnxblock.base(model), onnxblock.custom_op_library(
custom_op_library
custom_op_library_path
) if custom_op_library is not None else contextlib.nullcontext():
_ = training_block(*[output.name for output in model.graph.output])
training_model, eval_model = training_block.to_model_proto()
model_params = training_block.parameters()

def _export_to_ort_format(model_path, output_dir, extra_options):
if extra_options.get("ort_format", False):
custom_op_library = extra_options.get("custom_op_library", None)
if custom_op_library is not None:
custom_op_library = pathlib.Path(custom_op_library)
def _export_to_ort_format(model_path, output_dir, ort_format, custom_op_library_path):
if ort_format:
convert_onnx_models_to_ort(
model_path,
output_dir=output_dir,
custom_op_library_path=custom_op_library,
custom_op_library_path=custom_op_library_path,
optimization_styles=[OptimizationStyle.Fixed],
)

if artifact_directory is None:
artifact_directory = pathlib.Path.cwd()
prefix = ""
if "prefix" in extra_options:
prefix = extra_options["prefix"]
logging.info("Using prefix %s for generated artifacts.", prefix)

artifact_directory = pathlib.Path(artifact_directory)

if prefix:
logging.info("Using prefix %s for generated artifacts.", prefix)

training_model_path = artifact_directory / f"{prefix}training_model.onnx"
if os.path.exists(training_model_path):
logging.info("Training model path %s already exists. Overwriting.", training_model_path)
onnx.save(training_model, training_model_path)
_export_to_ort_format(training_model_path, artifact_directory, extra_options)
_export_to_ort_format(training_model_path, artifact_directory, ort_format, custom_op_library_path)
logging.info("Saved training model to %s", training_model_path)

eval_model_path = artifact_directory / f"{prefix}eval_model.onnx"
if os.path.exists(eval_model_path):
logging.info("Eval model path %s already exists. Overwriting.", eval_model_path)
onnx.save(eval_model, eval_model_path)
_export_to_ort_format(eval_model_path, artifact_directory, extra_options)
_export_to_ort_format(eval_model_path, artifact_directory, ort_format, custom_op_library_path)
logging.info("Saved eval model to %s", eval_model_path)

checkpoint_path = artifact_directory / f"{prefix}checkpoint"
if os.path.exists(checkpoint_path):
logging.info("Checkpoint path %s already exists. Overwriting.", checkpoint_path)
onnxblock.save_checkpoint(training_block.parameters(), checkpoint_path)
onnxblock.save_checkpoint(training_block.parameters(), checkpoint_path, nominal_checkpoint=False)
logging.info("Saved checkpoint to %s", checkpoint_path)
if nominal_checkpoint:
nominal_checkpoint_path = artifact_directory / f"{prefix}nominal_checkpoint"
onnxblock.save_checkpoint(training_block.parameters(), nominal_checkpoint_path, nominal_checkpoint=True)
logging.info("Saved nominal checkpoint to %s", nominal_checkpoint_path)

# If optimizer is not specified, skip creating the optimizer model
if optimizer is None:
Expand Down Expand Up @@ -225,5 +233,5 @@ def _export_to_ort_format(model_path, output_dir, extra_options):

optimizer_model_path = artifact_directory / f"{prefix}optimizer_model.onnx"
onnx.save(optim_model, optimizer_model_path)
_export_to_ort_format(optimizer_model_path, artifact_directory, extra_options)
_export_to_ort_format(optimizer_model_path, artifact_directory, ort_format, custom_op_library_path)
logging.info("Saved optimizer model to %s", optimizer_model_path)
Loading

0 comments on commit 3262e8d

Please sign in to comment.