From 5be66413fe2e416dbcea544273e03fb713873609 Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Tue, 19 Sep 2023 11:25:28 -0700 Subject: [PATCH] Address pull request review comments for C# and C API --- .../Training/CheckpointState.shared.cs | 29 ++----------- .../Training/NativeTrainingMethods.shared.cs | 3 +- .../Training/TrainingSession.shared.cs | 41 ++++++------------- .../TrainingTest.cs | 2 +- .../include/onnxruntime_training_c_api.h | 6 ++- .../include/onnxruntime_training_cxx_inline.h | 13 ++---- .../onnxruntime_training_c_api.cc | 13 +++++- .../training_api/ort_training_apis.h | 3 +- 8 files changed, 38 insertions(+), 72 deletions(-) diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs index a31626ea85a28..8eae86aa8588e 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs @@ -43,8 +43,7 @@ internal enum PropertyType : long private void AddPropertyImpl(string propertyName, PropertyType propertyType, T propertyValue) where T : unmanaged { var propertyNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyName); - T[] value = new T[1]; - value[0] = propertyValue; + T[] value = { propertyValue }; unsafe { fixed (T* memPtr = value) @@ -232,31 +231,9 @@ public void UpdateParameter(string parameterName, OrtValue parameter) public OrtValue GetParameter(string parameterName) { var parameterNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(parameterName); + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParameter(handle, parameterNameUtf8, OrtAllocator.DefaultInstance.Pointer, out IntPtr parameterHandle)); - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParameterTypeAndShape(handle, parameterNameUtf8, out IntPtr typeAndShapeInfoHandle)); - - try - { - var typeAndShapeInfo = new OrtTensorTypeAndShapeInfo(typeAndShapeInfoHandle); - var parameter = OrtValue.CreateAllocatedTensorValue(OrtAllocator.DefaultInstance, typeAndShapeInfo.ElementDataType, typeAndShapeInfo.Shape); - - try - { - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParameter(handle, parameterNameUtf8, parameter.Handle)); - } - catch (OnnxRuntimeException e) - { - parameter.Dispose(); - throw e; - } - - return parameter; - } - finally - { - NativeMethods.OrtReleaseTensorTypeAndShapeInfo(typeAndShapeInfoHandle); - } - + return new OrtValue(parameterHandle); } #region SafeHandle diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs index 6f1d94a8a8d25..d6341b90f28ff 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs @@ -387,7 +387,8 @@ out UIntPtr inputCount public delegate IntPtr /*(OrtStatus*)*/ DOrtGetParameter( IntPtr /*(OrtCheckpointState*)*/ checkpointState, byte[] /*(const char*)*/ parameterName, - IntPtr /*(OrtValue*)*/ parameter + IntPtr /*(OrtAllocator*)*/ allocator, + out IntPtr /*(OrtValue**)*/ parameter ); public static DOrtGetParameter OrtGetParameter; diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs index e4e45fdd18400..877677dcad57b 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs @@ -517,8 +517,7 @@ public OrtValue ToBuffer(bool onlyTrainable) float[] bufferMemory = new float[bufferSize.ToUInt64()]; - var memInfo = OrtMemoryInfo.DefaultInstance; // CPU - var shape = new long[] { (long)bufferSize.ToUInt64() }; + var shape = new long[] { (long)bufferSize }; var buffer = OrtValue.CreateAllocatedTensorValue(OrtAllocator.DefaultInstance, Tensors.TensorElementType.Float, shape); NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyParametersToBuffer(_nativeHandle, buffer.Handle, onlyTrainable)); @@ -529,46 +528,30 @@ public OrtValue ToBuffer(bool onlyTrainable) /// /// Loads the training session model parameters from a contiguous buffer /// - /// Contiguous buffer to load the parameters from. - public void FromBuffer(OrtValue buffer) + /// Contiguous buffer to load the parameters from. + /// Whether to only load trainable parameters or to load all parameters. + public void FromBuffer(OrtValue ortValue, bool onlyTrainable) { - if (buffer.OnnxType != OnnxValueType.ONNX_TYPE_TENSOR) + if (ortValue.OnnxType != OnnxValueType.ONNX_TYPE_TENSOR) { throw new ArgumentException("Incorrect buffer received. Expected a tensor buffer."); } - IntPtr typeAndShapeInfo = IntPtr.Zero; - NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorTypeAndShape(buffer.Handle, out typeAndShapeInfo)); - UIntPtr numDimensions = UIntPtr.Zero; - NativeApiStatus.VerifySuccess(NativeMethods.OrtGetDimensionsCount(typeAndShapeInfo, out numDimensions)); - if (numDimensions.ToUInt64() != 1) + var tensorInfo = ortValue.GetTensorTypeAndShape(); + if (tensorInfo.ElementDataType != Tensors.TensorElementType.Float) { - string errorMessage = "Incorrect buffer shape received. Expected a contiguous tensor buffer. Expected number of dimensions: 1, Actual: " + numDimensions.ToString(); - throw new ArgumentException(errorMessage); - } - - // Here buffer size represents the number of elements in the buffer - NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorShapeElementCount(typeAndShapeInfo, out UIntPtr bufferSize)); - - // OrtGetParametersSize returns the total number of elements in the model's parameters. - UIntPtr numElementsTrainingOnly = UIntPtr.Zero; - const bool onlyTrainable = true; - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out numElementsTrainingOnly, onlyTrainable)); - if ((ulong)bufferSize == (ulong)numElementsTrainingOnly) - { - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyBufferToParameters(_nativeHandle, buffer.Handle, onlyTrainable)); - return; + throw new ArgumentException("Incorrect buffer received. Expected a tensor buffer of type float."); } UIntPtr numElements = UIntPtr.Zero; - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out numElements, !onlyTrainable)); - if ((ulong)bufferSize != (ulong)numElements) + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out numElements, onlyTrainable)); + if ((ulong)tensorInfo.ElementCount != (ulong)numElements) { - string errorMessage = "Incorrect buffer size received. Expected size to be one of " + numElementsTrainingOnly.ToString() + " (training only) or " + numElements.ToString() + " (all parameters). Actual size: " + bufferSize.ToString(); + string errorMessage = "Incorrect buffer size received. Expected size to be " + numElements.ToString() + ". Actual size: " + tensorInfo.ElementCount.ToString(); throw new ArgumentException(errorMessage); } - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyBufferToParameters(_nativeHandle, buffer.Handle, !onlyTrainable)); + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyBufferToParameters(_nativeHandle, ortValue.Handle, onlyTrainable)); } /// diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs index 5632d34e1431a..68b1d5bcc6147 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs @@ -525,7 +525,7 @@ public void TestFromBuffer() var fetchedShape = typeShape.Shape; Assert.Equal(397510, fetchedShape[0]); - trainingSession.FromBuffer(buffer); + trainingSession.FromBuffer(buffer, true); } } } diff --git a/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h b/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h index 71b64ead0d388..0e8544a7639ba 100644 --- a/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h +++ b/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h @@ -704,19 +704,21 @@ struct OrtTrainingApi { /** \brief Gets the data associated with the model parameter from the checkpoint state for the given parameter name. * * This function retrieves the model parameter data from the checkpoint state for the given parameter name. - * The parameter is copied over to the provided OrtValue. The training session must be already created + * The parameter is copied over and returned as an OrtValue. The training session must be already created * with the checkpoint state that contains the parameter being retrieved. * The parameter must exist in the checkpoint state to be able to retrieve it successfully. * * \param[in] checkpoint_state The checkpoint state. * \param[in] parameter_name Name of the parameter being retrieved. + * \param[in] allocator Allocator used to allocate the memory for the parameter. * \param[out] parameter The parameter data that is retrieved from the checkpoint state. * * \snippet{doc} snippets.dox OrtStatus Return Value * */ ORT_API2_STATUS(GetParameter, _In_ const OrtCheckpointState* checkpoint_state, - _In_ const char* parameter_name, _Inout_ OrtValue* parameter); + _In_ const char* parameter_name, _Inout_ OrtAllocator* allocator, + _Outptr_ OrtValue** parameter); /// @} }; diff --git a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h index 553e17ba8c1b4..7d1326a10f8f8 100644 --- a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h +++ b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h @@ -284,18 +284,11 @@ inline void CheckpointState::UpdateParameter(const std::string& parameter_name, } inline Value CheckpointState::GetParameter(const std::string& parameter_name) { - OrtTensorTypeAndShapeInfo* parameter_type_and_shape_info; - ThrowOnError(GetTrainingApi().GetParameterTypeAndShape(p_, parameter_name.c_str(), ¶meter_type_and_shape_info)); - auto parameter_type_and_shape = TensorTypeAndShapeInfo{parameter_type_and_shape_info}; - auto shape = parameter_type_and_shape.GetShape(); - AllocatorWithDefaultOptions allocator; - Value parameter = Value::CreateTensor(allocator, shape.data(), shape.size(), - ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); - - ThrowOnError(GetTrainingApi().GetParameter(p_, parameter_name.c_str(), parameter)); + OrtValue* parameter; + ThrowOnError(GetTrainingApi().GetParameter(p_, parameter_name.c_str(), allocator, ¶meter)); - return parameter; + return Value{parameter}; } } // namespace Ort diff --git a/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc b/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc index 23649d6d34b9b..0fd9242d68f75 100644 --- a/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc +++ b/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc @@ -599,7 +599,8 @@ ORT_API_STATUS_IMPL(OrtTrainingApis::UpdateParameter, _Inout_ OrtCheckpointState } ORT_API_STATUS_IMPL(OrtTrainingApis::GetParameter, _In_ const OrtCheckpointState* checkpoint_state, - _In_ const char* parameter_name, _Inout_ OrtValue* parameter) { + _In_ const char* parameter_name, _Inout_ OrtAllocator* allocator, + _Outptr_ OrtValue** parameter) { API_IMPL_BEGIN if (parameter == nullptr) { @@ -613,8 +614,16 @@ ORT_API_STATUS_IMPL(OrtTrainingApis::GetParameter, _In_ const OrtCheckpointState return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, err_msg.c_str()); } + if (!it->second->Data().IsTensor()) { + return OrtApis::CreateStatus(ORT_FAIL, "Expected a tensor type for the parameter. Found a non-tensor type."); + } + const auto& parameter_tensor = it->second->Data().Get(); + ORT_API_RETURN_IF_ERROR(OrtApis::CreateTensorAsOrtValue( + allocator, parameter_tensor.Shape().GetDims().data(), parameter_tensor.Shape().NumDimensions(), + ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, parameter)); + ORT_API_RETURN_IF_STATUS_NOT_OK(it->second->CopyTo( - chkpt_state->module_checkpoint_state.train_session_data_transfer_mgr, *parameter)); + chkpt_state->module_checkpoint_state.train_session_data_transfer_mgr, **parameter)); return nullptr; API_IMPL_END diff --git a/orttraining/orttraining/training_api/ort_training_apis.h b/orttraining/orttraining/training_api/ort_training_apis.h index 6d65d786848cd..2a8c1e30361c6 100644 --- a/orttraining/orttraining/training_api/ort_training_apis.h +++ b/orttraining/orttraining/training_api/ort_training_apis.h @@ -101,6 +101,7 @@ ORT_API_STATUS_IMPL(UpdateParameter, _Inout_ OrtCheckpointState* checkpoint_stat _In_ const char* parameter_name, _In_ OrtValue* parameter); ORT_API_STATUS_IMPL(GetParameter, _In_ const OrtCheckpointState* checkpoint_state, - _In_ const char* parameter_name, _Inout_ OrtValue* parameter); + _In_ const char* parameter_name, _Inout_ OrtAllocator* allocator, + _Outptr_ OrtValue** parameter); } // namespace OrtTrainingApis