Skip to content

Commit

Permalink
Some cherry-picks for the 1.16.2 release (#18218)
Browse files Browse the repository at this point in the history
Cherry-pick PRs: 
#18026 
#17912 
#17901 “2 lines added whitespace errors when cherry-picking"
#17293 
#17364 
#17505 
#17885

This PR contains all the cherry-picks for the patch release except:
1. The PRs marked with sdxl_llama
2. #17772 which has a merge conflict.

---------

Co-authored-by: Chi Lo <[email protected]>
Co-authored-by: Chi Lo <[email protected]>
Co-authored-by: Scott McKay <[email protected]>
Co-authored-by: Baiju Meswani <[email protected]>
Co-authored-by: Kaz Nishimura <[email protected]>
Co-authored-by: Scott McKay <[email protected]>
  • Loading branch information
7 people authored Nov 2, 2023
1 parent bc533a6 commit 2f57f1e
Show file tree
Hide file tree
Showing 32 changed files with 1,170 additions and 288 deletions.
53 changes: 32 additions & 21 deletions csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1860,54 +1860,61 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca

public static DOrtFillStringTensor OrtFillStringTensor;

/// \param value A tensor created from OrtCreateTensor... function.
/// \param index The index of the entry in the tensor to resize. <summary>
/// \param length_in_bytes Length to resize the string to.
/// \param buffer The resized buffer.
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetResizedStringTensorElementBuffer(
IntPtr /* OrtValue */ value,
UIntPtr /* size_t */ index,
UIntPtr /* size_t */ length_in_bytes,
out IntPtr /* char** */ buffer
);
IntPtr /* OrtValue */ value,
UIntPtr /* size_t */ index,
UIntPtr /* size_t */ length_in_bytes,
out IntPtr /* char** */ buffer);

public static DOrtGetResizedStringTensorElementBuffer OrtGetResizedStringTensorElementBuffer;

[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetStringTensorContent(
IntPtr /*(OrtValue*)*/ value,
byte[] /*(void*)*/ dst_buffer,
UIntPtr dst_buffer_len,
UIntPtr[] offsets,
UIntPtr offsets_len);
IntPtr /*(OrtValue*)*/ value,
byte[] /*(void*)*/ dst_buffer,
UIntPtr dst_buffer_len,
UIntPtr[] offsets,
UIntPtr offsets_len);

public static DOrtGetStringTensorContent OrtGetStringTensorContent;

[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetStringTensorDataLength(IntPtr /*(OrtValue*)*/ value,
out UIntPtr /*(size_t*)*/ len);
out UIntPtr /*(size_t*)*/ len);

public static DOrtGetStringTensorDataLength OrtGetStringTensorDataLength;

[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetStringTensorElementLength(IntPtr /*(OrtValue*)*/ value,
UIntPtr /*(size_t)*/ index,
out UIntPtr /*(size_t*)*/ len);
UIntPtr /*(size_t)*/ index,
out UIntPtr /*(size_t*)*/ len);

public static DOrtGetStringTensorElementLength OrtGetStringTensorElementLength;

[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetStringTensorElement(IntPtr /*(OrtValue*)*/ value,
UIntPtr /*(size_t)*/ bufferLength,
UIntPtr /*(size_t)*/ elementIndex,
byte[] buffer);
UIntPtr /*(size_t)*/ bufferLength,
UIntPtr /*(size_t)*/ elementIndex,
byte[] buffer);

public static DOrtGetStringTensorElement OrtGetStringTensorElement;

[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/
DOrtCastTypeInfoToTensorInfo(IntPtr /*(struct OrtTypeInfo*)*/ typeInfo, out IntPtr /*(const struct OrtTensorTypeAndShapeInfo**)*/ typeAndShapeInfo);
public delegate IntPtr /*(OrtStatus*)*/ DOrtCastTypeInfoToTensorInfo(
IntPtr /*(struct OrtTypeInfo*)*/ typeInfo,
out IntPtr /*(const struct OrtTensorTypeAndShapeInfo**)*/ typeAndShapeInfo);

public static DOrtCastTypeInfoToTensorInfo OrtCastTypeInfoToTensorInfo;

[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetTensorTypeAndShape(IntPtr /*(OrtValue*)*/ value, out IntPtr /*(struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo);
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetTensorTypeAndShape(
IntPtr /*(OrtValue*)*/ value,
out IntPtr /*(struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo);

public static DOrtGetTensorTypeAndShape OrtGetTensorTypeAndShape;

Expand All @@ -1917,12 +1924,16 @@ out IntPtr /* char** */ buffer
public static DOrtReleaseTensorTypeAndShapeInfo OrtReleaseTensorTypeAndShapeInfo;

[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetTensorElementType(IntPtr /*(const struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo, out IntPtr /*(TensorElementType*)*/ output);
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetTensorElementType(
IntPtr /*(const struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo,
out IntPtr /*(TensorElementType*)*/ output);

public static DOrtGetTensorElementType OrtGetTensorElementType;

[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetDimensionsCount(IntPtr /*(const struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo, out UIntPtr output);
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetDimensionsCount(
IntPtr /*(const struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo,
out UIntPtr output);

public static DOrtGetDimensionsCount OrtGetDimensionsCount;

Expand Down
133 changes: 89 additions & 44 deletions csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,20 +40,16 @@ internal enum PropertyType : long
String = 2
}

private void AddPropertyImpl<T>(string propertyName, PropertyType propertyType, T propertyValue)
private void AddPropertyImpl<T>(string propertyName, PropertyType propertyType, T propertyValue) where T : unmanaged
{
var propertyNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyName);
T[] value = new T[1];
value[0] = propertyValue;
Memory<T> memory = value;
using (var memHandle = memory.Pin())
T[] value = { propertyValue };
unsafe
{
IntPtr memPtr;
unsafe
fixed (T* memPtr = value)
{
memPtr = (IntPtr)memHandle.Pointer;
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, propertyType, (IntPtr)memPtr));
}
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, propertyType, memPtr));
}
}

Expand Down Expand Up @@ -103,56 +99,53 @@ public static void SaveCheckpoint(CheckpointState state, string checkpointPath,
}

/// <summary>
/// Adds the given int property to the checkpoint state.
/// Adds or updates the given int property to/in the checkpoint state.
///
/// Runtime properties that are ints such as epoch, training step, and others can be added to the checkpoint
/// state by the user if they desire by calling this function with the appropriate property name and
/// value. The given property name must be unique to be able to successfully add the property.
/// Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint
/// state by the user by calling this function with the corresponding property name and value.
/// The given property name must be unique to be able to successfully add the property.
/// </summary>
/// <param name="propertyName">Unique name of the property being added.</param>
/// <param name="propertyName">Name of the property being added or updated.</param>
/// <param name="propertyValue">Property value associated with the given name.</param>
public void AddProperty(string propertyName, long propertyValue)
{
AddPropertyImpl(propertyName, PropertyType.Int, propertyValue);
}

/// <summary>
/// Adds the given float property to the checkpoint state.
/// Adds or updates the given float property to/in the checkpoint state.
///
/// Runtime properties that are floats such as loss, best score, and others can be added to the checkpoint
/// state by the user if they desire by calling this function with the appropriate property name and
/// value. The given property name must be unique to be able to successfully add the property.
/// Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint
/// state by the user by calling this function with the corresponding property name and value.
/// The given property name must be unique to be able to successfully add the property.
/// </summary>
/// <param name="propertyName">Unique name of the property being added.</param>
/// <param name="propertyName">Name of the property being added or updated.</param>
/// <param name="propertyValue">Property value associated with the given name.</param>
public void AddProperty(string propertyName, float propertyValue)
{
AddPropertyImpl(propertyName, PropertyType.Float, propertyValue);
}

/// <summary>
/// Adds the given string property to the checkpoint state.
/// Adds or updates the given string property to/in the checkpoint state.
///
/// Runtime properties that are strings such as parameter names, custom strings, and others can be added
/// to the checkpoint state by the user if they desire by calling this function with the appropriate property
/// name and value. The given property name must be unique to be able to successfully add the property.
/// Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint
/// state by the user by calling this function with the corresponding property name and value.
/// The given property name must be unique to be able to successfully add the property.
/// </summary>
/// <param name="propertyName">Unique name of the property being added.</param>
/// <param name="propertyName">Name of the property being added or updated.</param>
/// <param name="propertyValue">Property value associated with the given name.</param>
public void AddProperty(string propertyName, string propertyValue)
{
var propertyNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyName);
var propertyValueUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyValue);

IntPtr unmanagedPointer = Marshal.AllocHGlobal(propertyValueUtf8.Length);
try
{
Marshal.Copy(propertyValueUtf8, 0, unmanagedPointer, propertyValueUtf8.Length);
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, PropertyType.String, unmanagedPointer));
}
finally
unsafe
{
Marshal.FreeHGlobal(unmanagedPointer);
fixed (byte* p = propertyValueUtf8)
{
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, PropertyType.String, (IntPtr)p));
}
}
}

Expand All @@ -162,34 +155,86 @@ public void AddProperty(string propertyName, string propertyValue)
/// Gets the property value from an existing entry in the checkpoint state. The property must
/// exist in the checkpoint state to be able to retrieve it successfully.
/// </summary>
/// <param name="propertyName">Unique name of the property being retrieved.</param>
/// <param name="propertyName">Name of the property being retrieved.</param>
/// <returns>Property value associated with the given property name.</returns>
public object GetProperty(string propertyName)
{
var propertyNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyName);
var allocator = OrtAllocator.DefaultInstance;
IntPtr propertyValue = IntPtr.Zero;

NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetProperty(handle, propertyNameUtf8, allocator.Pointer, out PropertyType propertyType, out propertyValue));

if (propertyType == PropertyType.Int)
try
{
var longPropertyValue = Marshal.ReadInt64(propertyValue);
allocator.FreeMemory(propertyValue);
return longPropertyValue;
if (propertyType == PropertyType.Int)
{
Int64 value;
unsafe
{
value = *(Int64*)propertyValue;
}
return value;
}
else if (propertyType == PropertyType.Float)
{
float value;
unsafe
{
value = *(float*)propertyValue;
}
return value;
}
else if (propertyType == PropertyType.String)
{
return NativeOnnxValueHelper.StringFromNativeUtf8(propertyValue);
}

throw new ArgumentException("Expected the property type to be one of long, float or string. Unknown type retrieved " + propertyValue.ToString());
}
else if (propertyType == PropertyType.Float)
finally
{
float[] value = new float[1];
Marshal.Copy(propertyValue, value, 0, 1);
allocator.FreeMemory(propertyValue);
return value[0];
}
else if (propertyType == PropertyType.String)
}

/// <summary>
/// Updates the data associated with the model parameter in the checkpoint state for the given parameter name.
///
/// This function updates a model parameter in the checkpoint state with the given parameter data.
/// The training session must be already created with the checkpoint state that contains the parameter
/// being updated. The given parameter is copied over to the registered device for the training session.
/// The parameter must exist in the checkpoint state to be able to update it successfully.
/// </summary>
/// <param name="parameterName">Name of the parameter being updated.</param>
/// <param name="parameter">The parameter data that should replace the existing parameter data.</param>
public void UpdateParameter(string parameterName, OrtValue parameter)
{
if (parameter.OnnxType != OnnxValueType.ONNX_TYPE_TENSOR)
{
return NativeOnnxValueHelper.StringFromNativeUtf8(propertyValue, allocator);
throw new ArgumentException("Incorrect buffer received. Expected a tensor parameter.");
}

throw new ArgumentException("Expected the property type to be one of long, float or string. Unknown type retrieved " + propertyValue.ToString());
var parameterNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(parameterName);
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtUpdateParameter(handle, parameterNameUtf8, parameter.Handle));
}

/// <summary>
/// Gets the data associated with the model parameter from the checkpoint state for the given parameter name.
///
/// This function retrieves the model parameter data from the checkpoint state for the given parameter name.
/// The parameter is copied over to the provided OrtValue. The training session must be already created
/// with the checkpoint state that contains the parameter being retrieved.
/// The parameter must exist in the checkpoint state to be able to retrieve it successfully.
/// </summary>
/// <param name="parameterName">Name of the parameter being updated.</param>
/// <returns>The parameter data that is retrieved from the checkpoint state.</returns>
public OrtValue GetParameter(string parameterName)
{
var parameterNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(parameterName);
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParameter(handle, parameterNameUtf8, OrtAllocator.DefaultInstance.Pointer, out IntPtr parameterHandle));

return new OrtValue(parameterHandle);
}

#region SafeHandle
Expand Down
Loading

0 comments on commit 2f57f1e

Please sign in to comment.