Skip to content

Commit

Permalink
[On-Device Training] Expose Parameters through the Training API (#17364)
Browse files Browse the repository at this point in the history
  • Loading branch information
baijumeswani authored Sep 26, 2023
1 parent 95e8dfa commit ccb73fd
Show file tree
Hide file tree
Showing 16 changed files with 936 additions and 159 deletions.
133 changes: 89 additions & 44 deletions csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,20 +40,16 @@ internal enum PropertyType : long
String = 2
}

private void AddPropertyImpl<T>(string propertyName, PropertyType propertyType, T propertyValue)
private void AddPropertyImpl<T>(string propertyName, PropertyType propertyType, T propertyValue) where T : unmanaged
{
var propertyNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyName);
T[] value = new T[1];
value[0] = propertyValue;
Memory<T> memory = value;
using (var memHandle = memory.Pin())
T[] value = { propertyValue };
unsafe
{
IntPtr memPtr;
unsafe
fixed (T* memPtr = value)
{
memPtr = (IntPtr)memHandle.Pointer;
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, propertyType, (IntPtr)memPtr));
}
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, propertyType, memPtr));
}
}

Expand Down Expand Up @@ -103,56 +99,53 @@ public static void SaveCheckpoint(CheckpointState state, string checkpointPath,
}

/// <summary>
/// Adds the given int property to the checkpoint state.
/// Adds or updates the given int property to/in the checkpoint state.
///
/// Runtime properties that are ints such as epoch, training step, and others can be added to the checkpoint
/// state by the user if they desire by calling this function with the appropriate property name and
/// value. The given property name must be unique to be able to successfully add the property.
/// Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint
/// state by the user by calling this function with the corresponding property name and value.
/// The given property name must be unique to be able to successfully add the property.
/// </summary>
/// <param name="propertyName">Unique name of the property being added.</param>
/// <param name="propertyName">Name of the property being added or updated.</param>
/// <param name="propertyValue">Property value associated with the given name.</param>
public void AddProperty(string propertyName, long propertyValue)
{
AddPropertyImpl(propertyName, PropertyType.Int, propertyValue);
}

/// <summary>
/// Adds the given float property to the checkpoint state.
/// Adds or updates the given float property to/in the checkpoint state.
///
/// Runtime properties that are floats such as loss, best score, and others can be added to the checkpoint
/// state by the user if they desire by calling this function with the appropriate property name and
/// value. The given property name must be unique to be able to successfully add the property.
/// Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint
/// state by the user by calling this function with the corresponding property name and value.
/// The given property name must be unique to be able to successfully add the property.
/// </summary>
/// <param name="propertyName">Unique name of the property being added.</param>
/// <param name="propertyName">Name of the property being added or updated.</param>
/// <param name="propertyValue">Property value associated with the given name.</param>
public void AddProperty(string propertyName, float propertyValue)
{
AddPropertyImpl(propertyName, PropertyType.Float, propertyValue);
}

/// <summary>
/// Adds the given string property to the checkpoint state.
/// Adds or updates the given string property to/in the checkpoint state.
///
/// Runtime properties that are strings such as parameter names, custom strings, and others can be added
/// to the checkpoint state by the user if they desire by calling this function with the appropriate property
/// name and value. The given property name must be unique to be able to successfully add the property.
/// Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint
/// state by the user by calling this function with the corresponding property name and value.
/// The given property name must be unique to be able to successfully add the property.
/// </summary>
/// <param name="propertyName">Unique name of the property being added.</param>
/// <param name="propertyName">Name of the property being added or updated.</param>
/// <param name="propertyValue">Property value associated with the given name.</param>
public void AddProperty(string propertyName, string propertyValue)
{
var propertyNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyName);
var propertyValueUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyValue);

IntPtr unmanagedPointer = Marshal.AllocHGlobal(propertyValueUtf8.Length);
try
{
Marshal.Copy(propertyValueUtf8, 0, unmanagedPointer, propertyValueUtf8.Length);
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, PropertyType.String, unmanagedPointer));
}
finally
unsafe
{
Marshal.FreeHGlobal(unmanagedPointer);
fixed (byte* p = propertyValueUtf8)
{
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, PropertyType.String, (IntPtr)p));
}
}
}

Expand All @@ -162,34 +155,86 @@ public void AddProperty(string propertyName, string propertyValue)
/// Gets the property value from an existing entry in the checkpoint state. The property must
/// exist in the checkpoint state to be able to retrieve it successfully.
/// </summary>
/// <param name="propertyName">Unique name of the property being retrieved.</param>
/// <param name="propertyName">Name of the property being retrieved.</param>
/// <returns>Property value associated with the given property name.</returns>
public object GetProperty(string propertyName)
{
var propertyNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyName);
var allocator = OrtAllocator.DefaultInstance;
IntPtr propertyValue = IntPtr.Zero;

NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetProperty(handle, propertyNameUtf8, allocator.Pointer, out PropertyType propertyType, out propertyValue));

if (propertyType == PropertyType.Int)
try
{
var longPropertyValue = Marshal.ReadInt64(propertyValue);
allocator.FreeMemory(propertyValue);
return longPropertyValue;
if (propertyType == PropertyType.Int)
{
Int64 value;
unsafe
{
value = *(Int64*)propertyValue;
}
return value;
}
else if (propertyType == PropertyType.Float)
{
float value;
unsafe
{
value = *(float*)propertyValue;
}
return value;
}
else if (propertyType == PropertyType.String)
{
return NativeOnnxValueHelper.StringFromNativeUtf8(propertyValue);
}

throw new ArgumentException("Expected the property type to be one of long, float or string. Unknown type retrieved " + propertyValue.ToString());
}
else if (propertyType == PropertyType.Float)
finally
{
float[] value = new float[1];
Marshal.Copy(propertyValue, value, 0, 1);
allocator.FreeMemory(propertyValue);
return value[0];
}
else if (propertyType == PropertyType.String)
}

/// <summary>
/// Updates the data associated with the model parameter in the checkpoint state for the given parameter name.
///
/// This function updates a model parameter in the checkpoint state with the given parameter data.
/// The training session must be already created with the checkpoint state that contains the parameter
/// being updated. The given parameter is copied over to the registered device for the training session.
/// The parameter must exist in the checkpoint state to be able to update it successfully.
/// </summary>
/// <param name="parameterName">Name of the parameter being updated.</param>
/// <param name="parameter">The parameter data that should replace the existing parameter data.</param>
public void UpdateParameter(string parameterName, OrtValue parameter)
{
if (parameter.OnnxType != OnnxValueType.ONNX_TYPE_TENSOR)
{
return NativeOnnxValueHelper.StringFromNativeUtf8(propertyValue, allocator);
throw new ArgumentException("Incorrect buffer received. Expected a tensor parameter.");
}

throw new ArgumentException("Expected the property type to be one of long, float or string. Unknown type retrieved " + propertyValue.ToString());
var parameterNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(parameterName);
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtUpdateParameter(handle, parameterNameUtf8, parameter.Handle));
}

/// <summary>
/// Gets the data associated with the model parameter from the checkpoint state for the given parameter name.
///
/// This function retrieves the model parameter data from the checkpoint state for the given parameter name.
/// The parameter is copied over to the provided OrtValue. The training session must be already created
/// with the checkpoint state that contains the parameter being retrieved.
/// The parameter must exist in the checkpoint state to be able to retrieve it successfully.
/// </summary>
/// <param name="parameterName">Name of the parameter being updated.</param>
/// <returns>The parameter data that is retrieved from the checkpoint state.</returns>
public OrtValue GetParameter(string parameterName)
{
var parameterNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(parameterName);
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParameter(handle, parameterNameUtf8, OrtAllocator.DefaultInstance.Pointer, out IntPtr parameterHandle));

return new OrtValue(parameterHandle);
}

#region SafeHandle
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ public struct OrtTrainingApi
public IntPtr AddProperty;
public IntPtr GetProperty;
public IntPtr LoadCheckpointFromBuffer;
public IntPtr GetParameterTypeAndShape;
public IntPtr UpdateParameter;
public IntPtr GetParameter;
}

internal static class NativeTrainingMethods
Expand Down Expand Up @@ -97,6 +100,9 @@ static NativeTrainingMethods()
OrtGetEvalModelInputName = (DOrtGetEvalModelInputName)Marshal.GetDelegateForFunctionPointer(trainingApi_.TrainingSessionGetEvalModelInputName, typeof(DOrtGetEvalModelInputName));
OrtAddProperty = (DOrtAddProperty)Marshal.GetDelegateForFunctionPointer(trainingApi_.AddProperty, typeof(DOrtAddProperty));
OrtGetProperty = (DOrtGetProperty)Marshal.GetDelegateForFunctionPointer(trainingApi_.GetProperty, typeof(DOrtGetProperty));
OrtGetParameterTypeAndShape = (DOrtGetParameterTypeAndShape)Marshal.GetDelegateForFunctionPointer(trainingApi_.GetParameterTypeAndShape, typeof(DOrtGetParameterTypeAndShape));
OrtUpdateParameter = (DOrtUpdateParameter)Marshal.GetDelegateForFunctionPointer(trainingApi_.UpdateParameter, typeof(DOrtUpdateParameter));
OrtGetParameter = (DOrtGetParameter)Marshal.GetDelegateForFunctionPointer(trainingApi_.GetParameter, typeof(DOrtGetParameter));
}

}
Expand Down Expand Up @@ -359,6 +365,34 @@ out UIntPtr inputCount

public static DOrtGetProperty OrtGetProperty;

[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetParameterTypeAndShape(
IntPtr /*(OrtCheckpointState*)*/ checkpointState,
byte[] /*(const char*)*/ parameterName,
out IntPtr /*(OrtTensorTypeAndShapeInfo**)*/ parameterTypeAndShape
);

public static DOrtGetParameterTypeAndShape OrtGetParameterTypeAndShape;

[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtUpdateParameter(
IntPtr /*(OrtCheckpointState*)*/ checkpointState,
byte[] /*(const char*)*/ parameterName,
IntPtr /*(OrtValue*)*/ parameter
);

public static DOrtUpdateParameter OrtUpdateParameter;

[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetParameter(
IntPtr /*(OrtCheckpointState*)*/ checkpointState,
byte[] /*(const char*)*/ parameterName,
IntPtr /*(OrtAllocator*)*/ allocator,
out IntPtr /*(OrtValue**)*/ parameter
);

public static DOrtGetParameter OrtGetParameter;

#endregion TrainingSession API

public static bool TrainingEnabled()
Expand Down
Loading

0 comments on commit ccb73fd

Please sign in to comment.