Skip to content

Commit

Permalink
[VitisAI] Bug fixes in model_clone
Browse files Browse the repository at this point in the history
  • Loading branch information
Zhenze Wang committed Sep 2, 2024
1 parent 8c53364 commit 8d05ec0
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -950,6 +950,7 @@ struct ProviderHost {
virtual const Node* Graph__GetNode(const Graph* p, NodeIndex node_index) const = 0;
virtual const NodeArg* Graph__GetNodeArg(const Graph* p, const std::string& name) const = 0;
virtual IOnnxRuntimeOpSchemaCollectionPtr Graph__GetSchemaRegistry(const Graph* p) const = 0;
virtual bool Graph__SetOpSchemaFromRegistryForNode(Graph* p, Node& node) = 0;

// GraphViewer
virtual void GraphViewer__operator_delete(GraphViewer* p) = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1010,6 +1010,7 @@ struct Graph final {
Node* GetNode(NodeIndex node_index) noexcept { return g_host->Graph__GetNode(this, node_index); }
const NodeArg* GetNodeArg(const std::string& name) const { return g_host->Graph__GetNodeArg(this, name); }
IOnnxRuntimeOpSchemaCollectionPtr GetSchemaRegistry() const { return g_host->Graph__GetSchemaRegistry(this); }
bool SetOpSchemaFromRegistryForNode(Node& node) { return g_host->Graph__SetOpSchemaFromRegistryForNode(this, node); }

PROVIDER_DISALLOW_ALL(Graph)
};
Expand Down
6 changes: 5 additions & 1 deletion onnxruntime/core/providers/vitisai/imp/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,11 @@ Model* model_clone(const Model& original_model, int64_t external_data_threshold)
}
}
auto ret = Model::Create(std::move(*model_proto), file_path, &local_registries, logger);
auto status = ret->MainGraph().Resolve();
auto& graph = ret->MainGraph();
for (auto node : graph.Nodes()) {
graph.SetOpSchemaFromRegistryForNode(*graph.GetNode(node->Index()));
}
auto status = graph.Resolve();
vai_assert(status.IsOK(), status.ErrorMessage());
return ret.release();
}
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/session/provider_bridge_ort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1138,6 +1138,7 @@ struct ProviderHostImpl : ProviderHost {
const Node* Graph__GetNode(const Graph* p, NodeIndex node_index) const override { return p->GetNode(node_index); }
const NodeArg* Graph__GetNodeArg(const Graph* p, const std::string& name) const override { return p->GetNodeArg(name); }
IOnnxRuntimeOpSchemaCollectionPtr Graph__GetSchemaRegistry(const Graph* p) const override { return p->GetSchemaRegistry(); }
bool Graph__SetOpSchemaFromRegistryForNode(Graph* p, Node& node) override { return p->SetOpSchemaFromRegistryForNode(node); }

// GraphViewer (wrapped)
void GraphViewer__operator_delete(GraphViewer* p) override { delete p; }
Expand Down

0 comments on commit 8d05ec0

Please sign in to comment.