Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
carzh committed Jul 22, 2024
1 parent ec3d37d commit 985a581
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 8 deletions.
27 changes: 19 additions & 8 deletions orttraining/orttraining/training_api/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

#include "orttraining/training_api/module.h"

#include <memory>

#include "core/common/safeint.h"
#include "core/common/string_utils.h"
#include "core/framework/execution_provider.h"
Expand Down Expand Up @@ -425,7 +427,8 @@ size_t Module::GetTrainingModelOutputCount() const noexcept {
}

size_t Module::GetEvalModelOutputCount() const {
ORT_ENFORCE(!finished_training_, "Exporting for inference has modified the eval model. Cannot retrieve EvalModel output count. ");
ORT_ENFORCE(!finished_training_,
"Exporting for inference has modified the eval model. Cannot retrieve EvalModel output count. ");
return eval_output_names_.size();
}

Expand All @@ -435,7 +438,8 @@ std::string Module::GetTrainingModelOutputName(size_t index) const {
}

std::string Module::GetEvalModelOutputName(size_t index) const {
ORT_ENFORCE(!finished_training_, "Exporting for inference has modified the eval model. Cannot retrieve EvalModel output name. ");
ORT_ENFORCE(!finished_training_,
"Exporting for inference has modified the eval model. Cannot retrieve EvalModel output name. ");
ORT_ENFORCE(index < eval_output_names_.size(), "Eval output name index out of range. Expected in range [0-",
eval_output_names_.size(), "). Actual: ", index);
return eval_output_names_.at(index);
Expand Down Expand Up @@ -611,15 +615,17 @@ Status Module::CopyBufferToParameters(OrtValue& parameters_buffer, const bool tr
}

Status Module::LazyResetGrad() {
ORT_RETURN_IF(finished_training_, "Cannot train after exporting for inferencing. To continue training from this point, please save the checkpoint and create a new TrainingSession.");
ORT_RETURN_IF(finished_training_,
"Cannot train after exporting for inferencing. To continue training from this point, please save the checkpoint and create a new TrainingSession.");

Check warning on line 619 in orttraining/orttraining/training_api/module.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: orttraining/orttraining/training_api/module.cc:619: Lines should be <= 120 characters long [whitespace/line_length] [2]
accumulate_gradient_ = false;
return Status::OK();
}

Status Module::TrainStep(const std::vector<OrtValue>& inputs, std::vector<OrtValue>& outputs) {
ORT_RETURN_IF(state_->module_checkpoint_state.is_nominal_state,
"Cannot perform TrainStep with a nominal state. Please load the model parameters first.");
ORT_RETURN_IF(finished_training_, "Cannot train after exporting for inferencing. To continue training from this point, please save the checkpoint and create a new TrainingSession.");
ORT_RETURN_IF(finished_training_,
"Cannot train after exporting for inferencing. To continue training from this point, please save the checkpoint and create a new TrainingSession.");

Check warning on line 628 in orttraining/orttraining/training_api/module.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: orttraining/orttraining/training_api/module.cc:628: Lines should be <= 120 characters long [whitespace/line_length] [2]
std::vector<std::shared_ptr<Parameter>> params;
std::vector<OrtValue> feeds{inputs};
feeds.insert(feeds.end(), weights_.begin(), weights_.end());
Expand All @@ -642,7 +648,8 @@ Status Module::TrainStep(const std::vector<OrtValue>& inputs, std::vector<OrtVal
Status Module::EvalStep(const std::vector<OrtValue>& inputs, std::vector<OrtValue>& outputs) {
ORT_RETURN_IF(state_->module_checkpoint_state.is_nominal_state,
"Cannot perform EvalStep with a nominal state. Please load the model parameters first.");
ORT_RETURN_IF(finished_training_, "Cannot evaluate after exporting for inferencing. To continue training from this point, please save the checkpoint and create a new TrainingSession.");
ORT_RETURN_IF(finished_training_,
"Cannot evaluate after exporting for inferencing. To continue training from this point, please save the checkpoint and create a new TrainingSession.");
ORT_ENFORCE(nullptr != eval_sess_, "Evaluation session not initialized.");
std::vector<OrtValue> feeds{inputs};

Check warning on line 654 in orttraining/orttraining/training_api/module.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: orttraining/orttraining/training_api/module.cc:654: Add #include <vector> for vector<> [build/include_what_you_use] [4]
feeds.insert(feeds.end(), weights_.begin(), weights_.end());
Expand Down Expand Up @@ -697,13 +704,16 @@ size_t Module::GetTrainingModelInputCount() const noexcept {
}

size_t Module::GetEvalModelInputCount() const {
ORT_ENFORCE(!finished_training_, "Exporting for inference has modified the eval model. Cannot retrieve EvalModel input count. ");
ORT_ENFORCE(!finished_training_,
"Exporting for inference has modified the eval model. Cannot retrieve EvalModel input count. ");
return eval_user_input_count_;
}

std::string Module::GetTrainingModelInputName(size_t index) const {
ORT_ENFORCE(index < train_input_names_.UserInputNames().size(),
"Train input name index out of range. Expected in range [0-", train_input_names_.UserInputNames().size(), "). Actual: ",
"Train input name index out of range. Expected in range [0-",
train_input_names_.UserInputNames().size(),
"). Actual: ",
index);
return train_input_names_.UserInputNames()[index];
}
Expand All @@ -721,7 +731,8 @@ std::pair<common::Status, const InputDefList*> Module::GetTrainingModelInputs()
}

std::pair<common::Status, const InputDefList*> Module::GetEvalModelInputs() const {
ORT_ENFORCE(!finished_training_, "Exporting for inference has modified the eval model. Cannot retrieve EvalModel inputs. ");
ORT_ENFORCE(!finished_training_,
"Exporting for inference has modified the eval model. Cannot retrieve EvalModel inputs. ");
return eval_sess_->GetModelInputs();
}

Expand Down
1 change: 1 addition & 0 deletions orttraining/orttraining/training_api/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#pragma once

#include <string>
#include <utility>
#include "core/session/inference_session.h"
#include "orttraining/training_api/utils.h"

Expand Down

0 comments on commit 985a581

Please sign in to comment.