diff --git a/onnxruntime/core/flatbuffers/checkpoint_version.h b/onnxruntime/core/flatbuffers/checkpoint_version.h index 6cad27c35024b..e6ee20bf508ce 100644 --- a/onnxruntime/core/flatbuffers/checkpoint_version.h +++ b/onnxruntime/core/flatbuffers/checkpoint_version.h @@ -13,7 +13,9 @@ namespace onnxruntime { // The format includes support for the ModuleState (stores the module parameters), OptimizerGroups // (stores the optimizer states), and PropertyBag // (stores custom user properties with support for int64, float and strings). -constexpr const int kCheckpointVersion = 1; +// Version 2: Introduces the On-Device Training nominal checkpoint state. +// Changes include the addition of the is_nominal_state field in the checkpoint's ModuleState. +constexpr const int kCheckpointVersion = 2; /** * @brief Check if the given checkpoint version is supported in this build diff --git a/onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/ModuleState.py b/onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/ModuleState.py index 2be826fee2cc3..19c6b1b6f2753 100644 --- a/onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/ModuleState.py +++ b/onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/ModuleState.py @@ -74,9 +74,17 @@ def FrozenParamsIsNone(self): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) return o == 0 -def ModuleStateStart(builder): builder.StartObject(2) + # ModuleState + def IsNominalState(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False + +def ModuleStateStart(builder): builder.StartObject(3) def ModuleStateAddRequiresGradParams(builder, requiresGradParams): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(requiresGradParams), 0) def ModuleStateStartRequiresGradParamsVector(builder, numElems): return builder.StartVector(4, numElems, 4) def ModuleStateAddFrozenParams(builder, frozenParams): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(frozenParams), 0) def ModuleStateStartFrozenParamsVector(builder, numElems): return builder.StartVector(4, numElems, 4) +def ModuleStateAddIsNominalState(builder, isNominalState): builder.PrependBoolSlot(2, isNominalState, 0) def ModuleStateEnd(builder): return builder.EndObject() diff --git a/onnxruntime/core/flatbuffers/schema/ort_training_checkpoint.fbs b/onnxruntime/core/flatbuffers/schema/ort_training_checkpoint.fbs index c8244b0a426f3..94757fa6d5bf5 100644 --- a/onnxruntime/core/flatbuffers/schema/ort_training_checkpoint.fbs +++ b/onnxruntime/core/flatbuffers/schema/ort_training_checkpoint.fbs @@ -8,6 +8,10 @@ namespace onnxruntime.fbs; table ModuleState { requires_grad_params:[Tensor]; frozen_params:[Tensor]; + // Nominal state just means that the Tensors in the ModuleState + // are empty. i.e. The tensors are treated as named entities + // without any meaningful data. + is_nominal_state:bool; } table ParameterOptimizerState { diff --git a/onnxruntime/core/flatbuffers/schema/ort_training_checkpoint.fbs.h b/onnxruntime/core/flatbuffers/schema/ort_training_checkpoint.fbs.h index 48feebb197694..d205c5eb8f409 100644 --- a/onnxruntime/core/flatbuffers/schema/ort_training_checkpoint.fbs.h +++ b/onnxruntime/core/flatbuffers/schema/ort_training_checkpoint.fbs.h @@ -39,7 +39,8 @@ struct ModuleState FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef ModuleStateBuilder Builder; enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { VT_REQUIRES_GRAD_PARAMS = 4, - VT_FROZEN_PARAMS = 6 + VT_FROZEN_PARAMS = 6, + VT_IS_NOMINAL_STATE = 8 }; const flatbuffers::Vector> *requires_grad_params() const { return GetPointer> *>(VT_REQUIRES_GRAD_PARAMS); @@ -47,6 +48,9 @@ struct ModuleState FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const flatbuffers::Vector> *frozen_params() const { return GetPointer> *>(VT_FROZEN_PARAMS); } + bool is_nominal_state() const { + return GetField(VT_IS_NOMINAL_STATE, 0) != 0; + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_REQUIRES_GRAD_PARAMS) && @@ -55,6 +59,7 @@ struct ModuleState FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VerifyOffset(verifier, VT_FROZEN_PARAMS) && verifier.VerifyVector(frozen_params()) && verifier.VerifyVectorOfTables(frozen_params()) && + VerifyField(verifier, VT_IS_NOMINAL_STATE) && verifier.EndTable(); } }; @@ -69,6 +74,9 @@ struct ModuleStateBuilder { void add_frozen_params(flatbuffers::Offset>> frozen_params) { fbb_.AddOffset(ModuleState::VT_FROZEN_PARAMS, frozen_params); } + void add_is_nominal_state(bool is_nominal_state) { + fbb_.AddElement(ModuleState::VT_IS_NOMINAL_STATE, static_cast(is_nominal_state), 0); + } explicit ModuleStateBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -84,23 +92,27 @@ struct ModuleStateBuilder { inline flatbuffers::Offset CreateModuleState( flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::Offset>> requires_grad_params = 0, - flatbuffers::Offset>> frozen_params = 0) { + flatbuffers::Offset>> frozen_params = 0, + bool is_nominal_state = false) { ModuleStateBuilder builder_(_fbb); builder_.add_frozen_params(frozen_params); builder_.add_requires_grad_params(requires_grad_params); + builder_.add_is_nominal_state(is_nominal_state); return builder_.Finish(); } inline flatbuffers::Offset CreateModuleStateDirect( flatbuffers::FlatBufferBuilder &_fbb, const std::vector> *requires_grad_params = nullptr, - const std::vector> *frozen_params = nullptr) { + const std::vector> *frozen_params = nullptr, + bool is_nominal_state = false) { auto requires_grad_params__ = requires_grad_params ? _fbb.CreateVector>(*requires_grad_params) : 0; auto frozen_params__ = frozen_params ? _fbb.CreateVector>(*frozen_params) : 0; return onnxruntime::fbs::CreateModuleState( _fbb, requires_grad_params__, - frozen_params__); + frozen_params__, + is_nominal_state); } struct ParameterOptimizerState FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { diff --git a/onnxruntime/core/graph/graph_flatbuffers_utils.cc b/onnxruntime/core/graph/graph_flatbuffers_utils.cc index 8e962403556dd..6d7ed94b2956d 100644 --- a/onnxruntime/core/graph/graph_flatbuffers_utils.cc +++ b/onnxruntime/core/graph/graph_flatbuffers_utils.cc @@ -392,6 +392,14 @@ Status LoadOrtTensorOrtFormat(const fbs::Tensor& fbs_tensor, const AllocatorPtr ort_tensor = onnxruntime::Tensor( tensor_dtype, TensorShape(tensor_dims->data(), tensor_dims->size()), allocator); + if (fbs_tensor.raw_data()->size() == 0U) { + // Empty tensor. Nothing to unpack. + // This check is necessary because an empty ort tensor will return a size of 1. + // As a result, the following call to UnpackTensor will fail since the src and + // dst sizes do not match (0 and 1 elements). + return Status::OK(); + } + // The tensor proto is used as a dummy here. The actual data is stored in the raw_data field of the flatbuffer. // The data is copied from the raw_data field to the ort_tensor. ONNX_NAMESPACE::TensorProto unused_tensor_proto; diff --git a/onnxruntime/test/testdata/training_api/checkpoint.ckpt b/onnxruntime/test/testdata/training_api/checkpoint.ckpt index d0b7d0deb654c..d1bc1f121c8e6 100644 Binary files a/onnxruntime/test/testdata/training_api/checkpoint.ckpt and b/onnxruntime/test/testdata/training_api/checkpoint.ckpt differ diff --git a/onnxruntime/test/testdata/training_api/custom_ops/checkpoint b/onnxruntime/test/testdata/training_api/custom_ops/checkpoint index 753b24af63ba2..ce23d149e9499 100644 Binary files a/onnxruntime/test/testdata/training_api/custom_ops/checkpoint and b/onnxruntime/test/testdata/training_api/custom_ops/checkpoint differ diff --git a/onnxruntime/test/testdata/training_api/nominal_checkpoint b/onnxruntime/test/testdata/training_api/nominal_checkpoint new file mode 100644 index 0000000000000..2eadfeece2ed9 Binary files /dev/null and b/onnxruntime/test/testdata/training_api/nominal_checkpoint differ diff --git a/onnxruntime/test/testdata/training_api/ort_format/checkpoint b/onnxruntime/test/testdata/training_api/ort_format/checkpoint index ab35c9ad5acde..83ef6aa4c30de 100644 Binary files a/onnxruntime/test/testdata/training_api/ort_format/checkpoint and b/onnxruntime/test/testdata/training_api/ort_format/checkpoint differ diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc index 0c2bfa19e1671..4ab8db8565bf9 100644 --- a/orttraining/orttraining/python/orttraining_pybind_state.cc +++ b/orttraining/orttraining/python/orttraining_pybind_state.cc @@ -802,6 +802,9 @@ void addObjectMethodsForTraining(py::module& m) { .def("copy_parameter_from", [](onnxruntime::training::api::CheckpointState* state, const std::string& parameter_name, OrtValue& value) -> void { + if (state->module_checkpoint_state.is_nominal_state) { + ORT_THROW("Cannot copy parameter to a nominal state. Please load all the parameter states first"); + } auto it = state->module_checkpoint_state.named_parameters.find(parameter_name); if (it == state->module_checkpoint_state.named_parameters.end()) { ORT_THROW("Parameter with name ", parameter_name, " does not exist."); @@ -811,6 +814,9 @@ void addObjectMethodsForTraining(py::module& m) { }) .def("get_parameter", [](onnxruntime::training::api::CheckpointState* state, const std::string& parameter_name) { + if (state->module_checkpoint_state.is_nominal_state) { + ORT_THROW("Cannot get parameter from a nominal state. Please load the parameter states first"); + } auto it = state->module_checkpoint_state.named_parameters.find(parameter_name); if (it == state->module_checkpoint_state.named_parameters.end()) { ORT_THROW("Parameter with name ", parameter_name, " does not exist."); @@ -851,6 +857,9 @@ void addObjectMethodsForTraining(py::module& m) { return std::make_unique(optimizer_model_uri, state, providers, session_options); })) .def("optimizer_step", [](PyOptimizer* optimizer) -> void { + // In case the optimizer was constructed using a nominal checkpoint, + // the optimizer state construction is delayed until the first call to Optimizer::Step(). + // It is expected that the model parameter state is available at this point. ORT_THROW_IF_ERROR(optimizer->optimizer_->Step()); }) .def("set_learning_rate", [](PyOptimizer* optimizer, float lr) -> void { @@ -893,7 +902,7 @@ void addObjectMethodsForTraining(py::module& m) { "save_checkpoint", [](const std::vector& trainable_tensor_protos_pybytes, const std::vector& non_trainable_tensor_protos_pybytes, - const std::string& checkpoint_path) { + const std::string& checkpoint_path, const bool nominal_checkpoint) { std::vector trainable_tensor_protos(trainable_tensor_protos_pybytes.size()); std::vector non_trainable_tensor_protos(non_trainable_tensor_protos_pybytes.size()); @@ -914,7 +923,8 @@ void addObjectMethodsForTraining(py::module& m) { ORT_THROW_IF_ERROR(onnxruntime::training::api::SaveCheckpoint(trainable_tensor_protos, non_trainable_tensor_protos, - ToPathString(checkpoint_path))); + ToPathString(checkpoint_path), + nominal_checkpoint)); }); m.def("save_checkpoint", diff --git a/orttraining/orttraining/python/training/api/checkpoint_state.py b/orttraining/orttraining/python/training/api/checkpoint_state.py index ba95cd04fce7e..cc4e84111c47c 100644 --- a/orttraining/orttraining/python/training/api/checkpoint_state.py +++ b/orttraining/orttraining/python/training/api/checkpoint_state.py @@ -222,6 +222,8 @@ def __init__(self, state: C.CheckpointState): def load_checkpoint(cls, checkpoint_uri: str | os.PathLike) -> CheckpointState: """Loads the checkpoint state from the checkpoint file + The checkpoint file can either be the complete checkpoint or the nominal checkpoint. + Args: checkpoint_uri: The path to the checkpoint file. diff --git a/orttraining/orttraining/python/training/api/module.py b/orttraining/orttraining/python/training/api/module.py index f8f6b4322ce79..a87cd6fdd93cf 100644 --- a/orttraining/orttraining/python/training/api/module.py +++ b/orttraining/orttraining/python/training/api/module.py @@ -178,6 +178,9 @@ def get_parameters_size(self, trainable_only: bool = True) -> int: def copy_buffer_to_parameters(self, buffer: OrtValue, trainable_only: bool = True) -> None: """Copies the OrtValue buffer to the training session parameters. + In case the module was loaded from a nominal checkpoint, invoking this function is required + to load the updated parameters onto the checkpoint to complete it. + Args: buffer: The OrtValue buffer to copy to the training session parameters. """ diff --git a/orttraining/orttraining/python/training/artifacts.py b/orttraining/orttraining/python/training/artifacts.py index a57105545e114..7a4eb251bc5bc 100644 --- a/orttraining/orttraining/python/training/artifacts.py +++ b/orttraining/orttraining/python/training/artifacts.py @@ -43,7 +43,11 @@ def generate_artifacts( loss: Optional[Union[LossType, onnxblock.Block]] = None, optimizer: Optional[OptimType] = None, artifact_directory: Optional[Union[str, bytes, os.PathLike]] = None, - **extra_options, + prefix: str = "", + ort_format: bool = False, + custom_op_library: Optional[Union[str, bytes, os.PathLike]] = None, + additional_output_names: Optional[List[str]] = None, + nominal_checkpoint: bool = False, ) -> None: """Generates artifacts required for training with ORT training api. @@ -63,11 +67,16 @@ def generate_artifacts( optimizer: The optimizer enum to be used for training. If None, no optimizer model is generated. artifact_directory: The directory to save the generated artifacts. If None, the current working directory is used. - prefix (str): The prefix to be used for the generated artifacts. If not specified, no prefix is used. - ort_format (bool): Whether to save the generated artifacts in ORT format or not. Default is False. - custom_op_library (str | os.PathLike): The path to the custom op library. - If not specified, no custom op library is used. - additional_output_names (List[str]): List of additional output names to be added to the training/eval model. + prefix: The prefix to be used for the generated artifacts. If not specified, no prefix is used. + ort_format: Whether to save the generated artifacts in ORT format or not. Default is False. + custom_op_library: The path to the custom op library. + If not specified, no custom op library is used. + additional_output_names: List of additional output names to be added to the training/eval model in addition + to the loss output. Default is None. + nominal_checkpoint: Whether to generate the nominal checkpoint in addition to the complete checkpoint. + Default is False. Nominal checkpoint is a checkpoint that contains nominal information about the model + parameters. It can be used on the device to reduce overhead while constructing the training model + as well as to reduce the size of the checkpoint packaged with the on-device application. Raises: RuntimeError: If the loss provided is neither one of the supported losses nor an instance of `onnxblock.Block` @@ -107,19 +116,19 @@ def __init__(self, _loss): self._loss = _loss def build(self, *inputs_to_loss): - if "additional_output_names" in extra_options: + if additional_output_names: # If additional output names is not a list, raise an error - if not isinstance(extra_options["additional_output_names"], list): + if not isinstance(additional_output_names, list): raise RuntimeError( - f"Unknown type provided for additional output names {type(extra_options['additional_output_names'])}. " + f"Unknown type provided for additional output names {type(additional_output_names)}. " "Expected additional output names to be a list of strings." ) loss_output = self._loss(*inputs_to_loss) if isinstance(loss_output, tuple): - return (*loss_output, *tuple(extra_options["additional_output_names"])) + return (*loss_output, *tuple(additional_output_names)) else: - return (loss_output, *tuple(extra_options["additional_output_names"])) + return (loss_output, *tuple(additional_output_names)) return self._loss(*inputs_to_loss) @@ -143,58 +152,57 @@ def build(self, *inputs_to_loss): eval_model = None model_params = None - custom_op_library = extra_options.get("custom_op_library", None) + custom_op_library_path = None if custom_op_library is not None: logging.info("Custom op library provided: %s", custom_op_library) - custom_op_library = pathlib.Path(custom_op_library) + custom_op_library_path = pathlib.Path(custom_op_library) with onnxblock.base(model), onnxblock.custom_op_library( - custom_op_library + custom_op_library_path ) if custom_op_library is not None else contextlib.nullcontext(): _ = training_block(*[output.name for output in model.graph.output]) training_model, eval_model = training_block.to_model_proto() model_params = training_block.parameters() - def _export_to_ort_format(model_path, output_dir, extra_options): - if extra_options.get("ort_format", False): - custom_op_library = extra_options.get("custom_op_library", None) - if custom_op_library is not None: - custom_op_library = pathlib.Path(custom_op_library) + def _export_to_ort_format(model_path, output_dir, ort_format, custom_op_library_path): + if ort_format: convert_onnx_models_to_ort( model_path, output_dir=output_dir, - custom_op_library_path=custom_op_library, + custom_op_library_path=custom_op_library_path, optimization_styles=[OptimizationStyle.Fixed], ) if artifact_directory is None: artifact_directory = pathlib.Path.cwd() - prefix = "" - if "prefix" in extra_options: - prefix = extra_options["prefix"] - logging.info("Using prefix %s for generated artifacts.", prefix) - artifact_directory = pathlib.Path(artifact_directory) + if prefix: + logging.info("Using prefix %s for generated artifacts.", prefix) + training_model_path = artifact_directory / f"{prefix}training_model.onnx" if os.path.exists(training_model_path): logging.info("Training model path %s already exists. Overwriting.", training_model_path) onnx.save(training_model, training_model_path) - _export_to_ort_format(training_model_path, artifact_directory, extra_options) + _export_to_ort_format(training_model_path, artifact_directory, ort_format, custom_op_library_path) logging.info("Saved training model to %s", training_model_path) eval_model_path = artifact_directory / f"{prefix}eval_model.onnx" if os.path.exists(eval_model_path): logging.info("Eval model path %s already exists. Overwriting.", eval_model_path) onnx.save(eval_model, eval_model_path) - _export_to_ort_format(eval_model_path, artifact_directory, extra_options) + _export_to_ort_format(eval_model_path, artifact_directory, ort_format, custom_op_library_path) logging.info("Saved eval model to %s", eval_model_path) checkpoint_path = artifact_directory / f"{prefix}checkpoint" if os.path.exists(checkpoint_path): logging.info("Checkpoint path %s already exists. Overwriting.", checkpoint_path) - onnxblock.save_checkpoint(training_block.parameters(), checkpoint_path) + onnxblock.save_checkpoint(training_block.parameters(), checkpoint_path, nominal_checkpoint=False) logging.info("Saved checkpoint to %s", checkpoint_path) + if nominal_checkpoint: + nominal_checkpoint_path = artifact_directory / f"{prefix}nominal_checkpoint" + onnxblock.save_checkpoint(training_block.parameters(), nominal_checkpoint_path, nominal_checkpoint=True) + logging.info("Saved nominal checkpoint to %s", nominal_checkpoint_path) # If optimizer is not specified, skip creating the optimizer model if optimizer is None: @@ -225,5 +233,5 @@ def _export_to_ort_format(model_path, output_dir, extra_options): optimizer_model_path = artifact_directory / f"{prefix}optimizer_model.onnx" onnx.save(optim_model, optimizer_model_path) - _export_to_ort_format(optimizer_model_path, artifact_directory, extra_options) + _export_to_ort_format(optimizer_model_path, artifact_directory, ort_format, custom_op_library_path) logging.info("Saved optimizer model to %s", optimizer_model_path) diff --git a/orttraining/orttraining/python/training/onnxblock/checkpoint_utils.py b/orttraining/orttraining/python/training/onnxblock/checkpoint_utils.py index bc50d4afa2fe1..de3453c630f9c 100644 --- a/orttraining/orttraining/python/training/onnxblock/checkpoint_utils.py +++ b/orttraining/orttraining/python/training/onnxblock/checkpoint_utils.py @@ -6,18 +6,21 @@ import onnx -from onnxruntime.capi._pybind_state import get_model_after_loading_checkpoint as _internal_load_checkpoint_to_model -from onnxruntime.capi._pybind_state import save_checkpoint as _internal_save_checkpoint +from onnxruntime.capi._pybind_state import get_model_after_loading_checkpoint as _load_checkpoint_to_model +from onnxruntime.capi._pybind_state import save_checkpoint as _save_checkpoint def save_checkpoint( - parameters: Tuple[List[onnx.TensorProto], List[onnx.TensorProto]], path_to_checkpoint: Union[str, os.PathLike] + parameters: Tuple[List[onnx.TensorProto], List[onnx.TensorProto]], + path_to_checkpoint: Union[str, os.PathLike], + nominal_checkpoint: bool = False, ) -> None: """Saves the parameters to the checkpoint directory path_to_checkpoint. Args: parameters tuple(trainable_params, non_trainable_params): The parameters to save to the checkpoint file. - path_to_checkpoint (str): The path to the checkpoint directory. + path_to_checkpoint: The path to the checkpoint directory. + nominal_checkpoint: If True, the checkpoint is saved as a nominal checkpoint. Default is False. """ if parameters is None: @@ -26,7 +29,7 @@ def save_checkpoint( trainable_params, non_trainable_params = parameters trainable_params = [param.SerializeToString() for param in trainable_params] non_trainable_params = [param.SerializeToString() for param in non_trainable_params] - _internal_save_checkpoint(trainable_params, non_trainable_params, os.fspath(path_to_checkpoint)) + _save_checkpoint(trainable_params, non_trainable_params, os.fspath(path_to_checkpoint), nominal_checkpoint) def load_checkpoint_to_model(path_to_checkpoint: Union[str, os.PathLike], model: onnx.ModelProto) -> None: @@ -37,4 +40,4 @@ def load_checkpoint_to_model(path_to_checkpoint: Union[str, os.PathLike], model: model (onnx.ModelProto): The model to load the checkpoint to. """ - model.ParseFromString(_internal_load_checkpoint_to_model(os.fspath(path_to_checkpoint), model.SerializeToString())) + model.ParseFromString(_load_checkpoint_to_model(os.fspath(path_to_checkpoint), model.SerializeToString())) diff --git a/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py b/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py index 910ddb34e2b52..3d41c8678278c 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py +++ b/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py @@ -1047,3 +1047,26 @@ def build(self, input1, input2): with tempfile.TemporaryDirectory() as temp_dir: artifacts.generate_artifacts(onnx_model, loss=CustomLossBlock(), artifact_directory=temp_dir) + + +def test_save_nominal_checkpoint(): + device = "cpu" + batch_size, input_size, hidden_size, output_size = 64, 784, 500, 10 + _, base_model = _get_models(device, batch_size, input_size, hidden_size, output_size) + + with tempfile.TemporaryDirectory() as temp_dir: + artifacts.generate_artifacts( + base_model, + requires_grad=["fc1.weight", "fc1.bias", "fc2.weight", "fc2.bias"], + loss=artifacts.LossType.CrossEntropyLoss, + optimizer=artifacts.OptimType.AdamW, + artifact_directory=temp_dir, + nominal_checkpoint=True, + ) + + assert os.path.exists(os.path.join(temp_dir, "checkpoint")) + assert os.path.exists(os.path.join(temp_dir, "nominal_checkpoint")) + assert ( + os.stat(os.path.join(temp_dir, "checkpoint")).st_size + > os.stat(os.path.join(temp_dir, "nominal_checkpoint")).st_size + ) diff --git a/orttraining/orttraining/test/python/orttraining_test_ort_apis_py_bindings.py b/orttraining/orttraining/test/python/orttraining_test_ort_apis_py_bindings.py index 34d8c24ccfab4..ce251b98447bf 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ort_apis_py_bindings.py +++ b/orttraining/orttraining/test/python/orttraining_test_ort_apis_py_bindings.py @@ -6,6 +6,7 @@ import os import pathlib import tempfile +from dataclasses import dataclass import numpy as np import onnx @@ -28,11 +29,22 @@ def build(self, output_name): return self.loss(output_name) +@dataclass +class Artifacts: + checkpoint_file_path: str + training_model_file_path: str + eval_model_file_path: str + optimizer_model_file_path: str + pt_model: torch.nn.Module + nominal_checkpoint_file_path: str | None = None + + def _create_training_artifacts( artifact_directory: str | os.PathLike, requires_grad: list[str] | None = None, frozen_params: list[str] | None = None, optimizer_type=artifacts.OptimType.AdamW, + nominal_checkpoint: bool = False, ): device = "cpu" batch_size, input_size, hidden_size, output_size = 64, 784, 500, 10 @@ -51,14 +63,20 @@ def _create_training_artifacts( requires_grad=requires_grad, frozen_params=frozen_params, artifact_directory=artifact_directory, + nominal_checkpoint=nominal_checkpoint, ) training_model_file = os.path.join(artifact_directory, "training_model.onnx") eval_model_file = os.path.join(artifact_directory, "eval_model.onnx") optimizer_model_file = os.path.join(artifact_directory, "optimizer_model.onnx") checkpoint_file = os.path.join(artifact_directory, "checkpoint") + nominal_checkpoint_file = None + if nominal_checkpoint: + nominal_checkpoint_file = os.path.join(artifact_directory, "nominal_checkpoint") - return checkpoint_file, training_model_file, eval_model_file, optimizer_model_file, pt_model + return Artifacts( + checkpoint_file, training_model_file, eval_model_file, optimizer_model_file, pt_model, nominal_checkpoint_file + ) def test_train_step(): @@ -67,22 +85,16 @@ def test_train_step(): labels = torch.randint(high=10, size=(64,), dtype=torch.int64).numpy() with tempfile.TemporaryDirectory() as temp_dir: - ( - checkpoint_file_path, - training_model_file_path, - _, - _, - pt_model, - ) = _create_training_artifacts(temp_dir) + artifacts = _create_training_artifacts(temp_dir) # Create Checkpoint State. - state = CheckpointState.load_checkpoint(checkpoint_file_path) + state = CheckpointState.load_checkpoint(artifacts.checkpoint_file_path) # Create a Module. - model = Module(training_model_file_path, state) + model = Module(artifacts.training_model_file_path, state) model.train() ort_loss = model(inputs, labels) # Calculate loss using pytorch model to compare it with Module's output. - pt_outputs = pt_model(torch.from_numpy(inputs)) + pt_outputs = artifacts.pt_model(torch.from_numpy(inputs)) loss_fn = torch.nn.CrossEntropyLoss() pt_loss = loss_fn(pt_outputs, torch.from_numpy(labels).long()) @@ -95,17 +107,11 @@ def test_eval_step(): labels = torch.randint(high=10, size=(64,), dtype=torch.int64).numpy() with tempfile.TemporaryDirectory() as temp_dir: - ( - checkpoint_file_path, - training_model_file_path, - eval_model_file_path, - _, - _, - ) = _create_training_artifacts(temp_dir) + artifacts = _create_training_artifacts(temp_dir) # Create Checkpoint State. - state = CheckpointState.load_checkpoint(checkpoint_file_path) + state = CheckpointState.load_checkpoint(artifacts.checkpoint_file_path) # Create a Module. - model = Module(training_model_file_path, state, eval_model_file_path) + model = Module(artifacts.training_model_file_path, state, artifacts.eval_model_file_path) model.train() model(inputs, labels) @@ -121,18 +127,12 @@ def test_optimizer_step(optimizer_type): labels = torch.randint(high=10, size=(64,), dtype=torch.int64).numpy() with tempfile.TemporaryDirectory() as temp_dir: - ( - checkpoint_file_path, - training_model_file_path, - _, - optimizer_model_file_path, - _, - ) = _create_training_artifacts(temp_dir, optimizer_type=optimizer_type) + artifacts = _create_training_artifacts(temp_dir, optimizer_type=optimizer_type) # Create Checkpoint State. - state = CheckpointState.load_checkpoint(checkpoint_file_path) + state = CheckpointState.load_checkpoint(artifacts.checkpoint_file_path) # Create a Module and Optimizer. - model = Module(training_model_file_path, state) - optimizer = Optimizer(optimizer_model_file_path, model) + model = Module(artifacts.training_model_file_path, state) + optimizer = Optimizer(artifacts.optimizer_model_file_path, model) model.train() old_flatten_params = model.get_contiguous_parameters() @@ -147,18 +147,12 @@ def test_optimizer_step(optimizer_type): @pytest.mark.parametrize("optimizer_type", [artifacts.OptimType.SGD, artifacts.OptimType.AdamW]) def test_get_and_set_lr(optimizer_type): with tempfile.TemporaryDirectory() as temp_dir: - ( - checkpoint_file_path, - training_model_file_path, - _, - optimizer_model_file_path, - _, - ) = _create_training_artifacts(temp_dir, optimizer_type=optimizer_type) + artifacts = _create_training_artifacts(temp_dir, optimizer_type=optimizer_type) # Create Checkpoint State. - state = CheckpointState.load_checkpoint(checkpoint_file_path) + state = CheckpointState.load_checkpoint(artifacts.checkpoint_file_path) # Create a Module and Optimizer. - model = Module(training_model_file_path, state) - optimizer = Optimizer(optimizer_model_file_path, model) + model = Module(artifacts.training_model_file_path, state) + optimizer = Optimizer(artifacts.optimizer_model_file_path, model) # Test get and set learning rate. lr = optimizer.get_learning_rate() @@ -178,18 +172,11 @@ def test_scheduler_step(optimizer_type): labels = torch.randint(high=10, size=(64,), dtype=torch.int64).numpy() with tempfile.TemporaryDirectory() as temp_dir: - ( - checkpoint_file_path, - training_model_file_path, - _, - optimizer_model_file_path, - _, - ) = _create_training_artifacts(temp_dir, optimizer_type=optimizer_type) - # Create Checkpoint State. - state = CheckpointState.load_checkpoint(checkpoint_file_path) + artifacts = _create_training_artifacts(temp_dir, optimizer_type=optimizer_type) + state = CheckpointState.load_checkpoint(artifacts.checkpoint_file_path) # Create a Module and Optimizer. - model = Module(training_model_file_path, state) - optimizer = Optimizer(optimizer_model_file_path, model) + model = Module(artifacts.training_model_file_path, state) + optimizer = Optimizer(artifacts.optimizer_model_file_path, model) scheduler = LinearLRScheduler(optimizer, 1, 2, 0.2) # Test get and set learning rate. @@ -212,17 +199,11 @@ def test_training_module_checkpoint(): labels = torch.randint(high=10, size=(64,), dtype=torch.int64).numpy() with tempfile.TemporaryDirectory() as temp_dir: - ( - checkpoint_file_path, - training_model_file_path, - _, - _, - _, - ) = _create_training_artifacts(temp_dir) + artifacts = _create_training_artifacts(temp_dir) # Create Checkpoint State. - state = CheckpointState.load_checkpoint(checkpoint_file_path) + state = CheckpointState.load_checkpoint(artifacts.checkpoint_file_path) # Create a Training Module and Training Optimizer. - model = Module(training_model_file_path, state) + model = Module(artifacts.training_model_file_path, state) model.train() model(inputs, labels) @@ -237,7 +218,7 @@ def test_training_module_checkpoint(): # Assert the checkpoint parameters remain after saving. new_state = CheckpointState.load_checkpoint(checkpoint_save_path) - new_model = Module(training_model_file_path, new_state) + new_model = Module(artifacts.training_model_file_path, new_state) new_params = new_model.get_contiguous_parameters() @@ -252,23 +233,17 @@ def test_copy_buffer_to_parameters(trainable_only, optimizer_type): labels = torch.randint(high=10, size=(64,), dtype=torch.int64).numpy() with tempfile.TemporaryDirectory() as temp_dir: - ( - checkpoint_file_path, - training_model_file_path, - _, - optimizer_model_file_path, - _, - ) = _create_training_artifacts( + artifacts = _create_training_artifacts( temp_dir, requires_grad=["fc2.weight", "fc2.bias"], frozen_params=["fc1.weight", "fc1.bias"], optimizer_type=optimizer_type, ) - state = CheckpointState.load_checkpoint(checkpoint_file_path) + state = CheckpointState.load_checkpoint(artifacts.checkpoint_file_path) # Create a Module and Optimizer. - model = Module(training_model_file_path, state) - optimizer = Optimizer(optimizer_model_file_path, model) + model = Module(artifacts.training_model_file_path, state) + optimizer = Optimizer(artifacts.optimizer_model_file_path, model) # Keep a copy of the parameters. old_output_params = model.get_contiguous_parameters(trainable_only=trainable_only) @@ -295,19 +270,13 @@ def test_copy_buffer_to_parameters(trainable_only, optimizer_type): def test_export_model_for_inferencing(): with tempfile.TemporaryDirectory() as temp_dir: - ( - checkpoint_file_path, - training_model_file_path, - eval_model_file_path, - _, - _, - ) = _create_training_artifacts(temp_dir) + artifacts = _create_training_artifacts(temp_dir) # Create Checkpoint State. - state = CheckpointState.load_checkpoint(checkpoint_file_path) + state = CheckpointState.load_checkpoint(artifacts.checkpoint_file_path) # Create a Module. - model = Module(training_model_file_path, state, eval_model_file_path) + model = Module(artifacts.training_model_file_path, state, artifacts.eval_model_file_path) # Export inference model inference_model_file_path = os.path.join(temp_dir, "inference_model.onnx") @@ -317,18 +286,12 @@ def test_export_model_for_inferencing(): def test_cuda_execution_provider(): with tempfile.TemporaryDirectory() as temp_dir: - ( - checkpoint_file_path, - training_model_file_path, - _, - _, - _, - ) = _create_training_artifacts(temp_dir) + artifacts = _create_training_artifacts(temp_dir) # Create Checkpoint State. - state = CheckpointState.load_checkpoint(checkpoint_file_path) + state = CheckpointState.load_checkpoint(artifacts.checkpoint_file_path) # Create a Module. - model = Module(training_model_file_path, state, device="cuda") + model = Module(artifacts.training_model_file_path, state, device="cuda") params = model.get_contiguous_parameters() # Check if parameters are moved to cuda. @@ -341,19 +304,13 @@ def test_cuda_execution_provider(): ) def test_add_get_property(property_value): with tempfile.TemporaryDirectory() as temp_dir: - ( - checkpoint_file_path, - training_model_file_path, - _, - _, - _, - ) = _create_training_artifacts(temp_dir) + artifacts = _create_training_artifacts(temp_dir) # Create Checkpoint State. - state = CheckpointState.load_checkpoint(checkpoint_file_path) + state = CheckpointState.load_checkpoint(artifacts.checkpoint_file_path) # Create a Module. - _ = Module(training_model_file_path, state) + _ = Module(artifacts.training_model_file_path, state) # Float values in python are double precision. # Convert to float32 to match the type of the property. @@ -367,8 +324,8 @@ def test_add_get_property(property_value): assert state.properties["property"] == property_value assert len(state.properties) == 1 - CheckpointState.save_checkpoint(state, checkpoint_file_path) - new_state = CheckpointState.load_checkpoint(checkpoint_file_path) + CheckpointState.save_checkpoint(state, artifacts.checkpoint_file_path) + new_state = CheckpointState.load_checkpoint(artifacts.checkpoint_file_path) assert "property" in new_state.properties assert new_state.properties["property"] == property_value assert len(new_state.properties) == 1 @@ -376,21 +333,15 @@ def test_add_get_property(property_value): def test_get_input_output_names(): with tempfile.TemporaryDirectory() as temp_dir: - ( - checkpoint_file_path, - training_model_file_path, - eval_model_file_path, - _, - _, - ) = _create_training_artifacts(temp_dir) + artifacts = _create_training_artifacts(temp_dir) # Create Checkpoint State. - state = CheckpointState.load_checkpoint(checkpoint_file_path) + state = CheckpointState.load_checkpoint(artifacts.checkpoint_file_path) # Create a Module. - model = Module(training_model_file_path, state, eval_model_file_path) + model = Module(artifacts.training_model_file_path, state, artifacts.eval_model_file_path) - training_model = onnx.load(training_model_file_path) + training_model = onnx.load(artifacts.training_model_file_path) assert model.input_names() == [input.name for input in training_model.graph.input][:2] assert model.output_names() == [output.name for output in training_model.graph.output][:1] @@ -518,23 +469,18 @@ def test_train_step_with_ort_values(): labels = OrtValue.ortvalue_from_numpy(labels_np) with tempfile.TemporaryDirectory() as temp_dir: - ( - checkpoint_file_path, - training_model_file_path, - _, - _, - pt_model, - ) = _create_training_artifacts(temp_dir) + artifacts = _create_training_artifacts(temp_dir) + # Create Checkpoint State. - state = CheckpointState.load_checkpoint(checkpoint_file_path) + state = CheckpointState.load_checkpoint(artifacts.checkpoint_file_path) # Create a Module. - model = Module(training_model_file_path, state) + model = Module(artifacts.training_model_file_path, state) model.train() ort_loss = model(inputs, labels) assert isinstance(ort_loss, OrtValue) # Calculate loss using pytorch model to compare it with Module's output. - pt_outputs = pt_model(torch.from_numpy(inputs_np)) + pt_outputs = artifacts.pt_model(torch.from_numpy(inputs_np)) loss_fn = torch.nn.CrossEntropyLoss() pt_loss = loss_fn(pt_outputs, torch.from_numpy(labels_np).long()) @@ -549,17 +495,11 @@ def test_eval_step_with_ort_values(): labels = OrtValue.ortvalue_from_numpy(labels_np) with tempfile.TemporaryDirectory() as temp_dir: - ( - checkpoint_file_path, - training_model_file_path, - eval_model_file_path, - _, - _, - ) = _create_training_artifacts(temp_dir) + artifacts = _create_training_artifacts(temp_dir) # Create Checkpoint State. - state = CheckpointState.load_checkpoint(checkpoint_file_path) + state = CheckpointState.load_checkpoint(artifacts.checkpoint_file_path) # Create a Module. - model = Module(training_model_file_path, state, eval_model_file_path) + model = Module(artifacts.training_model_file_path, state, artifacts.eval_model_file_path) model.train() model(inputs, labels) @@ -572,26 +512,20 @@ def test_eval_step_with_ort_values(): @pytest.mark.parametrize("device", ["cpu", "cuda"]) def test_get_and_set_parameter_values(device): with tempfile.TemporaryDirectory() as temp_dir: - ( - checkpoint_file_path, - training_model_file_path, - eval_model_file_path, - _, - pt_model, - ) = _create_training_artifacts( + artifacts = _create_training_artifacts( temp_dir, requires_grad=["fc2.weight", "fc2.bias"], frozen_params=["fc1.weight", "fc1.bias"] ) - state = CheckpointState.load_checkpoint(checkpoint_file_path) + state = CheckpointState.load_checkpoint(artifacts.checkpoint_file_path) - model = Module(training_model_file_path, state, eval_model_file_path, device=device) + model = Module(artifacts.training_model_file_path, state, artifacts.eval_model_file_path, device=device) - state_dict = pt_model.state_dict() + state_dict = artifacts.pt_model.state_dict() assert len(state_dict) == len(state.parameters) for parameter_name, _ in state.parameters: assert parameter_name in state_dict - for name, pt_param in pt_model.named_parameters(): + for name, pt_param in artifacts.pt_model.named_parameters(): ort_param = state.parameters[name] assert ort_param.name == name assert np.allclose(pt_param.detach().cpu().numpy(), ort_param.data) @@ -612,7 +546,7 @@ def test_get_and_set_parameter_values(device): labels = torch.randint(high=10, size=(64,), dtype=torch.int64).numpy() loss = model(inputs, labels) assert loss is not None - for name, _ in pt_model.named_parameters(): + for name, _ in artifacts.pt_model.named_parameters(): ort_param = state.parameters[name] assert ort_param.name == name if name in ["fc1.weight", "fc1.bias"]: @@ -624,3 +558,111 @@ def test_get_and_set_parameter_values(device): state.parameters["fc1.weight"] = original_param assert np.allclose(state.parameters["fc1.weight"].data, original_param) + + +def test_model_construction_with_nominal_checkpoint(): + with tempfile.TemporaryDirectory() as temp_dir: + artifacts = _create_training_artifacts(temp_dir, nominal_checkpoint=True) + + nominal_state = CheckpointState.load_checkpoint(artifacts.nominal_checkpoint_file_path) + model_with_nominal_state = Module( + artifacts.training_model_file_path, nominal_state, artifacts.eval_model_file_path + ) + optimizer_with_nominal_state = Optimizer(artifacts.optimizer_model_file_path, model_with_nominal_state) + + inputs = torch.randn(64, 784).numpy() + labels = torch.randint(high=10, size=(64,), dtype=torch.int64).numpy() + + err_msg = "Please load the parameter states first" + + # Accessing the checkpoint parameter raises + state_dict = artifacts.pt_model.state_dict() + for param_name in state_dict: + assert param_name in nominal_state.parameters + with pytest.raises(Exception) as exc_info: + _ = nominal_state.parameters["fc1.weight"] + + assert err_msg in str(exc_info.value) + + err_msg = "Please load all the parameter states first" + with pytest.raises(Exception) as exc_info: + nominal_state.parameters["fc1.weight"] = np.ones((10, 10), dtype=np.float32) + + assert err_msg in str(exc_info.value) + + err_msg = "Please load the model parameters first." + + # Getting contiguous parameters raises + with pytest.raises(Exception) as exc_info: + _ = model_with_nominal_state.get_contiguous_parameters() + + assert err_msg in str(exc_info.value) + + # Train step raises + with pytest.raises(Exception) as exc_info: + model_with_nominal_state.train() + model_with_nominal_state(inputs, labels) + + assert err_msg in str(exc_info.value) + + # Optimizer step raises + with pytest.raises(Exception) as exc_info: + optimizer_with_nominal_state.step() + + assert err_msg in str(exc_info.value) + + # Eval step raises + with pytest.raises(Exception) as exc_info: + model_with_nominal_state.eval() + model_with_nominal_state(inputs, labels) + + assert err_msg in str(exc_info.value) + + # Get parameters size does not raise + params_size = model_with_nominal_state.get_parameters_size() + assert params_size > 0 + + +def test_train_with_nominal_checkpoint(): + with tempfile.TemporaryDirectory() as temp_dir: + artifacts = _create_training_artifacts(temp_dir, nominal_checkpoint=True) + + # Create Checkpoint State with nominal checkpoint as well as the complete checkpoint. + complete_state = CheckpointState.load_checkpoint(artifacts.checkpoint_file_path) + nominal_state = CheckpointState.load_checkpoint(artifacts.nominal_checkpoint_file_path) + + # Create a Module with both complete and nominal checkpoint states. + model_with_complete_state = Module(artifacts.training_model_file_path, complete_state) + model_with_nominal_state = Module(artifacts.training_model_file_path, nominal_state) + + optimizer_with_complete_state = Optimizer(artifacts.optimizer_model_file_path, model_with_complete_state) + optimizer_with_nominal_state = Optimizer(artifacts.optimizer_model_file_path, model_with_nominal_state) + + parameter_buffer = model_with_complete_state.get_contiguous_parameters() + model_with_nominal_state.copy_buffer_to_parameters(parameter_buffer, trainable_only=False) + + model_with_complete_state.train() + model_with_nominal_state.train() + + # Generate random data for testing. + inputs = torch.randn(64, 784).numpy() + labels = torch.randint(high=10, size=(64,), dtype=torch.int64).numpy() + + ort_loss_1 = model_with_complete_state(inputs, labels) + ort_loss_2 = model_with_nominal_state(inputs, labels) + + # Calculate loss using pytorch model to compare it with both the Modules' output. + pt_outputs = artifacts.pt_model(torch.from_numpy(inputs)) + loss_fn = torch.nn.CrossEntropyLoss() + pt_loss = loss_fn(pt_outputs, torch.from_numpy(labels).long()) + + assert np.allclose(ort_loss_1, ort_loss_2) + assert np.allclose(ort_loss_1, pt_loss.detach().numpy()) + + optimizer_with_complete_state.step() + optimizer_with_nominal_state.step() + + new_params_1 = model_with_complete_state.get_contiguous_parameters() + new_params_2 = model_with_nominal_state.get_contiguous_parameters() + + assert np.allclose(new_params_1.numpy(), new_params_2.numpy()) diff --git a/orttraining/orttraining/test/training_api/core/checkpoint_test.cc b/orttraining/orttraining/test/training_api/core/checkpoint_test.cc index 1369c9c69865a..5c53addb853e4 100644 --- a/orttraining/orttraining/test/training_api/core/checkpoint_test.cc +++ b/orttraining/orttraining/test/training_api/core/checkpoint_test.cc @@ -95,7 +95,8 @@ TEST(CheckpointApiTest, SaveOnnxModelAsCheckpoint_ThenLoad_CPU) { // Call Save APIs. PathString checkpoint_path{ ConcatPathComponent(tmp_dir.Path(), ORT_TSTR("e2e_ckpt_save_cpu"))}; - ASSERT_STATUS_OK(SaveCheckpoint(trainable_param_values, non_trainable_param_values, checkpoint_path)); + ASSERT_STATUS_OK(SaveCheckpoint(trainable_param_values, non_trainable_param_values, checkpoint_path, + false /* nominal checkpoint */)); /// Phase 3 - Run load checkpoint APIs. /// And check the result comparable with initial parameter values. @@ -193,7 +194,8 @@ TEST(CheckpointApiTest, SaveOnnxModelAsCheckpointThenLoadFromBufferCPU) { // Call Save APIs. PathString checkpoint_path{ ConcatPathComponent(tmp_dir.Path(), ORT_TSTR("e2e_ckpt_save_cpu"))}; - ASSERT_STATUS_OK(SaveCheckpoint(trainable_param_values, non_trainable_param_values, checkpoint_path)); + ASSERT_STATUS_OK(SaveCheckpoint(trainable_param_values, non_trainable_param_values, checkpoint_path, + false /* nominal checkpoint */)); /// Phase 3 - Run load checkpoint APIs. /// And check the result comparable with initial parameter values. @@ -435,4 +437,37 @@ TEST(CheckpointApiTest, SaveCustomPropertyAsCheckpoint_ThenLoad_CPU) { std::string restored_s_data = restored_property_bag.GetProperty(s_property_name); ASSERT_EQ(s_data, restored_s_data); } + +/** + * Loads a nominal checkpoint. Checks for nominal flag, and that the state is empty. + * Saves the checkpoint, and loads it again. Checks for nominal flag, and that the state is empty. + */ +TEST(CheckpointApiTest, LoadAndSaveNominalCheckpoint) { + PathString nominal_checkpoint_path{ORT_TSTR("testdata/training_api/nominal_checkpoint")}; + + CheckpointState checkpoint_state; + ASSERT_STATUS_OK(LoadCheckpoint(nominal_checkpoint_path, checkpoint_state)); + ASSERT_TRUE(checkpoint_state.module_checkpoint_state.is_nominal_state); + for (auto& [name, param] : checkpoint_state.module_checkpoint_state.named_parameters) { + ASSERT_TRUE(param->Data().IsTensor()); + // An empty tensor will have size 1. + ASSERT_EQ(param->Data().Get().Shape().Size(), 1); + } + + // Remove the temporary directory if it already exists. + auto ckpt_test_root_dir = ORT_TSTR("checkpointing_api_test_dir"); + TemporaryDirectory tmp_dir{ckpt_test_root_dir}; + PathString checkpoint_path{ + ConcatPathComponent(tmp_dir.Path(), ORT_TSTR("nominal_checkpoint_2"))}; + ASSERT_STATUS_OK(SaveCheckpoint(checkpoint_state, checkpoint_path, false)); + + CheckpointState checkpoint_state_2; + ASSERT_STATUS_OK(LoadCheckpoint(checkpoint_path, checkpoint_state_2)); + ASSERT_TRUE(checkpoint_state_2.module_checkpoint_state.is_nominal_state); + for (auto& [name, param] : checkpoint_state_2.module_checkpoint_state.named_parameters) { + ASSERT_TRUE(param->Data().IsTensor()); + // An empty tensor will have size 1. + ASSERT_EQ(param->Data().Get().Shape().Size(), 1); + } +} } // namespace onnxruntime::training::test diff --git a/orttraining/orttraining/test/training_api/core/training_api_tests.cc b/orttraining/orttraining/test/training_api/core/training_api_tests.cc index 2170f7957e6a6..e2232687d0b07 100644 --- a/orttraining/orttraining/test/training_api/core/training_api_tests.cc +++ b/orttraining/orttraining/test/training_api/core/training_api_tests.cc @@ -537,6 +537,167 @@ TEST(TrainingApiTest, OptimStep) { } } +TEST(TrainingApiTest, ModuleAndOptimizerWithNominalState) { + auto model_uri = MODEL_FOLDER "training_model.onnx"; + auto eval_model_uri = MODEL_FOLDER "eval_model.onnx"; + auto optim_uri = MODEL_FOLDER "adamw.onnx"; + + onnxruntime::training::api::CheckpointState complete_state; + onnxruntime::training::api::CheckpointState nominal_state; + auto complete_checkpoint_path = MODEL_FOLDER "checkpoint.ckpt"; + auto nominal_checkpoint_path = MODEL_FOLDER "nominal_checkpoint"; + ASSERT_STATUS_OK(onnxruntime::training::api::LoadCheckpoint(complete_checkpoint_path, complete_state)); + ASSERT_STATUS_OK(onnxruntime::training::api::LoadCheckpoint(nominal_checkpoint_path, nominal_state)); + + ASSERT_FALSE(complete_state.module_checkpoint_state.is_nominal_state); + ASSERT_TRUE(nominal_state.module_checkpoint_state.is_nominal_state); + + onnxruntime::SessionOptions session_option; + std::unique_ptr env; + std::vector> providers; +#if defined(USE_CUDA) + providers.push_back(onnxruntime::test::DefaultCudaExecutionProvider()); +#endif + ASSERT_STATUS_OK(Environment::Create(nullptr, env)); + + auto model_identifier = ModelIdentifiers(onnxruntime::ToUTF8String(model_uri), + std::optional(onnxruntime::ToUTF8String(eval_model_uri)), + std::optional(onnxruntime::ToUTF8String(optim_uri))); + auto model_with_complete_state = std::make_unique( + model_identifier, &complete_state, session_option, + *env, providers); + auto model_with_nominal_state = std::make_unique( + model_identifier, &nominal_state, session_option, + *env, providers); + auto optim_with_complete_state = std::make_unique( + model_identifier, &complete_state, session_option, + *env, providers); + auto optim_with_nominal_state = std::make_unique( + model_identifier, &nominal_state, session_option, + *env, providers); + + // Before running the test, copy all the parameters to the nominal module. + ASSERT_EQ(model_with_complete_state->GetParametersSize(), model_with_nominal_state->GetParametersSize()); + int64_t params_size = static_cast(model_with_nominal_state->GetParametersSize()); + OrtValue params_buffer; + Tensor::InitOrtValue(DataTypeImpl::GetType(), {params_size}, + onnxruntime::test::TestCPUExecutionProvider()->CreatePreferredAllocators()[0], + params_buffer); + ASSERT_STATUS_OK(model_with_complete_state->CopyParametersToBuffer(params_buffer, false)); + ASSERT_STATUS_OK(model_with_nominal_state->CopyBufferToParameters(params_buffer, false)); + + ASSERT_STATUS_OK(optim_with_nominal_state->ConstructOptimizerStateAndInputs()); + + OrtValue input, target; + GenerateRandomInput(std::array{2, 784}, input); + target = onnxruntime::test::CreateInputOrtValueOnCPU( + std::array{2}, std::vector(2, 1)); + auto data_loader = std::vector>(4, std::vector{input, target}); + + for (auto it = data_loader.begin(); it != data_loader.end(); ++it) { + std::vector& inputs = *it; + std::vector complete_fetches; + std::vector nominal_fetches; + ASSERT_STATUS_OK(model_with_complete_state->TrainStep(inputs, complete_fetches)); + ASSERT_STATUS_OK(model_with_nominal_state->TrainStep(inputs, nominal_fetches)); + + ASSERT_GT(complete_fetches.size(), 0); + for (size_t i = 0; i < complete_fetches.size(); ++i) { + ASSERT_TRUE(complete_fetches[i].IsTensor()); + ASSERT_TRUE(nominal_fetches[i].IsTensor()); + const Tensor& complete_tensor = complete_fetches[i].Get(); + const Tensor& nominal_tensor = nominal_fetches[i].Get(); + ASSERT_EQ(complete_tensor.Shape(), nominal_tensor.Shape()); + ASSERT_EQ(complete_tensor.DataType(), nominal_tensor.DataType()); + + std::vector complete_fetches_vec; + std::vector nominal_fetches_vec; +#if defined(USE_CUDA) + CudaOrtValueToCpuVec(complete_fetches[i], complete_fetches_vec); + CudaOrtValueToCpuVec(nominal_fetches[i], nominal_fetches_vec); +#else + CpuOrtValueToVec(complete_fetches[i], complete_fetches_vec); + CpuOrtValueToVec(nominal_fetches[i], nominal_fetches_vec); +#endif + + for (size_t j = 0; j < complete_fetches_vec.size(); ++j) { + ASSERT_EQ(complete_fetches_vec[j], nominal_fetches_vec[j]); + } + } + + ASSERT_STATUS_OK(optim_with_complete_state->Step()); + ASSERT_STATUS_OK(optim_with_nominal_state->Step()); + + for (auto& [name, param] : model_with_complete_state->NamedParameters()) { + ASSERT_TRUE(param->Data().IsTensor()); + ASSERT_TRUE(param->Gradient().IsTensor()); + ASSERT_TRUE(model_with_nominal_state->NamedParameters().at(name)->Data().IsTensor()); + ASSERT_TRUE(model_with_nominal_state->NamedParameters().at(name)->Gradient().IsTensor()); + + const Tensor& complete_data = param->Data().Get(); + const Tensor& complete_grad = param->Gradient().Get(); + const Tensor& nominal_data = model_with_nominal_state->NamedParameters().at(name)->Data().Get(); + const Tensor& nominal_grad = model_with_nominal_state->NamedParameters().at(name)->Gradient().Get(); + + ASSERT_EQ(complete_data.Shape(), nominal_data.Shape()); + ASSERT_EQ(complete_data.DataType(), nominal_data.DataType()); + ASSERT_EQ(complete_grad.Shape(), nominal_grad.Shape()); + ASSERT_EQ(complete_grad.DataType(), nominal_grad.DataType()); + + std::vector complete_data_vec; + std::vector complete_grad_vec; + std::vector nominal_data_vec; + std::vector nominal_grad_vec; + +#if defined(USE_CUDA) + CudaOrtValueToCpuVec(param->Data(), complete_data_vec); + CudaOrtValueToCpuVec(param->Gradient(), complete_grad_vec); + CudaOrtValueToCpuVec(model_with_nominal_state->NamedParameters().at(name)->Data(), nominal_data_vec); + CudaOrtValueToCpuVec(model_with_nominal_state->NamedParameters().at(name)->Gradient(), nominal_grad_vec); +#else + CpuOrtValueToVec(param->Data(), complete_data_vec); + CpuOrtValueToVec(param->Gradient(), complete_grad_vec); + CpuOrtValueToVec(model_with_nominal_state->NamedParameters().at(name)->Data(), nominal_data_vec); + CpuOrtValueToVec(model_with_nominal_state->NamedParameters().at(name)->Gradient(), nominal_grad_vec); +#endif + + for (size_t j = 0; j < complete_data_vec.size(); ++j) { + ASSERT_EQ(complete_data_vec[j], nominal_data_vec[j]); + ASSERT_EQ(complete_grad_vec[j], nominal_grad_vec[j]); + } + } + + std::vector complete_eval_fetches; + std::vector nominal_eval_fetches; + ASSERT_STATUS_OK(model_with_complete_state->EvalStep(inputs, complete_eval_fetches)); + ASSERT_STATUS_OK(model_with_nominal_state->EvalStep(inputs, nominal_eval_fetches)); + + ASSERT_GT(complete_eval_fetches.size(), 0); + for (size_t i = 0; i < complete_eval_fetches.size(); ++i) { + ASSERT_TRUE(complete_eval_fetches[i].IsTensor()); + ASSERT_TRUE(nominal_eval_fetches[i].IsTensor()); + const Tensor& complete_tensor = complete_eval_fetches[i].Get(); + const Tensor& nominal_tensor = nominal_eval_fetches[i].Get(); + ASSERT_EQ(complete_tensor.Shape(), nominal_tensor.Shape()); + ASSERT_EQ(complete_tensor.DataType(), nominal_tensor.DataType()); + + std::vector complete_eval_fetches_vec; + std::vector nominal_eval_fetches_vec; +#if defined(USE_CUDA) + CudaOrtValueToCpuVec(complete_eval_fetches[i], complete_eval_fetches_vec); + CudaOrtValueToCpuVec(nominal_eval_fetches[i], nominal_eval_fetches_vec); +#else + CpuOrtValueToVec(complete_eval_fetches[i], complete_eval_fetches_vec); + CpuOrtValueToVec(nominal_eval_fetches[i], nominal_eval_fetches_vec); +#endif + + for (size_t j = 0; j < complete_eval_fetches_vec.size(); ++j) { + ASSERT_EQ(complete_eval_fetches_vec[j], nominal_eval_fetches_vec[j]); + } + } + } +} + } // namespace test } // namespace training } // namespace onnxruntime diff --git a/orttraining/orttraining/test/training_api/core/training_capi_tests.cc b/orttraining/orttraining/test/training_api/core/training_capi_tests.cc index e46952d87c2bf..8f25e1e4c92b8 100644 --- a/orttraining/orttraining/test/training_api/core/training_capi_tests.cc +++ b/orttraining/orttraining/test/training_api/core/training_capi_tests.cc @@ -420,4 +420,79 @@ TEST(TrainingCApiTest, UpdateParameterDifferentDevices) { } #endif +TEST(TrainingCApiTest, ModuleAndOptimizerWithNominalState) { + auto training_model_uri = MODEL_FOLDER "training_model.onnx"; + auto eval_model_uri = MODEL_FOLDER "eval_model.onnx"; + auto optimizer_model_uri = MODEL_FOLDER "adamw.onnx"; + + Ort::Env env; + Ort::SessionOptions session_options_for_complete_state; + Ort::SessionOptions session_options_for_nominal_state; + Ort::CheckpointState complete_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt"); + Ort::CheckpointState nominal_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "nominal_checkpoint"); + +#ifdef USE_CUDA + Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options_for_complete_state, 0)); + Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options_for_nominal_state, 0)); +#endif + + Ort::TrainingSession complete_training_session = Ort::TrainingSession(env, session_options_for_complete_state, complete_state, + training_model_uri, eval_model_uri, optimizer_model_uri); + Ort::TrainingSession nominal_training_session = Ort::TrainingSession(env, session_options_for_nominal_state, nominal_state, + training_model_uri, eval_model_uri, + optimizer_model_uri); + + Ort::Value params_buffer = complete_training_session.ToBuffer(false); + nominal_training_session.FromBuffer(params_buffer); + + for (size_t i = 0; i < 4U; ++i) { + std::vector x(2 * 784); + std::vector x_shape{2, 784}; + GenerateRandomData(x); + + std::vector labels{0, 8}; + std::vector labels_shape{2}; + + Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); + std::vector ort_inputs; + ort_inputs.emplace_back(Ort::Value::CreateTensor(memory_info, x.data(), + x.size() * sizeof(float), + x_shape.data(), x_shape.size(), + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)); + ort_inputs.emplace_back(Ort::Value::CreateTensor(memory_info, labels.data(), + labels.size() * sizeof(int32_t), + labels_shape.data(), labels_shape.size(), + ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32)); + + std::vector complete_fetches = complete_training_session.TrainStep(ort_inputs); + std::vector nominal_fetches = nominal_training_session.TrainStep(ort_inputs); + + ASSERT_EQ(complete_fetches.size(), nominal_fetches.size()); + ASSERT_GT(complete_fetches.size(), 0U); + for (size_t j = 0; j < complete_fetches.size(); ++j) { + ASSERT_TRUE(complete_fetches[j].IsTensor()); + ASSERT_TRUE(nominal_fetches[j].IsTensor()); + + auto complete_tensor_info = complete_fetches[j].GetTensorTypeAndShapeInfo(); + auto nominal_tensor_info = nominal_fetches[j].GetTensorTypeAndShapeInfo(); + + ASSERT_EQ(complete_tensor_info.GetShape(), nominal_tensor_info.GetShape()); + ASSERT_EQ(complete_tensor_info.GetElementType(), nominal_tensor_info.GetElementType()); + + gsl::span complete_data = gsl::span(complete_fetches[j].GetTensorMutableData(), + complete_tensor_info.GetElementCount()); + gsl::span nominal_data = gsl::span(nominal_fetches[j].GetTensorMutableData(), + nominal_tensor_info.GetElementCount()); + + ASSERT_EQ(complete_data, nominal_data); + } + + complete_training_session.OptimizerStep(); + nominal_training_session.OptimizerStep(); + + complete_training_session.LazyResetGrad(); + nominal_training_session.LazyResetGrad(); + } +} + } // namespace onnxruntime::training::test diff --git a/orttraining/orttraining/training_api/checkpoint.cc b/orttraining/orttraining/training_api/checkpoint.cc index dbcef78c3965c..720bdd7e68dd3 100644 --- a/orttraining/orttraining/training_api/checkpoint.cc +++ b/orttraining/orttraining/training_api/checkpoint.cc @@ -174,7 +174,7 @@ Status ToFile(const PathString& checkpoint_path, flatbuffers::FlatBufferBuilder& Status FromTensorProtos( gsl::span trainable_tensor_protos, gsl::span non_trainable_tensor_protos, - const PathString& checkpoint_path) { + const PathString& checkpoint_path, const bool nominal_checkpoint) { const auto check_unique = [](gsl::span tensor_protos, InlinedHashSet& unique_names) { for (const auto& tensor_proto : tensor_protos) { @@ -230,6 +230,7 @@ Status FromTensorProtos( fbs::ModuleStateBuilder module_state_builder(builder); module_state_builder.add_requires_grad_params(fbs_trainable_tensors); module_state_builder.add_frozen_params(fbs_non_trainable_tensors); + module_state_builder.add_is_nominal_state(nominal_checkpoint); flatbuffers::Offset fbs_module_state = module_state_builder.Finish(); fbs::CheckpointBuilder checkpoint_builder(builder); @@ -294,6 +295,7 @@ Status FromModuleState(const ModuleCheckpointState& module_state, fbs::ModuleStateBuilder module_state_builder(builder); module_state_builder.add_requires_grad_params(fbs_trainable_tensors); module_state_builder.add_frozen_params(fbs_non_trainable_tensors); + module_state_builder.add_is_nominal_state(module_state.is_nominal_state); fbs_module_state = module_state_builder.Finish(); return Status::OK(); @@ -513,6 +515,8 @@ Status ToModuleState( module_state.named_parameters.insert({name, param}); } + module_state.is_nominal_state = fbs_module_state.is_nominal_state(); + return Status::OK(); } @@ -646,6 +650,10 @@ Status ToModelProto(gsl::span checkpoint_bytes, ORT_RETURN_IF_NOT(frozen_params, "Checkpoint is invalid. Expected: Valid non-trainable params flatbuffer. Actual: nullptr."); + ORT_RETURN_IF(module_state->is_nominal_state(), + "Cannot load a nominal checkpoint to a model proto. " + "Expected: Complete checkpoint. Actual: Nominal checkpoint."); + InlinedHashMap param_tensor_protos; param_tensor_protos.reserve( static_cast(requires_grad_params->size()) + static_cast(frozen_params->size())); @@ -717,14 +725,33 @@ Status ToCheckpointState(gsl::span checkpoint_bytes, CheckpointSt } // namespace load +#if !defined(ORT_MINIMAL_BUILD) +InlinedVector Nominalize(gsl::span tensor_protos) { + InlinedVector nominal_tensor_protos; + nominal_tensor_protos.reserve(tensor_protos.size()); + for (const auto& tensor_proto : tensor_protos) { + ONNX_NAMESPACE::TensorProto nominal_tensor_proto; + nominal_tensor_proto.set_name(tensor_proto.name()); + nominal_tensor_proto.set_data_type(tensor_proto.data_type()); + nominal_tensor_protos.push_back(nominal_tensor_proto); + } + + return nominal_tensor_protos; +} +#endif + } // namespace #if !defined(ORT_MINIMAL_BUILD) Status SaveCheckpoint(gsl::span trainable_tensor_protos, gsl::span non_trainable_tensor_protos, - const PathString& checkpoint_path) { + const PathString& checkpoint_path, const bool nominal_checkpoint) { ORT_RETURN_IF_NOT(FLATBUFFERS_LITTLEENDIAN, "ORT training checkpoint format only supports little-endian machines"); - return save::FromTensorProtos(trainable_tensor_protos, non_trainable_tensor_protos, checkpoint_path); + return nominal_checkpoint + ? save::FromTensorProtos(Nominalize(trainable_tensor_protos), Nominalize(non_trainable_tensor_protos), + checkpoint_path, nominal_checkpoint) + : save::FromTensorProtos(trainable_tensor_protos, non_trainable_tensor_protos, checkpoint_path, + nominal_checkpoint); } #endif diff --git a/orttraining/orttraining/training_api/checkpoint.h b/orttraining/orttraining/training_api/checkpoint.h index 5d8554662f48d..95d3820a33a70 100644 --- a/orttraining/orttraining/training_api/checkpoint.h +++ b/orttraining/orttraining/training_api/checkpoint.h @@ -49,11 +49,12 @@ Status SaveCheckpoint(const CheckpointState& state, const PathString& checkpoint * @param trainable_tensor_protos trainable parameters in TensorProto format. * @param non_trainable_tensor_protos non-trainable parameters in TensorProto format. * @param checkpoint_path file where checkpoint is saved. + * @param nominal_checkpoint flag indicating whether to save the complete checkpoint or the nominal checkpoint. * @return Status */ Status SaveCheckpoint(gsl::span trainable_tensor_protos, gsl::span non_trainable_tensor_protos, - const PathString& checkpoint_path); + const PathString& checkpoint_path, const bool nominal_checkpoint); #endif /** diff --git a/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h b/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h index 0e8544a7639ba..ed6d151a595b4 100644 --- a/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h +++ b/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h @@ -132,6 +132,7 @@ struct OrtTrainingApi { * \note Note that the training session created with a checkpoint state uses this state to store the entire * training state (including model parameters, its gradients, the optimizer states and the properties). * As a result, it is required that the checkpoint state outlive the lifetime of the training session. + * \note Note that the checkpoint file can be either the complete checkpoint or the nominal checkpoint. * * \param[in] checkpoint_path Path to the checkpoint file * \param[out] checkpoint_state Checkpoint state that contains the states of the training session. @@ -463,10 +464,12 @@ struct OrtTrainingApi { * * The parameters_buffer argument has to be of the size given by OrtTrainingApi::GetParametersSize api call, * with matching setting for trainable_only argument. All the target parameters must be of the same - * datatype. This is a complementary function to OrtTrainingApi::CopyBufferToParameters + * datatype. This is a complementary function to OrtTrainingApi::CopyParametersToBuffer * and can be used to load updated buffer values onto the training state. * Parameter ordering is preserved. * User is responsible for allocating and freeing the resources used by the parameters_buffer. + * In case the training session was created with a nominal checkpoint, invoking this function is required + * to load the updated parameters onto the checkpoint to complete it. * * \param[in] sess The `this` pointer to the training session. * \param[in] trainable_only Whether to skip non-trainable parameters diff --git a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h index 218bef524200c..e78c16136ab3f 100644 --- a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h +++ b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h @@ -58,6 +58,8 @@ using Property = std::variant; * training state (including model parameters, its gradients, the optimizer states and the properties). * The Ort::TrainingSession does not hold a copy of the Ort::CheckpointState and as a result, it is required * that the checkpoint state outlive the lifetime of the training session. + * \note Note that the checkpoint state can be either the complete checkpoint state or the nominal checkpoint + * state depending on the version provided while loading the checkpoint. * */ class CheckpointState : public detail::Base { @@ -386,6 +388,9 @@ class TrainingSession : public detail::Base { Value ToBuffer(const bool only_trainable); /** \brief Loads the training session model parameters from a contiguous buffer + * + * In case the training session was created with a nominal checkpoint, invoking this function is required + * to load the updated parameters onto the checkpoint to complete it. * * \param[in] buffer Contiguous buffer to load the parameters from. */ diff --git a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h index 7d1326a10f8f8..397cba0b0f9de 100644 --- a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h +++ b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h @@ -168,22 +168,23 @@ inline void TrainingSession::FromBuffer(Value& buffer) { auto buffer_size = buffer_shape.front(); + size_t session_buffer_size = 0U; + ThrowOnError(GetTrainingApi().GetParametersSize(p_, &session_buffer_size, false)); + + if (buffer_size == static_cast(session_buffer_size)) { + ThrowOnError(GetTrainingApi().CopyBufferToParameters(p_, buffer, false)); + return; + } + size_t session_buffer_size_trainable_only = 0U; ThrowOnError(GetTrainingApi().GetParametersSize(p_, &session_buffer_size_trainable_only, true)); if (buffer_size == static_cast(session_buffer_size_trainable_only)) { ThrowOnError(GetTrainingApi().CopyBufferToParameters(p_, buffer, true)); return; - } - - size_t session_buffer_size = 0U; - ThrowOnError(GetTrainingApi().GetParametersSize(p_, &session_buffer_size, false)); - - if (buffer_size != static_cast(session_buffer_size)) { + } else { ThrowStatus(Status("Incorrect buffer size received.", OrtErrorCode::ORT_INVALID_ARGUMENT)); } - - ThrowOnError(GetTrainingApi().CopyBufferToParameters(p_, buffer, false)); } inline CheckpointState CheckpointState::LoadCheckpoint(const std::basic_string& path_to_checkpoint) { diff --git a/orttraining/orttraining/training_api/module.cc b/orttraining/orttraining/training_api/module.cc index cf49a01517d6b..41ed79d285533 100644 --- a/orttraining/orttraining/training_api/module.cc +++ b/orttraining/orttraining/training_api/module.cc @@ -6,6 +6,8 @@ #include "core/common/safeint.h" #include "core/common/string_utils.h" #include "core/framework/execution_provider.h" +#include "core/framework/mldata_type_utils.h" +#include "core/framework/tensorprotoutils.h" #include "core/session/inference_session.h" #include "core/session/environment.h" #include "core/session/onnxruntime_session_options_config_keys.h" @@ -117,6 +119,75 @@ Status TransformModelInputsForInference(Graph& inference_graph, return Status::OK(); } #endif + +InlinedHashMap BuildParameterToInputNodeArgMap(const ModuleCheckpointState& state, + const InputDefList* model_inputs) { + ORT_ENFORCE(model_inputs != nullptr, "Model inputs are not defined."); + InlinedHashMap parameter_to_input_node_arg_map; + parameter_to_input_node_arg_map.reserve(state.named_parameters.size()); + for (const auto& input_def : *model_inputs) { + const std::string& input_name = input_def->Name(); + const auto param_it = state.named_parameters.find(input_name); + if (param_it == state.named_parameters.end()) { + continue; + } + parameter_to_input_node_arg_map[input_name] = input_def; + } + return parameter_to_input_node_arg_map; +} + +InlinedHashMap BuildParameterToGradInputIndexMap(gsl::span grad_names) { + InlinedHashMap param_name_to_grad_input_index_map; + param_name_to_grad_input_index_map.reserve(grad_names.size()); + for (size_t i = 0; i < grad_names.size(); ++i) { + std::string param_name; + utils::GetParamNameFromGradient(grad_names[i], param_name); + param_name_to_grad_input_index_map.insert({param_name, i}); + } + return param_name_to_grad_input_index_map; +} + +Status LoadParameter(const std::string& param_name, const Tensor& src_weight_tensor, + const SessionState& session_state, const bool force_load, + const InlinedHashMap& param_to_grad_index, + gsl::span grad_names, Parameter& param) { + InlinedVector node_info_vec; + ORT_THROW_IF_ERROR(session_state.GetInputNodeInfo(param_name, node_info_vec)); + const auto& node_info = node_info_vec.front(); + const auto target_device = *node_info.device; + for (auto it = node_info_vec.begin(); it != node_info_vec.end(); ++it) { + ORT_ENFORCE(target_device == *(it->device), "Inconsistent device requirements found for input: ", param_name); + } + + if (force_load || src_weight_tensor.Location().device.Type() != target_device.Type()) { + auto weight_allocator = session_state.GetAllocator(target_device); + ORT_ENFORCE(weight_allocator != nullptr); + + // Create a new tensor on the target_device and switch the source_ortvalue to point to this new tensor + auto dst_weight_tensor = std::make_unique(src_weight_tensor.DataType(), src_weight_tensor.Shape(), + weight_allocator); + ORT_THROW_IF_ERROR(session_state.GetDataTransferMgr().CopyTensor(src_weight_tensor, *dst_weight_tensor.get())); + auto ml_tensor_type = DataTypeImpl::GetType(); + param.Data().Init(dst_weight_tensor.release(), ml_tensor_type, ml_tensor_type->GetDeleteFunc()); + } + + if (param.RequiresGrad()) { + // Create gradient accumulation buffer. + auto grad_it = param_to_grad_index.find(param_name); + ORT_ENFORCE(grad_it != param_to_grad_index.end(), "Gradient buffer input not provided for param: ", + param_name); + + const size_t grad_input_index = grad_it->second; + auto& param_grad_name = grad_names[grad_input_index]; + + OrtValue param_grad; + ORT_THROW_IF_ERROR(utils::CreateZeroValuedOrtValueLike(session_state, param.Data(), param_grad)); + ORT_THROW_IF_ERROR(param.SetGrad(param_grad_name, param_grad)); + } + + return Status::OK(); +} + } // namespace Status Parameter::CopyTo(const DataTransferManager* data_transfer_manager, OrtValue& data) const { @@ -251,7 +322,6 @@ Module::Module(const ModelIdentifiers& model_identifiers, // user inputs, weights, gradients, reset_grad InlinedVector user_input_names, param_input_names, grad_input_names, reset_grad_name; - std::unordered_map param_name_to_grad_input_index_map; for (const auto& input_name : train_input_names) { auto it = state_->module_checkpoint_state.named_parameters.find(input_name); if (it != state_->module_checkpoint_state.named_parameters.end()) { @@ -259,7 +329,6 @@ Module::Module(const ModelIdentifiers& model_identifiers, } else if (input_name == ACCUMULATE_GRAD_CONTROL_INPUT_NAME) { reset_grad_name.emplace_back(input_name); } else if (std::string param_name; utils::GetParamNameFromGradient(input_name, param_name)) { - param_name_to_grad_input_index_map.insert({param_name, grad_input_names.size()}); grad_input_names.emplace_back(input_name); } else { user_input_names.emplace_back(input_name); @@ -268,11 +337,7 @@ Module::Module(const ModelIdentifiers& model_identifiers, gradients_.resize(grad_input_names.size()); - train_input_names_ = user_input_names; - train_user_input_count_ = user_input_names.size(); - train_input_names_.insert(train_input_names_.end(), param_input_names.begin(), param_input_names.end()); - train_input_names_.insert(train_input_names_.end(), grad_input_names.begin(), grad_input_names.end()); - train_input_names_.insert(train_input_names_.end(), reset_grad_name.begin(), reset_grad_name.end()); + train_input_names_ = TrainInputNames(user_input_names, param_input_names, grad_input_names); for (const auto& output_name : train_output_names) { if (std::string param_name; !utils::GetParamNameFromGradient(output_name, param_name)) { @@ -280,58 +345,24 @@ Module::Module(const ModelIdentifiers& model_identifiers, } } - // Loop each parameter, and allocate its memory based on the user-specified device. - auto& train_sess_state = train_sess_->GetSessionState(); - for (auto& param_name : param_input_names) { - auto params_iter = state_->module_checkpoint_state.named_parameters.find(param_name); - ORT_ENFORCE(params_iter != state_->module_checkpoint_state.named_parameters.end()); - - // Retrieve the target device for "param_name". - InlinedVector node_info_vec; - ORT_THROW_IF_ERROR(train_sess_state.GetInputNodeInfo(param_name, node_info_vec)); - const auto& node_info = node_info_vec.front(); - const auto target_device = *node_info.device; - for (auto it = node_info_vec.begin(); it != node_info_vec.end(); ++it) { - ORT_ENFORCE(target_device == *(it->device), "Inconsistent device requirements found for input: ", param_name); - } - - // Copy ortvalue buffer from CPU to target_device for this "param_name" (based on graph partitioning) - // Only copies data if the target device is not the same as the current device the buffer is placed on - OrtValue& param_data = params_iter->second->Data(); - ORT_ENFORCE(param_data.IsTensor()); - const Tensor& param_data_tensor = param_data.Get(); - // If the source device type is already the same as target device skip copy - if (param_data_tensor.Location().device.Type() != target_device.Type()) { - // TODO: move this outside of the for loop? - auto target_allocator = train_sess_state.GetAllocator(target_device); - ORT_ENFORCE(target_allocator != nullptr); - - // Create a new tensor on the target_device and switch the source_ortvalue to point to this new tensor - auto target_tensor = std::make_unique(param_data_tensor.DataType(), param_data_tensor.Shape(), - target_allocator); - ORT_THROW_IF_ERROR(train_sess_state.GetDataTransferMgr().CopyTensor(param_data_tensor, *target_tensor.get())); - auto ml_tensor_type = DataTypeImpl::GetType(); - param_data.Init(target_tensor.release(), ml_tensor_type, ml_tensor_type->GetDeleteFunc()); - } - - weights_.push_back(param_data); - weight_names_.push_back(param_name); - - // Create gradient buffer when parameter requires gradient. - if (params_iter->second->RequiresGrad()) { - // Create gradient accumulation buffer. - auto it = param_name_to_grad_input_index_map.find(param_name); - ORT_ENFORCE(it != param_name_to_grad_input_index_map.end(), "Gradient buffer input not provided for param: ", - param_name); - - const size_t grad_input_index = it->second; - auto& param_grad_name = grad_input_names[grad_input_index]; - // TODO: don't pre-allocate the gradient buffer. - // Gradient usually stays on the same device of its parameter. - OrtValue param_grad; - ORT_THROW_IF_ERROR(utils::CreateZeroValuedOrtValueLike(train_sess_state, param_data, param_grad)); - ORT_THROW_IF_ERROR(params_iter->second->SetGrad(param_grad_name, param_grad)); - gradients_[grad_input_index] = params_iter->second->Gradient(); + if (!state_->module_checkpoint_state.is_nominal_state) { + // ORT_THROW_IF_ERROR(AllocateMemoryForWeights()); + // Loop each parameter, and allocate its memory based on the user-specified device. + const auto param_to_grad_index = BuildParameterToGradInputIndexMap(train_input_names_.GradientInputNames()); + for (auto& param_name : train_input_names_.WeightsInputNames()) { + auto params_iter = state_->module_checkpoint_state.named_parameters.find(param_name); + ORT_ENFORCE(params_iter != state_->module_checkpoint_state.named_parameters.end()); + + OrtValue& param_data = params_iter->second->Data(); + ORT_ENFORCE(param_data.IsTensor(), "Expected: Parameter data should be of tensor type. Actual: ", + params_iter->second->Name(), " is not a tensor."); + ORT_THROW_IF_ERROR(LoadParameter(param_name, param_data.Get(), train_sess_->GetSessionState(), + false /* force_load */, param_to_grad_index, + train_input_names_.GradientInputNames(), *params_iter->second)); + weights_.push_back(param_data); + if (params_iter->second->RequiresGrad()) { + gradients_[param_to_grad_index.at(param_name)] = params_iter->second->Gradient(); + } } } @@ -414,16 +445,24 @@ std::string Module::GetEvalModelOutputName(size_t index) const { size_t Module::GetParametersSize(const bool trainable_only) const { SafeInt parameters_size = 0; - for (const auto& it : state_->module_checkpoint_state.named_parameters) { - if (trainable_only && !it.second->RequiresGrad()) { + const auto model_inputs_with_error = GetTrainingModelInputs(); + ORT_THROW_IF_ERROR(model_inputs_with_error.first); + ORT_ENFORCE(model_inputs_with_error.second, "Training model graph inputs are not defined."); + for (const auto& input_def : *model_inputs_with_error.second) { + const std::string& input_name = input_def->Name(); + const auto param_it = state_->module_checkpoint_state.named_parameters.find(input_name); + if (param_it == state_->module_checkpoint_state.named_parameters.end() || + (trainable_only && !param_it->second->RequiresGrad())) { continue; } - parameters_size += it.second->Data().Get().Shape().Size(); + parameters_size += onnxruntime::utils::GetTensorShapeFromTensorShapeProto(*input_def->Shape()).Size(); } return parameters_size; } std::vector> Module::Parameters() const { + ORT_ENFORCE(!state_->module_checkpoint_state.is_nominal_state, + "Cannot fetch parameters from a nominal checkpoint state. Please load the model parameters first."); std::vector> params; for (auto& it : state_->module_checkpoint_state.named_parameters) { params.push_back(it.second); @@ -432,23 +471,27 @@ std::vector> Module::Parameters() const { } std::unordered_map> Module::NamedParameters() const { + ORT_ENFORCE(!state_->module_checkpoint_state.is_nominal_state, + "Cannot fetch named parameters from a nominal checkpoint state. Please load the model parameters first."); return state_->module_checkpoint_state.named_parameters; } Status Module::CopyParametersToBuffer(OrtValue& parameters_buffer, const bool trainable_only) { - ORT_ENFORCE(parameters_buffer.IsAllocated(), "Parameters buffer should be pre-allocated."); - ORT_ENFORCE(parameters_buffer.IsTensor(), "Parameters buffer should be of tensor type."); + ORT_RETURN_IF(state_->module_checkpoint_state.is_nominal_state, + "Cannot copy parameters from a nominal checkpoint state. Please load the model parameters first."); + ORT_RETURN_IF_NOT(parameters_buffer.IsAllocated(), "Parameters buffer should be pre-allocated."); + ORT_RETURN_IF_NOT(parameters_buffer.IsTensor(), "Parameters buffer should be of tensor type."); auto* init_tensor = parameters_buffer.GetMutable(); ORT_ENFORCE(nullptr != init_tensor); auto expected_buffer_size = static_cast(GetParametersSize(trainable_only)); - ORT_ENFORCE(init_tensor->Shape().Size() == expected_buffer_size, - "Parameters buffer size incorrect. Expected:", expected_buffer_size, - ", Actual:", init_tensor->Shape().Size()); + ORT_RETURN_IF(init_tensor->Shape().Size() != expected_buffer_size, + "Parameters buffer size incorrect. Expected:", expected_buffer_size, + ", Actual:", init_tensor->Shape().Size()); const DataTransferManager& sess_data_transfer_manager = train_sess_->GetDataTransferManager(); size_t offset = 0; - for (const auto& param_name : weight_names_) { + for (const auto& param_name : train_input_names_.WeightsInputNames()) { auto& param = state_->module_checkpoint_state.named_parameters.at(param_name); if (trainable_only && !param->RequiresGrad()) { continue; @@ -458,7 +501,7 @@ Status Module::CopyParametersToBuffer(OrtValue& parameters_buffer, const bool tr const TensorShape& shape = weight_tensor->Shape(); auto element_type = init_tensor->DataType(); - ORT_ENFORCE(weight_tensor->DataType() == element_type, "Data types must match."); + ORT_RETURN_IF(weight_tensor->DataType() != element_type, "Data types must match."); const OrtMemoryInfo& info = init_tensor->Location(); std::unique_ptr p_tensor; @@ -470,54 +513,102 @@ Status Module::CopyParametersToBuffer(OrtValue& parameters_buffer, const bool tr data_buffer + offset, info); } else { - ORT_THROW("Unsupported type: ", element_type); + ORT_THROW("Unsupported type: ", element_type, " encountered while copying parameters to buffer. ", + "Only float is supported."); } - ORT_THROW_IF_ERROR(sess_data_transfer_manager.CopyTensor(*weight_tensor, *p_tensor.get())); + ORT_RETURN_IF_ERROR(sess_data_transfer_manager.CopyTensor(*weight_tensor, *p_tensor.get())); offset += shape.Size(); } return Status::OK(); } Status Module::CopyBufferToParameters(OrtValue& parameters_buffer, const bool trainable_only) { - ORT_ENFORCE(parameters_buffer.IsAllocated(), "Parameters buffer should be pre-allocated."); - ORT_ENFORCE(parameters_buffer.IsTensor(), "Parameters buffer should be of tensor type."); - auto* init_tensor = parameters_buffer.GetMutable(); - ORT_ENFORCE(nullptr != init_tensor); + // In case of a nominal checkpoint state, all parameters need to be loaded into the model. + // i.e. trainable_only must be false. + ORT_RETURN_IF(trainable_only && state_->module_checkpoint_state.is_nominal_state, + "For nominal checkpoint state, all parameters need to be loaded into the model " + "(trainable_only = false)."); + ORT_RETURN_IF_NOT(parameters_buffer.IsAllocated(), "Parameters buffer should be pre-allocated."); + ORT_RETURN_IF_NOT(parameters_buffer.IsTensor(), "Parameters buffer should be of tensor type."); + auto* buffer_tensor = parameters_buffer.GetMutable(); + ORT_RETURN_IF(nullptr == buffer_tensor, "Expected valid parameter buffer. Actual: nullptr."); auto expected_buffer_size = static_cast(GetParametersSize(trainable_only)); - ORT_ENFORCE(init_tensor->Shape().Size() == expected_buffer_size, - "Parameters buffer size incorrect. Expected:", expected_buffer_size, - ", Actual:", init_tensor->Shape().Size()); + ORT_RETURN_IF(buffer_tensor->Shape().Size() != expected_buffer_size, + "Parameters buffer size incorrect. Expected:", expected_buffer_size, + ", Actual:", buffer_tensor->Shape().Size()); + auto& train_sess_state = train_sess_->GetSessionState(); const DataTransferManager& sess_data_transfer_manager = train_sess_->GetDataTransferManager(); + const auto model_inputs_with_error = GetTrainingModelInputs(); + ORT_RETURN_IF_ERROR(model_inputs_with_error.first); + ORT_RETURN_IF_NOT(model_inputs_with_error.second, "Training model graph inputs are not defined."); + const auto param_to_node_arg = BuildParameterToInputNodeArgMap(state_->module_checkpoint_state, + model_inputs_with_error.second); + const auto param_to_grad_index = BuildParameterToGradInputIndexMap(train_input_names_.GradientInputNames()); + + if (state_->module_checkpoint_state.is_nominal_state) { + // weights_ vector is not initialized for a nominal state. This function is expected to + // initialize the weights_. + ORT_ENFORCE(weights_.empty(), "Weights vector should be empty for a nominal state."); + } size_t offset = 0; - for (const auto& param_name : weight_names_) { + for (const auto& param_name : train_input_names_.WeightsInputNames()) { auto& param = state_->module_checkpoint_state.named_parameters.at(param_name); if (trainable_only && !param->RequiresGrad()) { continue; } OrtValue& weight = param->Data(); - auto* weight_tensor = weight.GetMutable(); - const TensorShape& shape = weight_tensor->Shape(); - auto element_type = init_tensor->DataType(); - ORT_ENFORCE(weight_tensor->DataType() == element_type, "Data types must match."); + auto param_it = param_to_node_arg.find(param_name); + const TensorShape shape = onnxruntime::utils::GetTensorShapeFromTensorShapeProto( + *(param_it->second->Shape())); + const auto element_type = static_cast( + onnxruntime::utils::GetMLDataType(*param_it->second)) + ->GetElementType(); - const OrtMemoryInfo& info = init_tensor->Location(); - std::unique_ptr p_tensor; + const OrtMemoryInfo& info = buffer_tensor->Location(); + std::unique_ptr src_tensor; if (onnxruntime::utils::IsPrimitiveDataType(element_type)) { - float* data_buffer = init_tensor->MutableData(); - p_tensor = std::make_unique(element_type, - shape, - data_buffer + offset, - info); + float* data_buffer = buffer_tensor->MutableData(); + src_tensor = std::make_unique(element_type, + shape, + data_buffer + offset, + info); + } else { + ORT_THROW("Unsupported type: ", element_type, " encountered while copying buffer to parameters. ", + "Only float is supported."); + } + + if (state_->module_checkpoint_state.is_nominal_state) { + // If state is a nominal state, then we first need to allocate the memory for + // parameters and their gradients in the checkpoint state before copying the data. + ORT_RETURN_IF_ERROR(LoadParameter(param_name, *src_tensor, train_sess_state, true, + param_to_grad_index, train_input_names_.GradientInputNames(), + *param)); + weights_.push_back(param->Data()); + if (param->RequiresGrad()) { + // It is expected that the gradients_ vector is already initialized with the correct size + // in the Module constructor (even though the OrtValues contained in the vector are empty). + gradients_[param_to_grad_index.at(param_name)] = param->Gradient(); + } } else { - ORT_THROW("Unsupported type: ", element_type); + // If state is not a nominal state, then we can directly copy the data to the existing + // parameters in the checkpoint state. + auto* weight_tensor = weight.GetMutable(); + ORT_ENFORCE(weight_tensor->DataType() == element_type, "Data types must match."); + ORT_THROW_IF_ERROR(sess_data_transfer_manager.CopyTensor(*src_tensor.get(), *weight_tensor)); } - ORT_THROW_IF_ERROR(sess_data_transfer_manager.CopyTensor(*p_tensor.get(), *weight_tensor)); + offset += shape.Size(); } + + if (state_->module_checkpoint_state.is_nominal_state) { + // Once the parameters are loaded, the state is no longer a nominal state. + state_->module_checkpoint_state.is_nominal_state = false; + } + return Status::OK(); } @@ -527,6 +618,9 @@ Status Module::LazyResetGrad() { } Status Module::TrainStep(const std::vector& inputs, std::vector& outputs) { + ORT_RETURN_IF(state_->module_checkpoint_state.is_nominal_state, + "Cannot perform TrainStep with a nominal state. Please load the model parameters first."); + std::vector> params; std::vector feeds{inputs}; feeds.insert(feeds.end(), weights_.begin(), weights_.end()); feeds.insert(feeds.end(), gradients_.begin(), gradients_.end()); @@ -535,7 +629,7 @@ Status Module::TrainStep(const std::vector& inputs, std::vector(!accumulate_gradient_, &reset_grad_input); feeds.push_back(reset_grad_input); - ORT_THROW_IF_ERROR(train_sess_->Run(RunOptions(), train_input_names_, feeds, train_output_names_, &outputs)); + ORT_THROW_IF_ERROR(train_sess_->Run(RunOptions(), train_input_names_.AllInputNames(), feeds, train_output_names_, &outputs)); // Reset the flag after every step. In case the ResetGrad was called before running // the current step, it will have done the effective resetting during the @@ -546,6 +640,8 @@ Status Module::TrainStep(const std::vector& inputs, std::vector& inputs, std::vector& outputs) { + ORT_RETURN_IF(state_->module_checkpoint_state.is_nominal_state, + "Cannot perform EvalStep with a nominal state. Please load the model parameters first."); ORT_ENFORCE(nullptr != eval_sess_, "Evaluation session not initialized."); std::vector feeds{inputs}; feeds.insert(feeds.end(), weights_.begin(), weights_.end()); @@ -560,6 +656,8 @@ Status Module::EvalStep(const std::vector& inputs, std::vector graph_output_names) const { + ORT_RETURN_IF(state_->module_checkpoint_state.is_nominal_state, + "Cannot export the model with a nominal state. Please load the model parameters first."); ORT_RETURN_IF(!eval_sess_ || !eval_model_path_.has_value(), "Eval model was not provided. Cannot export a model for inferencing."); @@ -586,7 +684,7 @@ Status Module::ExportModelForInferencing(const std::string& inference_model_path #endif size_t Module::GetTrainingModelInputCount() const noexcept { - return train_user_input_count_; + return train_input_names_.UserInputNames().size(); } size_t Module::GetEvalModelInputCount() const noexcept { @@ -594,10 +692,10 @@ size_t Module::GetEvalModelInputCount() const noexcept { } std::string Module::GetTrainingModelInputName(size_t index) const { - ORT_ENFORCE(index < train_user_input_count_, - "Train input name index out of range. Expected in range [0-", train_user_input_count_, "). Actual: ", + ORT_ENFORCE(index < train_input_names_.UserInputNames().size(), + "Train input name index out of range. Expected in range [0-", train_input_names_.UserInputNames().size(), "). Actual: ", index); - return train_input_names_.at(index); + return train_input_names_.UserInputNames()[index]; } std::string Module::GetEvalModelInputName(size_t index) const { @@ -615,6 +713,43 @@ std::pair Module::GetEvalModelInputs() cons return eval_sess_->GetModelInputs(); } +Module::TrainInputNames::TrainInputNames(gsl::span user_input_names, + gsl::span weights_input_names, + gsl::span gradient_input_names) { + train_input_names_.reserve(user_input_names.size() + + weights_input_names.size() + + gradient_input_names.size() + + 1U); // +1 for the reset gradient flag input + train_input_index_offsets_.reserve(3); + + train_input_names_.insert(train_input_names_.end(), + user_input_names.begin(), user_input_names.end()); + train_input_index_offsets_.push_back(train_input_names_.size()); + train_input_names_.insert(train_input_names_.end(), + weights_input_names.begin(), weights_input_names.end()); + train_input_index_offsets_.push_back(train_input_names_.size()); + train_input_names_.insert(train_input_names_.end(), + gradient_input_names.begin(), gradient_input_names.end()); + train_input_index_offsets_.push_back(train_input_names_.size()); + train_input_names_.push_back(ACCUMULATE_GRAD_CONTROL_INPUT_NAME); +} + +gsl::span Module::TrainInputNames::AllInputNames() const { return train_input_names_; } + +gsl::span Module::TrainInputNames::UserInputNames() const { + return gsl::span{train_input_names_.begin(), train_input_index_offsets_[0]}; +} + +gsl::span Module::TrainInputNames::WeightsInputNames() const { + return gsl::span{train_input_names_.begin() + train_input_index_offsets_[0], + train_input_index_offsets_[1] - train_input_index_offsets_[0]}; +} + +gsl::span Module::TrainInputNames::GradientInputNames() const { + return gsl::span{train_input_names_.begin() + train_input_index_offsets_[1], + train_input_index_offsets_[2] - train_input_index_offsets_[1]}; +} + } // namespace api } // namespace training } // namespace onnxruntime diff --git a/orttraining/orttraining/training_api/module.h b/orttraining/orttraining/training_api/module.h index f323e6be72d49..917887404217f 100644 --- a/orttraining/orttraining/training_api/module.h +++ b/orttraining/orttraining/training_api/module.h @@ -53,6 +53,7 @@ struct ModuleCheckpointState { public: std::unordered_map> named_parameters; const DataTransferManager* train_session_data_transfer_mgr; + bool is_nominal_state = false; }; struct CheckpointState; @@ -87,19 +88,28 @@ struct Module { ~Module(); // Return the trainable/nontrainable parameters + // If the parameter state is not available; i.e. the module was created using the nominal checkpoint, + // and the state has not been loaded yet, then this function will raise an exception. std::vector> Parameters() const; + // Return the trainable/nontrainable parameters as a map + // If the parameter state is not available; i.e. the module was created using the nominal checkpoint, + // and the state has not been loaded yet, then this function will raise an exception. std::unordered_map> NamedParameters() const; // Reset and release the gradient buffer of all trainable params lazily. Status LazyResetGrad(); // Train Step – does forward and backward computation. The outputs will be the forward’s outputs. - // Gradients will be accumulated within the Parameter object + // Gradients will be accumulated within the Parameter object. + // If the parameter state is not available; i.e. the module was created using the nominal checkpoint, + // and the state has not been loaded yet, then this function will return an error. Status TrainStep(const std::vector& inputs, std::vector& outputs); // Eval Step – does forward computation. This will use a separate inference session // and take in a separate inference graph, while sharing the parameters + // If the parameter state is not available; i.e. the module was created using the nominal checkpoint, + // and the state has not been loaded yet, then this function will return an error. Status EvalStep(const std::vector& inputs, std::vector& outputs); // Returns the output count for training graph @@ -118,14 +128,20 @@ struct Module { size_t GetParametersSize(const bool trainable_only = true) const; // Copy parameters onto contiguous buffer held by parameters_buffer + // If the parameter state is not available; i.e. the module was created using the nominal checkpoint, + // and the state has not been loaded yet, then this function will return an error. Status CopyParametersToBuffer(OrtValue& parameters_buffer, const bool trainable_only = true); // Copy parameter values from contiguous buffer held by parameters_buffer onto parameters + // This function is responsible for completing the nominal checkpoint state. The checkpoint + // state will no longer be nominal after the successful completion of this function. Status CopyBufferToParameters(OrtValue& parameters_buffer, const bool trainable_only = true); #if !defined(ORT_MINIMAL_BUILD) // Load the eval model from eval_model_path_or_bytes and transform it for the purpose of - // inferencing, and serialize to given path + // inferencing, and serialize to given path. + // If the parameter state is not available; i.e. the module was created using the nominal checkpoint, + // and the state has not been loaded yet, then this function will return an error. Status ExportModelForInferencing(const std::string& inference_model_path, gsl::span graph_output_names) const; #endif @@ -152,11 +168,28 @@ struct Module { std::unique_ptr train_sess_{nullptr}; std::unique_ptr eval_sess_{nullptr}; - InlinedVector train_input_names_; + struct TrainInputNames { + private: + InlinedVector train_input_names_; + InlinedVector train_input_index_offsets_; // offset range[[0], [1]) = user input names + // offset range[[1], [2]) = weights input names + // offset range[[2], [3]) = gradient input names + public: + TrainInputNames() = default; + TrainInputNames(gsl::span user_input_names, + gsl::span weights_input_names, + gsl::span gradient_input_names); + + gsl::span AllInputNames() const; + gsl::span UserInputNames() const; + gsl::span WeightsInputNames() const; + gsl::span GradientInputNames() const; + }; + + TrainInputNames train_input_names_; InlinedVector train_output_names_; InlinedVector eval_input_names_; InlinedVector eval_output_names_; - InlinedVector weight_names_; InlinedVector weights_; InlinedVector gradients_; @@ -165,7 +198,6 @@ struct Module { bool accumulate_gradient_ = false; std::optional eval_model_path_; - size_t train_user_input_count_{0U}; size_t eval_user_input_count_{0U}; }; diff --git a/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc b/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc index 38a9aad9640ea..0ed41f670f9e3 100644 --- a/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc +++ b/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc @@ -568,9 +568,16 @@ ORT_API_STATUS_IMPL(OrtTrainingApis::GetParameterTypeAndShape, _In_ const OrtChe API_IMPL_BEGIN auto chkpt_state = reinterpret_cast(checkpoint_state); + if (chkpt_state->module_checkpoint_state.is_nominal_state) { + const std::string err_msg = + "Parameter type and shape cannot be retrieved from nominal checkpoint state. " + "Please load the parameter states first."; + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, err_msg.c_str()); + } + auto it = chkpt_state->module_checkpoint_state.named_parameters.find(parameter_name); if (it == chkpt_state->module_checkpoint_state.named_parameters.end()) { - std::string err_msg = "Parameter name " + std::string(parameter_name) + " not found in checkpoint state."; + const std::string err_msg = "Parameter name " + std::string(parameter_name) + " not found in checkpoint state."; return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, err_msg.c_str()); } @@ -586,9 +593,15 @@ ORT_API_STATUS_IMPL(OrtTrainingApis::UpdateParameter, _Inout_ OrtCheckpointState } auto chkpt_state = reinterpret_cast(checkpoint_state); + if (chkpt_state->module_checkpoint_state.is_nominal_state) { + const std::string err_msg = + "Parameter cannot be updated for nominal checkpoint state. Please load all the parameter states first."; + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, err_msg.c_str()); + } + auto it = chkpt_state->module_checkpoint_state.named_parameters.find(parameter_name); if (it == chkpt_state->module_checkpoint_state.named_parameters.end()) { - std::string err_msg = "Parameter name " + std::string(parameter_name) + " not found in checkpoint state."; + const std::string err_msg = "Parameter name " + std::string(parameter_name) + " not found in checkpoint state."; return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, err_msg.c_str()); } ORT_API_RETURN_IF_STATUS_NOT_OK(it->second->CopyFrom( @@ -608,9 +621,15 @@ ORT_API_STATUS_IMPL(OrtTrainingApis::GetParameter, _In_ const OrtCheckpointState } auto chkpt_state = reinterpret_cast(checkpoint_state); + if (chkpt_state->module_checkpoint_state.is_nominal_state) { + const std::string err_msg = + "Parameter cannot be retrieved from nominal checkpoint state. Please load the parameter states first."; + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, err_msg.c_str()); + } + auto it = chkpt_state->module_checkpoint_state.named_parameters.find(parameter_name); if (it == chkpt_state->module_checkpoint_state.named_parameters.end()) { - std::string err_msg = "Parameter name " + std::string(parameter_name) + " not found in checkpoint state."; + const std::string err_msg = "Parameter name " + std::string(parameter_name) + " not found in checkpoint state."; return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, err_msg.c_str()); } diff --git a/orttraining/orttraining/training_api/optimizer.cc b/orttraining/orttraining/training_api/optimizer.cc index 7f583ce8f6e76..84c35e6100385 100644 --- a/orttraining/orttraining/training_api/optimizer.cc +++ b/orttraining/orttraining/training_api/optimizer.cc @@ -21,8 +21,8 @@ namespace { constexpr char GROUP_ZERO_NAME[] = "group0"; static constexpr std::array CommonOptimizerInputs{"learning_rate", "step", "params", "gradients"}; -Status GraphInputsAreExpected(gsl::span actual_graph_inputs, - gsl::span expected_graph_inputs) { +Status GraphInputsAreExpected(gsl::span actual_graph_inputs, + gsl::span expected_graph_inputs) { const auto stringify = [](const auto& container) { if (container.empty()) { return std::string("[]"); @@ -245,8 +245,17 @@ Optimizer::Optimizer(const ModelIdentifiers& model_identifiers, if (!find_group_zero) state_->optimizer_checkpoint_state.group_named_optimizer_states.insert( {GROUP_ZERO_NAME, std::make_shared()}); - ORT_THROW_IF_ERROR(GenerateMomentumNamedStates(state_->optimizer_checkpoint_state)); - ORT_THROW_IF_ERROR(ConstructInputs()); + if (!state_->module_checkpoint_state.is_nominal_state) { + // Construct the optimizer state and inputs only if the complete state + // is available. + // For a nominal state, delay the construction of the optimizer state + // and inputs until the complete state is available. Once the complete + // state is available, the optimizer state and inputs can be constructed + // by invoking ConstructOptimizerStateAndInputs(). + ORT_THROW_IF_ERROR(ConstructOptimizerStateAndInputs()); + } else { + delay_optimizer_state_contruction_ = true; + } } else { ORT_THROW_IF_ERROR(LoadStateDict(state_->optimizer_checkpoint_state)); } @@ -298,6 +307,10 @@ void Optimizer::Initialize(const ModelIdentifiers& model_identifiers, } Status Optimizer::Step() { + if (delay_optimizer_state_contruction_) { + ORT_RETURN_IF_ERROR(ConstructOptimizerStateAndInputs()); + } + OrtValue learning_rate_input, step_input; utils::WrapInOrtValue(optimizer_state_->learning_rate, &learning_rate_input); // Use step count + 1 before running optimizer step. @@ -375,6 +388,17 @@ Status Optimizer::LoadStateDict(OptimizerCheckpointState& optimizer_checkpoint_s return Status::OK(); } +Status Optimizer::ConstructOptimizerStateAndInputs() { + ORT_RETURN_IF(state_->module_checkpoint_state.is_nominal_state, + "The optimizer state cannot be constructed. Please load the model parameters first."); + ORT_RETURN_IF_ERROR(GenerateMomentumNamedStates(state_->optimizer_checkpoint_state)); + ORT_RETURN_IF_ERROR(ConstructInputs()); + + delay_optimizer_state_contruction_ = false; + + return Status::OK(); +} + } // namespace api } // namespace training } // namespace onnxruntime diff --git a/orttraining/orttraining/training_api/optimizer.h b/orttraining/orttraining/training_api/optimizer.h index d9bc4870bb7ed..031b11426539b 100644 --- a/orttraining/orttraining/training_api/optimizer.h +++ b/orttraining/orttraining/training_api/optimizer.h @@ -123,6 +123,15 @@ struct Optimizer { return Status::OK(); } + // Constructs the optimizer state and prepares the model inputs. + // This is called once during the construction of the Optimizer if the model state is available. + // In case the optimizer was instantiated with a nominal checkpoint, this function must be + // called when the model state is available. + // The optimizer checks if the optimizer state needs to be constructed in the train step function. + // However, this is exposed as a public function in case the user wants to construct the optimizer + // state before the train step function is called. + Status ConstructOptimizerStateAndInputs(); + private: void Initialize(const ModelIdentifiers& model_identifiers, const std::vector>& providers, @@ -134,8 +143,7 @@ struct Optimizer { // Generates optimizer momentum states for parameters that require grad. Status GenerateMomentumNamedStates(OptimizerCheckpointState& optimizer_checkpoint_states); - // Constructs the ortvalue inputs to be fed to the graph - // at each step. + // Constructs the ortvalue inputs to be fed to the graph at each step. Status ConstructInputs(); /** @@ -160,6 +168,8 @@ struct Optimizer { InlinedVector inputs_; int32_t group_count_{0}; + + bool delay_optimizer_state_contruction_{false}; }; } // namespace api diff --git a/orttraining/orttraining/training_api/training_session.cc b/orttraining/orttraining/training_api/training_session.cc index 45f0f0ddcf7f4..78619947b8b18 100644 --- a/orttraining/orttraining/training_api/training_session.cc +++ b/orttraining/orttraining/training_api/training_session.cc @@ -112,7 +112,16 @@ Status TrainingSession::CopyParametersToBuffer(OrtValue& parameters_buffer, cons } Status TrainingSession::CopyBufferToParameters(OrtValue& parameters_buffer, const bool trainable_only) { - return module_->CopyBufferToParameters(parameters_buffer, trainable_only); + const bool was_nominal_state = state_->module_checkpoint_state.is_nominal_state; + ORT_RETURN_IF_ERROR(module_->CopyBufferToParameters(parameters_buffer, trainable_only)); + + // If the checkpoint state was nominal before loading the params, then we need to construct the + // optimizer state and inputs. + if (was_nominal_state) { + ORT_RETURN_IF_ERROR(optimizer_->ConstructOptimizerStateAndInputs()); + } + + return Status::OK(); } #if !defined(ORT_MINIMAL_BUILD)