Skip to content

Commit

Permalink
lintrunner
Browse files Browse the repository at this point in the history
  • Loading branch information
carzh committed Jul 22, 2024
1 parent 985a581 commit 26dc4ec
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions orttraining/orttraining/training_api/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ size_t Module::GetTrainingModelOutputCount() const noexcept {

size_t Module::GetEvalModelOutputCount() const {
ORT_ENFORCE(!finished_training_,
"Exporting for inference has modified the eval model. Cannot retrieve EvalModel output count. ");
"Exporting for inference has modified the eval model. Cannot retrieve EvalModel output count. ");
return eval_output_names_.size();
}

Expand All @@ -439,7 +439,7 @@ std::string Module::GetTrainingModelOutputName(size_t index) const {

std::string Module::GetEvalModelOutputName(size_t index) const {
ORT_ENFORCE(!finished_training_,
"Exporting for inference has modified the eval model. Cannot retrieve EvalModel output name. ");
"Exporting for inference has modified the eval model. Cannot retrieve EvalModel output name. ");
ORT_ENFORCE(index < eval_output_names_.size(), "Eval output name index out of range. Expected in range [0-",
eval_output_names_.size(), "). Actual: ", index);
return eval_output_names_.at(index);
Expand Down Expand Up @@ -616,7 +616,8 @@ Status Module::CopyBufferToParameters(OrtValue& parameters_buffer, const bool tr

Status Module::LazyResetGrad() {
ORT_RETURN_IF(finished_training_,
"Cannot train after exporting for inferencing. To continue training from this point, please save the checkpoint and create a new TrainingSession.");
"Cannot train after exporting for inferencing. ",
"To continue training from this point, please save the checkpoint and create a new TrainingSession.");
accumulate_gradient_ = false;
return Status::OK();
}
Expand All @@ -625,7 +626,8 @@ Status Module::TrainStep(const std::vector<OrtValue>& inputs, std::vector<OrtVal
ORT_RETURN_IF(state_->module_checkpoint_state.is_nominal_state,
"Cannot perform TrainStep with a nominal state. Please load the model parameters first.");
ORT_RETURN_IF(finished_training_,
"Cannot train after exporting for inferencing. To continue training from this point, please save the checkpoint and create a new TrainingSession.");
"Cannot train after exporting for inferencing. ",
"To continue training from this point, please save the checkpoint and create a new TrainingSession.");
std::vector<std::shared_ptr<Parameter>> params;
std::vector<OrtValue> feeds{inputs};
feeds.insert(feeds.end(), weights_.begin(), weights_.end());
Expand All @@ -649,7 +651,8 @@ Status Module::EvalStep(const std::vector<OrtValue>& inputs, std::vector<OrtValu
ORT_RETURN_IF(state_->module_checkpoint_state.is_nominal_state,
"Cannot perform EvalStep with a nominal state. Please load the model parameters first.");
ORT_RETURN_IF(finished_training_,
"Cannot evaluate after exporting for inferencing. To continue training from this point, please save the checkpoint and create a new TrainingSession.");
"Cannot evaluate after exporting for inferencing. ",
"To continue training from this point, please save the checkpoint and create a new TrainingSession.");
ORT_ENFORCE(nullptr != eval_sess_, "Evaluation session not initialized.");
std::vector<OrtValue> feeds{inputs};
feeds.insert(feeds.end(), weights_.begin(), weights_.end());
Expand Down Expand Up @@ -705,7 +708,7 @@ size_t Module::GetTrainingModelInputCount() const noexcept {

size_t Module::GetEvalModelInputCount() const {
ORT_ENFORCE(!finished_training_,
"Exporting for inference has modified the eval model. Cannot retrieve EvalModel input count. ");
"Exporting for inference has modified the eval model. Cannot retrieve EvalModel input count. ");
return eval_user_input_count_;
}

Expand All @@ -719,7 +722,8 @@ std::string Module::GetTrainingModelInputName(size_t index) const {
}

std::string Module::GetEvalModelInputName(size_t index) const {
ORT_ENFORCE(!finished_training_, "Exporting for inference has modified the eval model. Cannot retrieve EvalModel input name. ");
ORT_ENFORCE(!finished_training_,
"Exporting for inference has modified the eval model. Cannot retrieve EvalModel input name. ");
ORT_ENFORCE(index < eval_user_input_count_,
"Eval input name index out of range. Expected in range [0-", eval_user_input_count_, "). Actual: ",
index);
Expand All @@ -732,7 +736,7 @@ std::pair<common::Status, const InputDefList*> Module::GetTrainingModelInputs()

std::pair<common::Status, const InputDefList*> Module::GetEvalModelInputs() const {
ORT_ENFORCE(!finished_training_,
"Exporting for inference has modified the eval model. Cannot retrieve EvalModel inputs. ");
"Exporting for inference has modified the eval model. Cannot retrieve EvalModel inputs. ");
return eval_sess_->GetModelInputs();
}

Expand Down

0 comments on commit 26dc4ec

Please sign in to comment.