Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some cherry-picks for the 1.16.2 release #18218

Merged
merged 25 commits into from
Nov 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
545036f
update
chilo-ms Sep 28, 2023
e0cd0cc
update
chilo-ms Oct 2, 2023
4534d8f
update
chilo-ms Oct 3, 2023
b2a582a
update
chilo-ms Oct 19, 2023
0ec4b83
Fix comment
chilo-ms Oct 19, 2023
e9a8c88
fix
chilo-ms Oct 12, 2023
e718496
Fix missing attribute. Causes build error on release xamarin iOS build.
skottmckay Oct 12, 2023
5ada1ad
Address pull request review comment
baijumeswani Aug 25, 2023
d30dcfa
Add rocm execution provider for Sign
baijumeswani Aug 28, 2023
1e0736c
Update operator kernels doc
baijumeswani Aug 28, 2023
9c3df04
Add definitions to rocm ep
baijumeswani Aug 28, 2023
9353c91
Update count check in EvalStep C#
baijumeswani Aug 28, 2023
b78f80b
C, C++ functions for Updating and Getting a checkpoint parameter
baijumeswani Aug 29, 2023
0ab37c6
Adding UpdateParameter and GetParameter to C#
baijumeswani Aug 30, 2023
619cb81
Expose model parameters and their gradients in Python
baijumeswani Aug 31, 2023
9262106
Address pull request review comments
baijumeswani Sep 18, 2023
17ae5df
Address C# bindings pull request review comments
baijumeswani Sep 19, 2023
a14546e
Address pull request review comments for C# and C API
baijumeswani Sep 19, 2023
6803374
Address C# comments
baijumeswani Sep 20, 2023
d8c313d
Address C# comments
baijumeswani Sep 20, 2023
e1fb060
Address pull request review comments
baijumeswani Sep 21, 2023
265551e
fix typo
baijumeswani Sep 21, 2023
f5fe5da
Add missing member init
kazssym Sep 12, 2023
321e2ac
Fix illegal opcode error from mlas (#17885)
skottmckay Oct 12, 2023
3744ae3
format code
snnn Nov 1, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 32 additions & 21 deletions csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1860,54 +1860,61 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca

public static DOrtFillStringTensor OrtFillStringTensor;

/// \param value A tensor created from OrtCreateTensor... function.
/// \param index The index of the entry in the tensor to resize. <summary>
/// \param length_in_bytes Length to resize the string to.
/// \param buffer The resized buffer.
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetResizedStringTensorElementBuffer(
IntPtr /* OrtValue */ value,
UIntPtr /* size_t */ index,
UIntPtr /* size_t */ length_in_bytes,
out IntPtr /* char** */ buffer
);
IntPtr /* OrtValue */ value,
UIntPtr /* size_t */ index,
UIntPtr /* size_t */ length_in_bytes,
out IntPtr /* char** */ buffer);

public static DOrtGetResizedStringTensorElementBuffer OrtGetResizedStringTensorElementBuffer;

[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetStringTensorContent(
IntPtr /*(OrtValue*)*/ value,
byte[] /*(void*)*/ dst_buffer,
UIntPtr dst_buffer_len,
UIntPtr[] offsets,
UIntPtr offsets_len);
IntPtr /*(OrtValue*)*/ value,
byte[] /*(void*)*/ dst_buffer,
UIntPtr dst_buffer_len,
UIntPtr[] offsets,
UIntPtr offsets_len);

public static DOrtGetStringTensorContent OrtGetStringTensorContent;

[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetStringTensorDataLength(IntPtr /*(OrtValue*)*/ value,
out UIntPtr /*(size_t*)*/ len);
out UIntPtr /*(size_t*)*/ len);

public static DOrtGetStringTensorDataLength OrtGetStringTensorDataLength;

[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetStringTensorElementLength(IntPtr /*(OrtValue*)*/ value,
UIntPtr /*(size_t)*/ index,
out UIntPtr /*(size_t*)*/ len);
UIntPtr /*(size_t)*/ index,
out UIntPtr /*(size_t*)*/ len);

public static DOrtGetStringTensorElementLength OrtGetStringTensorElementLength;

[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetStringTensorElement(IntPtr /*(OrtValue*)*/ value,
UIntPtr /*(size_t)*/ bufferLength,
UIntPtr /*(size_t)*/ elementIndex,
byte[] buffer);
UIntPtr /*(size_t)*/ bufferLength,
UIntPtr /*(size_t)*/ elementIndex,
byte[] buffer);

public static DOrtGetStringTensorElement OrtGetStringTensorElement;

[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/
DOrtCastTypeInfoToTensorInfo(IntPtr /*(struct OrtTypeInfo*)*/ typeInfo, out IntPtr /*(const struct OrtTensorTypeAndShapeInfo**)*/ typeAndShapeInfo);
public delegate IntPtr /*(OrtStatus*)*/ DOrtCastTypeInfoToTensorInfo(
IntPtr /*(struct OrtTypeInfo*)*/ typeInfo,
out IntPtr /*(const struct OrtTensorTypeAndShapeInfo**)*/ typeAndShapeInfo);

public static DOrtCastTypeInfoToTensorInfo OrtCastTypeInfoToTensorInfo;

[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetTensorTypeAndShape(IntPtr /*(OrtValue*)*/ value, out IntPtr /*(struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo);
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetTensorTypeAndShape(
IntPtr /*(OrtValue*)*/ value,
out IntPtr /*(struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo);

public static DOrtGetTensorTypeAndShape OrtGetTensorTypeAndShape;

Expand All @@ -1917,12 +1924,16 @@ out IntPtr /* char** */ buffer
public static DOrtReleaseTensorTypeAndShapeInfo OrtReleaseTensorTypeAndShapeInfo;

[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetTensorElementType(IntPtr /*(const struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo, out IntPtr /*(TensorElementType*)*/ output);
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetTensorElementType(
IntPtr /*(const struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo,
out IntPtr /*(TensorElementType*)*/ output);

public static DOrtGetTensorElementType OrtGetTensorElementType;

[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetDimensionsCount(IntPtr /*(const struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo, out UIntPtr output);
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetDimensionsCount(
IntPtr /*(const struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo,
out UIntPtr output);

public static DOrtGetDimensionsCount OrtGetDimensionsCount;

Expand Down
133 changes: 89 additions & 44 deletions csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,20 +40,16 @@ internal enum PropertyType : long
String = 2
}

private void AddPropertyImpl<T>(string propertyName, PropertyType propertyType, T propertyValue)
private void AddPropertyImpl<T>(string propertyName, PropertyType propertyType, T propertyValue) where T : unmanaged
{
var propertyNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyName);
T[] value = new T[1];
value[0] = propertyValue;
Memory<T> memory = value;
using (var memHandle = memory.Pin())
T[] value = { propertyValue };
unsafe
{
IntPtr memPtr;
unsafe
fixed (T* memPtr = value)
{
memPtr = (IntPtr)memHandle.Pointer;
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, propertyType, (IntPtr)memPtr));
}
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, propertyType, memPtr));
}
}

Expand Down Expand Up @@ -103,56 +99,53 @@ public static void SaveCheckpoint(CheckpointState state, string checkpointPath,
}

/// <summary>
/// Adds the given int property to the checkpoint state.
/// Adds or updates the given int property to/in the checkpoint state.
///
/// Runtime properties that are ints such as epoch, training step, and others can be added to the checkpoint
/// state by the user if they desire by calling this function with the appropriate property name and
/// value. The given property name must be unique to be able to successfully add the property.
/// Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint
/// state by the user by calling this function with the corresponding property name and value.
/// The given property name must be unique to be able to successfully add the property.
/// </summary>
/// <param name="propertyName">Unique name of the property being added.</param>
/// <param name="propertyName">Name of the property being added or updated.</param>
/// <param name="propertyValue">Property value associated with the given name.</param>
public void AddProperty(string propertyName, long propertyValue)
{
AddPropertyImpl(propertyName, PropertyType.Int, propertyValue);
}

/// <summary>
/// Adds the given float property to the checkpoint state.
/// Adds or updates the given float property to/in the checkpoint state.
///
/// Runtime properties that are floats such as loss, best score, and others can be added to the checkpoint
/// state by the user if they desire by calling this function with the appropriate property name and
/// value. The given property name must be unique to be able to successfully add the property.
/// Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint
/// state by the user by calling this function with the corresponding property name and value.
/// The given property name must be unique to be able to successfully add the property.
/// </summary>
/// <param name="propertyName">Unique name of the property being added.</param>
/// <param name="propertyName">Name of the property being added or updated.</param>
/// <param name="propertyValue">Property value associated with the given name.</param>
public void AddProperty(string propertyName, float propertyValue)
{
AddPropertyImpl(propertyName, PropertyType.Float, propertyValue);
}

/// <summary>
/// Adds the given string property to the checkpoint state.
/// Adds or updates the given string property to/in the checkpoint state.
///
/// Runtime properties that are strings such as parameter names, custom strings, and others can be added
/// to the checkpoint state by the user if they desire by calling this function with the appropriate property
/// name and value. The given property name must be unique to be able to successfully add the property.
/// Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint
/// state by the user by calling this function with the corresponding property name and value.
/// The given property name must be unique to be able to successfully add the property.
/// </summary>
/// <param name="propertyName">Unique name of the property being added.</param>
/// <param name="propertyName">Name of the property being added or updated.</param>
/// <param name="propertyValue">Property value associated with the given name.</param>
public void AddProperty(string propertyName, string propertyValue)
{
var propertyNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyName);
var propertyValueUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyValue);

IntPtr unmanagedPointer = Marshal.AllocHGlobal(propertyValueUtf8.Length);
try
{
Marshal.Copy(propertyValueUtf8, 0, unmanagedPointer, propertyValueUtf8.Length);
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, PropertyType.String, unmanagedPointer));
}
finally
unsafe
{
Marshal.FreeHGlobal(unmanagedPointer);
fixed (byte* p = propertyValueUtf8)
{
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, PropertyType.String, (IntPtr)p));
}
}
}

Expand All @@ -162,34 +155,86 @@ public void AddProperty(string propertyName, string propertyValue)
/// Gets the property value from an existing entry in the checkpoint state. The property must
/// exist in the checkpoint state to be able to retrieve it successfully.
/// </summary>
/// <param name="propertyName">Unique name of the property being retrieved.</param>
/// <param name="propertyName">Name of the property being retrieved.</param>
/// <returns>Property value associated with the given property name.</returns>
public object GetProperty(string propertyName)
{
var propertyNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyName);
var allocator = OrtAllocator.DefaultInstance;
IntPtr propertyValue = IntPtr.Zero;

NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetProperty(handle, propertyNameUtf8, allocator.Pointer, out PropertyType propertyType, out propertyValue));

if (propertyType == PropertyType.Int)
try
{
var longPropertyValue = Marshal.ReadInt64(propertyValue);
allocator.FreeMemory(propertyValue);
return longPropertyValue;
if (propertyType == PropertyType.Int)
{
Int64 value;
unsafe
{
value = *(Int64*)propertyValue;
}
return value;
}
else if (propertyType == PropertyType.Float)
{
float value;
unsafe
{
value = *(float*)propertyValue;
}
return value;
}
else if (propertyType == PropertyType.String)
{
return NativeOnnxValueHelper.StringFromNativeUtf8(propertyValue);
}

throw new ArgumentException("Expected the property type to be one of long, float or string. Unknown type retrieved " + propertyValue.ToString());
}
else if (propertyType == PropertyType.Float)
finally
{
float[] value = new float[1];
Marshal.Copy(propertyValue, value, 0, 1);
allocator.FreeMemory(propertyValue);
return value[0];
}
else if (propertyType == PropertyType.String)
}

/// <summary>
/// Updates the data associated with the model parameter in the checkpoint state for the given parameter name.
///
/// This function updates a model parameter in the checkpoint state with the given parameter data.
/// The training session must be already created with the checkpoint state that contains the parameter
/// being updated. The given parameter is copied over to the registered device for the training session.
/// The parameter must exist in the checkpoint state to be able to update it successfully.
/// </summary>
/// <param name="parameterName">Name of the parameter being updated.</param>
/// <param name="parameter">The parameter data that should replace the existing parameter data.</param>
public void UpdateParameter(string parameterName, OrtValue parameter)
{
if (parameter.OnnxType != OnnxValueType.ONNX_TYPE_TENSOR)
{
return NativeOnnxValueHelper.StringFromNativeUtf8(propertyValue, allocator);
throw new ArgumentException("Incorrect buffer received. Expected a tensor parameter.");
}

throw new ArgumentException("Expected the property type to be one of long, float or string. Unknown type retrieved " + propertyValue.ToString());
var parameterNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(parameterName);
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtUpdateParameter(handle, parameterNameUtf8, parameter.Handle));
}

/// <summary>
/// Gets the data associated with the model parameter from the checkpoint state for the given parameter name.
///
/// This function retrieves the model parameter data from the checkpoint state for the given parameter name.
/// The parameter is copied over to the provided OrtValue. The training session must be already created
/// with the checkpoint state that contains the parameter being retrieved.
/// The parameter must exist in the checkpoint state to be able to retrieve it successfully.
/// </summary>
/// <param name="parameterName">Name of the parameter being updated.</param>
/// <returns>The parameter data that is retrieved from the checkpoint state.</returns>
public OrtValue GetParameter(string parameterName)
{
var parameterNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(parameterName);
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParameter(handle, parameterNameUtf8, OrtAllocator.DefaultInstance.Pointer, out IntPtr parameterHandle));

return new OrtValue(parameterHandle);
}

#region SafeHandle
Expand Down
Loading
Loading