From 2f57f1e4d711920494faf6c48ab35b0ecb0210cf Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Thu, 2 Nov 2023 10:01:53 -0700 Subject: [PATCH] Some cherry-picks for the 1.16.2 release (#18218) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Cherry-pick PRs: #18026 #17912 #17901 “2 lines added whitespace errors when cherry-picking" #17293 #17364 #17505 #17885 This PR contains all the cherry-picks for the patch release except: 1. The PRs marked with sdxl_llama 2. #17772 which has a merge conflict. --------- Co-authored-by: Chi Lo Co-authored-by: Chi Lo <54722500+chilo-ms@users.noreply.github.com> Co-authored-by: Scott McKay Co-authored-by: Baiju Meswani Co-authored-by: Kaz Nishimura Co-authored-by: Scott McKay --- .../NativeMethods.shared.cs | 53 +++-- .../Training/CheckpointState.shared.cs | 133 +++++++---- .../Training/NativeTrainingMethods.shared.cs | 34 +++ .../Training/TrainingSession.shared.cs | 55 ++--- .../TrainingTest.cs | 128 ++++++++-- docs/OperatorKernels.md | 1 + onnxruntime/core/mlas/lib/platform.cpp | 14 +- .../core/providers/cuda/cu_inc/common.cuh | 14 +- .../providers/cuda/cuda_execution_provider.cc | 22 ++ .../cuda/math/unary_elementwise_ops.cc | 1 + .../cuda/math/unary_elementwise_ops.h | 7 + .../cuda/math/unary_elementwise_ops_impl.cu | 170 +++++++------- .../cuda/math/unary_elementwise_ops_impl.h | 3 +- .../inc/IWinmlExecutionProvider.h | 2 +- .../core/providers/rocm/cu_inc/common.cuh | 18 +- .../providers/rocm/rocm_execution_provider.cc | 22 ++ .../tensorrt/tensorrt_execution_provider.cc | 12 +- .../tensorrt/tensorrt_execution_provider.h | 4 + .../test/providers/cpu/math/sign_test.cc | 10 +- .../python/orttraining_pybind_state.cc | 80 ++++++- .../python/training/api/checkpoint_state.py | 220 ++++++++++++++++-- .../orttraining_test_python_bindings.py | 71 +++++- .../training_api/core/training_capi_tests.cc | 102 ++++++++ .../training_api/checkpoint_property.h | 10 +- .../include/onnxruntime_training_c_api.h | 61 ++++- .../include/onnxruntime_training_cxx_api.h | 36 ++- .../include/onnxruntime_training_cxx_inline.h | 12 + .../orttraining/training_api/module.cc | 59 +++++ orttraining/orttraining/training_api/module.h | 5 +- .../onnxruntime_training_c_api.cc | 79 ++++++- .../training_api/ort_training_apis.h | 10 + winml/test/model/skip_model_tests.h | 10 +- 32 files changed, 1170 insertions(+), 288 deletions(-) diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs index 2ba837be22041..f722ca9d30fa4 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs @@ -1860,54 +1860,61 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca public static DOrtFillStringTensor OrtFillStringTensor; + /// \param value A tensor created from OrtCreateTensor... function. + /// \param index The index of the entry in the tensor to resize. + /// \param length_in_bytes Length to resize the string to. + /// \param buffer The resized buffer. + [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtGetResizedStringTensorElementBuffer( - IntPtr /* OrtValue */ value, - UIntPtr /* size_t */ index, - UIntPtr /* size_t */ length_in_bytes, - out IntPtr /* char** */ buffer - ); + IntPtr /* OrtValue */ value, + UIntPtr /* size_t */ index, + UIntPtr /* size_t */ length_in_bytes, + out IntPtr /* char** */ buffer); public static DOrtGetResizedStringTensorElementBuffer OrtGetResizedStringTensorElementBuffer; [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtGetStringTensorContent( - IntPtr /*(OrtValue*)*/ value, - byte[] /*(void*)*/ dst_buffer, - UIntPtr dst_buffer_len, - UIntPtr[] offsets, - UIntPtr offsets_len); + IntPtr /*(OrtValue*)*/ value, + byte[] /*(void*)*/ dst_buffer, + UIntPtr dst_buffer_len, + UIntPtr[] offsets, + UIntPtr offsets_len); public static DOrtGetStringTensorContent OrtGetStringTensorContent; [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtGetStringTensorDataLength(IntPtr /*(OrtValue*)*/ value, - out UIntPtr /*(size_t*)*/ len); + out UIntPtr /*(size_t*)*/ len); public static DOrtGetStringTensorDataLength OrtGetStringTensorDataLength; [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtGetStringTensorElementLength(IntPtr /*(OrtValue*)*/ value, - UIntPtr /*(size_t)*/ index, - out UIntPtr /*(size_t*)*/ len); + UIntPtr /*(size_t)*/ index, + out UIntPtr /*(size_t*)*/ len); public static DOrtGetStringTensorElementLength OrtGetStringTensorElementLength; [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtGetStringTensorElement(IntPtr /*(OrtValue*)*/ value, - UIntPtr /*(size_t)*/ bufferLength, - UIntPtr /*(size_t)*/ elementIndex, - byte[] buffer); + UIntPtr /*(size_t)*/ bufferLength, + UIntPtr /*(size_t)*/ elementIndex, + byte[] buffer); public static DOrtGetStringTensorElement OrtGetStringTensorElement; [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/ - DOrtCastTypeInfoToTensorInfo(IntPtr /*(struct OrtTypeInfo*)*/ typeInfo, out IntPtr /*(const struct OrtTensorTypeAndShapeInfo**)*/ typeAndShapeInfo); + public delegate IntPtr /*(OrtStatus*)*/ DOrtCastTypeInfoToTensorInfo( + IntPtr /*(struct OrtTypeInfo*)*/ typeInfo, + out IntPtr /*(const struct OrtTensorTypeAndShapeInfo**)*/ typeAndShapeInfo); public static DOrtCastTypeInfoToTensorInfo OrtCastTypeInfoToTensorInfo; [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/ DOrtGetTensorTypeAndShape(IntPtr /*(OrtValue*)*/ value, out IntPtr /*(struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo); + public delegate IntPtr /*(OrtStatus*)*/ DOrtGetTensorTypeAndShape( + IntPtr /*(OrtValue*)*/ value, + out IntPtr /*(struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo); public static DOrtGetTensorTypeAndShape OrtGetTensorTypeAndShape; @@ -1917,12 +1924,16 @@ out IntPtr /* char** */ buffer public static DOrtReleaseTensorTypeAndShapeInfo OrtReleaseTensorTypeAndShapeInfo; [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/ DOrtGetTensorElementType(IntPtr /*(const struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo, out IntPtr /*(TensorElementType*)*/ output); + public delegate IntPtr /*(OrtStatus*)*/ DOrtGetTensorElementType( + IntPtr /*(const struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo, + out IntPtr /*(TensorElementType*)*/ output); public static DOrtGetTensorElementType OrtGetTensorElementType; [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/ DOrtGetDimensionsCount(IntPtr /*(const struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo, out UIntPtr output); + public delegate IntPtr /*(OrtStatus*)*/ DOrtGetDimensionsCount( + IntPtr /*(const struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo, + out UIntPtr output); public static DOrtGetDimensionsCount OrtGetDimensionsCount; diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs index 659c6303702ac..6889112acb385 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs @@ -40,20 +40,16 @@ internal enum PropertyType : long String = 2 } - private void AddPropertyImpl(string propertyName, PropertyType propertyType, T propertyValue) + private void AddPropertyImpl(string propertyName, PropertyType propertyType, T propertyValue) where T : unmanaged { var propertyNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyName); - T[] value = new T[1]; - value[0] = propertyValue; - Memory memory = value; - using (var memHandle = memory.Pin()) + T[] value = { propertyValue }; + unsafe { - IntPtr memPtr; - unsafe + fixed (T* memPtr = value) { - memPtr = (IntPtr)memHandle.Pointer; + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, propertyType, (IntPtr)memPtr)); } - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, propertyType, memPtr)); } } @@ -103,13 +99,13 @@ public static void SaveCheckpoint(CheckpointState state, string checkpointPath, } /// - /// Adds the given int property to the checkpoint state. + /// Adds or updates the given int property to/in the checkpoint state. /// - /// Runtime properties that are ints such as epoch, training step, and others can be added to the checkpoint - /// state by the user if they desire by calling this function with the appropriate property name and - /// value. The given property name must be unique to be able to successfully add the property. + /// Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint + /// state by the user by calling this function with the corresponding property name and value. + /// The given property name must be unique to be able to successfully add the property. /// - /// Unique name of the property being added. + /// Name of the property being added or updated. /// Property value associated with the given name. public void AddProperty(string propertyName, long propertyValue) { @@ -117,13 +113,13 @@ public void AddProperty(string propertyName, long propertyValue) } /// - /// Adds the given float property to the checkpoint state. + /// Adds or updates the given float property to/in the checkpoint state. /// - /// Runtime properties that are floats such as loss, best score, and others can be added to the checkpoint - /// state by the user if they desire by calling this function with the appropriate property name and - /// value. The given property name must be unique to be able to successfully add the property. + /// Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint + /// state by the user by calling this function with the corresponding property name and value. + /// The given property name must be unique to be able to successfully add the property. /// - /// Unique name of the property being added. + /// Name of the property being added or updated. /// Property value associated with the given name. public void AddProperty(string propertyName, float propertyValue) { @@ -131,28 +127,25 @@ public void AddProperty(string propertyName, float propertyValue) } /// - /// Adds the given string property to the checkpoint state. + /// Adds or updates the given string property to/in the checkpoint state. /// - /// Runtime properties that are strings such as parameter names, custom strings, and others can be added - /// to the checkpoint state by the user if they desire by calling this function with the appropriate property - /// name and value. The given property name must be unique to be able to successfully add the property. + /// Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint + /// state by the user by calling this function with the corresponding property name and value. + /// The given property name must be unique to be able to successfully add the property. /// - /// Unique name of the property being added. + /// Name of the property being added or updated. /// Property value associated with the given name. public void AddProperty(string propertyName, string propertyValue) { var propertyNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyName); var propertyValueUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyValue); - IntPtr unmanagedPointer = Marshal.AllocHGlobal(propertyValueUtf8.Length); - try - { - Marshal.Copy(propertyValueUtf8, 0, unmanagedPointer, propertyValueUtf8.Length); - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, PropertyType.String, unmanagedPointer)); - } - finally + unsafe { - Marshal.FreeHGlobal(unmanagedPointer); + fixed (byte* p = propertyValueUtf8) + { + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, PropertyType.String, (IntPtr)p)); + } } } @@ -162,34 +155,86 @@ public void AddProperty(string propertyName, string propertyValue) /// Gets the property value from an existing entry in the checkpoint state. The property must /// exist in the checkpoint state to be able to retrieve it successfully. /// - /// Unique name of the property being retrieved. + /// Name of the property being retrieved. /// Property value associated with the given property name. public object GetProperty(string propertyName) { var propertyNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyName); var allocator = OrtAllocator.DefaultInstance; IntPtr propertyValue = IntPtr.Zero; + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetProperty(handle, propertyNameUtf8, allocator.Pointer, out PropertyType propertyType, out propertyValue)); - if (propertyType == PropertyType.Int) + try { - var longPropertyValue = Marshal.ReadInt64(propertyValue); - allocator.FreeMemory(propertyValue); - return longPropertyValue; + if (propertyType == PropertyType.Int) + { + Int64 value; + unsafe + { + value = *(Int64*)propertyValue; + } + return value; + } + else if (propertyType == PropertyType.Float) + { + float value; + unsafe + { + value = *(float*)propertyValue; + } + return value; + } + else if (propertyType == PropertyType.String) + { + return NativeOnnxValueHelper.StringFromNativeUtf8(propertyValue); + } + + throw new ArgumentException("Expected the property type to be one of long, float or string. Unknown type retrieved " + propertyValue.ToString()); } - else if (propertyType == PropertyType.Float) + finally { - float[] value = new float[1]; - Marshal.Copy(propertyValue, value, 0, 1); allocator.FreeMemory(propertyValue); - return value[0]; } - else if (propertyType == PropertyType.String) + } + + /// + /// Updates the data associated with the model parameter in the checkpoint state for the given parameter name. + /// + /// This function updates a model parameter in the checkpoint state with the given parameter data. + /// The training session must be already created with the checkpoint state that contains the parameter + /// being updated. The given parameter is copied over to the registered device for the training session. + /// The parameter must exist in the checkpoint state to be able to update it successfully. + /// + /// Name of the parameter being updated. + /// The parameter data that should replace the existing parameter data. + public void UpdateParameter(string parameterName, OrtValue parameter) + { + if (parameter.OnnxType != OnnxValueType.ONNX_TYPE_TENSOR) { - return NativeOnnxValueHelper.StringFromNativeUtf8(propertyValue, allocator); + throw new ArgumentException("Incorrect buffer received. Expected a tensor parameter."); } - throw new ArgumentException("Expected the property type to be one of long, float or string. Unknown type retrieved " + propertyValue.ToString()); + var parameterNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(parameterName); + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtUpdateParameter(handle, parameterNameUtf8, parameter.Handle)); + } + + /// + /// Gets the data associated with the model parameter from the checkpoint state for the given parameter name. + /// + /// This function retrieves the model parameter data from the checkpoint state for the given parameter name. + /// The parameter is copied over to the provided OrtValue. The training session must be already created + /// with the checkpoint state that contains the parameter being retrieved. + /// The parameter must exist in the checkpoint state to be able to retrieve it successfully. + /// + /// Name of the parameter being updated. + /// The parameter data that is retrieved from the checkpoint state. + public OrtValue GetParameter(string parameterName) + { + var parameterNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(parameterName); + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParameter(handle, parameterNameUtf8, OrtAllocator.DefaultInstance.Pointer, out IntPtr parameterHandle)); + + return new OrtValue(parameterHandle); } #region SafeHandle diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs index ac790242409e3..d6341b90f28ff 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs @@ -42,6 +42,9 @@ public struct OrtTrainingApi public IntPtr AddProperty; public IntPtr GetProperty; public IntPtr LoadCheckpointFromBuffer; + public IntPtr GetParameterTypeAndShape; + public IntPtr UpdateParameter; + public IntPtr GetParameter; } internal static class NativeTrainingMethods @@ -97,6 +100,9 @@ static NativeTrainingMethods() OrtGetEvalModelInputName = (DOrtGetEvalModelInputName)Marshal.GetDelegateForFunctionPointer(trainingApi_.TrainingSessionGetEvalModelInputName, typeof(DOrtGetEvalModelInputName)); OrtAddProperty = (DOrtAddProperty)Marshal.GetDelegateForFunctionPointer(trainingApi_.AddProperty, typeof(DOrtAddProperty)); OrtGetProperty = (DOrtGetProperty)Marshal.GetDelegateForFunctionPointer(trainingApi_.GetProperty, typeof(DOrtGetProperty)); + OrtGetParameterTypeAndShape = (DOrtGetParameterTypeAndShape)Marshal.GetDelegateForFunctionPointer(trainingApi_.GetParameterTypeAndShape, typeof(DOrtGetParameterTypeAndShape)); + OrtUpdateParameter = (DOrtUpdateParameter)Marshal.GetDelegateForFunctionPointer(trainingApi_.UpdateParameter, typeof(DOrtUpdateParameter)); + OrtGetParameter = (DOrtGetParameter)Marshal.GetDelegateForFunctionPointer(trainingApi_.GetParameter, typeof(DOrtGetParameter)); } } @@ -359,6 +365,34 @@ out UIntPtr inputCount public static DOrtGetProperty OrtGetProperty; + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /*(OrtStatus*)*/ DOrtGetParameterTypeAndShape( + IntPtr /*(OrtCheckpointState*)*/ checkpointState, + byte[] /*(const char*)*/ parameterName, + out IntPtr /*(OrtTensorTypeAndShapeInfo**)*/ parameterTypeAndShape + ); + + public static DOrtGetParameterTypeAndShape OrtGetParameterTypeAndShape; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /*(OrtStatus*)*/ DOrtUpdateParameter( + IntPtr /*(OrtCheckpointState*)*/ checkpointState, + byte[] /*(const char*)*/ parameterName, + IntPtr /*(OrtValue*)*/ parameter + ); + + public static DOrtUpdateParameter OrtUpdateParameter; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /*(OrtStatus*)*/ DOrtGetParameter( + IntPtr /*(OrtCheckpointState*)*/ checkpointState, + byte[] /*(const char*)*/ parameterName, + IntPtr /*(OrtAllocator*)*/ allocator, + out IntPtr /*(OrtValue**)*/ parameter + ); + + public static DOrtGetParameter OrtGetParameter; + #endregion TrainingSession API public static bool TrainingEnabled() diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs index 33993c2be135b..877677dcad57b 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs @@ -358,13 +358,14 @@ public void EvalStep( IReadOnlyCollection inputValues, IReadOnlyCollection outputValues) { - if (!_evalOutputCount.Equals(outputValues.Count)) + if (_evalOutputCount != (ulong)outputValues.Count()) { - throw new ArgumentException($"Length of {nameof(outputValues)} ({outputValues.Count}) must match that of train model ({_trainOutputCount})."); + throw new ArgumentException($"Length of {nameof(outputValues)} ({outputValues.Count}) must match that of eval model ({_evalOutputCount})."); } - IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues, true); + const bool isInput = true; + IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues, isInput); - IntPtr[] outputValuesArray = GetOrtValuesHandles(outputValues, false); /* pointers to Pre-allocated OrtValue instances */ + IntPtr[] outputValuesArray = GetOrtValuesHandles(outputValues, !isInput); /* pointers to Pre-allocated OrtValue instances */ NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtEvalStep(_nativeHandle, options.Handle, (UIntPtr)inputValues.Count, inputValuesArray, (UIntPtr)outputValues.Count, outputValuesArray)); } @@ -509,18 +510,17 @@ public void ExportModelForInferencing(string inferenceModelPath, IReadOnlyCollec /// Returns a contiguous buffer that holds a copy of all training state parameters /// /// Whether to only copy trainable parameters or to copy all parameters. - public FixedBufferOnnxValue ToBuffer(bool onlyTrainable) + public OrtValue ToBuffer(bool onlyTrainable) { UIntPtr bufferSize = UIntPtr.Zero; NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out bufferSize, onlyTrainable)); float[] bufferMemory = new float[bufferSize.ToUInt64()]; - var memInfo = OrtMemoryInfo.DefaultInstance; // CPU - var shape = new long[] { (long)bufferSize.ToUInt64() }; - var buffer = FixedBufferOnnxValue.CreateFromMemory(memInfo, bufferMemory, Tensors.TensorElementType.Float, shape, (long)bufferSize.ToUInt64() * sizeof(float)); + var shape = new long[] { (long)bufferSize }; + var buffer = OrtValue.CreateAllocatedTensorValue(OrtAllocator.DefaultInstance, Tensors.TensorElementType.Float, shape); - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyParametersToBuffer(_nativeHandle, buffer.Value.Handle, onlyTrainable)); + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyParametersToBuffer(_nativeHandle, buffer.Handle, onlyTrainable)); return buffer; } @@ -528,45 +528,30 @@ public FixedBufferOnnxValue ToBuffer(bool onlyTrainable) /// /// Loads the training session model parameters from a contiguous buffer /// - /// Contiguous buffer to load the parameters from. - public void FromBuffer(FixedBufferOnnxValue buffer) + /// Contiguous buffer to load the parameters from. + /// Whether to only load trainable parameters or to load all parameters. + public void FromBuffer(OrtValue ortValue, bool onlyTrainable) { - if (buffer.OnnxValueType != OnnxValueType.ONNX_TYPE_TENSOR) + if (ortValue.OnnxType != OnnxValueType.ONNX_TYPE_TENSOR) { throw new ArgumentException("Incorrect buffer received. Expected a tensor buffer."); } - IntPtr typeAndShapeInfo = IntPtr.Zero; - NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorTypeAndShape(buffer.Value.Handle, out typeAndShapeInfo)); - UIntPtr numDimensions = UIntPtr.Zero; - NativeApiStatus.VerifySuccess(NativeMethods.OrtGetDimensionsCount(typeAndShapeInfo, out numDimensions)); - if (numDimensions.ToUInt64() != 1) + var tensorInfo = ortValue.GetTensorTypeAndShape(); + if (tensorInfo.ElementDataType != Tensors.TensorElementType.Float) { - string errorMessage = "Incorrect buffer shape received. Expected a contiguous tensor buffer. Expected number of dimensions: 1, Actual: " + numDimensions.ToString(); - throw new ArgumentException(errorMessage); - } - - // Here buffer size represents the number of elements in the buffer - NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorShapeElementCount(typeAndShapeInfo, out UIntPtr bufferSize)); - - // OrtGetParametersSize returns the total number of elements in the model's parameters. - UIntPtr numElementsTrainingOnly = UIntPtr.Zero; - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out numElementsTrainingOnly, true)); - if ((ulong)bufferSize == (ulong)numElementsTrainingOnly) - { - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyBufferToParameters(_nativeHandle, buffer.Value.Handle, true)); - return; + throw new ArgumentException("Incorrect buffer received. Expected a tensor buffer of type float."); } UIntPtr numElements = UIntPtr.Zero; - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out numElements, false)); - if ((ulong)bufferSize != (ulong)numElements) + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out numElements, onlyTrainable)); + if ((ulong)tensorInfo.ElementCount != (ulong)numElements) { - string errorMessage = "Incorrect buffer size received. Expected size to be one of " + numElementsTrainingOnly.ToString() + " (training only) or " + numElements.ToString() + " (all parameters). Actual size: " + bufferSize.ToString(); + string errorMessage = "Incorrect buffer size received. Expected size to be " + numElements.ToString() + ". Actual size: " + tensorInfo.ElementCount.ToString(); throw new ArgumentException(errorMessage); } - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyBufferToParameters(_nativeHandle, buffer.Value.Handle, false)); + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyBufferToParameters(_nativeHandle, ortValue.Handle, onlyTrainable)); } /// diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs index ea2b6d7dbc118..68b1d5bcc6147 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs @@ -484,20 +484,23 @@ public void TestEvalModelOutputNames() public void TestToBuffer() { string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt"); - using (var cleanUp = new DisposableListTest()) + string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); + string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx"); + string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx"); + + using (var state = CheckpointState.LoadCheckpoint(checkpointPath)) + using (var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath)) { - var state = CheckpointState.LoadCheckpoint(checkpointPath); - cleanUp.Add(state); Assert.NotNull(state); - string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); - string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx"); - string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx"); - var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath); - cleanUp.Add(trainingSession); - - var buffer = trainingSession.ToBuffer(true); - cleanUp.Add(buffer); + using (var buffer = trainingSession.ToBuffer(true)) + { + Assert.NotNull(buffer); + var typeShape = buffer.GetTensorTypeAndShape(); + Assert.Equal(1, typeShape.DimensionsCount); + var fetchedShape = typeShape.Shape; + Assert.Equal(397510, fetchedShape[0]); + } } } @@ -505,22 +508,25 @@ public void TestToBuffer() public void TestFromBuffer() { string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt"); - using (var cleanUp = new DisposableListTest()) + string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); + string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx"); + string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx"); + + using (var state = CheckpointState.LoadCheckpoint(checkpointPath)) + using (var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath)) { - var state = CheckpointState.LoadCheckpoint(checkpointPath); - cleanUp.Add(state); Assert.NotNull(state); - string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); - string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx"); - string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx"); - - var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath); - cleanUp.Add(trainingSession); - var buffer = trainingSession.ToBuffer(true); - cleanUp.Add(buffer); + using (var buffer = trainingSession.ToBuffer(true)) + { + Assert.NotNull(buffer); + var typeShape = buffer.GetTensorTypeAndShape(); + Assert.Equal(1, typeShape.DimensionsCount); + var fetchedShape = typeShape.Shape; + Assert.Equal(397510, fetchedShape[0]); - trainingSession.FromBuffer(buffer); + trainingSession.FromBuffer(buffer, true); + } } } @@ -530,6 +536,82 @@ public void TestSetSeed() TrainingUtils.SetSeed(8888); } + [Fact(DisplayName = "TestGetParameter")] + public void TestGetParameter() + { + string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt"); + string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); + string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx"); + string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx"); + + using (var state = CheckpointState.LoadCheckpoint(checkpointPath)) + using (var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath)) + using (var parameter = state.GetParameter("fc1.weight")) + { + Assert.NotNull(state); + Assert.NotNull(parameter); + + var typeShape = parameter.GetTensorTypeAndShape(); + Assert.Equal(2, typeShape.DimensionsCount); + var fetchedShape = typeShape.Shape; + Assert.Equal(500, fetchedShape[0]); + Assert.Equal(784, fetchedShape[1]); + } + } + + [Fact(DisplayName = "TestUpdateParameter")] + public void TestUpdateParameter() + { + string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt"); + string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); + string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx"); + string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx"); + + using (var state = CheckpointState.LoadCheckpoint(checkpointPath)) + using (var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath)) + { + Assert.NotNull(state); + + using (var parameter = state.GetParameter("fc1.weight")) + { + Assert.NotNull(parameter); + var typeShape = parameter.GetTensorTypeAndShape(); + + Assert.Equal(2, typeShape.DimensionsCount); + var fetchedShape = typeShape.Shape; + Assert.Equal(500, fetchedShape[0]); + Assert.Equal(784, fetchedShape[1]); + + float maxVal = 20; + Random randNum = new Random(); + float[] updated_parameter_buffer = Enumerable + .Repeat(0, 500 * 784) + .Select(i => maxVal * (float)randNum.NextDouble()) + .ToArray(); + + using (var updated_parameter = OrtValue.CreateTensorValueFromMemory(updated_parameter_buffer, fetchedShape)) + { + state.UpdateParameter("fc1.weight", updated_parameter); + using (var current_parameter = state.GetParameter("fc1.weight")) + { + var current_parameter_tensor = current_parameter.GetTensorDataAsSpan().ToArray(); + Assert.Equal(updated_parameter_buffer, current_parameter_tensor); + Assert.NotEqual(parameter.GetTensorDataAsSpan().ToArray(), current_parameter_tensor); + } + + state.UpdateParameter("fc1.weight", parameter); + + using (var current_parameter = state.GetParameter("fc1.weight")) + { + var current_parameter_tensor = current_parameter.GetTensorDataAsSpan().ToArray(); + Assert.Equal(parameter.GetTensorDataAsSpan().ToArray(), current_parameter_tensor); + Assert.NotEqual(updated_parameter_buffer, current_parameter_tensor); + } + } + } + } + } + internal class FloatComparer : IEqualityComparer { private float atol = 1e-3f; diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index da4c248b58d94..c4984e423df21 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -763,6 +763,7 @@ Do not modify directly.* |Shrink|*in* input:**T**
*out* output:**T**|9+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Sigmoid|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)| +|Sign|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |SimplifiedLayerNormalization|*in* X:**T**
*in* scale:**V**
*out* Y:**V**
*out* inv_std_var:**U**|1+|**T** = tensor(double), tensor(float), tensor(float16)
**U** = tensor(double), tensor(float)
**V** = tensor(double), tensor(float), tensor(float16)| |Sin|*in* input:**T**
*out* output:**T**|7+|**T** = tensor(double), tensor(float), tensor(float16)| |Size|*in* data:**T**
*out* size:**T1**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 86b7450a7c4e5..32cc69d0b8040 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -451,12 +451,16 @@ Return Value: #if defined(_WIN32) HasDotProductInstructions = (IsProcessorFeaturePresent(PF_ARM_V82_DP_INSTRUCTIONS_AVAILABLE) != 0); -#elif !defined(__APPLE__) // The next few lines result in an EXC_BAD_INSTRUCTION runtime error on a M1 Mac so we - // disable it there. - uint64_t isar0_el1; - asm("mrs %[reg], ID_AA64ISAR0_EL1\n" : [reg] "=r"(isar0_el1) : :); - HasDotProductInstructions = ((isar0_el1 >> 44) & 0xfu) == 0x1u; #else + // Use the cpuinfo value which is read from sysctl and has some additional special cases. + // https://github.com/pytorch/cpuinfo/blob/959002f82d7962a473d8bf301845f2af720e0aa4/src/arm/mach/init.c#L369-L379 + // Do NOT use ID_AA64ISAR0_EL1. It causes illegal instruction errors on Mac M1 and ARMv8-A chips + // as well as failing on other ARM chips as it is an EL1 level register that requires extra + // privileges to read. + // + // uint64_t isar0_el1; + // asm("mrs %[reg], ID_AA64ISAR0_EL1\n" : [reg] "=r"(isar0_el1) : :); + // HasDotProductInstructions = ((isar0_el1 >> 44) & 0xfu) == 0x1u; HasDotProductInstructions = MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeonDot(); #endif diff --git a/onnxruntime/core/providers/cuda/cu_inc/common.cuh b/onnxruntime/core/providers/cuda/cu_inc/common.cuh index a50b53315ec9a..0d9928baa86e0 100644 --- a/onnxruntime/core/providers/cuda/cu_inc/common.cuh +++ b/onnxruntime/core/providers/cuda/cu_inc/common.cuh @@ -20,7 +20,7 @@ namespace cuda { // float16 arithmetic is supported after sm5.3 with intrinsics, and cuda does not provide fallback for lower versions // CUDA 12.2 does not limit the definition based on sm53 anymore and defines for all arches -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 530) && ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12 ) && (__CUDACC_VER_MINOR__ < 2))) +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 530) && ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 2))) __device__ __forceinline__ half operator+(const half& lh, const half& rh) { return half((float)lh + (float)rh); } __device__ __forceinline__ half operator-(const half& lh, const half& rh) { return half((float)lh - (float)rh); } __device__ __forceinline__ half operator*(const half& lh, const half& rh) { return half((float)lh * (float)rh); } @@ -351,6 +351,18 @@ __device__ __inline__ T _Max(T a, T b) { return a > b ? a : b; } template __device__ __inline__ T _Abs(T a) { return a > (T)0 ? a : -a; } +template +__device__ __inline__ T _Signum(T a, std::false_type /* is_signed */) { return T(0) < a; } + +template +__device__ __inline__ T _Signum(T a, std::true_type /* is_signed */) { return (T(0) < a) - (a < T(0)); } + +template +__device__ __inline__ T _Sign(T a) { return _Signum(a, std::is_signed()); } + +template <> +__device__ __inline__ half _Sign(half a) { return _Signum(a, std::true_type()); } + template __device__ __inline__ T _Normcdf(T a); diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index aa60db4d07222..ad892eab3b843 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -1180,6 +1180,17 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, bool, Pad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, SpaceToDepth); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, DepthToSpace); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int8_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int16_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int64_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint8_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint16_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint32_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint64_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Sign); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Add); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Sub); @@ -2118,6 +2129,17 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc index f026444328b24..9ede1f8d90ecc 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc @@ -157,6 +157,7 @@ UNARY_OP_HFD(Sqrt, 13) UNARY_OP_HFD(Log, 13) UNARY_OP_HFD(Exp, 13) UNARY_OP_HFD(Erf, 13) +UNARY_OP_BWUZCSILHFD(Sign, 13) UNARY_LOGICALOP_NOT_TYPED(1, bool) UNARY_OP_HFD(Round, 11) diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h index 3ff97a60114df..775b78c43a736 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h @@ -112,5 +112,12 @@ class Cos final : public UnaryElementwise { Status ComputeInternal(OpKernelContext* context) const override; }; +template +class Sign final : public UnaryElementwise { + public: + Sign(const OpKernelInfo& info) : UnaryElementwise(info) {} + Status ComputeInternal(OpKernelContext* context) const override; +}; + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu index ac7cc1126acb7..1298d53338337 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu @@ -90,6 +90,7 @@ SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Round) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Sin) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Cos) SPECIALIZED_UNARY_ELEMENTWISE_IMPL(Not, bool) +SPECIALIZED_UNARY_ELEMENTWISE_IMPL_BWUZCSILHFD(Sign) // When casting, half needs to be converted via float type from most other types template @@ -119,52 +120,52 @@ struct OP_Cast { } }; -#define IMPL_CAST_IMPL(InT, OutT) \ +#define IMPL_CAST_IMPL(InT, OutT) \ void Explicit_Impl_Cast(cudaStream_t stream, const InT* input_data, OutT* output_data, size_t count) { \ - UnaryElementWiseImpl(stream, input_data, output_data, OP_Cast(), count); \ + UnaryElementWiseImpl(stream, input_data, output_data, OP_Cast(), count); \ } -#define IMPL_CAST_IMPL_THROW(InT, OutT) \ +#define IMPL_CAST_IMPL_THROW(InT, OutT) \ void Explicit_Impl_Cast(cudaStream_t stream, const InT* input_data, OutT* output_data, size_t count) { \ - ORT_THROW("Cast from " #InT " to " #OutT " must define saturate."); \ + ORT_THROW("Cast from " #InT " to " #OutT " must define saturate."); \ } #if !defined(DISABLE_FLOAT8_TYPES) -#define IMPL_CAST_IMPL_FROM(T) \ - IMPL_CAST_IMPL(T, half) \ - IMPL_CAST_IMPL(T, float) \ - IMPL_CAST_IMPL(T, double) \ - IMPL_CAST_IMPL(T, int8_t) \ - IMPL_CAST_IMPL(T, int16_t) \ - IMPL_CAST_IMPL(T, int32_t) \ - IMPL_CAST_IMPL(T, int64_t) \ - IMPL_CAST_IMPL(T, uint8_t) \ - IMPL_CAST_IMPL(T, uint16_t) \ - IMPL_CAST_IMPL(T, uint32_t) \ - IMPL_CAST_IMPL(T, uint64_t) \ - IMPL_CAST_IMPL(T, bool) \ - IMPL_CAST_IMPL(T, BFloat16) \ - IMPL_CAST_IMPL_THROW(T, Float8E4M3FN) \ - IMPL_CAST_IMPL_THROW(T, Float8E5M2) \ +#define IMPL_CAST_IMPL_FROM(T) \ + IMPL_CAST_IMPL(T, half) \ + IMPL_CAST_IMPL(T, float) \ + IMPL_CAST_IMPL(T, double) \ + IMPL_CAST_IMPL(T, int8_t) \ + IMPL_CAST_IMPL(T, int16_t) \ + IMPL_CAST_IMPL(T, int32_t) \ + IMPL_CAST_IMPL(T, int64_t) \ + IMPL_CAST_IMPL(T, uint8_t) \ + IMPL_CAST_IMPL(T, uint16_t) \ + IMPL_CAST_IMPL(T, uint32_t) \ + IMPL_CAST_IMPL(T, uint64_t) \ + IMPL_CAST_IMPL(T, bool) \ + IMPL_CAST_IMPL(T, BFloat16) \ + IMPL_CAST_IMPL_THROW(T, Float8E4M3FN) \ + IMPL_CAST_IMPL_THROW(T, Float8E5M2) \ IMPL_CAST_IMPL_THROW(T, Float8E4M3FNUZ) \ IMPL_CAST_IMPL_THROW(T, Float8E5M2FNUZ) #else -#define IMPL_CAST_IMPL_FROM(T) \ - IMPL_CAST_IMPL(T, half) \ - IMPL_CAST_IMPL(T, float) \ - IMPL_CAST_IMPL(T, double) \ - IMPL_CAST_IMPL(T, int8_t) \ - IMPL_CAST_IMPL(T, int16_t) \ - IMPL_CAST_IMPL(T, int32_t) \ - IMPL_CAST_IMPL(T, int64_t) \ - IMPL_CAST_IMPL(T, uint8_t) \ - IMPL_CAST_IMPL(T, uint16_t) \ - IMPL_CAST_IMPL(T, uint32_t) \ - IMPL_CAST_IMPL(T, uint64_t) \ - IMPL_CAST_IMPL(T, bool) \ +#define IMPL_CAST_IMPL_FROM(T) \ + IMPL_CAST_IMPL(T, half) \ + IMPL_CAST_IMPL(T, float) \ + IMPL_CAST_IMPL(T, double) \ + IMPL_CAST_IMPL(T, int8_t) \ + IMPL_CAST_IMPL(T, int16_t) \ + IMPL_CAST_IMPL(T, int32_t) \ + IMPL_CAST_IMPL(T, int64_t) \ + IMPL_CAST_IMPL(T, uint8_t) \ + IMPL_CAST_IMPL(T, uint16_t) \ + IMPL_CAST_IMPL(T, uint32_t) \ + IMPL_CAST_IMPL(T, uint64_t) \ + IMPL_CAST_IMPL(T, bool) \ IMPL_CAST_IMPL(T, BFloat16) #endif @@ -199,58 +200,58 @@ struct OP_CastNoSat { #if defined(CUDA_VERSION) && CUDA_VERSION >= 11080 -#define OP_CAST(T, NVT) \ - template <> \ - struct OP_CastSat { \ - __device__ __inline__ T operator()(const half& v) const { \ +#define OP_CAST(T, NVT) \ + template <> \ + struct OP_CastSat { \ + __device__ __inline__ T operator()(const half& v) const { \ return T(static_cast(__nv_cvt_halfraw_to_fp8(v, __NV_SATFINITE, NVT)), T::FromBits()); \ - } \ - }; \ - template <> \ - struct OP_CastNoSat { \ - __device__ __inline__ T operator()(const half& v) const { \ - return T(static_cast(__nv_cvt_halfraw_to_fp8(v, __NV_NOSAT, NVT)), T::FromBits()); \ - } \ - }; \ - template <> \ - struct OP_CastSat { \ - __device__ __inline__ T operator()(const float& v) const { \ - return T(static_cast(__nv_cvt_float_to_fp8(v, __NV_SATFINITE, NVT)), T::FromBits()); \ - } \ - }; \ - template <> \ - struct OP_CastNoSat { \ - __device__ __inline__ T operator()(const float& v) const { \ - return T(static_cast(__nv_cvt_float_to_fp8(v, __NV_NOSAT, NVT)), T::FromBits()); \ - } \ + } \ + }; \ + template <> \ + struct OP_CastNoSat { \ + __device__ __inline__ T operator()(const half& v) const { \ + return T(static_cast(__nv_cvt_halfraw_to_fp8(v, __NV_NOSAT, NVT)), T::FromBits()); \ + } \ + }; \ + template <> \ + struct OP_CastSat { \ + __device__ __inline__ T operator()(const float& v) const { \ + return T(static_cast(__nv_cvt_float_to_fp8(v, __NV_SATFINITE, NVT)), T::FromBits()); \ + } \ + }; \ + template <> \ + struct OP_CastNoSat { \ + __device__ __inline__ T operator()(const float& v) const { \ + return T(static_cast(__nv_cvt_float_to_fp8(v, __NV_NOSAT, NVT)), T::FromBits()); \ + } \ }; #else -#define OP_CAST(T, NVT) \ - template <> \ - struct OP_CastSat { \ - __device__ __inline__ T operator()(const half& v) const { \ - return T(__half2float(v), true); \ - } \ - }; \ - template <> \ - struct OP_CastNoSat { \ - __device__ __inline__ T operator()(const half& v) const { \ - return T(__half2float(v), false); \ - } \ - }; \ - template <> \ - struct OP_CastSat { \ +#define OP_CAST(T, NVT) \ + template <> \ + struct OP_CastSat { \ + __device__ __inline__ T operator()(const half& v) const { \ + return T(__half2float(v), true); \ + } \ + }; \ + template <> \ + struct OP_CastNoSat { \ + __device__ __inline__ T operator()(const half& v) const { \ + return T(__half2float(v), false); \ + } \ + }; \ + template <> \ + struct OP_CastSat { \ __device__ __inline__ T operator()(const float& v) const { \ - return T(v, true); \ - } \ - }; \ - template <> \ - struct OP_CastNoSat { \ + return T(v, true); \ + } \ + }; \ + template <> \ + struct OP_CastNoSat { \ __device__ __inline__ T operator()(const float& v) const { \ - return T(v, false); \ - } \ + return T(v, false); \ + } \ }; #endif @@ -260,14 +261,13 @@ struct OP_CastNoSat { OP_CAST(Float8E4M3FN, __NV_E4M3) OP_CAST(Float8E5M2, __NV_E5M2) - -#define EXPLICIT_IMPL_CASTSAT(InT, OutT) \ +#define EXPLICIT_IMPL_CASTSAT(InT, OutT) \ void Explicit_Impl_CastSat(cudaStream_t stream, const InT* input_data, OutT* output_data, size_t count, bool saturate) { \ - if (saturate) { \ - UnaryElementWiseImpl(stream, input_data, output_data, OP_CastSat(), count); \ - } else { \ - UnaryElementWiseImpl(stream, input_data, output_data, OP_CastNoSat(), count); \ - } \ + if (saturate) { \ + UnaryElementWiseImpl(stream, input_data, output_data, OP_CastSat(), count); \ + } else { \ + UnaryElementWiseImpl(stream, input_data, output_data, OP_CastNoSat(), count); \ + } \ } EXPLICIT_IMPL_CASTSAT(float, Float8E4M3FN) diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h index 3d4868b54abe6..608a81a24cf4f 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h @@ -31,7 +31,8 @@ namespace cuda { UNARY_OP_NAME_EXPR(Not, !a) \ UNARY_OP_NAME_EXPR(Round, _Round(a)) \ UNARY_OP_NAME_EXPR(Sin, _Sin(a)) \ - UNARY_OP_NAME_EXPR(Cos, _Cos(a)) + UNARY_OP_NAME_EXPR(Cos, _Cos(a)) \ + UNARY_OP_NAME_EXPR(Sign, _Sign(a)) #define UNARY_ELEMENTWISE_IMPL_DECLARATION(name) \ template \ diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h index cf93d24c29bcf..074f13b309181 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h @@ -83,7 +83,7 @@ namespace Windows::AI::MachineLearning::Adapter // Either nodesAsOperatorDesc or nodesAsIDMLOperator can have non-zero size. struct DmlGraphNodeCreateInfo { - uint32_t nodeCount; + uint32_t nodeCount = 0; std::vector> nodesAsOperatorDesc; std::vector> nodesAsIDMLOperator; std::vector inputEdges; diff --git a/onnxruntime/core/providers/rocm/cu_inc/common.cuh b/onnxruntime/core/providers/rocm/cu_inc/common.cuh index 5c516aac65aab..429ceb1f7c699 100644 --- a/onnxruntime/core/providers/rocm/cu_inc/common.cuh +++ b/onnxruntime/core/providers/rocm/cu_inc/common.cuh @@ -250,6 +250,18 @@ __device__ __inline__ T _Max(T a, T b) { return a > b ? a : b; } template __device__ __inline__ T _Abs(T a) { return a > (T)0 ? a : -a; } +template +__device__ __inline__ T _Signum(T a, std::false_type /* is_signed */) { return T(0) < a; } + +template +__device__ __inline__ T _Signum(T a, std::true_type /* is_signed */) { return (T(0) < a) - (a < T(0)); } + +template +__device__ __inline__ T _Sign(T a) { return _Signum(a, std::is_signed()); } + +template <> +__device__ __inline__ half _Sign(half a) { return _Signum(a, std::true_type()); } + template __device__ __inline__ T _Normcdf(T a); @@ -337,7 +349,7 @@ struct GridDim { }; // aligned vector generates vectorized load/store -template +template struct alignas(sizeof(T) * vec_size) aligned_vector { T val[vec_size]; }; @@ -350,11 +362,11 @@ struct alignas(sizeof(T) * vec_size) aligned_vector { // HIP_KERNEL_ASSERT is a macro that wraps an assert() call inside rocm kernels. // TODO ROCM added support recently, should verify. #define HIP_KERNEL_ASSERT(...) -//#define HIP_KERNEL_ASSERT(...) assert(__VA_ARGS__) +// #define HIP_KERNEL_ASSERT(...) assert(__VA_ARGS__) // WARP related definitions and functions constexpr int GPU_WARP_SIZE = warpSize; -inline int GPU_WARP_SIZE_HOST= warpSizeDynamic(); +inline int GPU_WARP_SIZE_HOST = warpSizeDynamic(); template __device__ __forceinline__ T WARP_SHFL(T value, int srcLane, int width = GPU_WARP_SIZE, unsigned int mask = 0xffffffff) { diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index 9401de64269b9..e6ea876d8957c 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -1105,6 +1105,17 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, uint8_t, QuantizeLinear); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, int8_t, DequantizeLinear); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, uint8_t, DequantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int8_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int16_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int32_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int64_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint8_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint16_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint32_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint64_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Sign); // OpSet 14 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, CumSum); @@ -2067,6 +2078,17 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // OpSet 14 BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 88a576f3ffa73..ac92d46ca87fc 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -792,6 +792,10 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv if (info.has_user_compute_stream) { external_stream_ = true; stream_ = static_cast(info.user_compute_stream); + ORT_IGNORE_RETURN_VALUE(CUBLAS_CALL(cublasCreate(&external_cublas_handle_))); + ORT_IGNORE_RETURN_VALUE(CUBLAS_CALL(cublasSetStream(external_cublas_handle_, stream_))); + ORT_IGNORE_RETURN_VALUE(CUDNN_CALL(cudnnCreate(&external_cudnn_handle_))); + ORT_IGNORE_RETURN_VALUE(CUDNN_CALL(cudnnSetStream(external_cudnn_handle_, stream_))); } std::string profile_min_shapes, profile_max_shapes, profile_opt_shapes; @@ -1860,6 +1864,7 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, } else if (number_of_trt_nodes == number_of_ort_nodes) { LOGS_DEFAULT(INFO) << "[TensorRT EP] Whole graph will run on TensorRT execution provider"; } else { + sync_stream_after_enqueue_ = true; LOGS_DEFAULT(INFO) << "[TensorRT EP] Graph is partitioned and number of subgraphs running on TensorRT execution provider is " << number_of_subgraphs; } @@ -2372,7 +2377,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorallocate_func, context->release_func, context->allocator_handle, context->node_name, &parsers_[context->node_name], &engines_[context->node_name], &contexts_[context->node_name], &builders_[context->node_name], &networks_[context->node_name], input_info_[context->node_name], output_info_[context->node_name], - input_shape_ranges_[context->node_name], &tensorrt_mu_, fp16_enable_, int8_enable_, int8_calibration_cache_available_, + input_shape_ranges_[context->node_name], sync_stream_after_enqueue_, &tensorrt_mu_, fp16_enable_, int8_enable_, int8_calibration_cache_available_, dla_enable_, dla_core_, &max_workspace_size_, trt_node_name_with_precision, engine_cache_enable_, cache_path_, runtime_.get(), profiles_[context->node_name], context_memory_sharing_enable_, &max_ctx_mem_size_, dynamic_range_map, engine_decryption_enable_, engine_decryption_, engine_encryption_, timing_cache_enable_, @@ -2400,6 +2405,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector& input_indexes = (trt_state->input_info)[0]; const std::unordered_map& output_indexes = (trt_state->output_info)[0]; const std::unordered_map& output_types = (trt_state->output_info)[1]; + bool sync_stream_after_enqueue = trt_state->sync_stream_after_enqueue; auto fused_node_name = trt_state->fused_node_name; auto& shape_ranges = trt_state->input_shape_ranges; auto trt_builder = trt_state->builder->get(); @@ -3001,6 +3007,10 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector> input_info; std::vector> output_info; std::unordered_map>>> input_shape_ranges; + bool sync_stream_after_enqueue = false; OrtMutex* tensorrt_mu_ptr = nullptr; bool fp16_enable = false; bool int8_enable = false; @@ -262,6 +263,9 @@ class TensorrtExecutionProvider : public IExecutionProvider { cudnnHandle_t external_cudnn_handle_ = nullptr; cublasHandle_t external_cublas_handle_ = nullptr; + // Call cudaStreamSynchronize() after TRT enqueueV2()/enqueueV3() + mutable bool sync_stream_after_enqueue_ = false; + CUDAGraph cuda_graph_; bool is_graph_captured_ = false; int regular_run_count_before_graph_capture_ = 0; diff --git a/onnxruntime/test/providers/cpu/math/sign_test.cc b/onnxruntime/test/providers/cpu/math/sign_test.cc index 12844068c47d2..15b3f40faa791 100644 --- a/onnxruntime/test/providers/cpu/math/sign_test.cc +++ b/onnxruntime/test/providers/cpu/math/sign_test.cc @@ -113,7 +113,7 @@ TestImpl(ForwardIter first, ForwardIter last, OutputIter out) { TEST(MathOpTest, Sign_uint64) { using namespace test_sign_internal; - OpTester test("Sign", 9); + OpTester test("Sign", 13); std::vector input_dims{7}; std::vector input; @@ -129,7 +129,7 @@ TEST(MathOpTest, Sign_uint64) { // we disable this test for openvino as openvino ep supports only FP32 Precision TEST(MathOpTest, Sign_int64) { using namespace test_sign_internal; - OpTester test("Sign", 9); + OpTester test("Sign", 13); std::vector input_dims{7}; std::vector input; @@ -146,7 +146,7 @@ TEST(MathOpTest, Sign_int64) { TEST(MathOpTest, Sign_float) { using namespace test_sign_internal; - OpTester test("Sign", 9); + OpTester test("Sign", 13); std::vector input_dims{7}; std::vector input; @@ -162,7 +162,7 @@ TEST(MathOpTest, Sign_float) { TEST(MathOpTest, Sign_double) { using namespace test_sign_internal; - OpTester test("Sign", 9); + OpTester test("Sign", 13); std::vector input_dims{7}; std::vector input; @@ -177,7 +177,7 @@ TEST(MathOpTest, Sign_double) { } TEST(MathOpTest, Sign_MLFloat16) { using namespace test_sign_internal; - OpTester test("Sign", 9); + OpTester test("Sign", 13); std::vector input_dims{7}; std::vector input; diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc index 3f3aa396e6ca0..35d9755ba0ba7 100644 --- a/orttraining/orttraining/python/orttraining_pybind_state.cc +++ b/orttraining/orttraining/python/orttraining_pybind_state.cc @@ -1065,17 +1065,60 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn checkpoint_state(m, "CheckpointState", R"pbdoc(CheckpointState.)pbdoc"); checkpoint_state .def(py::init()) - .def("add_property", [](onnxruntime::training::api::CheckpointState* state, - const std::string& property_name, - const std::variant& property_value) { - state->property_bag.AddProperty(property_name, property_value); - }) - .def("get_property", [](onnxruntime::training::api::CheckpointState* state, const std::string& property_name) { - return state->property_bag.GetProperty(property_name); - }) - .def("has_property", [](onnxruntime::training::api::CheckpointState* state, const std::string& property_name) { - return state->property_bag.HasProperty(property_name); - }); + .def("add_property", + [](onnxruntime::training::api::CheckpointState* state, + const std::string& property_name, + const std::variant& property_value) { + state->property_bag.AddProperty(property_name, property_value); + }) + .def("get_property", + [](onnxruntime::training::api::CheckpointState* state, const std::string& property_name) { + return state->property_bag.GetProperty(property_name); + }) + .def("has_property", + [](onnxruntime::training::api::CheckpointState* state, const std::string& property_name) { + return state->property_bag.HasProperty(property_name); + }) + .def("copy_parameter_from", + [](onnxruntime::training::api::CheckpointState* state, + const std::string& parameter_name, OrtValue& value) -> void { + auto it = state->module_checkpoint_state.named_parameters.find(parameter_name); + if (it == state->module_checkpoint_state.named_parameters.end()) { + ORT_THROW("Parameter with name ", parameter_name, " does not exist."); + } + ORT_THROW_IF_ERROR(it->second->CopyFrom( + state->module_checkpoint_state.train_session_data_transfer_mgr, value)); + }) + .def("get_parameter", + [](onnxruntime::training::api::CheckpointState* state, const std::string& parameter_name) { + auto it = state->module_checkpoint_state.named_parameters.find(parameter_name); + if (it == state->module_checkpoint_state.named_parameters.end()) { + ORT_THROW("Parameter with name ", parameter_name, " does not exist."); + } + return it->second; + }) + .def("has_parameter", + [](onnxruntime::training::api::CheckpointState* state, const std::string& parameter_name) { + return state->module_checkpoint_state.named_parameters.count(parameter_name); + }) + .def("parameter_names", + [](onnxruntime::training::api::CheckpointState* state) { + std::vector names; + for ([[maybe_unused]] auto& [name, value] : state->module_checkpoint_state.named_parameters) { + names.push_back(name); + } + std::sort(names.begin(), names.end()); + return names; + }) + .def("property_names", + [](onnxruntime::training::api::CheckpointState* state) { + std::vector names; + for ([[maybe_unused]] auto& [name, value] : state->property_bag) { + names.push_back(name); + } + std::sort(names.begin(), names.end()); + return names; + }); py::class_ training_optimizer(m, "Optimizer", R"pbdoc(Training Optimizer.)pbdoc"); @@ -1111,6 +1154,21 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn ORT_THROW_IF_ERROR(scheduler->Step()); }); + py::class_> + parameter(m, "Parameter"); + parameter + .def_property_readonly("name", &onnxruntime::training::api::Parameter::Name) + .def_property_readonly("data", &onnxruntime::training::api::Parameter::Data) + .def_property_readonly("grad", &onnxruntime::training::api::Parameter::Gradient) + .def_property_readonly("requires_grad", &onnxruntime::training::api::Parameter::RequiresGrad) + .def("copy_from", + [](onnxruntime::training::api::Parameter* parameter, + onnxruntime::training::api::CheckpointState* state, + OrtValue& value) -> void { + ORT_THROW_IF_ERROR(parameter->CopyFrom(state->module_checkpoint_state.train_session_data_transfer_mgr, value)); + }); + m.def( "save_checkpoint", [](const std::vector& trainable_tensor_protos_pybytes, diff --git a/orttraining/orttraining/python/training/api/checkpoint_state.py b/orttraining/orttraining/python/training/api/checkpoint_state.py index 285264bbed744..ba95cd04fce7e 100644 --- a/orttraining/orttraining/python/training/api/checkpoint_state.py +++ b/orttraining/orttraining/python/training/api/checkpoint_state.py @@ -5,70 +5,171 @@ import os +import numpy as np + from onnxruntime.capi import _pybind_state as C +from onnxruntime.capi.onnxruntime_inference_collection import OrtValue -class CheckpointState: - """Class that holds the state of the training session +class Parameter: + """Class that represents a model parameter - This class holds all the state information of the training session such as the model parameters, - its gradients, the optimizer state and user defined properties. + This class represents a model parameter and provides access to its data, + gradient and other properties. This class is not expected to be instantiated directly. + Instead, it is returned by the `CheckpointState` object. + + Args: + parameter: The C.Parameter object that holds the underlying parameter data. + state: The C.CheckpointState object that holds the underlying session state. + """ + + def __init__(self, parameter: C.Parameter, state: C.CheckpointState): + self._parameter = parameter + self._state = state - User defined properties can be indexed by name from the `CheckpointState` object. + @property + def name(self) -> str: + """The name of the parameter""" + return self._parameter.name - To create the `CheckpointState`, use the `CheckpointState.load_checkpoint` method. + @property + def data(self) -> np.ndarray: + """The data of the parameter""" + return self._parameter.data.numpy() + + @data.setter + def data(self, value: np.ndarray) -> None: + """Sets the data of the parameter""" + self._parameter.copy_from(self._state, OrtValue.ortvalue_from_numpy(value)._ortvalue) + + @property + def grad(self) -> np.ndarray: + """The gradient of the parameter""" + return self._parameter.grad.numpy() if self._parameter.grad.has_value() else None + + @property + def requires_grad(self) -> bool: + """Whether or not the parameter requires its gradient to be computed""" + return self._parameter.requires_grad + + def __repr__(self) -> str: + """Returns a string representation of the parameter""" + return f"Parameter(name={self.name}, requires_grad={self.requires_grad})" + + +class Parameters: + """Class that holds all the model parameters + + This class holds all the model parameters and provides access to them. + This class is not expected to be instantiated directly. Instead, it is returned by the + `CheckpointState`'s parameters attribute. + This class behaves like a dictionary and provides access to the parameters by name. Args: - state: The C.Checkpoint state object that holds the underlying session state. + state: The C.CheckpointState object that holds the underlying session state. """ def __init__(self, state: C.CheckpointState): - if not isinstance(state, C.CheckpointState): - raise TypeError(f"Invalid argument for CheckpointState received {type(state)}") self._state = state - @classmethod - def load_checkpoint(cls, checkpoint_uri: str | os.PathLike) -> CheckpointState: - """Loads the checkpoint state from the checkpoint file + def __getitem__(self, name: str) -> Parameter: + """Gets the parameter associated with the given name + + Searches for the name in the parameters of the checkpoint state. Args: - checkpoint_uri: The path to the checkpoint file. + name: The name of the parameter Returns: - CheckpointState: The checkpoint state object. + The value of the parameter + + Raises: + KeyError: If the parameter is not found """ - return cls(C.load_checkpoint(os.fspath(checkpoint_uri))) - @classmethod - def save_checkpoint( - cls, state: CheckpointState, checkpoint_uri: str | os.PathLike, include_optimizer_state: bool = False - ) -> None: - """Saves the checkpoint state to the checkpoint file + if name not in self: + raise KeyError(f"Parameter {name} not found.") + + return Parameter(self._state.get_parameter(name), self._state) + + def __setitem__(self, name: str, value: np.ndarray) -> None: + """Sets the parameter value for the given name + + Searches for the name in the parameters of the checkpoint state. + If the name is found in parameters, the value is updated. Args: - state: The checkpoint state object. - checkpoint_uri: The path to the checkpoint file. - include_optimizer_state: If True, the optimizer state is also saved to the checkpoint file. + name: The name of the parameter + value: The value of the parameter as a numpy array + + Raises: + KeyError: If the parameter is not found """ - C.save_checkpoint(state._state, os.fspath(checkpoint_uri), include_optimizer_state) + if name not in self: + raise KeyError(f"Parameter {name} not found.") + + self._state.copy_parameter_from(name, OrtValue.ortvalue_from_numpy(value)._ortvalue) + + def __contains__(self, name: str) -> bool: + """Checks if the parameter exists in the state + + Args: + name: The name of the parameter + + Returns: + True if the name is a parameter False otherwise + """ + + return self._state.has_parameter(name) + + def __iter__(self): + """Returns an iterator over the properties""" + for parameter_name in self._state.parameter_names(): + yield parameter_name, Parameter(self._state.get_parameter(parameter_name), self._state) + + def __repr__(self) -> str: + """Returns a string representation of the parameters""" + return self._state.parameter_names() + + def __len__(self) -> int: + """Returns the number of parameters""" + return len(self._state.parameter_names()) + + +class Properties: + def __init__(self, state: C.CheckpointState): + self._state = state def __getitem__(self, name: str) -> int | float | str: """Gets the property associated with the given name + Searches for the name in the properties of the checkpoint state. + Args: name: The name of the property Returns: The value of the property + + Raises: + KeyError: If the property is not found """ + + if name not in self: + raise KeyError(f"Property {name} not found.") + return self._state.get_property(name) def __setitem__(self, name: str, value: int | float | str) -> None: """Sets the property value for the given name + Searches for the name in the properties of the checkpoint state. + The value is added or updated in the properties. + Args: name: The name of the property value: The value of the property + Properties only support int, float and str values. """ self._state.add_property(name, value) @@ -79,6 +180,75 @@ def __contains__(self, name: str) -> bool: name: The name of the property Returns: - True if the property exists, False otherwise + True if the name is a property, False otherwise """ + return self._state.has_property(name) + + def __iter__(self): + """Returns an iterator over the properties""" + for property_name in self._state.property_names(): + yield property_name, self._state.get_property(property_name) + + def __repr__(self) -> str: + """Returns a string representation of the properties""" + return self._state.property_names() + + def __len__(self) -> int: + """Returns the number of properties""" + return len(self._state.property_names()) + + +class CheckpointState: + """Class that holds the state of the training session + + This class holds all the state information of the training session such as the model parameters, + its gradients, the optimizer state and user defined properties. + + To create the `CheckpointState`, use the `CheckpointState.load_checkpoint` method. + + Args: + state: The C.Checkpoint state object that holds the underlying session state. + """ + + def __init__(self, state: C.CheckpointState): + if not isinstance(state, C.CheckpointState): + raise TypeError(f"Invalid argument for CheckpointState received {type(state)}") + self._state = state + self._parameters = Parameters(self._state) + self._properties = Properties(self._state) + + @classmethod + def load_checkpoint(cls, checkpoint_uri: str | os.PathLike) -> CheckpointState: + """Loads the checkpoint state from the checkpoint file + + Args: + checkpoint_uri: The path to the checkpoint file. + + Returns: + CheckpointState: The checkpoint state object. + """ + return cls(C.load_checkpoint(os.fspath(checkpoint_uri))) + + @classmethod + def save_checkpoint( + cls, state: CheckpointState, checkpoint_uri: str | os.PathLike, include_optimizer_state: bool = False + ) -> None: + """Saves the checkpoint state to the checkpoint file + + Args: + state: The checkpoint state object. + checkpoint_uri: The path to the checkpoint file. + include_optimizer_state: If True, the optimizer state is also saved to the checkpoint file. + """ + C.save_checkpoint(state._state, os.fspath(checkpoint_uri), include_optimizer_state) + + @property + def parameters(self) -> Parameters: + """Returns the model parameters from the checkpoint state""" + return self._parameters + + @property + def properties(self) -> Properties: + """Returns the properties from the checkpoint state""" + return self._properties diff --git a/orttraining/orttraining/test/python/orttraining_test_python_bindings.py b/orttraining/orttraining/test/python/orttraining_test_python_bindings.py index 56338ddbaffef..d5c37b3e36ee7 100644 --- a/orttraining/orttraining/test/python/orttraining_test_python_bindings.py +++ b/orttraining/orttraining/test/python/orttraining_test_python_bindings.py @@ -360,14 +360,18 @@ def test_add_get_property(property_value): if isinstance(property_value, float): property_value = float(np.float32(property_value)) - state["property"] = property_value - assert "property" in state - assert state["property"] == property_value + assert len(state.properties) == 0 + + state.properties["property"] = property_value + assert "property" in state.properties + assert state.properties["property"] == property_value + assert len(state.properties) == 1 CheckpointState.save_checkpoint(state, checkpoint_file_path) new_state = CheckpointState.load_checkpoint(checkpoint_file_path) - assert "property" in new_state - assert new_state["property"] == property_value + assert "property" in new_state.properties + assert new_state.properties["property"] == property_value + assert len(new_state.properties) == 1 def test_get_input_output_names(): @@ -563,3 +567,60 @@ def test_eval_step_with_ort_values(): fetches = model(inputs, labels) assert isinstance(fetches, OrtValue) assert fetches + + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_get_and_set_parameter_values(device): + with tempfile.TemporaryDirectory() as temp_dir: + ( + checkpoint_file_path, + training_model_file_path, + eval_model_file_path, + _, + pt_model, + ) = _create_training_artifacts( + temp_dir, requires_grad=["fc2.weight", "fc2.bias"], frozen_params=["fc1.weight", "fc1.bias"] + ) + + state = CheckpointState.load_checkpoint(checkpoint_file_path) + + model = Module(training_model_file_path, state, eval_model_file_path, device=device) + + state_dict = pt_model.state_dict() + assert len(state_dict) == len(state.parameters) + for parameter_name, _ in state.parameters: + assert parameter_name in state_dict + + for name, pt_param in pt_model.named_parameters(): + ort_param = state.parameters[name] + assert ort_param.name == name + assert np.allclose(pt_param.detach().cpu().numpy(), ort_param.data) + if name in ["fc1.weight", "fc1.bias"]: + assert ort_param.requires_grad is False + assert ort_param.grad is None + else: + assert ort_param.requires_grad is True + assert np.allclose(ort_param.grad, np.zeros_like(ort_param.data, dtype=np.float32)) + + original_param = state.parameters["fc1.weight"].data + state.parameters["fc1.weight"].data = np.ones_like(state.parameters["fc1.weight"].data, dtype=np.float32) + updated_param = state.parameters["fc1.weight"].data + assert np.allclose(updated_param, np.ones_like(updated_param, dtype=np.float32)) + + model.train() + inputs = torch.randn(64, 784).numpy() + labels = torch.randint(high=10, size=(64,), dtype=torch.int64).numpy() + loss = model(inputs, labels) + assert loss is not None + for name, _ in pt_model.named_parameters(): + ort_param = state.parameters[name] + assert ort_param.name == name + if name in ["fc1.weight", "fc1.bias"]: + assert ort_param.requires_grad is False + assert ort_param.grad is None + else: + assert ort_param.requires_grad is True + assert ort_param.grad.any() + + state.parameters["fc1.weight"] = original_param + assert np.allclose(state.parameters["fc1.weight"].data, original_param) diff --git a/orttraining/orttraining/test/training_api/core/training_capi_tests.cc b/orttraining/orttraining/test/training_api/core/training_capi_tests.cc index d734be8e3474b..e46952d87c2bf 100644 --- a/orttraining/orttraining/test/training_api/core/training_capi_tests.cc +++ b/orttraining/orttraining/test/training_api/core/training_capi_tests.cc @@ -318,4 +318,106 @@ TEST(TrainingCApiTest, LoadModelsFromBufferThrows) { testing::HasSubstr("Training Session Creation failed. Train model data cannot be NULL.")); } } + +TEST(TrainingCApiTest, GetParameter) { + auto model_uri = MODEL_FOLDER "training_model.onnx"; + + Ort::Env env; + Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt"); + Ort::TrainingSession training_session = Ort::TrainingSession(env, Ort::SessionOptions(), checkpoint_state, model_uri); + + Ort::Value parameter = checkpoint_state.GetParameter("fc1.weight"); + auto tensor_info = parameter.GetTensorTypeAndShapeInfo(); + auto shape = tensor_info.GetShape(); + ASSERT_EQ(shape.size(), 2U); + ASSERT_EQ(shape.front(), static_cast(500)); + ASSERT_EQ(shape.back(), static_cast(784)); +} + +TEST(TrainingCApiTest, UpdateParameter) { + auto model_uri = MODEL_FOLDER "training_model.onnx"; + + Ort::Env env; + Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt"); + Ort::TrainingSession training_session = Ort::TrainingSession(env, Ort::SessionOptions(), checkpoint_state, model_uri); + + Ort::Value parameter = checkpoint_state.GetParameter("fc1.weight"); + auto tensor_info = parameter.GetTensorTypeAndShapeInfo(); + auto shape = tensor_info.GetShape(); + ASSERT_EQ(shape.size(), 2U); + ASSERT_EQ(shape.front(), static_cast(500)); + ASSERT_EQ(shape.back(), static_cast(784)); + + OrtValue* updated_param_value = std::make_unique().release(); + GenerateRandomInput(std::array{500, 784}, *updated_param_value); + Ort::Value updated_parameter{updated_param_value}; + checkpoint_state.UpdateParameter("fc1.weight", updated_parameter); + + Ort::Value current_parameter = checkpoint_state.GetParameter("fc1.weight"); + gsl::span actual = gsl::span(current_parameter.GetTensorMutableData(), + current_parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + gsl::span expected = gsl::span(updated_parameter.GetTensorMutableData(), + updated_parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + gsl::span not_expected = gsl::span(parameter.GetTensorMutableData(), + parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + ASSERT_EQ(actual, expected); + ASSERT_NE(actual, not_expected); + + checkpoint_state.UpdateParameter("fc1.weight", parameter); + current_parameter = checkpoint_state.GetParameter("fc1.weight"); + actual = gsl::span(current_parameter.GetTensorMutableData(), + current_parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + expected = gsl::span(parameter.GetTensorMutableData(), + parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + not_expected = gsl::span(updated_parameter.GetTensorMutableData(), + updated_parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + ASSERT_EQ(actual, expected); + ASSERT_NE(actual, not_expected); +} + +#ifdef USE_CUDA +TEST(TrainingCApiTest, UpdateParameterDifferentDevices) { + auto model_uri = MODEL_FOLDER "training_model.onnx"; + + Ort::Env env; + Ort::SessionOptions session_options; + Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0)); + Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt"); + Ort::TrainingSession training_session = Ort::TrainingSession(env, session_options, checkpoint_state, model_uri); + + Ort::Value parameter = checkpoint_state.GetParameter("fc1.weight"); + auto tensor_info = parameter.GetTensorTypeAndShapeInfo(); + auto shape = tensor_info.GetShape(); + ASSERT_EQ(shape.size(), 2U); + ASSERT_EQ(shape.front(), static_cast(500)); + ASSERT_EQ(shape.back(), static_cast(784)); + + OrtValue* updated_param_value = std::make_unique().release(); + GenerateRandomInput(std::array{500, 784}, *updated_param_value); + Ort::Value updated_parameter{updated_param_value}; + checkpoint_state.UpdateParameter("fc1.weight", updated_parameter); + + Ort::Value current_parameter = checkpoint_state.GetParameter("fc1.weight"); + gsl::span actual = gsl::span(current_parameter.GetTensorMutableData(), + current_parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + gsl::span expected = gsl::span(updated_parameter.GetTensorMutableData(), + updated_parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + gsl::span not_expected = gsl::span(parameter.GetTensorMutableData(), + parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + ASSERT_EQ(actual, expected); + ASSERT_NE(actual, not_expected); + + checkpoint_state.UpdateParameter("fc1.weight", parameter); + current_parameter = checkpoint_state.GetParameter("fc1.weight"); + actual = gsl::span(current_parameter.GetTensorMutableData(), + current_parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + expected = gsl::span(parameter.GetTensorMutableData(), + parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + not_expected = gsl::span(updated_parameter.GetTensorMutableData(), + updated_parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + ASSERT_EQ(actual, expected); + ASSERT_NE(actual, not_expected); +} +#endif + } // namespace onnxruntime::training::test diff --git a/orttraining/orttraining/training_api/checkpoint_property.h b/orttraining/orttraining/training_api/checkpoint_property.h index d7b1e295df53e..3c38c99b3152f 100644 --- a/orttraining/orttraining/training_api/checkpoint_property.h +++ b/orttraining/orttraining/training_api/checkpoint_property.h @@ -22,10 +22,12 @@ struct PropertyBag { PropertyBag() = default; void AddProperty(const std::string& name, const PropertyDataType& val) { - ORT_ENFORCE(named_properties_.find(name) == named_properties_.end(), - "Duplicated property named ", name); - - named_properties_.insert({name, val}); + auto it = named_properties_.find(name); + if (it == named_properties_.end()) { + named_properties_.insert({name, val}); + } else { + it->second = val; + } } template diff --git a/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h b/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h index 0af737074964d..0e8544a7639ba 100644 --- a/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h +++ b/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h @@ -608,14 +608,14 @@ struct OrtTrainingApi { /// \name Accessing The Training Session State /// @{ - /** \brief Adds the given property to the checkpoint state. + /** \brief Adds or updates the given property to/in the checkpoint state. * * Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint - * state by the user if they desire by calling this function with the appropriate property name and - * value. The given property name must be unique to be able to successfully add the property. + * state by the user by calling this function with the corresponding property name and value. + * The given property name must be unique to be able to successfully add the property. * * \param[in] checkpoint_state The checkpoint state which should hold the property. - * \param[in] property_name Unique name of the property being added. + * \param[in] property_name Name of the property being added or updated. * \param[in] property_type Type of the property associated with the given name. * \param[in] property_value Property value associated with the given name. * @@ -632,7 +632,7 @@ struct OrtTrainingApi { * exist in the checkpoint state to be able to retrieve it successfully. * * \param[in] checkpoint_state The checkpoint state that is currently holding the property. - * \param[in] property_name Unique name of the property being retrieved. + * \param[in] property_name Name of the property being retrieved. * \param[in] allocator Allocator used to allocate the memory for the property_value. * \param[out] property_type Type of the property associated with the given name. * \param[out] property_value Property value associated with the given name. @@ -669,6 +669,57 @@ struct OrtTrainingApi { ORT_API2_STATUS(LoadCheckpointFromBuffer, _In_ const void* checkpoint_buffer, _In_ const size_t num_bytes, _Outptr_ OrtCheckpointState** checkpoint_state); + /** \brief Retrieves the type and shape information of the parameter associated with the given parameter name. + * + * This function retrieves the type and shape of the parameter associated with the given parameter name. + * The parameter must exist in the checkpoint state to be able to retrieve its type and shape information successfully. + * + * \param[in] checkpoint_state The checkpoint state. + * \param[in] parameter_name Name of the parameter being retrieved. + * \param[out] parameter_type_and_shape The type and shape of the parameter being retrieved. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + */ + ORT_API2_STATUS(GetParameterTypeAndShape, _In_ const OrtCheckpointState* checkpoint_state, + _In_ const char* parameter_name, _Outptr_ OrtTensorTypeAndShapeInfo** parameter_type_and_shape); + + /** \brief Updates the data associated with the model parameter in the checkpoint state for the given parameter name. + * + * This function updates a model parameter in the checkpoint state with the given parameter data. + * The training session must be already created with the checkpoint state that contains the parameter + * being updated. The given parameter is copied over to the registered device for the training session. + * The parameter must exist in the checkpoint state to be able to update it successfully. + * + * \param[in] checkpoint_state The checkpoint state. + * \param[in] parameter_name Name of the parameter being updated. + * \param[in] parameter The parameter data that should replace the existing parameter data. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + */ + ORT_API2_STATUS(UpdateParameter, _Inout_ OrtCheckpointState* checkpoint_state, + _In_ const char* parameter_name, _In_ OrtValue* parameter); + + /** \brief Gets the data associated with the model parameter from the checkpoint state for the given parameter name. + * + * This function retrieves the model parameter data from the checkpoint state for the given parameter name. + * The parameter is copied over and returned as an OrtValue. The training session must be already created + * with the checkpoint state that contains the parameter being retrieved. + * The parameter must exist in the checkpoint state to be able to retrieve it successfully. + * + * \param[in] checkpoint_state The checkpoint state. + * \param[in] parameter_name Name of the parameter being retrieved. + * \param[in] allocator Allocator used to allocate the memory for the parameter. + * \param[out] parameter The parameter data that is retrieved from the checkpoint state. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + */ + ORT_API2_STATUS(GetParameter, _In_ const OrtCheckpointState* checkpoint_state, + _In_ const char* parameter_name, _Inout_ OrtAllocator* allocator, + _Outptr_ OrtValue** parameter); + /// @} }; diff --git a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h index 0edef20ba6da8..218bef524200c 100644 --- a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h +++ b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h @@ -112,13 +112,13 @@ class CheckpointState : public detail::Base { const std::basic_string& path_to_checkpoint, const bool include_optimizer_state = false); - /** \brief Adds the given property to the checkpoint state. + /** \brief Adds or updates the given property to/in the checkpoint state. * * Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint - * state by the user if they desire by calling this function with the appropriate property name and - * value. The given property name must be unique to be able to successfully add the property. + * state by the user by calling this function with the corresponding property name and value. + * The given property name must be unique to be able to successfully add the property. * - * \param[in] property_name Unique name of the property being added. + * \param[in] property_name Name of the property being added or updated. * \param[in] property_value Property value associated with the given name. * */ @@ -129,12 +129,38 @@ class CheckpointState : public detail::Base { * Gets the property value from an existing entry in the checkpoint state. The property must * exist in the checkpoint state to be able to retrieve it successfully. * - * \param[in] property_name Unique name of the property being retrieved. + * \param[in] property_name Name of the property being retrieved. * \return Property value associated with the given property name. * */ Property GetProperty(const std::string& property_name); + /** \brief Updates the data associated with the model parameter in the checkpoint state for the given parameter name. + * + * This function updates a model parameter in the checkpoint state with the given parameter data. + * The training session must be already created with the checkpoint state that contains the parameter + * being updated. The given parameter is copied over to the registered device for the training session. + * The parameter must exist in the checkpoint state to be able to update it successfully. + * + * \param[in] parameter_name Name of the parameter being updated. + * \param[in] parameter The parameter data that should replace the existing parameter data. + * + */ + void UpdateParameter(const std::string& parameter_name, const Value& parameter); + + /** \brief Gets the data associated with the model parameter from the checkpoint state for the given parameter name. + * + * This function retrieves the model parameter data from the checkpoint state for the given parameter name. + * The parameter is copied over to the provided OrtValue. The training session must be already created + * with the checkpoint state that contains the parameter being retrieved. + * The parameter must exist in the checkpoint state to be able to retrieve it successfully. + * + * \param[in] parameter_name Name of the parameter being retrieved. + * \return The parameter data that is retrieved from the checkpoint state. + * + */ + Value GetParameter(const std::string& parameter_name); + /// @} }; diff --git a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h index 066147708863f..a5efa3c0e4bef 100644 --- a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h +++ b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h @@ -279,4 +279,16 @@ inline Property CheckpointState::GetProperty(const std::string& property_name) { return property; } +inline void CheckpointState::UpdateParameter(const std::string& parameter_name, const Value& parameter) { + ThrowOnError(GetTrainingApi().UpdateParameter(p_, parameter_name.c_str(), parameter)); +} + +inline Value CheckpointState::GetParameter(const std::string& parameter_name) { + AllocatorWithDefaultOptions allocator; + OrtValue* parameter; + ThrowOnError(GetTrainingApi().GetParameter(p_, parameter_name.c_str(), allocator, ¶meter)); + + return Value{parameter}; +} + } // namespace Ort diff --git a/orttraining/orttraining/training_api/module.cc b/orttraining/orttraining/training_api/module.cc index d1775e358163c..cf49a01517d6b 100644 --- a/orttraining/orttraining/training_api/module.cc +++ b/orttraining/orttraining/training_api/module.cc @@ -119,6 +119,61 @@ Status TransformModelInputsForInference(Graph& inference_graph, #endif } // namespace +Status Parameter::CopyTo(const DataTransferManager* data_transfer_manager, OrtValue& data) const { + ORT_ENFORCE(data.IsAllocated(), "Given parameter data is not allocated. Cannot copy the checkpoint parameter to it."); + ORT_ENFORCE(data.IsTensor(), "Parameter data should be of tensor type."); + ORT_ENFORCE(data.Get().Shape() == data_.Get().Shape(), + "Parameter data shape mismatch. Expected: ", data_.Get().Shape().ToString(), + ", Got: ", data.Get().Shape().ToString()); +#ifdef ENABLE_STRIDED_TENSORS + auto data_strides = data.Get().Strides(); + auto param_strides = data_.Get().Strides(); + ORT_ENFORCE(data_strides.size() == param_strides.size(), + "Parameter data stride mismatch. Expected strides of size: ", param_strides.size(), + ", Got: ", data_strides.size()); + ORT_ENFORCE(std::equal(data_strides.begin(), data_strides.end(), param_strides.begin()), + "Parameter data stride value mismatch."); +#endif + ORT_ENFORCE(data.Get().DataType() == data_.Get().DataType(), + "Parameter data type mismatch. Expected: ", data_.Get().DataType(), + ", Got: ", data.Get().DataType()); + ORT_ENFORCE(data_transfer_manager != nullptr, + "Data transfer manager must be provided to copy data to the parameter. " + "Please create the TrainingSession before trying to update the parameter."); + + ORT_THROW_IF_ERROR(data_transfer_manager->CopyTensor(data_.Get(), *data.GetMutable())); + + return Status::OK(); +} + +Status Parameter::CopyFrom(const DataTransferManager* data_transfer_manager, const OrtValue& data) { + ORT_ENFORCE(data_.IsAllocated(), + "The checkpoint parameter is not allocated. Cannot copy the given parameter data to it."); + ORT_ENFORCE(data.IsTensor(), "Parameter data should be of tensor type."); + ORT_ENFORCE(data.Get().Shape() == data_.Get().Shape(), + "Parameter data shape mismatch. Expected: ", data_.Get().Shape().ToString(), + ", Got: ", data.Get().Shape().ToString()); +#ifdef ENABLE_STRIDED_TENSORS + auto data_strides = data.Get().Strides(); + auto param_strides = data_.Get().Strides(); + ORT_ENFORCE(data_strides.size() == param_strides.size(), + "Parameter data stride mismatch. Expected strides of size: ", param_strides.size(), + ", Got: ", data_strides.size()); + ORT_ENFORCE(std::equal(data_strides.begin(), data_strides.end(), param_strides.begin()), + "Parameter data stride value mismatch."); +#endif + ORT_ENFORCE(data.Get().DataType() == data_.Get().DataType(), + "Parameter data type mismatch. Expected: ", data_.Get().DataType(), + ", Got: ", data.Get().DataType()); + ORT_ENFORCE(data_transfer_manager != nullptr, + "Data transfer manager must be provided to copy data to the parameter. " + "Please create the TrainingSession before trying to update the parameter."); + + ORT_THROW_IF_ERROR(data_transfer_manager->CopyTensor(data.Get(), *data_.GetMutable())); + + return Status::OK(); +} + Status Parameter::SetGrad(const std::string& gradient_name, const OrtValue& param_grad) { // assert param is allocated ORT_ENFORCE(data_.IsAllocated(), "Parameter data should be allocated before allocating gradient."); @@ -334,6 +389,10 @@ Module::Module(const ModelIdentifiers& model_identifiers, } } +Module::~Module() { + state_->module_checkpoint_state.train_session_data_transfer_mgr = nullptr; +} + size_t Module::GetTrainingModelOutputCount() const noexcept { return train_output_names_.size(); } diff --git a/orttraining/orttraining/training_api/module.h b/orttraining/orttraining/training_api/module.h index adb633343263e..f323e6be72d49 100644 --- a/orttraining/orttraining/training_api/module.h +++ b/orttraining/orttraining/training_api/module.h @@ -21,6 +21,8 @@ struct Parameter { // Return the mutable data. OrtValue& Data() { return data_; } + Status CopyTo(const DataTransferManager* data_transfer_manager, OrtValue& data) const; + Status CopyFrom(const DataTransferManager* data_transfer_manager, const OrtValue& data); const std::string& Name() const { return name_; } // Returns whether this parameter is trainable or not. @@ -34,7 +36,6 @@ struct Parameter { // Reset and release the gradient buffer of this Parameter greedily. Status ResetGrad(); - protected: Status SetGrad(const std::string& gradient_name, const OrtValue& param_grad); private: @@ -83,6 +84,8 @@ struct Module { const std::vector>& providers, gsl::span op_domains = gsl::span()); + ~Module(); + // Return the trainable/nontrainable parameters std::vector> Parameters() const; diff --git a/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc b/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc index 6693bba348648..38a9aad9640ea 100644 --- a/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc +++ b/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc @@ -333,6 +333,10 @@ ORT_API_STATUS_IMPL(OrtTrainingApis::LoadCheckpointFromBuffer, _In_ const void* _In_ const size_t num_bytes, _Outptr_ OrtCheckpointState** checkpoint_state) { API_IMPL_BEGIN + if (checkpoint_buffer == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Expected a valid checkpoint buffer. Actual: nullptr."); + } + *checkpoint_state = nullptr; auto chkpt_state = std::make_unique(); const auto* checkpoint_bytes = reinterpret_cast(checkpoint_buffer); @@ -559,6 +563,76 @@ ORT_API_STATUS_IMPL(OrtTrainingApis::GetProperty, _In_ const OrtCheckpointState* API_IMPL_END } +ORT_API_STATUS_IMPL(OrtTrainingApis::GetParameterTypeAndShape, _In_ const OrtCheckpointState* checkpoint_state, + _In_ const char* parameter_name, _Outptr_ OrtTensorTypeAndShapeInfo** parameter_type_and_shape) { + API_IMPL_BEGIN + + auto chkpt_state = reinterpret_cast(checkpoint_state); + auto it = chkpt_state->module_checkpoint_state.named_parameters.find(parameter_name); + if (it == chkpt_state->module_checkpoint_state.named_parameters.end()) { + std::string err_msg = "Parameter name " + std::string(parameter_name) + " not found in checkpoint state."; + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, err_msg.c_str()); + } + + return OrtApis::GetTensorTypeAndShape(&it->second->Data(), parameter_type_and_shape); + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtTrainingApis::UpdateParameter, _Inout_ OrtCheckpointState* checkpoint_state, + _In_ const char* parameter_name, _In_ OrtValue* parameter) { + API_IMPL_BEGIN + if (parameter == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Expected a valid parameter. Actual: nullptr."); + } + + auto chkpt_state = reinterpret_cast(checkpoint_state); + auto it = chkpt_state->module_checkpoint_state.named_parameters.find(parameter_name); + if (it == chkpt_state->module_checkpoint_state.named_parameters.end()) { + std::string err_msg = "Parameter name " + std::string(parameter_name) + " not found in checkpoint state."; + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, err_msg.c_str()); + } + ORT_API_RETURN_IF_STATUS_NOT_OK(it->second->CopyFrom( + chkpt_state->module_checkpoint_state.train_session_data_transfer_mgr, *parameter)); + + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtTrainingApis::GetParameter, _In_ const OrtCheckpointState* checkpoint_state, + _In_ const char* parameter_name, _Inout_ OrtAllocator* allocator, + _Outptr_ OrtValue** parameter) { + API_IMPL_BEGIN + + if (parameter == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Expected a valid parameter. Actual: nullptr."); + } + + auto chkpt_state = reinterpret_cast(checkpoint_state); + auto it = chkpt_state->module_checkpoint_state.named_parameters.find(parameter_name); + if (it == chkpt_state->module_checkpoint_state.named_parameters.end()) { + std::string err_msg = "Parameter name " + std::string(parameter_name) + " not found in checkpoint state."; + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, err_msg.c_str()); + } + + if (!it->second->Data().IsTensor()) { + return OrtApis::CreateStatus(ORT_FAIL, "Expected a tensor type for the parameter. Found a non-tensor type."); + } + const auto& parameter_tensor = it->second->Data().Get(); + ORT_API_RETURN_IF_ERROR(OrtApis::CreateTensorAsOrtValue( + allocator, parameter_tensor.Shape().GetDims().data(), parameter_tensor.Shape().NumDimensions(), + ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, parameter)); + + auto status = it->second->CopyTo( + chkpt_state->module_checkpoint_state.train_session_data_transfer_mgr, **parameter); + if (!status.IsOK()) { + OrtApis::ReleaseValue(*parameter); + return onnxruntime::ToOrtStatus(status); + } + + return nullptr; + API_IMPL_END +} + static constexpr OrtTrainingApi ort_training_api = { // NOTE: The C# bindings depend on the API order within this struct. Since Training APIs are not officially // released, it is OK to change the order here, however a corresponding matching change should also be done in the @@ -592,7 +666,10 @@ static constexpr OrtTrainingApi ort_training_api = { &OrtTrainingApis::TrainingSessionGetEvalModelInputName, &OrtTrainingApis::AddProperty, &OrtTrainingApis::GetProperty, - &OrtTrainingApis::LoadCheckpointFromBuffer}; + &OrtTrainingApis::LoadCheckpointFromBuffer, + &OrtTrainingApis::GetParameterTypeAndShape, + &OrtTrainingApis::UpdateParameter, + &OrtTrainingApis::GetParameter}; ORT_API(const OrtTrainingApi*, OrtTrainingApis::GetTrainingApi, uint32_t) { // No constraints on the API version yet. diff --git a/orttraining/orttraining/training_api/ort_training_apis.h b/orttraining/orttraining/training_api/ort_training_apis.h index c87108957c975..2a8c1e30361c6 100644 --- a/orttraining/orttraining/training_api/ort_training_apis.h +++ b/orttraining/orttraining/training_api/ort_training_apis.h @@ -94,4 +94,14 @@ ORT_API_STATUS_IMPL(GetProperty, _In_ const OrtCheckpointState* checkpoint_state ORT_API_STATUS_IMPL(LoadCheckpointFromBuffer, _In_ const void* checkpoint_buffer, _In_ const size_t num_bytes, _Outptr_ OrtCheckpointState** checkpoint_state); +ORT_API_STATUS_IMPL(GetParameterTypeAndShape, _In_ const OrtCheckpointState* checkpoint_state, + _In_ const char* parameter_name, _Outptr_ OrtTensorTypeAndShapeInfo** parameter_type_and_shape); + +ORT_API_STATUS_IMPL(UpdateParameter, _Inout_ OrtCheckpointState* checkpoint_state, + _In_ const char* parameter_name, _In_ OrtValue* parameter); + +ORT_API_STATUS_IMPL(GetParameter, _In_ const OrtCheckpointState* checkpoint_state, + _In_ const char* parameter_name, _Inout_ OrtAllocator* allocator, + _Outptr_ OrtValue** parameter); + } // namespace OrtTrainingApis diff --git a/winml/test/model/skip_model_tests.h b/winml/test/model/skip_model_tests.h index f815b197b6a00..9d66320343c43 100644 --- a/winml/test/model/skip_model_tests.h +++ b/winml/test/model/skip_model_tests.h @@ -163,10 +163,8 @@ std::unordered_map> disabledGpu test name -> absolute difference sampleTolerance */ std::unordered_map sampleTolerancePerTests({ - {"fp16_inception_v1_opset7_GPU",0.005 }, - {"fp16_inception_v1_opset8_GPU", 0.005}, - { "candy_opset9_GPU", - 0.00150000 }, // Intel(R) UHD Graphics 630 (29.20.100.9020) AP machine has inaccurate GPU results for FNS Candy opset 9 https://microsoft.visualstudio.com/OS/_workitems/edit/30696168/ - { "fp16_tiny_yolov2_opset8_GPU", - 0.109000 }, // Intel(R) UHD Graphics 630 (29.20.100.9020) AP machine has inaccurate GPU results for FNS Candy opset 9 https://microsoft.visualstudio.com/OS/_workitems/edit/30696168/ + {"fp16_inception_v1_opset7_GPU", 0.005}, + {"fp16_inception_v1_opset8_GPU", 0.005}, + { "candy_opset9_GPU", 0.00150000}, // Intel(R) UHD Graphics 630 (29.20.100.9020) AP machine has inaccurate GPU results for FNS Candy opset 9 https://microsoft.visualstudio.com/OS/_workitems/edit/30696168/ + { "fp16_tiny_yolov2_opset8_GPU", 0.109000}, // Intel(R) UHD Graphics 630 (29.20.100.9020) AP machine has inaccurate GPU results for FNS Candy opset 9 https://microsoft.visualstudio.com/OS/_workitems/edit/30696168/ });