From 1e1b3f96892363de0a44f398365e737ddb116678 Mon Sep 17 00:00:00 2001 From: Tom McDonald Date: Thu, 16 May 2024 12:46:19 -0400 Subject: [PATCH] Remove ref struct return usage (#20132) ### Description Removes ref struct return usage on netstandard 2.0 builds. ### Motivation and Context Unblocks .NET native compilation --- .../NativeMethods.shared.cs | 32 ++++++++++++++++++- .../SessionOptions.shared.cs | 5 +++ .../Training/NativeTrainingMethods.shared.cs | 22 +++++++++++-- 3 files changed, 56 insertions(+), 3 deletions(-) diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs index 8a8426a0b3054..13d925e0fc2ee 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs @@ -7,7 +7,11 @@ namespace Microsoft.ML.OnnxRuntime { [StructLayout(LayoutKind.Sequential)] +#if NETSTANDARD2_0 + public class OrtApiBase +#else public struct OrtApiBase +#endif { public IntPtr GetApi; public IntPtr GetVersionString; @@ -17,7 +21,11 @@ public struct OrtApiBase // OrtApi ort_api_1_to_ (onnxruntime/core/session/onnxruntime_c_api.cc) // If syncing your new C API, any other C APIs before yours also need to be synced here if haven't [StructLayout(LayoutKind.Sequential)] +#if NETSTANDARD2_0 + public class OrtApi +#else public struct OrtApi +#endif { public IntPtr CreateStatus; public IntPtr GetErrorCode; @@ -300,8 +308,13 @@ internal static class NativeMethods { static OrtApi api_; +#if NETSTANDARD2_0 + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr DOrtGetApi(UInt32 version); +#else [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate ref OrtApi DOrtGetApi(UInt32 version); +#endif [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr DOrtGetVersionString(); @@ -310,11 +323,24 @@ internal static class NativeMethods static NativeMethods() { +#if NETSTANDARD2_0 + IntPtr ortApiBasePtr = OrtGetApiBase(); + OrtApiBase ortApiBase = (OrtApiBase)Marshal.PtrToStructure(ortApiBasePtr, typeof(OrtApiBase)); + DOrtGetApi OrtGetApi = (DOrtGetApi)Marshal.GetDelegateForFunctionPointer(ortApiBase.GetApi, typeof(DOrtGetApi)); +#else DOrtGetApi OrtGetApi = (DOrtGetApi)Marshal.GetDelegateForFunctionPointer(OrtGetApiBase().GetApi, typeof(DOrtGetApi)); +#endif + const uint ORT_API_VERSION = 14; +#if NETSTANDARD2_0 + IntPtr ortApiPtr = OrtGetApi(ORT_API_VERSION); + api_ = (OrtApi)Marshal.PtrToStructure(ortApiPtr, typeof(OrtApi)); + OrtGetVersionString = (DOrtGetVersionString)Marshal.GetDelegateForFunctionPointer(ortApiBase.GetVersionString, typeof(DOrtGetVersionString)); +#else // TODO: Make this save the pointer, and not copy the whole structure across - api_ = (OrtApi)OrtGetApi(14 /*ORT_API_VERSION*/); + api_ = (OrtApi)OrtGetApi(ORT_API_VERSION); OrtGetVersionString = (DOrtGetVersionString)Marshal.GetDelegateForFunctionPointer(OrtGetApiBase().GetVersionString, typeof(DOrtGetVersionString)); +#endif OrtCreateEnv = (DOrtCreateEnv)Marshal.GetDelegateForFunctionPointer(api_.CreateEnv, typeof(DOrtCreateEnv)); OrtCreateEnvWithCustomLogger = (DOrtCreateEnvWithCustomLogger)Marshal.GetDelegateForFunctionPointer(api_.CreateEnvWithCustomLogger, typeof(DOrtCreateEnvWithCustomLogger)); @@ -530,7 +556,11 @@ internal class NativeLib } [DllImport(NativeLib.DllName, CharSet = CharSet.Ansi)] +#if NETSTANDARD2_0 + public static extern IntPtr OrtGetApiBase(); +#else public static extern ref OrtApiBase OrtGetApiBase(); +#endif #region Runtime / Environment API diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs index 30d005b3c4236..6ecfee0a35b60 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs @@ -507,7 +507,12 @@ public void RegisterOrtExtensions() { try { +#if NETSTANDARD2_0 + var ortApiBasePtr = NativeMethods.OrtGetApiBase(); + var ortApiBase = (OrtApiBase)Marshal.PtrToStructure(ortApiBasePtr, typeof(OrtApiBase)); +#else var ortApiBase = NativeMethods.OrtGetApiBase(); +#endif NativeApiStatus.VerifySuccess( OrtExtensionsNativeMethods.RegisterCustomOps(this.handle, ref ortApiBase) ); diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs index 24f85f603c414..1ba5f14641e78 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs @@ -53,8 +53,14 @@ internal static class NativeTrainingMethods static OrtTrainingApi trainingApi_; static IntPtr trainingApiPtr; +#if NETSTANDARD2_0 + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr DOrtGetApi(UInt32 version); +#else [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate ref OrtApi DOrtGetApi(UInt32 version); +#endif + [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /* OrtTrainingApi* */ DOrtGetTrainingApi(UInt32 version); @@ -62,13 +68,25 @@ internal static class NativeTrainingMethods static NativeTrainingMethods() { +#if NETSTANDARD2_0 + IntPtr ortApiBasePtr = NativeMethods.OrtGetApiBase(); + OrtApiBase ortApiBase = (OrtApiBase)Marshal.PtrToStructure(ortApiBasePtr, typeof(OrtApiBase)); + DOrtGetApi OrtGetApi = (DOrtGetApi)Marshal.GetDelegateForFunctionPointer(ortApiBase.GetApi, typeof(DOrtGetApi)); +#else DOrtGetApi OrtGetApi = (DOrtGetApi)Marshal.GetDelegateForFunctionPointer(NativeMethods.OrtGetApiBase().GetApi, typeof(DOrtGetApi)); +#endif + const uint ORT_API_VERSION = 19; +#if NETSTANDARD2_0 + IntPtr ortApiPtr = OrtGetApi(ORT_API_VERSION); + api_ = (OrtApi)Marshal.PtrToStructure(ortApiPtr, typeof(OrtApi)); +#else // TODO: Make this save the pointer, and not copy the whole structure across - api_ = (OrtApi)OrtGetApi(19 /*ORT_API_VERSION*/); + api_ = (OrtApi)OrtGetApi(ORT_API_VERSION); +#endif OrtGetTrainingApi = (DOrtGetTrainingApi)Marshal.GetDelegateForFunctionPointer(api_.GetTrainingApi, typeof(DOrtGetTrainingApi)); - trainingApiPtr = OrtGetTrainingApi(19 /*ORT_API_VERSION*/); + trainingApiPtr = OrtGetTrainingApi(ORT_API_VERSION); if (trainingApiPtr != IntPtr.Zero) { trainingApi_ = (OrtTrainingApi)Marshal.PtrToStructure(trainingApiPtr, typeof(OrtTrainingApi));