Skip to content

Commit

Permalink
Expose SessionOtions.DisablePerSessionThreads (#19730)
Browse files Browse the repository at this point in the history
### Description

### Motivation and Context
ML.NET needs to run mltiple sessions on a single threadpool.
  • Loading branch information
yuslepukhin authored Mar 4, 2024
1 parent 27b1dc9 commit 0cdf36f
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 1 deletion.
5 changes: 5 additions & 0 deletions csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,7 @@ static NativeMethods()
OrtDisableMemPattern = (DOrtDisableMemPattern)Marshal.GetDelegateForFunctionPointer(api_.DisableMemPattern, typeof(DOrtDisableMemPattern));
OrtEnableCpuMemArena = (DOrtEnableCpuMemArena)Marshal.GetDelegateForFunctionPointer(api_.EnableCpuMemArena, typeof(DOrtEnableCpuMemArena));
OrtDisableCpuMemArena = (DOrtDisableCpuMemArena)Marshal.GetDelegateForFunctionPointer(api_.DisableCpuMemArena, typeof(DOrtDisableCpuMemArena));
OrtDisablePerSessionThreads = (DOrtDisablePerSessionThreads)Marshal.GetDelegateForFunctionPointer(api_.DisablePerSessionThreads, typeof(DOrtDisablePerSessionThreads));
OrtSetSessionLogId = (DOrtSetSessionLogId)Marshal.GetDelegateForFunctionPointer(api_.SetSessionLogId, typeof(DOrtSetSessionLogId));
OrtSetSessionLogVerbosityLevel = (DOrtSetSessionLogVerbosityLevel)Marshal.GetDelegateForFunctionPointer(api_.SetSessionLogVerbosityLevel, typeof(DOrtSetSessionLogVerbosityLevel));
OrtSetSessionLogSeverityLevel = (DOrtSetSessionLogSeverityLevel)Marshal.GetDelegateForFunctionPointer(api_.SetSessionLogSeverityLevel, typeof(DOrtSetSessionLogSeverityLevel));
Expand Down Expand Up @@ -992,6 +993,10 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca
public delegate IntPtr /*(OrtStatus*)*/ DOrtDisableCpuMemArena(IntPtr /* OrtSessionOptions* */ options);
public static DOrtDisableCpuMemArena OrtDisableCpuMemArena;

[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtDisablePerSessionThreads(IntPtr /* OrtSessionOptions* */ options);
public static DOrtDisablePerSessionThreads OrtDisablePerSessionThreads;

[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtSetSessionLogId(IntPtr /* OrtSessionOptions* */ options, byte[] /* const char* */ logId);
public static DOrtSetSessionLogId OrtSetSessionLogId;
Expand Down
9 changes: 9 additions & 0 deletions csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,15 @@ public bool EnableCpuMemArena
}
private bool _enableCpuMemArena = true;

/// <summary>
/// Disables the per session threads. Default is true.
/// This makes all sessions in the process use a global TP.
/// </summary>
public void DisablePerSessionThreads()
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtDisablePerSessionThreads(handle));
}

/// <summary>
/// Log Id to be used for the session. Default is empty string.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ public void TestSessionOptions()
Assert.Equal(0, opt.InterOpNumThreads);
Assert.Equal(GraphOptimizationLevel.ORT_ENABLE_ALL, opt.GraphOptimizationLevel);

// No get, so no verify
opt.DisablePerSessionThreads();

// try setting options
opt.ExecutionMode = ExecutionMode.ORT_PARALLEL;
Assert.Equal(ExecutionMode.ORT_PARALLEL, opt.ExecutionMode);
Expand Down Expand Up @@ -98,7 +101,7 @@ public void TestSessionOptions()
Assert.Contains("[ErrorCode:InvalidArgument] Config key is empty", ex.Message);

// SessionOptions.RegisterOrtExtensions can be manually tested by referencing the
// Microsoft.ML.OnnxRuntime.Extensions nuget package. After that is done, this should not throw.
// Microsoft.ML.OnnxRuntime.Extensions nuget package. After that is done, this should not throw.
ex = Assert.Throws<OnnxRuntimeException>(() => { opt.RegisterOrtExtensions(); });
Assert.Contains("Microsoft.ML.OnnxRuntime.Extensions NuGet package must be referenced", ex.Message);

Expand Down

0 comments on commit 0cdf36f

Please sign in to comment.