From e66740dfa6ddc024e66568c9b34b5fb63f6e54de Mon Sep 17 00:00:00 2001 From: carzh Date: Thu, 25 Jul 2024 15:44:09 -0700 Subject: [PATCH] moved wrapper class from header file to locally defined in export function --- orttraining/orttraining/training_api/module.cc | 13 +++++++++++++ orttraining/orttraining/training_api/module.h | 13 ------------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/orttraining/orttraining/training_api/module.cc b/orttraining/orttraining/training_api/module.cc index 3293fcea1fd65..68e6cb8ae04a7 100644 --- a/orttraining/orttraining/training_api/module.cc +++ b/orttraining/orttraining/training_api/module.cc @@ -671,6 +671,19 @@ Status Module::ExportModelForInferencing(const std::string& inference_model_path "Cannot export the model with a nominal state. Please load the model parameters first."); ORT_RETURN_IF(!eval_sess_, "Eval model was not provided. Cannot export a model for inferencing."); + class EvalSessionWrapper : public InferenceSession { + public: + using InferenceSession::InferenceSession; + + Graph& GetMutableGraph() const { + return model_->MainGraph(); + } + + Model& GetMutableModel() { + return *model_; + } + }; + // Once finished_training is set to true, will no longer be able to train or evaluate with this module // since the eval session graph will have been modified. finished_training_ = true; diff --git a/orttraining/orttraining/training_api/module.h b/orttraining/orttraining/training_api/module.h index 3ce4aeaaedd60..13b08beef64ab 100644 --- a/orttraining/orttraining/training_api/module.h +++ b/orttraining/orttraining/training_api/module.h @@ -12,19 +12,6 @@ namespace onnxruntime { namespace training { namespace api { -class EvalSessionWrapper : public InferenceSession { - public: - using InferenceSession::InferenceSession; - - Graph& GetMutableGraph() const { - return model_->MainGraph(); - } - - Model& GetMutableModel() { - return *model_; - } -}; - struct Parameter { public: Parameter(const std::string& name, const OrtValue& data, const bool requires_grad)