diff --git a/orttraining/orttraining/training_api/module.cc b/orttraining/orttraining/training_api/module.cc index 43183a8fe8245..8e4fabced730c 100644 --- a/orttraining/orttraining/training_api/module.cc +++ b/orttraining/orttraining/training_api/module.cc @@ -3,6 +3,8 @@ #include "orttraining/training_api/module.h" +#include + #include "core/common/safeint.h" #include "core/common/string_utils.h" #include "core/framework/execution_provider.h" @@ -425,7 +427,8 @@ size_t Module::GetTrainingModelOutputCount() const noexcept { } size_t Module::GetEvalModelOutputCount() const { - ORT_ENFORCE(!finished_training_, "Exporting for inference has modified the eval model. Cannot retrieve EvalModel output count. "); + ORT_ENFORCE(!finished_training_, + "Exporting for inference has modified the eval model. Cannot retrieve EvalModel output count. "); return eval_output_names_.size(); } @@ -435,7 +438,8 @@ std::string Module::GetTrainingModelOutputName(size_t index) const { } std::string Module::GetEvalModelOutputName(size_t index) const { - ORT_ENFORCE(!finished_training_, "Exporting for inference has modified the eval model. Cannot retrieve EvalModel output name. "); + ORT_ENFORCE(!finished_training_, + "Exporting for inference has modified the eval model. Cannot retrieve EvalModel output name. "); ORT_ENFORCE(index < eval_output_names_.size(), "Eval output name index out of range. Expected in range [0-", eval_output_names_.size(), "). Actual: ", index); return eval_output_names_.at(index); @@ -611,7 +615,8 @@ Status Module::CopyBufferToParameters(OrtValue& parameters_buffer, const bool tr } Status Module::LazyResetGrad() { - ORT_RETURN_IF(finished_training_, "Cannot train after exporting for inferencing. To continue training from this point, please save the checkpoint and create a new TrainingSession."); + ORT_RETURN_IF(finished_training_, + "Cannot train after exporting for inferencing. To continue training from this point, please save the checkpoint and create a new TrainingSession."); accumulate_gradient_ = false; return Status::OK(); } @@ -619,7 +624,8 @@ Status Module::LazyResetGrad() { Status Module::TrainStep(const std::vector& inputs, std::vector& outputs) { ORT_RETURN_IF(state_->module_checkpoint_state.is_nominal_state, "Cannot perform TrainStep with a nominal state. Please load the model parameters first."); - ORT_RETURN_IF(finished_training_, "Cannot train after exporting for inferencing. To continue training from this point, please save the checkpoint and create a new TrainingSession."); + ORT_RETURN_IF(finished_training_, + "Cannot train after exporting for inferencing. To continue training from this point, please save the checkpoint and create a new TrainingSession."); std::vector> params; std::vector feeds{inputs}; feeds.insert(feeds.end(), weights_.begin(), weights_.end()); @@ -642,7 +648,8 @@ Status Module::TrainStep(const std::vector& inputs, std::vector& inputs, std::vector& outputs) { ORT_RETURN_IF(state_->module_checkpoint_state.is_nominal_state, "Cannot perform EvalStep with a nominal state. Please load the model parameters first."); - ORT_RETURN_IF(finished_training_, "Cannot evaluate after exporting for inferencing. To continue training from this point, please save the checkpoint and create a new TrainingSession."); + ORT_RETURN_IF(finished_training_, + "Cannot evaluate after exporting for inferencing. To continue training from this point, please save the checkpoint and create a new TrainingSession."); ORT_ENFORCE(nullptr != eval_sess_, "Evaluation session not initialized."); std::vector feeds{inputs}; feeds.insert(feeds.end(), weights_.begin(), weights_.end()); @@ -697,13 +704,16 @@ size_t Module::GetTrainingModelInputCount() const noexcept { } size_t Module::GetEvalModelInputCount() const { - ORT_ENFORCE(!finished_training_, "Exporting for inference has modified the eval model. Cannot retrieve EvalModel input count. "); + ORT_ENFORCE(!finished_training_, + "Exporting for inference has modified the eval model. Cannot retrieve EvalModel input count. "); return eval_user_input_count_; } std::string Module::GetTrainingModelInputName(size_t index) const { ORT_ENFORCE(index < train_input_names_.UserInputNames().size(), - "Train input name index out of range. Expected in range [0-", train_input_names_.UserInputNames().size(), "). Actual: ", + "Train input name index out of range. Expected in range [0-", + train_input_names_.UserInputNames().size(), + "). Actual: ", index); return train_input_names_.UserInputNames()[index]; } @@ -721,7 +731,8 @@ std::pair Module::GetTrainingModelInputs() } std::pair Module::GetEvalModelInputs() const { - ORT_ENFORCE(!finished_training_, "Exporting for inference has modified the eval model. Cannot retrieve EvalModel inputs. "); + ORT_ENFORCE(!finished_training_, + "Exporting for inference has modified the eval model. Cannot retrieve EvalModel inputs. "); return eval_sess_->GetModelInputs(); } diff --git a/orttraining/orttraining/training_api/module.h b/orttraining/orttraining/training_api/module.h index 237822738782e..13b08beef64ab 100644 --- a/orttraining/orttraining/training_api/module.h +++ b/orttraining/orttraining/training_api/module.h @@ -4,6 +4,7 @@ #pragma once #include +#include #include "core/session/inference_session.h" #include "orttraining/training_api/utils.h"