diff --git a/orttraining/orttraining/training_api/module.cc b/orttraining/orttraining/training_api/module.cc index 57843c689f9b5..43183a8fe8245 100644 --- a/orttraining/orttraining/training_api/module.cc +++ b/orttraining/orttraining/training_api/module.cc @@ -86,7 +86,7 @@ Status TransformModelOutputsForInference(Graph& inference_graph, inference_graph.SetOutputs(inference_graph_output_node_args); ORT_RETURN_IF_ERROR(RemoveUnusedNodes(inference_graph, inference_graph_output_node_args)); - ORT_THROW_IF_ERROR(inference_graph.Resolve()); + ORT_RETURN_IF_ERROR(inference_graph.Resolve()); return Status::OK(); } @@ -114,7 +114,7 @@ Status TransformModelInputsForInference(Graph& inference_graph, } inference_graph.SetInputs(user_graph_inputs); - ORT_THROW_IF_ERROR(inference_graph.Resolve()); + ORT_RETURN_IF_ERROR(inference_graph.Resolve()); return Status::OK(); } @@ -388,11 +388,6 @@ Module::Module(const ModelIdentifiers& model_identifiers, } ORT_THROW_IF_ERROR(eval_sess_->Initialize()); utils::GetGraphInputOutputNames(eval_sess_, eval_input_names_, eval_output_names_); - // TODO: remove this - // std::shared_ptr inference_model = eval_sess_->GetModel(); - // Graph& inference_graph = inference_model->MainGraph(); - - // ORT_THROW_IF_ERROR(RemoveThisMethodBeforeYourPR(inference_graph)); // Eval model validation // We are making certain assumptions: Like the order in which parameters occur will be same between train and eval @@ -417,14 +412,8 @@ Module::Module(const ModelIdentifiers& model_identifiers, eval_user_input_count_ = eval_user_input_names.size(); eval_input_names_.insert(eval_input_names_.end(), eval_param_input_names.begin(), eval_param_input_names.end()); - // TODO: remove this - // ORT_THROW_IF_ERROR(RemoveThisMethodBeforeYourPR(inference_graph)); // Keep a copy of the eval model path to be able to later export the model for inferencing. // The inference model will be reconstructed from the eval model. - // TODO(askhade): Find a fix to export model for inference when the eval model is loaded from a buffer. - if (std::holds_alternative>(model_identifiers.eval_model)) { - eval_model_path_ = std::get>(model_identifiers.eval_model); - } } Module::~Module() { @@ -676,7 +665,6 @@ Status Module::ExportModelForInferencing(const std::string& inference_model_path // since the eval session graph will have been modified. finished_training_ = true; - // Model& inference_model = const_cast(eval_sess_->GetModel()); std::shared_ptr inference_model = eval_sess_->GetModel(); Graph& inference_graph = inference_model->MainGraph(); diff --git a/orttraining/orttraining/training_api/module.h b/orttraining/orttraining/training_api/module.h index e4a784b95bcfd..237822738782e 100644 --- a/orttraining/orttraining/training_api/module.h +++ b/orttraining/orttraining/training_api/module.h @@ -199,7 +199,6 @@ struct Module { CheckpointState* state_; // Non owning pointer to the state. bool accumulate_gradient_ = false; - std::optional eval_model_path_; size_t eval_user_input_count_{0U}; };