Skip to content

Commit

Permalink
moved wrapper class from header file to locally defined in export fun…
Browse files Browse the repository at this point in the history
…ction
  • Loading branch information
carzh committed Jul 25, 2024
1 parent aa97469 commit e66740d
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 13 deletions.
13 changes: 13 additions & 0 deletions orttraining/orttraining/training_api/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,19 @@ Status Module::ExportModelForInferencing(const std::string& inference_model_path
"Cannot export the model with a nominal state. Please load the model parameters first.");
ORT_RETURN_IF(!eval_sess_, "Eval model was not provided. Cannot export a model for inferencing.");

class EvalSessionWrapper : public InferenceSession {
public:
using InferenceSession::InferenceSession;

Graph& GetMutableGraph() const {
return model_->MainGraph();
}

Model& GetMutableModel() {
return *model_;
}
};

// Once finished_training is set to true, will no longer be able to train or evaluate with this module
// since the eval session graph will have been modified.
finished_training_ = true;
Expand Down
13 changes: 0 additions & 13 deletions orttraining/orttraining/training_api/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,6 @@ namespace onnxruntime {
namespace training {
namespace api {

class EvalSessionWrapper : public InferenceSession {
public:
using InferenceSession::InferenceSession;

Graph& GetMutableGraph() const {
return model_->MainGraph();
}

Model& GetMutableModel() {
return *model_;
}
};

struct Parameter {
public:
Parameter(const std::string& name, const OrtValue& data, const bool requires_grad)
Expand Down

0 comments on commit e66740d

Please sign in to comment.