Skip to content

Commit

Permalink
removed the eval model path parameter + more cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
carzh committed Jul 19, 2024
1 parent 2ed4dfe commit ec3d37d
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 15 deletions.
16 changes: 2 additions & 14 deletions orttraining/orttraining/training_api/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ Status TransformModelOutputsForInference(Graph& inference_graph,
inference_graph.SetOutputs(inference_graph_output_node_args);
ORT_RETURN_IF_ERROR(RemoveUnusedNodes(inference_graph, inference_graph_output_node_args));

ORT_THROW_IF_ERROR(inference_graph.Resolve());
ORT_RETURN_IF_ERROR(inference_graph.Resolve());

return Status::OK();
}
Expand Down Expand Up @@ -114,7 +114,7 @@ Status TransformModelInputsForInference(Graph& inference_graph,
}

inference_graph.SetInputs(user_graph_inputs);
ORT_THROW_IF_ERROR(inference_graph.Resolve());
ORT_RETURN_IF_ERROR(inference_graph.Resolve());

return Status::OK();
}
Expand Down Expand Up @@ -388,11 +388,6 @@ Module::Module(const ModelIdentifiers& model_identifiers,
}
ORT_THROW_IF_ERROR(eval_sess_->Initialize());
utils::GetGraphInputOutputNames(eval_sess_, eval_input_names_, eval_output_names_);
// TODO: remove this
// std::shared_ptr<onnxruntime::Model> inference_model = eval_sess_->GetModel();
// Graph& inference_graph = inference_model->MainGraph();

// ORT_THROW_IF_ERROR(RemoveThisMethodBeforeYourPR(inference_graph));

// Eval model validation
// We are making certain assumptions: Like the order in which parameters occur will be same between train and eval
Expand All @@ -417,14 +412,8 @@ Module::Module(const ModelIdentifiers& model_identifiers,
eval_user_input_count_ = eval_user_input_names.size();
eval_input_names_.insert(eval_input_names_.end(), eval_param_input_names.begin(), eval_param_input_names.end());

// TODO: remove this
// ORT_THROW_IF_ERROR(RemoveThisMethodBeforeYourPR(inference_graph));
// Keep a copy of the eval model path to be able to later export the model for inferencing.
// The inference model will be reconstructed from the eval model.
// TODO(askhade): Find a fix to export model for inference when the eval model is loaded from a buffer.
if (std::holds_alternative<std::optional<std::string>>(model_identifiers.eval_model)) {
eval_model_path_ = std::get<std::optional<std::string>>(model_identifiers.eval_model);
}
}

Module::~Module() {
Expand Down Expand Up @@ -676,7 +665,6 @@ Status Module::ExportModelForInferencing(const std::string& inference_model_path
// since the eval session graph will have been modified.
finished_training_ = true;

// Model& inference_model = const_cast<Model&>(eval_sess_->GetModel());
std::shared_ptr<onnxruntime::Model> inference_model = eval_sess_->GetModel();

Check warning on line 668 in orttraining/orttraining/training_api/module.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <memory> for shared_ptr<> [build/include_what_you_use] [4] Raw Output: orttraining/orttraining/training_api/module.cc:668: Add #include <memory> for shared_ptr<> [build/include_what_you_use] [4]
Graph& inference_graph = inference_model->MainGraph();

Expand Down
1 change: 0 additions & 1 deletion orttraining/orttraining/training_api/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,6 @@ struct Module {
CheckpointState* state_; // Non owning pointer to the state.

bool accumulate_gradient_ = false;
std::optional<std::string> eval_model_path_;
size_t eval_user_input_count_{0U};
};

Expand Down

0 comments on commit ec3d37d

Please sign in to comment.