Skip to content

Commit

Permalink
Remove ref struct return usage (#20132)
Browse files Browse the repository at this point in the history
### Description
Removes ref struct return usage on netstandard 2.0 builds.

### Motivation and Context
Unblocks .NET native compilation
  • Loading branch information
tommcdon authored May 16, 2024
1 parent 47a178b commit 1e1b3f9
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 3 deletions.
32 changes: 31 additions & 1 deletion csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
namespace Microsoft.ML.OnnxRuntime
{
[StructLayout(LayoutKind.Sequential)]
#if NETSTANDARD2_0
public class OrtApiBase
#else
public struct OrtApiBase
#endif
{
public IntPtr GetApi;
public IntPtr GetVersionString;
Expand All @@ -17,7 +21,11 @@ public struct OrtApiBase
// OrtApi ort_api_1_to_<latest_version> (onnxruntime/core/session/onnxruntime_c_api.cc)
// If syncing your new C API, any other C APIs before yours also need to be synced here if haven't
[StructLayout(LayoutKind.Sequential)]
#if NETSTANDARD2_0
public class OrtApi
#else
public struct OrtApi
#endif
{
public IntPtr CreateStatus;
public IntPtr GetErrorCode;
Expand Down Expand Up @@ -300,8 +308,13 @@ internal static class NativeMethods
{
static OrtApi api_;

#if NETSTANDARD2_0
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr DOrtGetApi(UInt32 version);
#else
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate ref OrtApi DOrtGetApi(UInt32 version);
#endif

[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr DOrtGetVersionString();
Expand All @@ -310,11 +323,24 @@ internal static class NativeMethods

static NativeMethods()
{
#if NETSTANDARD2_0
IntPtr ortApiBasePtr = OrtGetApiBase();
OrtApiBase ortApiBase = (OrtApiBase)Marshal.PtrToStructure(ortApiBasePtr, typeof(OrtApiBase));
DOrtGetApi OrtGetApi = (DOrtGetApi)Marshal.GetDelegateForFunctionPointer(ortApiBase.GetApi, typeof(DOrtGetApi));
#else
DOrtGetApi OrtGetApi = (DOrtGetApi)Marshal.GetDelegateForFunctionPointer(OrtGetApiBase().GetApi, typeof(DOrtGetApi));
#endif

const uint ORT_API_VERSION = 14;
#if NETSTANDARD2_0
IntPtr ortApiPtr = OrtGetApi(ORT_API_VERSION);
api_ = (OrtApi)Marshal.PtrToStructure(ortApiPtr, typeof(OrtApi));
OrtGetVersionString = (DOrtGetVersionString)Marshal.GetDelegateForFunctionPointer(ortApiBase.GetVersionString, typeof(DOrtGetVersionString));
#else
// TODO: Make this save the pointer, and not copy the whole structure across
api_ = (OrtApi)OrtGetApi(14 /*ORT_API_VERSION*/);
api_ = (OrtApi)OrtGetApi(ORT_API_VERSION);
OrtGetVersionString = (DOrtGetVersionString)Marshal.GetDelegateForFunctionPointer(OrtGetApiBase().GetVersionString, typeof(DOrtGetVersionString));
#endif

OrtCreateEnv = (DOrtCreateEnv)Marshal.GetDelegateForFunctionPointer(api_.CreateEnv, typeof(DOrtCreateEnv));
OrtCreateEnvWithCustomLogger = (DOrtCreateEnvWithCustomLogger)Marshal.GetDelegateForFunctionPointer(api_.CreateEnvWithCustomLogger, typeof(DOrtCreateEnvWithCustomLogger));
Expand Down Expand Up @@ -530,7 +556,11 @@ internal class NativeLib
}

[DllImport(NativeLib.DllName, CharSet = CharSet.Ansi)]
#if NETSTANDARD2_0
public static extern IntPtr OrtGetApiBase();
#else
public static extern ref OrtApiBase OrtGetApiBase();
#endif

#region Runtime / Environment API

Expand Down
5 changes: 5 additions & 0 deletions csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,12 @@ public void RegisterOrtExtensions()
{
try
{
#if NETSTANDARD2_0
var ortApiBasePtr = NativeMethods.OrtGetApiBase();
var ortApiBase = (OrtApiBase)Marshal.PtrToStructure(ortApiBasePtr, typeof(OrtApiBase));
#else
var ortApiBase = NativeMethods.OrtGetApiBase();
#endif
NativeApiStatus.VerifySuccess(
OrtExtensionsNativeMethods.RegisterCustomOps(this.handle, ref ortApiBase)
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,22 +53,40 @@ internal static class NativeTrainingMethods
static OrtTrainingApi trainingApi_;
static IntPtr trainingApiPtr;

#if NETSTANDARD2_0
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr DOrtGetApi(UInt32 version);
#else
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate ref OrtApi DOrtGetApi(UInt32 version);
#endif


[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /* OrtTrainingApi* */ DOrtGetTrainingApi(UInt32 version);
public static DOrtGetTrainingApi OrtGetTrainingApi;

static NativeTrainingMethods()
{
#if NETSTANDARD2_0
IntPtr ortApiBasePtr = NativeMethods.OrtGetApiBase();
OrtApiBase ortApiBase = (OrtApiBase)Marshal.PtrToStructure(ortApiBasePtr, typeof(OrtApiBase));
DOrtGetApi OrtGetApi = (DOrtGetApi)Marshal.GetDelegateForFunctionPointer(ortApiBase.GetApi, typeof(DOrtGetApi));
#else
DOrtGetApi OrtGetApi = (DOrtGetApi)Marshal.GetDelegateForFunctionPointer(NativeMethods.OrtGetApiBase().GetApi, typeof(DOrtGetApi));
#endif

const uint ORT_API_VERSION = 19;
#if NETSTANDARD2_0
IntPtr ortApiPtr = OrtGetApi(ORT_API_VERSION);
api_ = (OrtApi)Marshal.PtrToStructure(ortApiPtr, typeof(OrtApi));
#else
// TODO: Make this save the pointer, and not copy the whole structure across
api_ = (OrtApi)OrtGetApi(19 /*ORT_API_VERSION*/);
api_ = (OrtApi)OrtGetApi(ORT_API_VERSION);
#endif

OrtGetTrainingApi = (DOrtGetTrainingApi)Marshal.GetDelegateForFunctionPointer(api_.GetTrainingApi, typeof(DOrtGetTrainingApi));
trainingApiPtr = OrtGetTrainingApi(19 /*ORT_API_VERSION*/);
trainingApiPtr = OrtGetTrainingApi(ORT_API_VERSION);
if (trainingApiPtr != IntPtr.Zero)
{
trainingApi_ = (OrtTrainingApi)Marshal.PtrToStructure(trainingApiPtr, typeof(OrtTrainingApi));
Expand Down

0 comments on commit 1e1b3f9

Please sign in to comment.