diff --git a/onnxruntime/core/providers/vitisai/imp/global_api.cc b/onnxruntime/core/providers/vitisai/imp/global_api.cc index 29a1231fdce18..1133751d82d65 100644 --- a/onnxruntime/core/providers/vitisai/imp/global_api.cc +++ b/onnxruntime/core/providers/vitisai/imp/global_api.cc @@ -270,6 +270,9 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { graph.SetGraphResolveNeeded(); } auto status = graph.Resolve(); + if (!status.IsOK()) { + std::cerr << "graph resolve error:" << status.ErrorMessage() << std::endl; + } return status.Code(); }; the_global_api.graph_get_consumer_nodes_unsafe = [](const Graph& graph, const std::string& node_arg_name) -> auto { diff --git a/onnxruntime/core/providers/vitisai/include/vaip/my_ort.h b/onnxruntime/core/providers/vitisai/include/vaip/my_ort.h index 46fc4ac9b2a5d..74482d8e9ee0e 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/my_ort.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/my_ort.h @@ -38,7 +38,13 @@ enum TensorProto_DataType : int { TensorProto_DataType_UINT64 = 13, TensorProto_DataType_COMPLEX64 = 14, TensorProto_DataType_COMPLEX128 = 15, - TensorProto_DataType_BFLOAT16 = 16 + TensorProto_DataType_BFLOAT16 = 16, + TensorProto_DataType_FLOAT8E4M3FN = 17, + TensorProto_DataType_FLOAT8E4M3FNUZ = 18, + TensorProto_DataType_FLOAT8E5M2 = 19, + TensorProto_DataType_FLOAT8E5M2FNUZ = 20, + TensorProto_DataType_UINT4 = 21, + TensorProto_DataType_INT4 = 22 }; enum AttributeProto_AttributeType : int { AttributeProto_AttributeType_UNDEFINED = 0, diff --git a/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h b/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h index 62a7bb602e7e8..3346739890484 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h @@ -13,7 +13,7 @@ struct OrtApi; namespace vaip_core { #define VAIP_ORT_API_MAJOR (3u) -#define VAIP_ORT_API_MINOR (0u) +#define VAIP_ORT_API_MINOR (1u) #define VAIP_ORT_API_PATCH (0u) struct OrtApiForVaip { uint32_t magic; // 'VAIP' or something else to make sure the following field diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index d4c6e3d506f18..408ad7815835f 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -613,8 +613,12 @@ struct ProviderHostImpl : ProviderHost { elemType = ONNX_NAMESPACE::TensorProto_DataType_UINT8; } else if (data_type->s() == "int32") { elemType = ONNX_NAMESPACE::TensorProto_DataType_INT32; + } else if (data_type->s() == "uint32") { + elemType = ONNX_NAMESPACE::TensorProto_DataType_UINT32; } else if (data_type->s() == "int64") { elemType = ONNX_NAMESPACE::TensorProto_DataType_INT64; + } else if (data_type->s() == "uint64") { + elemType = ONNX_NAMESPACE::TensorProto_DataType_UINT64; } else if (data_type->s() == "int1") { elemType = ONNX_NAMESPACE::TensorProto_DataType_BOOL; } else if (data_type->s() == "bfloat16") { @@ -625,6 +629,26 @@ struct ProviderHostImpl : ProviderHost { elemType = ONNX_NAMESPACE::TensorProto_DataType_UINT16; } else if (data_type->s() == "int16") { elemType = ONNX_NAMESPACE::TensorProto_DataType_INT16; + } else if (data_type->s() == "double") { + elemType = ONNX_NAMESPACE::TensorProto_DataType_DOUBLE; + } else if (data_type->s() == "string") { + elemType = ONNX_NAMESPACE::TensorProto_DataType_STRING; + } else if (data_type->s() == "complex64") { + elemType = ONNX_NAMESPACE::TensorProto_DataType_COMPLEX64; + } else if (data_type->s() == "complex128") { + elemType = ONNX_NAMESPACE::TensorProto_DataType_COMPLEX128; + } else if (data_type->s() == "float8e4m3fn") { + elemType = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN; + } else if (data_type->s() == "float8e4m3fnuz") { + elemType = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FNUZ; + } else if (data_type->s() == "float8e5m2") { + elemType = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2; + } else if (data_type->s() == "float8e5m2funz") { + elemType = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ; + } else if (data_type->s() == "uint4") { + elemType = ONNX_NAMESPACE::TensorProto_DataType_UINT4; + } else if (data_type->s() == "int4") { + elemType = ONNX_NAMESPACE::TensorProto_DataType_INT4; } else { return; }