diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
index 2ba837be22041..f722ca9d30fa4 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
@@ -1860,54 +1860,61 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca
public static DOrtFillStringTensor OrtFillStringTensor;
+ /// \param value A tensor created from OrtCreateTensor... function.
+ /// \param index The index of the entry in the tensor to resize.
+ /// \param length_in_bytes Length to resize the string to.
+ /// \param buffer The resized buffer.
+ [UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetResizedStringTensorElementBuffer(
- IntPtr /* OrtValue */ value,
- UIntPtr /* size_t */ index,
- UIntPtr /* size_t */ length_in_bytes,
- out IntPtr /* char** */ buffer
- );
+ IntPtr /* OrtValue */ value,
+ UIntPtr /* size_t */ index,
+ UIntPtr /* size_t */ length_in_bytes,
+ out IntPtr /* char** */ buffer);
public static DOrtGetResizedStringTensorElementBuffer OrtGetResizedStringTensorElementBuffer;
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetStringTensorContent(
- IntPtr /*(OrtValue*)*/ value,
- byte[] /*(void*)*/ dst_buffer,
- UIntPtr dst_buffer_len,
- UIntPtr[] offsets,
- UIntPtr offsets_len);
+ IntPtr /*(OrtValue*)*/ value,
+ byte[] /*(void*)*/ dst_buffer,
+ UIntPtr dst_buffer_len,
+ UIntPtr[] offsets,
+ UIntPtr offsets_len);
public static DOrtGetStringTensorContent OrtGetStringTensorContent;
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetStringTensorDataLength(IntPtr /*(OrtValue*)*/ value,
- out UIntPtr /*(size_t*)*/ len);
+ out UIntPtr /*(size_t*)*/ len);
public static DOrtGetStringTensorDataLength OrtGetStringTensorDataLength;
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetStringTensorElementLength(IntPtr /*(OrtValue*)*/ value,
- UIntPtr /*(size_t)*/ index,
- out UIntPtr /*(size_t*)*/ len);
+ UIntPtr /*(size_t)*/ index,
+ out UIntPtr /*(size_t*)*/ len);
public static DOrtGetStringTensorElementLength OrtGetStringTensorElementLength;
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetStringTensorElement(IntPtr /*(OrtValue*)*/ value,
- UIntPtr /*(size_t)*/ bufferLength,
- UIntPtr /*(size_t)*/ elementIndex,
- byte[] buffer);
+ UIntPtr /*(size_t)*/ bufferLength,
+ UIntPtr /*(size_t)*/ elementIndex,
+ byte[] buffer);
public static DOrtGetStringTensorElement OrtGetStringTensorElement;
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
- public delegate IntPtr /*(OrtStatus*)*/
- DOrtCastTypeInfoToTensorInfo(IntPtr /*(struct OrtTypeInfo*)*/ typeInfo, out IntPtr /*(const struct OrtTensorTypeAndShapeInfo**)*/ typeAndShapeInfo);
+ public delegate IntPtr /*(OrtStatus*)*/ DOrtCastTypeInfoToTensorInfo(
+ IntPtr /*(struct OrtTypeInfo*)*/ typeInfo,
+ out IntPtr /*(const struct OrtTensorTypeAndShapeInfo**)*/ typeAndShapeInfo);
public static DOrtCastTypeInfoToTensorInfo OrtCastTypeInfoToTensorInfo;
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
- public delegate IntPtr /*(OrtStatus*)*/ DOrtGetTensorTypeAndShape(IntPtr /*(OrtValue*)*/ value, out IntPtr /*(struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo);
+ public delegate IntPtr /*(OrtStatus*)*/ DOrtGetTensorTypeAndShape(
+ IntPtr /*(OrtValue*)*/ value,
+ out IntPtr /*(struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo);
public static DOrtGetTensorTypeAndShape OrtGetTensorTypeAndShape;
@@ -1917,12 +1924,16 @@ out IntPtr /* char** */ buffer
public static DOrtReleaseTensorTypeAndShapeInfo OrtReleaseTensorTypeAndShapeInfo;
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
- public delegate IntPtr /*(OrtStatus*)*/ DOrtGetTensorElementType(IntPtr /*(const struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo, out IntPtr /*(TensorElementType*)*/ output);
+ public delegate IntPtr /*(OrtStatus*)*/ DOrtGetTensorElementType(
+ IntPtr /*(const struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo,
+ out IntPtr /*(TensorElementType*)*/ output);
public static DOrtGetTensorElementType OrtGetTensorElementType;
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
- public delegate IntPtr /*(OrtStatus*)*/ DOrtGetDimensionsCount(IntPtr /*(const struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo, out UIntPtr output);
+ public delegate IntPtr /*(OrtStatus*)*/ DOrtGetDimensionsCount(
+ IntPtr /*(const struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo,
+ out UIntPtr output);
public static DOrtGetDimensionsCount OrtGetDimensionsCount;
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs
index 659c6303702ac..6889112acb385 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs
@@ -40,20 +40,16 @@ internal enum PropertyType : long
String = 2
}
- private void AddPropertyImpl(string propertyName, PropertyType propertyType, T propertyValue)
+ private void AddPropertyImpl(string propertyName, PropertyType propertyType, T propertyValue) where T : unmanaged
{
var propertyNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyName);
- T[] value = new T[1];
- value[0] = propertyValue;
- Memory memory = value;
- using (var memHandle = memory.Pin())
+ T[] value = { propertyValue };
+ unsafe
{
- IntPtr memPtr;
- unsafe
+ fixed (T* memPtr = value)
{
- memPtr = (IntPtr)memHandle.Pointer;
+ NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, propertyType, (IntPtr)memPtr));
}
- NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, propertyType, memPtr));
}
}
@@ -103,13 +99,13 @@ public static void SaveCheckpoint(CheckpointState state, string checkpointPath,
}
///
- /// Adds the given int property to the checkpoint state.
+ /// Adds or updates the given int property to/in the checkpoint state.
///
- /// Runtime properties that are ints such as epoch, training step, and others can be added to the checkpoint
- /// state by the user if they desire by calling this function with the appropriate property name and
- /// value. The given property name must be unique to be able to successfully add the property.
+ /// Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint
+ /// state by the user by calling this function with the corresponding property name and value.
+ /// The given property name must be unique to be able to successfully add the property.
///
- /// Unique name of the property being added.
+ /// Name of the property being added or updated.
/// Property value associated with the given name.
public void AddProperty(string propertyName, long propertyValue)
{
@@ -117,13 +113,13 @@ public void AddProperty(string propertyName, long propertyValue)
}
///
- /// Adds the given float property to the checkpoint state.
+ /// Adds or updates the given float property to/in the checkpoint state.
///
- /// Runtime properties that are floats such as loss, best score, and others can be added to the checkpoint
- /// state by the user if they desire by calling this function with the appropriate property name and
- /// value. The given property name must be unique to be able to successfully add the property.
+ /// Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint
+ /// state by the user by calling this function with the corresponding property name and value.
+ /// The given property name must be unique to be able to successfully add the property.
///
- /// Unique name of the property being added.
+ /// Name of the property being added or updated.
/// Property value associated with the given name.
public void AddProperty(string propertyName, float propertyValue)
{
@@ -131,28 +127,25 @@ public void AddProperty(string propertyName, float propertyValue)
}
///
- /// Adds the given string property to the checkpoint state.
+ /// Adds or updates the given string property to/in the checkpoint state.
///
- /// Runtime properties that are strings such as parameter names, custom strings, and others can be added
- /// to the checkpoint state by the user if they desire by calling this function with the appropriate property
- /// name and value. The given property name must be unique to be able to successfully add the property.
+ /// Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint
+ /// state by the user by calling this function with the corresponding property name and value.
+ /// The given property name must be unique to be able to successfully add the property.
///
- /// Unique name of the property being added.
+ /// Name of the property being added or updated.
/// Property value associated with the given name.
public void AddProperty(string propertyName, string propertyValue)
{
var propertyNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyName);
var propertyValueUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyValue);
- IntPtr unmanagedPointer = Marshal.AllocHGlobal(propertyValueUtf8.Length);
- try
- {
- Marshal.Copy(propertyValueUtf8, 0, unmanagedPointer, propertyValueUtf8.Length);
- NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, PropertyType.String, unmanagedPointer));
- }
- finally
+ unsafe
{
- Marshal.FreeHGlobal(unmanagedPointer);
+ fixed (byte* p = propertyValueUtf8)
+ {
+ NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, PropertyType.String, (IntPtr)p));
+ }
}
}
@@ -162,34 +155,86 @@ public void AddProperty(string propertyName, string propertyValue)
/// Gets the property value from an existing entry in the checkpoint state. The property must
/// exist in the checkpoint state to be able to retrieve it successfully.
///
- /// Unique name of the property being retrieved.
+ /// Name of the property being retrieved.
/// Property value associated with the given property name.
public object GetProperty(string propertyName)
{
var propertyNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyName);
var allocator = OrtAllocator.DefaultInstance;
IntPtr propertyValue = IntPtr.Zero;
+
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetProperty(handle, propertyNameUtf8, allocator.Pointer, out PropertyType propertyType, out propertyValue));
- if (propertyType == PropertyType.Int)
+ try
{
- var longPropertyValue = Marshal.ReadInt64(propertyValue);
- allocator.FreeMemory(propertyValue);
- return longPropertyValue;
+ if (propertyType == PropertyType.Int)
+ {
+ Int64 value;
+ unsafe
+ {
+ value = *(Int64*)propertyValue;
+ }
+ return value;
+ }
+ else if (propertyType == PropertyType.Float)
+ {
+ float value;
+ unsafe
+ {
+ value = *(float*)propertyValue;
+ }
+ return value;
+ }
+ else if (propertyType == PropertyType.String)
+ {
+ return NativeOnnxValueHelper.StringFromNativeUtf8(propertyValue);
+ }
+
+ throw new ArgumentException("Expected the property type to be one of long, float or string. Unknown type retrieved " + propertyValue.ToString());
}
- else if (propertyType == PropertyType.Float)
+ finally
{
- float[] value = new float[1];
- Marshal.Copy(propertyValue, value, 0, 1);
allocator.FreeMemory(propertyValue);
- return value[0];
}
- else if (propertyType == PropertyType.String)
+ }
+
+ ///
+ /// Updates the data associated with the model parameter in the checkpoint state for the given parameter name.
+ ///
+ /// This function updates a model parameter in the checkpoint state with the given parameter data.
+ /// The training session must be already created with the checkpoint state that contains the parameter
+ /// being updated. The given parameter is copied over to the registered device for the training session.
+ /// The parameter must exist in the checkpoint state to be able to update it successfully.
+ ///
+ /// Name of the parameter being updated.
+ /// The parameter data that should replace the existing parameter data.
+ public void UpdateParameter(string parameterName, OrtValue parameter)
+ {
+ if (parameter.OnnxType != OnnxValueType.ONNX_TYPE_TENSOR)
{
- return NativeOnnxValueHelper.StringFromNativeUtf8(propertyValue, allocator);
+ throw new ArgumentException("Incorrect buffer received. Expected a tensor parameter.");
}
- throw new ArgumentException("Expected the property type to be one of long, float or string. Unknown type retrieved " + propertyValue.ToString());
+ var parameterNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(parameterName);
+ NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtUpdateParameter(handle, parameterNameUtf8, parameter.Handle));
+ }
+
+ ///
+ /// Gets the data associated with the model parameter from the checkpoint state for the given parameter name.
+ ///
+ /// This function retrieves the model parameter data from the checkpoint state for the given parameter name.
+ /// The parameter is copied over to the provided OrtValue. The training session must be already created
+ /// with the checkpoint state that contains the parameter being retrieved.
+ /// The parameter must exist in the checkpoint state to be able to retrieve it successfully.
+ ///
+ /// Name of the parameter being updated.
+ /// The parameter data that is retrieved from the checkpoint state.
+ public OrtValue GetParameter(string parameterName)
+ {
+ var parameterNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(parameterName);
+ NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParameter(handle, parameterNameUtf8, OrtAllocator.DefaultInstance.Pointer, out IntPtr parameterHandle));
+
+ return new OrtValue(parameterHandle);
}
#region SafeHandle
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs
index ac790242409e3..d6341b90f28ff 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs
@@ -42,6 +42,9 @@ public struct OrtTrainingApi
public IntPtr AddProperty;
public IntPtr GetProperty;
public IntPtr LoadCheckpointFromBuffer;
+ public IntPtr GetParameterTypeAndShape;
+ public IntPtr UpdateParameter;
+ public IntPtr GetParameter;
}
internal static class NativeTrainingMethods
@@ -97,6 +100,9 @@ static NativeTrainingMethods()
OrtGetEvalModelInputName = (DOrtGetEvalModelInputName)Marshal.GetDelegateForFunctionPointer(trainingApi_.TrainingSessionGetEvalModelInputName, typeof(DOrtGetEvalModelInputName));
OrtAddProperty = (DOrtAddProperty)Marshal.GetDelegateForFunctionPointer(trainingApi_.AddProperty, typeof(DOrtAddProperty));
OrtGetProperty = (DOrtGetProperty)Marshal.GetDelegateForFunctionPointer(trainingApi_.GetProperty, typeof(DOrtGetProperty));
+ OrtGetParameterTypeAndShape = (DOrtGetParameterTypeAndShape)Marshal.GetDelegateForFunctionPointer(trainingApi_.GetParameterTypeAndShape, typeof(DOrtGetParameterTypeAndShape));
+ OrtUpdateParameter = (DOrtUpdateParameter)Marshal.GetDelegateForFunctionPointer(trainingApi_.UpdateParameter, typeof(DOrtUpdateParameter));
+ OrtGetParameter = (DOrtGetParameter)Marshal.GetDelegateForFunctionPointer(trainingApi_.GetParameter, typeof(DOrtGetParameter));
}
}
@@ -359,6 +365,34 @@ out UIntPtr inputCount
public static DOrtGetProperty OrtGetProperty;
+ [UnmanagedFunctionPointer(CallingConvention.Winapi)]
+ public delegate IntPtr /*(OrtStatus*)*/ DOrtGetParameterTypeAndShape(
+ IntPtr /*(OrtCheckpointState*)*/ checkpointState,
+ byte[] /*(const char*)*/ parameterName,
+ out IntPtr /*(OrtTensorTypeAndShapeInfo**)*/ parameterTypeAndShape
+ );
+
+ public static DOrtGetParameterTypeAndShape OrtGetParameterTypeAndShape;
+
+ [UnmanagedFunctionPointer(CallingConvention.Winapi)]
+ public delegate IntPtr /*(OrtStatus*)*/ DOrtUpdateParameter(
+ IntPtr /*(OrtCheckpointState*)*/ checkpointState,
+ byte[] /*(const char*)*/ parameterName,
+ IntPtr /*(OrtValue*)*/ parameter
+ );
+
+ public static DOrtUpdateParameter OrtUpdateParameter;
+
+ [UnmanagedFunctionPointer(CallingConvention.Winapi)]
+ public delegate IntPtr /*(OrtStatus*)*/ DOrtGetParameter(
+ IntPtr /*(OrtCheckpointState*)*/ checkpointState,
+ byte[] /*(const char*)*/ parameterName,
+ IntPtr /*(OrtAllocator*)*/ allocator,
+ out IntPtr /*(OrtValue**)*/ parameter
+ );
+
+ public static DOrtGetParameter OrtGetParameter;
+
#endregion TrainingSession API
public static bool TrainingEnabled()
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs
index 33993c2be135b..877677dcad57b 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs
@@ -358,13 +358,14 @@ public void EvalStep(
IReadOnlyCollection inputValues,
IReadOnlyCollection outputValues)
{
- if (!_evalOutputCount.Equals(outputValues.Count))
+ if (_evalOutputCount != (ulong)outputValues.Count())
{
- throw new ArgumentException($"Length of {nameof(outputValues)} ({outputValues.Count}) must match that of train model ({_trainOutputCount}).");
+ throw new ArgumentException($"Length of {nameof(outputValues)} ({outputValues.Count}) must match that of eval model ({_evalOutputCount}).");
}
- IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues, true);
+ const bool isInput = true;
+ IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues, isInput);
- IntPtr[] outputValuesArray = GetOrtValuesHandles(outputValues, false); /* pointers to Pre-allocated OrtValue instances */
+ IntPtr[] outputValuesArray = GetOrtValuesHandles(outputValues, !isInput); /* pointers to Pre-allocated OrtValue instances */
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtEvalStep(_nativeHandle, options.Handle, (UIntPtr)inputValues.Count,
inputValuesArray, (UIntPtr)outputValues.Count, outputValuesArray));
}
@@ -509,18 +510,17 @@ public void ExportModelForInferencing(string inferenceModelPath, IReadOnlyCollec
/// Returns a contiguous buffer that holds a copy of all training state parameters
///
/// Whether to only copy trainable parameters or to copy all parameters.
- public FixedBufferOnnxValue ToBuffer(bool onlyTrainable)
+ public OrtValue ToBuffer(bool onlyTrainable)
{
UIntPtr bufferSize = UIntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out bufferSize, onlyTrainable));
float[] bufferMemory = new float[bufferSize.ToUInt64()];
- var memInfo = OrtMemoryInfo.DefaultInstance; // CPU
- var shape = new long[] { (long)bufferSize.ToUInt64() };
- var buffer = FixedBufferOnnxValue.CreateFromMemory(memInfo, bufferMemory, Tensors.TensorElementType.Float, shape, (long)bufferSize.ToUInt64() * sizeof(float));
+ var shape = new long[] { (long)bufferSize };
+ var buffer = OrtValue.CreateAllocatedTensorValue(OrtAllocator.DefaultInstance, Tensors.TensorElementType.Float, shape);
- NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyParametersToBuffer(_nativeHandle, buffer.Value.Handle, onlyTrainable));
+ NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyParametersToBuffer(_nativeHandle, buffer.Handle, onlyTrainable));
return buffer;
}
@@ -528,45 +528,30 @@ public FixedBufferOnnxValue ToBuffer(bool onlyTrainable)
///
/// Loads the training session model parameters from a contiguous buffer
///
- /// Contiguous buffer to load the parameters from.
- public void FromBuffer(FixedBufferOnnxValue buffer)
+ /// Contiguous buffer to load the parameters from.
+ /// Whether to only load trainable parameters or to load all parameters.
+ public void FromBuffer(OrtValue ortValue, bool onlyTrainable)
{
- if (buffer.OnnxValueType != OnnxValueType.ONNX_TYPE_TENSOR)
+ if (ortValue.OnnxType != OnnxValueType.ONNX_TYPE_TENSOR)
{
throw new ArgumentException("Incorrect buffer received. Expected a tensor buffer.");
}
- IntPtr typeAndShapeInfo = IntPtr.Zero;
- NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorTypeAndShape(buffer.Value.Handle, out typeAndShapeInfo));
- UIntPtr numDimensions = UIntPtr.Zero;
- NativeApiStatus.VerifySuccess(NativeMethods.OrtGetDimensionsCount(typeAndShapeInfo, out numDimensions));
- if (numDimensions.ToUInt64() != 1)
+ var tensorInfo = ortValue.GetTensorTypeAndShape();
+ if (tensorInfo.ElementDataType != Tensors.TensorElementType.Float)
{
- string errorMessage = "Incorrect buffer shape received. Expected a contiguous tensor buffer. Expected number of dimensions: 1, Actual: " + numDimensions.ToString();
- throw new ArgumentException(errorMessage);
- }
-
- // Here buffer size represents the number of elements in the buffer
- NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorShapeElementCount(typeAndShapeInfo, out UIntPtr bufferSize));
-
- // OrtGetParametersSize returns the total number of elements in the model's parameters.
- UIntPtr numElementsTrainingOnly = UIntPtr.Zero;
- NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out numElementsTrainingOnly, true));
- if ((ulong)bufferSize == (ulong)numElementsTrainingOnly)
- {
- NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyBufferToParameters(_nativeHandle, buffer.Value.Handle, true));
- return;
+ throw new ArgumentException("Incorrect buffer received. Expected a tensor buffer of type float.");
}
UIntPtr numElements = UIntPtr.Zero;
- NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out numElements, false));
- if ((ulong)bufferSize != (ulong)numElements)
+ NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out numElements, onlyTrainable));
+ if ((ulong)tensorInfo.ElementCount != (ulong)numElements)
{
- string errorMessage = "Incorrect buffer size received. Expected size to be one of " + numElementsTrainingOnly.ToString() + " (training only) or " + numElements.ToString() + " (all parameters). Actual size: " + bufferSize.ToString();
+ string errorMessage = "Incorrect buffer size received. Expected size to be " + numElements.ToString() + ". Actual size: " + tensorInfo.ElementCount.ToString();
throw new ArgumentException(errorMessage);
}
- NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyBufferToParameters(_nativeHandle, buffer.Value.Handle, false));
+ NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyBufferToParameters(_nativeHandle, ortValue.Handle, onlyTrainable));
}
///
diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs
index ea2b6d7dbc118..68b1d5bcc6147 100644
--- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs
+++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs
@@ -484,20 +484,23 @@ public void TestEvalModelOutputNames()
public void TestToBuffer()
{
string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
- using (var cleanUp = new DisposableListTest())
+ string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
+ string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx");
+ string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx");
+
+ using (var state = CheckpointState.LoadCheckpoint(checkpointPath))
+ using (var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath))
{
- var state = CheckpointState.LoadCheckpoint(checkpointPath);
- cleanUp.Add(state);
Assert.NotNull(state);
- string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
- string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx");
- string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx");
- var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath);
- cleanUp.Add(trainingSession);
-
- var buffer = trainingSession.ToBuffer(true);
- cleanUp.Add(buffer);
+ using (var buffer = trainingSession.ToBuffer(true))
+ {
+ Assert.NotNull(buffer);
+ var typeShape = buffer.GetTensorTypeAndShape();
+ Assert.Equal(1, typeShape.DimensionsCount);
+ var fetchedShape = typeShape.Shape;
+ Assert.Equal(397510, fetchedShape[0]);
+ }
}
}
@@ -505,22 +508,25 @@ public void TestToBuffer()
public void TestFromBuffer()
{
string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
- using (var cleanUp = new DisposableListTest())
+ string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
+ string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx");
+ string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx");
+
+ using (var state = CheckpointState.LoadCheckpoint(checkpointPath))
+ using (var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath))
{
- var state = CheckpointState.LoadCheckpoint(checkpointPath);
- cleanUp.Add(state);
Assert.NotNull(state);
- string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
- string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx");
- string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx");
-
- var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath);
- cleanUp.Add(trainingSession);
- var buffer = trainingSession.ToBuffer(true);
- cleanUp.Add(buffer);
+ using (var buffer = trainingSession.ToBuffer(true))
+ {
+ Assert.NotNull(buffer);
+ var typeShape = buffer.GetTensorTypeAndShape();
+ Assert.Equal(1, typeShape.DimensionsCount);
+ var fetchedShape = typeShape.Shape;
+ Assert.Equal(397510, fetchedShape[0]);
- trainingSession.FromBuffer(buffer);
+ trainingSession.FromBuffer(buffer, true);
+ }
}
}
@@ -530,6 +536,82 @@ public void TestSetSeed()
TrainingUtils.SetSeed(8888);
}
+ [Fact(DisplayName = "TestGetParameter")]
+ public void TestGetParameter()
+ {
+ string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
+ string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
+ string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx");
+ string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx");
+
+ using (var state = CheckpointState.LoadCheckpoint(checkpointPath))
+ using (var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath))
+ using (var parameter = state.GetParameter("fc1.weight"))
+ {
+ Assert.NotNull(state);
+ Assert.NotNull(parameter);
+
+ var typeShape = parameter.GetTensorTypeAndShape();
+ Assert.Equal(2, typeShape.DimensionsCount);
+ var fetchedShape = typeShape.Shape;
+ Assert.Equal(500, fetchedShape[0]);
+ Assert.Equal(784, fetchedShape[1]);
+ }
+ }
+
+ [Fact(DisplayName = "TestUpdateParameter")]
+ public void TestUpdateParameter()
+ {
+ string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
+ string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
+ string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx");
+ string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx");
+
+ using (var state = CheckpointState.LoadCheckpoint(checkpointPath))
+ using (var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath))
+ {
+ Assert.NotNull(state);
+
+ using (var parameter = state.GetParameter("fc1.weight"))
+ {
+ Assert.NotNull(parameter);
+ var typeShape = parameter.GetTensorTypeAndShape();
+
+ Assert.Equal(2, typeShape.DimensionsCount);
+ var fetchedShape = typeShape.Shape;
+ Assert.Equal(500, fetchedShape[0]);
+ Assert.Equal(784, fetchedShape[1]);
+
+ float maxVal = 20;
+ Random randNum = new Random();
+ float[] updated_parameter_buffer = Enumerable
+ .Repeat(0, 500 * 784)
+ .Select(i => maxVal * (float)randNum.NextDouble())
+ .ToArray();
+
+ using (var updated_parameter = OrtValue.CreateTensorValueFromMemory(updated_parameter_buffer, fetchedShape))
+ {
+ state.UpdateParameter("fc1.weight", updated_parameter);
+ using (var current_parameter = state.GetParameter("fc1.weight"))
+ {
+ var current_parameter_tensor = current_parameter.GetTensorDataAsSpan().ToArray();
+ Assert.Equal(updated_parameter_buffer, current_parameter_tensor);
+ Assert.NotEqual(parameter.GetTensorDataAsSpan().ToArray(), current_parameter_tensor);
+ }
+
+ state.UpdateParameter("fc1.weight", parameter);
+
+ using (var current_parameter = state.GetParameter("fc1.weight"))
+ {
+ var current_parameter_tensor = current_parameter.GetTensorDataAsSpan().ToArray();
+ Assert.Equal(parameter.GetTensorDataAsSpan().ToArray(), current_parameter_tensor);
+ Assert.NotEqual(updated_parameter_buffer, current_parameter_tensor);
+ }
+ }
+ }
+ }
+ }
+
internal class FloatComparer : IEqualityComparer
{
private float atol = 1e-3f;
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index da4c248b58d94..c4984e423df21 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -763,6 +763,7 @@ Do not modify directly.*
|Shrink|*in* input:**T**
*out* output:**T**|9+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Sigmoid|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
+|Sign|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|SimplifiedLayerNormalization|*in* X:**T**
*in* scale:**V**
*out* Y:**V**
*out* inv_std_var:**U**|1+|**T** = tensor(double), tensor(float), tensor(float16)
**U** = tensor(double), tensor(float)
**V** = tensor(double), tensor(float), tensor(float16)|
|Sin|*in* input:**T**
*out* output:**T**|7+|**T** = tensor(double), tensor(float), tensor(float16)|
|Size|*in* data:**T**
*out* size:**T1**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)|
diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp
index 86b7450a7c4e5..32cc69d0b8040 100644
--- a/onnxruntime/core/mlas/lib/platform.cpp
+++ b/onnxruntime/core/mlas/lib/platform.cpp
@@ -451,12 +451,16 @@ Return Value:
#if defined(_WIN32)
HasDotProductInstructions = (IsProcessorFeaturePresent(PF_ARM_V82_DP_INSTRUCTIONS_AVAILABLE) != 0);
-#elif !defined(__APPLE__) // The next few lines result in an EXC_BAD_INSTRUCTION runtime error on a M1 Mac so we
- // disable it there.
- uint64_t isar0_el1;
- asm("mrs %[reg], ID_AA64ISAR0_EL1\n" : [reg] "=r"(isar0_el1) : :);
- HasDotProductInstructions = ((isar0_el1 >> 44) & 0xfu) == 0x1u;
#else
+ // Use the cpuinfo value which is read from sysctl and has some additional special cases.
+ // https://github.com/pytorch/cpuinfo/blob/959002f82d7962a473d8bf301845f2af720e0aa4/src/arm/mach/init.c#L369-L379
+ // Do NOT use ID_AA64ISAR0_EL1. It causes illegal instruction errors on Mac M1 and ARMv8-A chips
+ // as well as failing on other ARM chips as it is an EL1 level register that requires extra
+ // privileges to read.
+ //
+ // uint64_t isar0_el1;
+ // asm("mrs %[reg], ID_AA64ISAR0_EL1\n" : [reg] "=r"(isar0_el1) : :);
+ // HasDotProductInstructions = ((isar0_el1 >> 44) & 0xfu) == 0x1u;
HasDotProductInstructions = MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeonDot();
#endif
diff --git a/onnxruntime/core/providers/cuda/cu_inc/common.cuh b/onnxruntime/core/providers/cuda/cu_inc/common.cuh
index a50b53315ec9a..0d9928baa86e0 100644
--- a/onnxruntime/core/providers/cuda/cu_inc/common.cuh
+++ b/onnxruntime/core/providers/cuda/cu_inc/common.cuh
@@ -20,7 +20,7 @@ namespace cuda {
// float16 arithmetic is supported after sm5.3 with intrinsics, and cuda does not provide fallback for lower versions
// CUDA 12.2 does not limit the definition based on sm53 anymore and defines for all arches
-#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 530) && ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12 ) && (__CUDACC_VER_MINOR__ < 2)))
+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 530) && ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 2)))
__device__ __forceinline__ half operator+(const half& lh, const half& rh) { return half((float)lh + (float)rh); }
__device__ __forceinline__ half operator-(const half& lh, const half& rh) { return half((float)lh - (float)rh); }
__device__ __forceinline__ half operator*(const half& lh, const half& rh) { return half((float)lh * (float)rh); }
@@ -351,6 +351,18 @@ __device__ __inline__ T _Max(T a, T b) { return a > b ? a : b; }
template
__device__ __inline__ T _Abs(T a) { return a > (T)0 ? a : -a; }
+template
+__device__ __inline__ T _Signum(T a, std::false_type /* is_signed */) { return T(0) < a; }
+
+template
+__device__ __inline__ T _Signum(T a, std::true_type /* is_signed */) { return (T(0) < a) - (a < T(0)); }
+
+template
+__device__ __inline__ T _Sign(T a) { return _Signum(a, std::is_signed()); }
+
+template <>
+__device__ __inline__ half _Sign(half a) { return _Signum(a, std::true_type()); }
+
template
__device__ __inline__ T _Normcdf(T a);
diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
index aa60db4d07222..ad892eab3b843 100644
--- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
+++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
@@ -1180,6 +1180,17 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, bool, Pad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, SpaceToDepth);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, DepthToSpace);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int8_t, Sign);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int16_t, Sign);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, Sign);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int64_t, Sign);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint8_t, Sign);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint16_t, Sign);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint32_t, Sign);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint64_t, Sign);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Sign);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Sign);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Sign);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Add);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Sub);
@@ -2118,6 +2129,17 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc
index f026444328b24..9ede1f8d90ecc 100644
--- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc
+++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc
@@ -157,6 +157,7 @@ UNARY_OP_HFD(Sqrt, 13)
UNARY_OP_HFD(Log, 13)
UNARY_OP_HFD(Exp, 13)
UNARY_OP_HFD(Erf, 13)
+UNARY_OP_BWUZCSILHFD(Sign, 13)
UNARY_LOGICALOP_NOT_TYPED(1, bool)
UNARY_OP_HFD(Round, 11)
diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h
index 3ff97a60114df..775b78c43a736 100644
--- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h
+++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h
@@ -112,5 +112,12 @@ class Cos final : public UnaryElementwise {
Status ComputeInternal(OpKernelContext* context) const override;
};
+template
+class Sign final : public UnaryElementwise {
+ public:
+ Sign(const OpKernelInfo& info) : UnaryElementwise(info) {}
+ Status ComputeInternal(OpKernelContext* context) const override;
+};
+
} // namespace cuda
} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu
index ac7cc1126acb7..1298d53338337 100644
--- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu
+++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu
@@ -90,6 +90,7 @@ SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Round)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Sin)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Cos)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL(Not, bool)
+SPECIALIZED_UNARY_ELEMENTWISE_IMPL_BWUZCSILHFD(Sign)
// When casting, half needs to be converted via float type from most other types
template
@@ -119,52 +120,52 @@ struct OP_Cast {
}
};
-#define IMPL_CAST_IMPL(InT, OutT) \
+#define IMPL_CAST_IMPL(InT, OutT) \
void Explicit_Impl_Cast(cudaStream_t stream, const InT* input_data, OutT* output_data, size_t count) { \
- UnaryElementWiseImpl(stream, input_data, output_data, OP_Cast(), count); \
+ UnaryElementWiseImpl(stream, input_data, output_data, OP_Cast(), count); \
}
-#define IMPL_CAST_IMPL_THROW(InT, OutT) \
+#define IMPL_CAST_IMPL_THROW(InT, OutT) \
void Explicit_Impl_Cast(cudaStream_t stream, const InT* input_data, OutT* output_data, size_t count) { \
- ORT_THROW("Cast from " #InT " to " #OutT " must define saturate."); \
+ ORT_THROW("Cast from " #InT " to " #OutT " must define saturate."); \
}
#if !defined(DISABLE_FLOAT8_TYPES)
-#define IMPL_CAST_IMPL_FROM(T) \
- IMPL_CAST_IMPL(T, half) \
- IMPL_CAST_IMPL(T, float) \
- IMPL_CAST_IMPL(T, double) \
- IMPL_CAST_IMPL(T, int8_t) \
- IMPL_CAST_IMPL(T, int16_t) \
- IMPL_CAST_IMPL(T, int32_t) \
- IMPL_CAST_IMPL(T, int64_t) \
- IMPL_CAST_IMPL(T, uint8_t) \
- IMPL_CAST_IMPL(T, uint16_t) \
- IMPL_CAST_IMPL(T, uint32_t) \
- IMPL_CAST_IMPL(T, uint64_t) \
- IMPL_CAST_IMPL(T, bool) \
- IMPL_CAST_IMPL(T, BFloat16) \
- IMPL_CAST_IMPL_THROW(T, Float8E4M3FN) \
- IMPL_CAST_IMPL_THROW(T, Float8E5M2) \
+#define IMPL_CAST_IMPL_FROM(T) \
+ IMPL_CAST_IMPL(T, half) \
+ IMPL_CAST_IMPL(T, float) \
+ IMPL_CAST_IMPL(T, double) \
+ IMPL_CAST_IMPL(T, int8_t) \
+ IMPL_CAST_IMPL(T, int16_t) \
+ IMPL_CAST_IMPL(T, int32_t) \
+ IMPL_CAST_IMPL(T, int64_t) \
+ IMPL_CAST_IMPL(T, uint8_t) \
+ IMPL_CAST_IMPL(T, uint16_t) \
+ IMPL_CAST_IMPL(T, uint32_t) \
+ IMPL_CAST_IMPL(T, uint64_t) \
+ IMPL_CAST_IMPL(T, bool) \
+ IMPL_CAST_IMPL(T, BFloat16) \
+ IMPL_CAST_IMPL_THROW(T, Float8E4M3FN) \
+ IMPL_CAST_IMPL_THROW(T, Float8E5M2) \
IMPL_CAST_IMPL_THROW(T, Float8E4M3FNUZ) \
IMPL_CAST_IMPL_THROW(T, Float8E5M2FNUZ)
#else
-#define IMPL_CAST_IMPL_FROM(T) \
- IMPL_CAST_IMPL(T, half) \
- IMPL_CAST_IMPL(T, float) \
- IMPL_CAST_IMPL(T, double) \
- IMPL_CAST_IMPL(T, int8_t) \
- IMPL_CAST_IMPL(T, int16_t) \
- IMPL_CAST_IMPL(T, int32_t) \
- IMPL_CAST_IMPL(T, int64_t) \
- IMPL_CAST_IMPL(T, uint8_t) \
- IMPL_CAST_IMPL(T, uint16_t) \
- IMPL_CAST_IMPL(T, uint32_t) \
- IMPL_CAST_IMPL(T, uint64_t) \
- IMPL_CAST_IMPL(T, bool) \
+#define IMPL_CAST_IMPL_FROM(T) \
+ IMPL_CAST_IMPL(T, half) \
+ IMPL_CAST_IMPL(T, float) \
+ IMPL_CAST_IMPL(T, double) \
+ IMPL_CAST_IMPL(T, int8_t) \
+ IMPL_CAST_IMPL(T, int16_t) \
+ IMPL_CAST_IMPL(T, int32_t) \
+ IMPL_CAST_IMPL(T, int64_t) \
+ IMPL_CAST_IMPL(T, uint8_t) \
+ IMPL_CAST_IMPL(T, uint16_t) \
+ IMPL_CAST_IMPL(T, uint32_t) \
+ IMPL_CAST_IMPL(T, uint64_t) \
+ IMPL_CAST_IMPL(T, bool) \
IMPL_CAST_IMPL(T, BFloat16)
#endif
@@ -199,58 +200,58 @@ struct OP_CastNoSat {
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11080
-#define OP_CAST(T, NVT) \
- template <> \
- struct OP_CastSat { \
- __device__ __inline__ T operator()(const half& v) const { \
+#define OP_CAST(T, NVT) \
+ template <> \
+ struct OP_CastSat { \
+ __device__ __inline__ T operator()(const half& v) const { \
return T(static_cast(__nv_cvt_halfraw_to_fp8(v, __NV_SATFINITE, NVT)), T::FromBits()); \
- } \
- }; \
- template <> \
- struct OP_CastNoSat { \
- __device__ __inline__ T operator()(const half& v) const { \
- return T(static_cast(__nv_cvt_halfraw_to_fp8(v, __NV_NOSAT, NVT)), T::FromBits()); \
- } \
- }; \
- template <> \
- struct OP_CastSat { \
- __device__ __inline__ T operator()(const float& v) const { \
- return T(static_cast(__nv_cvt_float_to_fp8(v, __NV_SATFINITE, NVT)), T::FromBits()); \
- } \
- }; \
- template <> \
- struct OP_CastNoSat { \
- __device__ __inline__ T operator()(const float& v) const { \
- return T(static_cast(__nv_cvt_float_to_fp8(v, __NV_NOSAT, NVT)), T::FromBits()); \
- } \
+ } \
+ }; \
+ template <> \
+ struct OP_CastNoSat { \
+ __device__ __inline__ T operator()(const half& v) const { \
+ return T(static_cast(__nv_cvt_halfraw_to_fp8(v, __NV_NOSAT, NVT)), T::FromBits()); \
+ } \
+ }; \
+ template <> \
+ struct OP_CastSat { \
+ __device__ __inline__ T operator()(const float& v) const { \
+ return T(static_cast(__nv_cvt_float_to_fp8(v, __NV_SATFINITE, NVT)), T::FromBits()); \
+ } \
+ }; \
+ template <> \
+ struct OP_CastNoSat { \
+ __device__ __inline__ T operator()(const float& v) const { \
+ return T(static_cast(__nv_cvt_float_to_fp8(v, __NV_NOSAT, NVT)), T::FromBits()); \
+ } \
};
#else
-#define OP_CAST(T, NVT) \
- template <> \
- struct OP_CastSat { \
- __device__ __inline__ T operator()(const half& v) const { \
- return T(__half2float(v), true); \
- } \
- }; \
- template <> \
- struct OP_CastNoSat { \
- __device__ __inline__ T operator()(const half& v) const { \
- return T(__half2float(v), false); \
- } \
- }; \
- template <> \
- struct OP_CastSat { \
+#define OP_CAST(T, NVT) \
+ template <> \
+ struct OP_CastSat { \
+ __device__ __inline__ T operator()(const half& v) const { \
+ return T(__half2float(v), true); \
+ } \
+ }; \
+ template <> \
+ struct OP_CastNoSat { \
+ __device__ __inline__ T operator()(const half& v) const { \
+ return T(__half2float(v), false); \
+ } \
+ }; \
+ template <> \
+ struct OP_CastSat { \
__device__ __inline__ T operator()(const float& v) const { \
- return T(v, true); \
- } \
- }; \
- template <> \
- struct OP_CastNoSat { \
+ return T(v, true); \
+ } \
+ }; \
+ template <> \
+ struct OP_CastNoSat { \
__device__ __inline__ T operator()(const float& v) const { \
- return T(v, false); \
- } \
+ return T(v, false); \
+ } \
};
#endif
@@ -260,14 +261,13 @@ struct OP_CastNoSat {
OP_CAST(Float8E4M3FN, __NV_E4M3)
OP_CAST(Float8E5M2, __NV_E5M2)
-
-#define EXPLICIT_IMPL_CASTSAT(InT, OutT) \
+#define EXPLICIT_IMPL_CASTSAT(InT, OutT) \
void Explicit_Impl_CastSat(cudaStream_t stream, const InT* input_data, OutT* output_data, size_t count, bool saturate) { \
- if (saturate) { \
- UnaryElementWiseImpl(stream, input_data, output_data, OP_CastSat(), count); \
- } else { \
- UnaryElementWiseImpl(stream, input_data, output_data, OP_CastNoSat(), count); \
- } \
+ if (saturate) { \
+ UnaryElementWiseImpl(stream, input_data, output_data, OP_CastSat(), count); \
+ } else { \
+ UnaryElementWiseImpl(stream, input_data, output_data, OP_CastNoSat(), count); \
+ } \
}
EXPLICIT_IMPL_CASTSAT(float, Float8E4M3FN)
diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h
index 3d4868b54abe6..608a81a24cf4f 100644
--- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h
+++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h
@@ -31,7 +31,8 @@ namespace cuda {
UNARY_OP_NAME_EXPR(Not, !a) \
UNARY_OP_NAME_EXPR(Round, _Round(a)) \
UNARY_OP_NAME_EXPR(Sin, _Sin(a)) \
- UNARY_OP_NAME_EXPR(Cos, _Cos(a))
+ UNARY_OP_NAME_EXPR(Cos, _Cos(a)) \
+ UNARY_OP_NAME_EXPR(Sign, _Sign(a))
#define UNARY_ELEMENTWISE_IMPL_DECLARATION(name) \
template \
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h
index cf93d24c29bcf..074f13b309181 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h
@@ -83,7 +83,7 @@ namespace Windows::AI::MachineLearning::Adapter
// Either nodesAsOperatorDesc or nodesAsIDMLOperator can have non-zero size.
struct DmlGraphNodeCreateInfo
{
- uint32_t nodeCount;
+ uint32_t nodeCount = 0;
std::vector> nodesAsOperatorDesc;
std::vector> nodesAsIDMLOperator;
std::vector inputEdges;
diff --git a/onnxruntime/core/providers/rocm/cu_inc/common.cuh b/onnxruntime/core/providers/rocm/cu_inc/common.cuh
index 5c516aac65aab..429ceb1f7c699 100644
--- a/onnxruntime/core/providers/rocm/cu_inc/common.cuh
+++ b/onnxruntime/core/providers/rocm/cu_inc/common.cuh
@@ -250,6 +250,18 @@ __device__ __inline__ T _Max(T a, T b) { return a > b ? a : b; }
template
__device__ __inline__ T _Abs(T a) { return a > (T)0 ? a : -a; }
+template
+__device__ __inline__ T _Signum(T a, std::false_type /* is_signed */) { return T(0) < a; }
+
+template
+__device__ __inline__ T _Signum(T a, std::true_type /* is_signed */) { return (T(0) < a) - (a < T(0)); }
+
+template
+__device__ __inline__ T _Sign(T a) { return _Signum(a, std::is_signed()); }
+
+template <>
+__device__ __inline__ half _Sign(half a) { return _Signum(a, std::true_type()); }
+
template
__device__ __inline__ T _Normcdf(T a);
@@ -337,7 +349,7 @@ struct GridDim {
};
// aligned vector generates vectorized load/store
-template
+template
struct alignas(sizeof(T) * vec_size) aligned_vector {
T val[vec_size];
};
@@ -350,11 +362,11 @@ struct alignas(sizeof(T) * vec_size) aligned_vector {
// HIP_KERNEL_ASSERT is a macro that wraps an assert() call inside rocm kernels.
// TODO ROCM added support recently, should verify.
#define HIP_KERNEL_ASSERT(...)
-//#define HIP_KERNEL_ASSERT(...) assert(__VA_ARGS__)
+// #define HIP_KERNEL_ASSERT(...) assert(__VA_ARGS__)
// WARP related definitions and functions
constexpr int GPU_WARP_SIZE = warpSize;
-inline int GPU_WARP_SIZE_HOST= warpSizeDynamic();
+inline int GPU_WARP_SIZE_HOST = warpSizeDynamic();
template
__device__ __forceinline__ T WARP_SHFL(T value, int srcLane, int width = GPU_WARP_SIZE, unsigned int mask = 0xffffffff) {
diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc
index 9401de64269b9..e6ea876d8957c 100644
--- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc
+++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc
@@ -1105,6 +1105,17 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kO
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, uint8_t, QuantizeLinear);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, int8_t, DequantizeLinear);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, uint8_t, DequantizeLinear);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int8_t, Sign);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int16_t, Sign);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int32_t, Sign);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int64_t, Sign);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint8_t, Sign);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint16_t, Sign);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint32_t, Sign);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint64_t, Sign);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Sign);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Sign);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Sign);
// OpSet 14
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, CumSum);
@@ -2067,6 +2078,17 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
// OpSet 14
BuildKernelCreateInfo,
diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
index 88a576f3ffa73..ac92d46ca87fc 100644
--- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
+++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
@@ -792,6 +792,10 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
if (info.has_user_compute_stream) {
external_stream_ = true;
stream_ = static_cast(info.user_compute_stream);
+ ORT_IGNORE_RETURN_VALUE(CUBLAS_CALL(cublasCreate(&external_cublas_handle_)));
+ ORT_IGNORE_RETURN_VALUE(CUBLAS_CALL(cublasSetStream(external_cublas_handle_, stream_)));
+ ORT_IGNORE_RETURN_VALUE(CUDNN_CALL(cudnnCreate(&external_cudnn_handle_)));
+ ORT_IGNORE_RETURN_VALUE(CUDNN_CALL(cudnnSetStream(external_cudnn_handle_, stream_)));
}
std::string profile_min_shapes, profile_max_shapes, profile_opt_shapes;
@@ -1860,6 +1864,7 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph,
} else if (number_of_trt_nodes == number_of_ort_nodes) {
LOGS_DEFAULT(INFO) << "[TensorRT EP] Whole graph will run on TensorRT execution provider";
} else {
+ sync_stream_after_enqueue_ = true;
LOGS_DEFAULT(INFO) << "[TensorRT EP] Graph is partitioned and number of subgraphs running on TensorRT execution provider is " << number_of_subgraphs;
}
@@ -2372,7 +2377,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorallocate_func, context->release_func, context->allocator_handle, context->node_name,
&parsers_[context->node_name], &engines_[context->node_name], &contexts_[context->node_name], &builders_[context->node_name],
&networks_[context->node_name], input_info_[context->node_name], output_info_[context->node_name],
- input_shape_ranges_[context->node_name], &tensorrt_mu_, fp16_enable_, int8_enable_, int8_calibration_cache_available_,
+ input_shape_ranges_[context->node_name], sync_stream_after_enqueue_, &tensorrt_mu_, fp16_enable_, int8_enable_, int8_calibration_cache_available_,
dla_enable_, dla_core_, &max_workspace_size_, trt_node_name_with_precision, engine_cache_enable_, cache_path_,
runtime_.get(), profiles_[context->node_name], context_memory_sharing_enable_, &max_ctx_mem_size_,
dynamic_range_map, engine_decryption_enable_, engine_decryption_, engine_encryption_, timing_cache_enable_,
@@ -2400,6 +2405,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector& input_indexes = (trt_state->input_info)[0];
const std::unordered_map& output_indexes = (trt_state->output_info)[0];
const std::unordered_map& output_types = (trt_state->output_info)[1];
+ bool sync_stream_after_enqueue = trt_state->sync_stream_after_enqueue;
auto fused_node_name = trt_state->fused_node_name;
auto& shape_ranges = trt_state->input_shape_ranges;
auto trt_builder = trt_state->builder->get();
@@ -3001,6 +3007,10 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector> input_info;
std::vector> output_info;
std::unordered_map>>> input_shape_ranges;
+ bool sync_stream_after_enqueue = false;
OrtMutex* tensorrt_mu_ptr = nullptr;
bool fp16_enable = false;
bool int8_enable = false;
@@ -262,6 +263,9 @@ class TensorrtExecutionProvider : public IExecutionProvider {
cudnnHandle_t external_cudnn_handle_ = nullptr;
cublasHandle_t external_cublas_handle_ = nullptr;
+ // Call cudaStreamSynchronize() after TRT enqueueV2()/enqueueV3()
+ mutable bool sync_stream_after_enqueue_ = false;
+
CUDAGraph cuda_graph_;
bool is_graph_captured_ = false;
int regular_run_count_before_graph_capture_ = 0;
diff --git a/onnxruntime/test/providers/cpu/math/sign_test.cc b/onnxruntime/test/providers/cpu/math/sign_test.cc
index 12844068c47d2..15b3f40faa791 100644
--- a/onnxruntime/test/providers/cpu/math/sign_test.cc
+++ b/onnxruntime/test/providers/cpu/math/sign_test.cc
@@ -113,7 +113,7 @@ TestImpl(ForwardIter first, ForwardIter last, OutputIter out) {
TEST(MathOpTest, Sign_uint64) {
using namespace test_sign_internal;
- OpTester test("Sign", 9);
+ OpTester test("Sign", 13);
std::vector input_dims{7};
std::vector input;
@@ -129,7 +129,7 @@ TEST(MathOpTest, Sign_uint64) {
// we disable this test for openvino as openvino ep supports only FP32 Precision
TEST(MathOpTest, Sign_int64) {
using namespace test_sign_internal;
- OpTester test("Sign", 9);
+ OpTester test("Sign", 13);
std::vector input_dims{7};
std::vector input;
@@ -146,7 +146,7 @@ TEST(MathOpTest, Sign_int64) {
TEST(MathOpTest, Sign_float) {
using namespace test_sign_internal;
- OpTester test("Sign", 9);
+ OpTester test("Sign", 13);
std::vector input_dims{7};
std::vector input;
@@ -162,7 +162,7 @@ TEST(MathOpTest, Sign_float) {
TEST(MathOpTest, Sign_double) {
using namespace test_sign_internal;
- OpTester test("Sign", 9);
+ OpTester test("Sign", 13);
std::vector input_dims{7};
std::vector input;
@@ -177,7 +177,7 @@ TEST(MathOpTest, Sign_double) {
}
TEST(MathOpTest, Sign_MLFloat16) {
using namespace test_sign_internal;
- OpTester test("Sign", 9);
+ OpTester test("Sign", 13);
std::vector input_dims{7};
std::vector input;
diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc
index 3f3aa396e6ca0..35d9755ba0ba7 100644
--- a/orttraining/orttraining/python/orttraining_pybind_state.cc
+++ b/orttraining/orttraining/python/orttraining_pybind_state.cc
@@ -1065,17 +1065,60 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn
checkpoint_state(m, "CheckpointState", R"pbdoc(CheckpointState.)pbdoc");
checkpoint_state
.def(py::init())
- .def("add_property", [](onnxruntime::training::api::CheckpointState* state,
- const std::string& property_name,
- const std::variant& property_value) {
- state->property_bag.AddProperty(property_name, property_value);
- })
- .def("get_property", [](onnxruntime::training::api::CheckpointState* state, const std::string& property_name) {
- return state->property_bag.GetProperty(property_name);
- })
- .def("has_property", [](onnxruntime::training::api::CheckpointState* state, const std::string& property_name) {
- return state->property_bag.HasProperty(property_name);
- });
+ .def("add_property",
+ [](onnxruntime::training::api::CheckpointState* state,
+ const std::string& property_name,
+ const std::variant& property_value) {
+ state->property_bag.AddProperty(property_name, property_value);
+ })
+ .def("get_property",
+ [](onnxruntime::training::api::CheckpointState* state, const std::string& property_name) {
+ return state->property_bag.GetProperty(property_name);
+ })
+ .def("has_property",
+ [](onnxruntime::training::api::CheckpointState* state, const std::string& property_name) {
+ return state->property_bag.HasProperty(property_name);
+ })
+ .def("copy_parameter_from",
+ [](onnxruntime::training::api::CheckpointState* state,
+ const std::string& parameter_name, OrtValue& value) -> void {
+ auto it = state->module_checkpoint_state.named_parameters.find(parameter_name);
+ if (it == state->module_checkpoint_state.named_parameters.end()) {
+ ORT_THROW("Parameter with name ", parameter_name, " does not exist.");
+ }
+ ORT_THROW_IF_ERROR(it->second->CopyFrom(
+ state->module_checkpoint_state.train_session_data_transfer_mgr, value));
+ })
+ .def("get_parameter",
+ [](onnxruntime::training::api::CheckpointState* state, const std::string& parameter_name) {
+ auto it = state->module_checkpoint_state.named_parameters.find(parameter_name);
+ if (it == state->module_checkpoint_state.named_parameters.end()) {
+ ORT_THROW("Parameter with name ", parameter_name, " does not exist.");
+ }
+ return it->second;
+ })
+ .def("has_parameter",
+ [](onnxruntime::training::api::CheckpointState* state, const std::string& parameter_name) {
+ return state->module_checkpoint_state.named_parameters.count(parameter_name);
+ })
+ .def("parameter_names",
+ [](onnxruntime::training::api::CheckpointState* state) {
+ std::vector names;
+ for ([[maybe_unused]] auto& [name, value] : state->module_checkpoint_state.named_parameters) {
+ names.push_back(name);
+ }
+ std::sort(names.begin(), names.end());
+ return names;
+ })
+ .def("property_names",
+ [](onnxruntime::training::api::CheckpointState* state) {
+ std::vector names;
+ for ([[maybe_unused]] auto& [name, value] : state->property_bag) {
+ names.push_back(name);
+ }
+ std::sort(names.begin(), names.end());
+ return names;
+ });
py::class_
training_optimizer(m, "Optimizer", R"pbdoc(Training Optimizer.)pbdoc");
@@ -1111,6 +1154,21 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn
ORT_THROW_IF_ERROR(scheduler->Step());
});
+ py::class_>
+ parameter(m, "Parameter");
+ parameter
+ .def_property_readonly("name", &onnxruntime::training::api::Parameter::Name)
+ .def_property_readonly("data", &onnxruntime::training::api::Parameter::Data)
+ .def_property_readonly("grad", &onnxruntime::training::api::Parameter::Gradient)
+ .def_property_readonly("requires_grad", &onnxruntime::training::api::Parameter::RequiresGrad)
+ .def("copy_from",
+ [](onnxruntime::training::api::Parameter* parameter,
+ onnxruntime::training::api::CheckpointState* state,
+ OrtValue& value) -> void {
+ ORT_THROW_IF_ERROR(parameter->CopyFrom(state->module_checkpoint_state.train_session_data_transfer_mgr, value));
+ });
+
m.def(
"save_checkpoint",
[](const std::vector& trainable_tensor_protos_pybytes,
diff --git a/orttraining/orttraining/python/training/api/checkpoint_state.py b/orttraining/orttraining/python/training/api/checkpoint_state.py
index 285264bbed744..ba95cd04fce7e 100644
--- a/orttraining/orttraining/python/training/api/checkpoint_state.py
+++ b/orttraining/orttraining/python/training/api/checkpoint_state.py
@@ -5,70 +5,171 @@
import os
+import numpy as np
+
from onnxruntime.capi import _pybind_state as C
+from onnxruntime.capi.onnxruntime_inference_collection import OrtValue
-class CheckpointState:
- """Class that holds the state of the training session
+class Parameter:
+ """Class that represents a model parameter
- This class holds all the state information of the training session such as the model parameters,
- its gradients, the optimizer state and user defined properties.
+ This class represents a model parameter and provides access to its data,
+ gradient and other properties. This class is not expected to be instantiated directly.
+ Instead, it is returned by the `CheckpointState` object.
+
+ Args:
+ parameter: The C.Parameter object that holds the underlying parameter data.
+ state: The C.CheckpointState object that holds the underlying session state.
+ """
+
+ def __init__(self, parameter: C.Parameter, state: C.CheckpointState):
+ self._parameter = parameter
+ self._state = state
- User defined properties can be indexed by name from the `CheckpointState` object.
+ @property
+ def name(self) -> str:
+ """The name of the parameter"""
+ return self._parameter.name
- To create the `CheckpointState`, use the `CheckpointState.load_checkpoint` method.
+ @property
+ def data(self) -> np.ndarray:
+ """The data of the parameter"""
+ return self._parameter.data.numpy()
+
+ @data.setter
+ def data(self, value: np.ndarray) -> None:
+ """Sets the data of the parameter"""
+ self._parameter.copy_from(self._state, OrtValue.ortvalue_from_numpy(value)._ortvalue)
+
+ @property
+ def grad(self) -> np.ndarray:
+ """The gradient of the parameter"""
+ return self._parameter.grad.numpy() if self._parameter.grad.has_value() else None
+
+ @property
+ def requires_grad(self) -> bool:
+ """Whether or not the parameter requires its gradient to be computed"""
+ return self._parameter.requires_grad
+
+ def __repr__(self) -> str:
+ """Returns a string representation of the parameter"""
+ return f"Parameter(name={self.name}, requires_grad={self.requires_grad})"
+
+
+class Parameters:
+ """Class that holds all the model parameters
+
+ This class holds all the model parameters and provides access to them.
+ This class is not expected to be instantiated directly. Instead, it is returned by the
+ `CheckpointState`'s parameters attribute.
+ This class behaves like a dictionary and provides access to the parameters by name.
Args:
- state: The C.Checkpoint state object that holds the underlying session state.
+ state: The C.CheckpointState object that holds the underlying session state.
"""
def __init__(self, state: C.CheckpointState):
- if not isinstance(state, C.CheckpointState):
- raise TypeError(f"Invalid argument for CheckpointState received {type(state)}")
self._state = state
- @classmethod
- def load_checkpoint(cls, checkpoint_uri: str | os.PathLike) -> CheckpointState:
- """Loads the checkpoint state from the checkpoint file
+ def __getitem__(self, name: str) -> Parameter:
+ """Gets the parameter associated with the given name
+
+ Searches for the name in the parameters of the checkpoint state.
Args:
- checkpoint_uri: The path to the checkpoint file.
+ name: The name of the parameter
Returns:
- CheckpointState: The checkpoint state object.
+ The value of the parameter
+
+ Raises:
+ KeyError: If the parameter is not found
"""
- return cls(C.load_checkpoint(os.fspath(checkpoint_uri)))
- @classmethod
- def save_checkpoint(
- cls, state: CheckpointState, checkpoint_uri: str | os.PathLike, include_optimizer_state: bool = False
- ) -> None:
- """Saves the checkpoint state to the checkpoint file
+ if name not in self:
+ raise KeyError(f"Parameter {name} not found.")
+
+ return Parameter(self._state.get_parameter(name), self._state)
+
+ def __setitem__(self, name: str, value: np.ndarray) -> None:
+ """Sets the parameter value for the given name
+
+ Searches for the name in the parameters of the checkpoint state.
+ If the name is found in parameters, the value is updated.
Args:
- state: The checkpoint state object.
- checkpoint_uri: The path to the checkpoint file.
- include_optimizer_state: If True, the optimizer state is also saved to the checkpoint file.
+ name: The name of the parameter
+ value: The value of the parameter as a numpy array
+
+ Raises:
+ KeyError: If the parameter is not found
"""
- C.save_checkpoint(state._state, os.fspath(checkpoint_uri), include_optimizer_state)
+ if name not in self:
+ raise KeyError(f"Parameter {name} not found.")
+
+ self._state.copy_parameter_from(name, OrtValue.ortvalue_from_numpy(value)._ortvalue)
+
+ def __contains__(self, name: str) -> bool:
+ """Checks if the parameter exists in the state
+
+ Args:
+ name: The name of the parameter
+
+ Returns:
+ True if the name is a parameter False otherwise
+ """
+
+ return self._state.has_parameter(name)
+
+ def __iter__(self):
+ """Returns an iterator over the properties"""
+ for parameter_name in self._state.parameter_names():
+ yield parameter_name, Parameter(self._state.get_parameter(parameter_name), self._state)
+
+ def __repr__(self) -> str:
+ """Returns a string representation of the parameters"""
+ return self._state.parameter_names()
+
+ def __len__(self) -> int:
+ """Returns the number of parameters"""
+ return len(self._state.parameter_names())
+
+
+class Properties:
+ def __init__(self, state: C.CheckpointState):
+ self._state = state
def __getitem__(self, name: str) -> int | float | str:
"""Gets the property associated with the given name
+ Searches for the name in the properties of the checkpoint state.
+
Args:
name: The name of the property
Returns:
The value of the property
+
+ Raises:
+ KeyError: If the property is not found
"""
+
+ if name not in self:
+ raise KeyError(f"Property {name} not found.")
+
return self._state.get_property(name)
def __setitem__(self, name: str, value: int | float | str) -> None:
"""Sets the property value for the given name
+ Searches for the name in the properties of the checkpoint state.
+ The value is added or updated in the properties.
+
Args:
name: The name of the property
value: The value of the property
+ Properties only support int, float and str values.
"""
self._state.add_property(name, value)
@@ -79,6 +180,75 @@ def __contains__(self, name: str) -> bool:
name: The name of the property
Returns:
- True if the property exists, False otherwise
+ True if the name is a property, False otherwise
"""
+
return self._state.has_property(name)
+
+ def __iter__(self):
+ """Returns an iterator over the properties"""
+ for property_name in self._state.property_names():
+ yield property_name, self._state.get_property(property_name)
+
+ def __repr__(self) -> str:
+ """Returns a string representation of the properties"""
+ return self._state.property_names()
+
+ def __len__(self) -> int:
+ """Returns the number of properties"""
+ return len(self._state.property_names())
+
+
+class CheckpointState:
+ """Class that holds the state of the training session
+
+ This class holds all the state information of the training session such as the model parameters,
+ its gradients, the optimizer state and user defined properties.
+
+ To create the `CheckpointState`, use the `CheckpointState.load_checkpoint` method.
+
+ Args:
+ state: The C.Checkpoint state object that holds the underlying session state.
+ """
+
+ def __init__(self, state: C.CheckpointState):
+ if not isinstance(state, C.CheckpointState):
+ raise TypeError(f"Invalid argument for CheckpointState received {type(state)}")
+ self._state = state
+ self._parameters = Parameters(self._state)
+ self._properties = Properties(self._state)
+
+ @classmethod
+ def load_checkpoint(cls, checkpoint_uri: str | os.PathLike) -> CheckpointState:
+ """Loads the checkpoint state from the checkpoint file
+
+ Args:
+ checkpoint_uri: The path to the checkpoint file.
+
+ Returns:
+ CheckpointState: The checkpoint state object.
+ """
+ return cls(C.load_checkpoint(os.fspath(checkpoint_uri)))
+
+ @classmethod
+ def save_checkpoint(
+ cls, state: CheckpointState, checkpoint_uri: str | os.PathLike, include_optimizer_state: bool = False
+ ) -> None:
+ """Saves the checkpoint state to the checkpoint file
+
+ Args:
+ state: The checkpoint state object.
+ checkpoint_uri: The path to the checkpoint file.
+ include_optimizer_state: If True, the optimizer state is also saved to the checkpoint file.
+ """
+ C.save_checkpoint(state._state, os.fspath(checkpoint_uri), include_optimizer_state)
+
+ @property
+ def parameters(self) -> Parameters:
+ """Returns the model parameters from the checkpoint state"""
+ return self._parameters
+
+ @property
+ def properties(self) -> Properties:
+ """Returns the properties from the checkpoint state"""
+ return self._properties
diff --git a/orttraining/orttraining/test/python/orttraining_test_python_bindings.py b/orttraining/orttraining/test/python/orttraining_test_python_bindings.py
index 56338ddbaffef..d5c37b3e36ee7 100644
--- a/orttraining/orttraining/test/python/orttraining_test_python_bindings.py
+++ b/orttraining/orttraining/test/python/orttraining_test_python_bindings.py
@@ -360,14 +360,18 @@ def test_add_get_property(property_value):
if isinstance(property_value, float):
property_value = float(np.float32(property_value))
- state["property"] = property_value
- assert "property" in state
- assert state["property"] == property_value
+ assert len(state.properties) == 0
+
+ state.properties["property"] = property_value
+ assert "property" in state.properties
+ assert state.properties["property"] == property_value
+ assert len(state.properties) == 1
CheckpointState.save_checkpoint(state, checkpoint_file_path)
new_state = CheckpointState.load_checkpoint(checkpoint_file_path)
- assert "property" in new_state
- assert new_state["property"] == property_value
+ assert "property" in new_state.properties
+ assert new_state.properties["property"] == property_value
+ assert len(new_state.properties) == 1
def test_get_input_output_names():
@@ -563,3 +567,60 @@ def test_eval_step_with_ort_values():
fetches = model(inputs, labels)
assert isinstance(fetches, OrtValue)
assert fetches
+
+
+@pytest.mark.parametrize("device", ["cpu", "cuda"])
+def test_get_and_set_parameter_values(device):
+ with tempfile.TemporaryDirectory() as temp_dir:
+ (
+ checkpoint_file_path,
+ training_model_file_path,
+ eval_model_file_path,
+ _,
+ pt_model,
+ ) = _create_training_artifacts(
+ temp_dir, requires_grad=["fc2.weight", "fc2.bias"], frozen_params=["fc1.weight", "fc1.bias"]
+ )
+
+ state = CheckpointState.load_checkpoint(checkpoint_file_path)
+
+ model = Module(training_model_file_path, state, eval_model_file_path, device=device)
+
+ state_dict = pt_model.state_dict()
+ assert len(state_dict) == len(state.parameters)
+ for parameter_name, _ in state.parameters:
+ assert parameter_name in state_dict
+
+ for name, pt_param in pt_model.named_parameters():
+ ort_param = state.parameters[name]
+ assert ort_param.name == name
+ assert np.allclose(pt_param.detach().cpu().numpy(), ort_param.data)
+ if name in ["fc1.weight", "fc1.bias"]:
+ assert ort_param.requires_grad is False
+ assert ort_param.grad is None
+ else:
+ assert ort_param.requires_grad is True
+ assert np.allclose(ort_param.grad, np.zeros_like(ort_param.data, dtype=np.float32))
+
+ original_param = state.parameters["fc1.weight"].data
+ state.parameters["fc1.weight"].data = np.ones_like(state.parameters["fc1.weight"].data, dtype=np.float32)
+ updated_param = state.parameters["fc1.weight"].data
+ assert np.allclose(updated_param, np.ones_like(updated_param, dtype=np.float32))
+
+ model.train()
+ inputs = torch.randn(64, 784).numpy()
+ labels = torch.randint(high=10, size=(64,), dtype=torch.int64).numpy()
+ loss = model(inputs, labels)
+ assert loss is not None
+ for name, _ in pt_model.named_parameters():
+ ort_param = state.parameters[name]
+ assert ort_param.name == name
+ if name in ["fc1.weight", "fc1.bias"]:
+ assert ort_param.requires_grad is False
+ assert ort_param.grad is None
+ else:
+ assert ort_param.requires_grad is True
+ assert ort_param.grad.any()
+
+ state.parameters["fc1.weight"] = original_param
+ assert np.allclose(state.parameters["fc1.weight"].data, original_param)
diff --git a/orttraining/orttraining/test/training_api/core/training_capi_tests.cc b/orttraining/orttraining/test/training_api/core/training_capi_tests.cc
index d734be8e3474b..e46952d87c2bf 100644
--- a/orttraining/orttraining/test/training_api/core/training_capi_tests.cc
+++ b/orttraining/orttraining/test/training_api/core/training_capi_tests.cc
@@ -318,4 +318,106 @@ TEST(TrainingCApiTest, LoadModelsFromBufferThrows) {
testing::HasSubstr("Training Session Creation failed. Train model data cannot be NULL."));
}
}
+
+TEST(TrainingCApiTest, GetParameter) {
+ auto model_uri = MODEL_FOLDER "training_model.onnx";
+
+ Ort::Env env;
+ Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt");
+ Ort::TrainingSession training_session = Ort::TrainingSession(env, Ort::SessionOptions(), checkpoint_state, model_uri);
+
+ Ort::Value parameter = checkpoint_state.GetParameter("fc1.weight");
+ auto tensor_info = parameter.GetTensorTypeAndShapeInfo();
+ auto shape = tensor_info.GetShape();
+ ASSERT_EQ(shape.size(), 2U);
+ ASSERT_EQ(shape.front(), static_cast(500));
+ ASSERT_EQ(shape.back(), static_cast(784));
+}
+
+TEST(TrainingCApiTest, UpdateParameter) {
+ auto model_uri = MODEL_FOLDER "training_model.onnx";
+
+ Ort::Env env;
+ Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt");
+ Ort::TrainingSession training_session = Ort::TrainingSession(env, Ort::SessionOptions(), checkpoint_state, model_uri);
+
+ Ort::Value parameter = checkpoint_state.GetParameter("fc1.weight");
+ auto tensor_info = parameter.GetTensorTypeAndShapeInfo();
+ auto shape = tensor_info.GetShape();
+ ASSERT_EQ(shape.size(), 2U);
+ ASSERT_EQ(shape.front(), static_cast(500));
+ ASSERT_EQ(shape.back(), static_cast(784));
+
+ OrtValue* updated_param_value = std::make_unique().release();
+ GenerateRandomInput(std::array{500, 784}, *updated_param_value);
+ Ort::Value updated_parameter{updated_param_value};
+ checkpoint_state.UpdateParameter("fc1.weight", updated_parameter);
+
+ Ort::Value current_parameter = checkpoint_state.GetParameter("fc1.weight");
+ gsl::span actual = gsl::span(current_parameter.GetTensorMutableData(),
+ current_parameter.GetTensorTypeAndShapeInfo().GetElementCount());
+ gsl::span expected = gsl::span(updated_parameter.GetTensorMutableData(),
+ updated_parameter.GetTensorTypeAndShapeInfo().GetElementCount());
+ gsl::span not_expected = gsl::span(parameter.GetTensorMutableData(),
+ parameter.GetTensorTypeAndShapeInfo().GetElementCount());
+ ASSERT_EQ(actual, expected);
+ ASSERT_NE(actual, not_expected);
+
+ checkpoint_state.UpdateParameter("fc1.weight", parameter);
+ current_parameter = checkpoint_state.GetParameter("fc1.weight");
+ actual = gsl::span(current_parameter.GetTensorMutableData(),
+ current_parameter.GetTensorTypeAndShapeInfo().GetElementCount());
+ expected = gsl::span(parameter.GetTensorMutableData(),
+ parameter.GetTensorTypeAndShapeInfo().GetElementCount());
+ not_expected = gsl::span(updated_parameter.GetTensorMutableData(),
+ updated_parameter.GetTensorTypeAndShapeInfo().GetElementCount());
+ ASSERT_EQ(actual, expected);
+ ASSERT_NE(actual, not_expected);
+}
+
+#ifdef USE_CUDA
+TEST(TrainingCApiTest, UpdateParameterDifferentDevices) {
+ auto model_uri = MODEL_FOLDER "training_model.onnx";
+
+ Ort::Env env;
+ Ort::SessionOptions session_options;
+ Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0));
+ Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt");
+ Ort::TrainingSession training_session = Ort::TrainingSession(env, session_options, checkpoint_state, model_uri);
+
+ Ort::Value parameter = checkpoint_state.GetParameter("fc1.weight");
+ auto tensor_info = parameter.GetTensorTypeAndShapeInfo();
+ auto shape = tensor_info.GetShape();
+ ASSERT_EQ(shape.size(), 2U);
+ ASSERT_EQ(shape.front(), static_cast(500));
+ ASSERT_EQ(shape.back(), static_cast(784));
+
+ OrtValue* updated_param_value = std::make_unique().release();
+ GenerateRandomInput(std::array{500, 784}, *updated_param_value);
+ Ort::Value updated_parameter{updated_param_value};
+ checkpoint_state.UpdateParameter("fc1.weight", updated_parameter);
+
+ Ort::Value current_parameter = checkpoint_state.GetParameter("fc1.weight");
+ gsl::span actual = gsl::span(current_parameter.GetTensorMutableData(),
+ current_parameter.GetTensorTypeAndShapeInfo().GetElementCount());
+ gsl::span expected = gsl::span(updated_parameter.GetTensorMutableData(),
+ updated_parameter.GetTensorTypeAndShapeInfo().GetElementCount());
+ gsl::span not_expected = gsl::span(parameter.GetTensorMutableData(),
+ parameter.GetTensorTypeAndShapeInfo().GetElementCount());
+ ASSERT_EQ(actual, expected);
+ ASSERT_NE(actual, not_expected);
+
+ checkpoint_state.UpdateParameter("fc1.weight", parameter);
+ current_parameter = checkpoint_state.GetParameter("fc1.weight");
+ actual = gsl::span(current_parameter.GetTensorMutableData(),
+ current_parameter.GetTensorTypeAndShapeInfo().GetElementCount());
+ expected = gsl::span(parameter.GetTensorMutableData(),
+ parameter.GetTensorTypeAndShapeInfo().GetElementCount());
+ not_expected = gsl::span(updated_parameter.GetTensorMutableData(),
+ updated_parameter.GetTensorTypeAndShapeInfo().GetElementCount());
+ ASSERT_EQ(actual, expected);
+ ASSERT_NE(actual, not_expected);
+}
+#endif
+
} // namespace onnxruntime::training::test
diff --git a/orttraining/orttraining/training_api/checkpoint_property.h b/orttraining/orttraining/training_api/checkpoint_property.h
index d7b1e295df53e..3c38c99b3152f 100644
--- a/orttraining/orttraining/training_api/checkpoint_property.h
+++ b/orttraining/orttraining/training_api/checkpoint_property.h
@@ -22,10 +22,12 @@ struct PropertyBag {
PropertyBag() = default;
void AddProperty(const std::string& name, const PropertyDataType& val) {
- ORT_ENFORCE(named_properties_.find(name) == named_properties_.end(),
- "Duplicated property named ", name);
-
- named_properties_.insert({name, val});
+ auto it = named_properties_.find(name);
+ if (it == named_properties_.end()) {
+ named_properties_.insert({name, val});
+ } else {
+ it->second = val;
+ }
}
template
diff --git a/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h b/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h
index 0af737074964d..0e8544a7639ba 100644
--- a/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h
+++ b/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h
@@ -608,14 +608,14 @@ struct OrtTrainingApi {
/// \name Accessing The Training Session State
/// @{
- /** \brief Adds the given property to the checkpoint state.
+ /** \brief Adds or updates the given property to/in the checkpoint state.
*
* Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint
- * state by the user if they desire by calling this function with the appropriate property name and
- * value. The given property name must be unique to be able to successfully add the property.
+ * state by the user by calling this function with the corresponding property name and value.
+ * The given property name must be unique to be able to successfully add the property.
*
* \param[in] checkpoint_state The checkpoint state which should hold the property.
- * \param[in] property_name Unique name of the property being added.
+ * \param[in] property_name Name of the property being added or updated.
* \param[in] property_type Type of the property associated with the given name.
* \param[in] property_value Property value associated with the given name.
*
@@ -632,7 +632,7 @@ struct OrtTrainingApi {
* exist in the checkpoint state to be able to retrieve it successfully.
*
* \param[in] checkpoint_state The checkpoint state that is currently holding the property.
- * \param[in] property_name Unique name of the property being retrieved.
+ * \param[in] property_name Name of the property being retrieved.
* \param[in] allocator Allocator used to allocate the memory for the property_value.
* \param[out] property_type Type of the property associated with the given name.
* \param[out] property_value Property value associated with the given name.
@@ -669,6 +669,57 @@ struct OrtTrainingApi {
ORT_API2_STATUS(LoadCheckpointFromBuffer, _In_ const void* checkpoint_buffer,
_In_ const size_t num_bytes, _Outptr_ OrtCheckpointState** checkpoint_state);
+ /** \brief Retrieves the type and shape information of the parameter associated with the given parameter name.
+ *
+ * This function retrieves the type and shape of the parameter associated with the given parameter name.
+ * The parameter must exist in the checkpoint state to be able to retrieve its type and shape information successfully.
+ *
+ * \param[in] checkpoint_state The checkpoint state.
+ * \param[in] parameter_name Name of the parameter being retrieved.
+ * \param[out] parameter_type_and_shape The type and shape of the parameter being retrieved.
+ *
+ * \snippet{doc} snippets.dox OrtStatus Return Value
+ *
+ */
+ ORT_API2_STATUS(GetParameterTypeAndShape, _In_ const OrtCheckpointState* checkpoint_state,
+ _In_ const char* parameter_name, _Outptr_ OrtTensorTypeAndShapeInfo** parameter_type_and_shape);
+
+ /** \brief Updates the data associated with the model parameter in the checkpoint state for the given parameter name.
+ *
+ * This function updates a model parameter in the checkpoint state with the given parameter data.
+ * The training session must be already created with the checkpoint state that contains the parameter
+ * being updated. The given parameter is copied over to the registered device for the training session.
+ * The parameter must exist in the checkpoint state to be able to update it successfully.
+ *
+ * \param[in] checkpoint_state The checkpoint state.
+ * \param[in] parameter_name Name of the parameter being updated.
+ * \param[in] parameter The parameter data that should replace the existing parameter data.
+ *
+ * \snippet{doc} snippets.dox OrtStatus Return Value
+ *
+ */
+ ORT_API2_STATUS(UpdateParameter, _Inout_ OrtCheckpointState* checkpoint_state,
+ _In_ const char* parameter_name, _In_ OrtValue* parameter);
+
+ /** \brief Gets the data associated with the model parameter from the checkpoint state for the given parameter name.
+ *
+ * This function retrieves the model parameter data from the checkpoint state for the given parameter name.
+ * The parameter is copied over and returned as an OrtValue. The training session must be already created
+ * with the checkpoint state that contains the parameter being retrieved.
+ * The parameter must exist in the checkpoint state to be able to retrieve it successfully.
+ *
+ * \param[in] checkpoint_state The checkpoint state.
+ * \param[in] parameter_name Name of the parameter being retrieved.
+ * \param[in] allocator Allocator used to allocate the memory for the parameter.
+ * \param[out] parameter The parameter data that is retrieved from the checkpoint state.
+ *
+ * \snippet{doc} snippets.dox OrtStatus Return Value
+ *
+ */
+ ORT_API2_STATUS(GetParameter, _In_ const OrtCheckpointState* checkpoint_state,
+ _In_ const char* parameter_name, _Inout_ OrtAllocator* allocator,
+ _Outptr_ OrtValue** parameter);
+
/// @}
};
diff --git a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h
index 0edef20ba6da8..218bef524200c 100644
--- a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h
+++ b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h
@@ -112,13 +112,13 @@ class CheckpointState : public detail::Base {
const std::basic_string& path_to_checkpoint,
const bool include_optimizer_state = false);
- /** \brief Adds the given property to the checkpoint state.
+ /** \brief Adds or updates the given property to/in the checkpoint state.
*
* Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint
- * state by the user if they desire by calling this function with the appropriate property name and
- * value. The given property name must be unique to be able to successfully add the property.
+ * state by the user by calling this function with the corresponding property name and value.
+ * The given property name must be unique to be able to successfully add the property.
*
- * \param[in] property_name Unique name of the property being added.
+ * \param[in] property_name Name of the property being added or updated.
* \param[in] property_value Property value associated with the given name.
*
*/
@@ -129,12 +129,38 @@ class CheckpointState : public detail::Base {
* Gets the property value from an existing entry in the checkpoint state. The property must
* exist in the checkpoint state to be able to retrieve it successfully.
*
- * \param[in] property_name Unique name of the property being retrieved.
+ * \param[in] property_name Name of the property being retrieved.
* \return Property value associated with the given property name.
*
*/
Property GetProperty(const std::string& property_name);
+ /** \brief Updates the data associated with the model parameter in the checkpoint state for the given parameter name.
+ *
+ * This function updates a model parameter in the checkpoint state with the given parameter data.
+ * The training session must be already created with the checkpoint state that contains the parameter
+ * being updated. The given parameter is copied over to the registered device for the training session.
+ * The parameter must exist in the checkpoint state to be able to update it successfully.
+ *
+ * \param[in] parameter_name Name of the parameter being updated.
+ * \param[in] parameter The parameter data that should replace the existing parameter data.
+ *
+ */
+ void UpdateParameter(const std::string& parameter_name, const Value& parameter);
+
+ /** \brief Gets the data associated with the model parameter from the checkpoint state for the given parameter name.
+ *
+ * This function retrieves the model parameter data from the checkpoint state for the given parameter name.
+ * The parameter is copied over to the provided OrtValue. The training session must be already created
+ * with the checkpoint state that contains the parameter being retrieved.
+ * The parameter must exist in the checkpoint state to be able to retrieve it successfully.
+ *
+ * \param[in] parameter_name Name of the parameter being retrieved.
+ * \return The parameter data that is retrieved from the checkpoint state.
+ *
+ */
+ Value GetParameter(const std::string& parameter_name);
+
/// @}
};
diff --git a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h
index 066147708863f..a5efa3c0e4bef 100644
--- a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h
+++ b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h
@@ -279,4 +279,16 @@ inline Property CheckpointState::GetProperty(const std::string& property_name) {
return property;
}
+inline void CheckpointState::UpdateParameter(const std::string& parameter_name, const Value& parameter) {
+ ThrowOnError(GetTrainingApi().UpdateParameter(p_, parameter_name.c_str(), parameter));
+}
+
+inline Value CheckpointState::GetParameter(const std::string& parameter_name) {
+ AllocatorWithDefaultOptions allocator;
+ OrtValue* parameter;
+ ThrowOnError(GetTrainingApi().GetParameter(p_, parameter_name.c_str(), allocator, ¶meter));
+
+ return Value{parameter};
+}
+
} // namespace Ort
diff --git a/orttraining/orttraining/training_api/module.cc b/orttraining/orttraining/training_api/module.cc
index d1775e358163c..cf49a01517d6b 100644
--- a/orttraining/orttraining/training_api/module.cc
+++ b/orttraining/orttraining/training_api/module.cc
@@ -119,6 +119,61 @@ Status TransformModelInputsForInference(Graph& inference_graph,
#endif
} // namespace
+Status Parameter::CopyTo(const DataTransferManager* data_transfer_manager, OrtValue& data) const {
+ ORT_ENFORCE(data.IsAllocated(), "Given parameter data is not allocated. Cannot copy the checkpoint parameter to it.");
+ ORT_ENFORCE(data.IsTensor(), "Parameter data should be of tensor type.");
+ ORT_ENFORCE(data.Get().Shape() == data_.Get().Shape(),
+ "Parameter data shape mismatch. Expected: ", data_.Get().Shape().ToString(),
+ ", Got: ", data.Get().Shape().ToString());
+#ifdef ENABLE_STRIDED_TENSORS
+ auto data_strides = data.Get().Strides();
+ auto param_strides = data_.Get().Strides();
+ ORT_ENFORCE(data_strides.size() == param_strides.size(),
+ "Parameter data stride mismatch. Expected strides of size: ", param_strides.size(),
+ ", Got: ", data_strides.size());
+ ORT_ENFORCE(std::equal(data_strides.begin(), data_strides.end(), param_strides.begin()),
+ "Parameter data stride value mismatch.");
+#endif
+ ORT_ENFORCE(data.Get().DataType() == data_.Get().DataType(),
+ "Parameter data type mismatch. Expected: ", data_.Get().DataType(),
+ ", Got: ", data.Get().DataType());
+ ORT_ENFORCE(data_transfer_manager != nullptr,
+ "Data transfer manager must be provided to copy data to the parameter. "
+ "Please create the TrainingSession before trying to update the parameter.");
+
+ ORT_THROW_IF_ERROR(data_transfer_manager->CopyTensor(data_.Get(), *data.GetMutable()));
+
+ return Status::OK();
+}
+
+Status Parameter::CopyFrom(const DataTransferManager* data_transfer_manager, const OrtValue& data) {
+ ORT_ENFORCE(data_.IsAllocated(),
+ "The checkpoint parameter is not allocated. Cannot copy the given parameter data to it.");
+ ORT_ENFORCE(data.IsTensor(), "Parameter data should be of tensor type.");
+ ORT_ENFORCE(data.Get().Shape() == data_.Get().Shape(),
+ "Parameter data shape mismatch. Expected: ", data_.Get().Shape().ToString(),
+ ", Got: ", data.Get().Shape().ToString());
+#ifdef ENABLE_STRIDED_TENSORS
+ auto data_strides = data.Get().Strides();
+ auto param_strides = data_.Get().Strides();
+ ORT_ENFORCE(data_strides.size() == param_strides.size(),
+ "Parameter data stride mismatch. Expected strides of size: ", param_strides.size(),
+ ", Got: ", data_strides.size());
+ ORT_ENFORCE(std::equal(data_strides.begin(), data_strides.end(), param_strides.begin()),
+ "Parameter data stride value mismatch.");
+#endif
+ ORT_ENFORCE(data.Get().DataType() == data_.Get().DataType(),
+ "Parameter data type mismatch. Expected: ", data_.Get().DataType(),
+ ", Got: ", data.Get().DataType());
+ ORT_ENFORCE(data_transfer_manager != nullptr,
+ "Data transfer manager must be provided to copy data to the parameter. "
+ "Please create the TrainingSession before trying to update the parameter.");
+
+ ORT_THROW_IF_ERROR(data_transfer_manager->CopyTensor(data.Get(), *data_.GetMutable()));
+
+ return Status::OK();
+}
+
Status Parameter::SetGrad(const std::string& gradient_name, const OrtValue& param_grad) {
// assert param is allocated
ORT_ENFORCE(data_.IsAllocated(), "Parameter data should be allocated before allocating gradient.");
@@ -334,6 +389,10 @@ Module::Module(const ModelIdentifiers& model_identifiers,
}
}
+Module::~Module() {
+ state_->module_checkpoint_state.train_session_data_transfer_mgr = nullptr;
+}
+
size_t Module::GetTrainingModelOutputCount() const noexcept {
return train_output_names_.size();
}
diff --git a/orttraining/orttraining/training_api/module.h b/orttraining/orttraining/training_api/module.h
index adb633343263e..f323e6be72d49 100644
--- a/orttraining/orttraining/training_api/module.h
+++ b/orttraining/orttraining/training_api/module.h
@@ -21,6 +21,8 @@ struct Parameter {
// Return the mutable data.
OrtValue& Data() { return data_; }
+ Status CopyTo(const DataTransferManager* data_transfer_manager, OrtValue& data) const;
+ Status CopyFrom(const DataTransferManager* data_transfer_manager, const OrtValue& data);
const std::string& Name() const { return name_; }
// Returns whether this parameter is trainable or not.
@@ -34,7 +36,6 @@ struct Parameter {
// Reset and release the gradient buffer of this Parameter greedily.
Status ResetGrad();
- protected:
Status SetGrad(const std::string& gradient_name, const OrtValue& param_grad);
private:
@@ -83,6 +84,8 @@ struct Module {
const std::vector>& providers,
gsl::span op_domains = gsl::span());
+ ~Module();
+
// Return the trainable/nontrainable parameters
std::vector> Parameters() const;
diff --git a/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc b/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc
index 6693bba348648..38a9aad9640ea 100644
--- a/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc
+++ b/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc
@@ -333,6 +333,10 @@ ORT_API_STATUS_IMPL(OrtTrainingApis::LoadCheckpointFromBuffer, _In_ const void*
_In_ const size_t num_bytes, _Outptr_ OrtCheckpointState** checkpoint_state) {
API_IMPL_BEGIN
+ if (checkpoint_buffer == nullptr) {
+ return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Expected a valid checkpoint buffer. Actual: nullptr.");
+ }
+
*checkpoint_state = nullptr;
auto chkpt_state = std::make_unique();
const auto* checkpoint_bytes = reinterpret_cast(checkpoint_buffer);
@@ -559,6 +563,76 @@ ORT_API_STATUS_IMPL(OrtTrainingApis::GetProperty, _In_ const OrtCheckpointState*
API_IMPL_END
}
+ORT_API_STATUS_IMPL(OrtTrainingApis::GetParameterTypeAndShape, _In_ const OrtCheckpointState* checkpoint_state,
+ _In_ const char* parameter_name, _Outptr_ OrtTensorTypeAndShapeInfo** parameter_type_and_shape) {
+ API_IMPL_BEGIN
+
+ auto chkpt_state = reinterpret_cast(checkpoint_state);
+ auto it = chkpt_state->module_checkpoint_state.named_parameters.find(parameter_name);
+ if (it == chkpt_state->module_checkpoint_state.named_parameters.end()) {
+ std::string err_msg = "Parameter name " + std::string(parameter_name) + " not found in checkpoint state.";
+ return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, err_msg.c_str());
+ }
+
+ return OrtApis::GetTensorTypeAndShape(&it->second->Data(), parameter_type_and_shape);
+ API_IMPL_END
+}
+
+ORT_API_STATUS_IMPL(OrtTrainingApis::UpdateParameter, _Inout_ OrtCheckpointState* checkpoint_state,
+ _In_ const char* parameter_name, _In_ OrtValue* parameter) {
+ API_IMPL_BEGIN
+ if (parameter == nullptr) {
+ return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Expected a valid parameter. Actual: nullptr.");
+ }
+
+ auto chkpt_state = reinterpret_cast(checkpoint_state);
+ auto it = chkpt_state->module_checkpoint_state.named_parameters.find(parameter_name);
+ if (it == chkpt_state->module_checkpoint_state.named_parameters.end()) {
+ std::string err_msg = "Parameter name " + std::string(parameter_name) + " not found in checkpoint state.";
+ return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, err_msg.c_str());
+ }
+ ORT_API_RETURN_IF_STATUS_NOT_OK(it->second->CopyFrom(
+ chkpt_state->module_checkpoint_state.train_session_data_transfer_mgr, *parameter));
+
+ return nullptr;
+ API_IMPL_END
+}
+
+ORT_API_STATUS_IMPL(OrtTrainingApis::GetParameter, _In_ const OrtCheckpointState* checkpoint_state,
+ _In_ const char* parameter_name, _Inout_ OrtAllocator* allocator,
+ _Outptr_ OrtValue** parameter) {
+ API_IMPL_BEGIN
+
+ if (parameter == nullptr) {
+ return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Expected a valid parameter. Actual: nullptr.");
+ }
+
+ auto chkpt_state = reinterpret_cast(checkpoint_state);
+ auto it = chkpt_state->module_checkpoint_state.named_parameters.find(parameter_name);
+ if (it == chkpt_state->module_checkpoint_state.named_parameters.end()) {
+ std::string err_msg = "Parameter name " + std::string(parameter_name) + " not found in checkpoint state.";
+ return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, err_msg.c_str());
+ }
+
+ if (!it->second->Data().IsTensor()) {
+ return OrtApis::CreateStatus(ORT_FAIL, "Expected a tensor type for the parameter. Found a non-tensor type.");
+ }
+ const auto& parameter_tensor = it->second->Data().Get();
+ ORT_API_RETURN_IF_ERROR(OrtApis::CreateTensorAsOrtValue(
+ allocator, parameter_tensor.Shape().GetDims().data(), parameter_tensor.Shape().NumDimensions(),
+ ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, parameter));
+
+ auto status = it->second->CopyTo(
+ chkpt_state->module_checkpoint_state.train_session_data_transfer_mgr, **parameter);
+ if (!status.IsOK()) {
+ OrtApis::ReleaseValue(*parameter);
+ return onnxruntime::ToOrtStatus(status);
+ }
+
+ return nullptr;
+ API_IMPL_END
+}
+
static constexpr OrtTrainingApi ort_training_api = {
// NOTE: The C# bindings depend on the API order within this struct. Since Training APIs are not officially
// released, it is OK to change the order here, however a corresponding matching change should also be done in the
@@ -592,7 +666,10 @@ static constexpr OrtTrainingApi ort_training_api = {
&OrtTrainingApis::TrainingSessionGetEvalModelInputName,
&OrtTrainingApis::AddProperty,
&OrtTrainingApis::GetProperty,
- &OrtTrainingApis::LoadCheckpointFromBuffer};
+ &OrtTrainingApis::LoadCheckpointFromBuffer,
+ &OrtTrainingApis::GetParameterTypeAndShape,
+ &OrtTrainingApis::UpdateParameter,
+ &OrtTrainingApis::GetParameter};
ORT_API(const OrtTrainingApi*, OrtTrainingApis::GetTrainingApi, uint32_t) {
// No constraints on the API version yet.
diff --git a/orttraining/orttraining/training_api/ort_training_apis.h b/orttraining/orttraining/training_api/ort_training_apis.h
index c87108957c975..2a8c1e30361c6 100644
--- a/orttraining/orttraining/training_api/ort_training_apis.h
+++ b/orttraining/orttraining/training_api/ort_training_apis.h
@@ -94,4 +94,14 @@ ORT_API_STATUS_IMPL(GetProperty, _In_ const OrtCheckpointState* checkpoint_state
ORT_API_STATUS_IMPL(LoadCheckpointFromBuffer, _In_ const void* checkpoint_buffer,
_In_ const size_t num_bytes, _Outptr_ OrtCheckpointState** checkpoint_state);
+ORT_API_STATUS_IMPL(GetParameterTypeAndShape, _In_ const OrtCheckpointState* checkpoint_state,
+ _In_ const char* parameter_name, _Outptr_ OrtTensorTypeAndShapeInfo** parameter_type_and_shape);
+
+ORT_API_STATUS_IMPL(UpdateParameter, _Inout_ OrtCheckpointState* checkpoint_state,
+ _In_ const char* parameter_name, _In_ OrtValue* parameter);
+
+ORT_API_STATUS_IMPL(GetParameter, _In_ const OrtCheckpointState* checkpoint_state,
+ _In_ const char* parameter_name, _Inout_ OrtAllocator* allocator,
+ _Outptr_ OrtValue** parameter);
+
} // namespace OrtTrainingApis
diff --git a/winml/test/model/skip_model_tests.h b/winml/test/model/skip_model_tests.h
index f815b197b6a00..9d66320343c43 100644
--- a/winml/test/model/skip_model_tests.h
+++ b/winml/test/model/skip_model_tests.h
@@ -163,10 +163,8 @@ std::unordered_map> disabledGpu
test name -> absolute difference sampleTolerance
*/
std::unordered_map sampleTolerancePerTests({
- {"fp16_inception_v1_opset7_GPU",0.005 },
- {"fp16_inception_v1_opset8_GPU", 0.005},
- { "candy_opset9_GPU",
- 0.00150000 }, // Intel(R) UHD Graphics 630 (29.20.100.9020) AP machine has inaccurate GPU results for FNS Candy opset 9 https://microsoft.visualstudio.com/OS/_workitems/edit/30696168/
- { "fp16_tiny_yolov2_opset8_GPU",
- 0.109000 }, // Intel(R) UHD Graphics 630 (29.20.100.9020) AP machine has inaccurate GPU results for FNS Candy opset 9 https://microsoft.visualstudio.com/OS/_workitems/edit/30696168/
+ {"fp16_inception_v1_opset7_GPU", 0.005},
+ {"fp16_inception_v1_opset8_GPU", 0.005},
+ { "candy_opset9_GPU", 0.00150000}, // Intel(R) UHD Graphics 630 (29.20.100.9020) AP machine has inaccurate GPU results for FNS Candy opset 9 https://microsoft.visualstudio.com/OS/_workitems/edit/30696168/
+ { "fp16_tiny_yolov2_opset8_GPU", 0.109000}, // Intel(R) UHD Graphics 630 (29.20.100.9020) AP machine has inaccurate GPU results for FNS Candy opset 9 https://microsoft.visualstudio.com/OS/_workitems/edit/30696168/
});