Skip to content

Commit

Permalink
Address pull request review comments for C# and C API
Browse files Browse the repository at this point in the history
  • Loading branch information
baijumeswani committed Sep 19, 2023
1 parent a97b607 commit 5be6641
Show file tree
Hide file tree
Showing 8 changed files with 38 additions and 72 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ internal enum PropertyType : long
private void AddPropertyImpl<T>(string propertyName, PropertyType propertyType, T propertyValue) where T : unmanaged
{
var propertyNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyName);
T[] value = new T[1];
value[0] = propertyValue;
T[] value = { propertyValue };
unsafe
{
fixed (T* memPtr = value)
Expand Down Expand Up @@ -232,31 +231,9 @@ public void UpdateParameter(string parameterName, OrtValue parameter)
public OrtValue GetParameter(string parameterName)
{
var parameterNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(parameterName);
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParameter(handle, parameterNameUtf8, OrtAllocator.DefaultInstance.Pointer, out IntPtr parameterHandle));

NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParameterTypeAndShape(handle, parameterNameUtf8, out IntPtr typeAndShapeInfoHandle));

try
{
var typeAndShapeInfo = new OrtTensorTypeAndShapeInfo(typeAndShapeInfoHandle);
var parameter = OrtValue.CreateAllocatedTensorValue(OrtAllocator.DefaultInstance, typeAndShapeInfo.ElementDataType, typeAndShapeInfo.Shape);

try
{
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParameter(handle, parameterNameUtf8, parameter.Handle));
}
catch (OnnxRuntimeException e)
{
parameter.Dispose();
throw e;
}

return parameter;
}
finally
{
NativeMethods.OrtReleaseTensorTypeAndShapeInfo(typeAndShapeInfoHandle);
}

return new OrtValue(parameterHandle);
}

#region SafeHandle
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,8 @@ out UIntPtr inputCount
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetParameter(
IntPtr /*(OrtCheckpointState*)*/ checkpointState,
byte[] /*(const char*)*/ parameterName,
IntPtr /*(OrtValue*)*/ parameter
IntPtr /*(OrtAllocator*)*/ allocator,
out IntPtr /*(OrtValue**)*/ parameter
);

public static DOrtGetParameter OrtGetParameter;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -517,8 +517,7 @@ public OrtValue ToBuffer(bool onlyTrainable)

float[] bufferMemory = new float[bufferSize.ToUInt64()];

var memInfo = OrtMemoryInfo.DefaultInstance; // CPU
var shape = new long[] { (long)bufferSize.ToUInt64() };
var shape = new long[] { (long)bufferSize };
var buffer = OrtValue.CreateAllocatedTensorValue(OrtAllocator.DefaultInstance, Tensors.TensorElementType.Float, shape);

NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyParametersToBuffer(_nativeHandle, buffer.Handle, onlyTrainable));
Expand All @@ -529,46 +528,30 @@ public OrtValue ToBuffer(bool onlyTrainable)
/// <summary>
/// Loads the training session model parameters from a contiguous buffer
/// </summary>
/// <param name="buffer">Contiguous buffer to load the parameters from.</param>
public void FromBuffer(OrtValue buffer)
/// <param name="ortValue">Contiguous buffer to load the parameters from.</param>
/// <param name="onlyTrainable">Whether to only load trainable parameters or to load all parameters.</param>
public void FromBuffer(OrtValue ortValue, bool onlyTrainable)
{
if (buffer.OnnxType != OnnxValueType.ONNX_TYPE_TENSOR)
if (ortValue.OnnxType != OnnxValueType.ONNX_TYPE_TENSOR)
{
throw new ArgumentException("Incorrect buffer received. Expected a tensor buffer.");
}

IntPtr typeAndShapeInfo = IntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorTypeAndShape(buffer.Handle, out typeAndShapeInfo));
UIntPtr numDimensions = UIntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeMethods.OrtGetDimensionsCount(typeAndShapeInfo, out numDimensions));
if (numDimensions.ToUInt64() != 1)
var tensorInfo = ortValue.GetTensorTypeAndShape();
if (tensorInfo.ElementDataType != Tensors.TensorElementType.Float)
{
string errorMessage = "Incorrect buffer shape received. Expected a contiguous tensor buffer. Expected number of dimensions: 1, Actual: " + numDimensions.ToString();
throw new ArgumentException(errorMessage);
}

// Here buffer size represents the number of elements in the buffer
NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorShapeElementCount(typeAndShapeInfo, out UIntPtr bufferSize));

// OrtGetParametersSize returns the total number of elements in the model's parameters.
UIntPtr numElementsTrainingOnly = UIntPtr.Zero;
const bool onlyTrainable = true;
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out numElementsTrainingOnly, onlyTrainable));
if ((ulong)bufferSize == (ulong)numElementsTrainingOnly)
{
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyBufferToParameters(_nativeHandle, buffer.Handle, onlyTrainable));
return;
throw new ArgumentException("Incorrect buffer received. Expected a tensor buffer of type float.");
}

UIntPtr numElements = UIntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out numElements, !onlyTrainable));
if ((ulong)bufferSize != (ulong)numElements)
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out numElements, onlyTrainable));
if ((ulong)tensorInfo.ElementCount != (ulong)numElements)
{
string errorMessage = "Incorrect buffer size received. Expected size to be one of " + numElementsTrainingOnly.ToString() + " (training only) or " + numElements.ToString() + " (all parameters). Actual size: " + bufferSize.ToString();
string errorMessage = "Incorrect buffer size received. Expected size to be " + numElements.ToString() + ". Actual size: " + tensorInfo.ElementCount.ToString();
throw new ArgumentException(errorMessage);
}

NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyBufferToParameters(_nativeHandle, buffer.Handle, !onlyTrainable));
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyBufferToParameters(_nativeHandle, ortValue.Handle, onlyTrainable));
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ public void TestFromBuffer()
var fetchedShape = typeShape.Shape;
Assert.Equal(397510, fetchedShape[0]);

trainingSession.FromBuffer(buffer);
trainingSession.FromBuffer(buffer, true);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -704,19 +704,21 @@ struct OrtTrainingApi {
/** \brief Gets the data associated with the model parameter from the checkpoint state for the given parameter name.
*
* This function retrieves the model parameter data from the checkpoint state for the given parameter name.
* The parameter is copied over to the provided OrtValue. The training session must be already created
* The parameter is copied over and returned as an OrtValue. The training session must be already created
* with the checkpoint state that contains the parameter being retrieved.
* The parameter must exist in the checkpoint state to be able to retrieve it successfully.
*
* \param[in] checkpoint_state The checkpoint state.
* \param[in] parameter_name Name of the parameter being retrieved.
* \param[in] allocator Allocator used to allocate the memory for the parameter.
* \param[out] parameter The parameter data that is retrieved from the checkpoint state.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(GetParameter, _In_ const OrtCheckpointState* checkpoint_state,
_In_ const char* parameter_name, _Inout_ OrtValue* parameter);
_In_ const char* parameter_name, _Inout_ OrtAllocator* allocator,
_Outptr_ OrtValue** parameter);

/// @}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -284,18 +284,11 @@ inline void CheckpointState::UpdateParameter(const std::string& parameter_name,
}

inline Value CheckpointState::GetParameter(const std::string& parameter_name) {

Check warning on line 286 in orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h#L286

Add #include <string> for string [build/include_what_you_use] [4]
Raw output
orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h:286:  Add #include <string> for string  [build/include_what_you_use] [4]
OrtTensorTypeAndShapeInfo* parameter_type_and_shape_info;
ThrowOnError(GetTrainingApi().GetParameterTypeAndShape(p_, parameter_name.c_str(), &parameter_type_and_shape_info));
auto parameter_type_and_shape = TensorTypeAndShapeInfo{parameter_type_and_shape_info};
auto shape = parameter_type_and_shape.GetShape();

AllocatorWithDefaultOptions allocator;
Value parameter = Value::CreateTensor(allocator, shape.data(), shape.size(),
ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT);

ThrowOnError(GetTrainingApi().GetParameter(p_, parameter_name.c_str(), parameter));
OrtValue* parameter;
ThrowOnError(GetTrainingApi().GetParameter(p_, parameter_name.c_str(), allocator, &parameter));

return parameter;
return Value{parameter};
}

} // namespace Ort
13 changes: 11 additions & 2 deletions orttraining/orttraining/training_api/onnxruntime_training_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,8 @@ ORT_API_STATUS_IMPL(OrtTrainingApis::UpdateParameter, _Inout_ OrtCheckpointState
}

ORT_API_STATUS_IMPL(OrtTrainingApis::GetParameter, _In_ const OrtCheckpointState* checkpoint_state,
_In_ const char* parameter_name, _Inout_ OrtValue* parameter) {
_In_ const char* parameter_name, _Inout_ OrtAllocator* allocator,
_Outptr_ OrtValue** parameter) {
API_IMPL_BEGIN

if (parameter == nullptr) {
Expand All @@ -613,8 +614,16 @@ ORT_API_STATUS_IMPL(OrtTrainingApis::GetParameter, _In_ const OrtCheckpointState
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, err_msg.c_str());
}

if (!it->second->Data().IsTensor()) {
return OrtApis::CreateStatus(ORT_FAIL, "Expected a tensor type for the parameter. Found a non-tensor type.");
}
const auto& parameter_tensor = it->second->Data().Get<onnxruntime::Tensor>();
ORT_API_RETURN_IF_ERROR(OrtApis::CreateTensorAsOrtValue(
allocator, parameter_tensor.Shape().GetDims().data(), parameter_tensor.Shape().NumDimensions(),
ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, parameter));

ORT_API_RETURN_IF_STATUS_NOT_OK(it->second->CopyTo(
chkpt_state->module_checkpoint_state.train_session_data_transfer_mgr, *parameter));
chkpt_state->module_checkpoint_state.train_session_data_transfer_mgr, **parameter));

return nullptr;
API_IMPL_END
Expand Down
3 changes: 2 additions & 1 deletion orttraining/orttraining/training_api/ort_training_apis.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ ORT_API_STATUS_IMPL(UpdateParameter, _Inout_ OrtCheckpointState* checkpoint_stat
_In_ const char* parameter_name, _In_ OrtValue* parameter);

ORT_API_STATUS_IMPL(GetParameter, _In_ const OrtCheckpointState* checkpoint_state,
_In_ const char* parameter_name, _Inout_ OrtValue* parameter);
_In_ const char* parameter_name, _Inout_ OrtAllocator* allocator,
_Outptr_ OrtValue** parameter);

} // namespace OrtTrainingApis

0 comments on commit 5be6641

Please sign in to comment.