From 545036fd847a6805fbfa831d1700d6bd5bb97357 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Thu, 28 Sep 2023 23:41:59 +0000 Subject: [PATCH 01/25] update --- .../providers/tensorrt/tensorrt_execution_provider.cc | 8 +++++++- .../core/providers/tensorrt/tensorrt_execution_provider.h | 4 ++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 88a576f3ffa73..86fe2dcad2938 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -1860,6 +1860,7 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, } else if (number_of_trt_nodes == number_of_ort_nodes) { LOGS_DEFAULT(INFO) << "[TensorRT EP] Whole graph will run on TensorRT execution provider"; } else { + sync_stream_before_enqueue_ = true; LOGS_DEFAULT(INFO) << "[TensorRT EP] Graph is partitioned and number of subgraphs running on TensorRT execution provider is " << number_of_subgraphs; } @@ -2372,7 +2373,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorallocate_func, context->release_func, context->allocator_handle, context->node_name, &parsers_[context->node_name], &engines_[context->node_name], &contexts_[context->node_name], &builders_[context->node_name], &networks_[context->node_name], input_info_[context->node_name], output_info_[context->node_name], - input_shape_ranges_[context->node_name], &tensorrt_mu_, fp16_enable_, int8_enable_, int8_calibration_cache_available_, + input_shape_ranges_[context->node_name], sync_stream_before_enqueue_, &tensorrt_mu_, fp16_enable_, int8_enable_, int8_calibration_cache_available_, dla_enable_, dla_core_, &max_workspace_size_, trt_node_name_with_precision, engine_cache_enable_, cache_path_, runtime_.get(), profiles_[context->node_name], context_memory_sharing_enable_, &max_ctx_mem_size_, dynamic_range_map, engine_decryption_enable_, engine_decryption_, engine_encryption_, timing_cache_enable_, @@ -2400,6 +2401,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector& input_indexes = (trt_state->input_info)[0]; const std::unordered_map& output_indexes = (trt_state->output_info)[0]; const std::unordered_map& output_types = (trt_state->output_info)[1]; + bool sync_stream_before_enqueue = trt_state->sync_stream_before_enqueue; auto fused_node_name = trt_state->fused_node_name; auto& shape_ranges = trt_state->input_shape_ranges; auto trt_builder = trt_state->builder->get(); @@ -2996,6 +2998,10 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorenqueueV2(&buffers[0], stream, nullptr)) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "TensorRT EP execution context enqueue failed."); diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index 64ab2db2aedc9..c393a9dcf4f7b 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -111,6 +111,7 @@ struct TensorrtFuncState { std::vector> input_info; std::vector> output_info; std::unordered_map>>> input_shape_ranges; + bool sync_stream_before_enqueue = false; OrtMutex* tensorrt_mu_ptr = nullptr; bool fp16_enable = false; bool int8_enable = false; @@ -262,6 +263,9 @@ class TensorrtExecutionProvider : public IExecutionProvider { cudnnHandle_t external_cudnn_handle_ = nullptr; cublasHandle_t external_cublas_handle_ = nullptr; + // Call cudaStreamSynchronize() before TRT enqueueV2()/enqueueV3() + sync_stream_before_enqueue_ = false; + CUDAGraph cuda_graph_; bool is_graph_captured_ = false; int regular_run_count_before_graph_capture_ = 0; From e0cd0cc614955c32b849efcf046ec5bd66a7c0dd Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Mon, 2 Oct 2023 19:16:34 +0000 Subject: [PATCH 02/25] update --- .../core/providers/tensorrt/tensorrt_execution_provider.cc | 2 +- .../core/providers/tensorrt/tensorrt_execution_provider.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 86fe2dcad2938..7f3047c7faa88 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -2998,7 +2998,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector Date: Tue, 3 Oct 2023 01:52:01 +0000 Subject: [PATCH 03/25] update --- .../providers/tensorrt/tensorrt_execution_provider.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 7f3047c7faa88..f7e783e1999c6 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -2998,15 +2998,15 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorenqueueV2(&buffers[0], stream, nullptr)) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "TensorRT EP execution context enqueue failed."); } + if (sync_stream_before_enqueue) { + cudaStreamSynchronize(stream); + } + // Cast INT64 input to INT32 because TensorRT doesn't fully support INT64 for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) { const std::string& output_name = output_binding_names[i]; From b2a582ae0833aa46c2c4479bc84157310aecfeed Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Thu, 19 Oct 2023 06:39:54 +0000 Subject: [PATCH 04/25] update --- .../providers/tensorrt/tensorrt_execution_provider.cc | 8 ++++---- .../core/providers/tensorrt/tensorrt_execution_provider.h | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index f7e783e1999c6..e04b7a6f296ae 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -1860,7 +1860,7 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, } else if (number_of_trt_nodes == number_of_ort_nodes) { LOGS_DEFAULT(INFO) << "[TensorRT EP] Whole graph will run on TensorRT execution provider"; } else { - sync_stream_before_enqueue_ = true; + sync_stream_after_enqueue_ = true; LOGS_DEFAULT(INFO) << "[TensorRT EP] Graph is partitioned and number of subgraphs running on TensorRT execution provider is " << number_of_subgraphs; } @@ -2373,7 +2373,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorallocate_func, context->release_func, context->allocator_handle, context->node_name, &parsers_[context->node_name], &engines_[context->node_name], &contexts_[context->node_name], &builders_[context->node_name], &networks_[context->node_name], input_info_[context->node_name], output_info_[context->node_name], - input_shape_ranges_[context->node_name], sync_stream_before_enqueue_, &tensorrt_mu_, fp16_enable_, int8_enable_, int8_calibration_cache_available_, + input_shape_ranges_[context->node_name], sync_stream_after_enqueue_, &tensorrt_mu_, fp16_enable_, int8_enable_, int8_calibration_cache_available_, dla_enable_, dla_core_, &max_workspace_size_, trt_node_name_with_precision, engine_cache_enable_, cache_path_, runtime_.get(), profiles_[context->node_name], context_memory_sharing_enable_, &max_ctx_mem_size_, dynamic_range_map, engine_decryption_enable_, engine_decryption_, engine_encryption_, timing_cache_enable_, @@ -2401,7 +2401,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector& input_indexes = (trt_state->input_info)[0]; const std::unordered_map& output_indexes = (trt_state->output_info)[0]; const std::unordered_map& output_types = (trt_state->output_info)[1]; - bool sync_stream_before_enqueue = trt_state->sync_stream_before_enqueue; + bool sync_stream_after_enqueue = trt_state->sync_stream_after_enqueue; auto fused_node_name = trt_state->fused_node_name; auto& shape_ranges = trt_state->input_shape_ranges; auto trt_builder = trt_state->builder->get(); @@ -3003,7 +3003,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector> input_info; std::vector> output_info; std::unordered_map>>> input_shape_ranges; - bool sync_stream_before_enqueue = false; + bool sync_stream_after_enqueue = false; OrtMutex* tensorrt_mu_ptr = nullptr; bool fp16_enable = false; bool int8_enable = false; @@ -264,7 +264,7 @@ class TensorrtExecutionProvider : public IExecutionProvider { cublasHandle_t external_cublas_handle_ = nullptr; // Call cudaStreamSynchronize() before TRT enqueueV2()/enqueueV3() - mutable bool sync_stream_before_enqueue_ = false; + mutable bool sync_stream_after_enqueue_ = false; CUDAGraph cuda_graph_; bool is_graph_captured_ = false; From 0ec4b83212b84089ca685ef989659205a473b207 Mon Sep 17 00:00:00 2001 From: Chi Lo <54722500+chilo-ms@users.noreply.github.com> Date: Thu, 19 Oct 2023 15:07:16 -0700 Subject: [PATCH 05/25] Fix comment --- .../core/providers/tensorrt/tensorrt_execution_provider.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index b8b82f650c7e3..3bf6bc05a65df 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -263,7 +263,7 @@ class TensorrtExecutionProvider : public IExecutionProvider { cudnnHandle_t external_cudnn_handle_ = nullptr; cublasHandle_t external_cublas_handle_ = nullptr; - // Call cudaStreamSynchronize() before TRT enqueueV2()/enqueueV3() + // Call cudaStreamSynchronize() after TRT enqueueV2()/enqueueV3() mutable bool sync_stream_after_enqueue_ = false; CUDAGraph cuda_graph_; From e9a8c88c948968b841a0875e10ade32ebe645a57 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Thu, 12 Oct 2023 17:49:12 +0000 Subject: [PATCH 06/25] fix --- .../core/providers/tensorrt/tensorrt_execution_provider.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index e04b7a6f296ae..ac92d46ca87fc 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -792,6 +792,10 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv if (info.has_user_compute_stream) { external_stream_ = true; stream_ = static_cast(info.user_compute_stream); + ORT_IGNORE_RETURN_VALUE(CUBLAS_CALL(cublasCreate(&external_cublas_handle_))); + ORT_IGNORE_RETURN_VALUE(CUBLAS_CALL(cublasSetStream(external_cublas_handle_, stream_))); + ORT_IGNORE_RETURN_VALUE(CUDNN_CALL(cudnnCreate(&external_cudnn_handle_))); + ORT_IGNORE_RETURN_VALUE(CUDNN_CALL(cudnnSetStream(external_cudnn_handle_, stream_))); } std::string profile_min_shapes, profile_max_shapes, profile_opt_shapes; From e71849673b9c68199f2ab899c462618f5943e4e6 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Thu, 12 Oct 2023 17:00:50 +1000 Subject: [PATCH 07/25] Fix missing attribute. Causes build error on release xamarin iOS build. Fix some long lines as well. --- .../NativeMethods.shared.cs | 53 +++++++++++-------- 1 file changed, 32 insertions(+), 21 deletions(-) diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs index 2ba837be22041..f722ca9d30fa4 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs @@ -1860,54 +1860,61 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca public static DOrtFillStringTensor OrtFillStringTensor; + /// \param value A tensor created from OrtCreateTensor... function. + /// \param index The index of the entry in the tensor to resize. + /// \param length_in_bytes Length to resize the string to. + /// \param buffer The resized buffer. + [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtGetResizedStringTensorElementBuffer( - IntPtr /* OrtValue */ value, - UIntPtr /* size_t */ index, - UIntPtr /* size_t */ length_in_bytes, - out IntPtr /* char** */ buffer - ); + IntPtr /* OrtValue */ value, + UIntPtr /* size_t */ index, + UIntPtr /* size_t */ length_in_bytes, + out IntPtr /* char** */ buffer); public static DOrtGetResizedStringTensorElementBuffer OrtGetResizedStringTensorElementBuffer; [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtGetStringTensorContent( - IntPtr /*(OrtValue*)*/ value, - byte[] /*(void*)*/ dst_buffer, - UIntPtr dst_buffer_len, - UIntPtr[] offsets, - UIntPtr offsets_len); + IntPtr /*(OrtValue*)*/ value, + byte[] /*(void*)*/ dst_buffer, + UIntPtr dst_buffer_len, + UIntPtr[] offsets, + UIntPtr offsets_len); public static DOrtGetStringTensorContent OrtGetStringTensorContent; [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtGetStringTensorDataLength(IntPtr /*(OrtValue*)*/ value, - out UIntPtr /*(size_t*)*/ len); + out UIntPtr /*(size_t*)*/ len); public static DOrtGetStringTensorDataLength OrtGetStringTensorDataLength; [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtGetStringTensorElementLength(IntPtr /*(OrtValue*)*/ value, - UIntPtr /*(size_t)*/ index, - out UIntPtr /*(size_t*)*/ len); + UIntPtr /*(size_t)*/ index, + out UIntPtr /*(size_t*)*/ len); public static DOrtGetStringTensorElementLength OrtGetStringTensorElementLength; [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtGetStringTensorElement(IntPtr /*(OrtValue*)*/ value, - UIntPtr /*(size_t)*/ bufferLength, - UIntPtr /*(size_t)*/ elementIndex, - byte[] buffer); + UIntPtr /*(size_t)*/ bufferLength, + UIntPtr /*(size_t)*/ elementIndex, + byte[] buffer); public static DOrtGetStringTensorElement OrtGetStringTensorElement; [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/ - DOrtCastTypeInfoToTensorInfo(IntPtr /*(struct OrtTypeInfo*)*/ typeInfo, out IntPtr /*(const struct OrtTensorTypeAndShapeInfo**)*/ typeAndShapeInfo); + public delegate IntPtr /*(OrtStatus*)*/ DOrtCastTypeInfoToTensorInfo( + IntPtr /*(struct OrtTypeInfo*)*/ typeInfo, + out IntPtr /*(const struct OrtTensorTypeAndShapeInfo**)*/ typeAndShapeInfo); public static DOrtCastTypeInfoToTensorInfo OrtCastTypeInfoToTensorInfo; [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/ DOrtGetTensorTypeAndShape(IntPtr /*(OrtValue*)*/ value, out IntPtr /*(struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo); + public delegate IntPtr /*(OrtStatus*)*/ DOrtGetTensorTypeAndShape( + IntPtr /*(OrtValue*)*/ value, + out IntPtr /*(struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo); public static DOrtGetTensorTypeAndShape OrtGetTensorTypeAndShape; @@ -1917,12 +1924,16 @@ out IntPtr /* char** */ buffer public static DOrtReleaseTensorTypeAndShapeInfo OrtReleaseTensorTypeAndShapeInfo; [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/ DOrtGetTensorElementType(IntPtr /*(const struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo, out IntPtr /*(TensorElementType*)*/ output); + public delegate IntPtr /*(OrtStatus*)*/ DOrtGetTensorElementType( + IntPtr /*(const struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo, + out IntPtr /*(TensorElementType*)*/ output); public static DOrtGetTensorElementType OrtGetTensorElementType; [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/ DOrtGetDimensionsCount(IntPtr /*(const struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo, out UIntPtr output); + public delegate IntPtr /*(OrtStatus*)*/ DOrtGetDimensionsCount( + IntPtr /*(const struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo, + out UIntPtr output); public static DOrtGetDimensionsCount OrtGetDimensionsCount; From 5ada1adadb269a26c55bc118c5c5ad5a63b9c3aa Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Fri, 25 Aug 2023 21:14:18 +0000 Subject: [PATCH 08/25] Address pull request review comment --- .../core/providers/cuda/cu_inc/common.cuh | 14 +- .../providers/cuda/cuda_execution_provider.cc | 22 +++ .../cuda/math/unary_elementwise_ops.cc | 1 + .../cuda/math/unary_elementwise_ops.h | 7 + .../cuda/math/unary_elementwise_ops_impl.cu | 170 +++++++++--------- .../cuda/math/unary_elementwise_ops_impl.h | 3 +- .../test/providers/cpu/math/sign_test.cc | 10 +- 7 files changed, 135 insertions(+), 92 deletions(-) diff --git a/onnxruntime/core/providers/cuda/cu_inc/common.cuh b/onnxruntime/core/providers/cuda/cu_inc/common.cuh index a50b53315ec9a..0d9928baa86e0 100644 --- a/onnxruntime/core/providers/cuda/cu_inc/common.cuh +++ b/onnxruntime/core/providers/cuda/cu_inc/common.cuh @@ -20,7 +20,7 @@ namespace cuda { // float16 arithmetic is supported after sm5.3 with intrinsics, and cuda does not provide fallback for lower versions // CUDA 12.2 does not limit the definition based on sm53 anymore and defines for all arches -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 530) && ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12 ) && (__CUDACC_VER_MINOR__ < 2))) +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 530) && ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 2))) __device__ __forceinline__ half operator+(const half& lh, const half& rh) { return half((float)lh + (float)rh); } __device__ __forceinline__ half operator-(const half& lh, const half& rh) { return half((float)lh - (float)rh); } __device__ __forceinline__ half operator*(const half& lh, const half& rh) { return half((float)lh * (float)rh); } @@ -351,6 +351,18 @@ __device__ __inline__ T _Max(T a, T b) { return a > b ? a : b; } template __device__ __inline__ T _Abs(T a) { return a > (T)0 ? a : -a; } +template +__device__ __inline__ T _Signum(T a, std::false_type /* is_signed */) { return T(0) < a; } + +template +__device__ __inline__ T _Signum(T a, std::true_type /* is_signed */) { return (T(0) < a) - (a < T(0)); } + +template +__device__ __inline__ T _Sign(T a) { return _Signum(a, std::is_signed()); } + +template <> +__device__ __inline__ half _Sign(half a) { return _Signum(a, std::true_type()); } + template __device__ __inline__ T _Normcdf(T a); diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index aa60db4d07222..ad892eab3b843 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -1180,6 +1180,17 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, bool, Pad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, SpaceToDepth); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, DepthToSpace); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int8_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int16_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int64_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint8_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint16_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint32_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint64_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Sign); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Add); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Sub); @@ -2118,6 +2129,17 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc index f026444328b24..9ede1f8d90ecc 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc @@ -157,6 +157,7 @@ UNARY_OP_HFD(Sqrt, 13) UNARY_OP_HFD(Log, 13) UNARY_OP_HFD(Exp, 13) UNARY_OP_HFD(Erf, 13) +UNARY_OP_BWUZCSILHFD(Sign, 13) UNARY_LOGICALOP_NOT_TYPED(1, bool) UNARY_OP_HFD(Round, 11) diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h index 3ff97a60114df..775b78c43a736 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h @@ -112,5 +112,12 @@ class Cos final : public UnaryElementwise { Status ComputeInternal(OpKernelContext* context) const override; }; +template +class Sign final : public UnaryElementwise { + public: + Sign(const OpKernelInfo& info) : UnaryElementwise(info) {} + Status ComputeInternal(OpKernelContext* context) const override; +}; + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu index ac7cc1126acb7..1298d53338337 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu @@ -90,6 +90,7 @@ SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Round) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Sin) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Cos) SPECIALIZED_UNARY_ELEMENTWISE_IMPL(Not, bool) +SPECIALIZED_UNARY_ELEMENTWISE_IMPL_BWUZCSILHFD(Sign) // When casting, half needs to be converted via float type from most other types template @@ -119,52 +120,52 @@ struct OP_Cast { } }; -#define IMPL_CAST_IMPL(InT, OutT) \ +#define IMPL_CAST_IMPL(InT, OutT) \ void Explicit_Impl_Cast(cudaStream_t stream, const InT* input_data, OutT* output_data, size_t count) { \ - UnaryElementWiseImpl(stream, input_data, output_data, OP_Cast(), count); \ + UnaryElementWiseImpl(stream, input_data, output_data, OP_Cast(), count); \ } -#define IMPL_CAST_IMPL_THROW(InT, OutT) \ +#define IMPL_CAST_IMPL_THROW(InT, OutT) \ void Explicit_Impl_Cast(cudaStream_t stream, const InT* input_data, OutT* output_data, size_t count) { \ - ORT_THROW("Cast from " #InT " to " #OutT " must define saturate."); \ + ORT_THROW("Cast from " #InT " to " #OutT " must define saturate."); \ } #if !defined(DISABLE_FLOAT8_TYPES) -#define IMPL_CAST_IMPL_FROM(T) \ - IMPL_CAST_IMPL(T, half) \ - IMPL_CAST_IMPL(T, float) \ - IMPL_CAST_IMPL(T, double) \ - IMPL_CAST_IMPL(T, int8_t) \ - IMPL_CAST_IMPL(T, int16_t) \ - IMPL_CAST_IMPL(T, int32_t) \ - IMPL_CAST_IMPL(T, int64_t) \ - IMPL_CAST_IMPL(T, uint8_t) \ - IMPL_CAST_IMPL(T, uint16_t) \ - IMPL_CAST_IMPL(T, uint32_t) \ - IMPL_CAST_IMPL(T, uint64_t) \ - IMPL_CAST_IMPL(T, bool) \ - IMPL_CAST_IMPL(T, BFloat16) \ - IMPL_CAST_IMPL_THROW(T, Float8E4M3FN) \ - IMPL_CAST_IMPL_THROW(T, Float8E5M2) \ +#define IMPL_CAST_IMPL_FROM(T) \ + IMPL_CAST_IMPL(T, half) \ + IMPL_CAST_IMPL(T, float) \ + IMPL_CAST_IMPL(T, double) \ + IMPL_CAST_IMPL(T, int8_t) \ + IMPL_CAST_IMPL(T, int16_t) \ + IMPL_CAST_IMPL(T, int32_t) \ + IMPL_CAST_IMPL(T, int64_t) \ + IMPL_CAST_IMPL(T, uint8_t) \ + IMPL_CAST_IMPL(T, uint16_t) \ + IMPL_CAST_IMPL(T, uint32_t) \ + IMPL_CAST_IMPL(T, uint64_t) \ + IMPL_CAST_IMPL(T, bool) \ + IMPL_CAST_IMPL(T, BFloat16) \ + IMPL_CAST_IMPL_THROW(T, Float8E4M3FN) \ + IMPL_CAST_IMPL_THROW(T, Float8E5M2) \ IMPL_CAST_IMPL_THROW(T, Float8E4M3FNUZ) \ IMPL_CAST_IMPL_THROW(T, Float8E5M2FNUZ) #else -#define IMPL_CAST_IMPL_FROM(T) \ - IMPL_CAST_IMPL(T, half) \ - IMPL_CAST_IMPL(T, float) \ - IMPL_CAST_IMPL(T, double) \ - IMPL_CAST_IMPL(T, int8_t) \ - IMPL_CAST_IMPL(T, int16_t) \ - IMPL_CAST_IMPL(T, int32_t) \ - IMPL_CAST_IMPL(T, int64_t) \ - IMPL_CAST_IMPL(T, uint8_t) \ - IMPL_CAST_IMPL(T, uint16_t) \ - IMPL_CAST_IMPL(T, uint32_t) \ - IMPL_CAST_IMPL(T, uint64_t) \ - IMPL_CAST_IMPL(T, bool) \ +#define IMPL_CAST_IMPL_FROM(T) \ + IMPL_CAST_IMPL(T, half) \ + IMPL_CAST_IMPL(T, float) \ + IMPL_CAST_IMPL(T, double) \ + IMPL_CAST_IMPL(T, int8_t) \ + IMPL_CAST_IMPL(T, int16_t) \ + IMPL_CAST_IMPL(T, int32_t) \ + IMPL_CAST_IMPL(T, int64_t) \ + IMPL_CAST_IMPL(T, uint8_t) \ + IMPL_CAST_IMPL(T, uint16_t) \ + IMPL_CAST_IMPL(T, uint32_t) \ + IMPL_CAST_IMPL(T, uint64_t) \ + IMPL_CAST_IMPL(T, bool) \ IMPL_CAST_IMPL(T, BFloat16) #endif @@ -199,58 +200,58 @@ struct OP_CastNoSat { #if defined(CUDA_VERSION) && CUDA_VERSION >= 11080 -#define OP_CAST(T, NVT) \ - template <> \ - struct OP_CastSat { \ - __device__ __inline__ T operator()(const half& v) const { \ +#define OP_CAST(T, NVT) \ + template <> \ + struct OP_CastSat { \ + __device__ __inline__ T operator()(const half& v) const { \ return T(static_cast(__nv_cvt_halfraw_to_fp8(v, __NV_SATFINITE, NVT)), T::FromBits()); \ - } \ - }; \ - template <> \ - struct OP_CastNoSat { \ - __device__ __inline__ T operator()(const half& v) const { \ - return T(static_cast(__nv_cvt_halfraw_to_fp8(v, __NV_NOSAT, NVT)), T::FromBits()); \ - } \ - }; \ - template <> \ - struct OP_CastSat { \ - __device__ __inline__ T operator()(const float& v) const { \ - return T(static_cast(__nv_cvt_float_to_fp8(v, __NV_SATFINITE, NVT)), T::FromBits()); \ - } \ - }; \ - template <> \ - struct OP_CastNoSat { \ - __device__ __inline__ T operator()(const float& v) const { \ - return T(static_cast(__nv_cvt_float_to_fp8(v, __NV_NOSAT, NVT)), T::FromBits()); \ - } \ + } \ + }; \ + template <> \ + struct OP_CastNoSat { \ + __device__ __inline__ T operator()(const half& v) const { \ + return T(static_cast(__nv_cvt_halfraw_to_fp8(v, __NV_NOSAT, NVT)), T::FromBits()); \ + } \ + }; \ + template <> \ + struct OP_CastSat { \ + __device__ __inline__ T operator()(const float& v) const { \ + return T(static_cast(__nv_cvt_float_to_fp8(v, __NV_SATFINITE, NVT)), T::FromBits()); \ + } \ + }; \ + template <> \ + struct OP_CastNoSat { \ + __device__ __inline__ T operator()(const float& v) const { \ + return T(static_cast(__nv_cvt_float_to_fp8(v, __NV_NOSAT, NVT)), T::FromBits()); \ + } \ }; #else -#define OP_CAST(T, NVT) \ - template <> \ - struct OP_CastSat { \ - __device__ __inline__ T operator()(const half& v) const { \ - return T(__half2float(v), true); \ - } \ - }; \ - template <> \ - struct OP_CastNoSat { \ - __device__ __inline__ T operator()(const half& v) const { \ - return T(__half2float(v), false); \ - } \ - }; \ - template <> \ - struct OP_CastSat { \ +#define OP_CAST(T, NVT) \ + template <> \ + struct OP_CastSat { \ + __device__ __inline__ T operator()(const half& v) const { \ + return T(__half2float(v), true); \ + } \ + }; \ + template <> \ + struct OP_CastNoSat { \ + __device__ __inline__ T operator()(const half& v) const { \ + return T(__half2float(v), false); \ + } \ + }; \ + template <> \ + struct OP_CastSat { \ __device__ __inline__ T operator()(const float& v) const { \ - return T(v, true); \ - } \ - }; \ - template <> \ - struct OP_CastNoSat { \ + return T(v, true); \ + } \ + }; \ + template <> \ + struct OP_CastNoSat { \ __device__ __inline__ T operator()(const float& v) const { \ - return T(v, false); \ - } \ + return T(v, false); \ + } \ }; #endif @@ -260,14 +261,13 @@ struct OP_CastNoSat { OP_CAST(Float8E4M3FN, __NV_E4M3) OP_CAST(Float8E5M2, __NV_E5M2) - -#define EXPLICIT_IMPL_CASTSAT(InT, OutT) \ +#define EXPLICIT_IMPL_CASTSAT(InT, OutT) \ void Explicit_Impl_CastSat(cudaStream_t stream, const InT* input_data, OutT* output_data, size_t count, bool saturate) { \ - if (saturate) { \ - UnaryElementWiseImpl(stream, input_data, output_data, OP_CastSat(), count); \ - } else { \ - UnaryElementWiseImpl(stream, input_data, output_data, OP_CastNoSat(), count); \ - } \ + if (saturate) { \ + UnaryElementWiseImpl(stream, input_data, output_data, OP_CastSat(), count); \ + } else { \ + UnaryElementWiseImpl(stream, input_data, output_data, OP_CastNoSat(), count); \ + } \ } EXPLICIT_IMPL_CASTSAT(float, Float8E4M3FN) diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h index 3d4868b54abe6..608a81a24cf4f 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h @@ -31,7 +31,8 @@ namespace cuda { UNARY_OP_NAME_EXPR(Not, !a) \ UNARY_OP_NAME_EXPR(Round, _Round(a)) \ UNARY_OP_NAME_EXPR(Sin, _Sin(a)) \ - UNARY_OP_NAME_EXPR(Cos, _Cos(a)) + UNARY_OP_NAME_EXPR(Cos, _Cos(a)) \ + UNARY_OP_NAME_EXPR(Sign, _Sign(a)) #define UNARY_ELEMENTWISE_IMPL_DECLARATION(name) \ template \ diff --git a/onnxruntime/test/providers/cpu/math/sign_test.cc b/onnxruntime/test/providers/cpu/math/sign_test.cc index 12844068c47d2..15b3f40faa791 100644 --- a/onnxruntime/test/providers/cpu/math/sign_test.cc +++ b/onnxruntime/test/providers/cpu/math/sign_test.cc @@ -113,7 +113,7 @@ TestImpl(ForwardIter first, ForwardIter last, OutputIter out) { TEST(MathOpTest, Sign_uint64) { using namespace test_sign_internal; - OpTester test("Sign", 9); + OpTester test("Sign", 13); std::vector input_dims{7}; std::vector input; @@ -129,7 +129,7 @@ TEST(MathOpTest, Sign_uint64) { // we disable this test for openvino as openvino ep supports only FP32 Precision TEST(MathOpTest, Sign_int64) { using namespace test_sign_internal; - OpTester test("Sign", 9); + OpTester test("Sign", 13); std::vector input_dims{7}; std::vector input; @@ -146,7 +146,7 @@ TEST(MathOpTest, Sign_int64) { TEST(MathOpTest, Sign_float) { using namespace test_sign_internal; - OpTester test("Sign", 9); + OpTester test("Sign", 13); std::vector input_dims{7}; std::vector input; @@ -162,7 +162,7 @@ TEST(MathOpTest, Sign_float) { TEST(MathOpTest, Sign_double) { using namespace test_sign_internal; - OpTester test("Sign", 9); + OpTester test("Sign", 13); std::vector input_dims{7}; std::vector input; @@ -177,7 +177,7 @@ TEST(MathOpTest, Sign_double) { } TEST(MathOpTest, Sign_MLFloat16) { using namespace test_sign_internal; - OpTester test("Sign", 9); + OpTester test("Sign", 13); std::vector input_dims{7}; std::vector input; From d30dcfac62c88805f0f85c82a715f0d1ba0a1510 Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Mon, 28 Aug 2023 16:33:15 +0000 Subject: [PATCH 09/25] Add rocm execution provider for Sign --- .../providers/rocm/rocm_execution_provider.cc | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index 9401de64269b9..e6ea876d8957c 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -1105,6 +1105,17 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, uint8_t, QuantizeLinear); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, int8_t, DequantizeLinear); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, uint8_t, DequantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int8_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int16_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int32_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int64_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint8_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint16_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint32_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint64_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Sign); // OpSet 14 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, CumSum); @@ -2067,6 +2078,17 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // OpSet 14 BuildKernelCreateInfo, From 1e0736cd1700340208fc519184ffc35b2ea7f562 Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Mon, 28 Aug 2023 17:20:28 +0000 Subject: [PATCH 10/25] Update operator kernels doc --- docs/OperatorKernels.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index c76f760ef04bd..764dded991abb 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -759,6 +759,7 @@ Do not modify directly.* |Shrink|*in* input:**T**
*out* output:**T**|9+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Sigmoid|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)| +|Sign|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |SimplifiedLayerNormalization|*in* X:**T**
*in* scale:**V**
*out* Y:**V**
*out* inv_std_var:**U**|1+|**T** = tensor(double), tensor(float), tensor(float16)
**U** = tensor(double), tensor(float)
**V** = tensor(double), tensor(float), tensor(float16)| |Sin|*in* input:**T**
*out* output:**T**|7+|**T** = tensor(double), tensor(float), tensor(float16)| |Size|*in* data:**T**
*out* size:**T1**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| From 9c3df0495dd77f6af27678c7c7ad1eb6067c58cd Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Mon, 28 Aug 2023 18:06:55 +0000 Subject: [PATCH 11/25] Add definitions to rocm ep --- .../core/providers/rocm/cu_inc/common.cuh | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/rocm/cu_inc/common.cuh b/onnxruntime/core/providers/rocm/cu_inc/common.cuh index 5c516aac65aab..429ceb1f7c699 100644 --- a/onnxruntime/core/providers/rocm/cu_inc/common.cuh +++ b/onnxruntime/core/providers/rocm/cu_inc/common.cuh @@ -250,6 +250,18 @@ __device__ __inline__ T _Max(T a, T b) { return a > b ? a : b; } template __device__ __inline__ T _Abs(T a) { return a > (T)0 ? a : -a; } +template +__device__ __inline__ T _Signum(T a, std::false_type /* is_signed */) { return T(0) < a; } + +template +__device__ __inline__ T _Signum(T a, std::true_type /* is_signed */) { return (T(0) < a) - (a < T(0)); } + +template +__device__ __inline__ T _Sign(T a) { return _Signum(a, std::is_signed()); } + +template <> +__device__ __inline__ half _Sign(half a) { return _Signum(a, std::true_type()); } + template __device__ __inline__ T _Normcdf(T a); @@ -337,7 +349,7 @@ struct GridDim { }; // aligned vector generates vectorized load/store -template +template struct alignas(sizeof(T) * vec_size) aligned_vector { T val[vec_size]; }; @@ -350,11 +362,11 @@ struct alignas(sizeof(T) * vec_size) aligned_vector { // HIP_KERNEL_ASSERT is a macro that wraps an assert() call inside rocm kernels. // TODO ROCM added support recently, should verify. #define HIP_KERNEL_ASSERT(...) -//#define HIP_KERNEL_ASSERT(...) assert(__VA_ARGS__) +// #define HIP_KERNEL_ASSERT(...) assert(__VA_ARGS__) // WARP related definitions and functions constexpr int GPU_WARP_SIZE = warpSize; -inline int GPU_WARP_SIZE_HOST= warpSizeDynamic(); +inline int GPU_WARP_SIZE_HOST = warpSizeDynamic(); template __device__ __forceinline__ T WARP_SHFL(T value, int srcLane, int width = GPU_WARP_SIZE, unsigned int mask = 0xffffffff) { From 9353c919fe1311f2850735f5ce95297af99a3144 Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Mon, 28 Aug 2023 17:45:38 +0000 Subject: [PATCH 12/25] Update count check in EvalStep C# --- .../Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs index 33993c2be135b..40f4031846161 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs @@ -358,7 +358,7 @@ public void EvalStep( IReadOnlyCollection inputValues, IReadOnlyCollection outputValues) { - if (!_evalOutputCount.Equals(outputValues.Count)) + if (_evalOutputCount != (ulong)outputValues.Count()) { throw new ArgumentException($"Length of {nameof(outputValues)} ({outputValues.Count}) must match that of train model ({_trainOutputCount})."); } From b78f80bf6b3bd1a5e95a439d4e0e01b167ad0ce3 Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Tue, 29 Aug 2023 23:18:10 +0000 Subject: [PATCH 13/25] C, C++ functions for Updating and Getting a checkpoint parameter --- .../training_api/core/training_capi_tests.cc | 102 ++++++++++++++++++ .../training_api/checkpoint_property.h | 10 +- .../include/onnxruntime_training_c_api.h | 59 +++++++++- .../include/onnxruntime_training_cxx_api.h | 36 ++++++- .../include/onnxruntime_training_cxx_inline.h | 19 ++++ .../orttraining/training_api/module.cc | 37 +++++++ orttraining/orttraining/training_api/module.h | 3 +- .../onnxruntime_training_c_api.cc | 66 +++++++++++- .../training_api/ort_training_apis.h | 9 ++ 9 files changed, 325 insertions(+), 16 deletions(-) diff --git a/orttraining/orttraining/test/training_api/core/training_capi_tests.cc b/orttraining/orttraining/test/training_api/core/training_capi_tests.cc index d734be8e3474b..e46952d87c2bf 100644 --- a/orttraining/orttraining/test/training_api/core/training_capi_tests.cc +++ b/orttraining/orttraining/test/training_api/core/training_capi_tests.cc @@ -318,4 +318,106 @@ TEST(TrainingCApiTest, LoadModelsFromBufferThrows) { testing::HasSubstr("Training Session Creation failed. Train model data cannot be NULL.")); } } + +TEST(TrainingCApiTest, GetParameter) { + auto model_uri = MODEL_FOLDER "training_model.onnx"; + + Ort::Env env; + Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt"); + Ort::TrainingSession training_session = Ort::TrainingSession(env, Ort::SessionOptions(), checkpoint_state, model_uri); + + Ort::Value parameter = checkpoint_state.GetParameter("fc1.weight"); + auto tensor_info = parameter.GetTensorTypeAndShapeInfo(); + auto shape = tensor_info.GetShape(); + ASSERT_EQ(shape.size(), 2U); + ASSERT_EQ(shape.front(), static_cast(500)); + ASSERT_EQ(shape.back(), static_cast(784)); +} + +TEST(TrainingCApiTest, UpdateParameter) { + auto model_uri = MODEL_FOLDER "training_model.onnx"; + + Ort::Env env; + Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt"); + Ort::TrainingSession training_session = Ort::TrainingSession(env, Ort::SessionOptions(), checkpoint_state, model_uri); + + Ort::Value parameter = checkpoint_state.GetParameter("fc1.weight"); + auto tensor_info = parameter.GetTensorTypeAndShapeInfo(); + auto shape = tensor_info.GetShape(); + ASSERT_EQ(shape.size(), 2U); + ASSERT_EQ(shape.front(), static_cast(500)); + ASSERT_EQ(shape.back(), static_cast(784)); + + OrtValue* updated_param_value = std::make_unique().release(); + GenerateRandomInput(std::array{500, 784}, *updated_param_value); + Ort::Value updated_parameter{updated_param_value}; + checkpoint_state.UpdateParameter("fc1.weight", updated_parameter); + + Ort::Value current_parameter = checkpoint_state.GetParameter("fc1.weight"); + gsl::span actual = gsl::span(current_parameter.GetTensorMutableData(), + current_parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + gsl::span expected = gsl::span(updated_parameter.GetTensorMutableData(), + updated_parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + gsl::span not_expected = gsl::span(parameter.GetTensorMutableData(), + parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + ASSERT_EQ(actual, expected); + ASSERT_NE(actual, not_expected); + + checkpoint_state.UpdateParameter("fc1.weight", parameter); + current_parameter = checkpoint_state.GetParameter("fc1.weight"); + actual = gsl::span(current_parameter.GetTensorMutableData(), + current_parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + expected = gsl::span(parameter.GetTensorMutableData(), + parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + not_expected = gsl::span(updated_parameter.GetTensorMutableData(), + updated_parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + ASSERT_EQ(actual, expected); + ASSERT_NE(actual, not_expected); +} + +#ifdef USE_CUDA +TEST(TrainingCApiTest, UpdateParameterDifferentDevices) { + auto model_uri = MODEL_FOLDER "training_model.onnx"; + + Ort::Env env; + Ort::SessionOptions session_options; + Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0)); + Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt"); + Ort::TrainingSession training_session = Ort::TrainingSession(env, session_options, checkpoint_state, model_uri); + + Ort::Value parameter = checkpoint_state.GetParameter("fc1.weight"); + auto tensor_info = parameter.GetTensorTypeAndShapeInfo(); + auto shape = tensor_info.GetShape(); + ASSERT_EQ(shape.size(), 2U); + ASSERT_EQ(shape.front(), static_cast(500)); + ASSERT_EQ(shape.back(), static_cast(784)); + + OrtValue* updated_param_value = std::make_unique().release(); + GenerateRandomInput(std::array{500, 784}, *updated_param_value); + Ort::Value updated_parameter{updated_param_value}; + checkpoint_state.UpdateParameter("fc1.weight", updated_parameter); + + Ort::Value current_parameter = checkpoint_state.GetParameter("fc1.weight"); + gsl::span actual = gsl::span(current_parameter.GetTensorMutableData(), + current_parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + gsl::span expected = gsl::span(updated_parameter.GetTensorMutableData(), + updated_parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + gsl::span not_expected = gsl::span(parameter.GetTensorMutableData(), + parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + ASSERT_EQ(actual, expected); + ASSERT_NE(actual, not_expected); + + checkpoint_state.UpdateParameter("fc1.weight", parameter); + current_parameter = checkpoint_state.GetParameter("fc1.weight"); + actual = gsl::span(current_parameter.GetTensorMutableData(), + current_parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + expected = gsl::span(parameter.GetTensorMutableData(), + parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + not_expected = gsl::span(updated_parameter.GetTensorMutableData(), + updated_parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + ASSERT_EQ(actual, expected); + ASSERT_NE(actual, not_expected); +} +#endif + } // namespace onnxruntime::training::test diff --git a/orttraining/orttraining/training_api/checkpoint_property.h b/orttraining/orttraining/training_api/checkpoint_property.h index d7b1e295df53e..3c38c99b3152f 100644 --- a/orttraining/orttraining/training_api/checkpoint_property.h +++ b/orttraining/orttraining/training_api/checkpoint_property.h @@ -22,10 +22,12 @@ struct PropertyBag { PropertyBag() = default; void AddProperty(const std::string& name, const PropertyDataType& val) { - ORT_ENFORCE(named_properties_.find(name) == named_properties_.end(), - "Duplicated property named ", name); - - named_properties_.insert({name, val}); + auto it = named_properties_.find(name); + if (it == named_properties_.end()) { + named_properties_.insert({name, val}); + } else { + it->second = val; + } } template diff --git a/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h b/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h index 0af737074964d..71b64ead0d388 100644 --- a/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h +++ b/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h @@ -608,14 +608,14 @@ struct OrtTrainingApi { /// \name Accessing The Training Session State /// @{ - /** \brief Adds the given property to the checkpoint state. + /** \brief Adds or updates the given property to/in the checkpoint state. * * Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint - * state by the user if they desire by calling this function with the appropriate property name and - * value. The given property name must be unique to be able to successfully add the property. + * state by the user by calling this function with the corresponding property name and value. + * The given property name must be unique to be able to successfully add the property. * * \param[in] checkpoint_state The checkpoint state which should hold the property. - * \param[in] property_name Unique name of the property being added. + * \param[in] property_name Name of the property being added or updated. * \param[in] property_type Type of the property associated with the given name. * \param[in] property_value Property value associated with the given name. * @@ -632,7 +632,7 @@ struct OrtTrainingApi { * exist in the checkpoint state to be able to retrieve it successfully. * * \param[in] checkpoint_state The checkpoint state that is currently holding the property. - * \param[in] property_name Unique name of the property being retrieved. + * \param[in] property_name Name of the property being retrieved. * \param[in] allocator Allocator used to allocate the memory for the property_value. * \param[out] property_type Type of the property associated with the given name. * \param[out] property_value Property value associated with the given name. @@ -669,6 +669,55 @@ struct OrtTrainingApi { ORT_API2_STATUS(LoadCheckpointFromBuffer, _In_ const void* checkpoint_buffer, _In_ const size_t num_bytes, _Outptr_ OrtCheckpointState** checkpoint_state); + /** \brief Retrieves the type and shape information of the parameter associated with the given parameter name. + * + * This function retrieves the type and shape of the parameter associated with the given parameter name. + * The parameter must exist in the checkpoint state to be able to retrieve its type and shape information successfully. + * + * \param[in] checkpoint_state The checkpoint state. + * \param[in] parameter_name Name of the parameter being retrieved. + * \param[out] parameter_type_and_shape The type and shape of the parameter being retrieved. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + */ + ORT_API2_STATUS(GetParameterTypeAndShape, _In_ const OrtCheckpointState* checkpoint_state, + _In_ const char* parameter_name, _Outptr_ OrtTensorTypeAndShapeInfo** parameter_type_and_shape); + + /** \brief Updates the data associated with the model parameter in the checkpoint state for the given parameter name. + * + * This function updates a model parameter in the checkpoint state with the given parameter data. + * The training session must be already created with the checkpoint state that contains the parameter + * being updated. The given parameter is copied over to the registered device for the training session. + * The parameter must exist in the checkpoint state to be able to update it successfully. + * + * \param[in] checkpoint_state The checkpoint state. + * \param[in] parameter_name Name of the parameter being updated. + * \param[in] parameter The parameter data that should replace the existing parameter data. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + */ + ORT_API2_STATUS(UpdateParameter, _Inout_ OrtCheckpointState* checkpoint_state, + _In_ const char* parameter_name, _In_ OrtValue* parameter); + + /** \brief Gets the data associated with the model parameter from the checkpoint state for the given parameter name. + * + * This function retrieves the model parameter data from the checkpoint state for the given parameter name. + * The parameter is copied over to the provided OrtValue. The training session must be already created + * with the checkpoint state that contains the parameter being retrieved. + * The parameter must exist in the checkpoint state to be able to retrieve it successfully. + * + * \param[in] checkpoint_state The checkpoint state. + * \param[in] parameter_name Name of the parameter being retrieved. + * \param[out] parameter The parameter data that is retrieved from the checkpoint state. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + */ + ORT_API2_STATUS(GetParameter, _In_ const OrtCheckpointState* checkpoint_state, + _In_ const char* parameter_name, _Inout_ OrtValue* parameter); + /// @} }; diff --git a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h index 0edef20ba6da8..218bef524200c 100644 --- a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h +++ b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h @@ -112,13 +112,13 @@ class CheckpointState : public detail::Base { const std::basic_string& path_to_checkpoint, const bool include_optimizer_state = false); - /** \brief Adds the given property to the checkpoint state. + /** \brief Adds or updates the given property to/in the checkpoint state. * * Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint - * state by the user if they desire by calling this function with the appropriate property name and - * value. The given property name must be unique to be able to successfully add the property. + * state by the user by calling this function with the corresponding property name and value. + * The given property name must be unique to be able to successfully add the property. * - * \param[in] property_name Unique name of the property being added. + * \param[in] property_name Name of the property being added or updated. * \param[in] property_value Property value associated with the given name. * */ @@ -129,12 +129,38 @@ class CheckpointState : public detail::Base { * Gets the property value from an existing entry in the checkpoint state. The property must * exist in the checkpoint state to be able to retrieve it successfully. * - * \param[in] property_name Unique name of the property being retrieved. + * \param[in] property_name Name of the property being retrieved. * \return Property value associated with the given property name. * */ Property GetProperty(const std::string& property_name); + /** \brief Updates the data associated with the model parameter in the checkpoint state for the given parameter name. + * + * This function updates a model parameter in the checkpoint state with the given parameter data. + * The training session must be already created with the checkpoint state that contains the parameter + * being updated. The given parameter is copied over to the registered device for the training session. + * The parameter must exist in the checkpoint state to be able to update it successfully. + * + * \param[in] parameter_name Name of the parameter being updated. + * \param[in] parameter The parameter data that should replace the existing parameter data. + * + */ + void UpdateParameter(const std::string& parameter_name, const Value& parameter); + + /** \brief Gets the data associated with the model parameter from the checkpoint state for the given parameter name. + * + * This function retrieves the model parameter data from the checkpoint state for the given parameter name. + * The parameter is copied over to the provided OrtValue. The training session must be already created + * with the checkpoint state that contains the parameter being retrieved. + * The parameter must exist in the checkpoint state to be able to retrieve it successfully. + * + * \param[in] parameter_name Name of the parameter being retrieved. + * \return The parameter data that is retrieved from the checkpoint state. + * + */ + Value GetParameter(const std::string& parameter_name); + /// @} }; diff --git a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h index 066147708863f..d6a8764a7dbe5 100644 --- a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h +++ b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h @@ -279,4 +279,23 @@ inline Property CheckpointState::GetProperty(const std::string& property_name) { return property; } +inline void CheckpointState::UpdateParameter(const std::string& parameter_name, const Value& parameter) { + ThrowOnError(GetTrainingApi().UpdateParameter(p_, parameter_name.c_str(), parameter)); +} + +inline Value CheckpointState::GetParameter(const std::string& parameter_name) { + OrtTensorTypeAndShapeInfo* parameter_type_and_shape_info; + ThrowOnError(GetTrainingApi().GetParameterTypeAndShape(p_, parameter_name.c_str(), ¶meter_type_and_shape_info)); + auto parameter_type_and_shape = TensorTypeAndShapeInfo{parameter_type_and_shape_info}; + auto shape = parameter_type_and_shape.GetShape(); + + AllocatorWithDefaultOptions allocator; + Value parameter = Value::CreateTensor(allocator, shape.data(), shape.size(), + ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); + + ThrowOnError(GetTrainingApi().GetParameter(p_, parameter_name.c_str(), parameter)); + + return parameter; +} + } // namespace Ort diff --git a/orttraining/orttraining/training_api/module.cc b/orttraining/orttraining/training_api/module.cc index d1775e358163c..e86526f902a9c 100644 --- a/orttraining/orttraining/training_api/module.cc +++ b/orttraining/orttraining/training_api/module.cc @@ -119,6 +119,43 @@ Status TransformModelInputsForInference(Graph& inference_graph, #endif } // namespace +Status Parameter::CopyTo(const DataTransferManager* data_transfer_manager, OrtValue& data) const { + ORT_ENFORCE(data.IsAllocated(), "Given parameter data is not allocated. Cannot cope the checkpoint parameter to it."); + ORT_ENFORCE(data.IsTensor(), "Parameter data should be of tensor type."); + ORT_ENFORCE(data.Get().Shape() == data_.Get().Shape(), + "Parameter data shape mismatch. Expected: ", data_.Get().Shape().ToString(), + ", Got: ", data.Get().Shape().ToString()); + ORT_ENFORCE(data.Get().DataType() == data_.Get().DataType(), + "Parameter data type mismatch. Expected: ", data_.Get().DataType(), + ", Got: ", data.Get().DataType()); + ORT_ENFORCE(data_transfer_manager != nullptr, + "Data transfer manager must be provided to copy data to the parameter. " + "Please create the TrainingSession before trying to update the parameter."); + + ORT_THROW_IF_ERROR(data_transfer_manager->CopyTensor(data_.Get(), *data.GetMutable())); + + return Status::OK(); +} + +Status Parameter::CopyFrom(const OrtValue& data, const DataTransferManager* data_transfer_manager) { + ORT_ENFORCE(data_.IsAllocated(), + "The checkpoint parameter is not allocated. Cannot copy the given parameter data to it."); + ORT_ENFORCE(data.IsTensor(), "Parameter data should be of tensor type."); + ORT_ENFORCE(data.Get().Shape() == data_.Get().Shape(), + "Parameter data shape mismatch. Expected: ", data_.Get().Shape().ToString(), + ", Got: ", data.Get().Shape().ToString()); + ORT_ENFORCE(data.Get().DataType() == data_.Get().DataType(), + "Parameter data type mismatch. Expected: ", data_.Get().DataType(), + ", Got: ", data.Get().DataType()); + ORT_ENFORCE(data_transfer_manager != nullptr, + "Data transfer manager must be provided to copy data to the parameter. " + "Please create the TrainingSession before trying to update the parameter."); + + ORT_THROW_IF_ERROR(data_transfer_manager->CopyTensor(data.Get(), *data_.GetMutable())); + + return Status::OK(); +} + Status Parameter::SetGrad(const std::string& gradient_name, const OrtValue& param_grad) { // assert param is allocated ORT_ENFORCE(data_.IsAllocated(), "Parameter data should be allocated before allocating gradient."); diff --git a/orttraining/orttraining/training_api/module.h b/orttraining/orttraining/training_api/module.h index adb633343263e..a638a421ecf90 100644 --- a/orttraining/orttraining/training_api/module.h +++ b/orttraining/orttraining/training_api/module.h @@ -21,6 +21,8 @@ struct Parameter { // Return the mutable data. OrtValue& Data() { return data_; } + Status CopyTo(const DataTransferManager* data_transfer_manager, OrtValue& data) const; + Status CopyFrom(const OrtValue& data, const DataTransferManager* data_transfer_manager); const std::string& Name() const { return name_; } // Returns whether this parameter is trainable or not. @@ -34,7 +36,6 @@ struct Parameter { // Reset and release the gradient buffer of this Parameter greedily. Status ResetGrad(); - protected: Status SetGrad(const std::string& gradient_name, const OrtValue& param_grad); private: diff --git a/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc b/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc index 6693bba348648..23649d6d34b9b 100644 --- a/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc +++ b/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc @@ -333,6 +333,10 @@ ORT_API_STATUS_IMPL(OrtTrainingApis::LoadCheckpointFromBuffer, _In_ const void* _In_ const size_t num_bytes, _Outptr_ OrtCheckpointState** checkpoint_state) { API_IMPL_BEGIN + if (checkpoint_buffer == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Expected a valid checkpoint buffer. Actual: nullptr."); + } + *checkpoint_state = nullptr; auto chkpt_state = std::make_unique(); const auto* checkpoint_bytes = reinterpret_cast(checkpoint_buffer); @@ -559,6 +563,63 @@ ORT_API_STATUS_IMPL(OrtTrainingApis::GetProperty, _In_ const OrtCheckpointState* API_IMPL_END } +ORT_API_STATUS_IMPL(OrtTrainingApis::GetParameterTypeAndShape, _In_ const OrtCheckpointState* checkpoint_state, + _In_ const char* parameter_name, _Outptr_ OrtTensorTypeAndShapeInfo** parameter_type_and_shape) { + API_IMPL_BEGIN + + auto chkpt_state = reinterpret_cast(checkpoint_state); + auto it = chkpt_state->module_checkpoint_state.named_parameters.find(parameter_name); + if (it == chkpt_state->module_checkpoint_state.named_parameters.end()) { + std::string err_msg = "Parameter name " + std::string(parameter_name) + " not found in checkpoint state."; + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, err_msg.c_str()); + } + + return OrtApis::GetTensorTypeAndShape(&it->second->Data(), parameter_type_and_shape); + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtTrainingApis::UpdateParameter, _Inout_ OrtCheckpointState* checkpoint_state, + _In_ const char* parameter_name, _In_ OrtValue* parameter) { + API_IMPL_BEGIN + if (parameter == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Expected a valid parameter. Actual: nullptr."); + } + + auto chkpt_state = reinterpret_cast(checkpoint_state); + auto it = chkpt_state->module_checkpoint_state.named_parameters.find(parameter_name); + if (it == chkpt_state->module_checkpoint_state.named_parameters.end()) { + std::string err_msg = "Parameter name " + std::string(parameter_name) + " not found in checkpoint state."; + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, err_msg.c_str()); + } + ORT_API_RETURN_IF_STATUS_NOT_OK(it->second->CopyFrom( + *parameter, chkpt_state->module_checkpoint_state.train_session_data_transfer_mgr)); + + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtTrainingApis::GetParameter, _In_ const OrtCheckpointState* checkpoint_state, + _In_ const char* parameter_name, _Inout_ OrtValue* parameter) { + API_IMPL_BEGIN + + if (parameter == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Expected a valid parameter. Actual: nullptr."); + } + + auto chkpt_state = reinterpret_cast(checkpoint_state); + auto it = chkpt_state->module_checkpoint_state.named_parameters.find(parameter_name); + if (it == chkpt_state->module_checkpoint_state.named_parameters.end()) { + std::string err_msg = "Parameter name " + std::string(parameter_name) + " not found in checkpoint state."; + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, err_msg.c_str()); + } + + ORT_API_RETURN_IF_STATUS_NOT_OK(it->second->CopyTo( + chkpt_state->module_checkpoint_state.train_session_data_transfer_mgr, *parameter)); + + return nullptr; + API_IMPL_END +} + static constexpr OrtTrainingApi ort_training_api = { // NOTE: The C# bindings depend on the API order within this struct. Since Training APIs are not officially // released, it is OK to change the order here, however a corresponding matching change should also be done in the @@ -592,7 +653,10 @@ static constexpr OrtTrainingApi ort_training_api = { &OrtTrainingApis::TrainingSessionGetEvalModelInputName, &OrtTrainingApis::AddProperty, &OrtTrainingApis::GetProperty, - &OrtTrainingApis::LoadCheckpointFromBuffer}; + &OrtTrainingApis::LoadCheckpointFromBuffer, + &OrtTrainingApis::GetParameterTypeAndShape, + &OrtTrainingApis::UpdateParameter, + &OrtTrainingApis::GetParameter}; ORT_API(const OrtTrainingApi*, OrtTrainingApis::GetTrainingApi, uint32_t) { // No constraints on the API version yet. diff --git a/orttraining/orttraining/training_api/ort_training_apis.h b/orttraining/orttraining/training_api/ort_training_apis.h index c87108957c975..6d65d786848cd 100644 --- a/orttraining/orttraining/training_api/ort_training_apis.h +++ b/orttraining/orttraining/training_api/ort_training_apis.h @@ -94,4 +94,13 @@ ORT_API_STATUS_IMPL(GetProperty, _In_ const OrtCheckpointState* checkpoint_state ORT_API_STATUS_IMPL(LoadCheckpointFromBuffer, _In_ const void* checkpoint_buffer, _In_ const size_t num_bytes, _Outptr_ OrtCheckpointState** checkpoint_state); +ORT_API_STATUS_IMPL(GetParameterTypeAndShape, _In_ const OrtCheckpointState* checkpoint_state, + _In_ const char* parameter_name, _Outptr_ OrtTensorTypeAndShapeInfo** parameter_type_and_shape); + +ORT_API_STATUS_IMPL(UpdateParameter, _Inout_ OrtCheckpointState* checkpoint_state, + _In_ const char* parameter_name, _In_ OrtValue* parameter); + +ORT_API_STATUS_IMPL(GetParameter, _In_ const OrtCheckpointState* checkpoint_state, + _In_ const char* parameter_name, _Inout_ OrtValue* parameter); + } // namespace OrtTrainingApis From 0ab37c6f5e6a48e83b7d2afb43cafb7debd82923 Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Wed, 30 Aug 2023 21:12:34 +0000 Subject: [PATCH 14/25] Adding UpdateParameter and GetParameter to C# --- .../Training/CheckpointState.shared.cs | 85 +++++++++++++++---- .../Training/NativeTrainingMethods.shared.cs | 33 +++++++ .../TrainingTest.cs | 84 ++++++++++++++++++ 3 files changed, 186 insertions(+), 16 deletions(-) diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs index 659c6303702ac..47de5a82176e3 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs @@ -103,13 +103,13 @@ public static void SaveCheckpoint(CheckpointState state, string checkpointPath, } /// - /// Adds the given int property to the checkpoint state. + /// Adds or updates the given int property to/in the checkpoint state. /// - /// Runtime properties that are ints such as epoch, training step, and others can be added to the checkpoint - /// state by the user if they desire by calling this function with the appropriate property name and - /// value. The given property name must be unique to be able to successfully add the property. + /// Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint + /// state by the user by calling this function with the corresponding property name and value. + /// The given property name must be unique to be able to successfully add the property. /// - /// Unique name of the property being added. + /// Name of the property being added or updated. /// Property value associated with the given name. public void AddProperty(string propertyName, long propertyValue) { @@ -117,13 +117,13 @@ public void AddProperty(string propertyName, long propertyValue) } /// - /// Adds the given float property to the checkpoint state. + /// Adds or updates the given float property to/in the checkpoint state. /// - /// Runtime properties that are floats such as loss, best score, and others can be added to the checkpoint - /// state by the user if they desire by calling this function with the appropriate property name and - /// value. The given property name must be unique to be able to successfully add the property. + /// Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint + /// state by the user by calling this function with the corresponding property name and value. + /// The given property name must be unique to be able to successfully add the property. /// - /// Unique name of the property being added. + /// Name of the property being added or updated. /// Property value associated with the given name. public void AddProperty(string propertyName, float propertyValue) { @@ -131,13 +131,13 @@ public void AddProperty(string propertyName, float propertyValue) } /// - /// Adds the given string property to the checkpoint state. + /// Adds or updates the given string property to/in the checkpoint state. /// - /// Runtime properties that are strings such as parameter names, custom strings, and others can be added - /// to the checkpoint state by the user if they desire by calling this function with the appropriate property - /// name and value. The given property name must be unique to be able to successfully add the property. + /// Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint + /// state by the user by calling this function with the corresponding property name and value. + /// The given property name must be unique to be able to successfully add the property. /// - /// Unique name of the property being added. + /// Name of the property being added or updated. /// Property value associated with the given name. public void AddProperty(string propertyName, string propertyValue) { @@ -162,7 +162,7 @@ public void AddProperty(string propertyName, string propertyValue) /// Gets the property value from an existing entry in the checkpoint state. The property must /// exist in the checkpoint state to be able to retrieve it successfully. ///
- /// Unique name of the property being retrieved. + /// Name of the property being retrieved. /// Property value associated with the given property name. public object GetProperty(string propertyName) { @@ -192,6 +192,59 @@ public object GetProperty(string propertyName) throw new ArgumentException("Expected the property type to be one of long, float or string. Unknown type retrieved " + propertyValue.ToString()); } + /// + /// Updates the data associated with the model parameter in the checkpoint state for the given parameter name. + /// + /// This function updates a model parameter in the checkpoint state with the given parameter data. + /// The training session must be already created with the checkpoint state that contains the parameter + /// being updated. The given parameter is copied over to the registered device for the training session. + /// The parameter must exist in the checkpoint state to be able to update it successfully. + /// + /// Name of the parameter being updated. + /// The parameter data that should replace the existing parameter data. + public void UpdateParameter(string parameterName, OrtValue parameter) + { + if (parameter.OnnxType != OnnxValueType.ONNX_TYPE_TENSOR) + { + throw new ArgumentException("Incorrect buffer received. Expected a tensor parameter."); + } + + var parameterNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(parameterName); + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtUpdateParameter(handle, parameterNameUtf8, parameter.Handle)); + } + + /// + /// Gets the data associated with the model parameter from the checkpoint state for the given parameter name. + /// + /// This function retrieves the model parameter data from the checkpoint state for the given parameter name. + /// The parameter is copied over to the provided OrtValue. The training session must be already created + /// with the checkpoint state that contains the parameter being retrieved. + /// The parameter must exist in the checkpoint state to be able to retrieve it successfully. + /// + /// Name of the parameter being updated. + /// The parameter data that is retrieved from the checkpoint state. + public OrtValue GetParameter(string parameterName) + { + var parameterNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(parameterName); + + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParameterTypeAndShape(handle, parameterNameUtf8, out IntPtr typeAndShapeInfoHandle)); + + try + { + var typeAndShapeInfo = new OrtTensorTypeAndShapeInfo(typeAndShapeInfoHandle); + var parameter = OrtValue.CreateAllocatedTensorValue(OrtAllocator.DefaultInstance, typeAndShapeInfo.ElementDataType, typeAndShapeInfo.Shape); + + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParameter(handle, parameterNameUtf8, parameter.Handle)); + + return parameter; + } + finally + { + NativeMethods.OrtReleaseTensorTypeAndShapeInfo(typeAndShapeInfoHandle); + } + + } + #region SafeHandle /// /// Overrides SafeHandle.ReleaseHandle() to properly dispose of diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs index ac790242409e3..6f1d94a8a8d25 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs @@ -42,6 +42,9 @@ public struct OrtTrainingApi public IntPtr AddProperty; public IntPtr GetProperty; public IntPtr LoadCheckpointFromBuffer; + public IntPtr GetParameterTypeAndShape; + public IntPtr UpdateParameter; + public IntPtr GetParameter; } internal static class NativeTrainingMethods @@ -97,6 +100,9 @@ static NativeTrainingMethods() OrtGetEvalModelInputName = (DOrtGetEvalModelInputName)Marshal.GetDelegateForFunctionPointer(trainingApi_.TrainingSessionGetEvalModelInputName, typeof(DOrtGetEvalModelInputName)); OrtAddProperty = (DOrtAddProperty)Marshal.GetDelegateForFunctionPointer(trainingApi_.AddProperty, typeof(DOrtAddProperty)); OrtGetProperty = (DOrtGetProperty)Marshal.GetDelegateForFunctionPointer(trainingApi_.GetProperty, typeof(DOrtGetProperty)); + OrtGetParameterTypeAndShape = (DOrtGetParameterTypeAndShape)Marshal.GetDelegateForFunctionPointer(trainingApi_.GetParameterTypeAndShape, typeof(DOrtGetParameterTypeAndShape)); + OrtUpdateParameter = (DOrtUpdateParameter)Marshal.GetDelegateForFunctionPointer(trainingApi_.UpdateParameter, typeof(DOrtUpdateParameter)); + OrtGetParameter = (DOrtGetParameter)Marshal.GetDelegateForFunctionPointer(trainingApi_.GetParameter, typeof(DOrtGetParameter)); } } @@ -359,6 +365,33 @@ out UIntPtr inputCount public static DOrtGetProperty OrtGetProperty; + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /*(OrtStatus*)*/ DOrtGetParameterTypeAndShape( + IntPtr /*(OrtCheckpointState*)*/ checkpointState, + byte[] /*(const char*)*/ parameterName, + out IntPtr /*(OrtTensorTypeAndShapeInfo**)*/ parameterTypeAndShape + ); + + public static DOrtGetParameterTypeAndShape OrtGetParameterTypeAndShape; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /*(OrtStatus*)*/ DOrtUpdateParameter( + IntPtr /*(OrtCheckpointState*)*/ checkpointState, + byte[] /*(const char*)*/ parameterName, + IntPtr /*(OrtValue*)*/ parameter + ); + + public static DOrtUpdateParameter OrtUpdateParameter; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /*(OrtStatus*)*/ DOrtGetParameter( + IntPtr /*(OrtCheckpointState*)*/ checkpointState, + byte[] /*(const char*)*/ parameterName, + IntPtr /*(OrtValue*)*/ parameter + ); + + public static DOrtGetParameter OrtGetParameter; + #endregion TrainingSession API public static bool TrainingEnabled() diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs index ea2b6d7dbc118..82d8bbe715b74 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs @@ -530,6 +530,90 @@ public void TestSetSeed() TrainingUtils.SetSeed(8888); } + [Fact(DisplayName = "TestGetParameter")] + public void TestGetParameter() + { + string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt"); + using (var cleanUp = new DisposableListTest()) + { + var state = CheckpointState.LoadCheckpoint(checkpointPath); + cleanUp.Add(state); + Assert.NotNull(state); + string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); + string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx"); + string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx"); + + var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath); + cleanUp.Add(trainingSession); + + var parameter = state.GetParameter("fc1.weight"); + cleanUp.Add(parameter); + + Assert.NotNull(parameter); + var typeShape = parameter.GetTensorTypeAndShape(); + + Assert.Equal(2, typeShape.DimensionsCount); + var fetchedShape = typeShape.Shape; + Assert.Equal(500, fetchedShape[0]); + Assert.Equal(784, fetchedShape[1]); + } + } + + [Fact(DisplayName = "TestUpdateParameter")] + public void TestUpdateParameter() + { + string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt"); + using (var cleanUp = new DisposableListTest()) + { + var state = CheckpointState.LoadCheckpoint(checkpointPath); + cleanUp.Add(state); + Assert.NotNull(state); + string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); + string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx"); + string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx"); + + var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath); + cleanUp.Add(trainingSession); + + var parameter = state.GetParameter("fc1.weight"); + cleanUp.Add(parameter); + + Assert.NotNull(parameter); + var typeShape = parameter.GetTensorTypeAndShape(); + + Assert.Equal(2, typeShape.DimensionsCount); + var fetchedShape = typeShape.Shape; + Assert.Equal(500, fetchedShape[0]); + Assert.Equal(784, fetchedShape[1]); + + float maxVal = 20; + Random randNum = new Random(); + float[] updated_parameter_buffer = Enumerable + .Repeat(0, 500 * 784) + .Select(i => maxVal * (float)randNum.NextDouble()) + .ToArray(); + + var updated_parameter = OrtValue.CreateTensorValueFromMemory(updated_parameter_buffer, fetchedShape); + cleanUp.Add(updated_parameter); + + state.UpdateParameter("fc1.weight", updated_parameter); + var current_parameter = state.GetParameter("fc1.weight"); + cleanUp.Add(current_parameter); + + var current_parameter_tensor = current_parameter.GetTensorDataAsSpan().ToArray(); + Assert.Equal(updated_parameter_buffer, current_parameter_tensor); + Assert.NotEqual(parameter.GetTensorDataAsSpan().ToArray(), current_parameter_tensor); + + state.UpdateParameter("fc1.weight", parameter); + current_parameter = state.GetParameter("fc1.weight"); + cleanUp.Add(current_parameter); + + current_parameter_tensor = current_parameter.GetTensorDataAsSpan().ToArray(); + Assert.Equal(parameter.GetTensorDataAsSpan().ToArray(), current_parameter_tensor); + Assert.NotEqual(updated_parameter_buffer, current_parameter_tensor); + } + } + internal class FloatComparer : IEqualityComparer { private float atol = 1e-3f; From 619cb81f5caac7ca525fc848738fda699e35a76c Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Thu, 31 Aug 2023 02:53:14 +0000 Subject: [PATCH 15/25] Expose model parameters and their gradients in Python --- .../python/orttraining_pybind_state.cc | 62 +++++++++--- .../python/training/api/checkpoint_state.py | 96 ++++++++++++++++--- .../orttraining_test_python_bindings.py | 52 ++++++++++ 3 files changed, 185 insertions(+), 25 deletions(-) diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc index 3f3aa396e6ca0..bf6c7666b6a80 100644 --- a/orttraining/orttraining/python/orttraining_pybind_state.cc +++ b/orttraining/orttraining/python/orttraining_pybind_state.cc @@ -1065,17 +1065,42 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn checkpoint_state(m, "CheckpointState", R"pbdoc(CheckpointState.)pbdoc"); checkpoint_state .def(py::init()) - .def("add_property", [](onnxruntime::training::api::CheckpointState* state, - const std::string& property_name, - const std::variant& property_value) { - state->property_bag.AddProperty(property_name, property_value); - }) - .def("get_property", [](onnxruntime::training::api::CheckpointState* state, const std::string& property_name) { - return state->property_bag.GetProperty(property_name); - }) - .def("has_property", [](onnxruntime::training::api::CheckpointState* state, const std::string& property_name) { - return state->property_bag.HasProperty(property_name); - }); + .def("add_property", + [](onnxruntime::training::api::CheckpointState* state, + const std::string& property_name, + const std::variant& property_value) { + state->property_bag.AddProperty(property_name, property_value); + }) + .def("get_property", + [](onnxruntime::training::api::CheckpointState* state, const std::string& property_name) { + return state->property_bag.GetProperty(property_name); + }) + .def("has_property", + [](onnxruntime::training::api::CheckpointState* state, const std::string& property_name) { + return state->property_bag.HasProperty(property_name); + }) + .def("copy_parameter_from", + [](onnxruntime::training::api::CheckpointState* state, + const std::string& parameter_name, OrtValue& value) -> void { + auto it = state->module_checkpoint_state.named_parameters.find(parameter_name); + if (it == state->module_checkpoint_state.named_parameters.end()) { + ORT_THROW("Parameter with name ", parameter_name, " does not exist."); + } + ORT_THROW_IF_ERROR(it->second->CopyFrom( + value, state->module_checkpoint_state.train_session_data_transfer_mgr)); + }) + .def("get_parameter", + [](onnxruntime::training::api::CheckpointState* state, const std::string& parameter_name) { + auto it = state->module_checkpoint_state.named_parameters.find(parameter_name); + if (it == state->module_checkpoint_state.named_parameters.end()) { + ORT_THROW("Parameter with name ", parameter_name, " does not exist."); + } + return it->second; + }) + .def("has_parameter", + [](onnxruntime::training::api::CheckpointState* state, const std::string& parameter_name) { + return state->module_checkpoint_state.named_parameters.count(parameter_name); + }); py::class_ training_optimizer(m, "Optimizer", R"pbdoc(Training Optimizer.)pbdoc"); @@ -1111,6 +1136,21 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn ORT_THROW_IF_ERROR(scheduler->Step()); }); + py::class_> + parameter(m, "Parameter"); + parameter + .def_property_readonly("name", &onnxruntime::training::api::Parameter::Name) + .def_property_readonly("data", &onnxruntime::training::api::Parameter::Data) + .def_property_readonly("grad", &onnxruntime::training::api::Parameter::Gradient) + .def_property_readonly("requires_grad", &onnxruntime::training::api::Parameter::RequiresGrad) + .def("copy_from", + [](onnxruntime::training::api::Parameter* parameter, + onnxruntime::training::api::CheckpointState* state, + OrtValue& value) -> void { + ORT_THROW_IF_ERROR(parameter->CopyFrom(value, state->module_checkpoint_state.train_session_data_transfer_mgr)); + }); + m.def( "save_checkpoint", [](const std::vector& trainable_tensor_protos_pybytes, diff --git a/orttraining/orttraining/python/training/api/checkpoint_state.py b/orttraining/orttraining/python/training/api/checkpoint_state.py index 285264bbed744..d723c86711124 100644 --- a/orttraining/orttraining/python/training/api/checkpoint_state.py +++ b/orttraining/orttraining/python/training/api/checkpoint_state.py @@ -5,7 +5,56 @@ import os +import numpy as np + from onnxruntime.capi import _pybind_state as C +from onnxruntime.capi.onnxruntime_inference_collection import OrtValue + + +class Parameter: + """Class that represents a model parameter + + This class represents a model parameter and provides access to its data, + gradient and other properties. This class is not expected to be instantiated directly. + Instead, it is returned by the `CheckpointState` object. + + Args: + parameter: The C.Parameter object that holds the underlying parameter data. + state: The C.CheckpointState object that holds the underlying session state. + """ + + def __init__(self, parameter: C.Parameter, state: C.CheckpointState): + self._parameter = parameter + self._state = state + + @property + def name(self) -> str: + """The name of the parameter""" + return self._parameter.name + + @property + def data(self) -> np.ndarray: + """The data of the parameter""" + return self._parameter.data.numpy() + + @data.setter + def data(self, value: np.ndarray) -> None: + """Sets the data of the parameter""" + self._parameter.copy_from(self._state, OrtValue.ortvalue_from_numpy(value)._ortvalue) + + @property + def grad(self) -> np.ndarray: + """The gradient of the parameter""" + return self._parameter.grad.numpy() if self._parameter.grad.has_value() else None + + @property + def requires_grad(self) -> bool: + """Whether or not the parameter requires its gradient to be computed""" + return self._parameter.requires_grad + + def __repr__(self) -> str: + """Returns a string representation of the parameter""" + return f"Parameter(name={self.name}, requires_grad={self.requires_grad})" class CheckpointState: @@ -52,33 +101,52 @@ def save_checkpoint( """ C.save_checkpoint(state._state, os.fspath(checkpoint_uri), include_optimizer_state) - def __getitem__(self, name: str) -> int | float | str: - """Gets the property associated with the given name + def __getitem__(self, name: str) -> int | float | str | Parameter: + """Gets the parameter or property associated with the given name + + Searches for the name in the parameters and properties of the checkpoint state. Args: - name: The name of the property + name: The name of the parameter or property Returns: - The value of the property + The value of the parameter or property """ - return self._state.get_property(name) - def __setitem__(self, name: str, value: int | float | str) -> None: - """Sets the property value for the given name + if self._state.has_parameter(name): + return Parameter(self._state.get_parameter(name), self._state) + elif self._state.has_property(name): + return self._state.get_property(name) + else: + raise KeyError(f"Could not find {name} in the checkpoint state.") + + def __setitem__(self, name: str, value: int | float | str | np.ndarray) -> None: + """Sets the parameter or property value for the given name + + Searches for the name in the parameters and properties of the checkpoint state. + If the name is found in parameters, the value is updated. + Else, the value is added or updated in the properties. Args: - name: The name of the property - value: The value of the property + name: The name of the parameter or property + value: The value of the parameter or property + Properties only support int, float and str values. """ - self._state.add_property(name, value) + if self._state.has_parameter(name): + self._state.copy_parameter_from(name, OrtValue.ortvalue_from_numpy(value)._ortvalue) + else: + self._state.add_property(name, value) def __contains__(self, name: str) -> bool: - """Checks if the property exists in the state + """Checks if the parameter or property exists in the state + + Tthe name is searched in both parameters and properties. Args: - name: The name of the property + name: The name of the parameter or property Returns: - True if the property exists, False otherwise + True if the name is either a parameter or a property, False otherwise """ - return self._state.has_property(name) + + return self._state.has_parameter(name) or self._state.has_property(name) diff --git a/orttraining/orttraining/test/python/orttraining_test_python_bindings.py b/orttraining/orttraining/test/python/orttraining_test_python_bindings.py index 56338ddbaffef..8debf4a9cbf10 100644 --- a/orttraining/orttraining/test/python/orttraining_test_python_bindings.py +++ b/orttraining/orttraining/test/python/orttraining_test_python_bindings.py @@ -563,3 +563,55 @@ def test_eval_step_with_ort_values(): fetches = model(inputs, labels) assert isinstance(fetches, OrtValue) assert fetches + + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_get_and_set_parameter_values(device): + with tempfile.TemporaryDirectory() as temp_dir: + ( + checkpoint_file_path, + training_model_file_path, + eval_model_file_path, + _, + pt_model, + ) = _create_training_artifacts( + temp_dir, requires_grad=["fc2.weight", "fc2.bias"], frozen_params=["fc1.weight", "fc1.bias"] + ) + + state = CheckpointState.load_checkpoint(checkpoint_file_path) + + model = Module(training_model_file_path, state, eval_model_file_path, device=device) + + for name, pt_param in pt_model.named_parameters(): + ort_param = state[name] + assert ort_param.name == name + assert np.allclose(pt_param.detach().cpu().numpy(), ort_param.data) + if name in ["fc1.weight", "fc1.bias"]: + assert ort_param.requires_grad is False + assert ort_param.grad is None + else: + assert ort_param.requires_grad is True + assert np.allclose(ort_param.grad, np.zeros_like(ort_param.data, dtype=np.float32)) + + original_param = state["fc1.weight"].data + state["fc1.weight"].data = np.ones_like(state["fc1.weight"].data, dtype=np.float32) + updated_param = state["fc1.weight"].data + assert np.allclose(updated_param, np.ones_like(updated_param, dtype=np.float32)) + + model.train() + inputs = torch.randn(64, 784).numpy() + labels = torch.randint(high=10, size=(64,), dtype=torch.int64).numpy() + loss = model(inputs, labels) + assert loss is not None + for name, _ in pt_model.named_parameters(): + ort_param = state[name] + assert ort_param.name == name + if name in ["fc1.weight", "fc1.bias"]: + assert ort_param.requires_grad is False + assert ort_param.grad is None + else: + assert ort_param.requires_grad is True + assert ort_param.grad.any() + + state["fc1.weight"] = original_param + assert np.allclose(state["fc1.weight"].data, original_param) From 9262106206cf11c35751c8d81cc7bfadc0ef755a Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Mon, 18 Sep 2023 12:41:55 -0700 Subject: [PATCH 16/25] Address pull request review comments --- .../Training/CheckpointState.shared.cs | 38 ++-- .../Training/TrainingSession.shared.cs | 28 +-- .../TrainingTest.cs | 162 +++++++++--------- 3 files changed, 117 insertions(+), 111 deletions(-) diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs index 47de5a82176e3..c0b5d8b3ae8ca 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs @@ -144,15 +144,12 @@ public void AddProperty(string propertyName, string propertyValue) var propertyNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyName); var propertyValueUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyValue); - IntPtr unmanagedPointer = Marshal.AllocHGlobal(propertyValueUtf8.Length); - try - { - Marshal.Copy(propertyValueUtf8, 0, unmanagedPointer, propertyValueUtf8.Length); - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, PropertyType.String, unmanagedPointer)); - } - finally + unsafe { - Marshal.FreeHGlobal(unmanagedPointer); + fixed (byte* p = propertyValueUtf8) + { + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, PropertyType.String, (IntPtr)p)); + } } } @@ -173,23 +170,32 @@ public object GetProperty(string propertyName) if (propertyType == PropertyType.Int) { - var longPropertyValue = Marshal.ReadInt64(propertyValue); - allocator.FreeMemory(propertyValue); - return longPropertyValue; + Int64 value; + unsafe + { + value = *(Int64*)propertyValue; + } + return value; } else if (propertyType == PropertyType.Float) { - float[] value = new float[1]; - Marshal.Copy(propertyValue, value, 0, 1); - allocator.FreeMemory(propertyValue); - return value[0]; + float value; + unsafe + { + value = *(float*)propertyValue; + } + return value; } else if (propertyType == PropertyType.String) { return NativeOnnxValueHelper.StringFromNativeUtf8(propertyValue, allocator); } - throw new ArgumentException("Expected the property type to be one of long, float or string. Unknown type retrieved " + propertyValue.ToString()); + try { + throw new ArgumentException("Expected the property type to be one of long, float or string. Unknown type retrieved " + propertyValue.ToString()); + } finally { + allocator.FreeMemory(propertyValue); + } } /// diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs index 40f4031846161..e4e45fdd18400 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs @@ -360,11 +360,12 @@ public void EvalStep( { if (_evalOutputCount != (ulong)outputValues.Count()) { - throw new ArgumentException($"Length of {nameof(outputValues)} ({outputValues.Count}) must match that of train model ({_trainOutputCount})."); + throw new ArgumentException($"Length of {nameof(outputValues)} ({outputValues.Count}) must match that of eval model ({_evalOutputCount})."); } - IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues, true); + const bool isInput = true; + IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues, isInput); - IntPtr[] outputValuesArray = GetOrtValuesHandles(outputValues, false); /* pointers to Pre-allocated OrtValue instances */ + IntPtr[] outputValuesArray = GetOrtValuesHandles(outputValues, !isInput); /* pointers to Pre-allocated OrtValue instances */ NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtEvalStep(_nativeHandle, options.Handle, (UIntPtr)inputValues.Count, inputValuesArray, (UIntPtr)outputValues.Count, outputValuesArray)); } @@ -509,7 +510,7 @@ public void ExportModelForInferencing(string inferenceModelPath, IReadOnlyCollec /// Returns a contiguous buffer that holds a copy of all training state parameters /// /// Whether to only copy trainable parameters or to copy all parameters. - public FixedBufferOnnxValue ToBuffer(bool onlyTrainable) + public OrtValue ToBuffer(bool onlyTrainable) { UIntPtr bufferSize = UIntPtr.Zero; NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out bufferSize, onlyTrainable)); @@ -518,9 +519,9 @@ public FixedBufferOnnxValue ToBuffer(bool onlyTrainable) var memInfo = OrtMemoryInfo.DefaultInstance; // CPU var shape = new long[] { (long)bufferSize.ToUInt64() }; - var buffer = FixedBufferOnnxValue.CreateFromMemory(memInfo, bufferMemory, Tensors.TensorElementType.Float, shape, (long)bufferSize.ToUInt64() * sizeof(float)); + var buffer = OrtValue.CreateAllocatedTensorValue(OrtAllocator.DefaultInstance, Tensors.TensorElementType.Float, shape); - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyParametersToBuffer(_nativeHandle, buffer.Value.Handle, onlyTrainable)); + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyParametersToBuffer(_nativeHandle, buffer.Handle, onlyTrainable)); return buffer; } @@ -529,15 +530,15 @@ public FixedBufferOnnxValue ToBuffer(bool onlyTrainable) /// Loads the training session model parameters from a contiguous buffer /// /// Contiguous buffer to load the parameters from. - public void FromBuffer(FixedBufferOnnxValue buffer) + public void FromBuffer(OrtValue buffer) { - if (buffer.OnnxValueType != OnnxValueType.ONNX_TYPE_TENSOR) + if (buffer.OnnxType != OnnxValueType.ONNX_TYPE_TENSOR) { throw new ArgumentException("Incorrect buffer received. Expected a tensor buffer."); } IntPtr typeAndShapeInfo = IntPtr.Zero; - NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorTypeAndShape(buffer.Value.Handle, out typeAndShapeInfo)); + NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorTypeAndShape(buffer.Handle, out typeAndShapeInfo)); UIntPtr numDimensions = UIntPtr.Zero; NativeApiStatus.VerifySuccess(NativeMethods.OrtGetDimensionsCount(typeAndShapeInfo, out numDimensions)); if (numDimensions.ToUInt64() != 1) @@ -551,22 +552,23 @@ public void FromBuffer(FixedBufferOnnxValue buffer) // OrtGetParametersSize returns the total number of elements in the model's parameters. UIntPtr numElementsTrainingOnly = UIntPtr.Zero; - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out numElementsTrainingOnly, true)); + const bool onlyTrainable = true; + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out numElementsTrainingOnly, onlyTrainable)); if ((ulong)bufferSize == (ulong)numElementsTrainingOnly) { - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyBufferToParameters(_nativeHandle, buffer.Value.Handle, true)); + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyBufferToParameters(_nativeHandle, buffer.Handle, onlyTrainable)); return; } UIntPtr numElements = UIntPtr.Zero; - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out numElements, false)); + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out numElements, !onlyTrainable)); if ((ulong)bufferSize != (ulong)numElements) { string errorMessage = "Incorrect buffer size received. Expected size to be one of " + numElementsTrainingOnly.ToString() + " (training only) or " + numElements.ToString() + " (all parameters). Actual size: " + bufferSize.ToString(); throw new ArgumentException(errorMessage); } - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyBufferToParameters(_nativeHandle, buffer.Value.Handle, false)); + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyBufferToParameters(_nativeHandle, buffer.Handle, !onlyTrainable)); } /// diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs index 82d8bbe715b74..5632d34e1431a 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs @@ -484,20 +484,23 @@ public void TestEvalModelOutputNames() public void TestToBuffer() { string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt"); - using (var cleanUp = new DisposableListTest()) + string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); + string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx"); + string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx"); + + using (var state = CheckpointState.LoadCheckpoint(checkpointPath)) + using (var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath)) { - var state = CheckpointState.LoadCheckpoint(checkpointPath); - cleanUp.Add(state); Assert.NotNull(state); - string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); - string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx"); - string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx"); - - var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath); - cleanUp.Add(trainingSession); - var buffer = trainingSession.ToBuffer(true); - cleanUp.Add(buffer); + using (var buffer = trainingSession.ToBuffer(true)) + { + Assert.NotNull(buffer); + var typeShape = buffer.GetTensorTypeAndShape(); + Assert.Equal(1, typeShape.DimensionsCount); + var fetchedShape = typeShape.Shape; + Assert.Equal(397510, fetchedShape[0]); + } } } @@ -505,22 +508,25 @@ public void TestToBuffer() public void TestFromBuffer() { string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt"); - using (var cleanUp = new DisposableListTest()) + string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); + string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx"); + string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx"); + + using (var state = CheckpointState.LoadCheckpoint(checkpointPath)) + using (var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath)) { - var state = CheckpointState.LoadCheckpoint(checkpointPath); - cleanUp.Add(state); Assert.NotNull(state); - string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); - string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx"); - string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx"); - - var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath); - cleanUp.Add(trainingSession); - var buffer = trainingSession.ToBuffer(true); - cleanUp.Add(buffer); + using (var buffer = trainingSession.ToBuffer(true)) + { + Assert.NotNull(buffer); + var typeShape = buffer.GetTensorTypeAndShape(); + Assert.Equal(1, typeShape.DimensionsCount); + var fetchedShape = typeShape.Shape; + Assert.Equal(397510, fetchedShape[0]); - trainingSession.FromBuffer(buffer); + trainingSession.FromBuffer(buffer); + } } } @@ -534,24 +540,18 @@ public void TestSetSeed() public void TestGetParameter() { string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt"); - using (var cleanUp = new DisposableListTest()) + string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); + string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx"); + string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx"); + + using (var state = CheckpointState.LoadCheckpoint(checkpointPath)) + using (var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath)) + using (var parameter = state.GetParameter("fc1.weight")) { - var state = CheckpointState.LoadCheckpoint(checkpointPath); - cleanUp.Add(state); Assert.NotNull(state); - string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); - string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx"); - string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx"); - - var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath); - cleanUp.Add(trainingSession); - - var parameter = state.GetParameter("fc1.weight"); - cleanUp.Add(parameter); - Assert.NotNull(parameter); - var typeShape = parameter.GetTensorTypeAndShape(); + var typeShape = parameter.GetTensorTypeAndShape(); Assert.Equal(2, typeShape.DimensionsCount); var fetchedShape = typeShape.Shape; Assert.Equal(500, fetchedShape[0]); @@ -563,54 +563,52 @@ public void TestGetParameter() public void TestUpdateParameter() { string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt"); - using (var cleanUp = new DisposableListTest()) + string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); + string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx"); + string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx"); + + using (var state = CheckpointState.LoadCheckpoint(checkpointPath)) + using (var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath)) { - var state = CheckpointState.LoadCheckpoint(checkpointPath); - cleanUp.Add(state); Assert.NotNull(state); - string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); - string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx"); - string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx"); - - var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath); - cleanUp.Add(trainingSession); - var parameter = state.GetParameter("fc1.weight"); - cleanUp.Add(parameter); - - Assert.NotNull(parameter); - var typeShape = parameter.GetTensorTypeAndShape(); - - Assert.Equal(2, typeShape.DimensionsCount); - var fetchedShape = typeShape.Shape; - Assert.Equal(500, fetchedShape[0]); - Assert.Equal(784, fetchedShape[1]); - - float maxVal = 20; - Random randNum = new Random(); - float[] updated_parameter_buffer = Enumerable - .Repeat(0, 500 * 784) - .Select(i => maxVal * (float)randNum.NextDouble()) - .ToArray(); - - var updated_parameter = OrtValue.CreateTensorValueFromMemory(updated_parameter_buffer, fetchedShape); - cleanUp.Add(updated_parameter); - - state.UpdateParameter("fc1.weight", updated_parameter); - var current_parameter = state.GetParameter("fc1.weight"); - cleanUp.Add(current_parameter); - - var current_parameter_tensor = current_parameter.GetTensorDataAsSpan().ToArray(); - Assert.Equal(updated_parameter_buffer, current_parameter_tensor); - Assert.NotEqual(parameter.GetTensorDataAsSpan().ToArray(), current_parameter_tensor); - - state.UpdateParameter("fc1.weight", parameter); - current_parameter = state.GetParameter("fc1.weight"); - cleanUp.Add(current_parameter); - - current_parameter_tensor = current_parameter.GetTensorDataAsSpan().ToArray(); - Assert.Equal(parameter.GetTensorDataAsSpan().ToArray(), current_parameter_tensor); - Assert.NotEqual(updated_parameter_buffer, current_parameter_tensor); + using (var parameter = state.GetParameter("fc1.weight")) + { + Assert.NotNull(parameter); + var typeShape = parameter.GetTensorTypeAndShape(); + + Assert.Equal(2, typeShape.DimensionsCount); + var fetchedShape = typeShape.Shape; + Assert.Equal(500, fetchedShape[0]); + Assert.Equal(784, fetchedShape[1]); + + float maxVal = 20; + Random randNum = new Random(); + float[] updated_parameter_buffer = Enumerable + .Repeat(0, 500 * 784) + .Select(i => maxVal * (float)randNum.NextDouble()) + .ToArray(); + + using (var updated_parameter = OrtValue.CreateTensorValueFromMemory(updated_parameter_buffer, fetchedShape)) + { + state.UpdateParameter("fc1.weight", updated_parameter); + using (var current_parameter = state.GetParameter("fc1.weight")) + { + var current_parameter_tensor = current_parameter.GetTensorDataAsSpan().ToArray(); + Assert.Equal(updated_parameter_buffer, current_parameter_tensor); + Assert.NotEqual(parameter.GetTensorDataAsSpan().ToArray(), current_parameter_tensor); + } + + state.UpdateParameter("fc1.weight", parameter); + + using (var current_parameter = state.GetParameter("fc1.weight")) + { + var current_parameter_tensor = current_parameter.GetTensorDataAsSpan().ToArray(); + Assert.Equal(parameter.GetTensorDataAsSpan().ToArray(), current_parameter_tensor); + Assert.NotEqual(updated_parameter_buffer, current_parameter_tensor); + } + } + } } } From 17ae5dfd8d6fcb14f016ca23e1f124d0b69e77d4 Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Mon, 18 Sep 2023 21:06:03 -0700 Subject: [PATCH 17/25] Address C# bindings pull request review comments --- .../Training/CheckpointState.shared.cs | 28 ++++++++++++------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs index c0b5d8b3ae8ca..a31626ea85a28 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs @@ -40,20 +40,17 @@ internal enum PropertyType : long String = 2 } - private void AddPropertyImpl(string propertyName, PropertyType propertyType, T propertyValue) + private void AddPropertyImpl(string propertyName, PropertyType propertyType, T propertyValue) where T : unmanaged { var propertyNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyName); T[] value = new T[1]; value[0] = propertyValue; - Memory memory = value; - using (var memHandle = memory.Pin()) + unsafe { - IntPtr memPtr; - unsafe + fixed (T* memPtr = value) { - memPtr = (IntPtr)memHandle.Pointer; + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, propertyType, (IntPtr)memPtr)); } - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, propertyType, memPtr)); } } @@ -191,9 +188,12 @@ public object GetProperty(string propertyName) return NativeOnnxValueHelper.StringFromNativeUtf8(propertyValue, allocator); } - try { + try + { throw new ArgumentException("Expected the property type to be one of long, float or string. Unknown type retrieved " + propertyValue.ToString()); - } finally { + } + finally + { allocator.FreeMemory(propertyValue); } } @@ -240,7 +240,15 @@ public OrtValue GetParameter(string parameterName) var typeAndShapeInfo = new OrtTensorTypeAndShapeInfo(typeAndShapeInfoHandle); var parameter = OrtValue.CreateAllocatedTensorValue(OrtAllocator.DefaultInstance, typeAndShapeInfo.ElementDataType, typeAndShapeInfo.Shape); - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParameter(handle, parameterNameUtf8, parameter.Handle)); + try + { + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParameter(handle, parameterNameUtf8, parameter.Handle)); + } + catch (OnnxRuntimeException e) + { + parameter.Dispose(); + throw e; + } return parameter; } From a14546e345be1bc7ebde5346af13ad6463248f15 Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Tue, 19 Sep 2023 11:25:28 -0700 Subject: [PATCH 18/25] Address pull request review comments for C# and C API --- .../Training/CheckpointState.shared.cs | 29 ++----------- .../Training/NativeTrainingMethods.shared.cs | 3 +- .../Training/TrainingSession.shared.cs | 41 ++++++------------- .../TrainingTest.cs | 2 +- .../include/onnxruntime_training_c_api.h | 6 ++- .../include/onnxruntime_training_cxx_inline.h | 13 ++---- .../onnxruntime_training_c_api.cc | 13 +++++- .../training_api/ort_training_apis.h | 3 +- 8 files changed, 38 insertions(+), 72 deletions(-) diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs index a31626ea85a28..8eae86aa8588e 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs @@ -43,8 +43,7 @@ internal enum PropertyType : long private void AddPropertyImpl(string propertyName, PropertyType propertyType, T propertyValue) where T : unmanaged { var propertyNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyName); - T[] value = new T[1]; - value[0] = propertyValue; + T[] value = { propertyValue }; unsafe { fixed (T* memPtr = value) @@ -232,31 +231,9 @@ public void UpdateParameter(string parameterName, OrtValue parameter) public OrtValue GetParameter(string parameterName) { var parameterNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(parameterName); + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParameter(handle, parameterNameUtf8, OrtAllocator.DefaultInstance.Pointer, out IntPtr parameterHandle)); - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParameterTypeAndShape(handle, parameterNameUtf8, out IntPtr typeAndShapeInfoHandle)); - - try - { - var typeAndShapeInfo = new OrtTensorTypeAndShapeInfo(typeAndShapeInfoHandle); - var parameter = OrtValue.CreateAllocatedTensorValue(OrtAllocator.DefaultInstance, typeAndShapeInfo.ElementDataType, typeAndShapeInfo.Shape); - - try - { - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParameter(handle, parameterNameUtf8, parameter.Handle)); - } - catch (OnnxRuntimeException e) - { - parameter.Dispose(); - throw e; - } - - return parameter; - } - finally - { - NativeMethods.OrtReleaseTensorTypeAndShapeInfo(typeAndShapeInfoHandle); - } - + return new OrtValue(parameterHandle); } #region SafeHandle diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs index 6f1d94a8a8d25..d6341b90f28ff 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs @@ -387,7 +387,8 @@ out UIntPtr inputCount public delegate IntPtr /*(OrtStatus*)*/ DOrtGetParameter( IntPtr /*(OrtCheckpointState*)*/ checkpointState, byte[] /*(const char*)*/ parameterName, - IntPtr /*(OrtValue*)*/ parameter + IntPtr /*(OrtAllocator*)*/ allocator, + out IntPtr /*(OrtValue**)*/ parameter ); public static DOrtGetParameter OrtGetParameter; diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs index e4e45fdd18400..877677dcad57b 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs @@ -517,8 +517,7 @@ public OrtValue ToBuffer(bool onlyTrainable) float[] bufferMemory = new float[bufferSize.ToUInt64()]; - var memInfo = OrtMemoryInfo.DefaultInstance; // CPU - var shape = new long[] { (long)bufferSize.ToUInt64() }; + var shape = new long[] { (long)bufferSize }; var buffer = OrtValue.CreateAllocatedTensorValue(OrtAllocator.DefaultInstance, Tensors.TensorElementType.Float, shape); NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyParametersToBuffer(_nativeHandle, buffer.Handle, onlyTrainable)); @@ -529,46 +528,30 @@ public OrtValue ToBuffer(bool onlyTrainable) /// /// Loads the training session model parameters from a contiguous buffer /// - /// Contiguous buffer to load the parameters from. - public void FromBuffer(OrtValue buffer) + /// Contiguous buffer to load the parameters from. + /// Whether to only load trainable parameters or to load all parameters. + public void FromBuffer(OrtValue ortValue, bool onlyTrainable) { - if (buffer.OnnxType != OnnxValueType.ONNX_TYPE_TENSOR) + if (ortValue.OnnxType != OnnxValueType.ONNX_TYPE_TENSOR) { throw new ArgumentException("Incorrect buffer received. Expected a tensor buffer."); } - IntPtr typeAndShapeInfo = IntPtr.Zero; - NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorTypeAndShape(buffer.Handle, out typeAndShapeInfo)); - UIntPtr numDimensions = UIntPtr.Zero; - NativeApiStatus.VerifySuccess(NativeMethods.OrtGetDimensionsCount(typeAndShapeInfo, out numDimensions)); - if (numDimensions.ToUInt64() != 1) + var tensorInfo = ortValue.GetTensorTypeAndShape(); + if (tensorInfo.ElementDataType != Tensors.TensorElementType.Float) { - string errorMessage = "Incorrect buffer shape received. Expected a contiguous tensor buffer. Expected number of dimensions: 1, Actual: " + numDimensions.ToString(); - throw new ArgumentException(errorMessage); - } - - // Here buffer size represents the number of elements in the buffer - NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorShapeElementCount(typeAndShapeInfo, out UIntPtr bufferSize)); - - // OrtGetParametersSize returns the total number of elements in the model's parameters. - UIntPtr numElementsTrainingOnly = UIntPtr.Zero; - const bool onlyTrainable = true; - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out numElementsTrainingOnly, onlyTrainable)); - if ((ulong)bufferSize == (ulong)numElementsTrainingOnly) - { - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyBufferToParameters(_nativeHandle, buffer.Handle, onlyTrainable)); - return; + throw new ArgumentException("Incorrect buffer received. Expected a tensor buffer of type float."); } UIntPtr numElements = UIntPtr.Zero; - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out numElements, !onlyTrainable)); - if ((ulong)bufferSize != (ulong)numElements) + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out numElements, onlyTrainable)); + if ((ulong)tensorInfo.ElementCount != (ulong)numElements) { - string errorMessage = "Incorrect buffer size received. Expected size to be one of " + numElementsTrainingOnly.ToString() + " (training only) or " + numElements.ToString() + " (all parameters). Actual size: " + bufferSize.ToString(); + string errorMessage = "Incorrect buffer size received. Expected size to be " + numElements.ToString() + ". Actual size: " + tensorInfo.ElementCount.ToString(); throw new ArgumentException(errorMessage); } - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyBufferToParameters(_nativeHandle, buffer.Handle, !onlyTrainable)); + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyBufferToParameters(_nativeHandle, ortValue.Handle, onlyTrainable)); } /// diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs index 5632d34e1431a..68b1d5bcc6147 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs @@ -525,7 +525,7 @@ public void TestFromBuffer() var fetchedShape = typeShape.Shape; Assert.Equal(397510, fetchedShape[0]); - trainingSession.FromBuffer(buffer); + trainingSession.FromBuffer(buffer, true); } } } diff --git a/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h b/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h index 71b64ead0d388..0e8544a7639ba 100644 --- a/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h +++ b/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h @@ -704,19 +704,21 @@ struct OrtTrainingApi { /** \brief Gets the data associated with the model parameter from the checkpoint state for the given parameter name. * * This function retrieves the model parameter data from the checkpoint state for the given parameter name. - * The parameter is copied over to the provided OrtValue. The training session must be already created + * The parameter is copied over and returned as an OrtValue. The training session must be already created * with the checkpoint state that contains the parameter being retrieved. * The parameter must exist in the checkpoint state to be able to retrieve it successfully. * * \param[in] checkpoint_state The checkpoint state. * \param[in] parameter_name Name of the parameter being retrieved. + * \param[in] allocator Allocator used to allocate the memory for the parameter. * \param[out] parameter The parameter data that is retrieved from the checkpoint state. * * \snippet{doc} snippets.dox OrtStatus Return Value * */ ORT_API2_STATUS(GetParameter, _In_ const OrtCheckpointState* checkpoint_state, - _In_ const char* parameter_name, _Inout_ OrtValue* parameter); + _In_ const char* parameter_name, _Inout_ OrtAllocator* allocator, + _Outptr_ OrtValue** parameter); /// @} }; diff --git a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h index d6a8764a7dbe5..a5efa3c0e4bef 100644 --- a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h +++ b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h @@ -284,18 +284,11 @@ inline void CheckpointState::UpdateParameter(const std::string& parameter_name, } inline Value CheckpointState::GetParameter(const std::string& parameter_name) { - OrtTensorTypeAndShapeInfo* parameter_type_and_shape_info; - ThrowOnError(GetTrainingApi().GetParameterTypeAndShape(p_, parameter_name.c_str(), ¶meter_type_and_shape_info)); - auto parameter_type_and_shape = TensorTypeAndShapeInfo{parameter_type_and_shape_info}; - auto shape = parameter_type_and_shape.GetShape(); - AllocatorWithDefaultOptions allocator; - Value parameter = Value::CreateTensor(allocator, shape.data(), shape.size(), - ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); - - ThrowOnError(GetTrainingApi().GetParameter(p_, parameter_name.c_str(), parameter)); + OrtValue* parameter; + ThrowOnError(GetTrainingApi().GetParameter(p_, parameter_name.c_str(), allocator, ¶meter)); - return parameter; + return Value{parameter}; } } // namespace Ort diff --git a/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc b/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc index 23649d6d34b9b..0fd9242d68f75 100644 --- a/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc +++ b/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc @@ -599,7 +599,8 @@ ORT_API_STATUS_IMPL(OrtTrainingApis::UpdateParameter, _Inout_ OrtCheckpointState } ORT_API_STATUS_IMPL(OrtTrainingApis::GetParameter, _In_ const OrtCheckpointState* checkpoint_state, - _In_ const char* parameter_name, _Inout_ OrtValue* parameter) { + _In_ const char* parameter_name, _Inout_ OrtAllocator* allocator, + _Outptr_ OrtValue** parameter) { API_IMPL_BEGIN if (parameter == nullptr) { @@ -613,8 +614,16 @@ ORT_API_STATUS_IMPL(OrtTrainingApis::GetParameter, _In_ const OrtCheckpointState return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, err_msg.c_str()); } + if (!it->second->Data().IsTensor()) { + return OrtApis::CreateStatus(ORT_FAIL, "Expected a tensor type for the parameter. Found a non-tensor type."); + } + const auto& parameter_tensor = it->second->Data().Get(); + ORT_API_RETURN_IF_ERROR(OrtApis::CreateTensorAsOrtValue( + allocator, parameter_tensor.Shape().GetDims().data(), parameter_tensor.Shape().NumDimensions(), + ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, parameter)); + ORT_API_RETURN_IF_STATUS_NOT_OK(it->second->CopyTo( - chkpt_state->module_checkpoint_state.train_session_data_transfer_mgr, *parameter)); + chkpt_state->module_checkpoint_state.train_session_data_transfer_mgr, **parameter)); return nullptr; API_IMPL_END diff --git a/orttraining/orttraining/training_api/ort_training_apis.h b/orttraining/orttraining/training_api/ort_training_apis.h index 6d65d786848cd..2a8c1e30361c6 100644 --- a/orttraining/orttraining/training_api/ort_training_apis.h +++ b/orttraining/orttraining/training_api/ort_training_apis.h @@ -101,6 +101,7 @@ ORT_API_STATUS_IMPL(UpdateParameter, _Inout_ OrtCheckpointState* checkpoint_stat _In_ const char* parameter_name, _In_ OrtValue* parameter); ORT_API_STATUS_IMPL(GetParameter, _In_ const OrtCheckpointState* checkpoint_state, - _In_ const char* parameter_name, _Inout_ OrtValue* parameter); + _In_ const char* parameter_name, _Inout_ OrtAllocator* allocator, + _Outptr_ OrtValue** parameter); } // namespace OrtTrainingApis From 680337407cae8243034a7a070b0a0042f47d5b75 Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Wed, 20 Sep 2023 15:59:29 -0700 Subject: [PATCH 19/25] Address C# comments --- .../Training/CheckpointState.shared.cs | 41 ++++++++++--------- .../orttraining/training_api/module.cc | 6 ++- orttraining/orttraining/training_api/module.h | 4 +- .../onnxruntime_training_c_api.cc | 6 +-- 4 files changed, 32 insertions(+), 25 deletions(-) diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs index 8eae86aa8588e..93105d0afa02d 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs @@ -162,33 +162,34 @@ public object GetProperty(string propertyName) var propertyNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyName); var allocator = OrtAllocator.DefaultInstance; IntPtr propertyValue = IntPtr.Zero; - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetProperty(handle, propertyNameUtf8, allocator.Pointer, out PropertyType propertyType, out propertyValue)); - if (propertyType == PropertyType.Int) + try { - Int64 value; - unsafe + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetProperty(handle, propertyNameUtf8, allocator.Pointer, out PropertyType propertyType, out propertyValue)); + + if (propertyType == PropertyType.Int) { - value = *(Int64*)propertyValue; + Int64 value; + unsafe + { + value = *(Int64*)propertyValue; + } + return value; } - return value; - } - else if (propertyType == PropertyType.Float) - { - float value; - unsafe + else if (propertyType == PropertyType.Float) { - value = *(float*)propertyValue; + float value; + unsafe + { + value = *(float*)propertyValue; + } + return value; + } + else if (propertyType == PropertyType.String) + { + return NativeOnnxValueHelper.StringFromNativeUtf8(propertyValue); } - return value; - } - else if (propertyType == PropertyType.String) - { - return NativeOnnxValueHelper.StringFromNativeUtf8(propertyValue, allocator); - } - try - { throw new ArgumentException("Expected the property type to be one of long, float or string. Unknown type retrieved " + propertyValue.ToString()); } finally diff --git a/orttraining/orttraining/training_api/module.cc b/orttraining/orttraining/training_api/module.cc index e86526f902a9c..cea54bc65a81e 100644 --- a/orttraining/orttraining/training_api/module.cc +++ b/orttraining/orttraining/training_api/module.cc @@ -137,7 +137,7 @@ Status Parameter::CopyTo(const DataTransferManager* data_transfer_manager, OrtVa return Status::OK(); } -Status Parameter::CopyFrom(const OrtValue& data, const DataTransferManager* data_transfer_manager) { +Status Parameter::CopyFrom(const DataTransferManager* data_transfer_manager, const OrtValue& data) { ORT_ENFORCE(data_.IsAllocated(), "The checkpoint parameter is not allocated. Cannot copy the given parameter data to it."); ORT_ENFORCE(data.IsTensor(), "Parameter data should be of tensor type."); @@ -371,6 +371,10 @@ Module::Module(const ModelIdentifiers& model_identifiers, } } +Module::~Module() { + state_->module_checkpoint_state.train_session_data_transfer_mgr = nullptr; +} + size_t Module::GetTrainingModelOutputCount() const noexcept { return train_output_names_.size(); } diff --git a/orttraining/orttraining/training_api/module.h b/orttraining/orttraining/training_api/module.h index a638a421ecf90..f323e6be72d49 100644 --- a/orttraining/orttraining/training_api/module.h +++ b/orttraining/orttraining/training_api/module.h @@ -22,7 +22,7 @@ struct Parameter { // Return the mutable data. OrtValue& Data() { return data_; } Status CopyTo(const DataTransferManager* data_transfer_manager, OrtValue& data) const; - Status CopyFrom(const OrtValue& data, const DataTransferManager* data_transfer_manager); + Status CopyFrom(const DataTransferManager* data_transfer_manager, const OrtValue& data); const std::string& Name() const { return name_; } // Returns whether this parameter is trainable or not. @@ -84,6 +84,8 @@ struct Module { const std::vector>& providers, gsl::span op_domains = gsl::span()); + ~Module(); + // Return the trainable/nontrainable parameters std::vector> Parameters() const; diff --git a/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc b/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc index 0fd9242d68f75..45aeaebaac236 100644 --- a/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc +++ b/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc @@ -592,7 +592,7 @@ ORT_API_STATUS_IMPL(OrtTrainingApis::UpdateParameter, _Inout_ OrtCheckpointState return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, err_msg.c_str()); } ORT_API_RETURN_IF_STATUS_NOT_OK(it->second->CopyFrom( - *parameter, chkpt_state->module_checkpoint_state.train_session_data_transfer_mgr)); + chkpt_state->module_checkpoint_state.train_session_data_transfer_mgr, *parameter)); return nullptr; API_IMPL_END @@ -619,8 +619,8 @@ ORT_API_STATUS_IMPL(OrtTrainingApis::GetParameter, _In_ const OrtCheckpointState } const auto& parameter_tensor = it->second->Data().Get(); ORT_API_RETURN_IF_ERROR(OrtApis::CreateTensorAsOrtValue( - allocator, parameter_tensor.Shape().GetDims().data(), parameter_tensor.Shape().NumDimensions(), - ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, parameter)); + allocator, parameter_tensor.Shape().GetDims().data(), parameter_tensor.Shape().NumDimensions(), + ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, parameter)); ORT_API_RETURN_IF_STATUS_NOT_OK(it->second->CopyTo( chkpt_state->module_checkpoint_state.train_session_data_transfer_mgr, **parameter)); From d8c313d092575fa8bb9ee9e7dd61339a631b9cf4 Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Wed, 20 Sep 2023 16:09:35 -0700 Subject: [PATCH 20/25] Address C# comments --- .../Training/CheckpointState.shared.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs index 93105d0afa02d..6889112acb385 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs @@ -163,10 +163,10 @@ public object GetProperty(string propertyName) var allocator = OrtAllocator.DefaultInstance; IntPtr propertyValue = IntPtr.Zero; + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetProperty(handle, propertyNameUtf8, allocator.Pointer, out PropertyType propertyType, out propertyValue)); + try { - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetProperty(handle, propertyNameUtf8, allocator.Pointer, out PropertyType propertyType, out propertyValue)); - if (propertyType == PropertyType.Int) { Int64 value; From e1fb060eded4940e44f4de3b58e428d5fd8b1e0b Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Thu, 21 Sep 2023 18:56:48 +0000 Subject: [PATCH 21/25] Address pull request review comments --- .../python/orttraining_pybind_state.cc | 22 +- .../python/training/api/checkpoint_state.py | 202 +++++++++++++----- .../orttraining_test_python_bindings.py | 33 +-- .../orttraining/training_api/module.cc | 18 ++ .../onnxruntime_training_c_api.cc | 8 +- 5 files changed, 217 insertions(+), 66 deletions(-) diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc index bf6c7666b6a80..35d9755ba0ba7 100644 --- a/orttraining/orttraining/python/orttraining_pybind_state.cc +++ b/orttraining/orttraining/python/orttraining_pybind_state.cc @@ -1087,7 +1087,7 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn ORT_THROW("Parameter with name ", parameter_name, " does not exist."); } ORT_THROW_IF_ERROR(it->second->CopyFrom( - value, state->module_checkpoint_state.train_session_data_transfer_mgr)); + state->module_checkpoint_state.train_session_data_transfer_mgr, value)); }) .def("get_parameter", [](onnxruntime::training::api::CheckpointState* state, const std::string& parameter_name) { @@ -1100,6 +1100,24 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn .def("has_parameter", [](onnxruntime::training::api::CheckpointState* state, const std::string& parameter_name) { return state->module_checkpoint_state.named_parameters.count(parameter_name); + }) + .def("parameter_names", + [](onnxruntime::training::api::CheckpointState* state) { + std::vector names; + for ([[maybe_unused]] auto& [name, value] : state->module_checkpoint_state.named_parameters) { + names.push_back(name); + } + std::sort(names.begin(), names.end()); + return names; + }) + .def("property_names", + [](onnxruntime::training::api::CheckpointState* state) { + std::vector names; + for ([[maybe_unused]] auto& [name, value] : state->property_bag) { + names.push_back(name); + } + std::sort(names.begin(), names.end()); + return names; }); py::class_ @@ -1148,7 +1166,7 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn [](onnxruntime::training::api::Parameter* parameter, onnxruntime::training::api::CheckpointState* state, OrtValue& value) -> void { - ORT_THROW_IF_ERROR(parameter->CopyFrom(value, state->module_checkpoint_state.train_session_data_transfer_mgr)); + ORT_THROW_IF_ERROR(parameter->CopyFrom(state->module_checkpoint_state.train_session_data_transfer_mgr, value)); }); m.def( diff --git a/orttraining/orttraining/python/training/api/checkpoint_state.py b/orttraining/orttraining/python/training/api/checkpoint_state.py index d723c86711124..ba95cd04fce7e 100644 --- a/orttraining/orttraining/python/training/api/checkpoint_state.py +++ b/orttraining/orttraining/python/training/api/checkpoint_state.py @@ -57,14 +57,154 @@ def __repr__(self) -> str: return f"Parameter(name={self.name}, requires_grad={self.requires_grad})" +class Parameters: + """Class that holds all the model parameters + + This class holds all the model parameters and provides access to them. + This class is not expected to be instantiated directly. Instead, it is returned by the + `CheckpointState`'s parameters attribute. + This class behaves like a dictionary and provides access to the parameters by name. + + Args: + state: The C.CheckpointState object that holds the underlying session state. + """ + + def __init__(self, state: C.CheckpointState): + self._state = state + + def __getitem__(self, name: str) -> Parameter: + """Gets the parameter associated with the given name + + Searches for the name in the parameters of the checkpoint state. + + Args: + name: The name of the parameter + + Returns: + The value of the parameter + + Raises: + KeyError: If the parameter is not found + """ + + if name not in self: + raise KeyError(f"Parameter {name} not found.") + + return Parameter(self._state.get_parameter(name), self._state) + + def __setitem__(self, name: str, value: np.ndarray) -> None: + """Sets the parameter value for the given name + + Searches for the name in the parameters of the checkpoint state. + If the name is found in parameters, the value is updated. + + Args: + name: The name of the parameter + value: The value of the parameter as a numpy array + + Raises: + KeyError: If the parameter is not found + """ + if name not in self: + raise KeyError(f"Parameter {name} not found.") + + self._state.copy_parameter_from(name, OrtValue.ortvalue_from_numpy(value)._ortvalue) + + def __contains__(self, name: str) -> bool: + """Checks if the parameter exists in the state + + Args: + name: The name of the parameter + + Returns: + True if the name is a parameter False otherwise + """ + + return self._state.has_parameter(name) + + def __iter__(self): + """Returns an iterator over the properties""" + for parameter_name in self._state.parameter_names(): + yield parameter_name, Parameter(self._state.get_parameter(parameter_name), self._state) + + def __repr__(self) -> str: + """Returns a string representation of the parameters""" + return self._state.parameter_names() + + def __len__(self) -> int: + """Returns the number of parameters""" + return len(self._state.parameter_names()) + + +class Properties: + def __init__(self, state: C.CheckpointState): + self._state = state + + def __getitem__(self, name: str) -> int | float | str: + """Gets the property associated with the given name + + Searches for the name in the properties of the checkpoint state. + + Args: + name: The name of the property + + Returns: + The value of the property + + Raises: + KeyError: If the property is not found + """ + + if name not in self: + raise KeyError(f"Property {name} not found.") + + return self._state.get_property(name) + + def __setitem__(self, name: str, value: int | float | str) -> None: + """Sets the property value for the given name + + Searches for the name in the properties of the checkpoint state. + The value is added or updated in the properties. + + Args: + name: The name of the property + value: The value of the property + Properties only support int, float and str values. + """ + self._state.add_property(name, value) + + def __contains__(self, name: str) -> bool: + """Checks if the property exists in the state + + Args: + name: The name of the property + + Returns: + True if the name is a property, False otherwise + """ + + return self._state.has_property(name) + + def __iter__(self): + """Returns an iterator over the properties""" + for property_name in self._state.property_names(): + yield property_name, self._state.get_property(property_name) + + def __repr__(self) -> str: + """Returns a string representation of the properties""" + return self._state.property_names() + + def __len__(self) -> int: + """Returns the number of properties""" + return len(self._state.property_names()) + + class CheckpointState: """Class that holds the state of the training session This class holds all the state information of the training session such as the model parameters, its gradients, the optimizer state and user defined properties. - User defined properties can be indexed by name from the `CheckpointState` object. - To create the `CheckpointState`, use the `CheckpointState.load_checkpoint` method. Args: @@ -75,6 +215,8 @@ def __init__(self, state: C.CheckpointState): if not isinstance(state, C.CheckpointState): raise TypeError(f"Invalid argument for CheckpointState received {type(state)}") self._state = state + self._parameters = Parameters(self._state) + self._properties = Properties(self._state) @classmethod def load_checkpoint(cls, checkpoint_uri: str | os.PathLike) -> CheckpointState: @@ -101,52 +243,12 @@ def save_checkpoint( """ C.save_checkpoint(state._state, os.fspath(checkpoint_uri), include_optimizer_state) - def __getitem__(self, name: str) -> int | float | str | Parameter: - """Gets the parameter or property associated with the given name - - Searches for the name in the parameters and properties of the checkpoint state. - - Args: - name: The name of the parameter or property - - Returns: - The value of the parameter or property - """ - - if self._state.has_parameter(name): - return Parameter(self._state.get_parameter(name), self._state) - elif self._state.has_property(name): - return self._state.get_property(name) - else: - raise KeyError(f"Could not find {name} in the checkpoint state.") - - def __setitem__(self, name: str, value: int | float | str | np.ndarray) -> None: - """Sets the parameter or property value for the given name - - Searches for the name in the parameters and properties of the checkpoint state. - If the name is found in parameters, the value is updated. - Else, the value is added or updated in the properties. - - Args: - name: The name of the parameter or property - value: The value of the parameter or property - Properties only support int, float and str values. - """ - if self._state.has_parameter(name): - self._state.copy_parameter_from(name, OrtValue.ortvalue_from_numpy(value)._ortvalue) - else: - self._state.add_property(name, value) - - def __contains__(self, name: str) -> bool: - """Checks if the parameter or property exists in the state - - Tthe name is searched in both parameters and properties. - - Args: - name: The name of the parameter or property - - Returns: - True if the name is either a parameter or a property, False otherwise - """ + @property + def parameters(self) -> Parameters: + """Returns the model parameters from the checkpoint state""" + return self._parameters - return self._state.has_parameter(name) or self._state.has_property(name) + @property + def properties(self) -> Properties: + """Returns the properties from the checkpoint state""" + return self._properties diff --git a/orttraining/orttraining/test/python/orttraining_test_python_bindings.py b/orttraining/orttraining/test/python/orttraining_test_python_bindings.py index 8debf4a9cbf10..d5c37b3e36ee7 100644 --- a/orttraining/orttraining/test/python/orttraining_test_python_bindings.py +++ b/orttraining/orttraining/test/python/orttraining_test_python_bindings.py @@ -360,14 +360,18 @@ def test_add_get_property(property_value): if isinstance(property_value, float): property_value = float(np.float32(property_value)) - state["property"] = property_value - assert "property" in state - assert state["property"] == property_value + assert len(state.properties) == 0 + + state.properties["property"] = property_value + assert "property" in state.properties + assert state.properties["property"] == property_value + assert len(state.properties) == 1 CheckpointState.save_checkpoint(state, checkpoint_file_path) new_state = CheckpointState.load_checkpoint(checkpoint_file_path) - assert "property" in new_state - assert new_state["property"] == property_value + assert "property" in new_state.properties + assert new_state.properties["property"] == property_value + assert len(new_state.properties) == 1 def test_get_input_output_names(): @@ -582,8 +586,13 @@ def test_get_and_set_parameter_values(device): model = Module(training_model_file_path, state, eval_model_file_path, device=device) + state_dict = pt_model.state_dict() + assert len(state_dict) == len(state.parameters) + for parameter_name, _ in state.parameters: + assert parameter_name in state_dict + for name, pt_param in pt_model.named_parameters(): - ort_param = state[name] + ort_param = state.parameters[name] assert ort_param.name == name assert np.allclose(pt_param.detach().cpu().numpy(), ort_param.data) if name in ["fc1.weight", "fc1.bias"]: @@ -593,9 +602,9 @@ def test_get_and_set_parameter_values(device): assert ort_param.requires_grad is True assert np.allclose(ort_param.grad, np.zeros_like(ort_param.data, dtype=np.float32)) - original_param = state["fc1.weight"].data - state["fc1.weight"].data = np.ones_like(state["fc1.weight"].data, dtype=np.float32) - updated_param = state["fc1.weight"].data + original_param = state.parameters["fc1.weight"].data + state.parameters["fc1.weight"].data = np.ones_like(state.parameters["fc1.weight"].data, dtype=np.float32) + updated_param = state.parameters["fc1.weight"].data assert np.allclose(updated_param, np.ones_like(updated_param, dtype=np.float32)) model.train() @@ -604,7 +613,7 @@ def test_get_and_set_parameter_values(device): loss = model(inputs, labels) assert loss is not None for name, _ in pt_model.named_parameters(): - ort_param = state[name] + ort_param = state.parameters[name] assert ort_param.name == name if name in ["fc1.weight", "fc1.bias"]: assert ort_param.requires_grad is False @@ -613,5 +622,5 @@ def test_get_and_set_parameter_values(device): assert ort_param.requires_grad is True assert ort_param.grad.any() - state["fc1.weight"] = original_param - assert np.allclose(state["fc1.weight"].data, original_param) + state.parameters["fc1.weight"] = original_param + assert np.allclose(state.parameters["fc1.weight"].data, original_param) diff --git a/orttraining/orttraining/training_api/module.cc b/orttraining/orttraining/training_api/module.cc index cea54bc65a81e..2e1594f9dc42b 100644 --- a/orttraining/orttraining/training_api/module.cc +++ b/orttraining/orttraining/training_api/module.cc @@ -125,6 +125,15 @@ Status Parameter::CopyTo(const DataTransferManager* data_transfer_manager, OrtVa ORT_ENFORCE(data.Get().Shape() == data_.Get().Shape(), "Parameter data shape mismatch. Expected: ", data_.Get().Shape().ToString(), ", Got: ", data.Get().Shape().ToString()); +#ifdef ENABLE_STRIDED_TENSORS + auto data_strides = data.Get().Strides(); + auto param_strides = data_.Get().Strides(); + ORT_ENFORCE(data_strides.size() == param_strides.size(), + "Parameter data stride mismatch. Expected strides of size: ", param_strides.size(), + ", Got: ", data_strides.size()); + ORT_ENFORCE(std::equal(data_strides.begin(), data_strides.end(), param_strides.begin()), + "Parameter data stride value mismatch."); +#endif ORT_ENFORCE(data.Get().DataType() == data_.Get().DataType(), "Parameter data type mismatch. Expected: ", data_.Get().DataType(), ", Got: ", data.Get().DataType()); @@ -144,6 +153,15 @@ Status Parameter::CopyFrom(const DataTransferManager* data_transfer_manager, con ORT_ENFORCE(data.Get().Shape() == data_.Get().Shape(), "Parameter data shape mismatch. Expected: ", data_.Get().Shape().ToString(), ", Got: ", data.Get().Shape().ToString()); +#ifdef ENABLE_STRIDED_TENSORS + auto data_strides = data.Get().Strides(); + auto param_strides = data_.Get().Strides(); + ORT_ENFORCE(data_strides.size() == param_strides.size(), + "Parameter data stride mismatch. Expected strides of size: ", param_strides.size(), + ", Got: ", data_strides.size()); + ORT_ENFORCE(std::equal(data_strides.begin(), data_strides.end(), param_strides.begin()), + "Parameter data stride value mismatch."); +#endif ORT_ENFORCE(data.Get().DataType() == data_.Get().DataType(), "Parameter data type mismatch. Expected: ", data_.Get().DataType(), ", Got: ", data.Get().DataType()); diff --git a/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc b/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc index 45aeaebaac236..38a9aad9640ea 100644 --- a/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc +++ b/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc @@ -622,8 +622,12 @@ ORT_API_STATUS_IMPL(OrtTrainingApis::GetParameter, _In_ const OrtCheckpointState allocator, parameter_tensor.Shape().GetDims().data(), parameter_tensor.Shape().NumDimensions(), ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, parameter)); - ORT_API_RETURN_IF_STATUS_NOT_OK(it->second->CopyTo( - chkpt_state->module_checkpoint_state.train_session_data_transfer_mgr, **parameter)); + auto status = it->second->CopyTo( + chkpt_state->module_checkpoint_state.train_session_data_transfer_mgr, **parameter); + if (!status.IsOK()) { + OrtApis::ReleaseValue(*parameter); + return onnxruntime::ToOrtStatus(status); + } return nullptr; API_IMPL_END From 265551efa569ddd46d90a71ac33b018152006ee6 Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Thu, 21 Sep 2023 18:58:28 +0000 Subject: [PATCH 22/25] fix typo --- orttraining/orttraining/training_api/module.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/orttraining/orttraining/training_api/module.cc b/orttraining/orttraining/training_api/module.cc index 2e1594f9dc42b..cf49a01517d6b 100644 --- a/orttraining/orttraining/training_api/module.cc +++ b/orttraining/orttraining/training_api/module.cc @@ -120,7 +120,7 @@ Status TransformModelInputsForInference(Graph& inference_graph, } // namespace Status Parameter::CopyTo(const DataTransferManager* data_transfer_manager, OrtValue& data) const { - ORT_ENFORCE(data.IsAllocated(), "Given parameter data is not allocated. Cannot cope the checkpoint parameter to it."); + ORT_ENFORCE(data.IsAllocated(), "Given parameter data is not allocated. Cannot copy the checkpoint parameter to it."); ORT_ENFORCE(data.IsTensor(), "Parameter data should be of tensor type."); ORT_ENFORCE(data.Get().Shape() == data_.Get().Shape(), "Parameter data shape mismatch. Expected: ", data_.Get().Shape().ToString(), From f5fe5da14f53b7e814c892f96d623001aa873031 Mon Sep 17 00:00:00 2001 From: Kaz Nishimura Date: Tue, 12 Sep 2023 10:03:29 +0900 Subject: [PATCH 23/25] Add missing member init --- .../dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h index 232a022d869f4..04381b6ce355c 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h @@ -80,7 +80,7 @@ namespace Windows::AI::MachineLearning::Adapter // Either nodesAsOperatorDesc or nodesAsIDMLOperator can have non-zero size. struct DmlGraphNodeCreateInfo { - uint32_t nodeCount; + uint32_t nodeCount = 0; std::vector> nodesAsOperatorDesc; std::vector> nodesAsIDMLOperator; std::vector inputEdges; From 321e2ac1bd20b8c51e7f012f498737de00ccea8a Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Fri, 13 Oct 2023 08:27:15 +1000 Subject: [PATCH 24/25] Fix illegal opcode error from mlas (#17885) ### Description Use cpuinfo value when checking to dot product is available. Reading the ID_AA64ISAR0_EL1 register is unsafe. ### Motivation and Context #17647 #17541 #17851 --- onnxruntime/core/mlas/lib/platform.cpp | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 86b7450a7c4e5..32cc69d0b8040 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -451,12 +451,16 @@ Return Value: #if defined(_WIN32) HasDotProductInstructions = (IsProcessorFeaturePresent(PF_ARM_V82_DP_INSTRUCTIONS_AVAILABLE) != 0); -#elif !defined(__APPLE__) // The next few lines result in an EXC_BAD_INSTRUCTION runtime error on a M1 Mac so we - // disable it there. - uint64_t isar0_el1; - asm("mrs %[reg], ID_AA64ISAR0_EL1\n" : [reg] "=r"(isar0_el1) : :); - HasDotProductInstructions = ((isar0_el1 >> 44) & 0xfu) == 0x1u; #else + // Use the cpuinfo value which is read from sysctl and has some additional special cases. + // https://github.com/pytorch/cpuinfo/blob/959002f82d7962a473d8bf301845f2af720e0aa4/src/arm/mach/init.c#L369-L379 + // Do NOT use ID_AA64ISAR0_EL1. It causes illegal instruction errors on Mac M1 and ARMv8-A chips + // as well as failing on other ARM chips as it is an EL1 level register that requires extra + // privileges to read. + // + // uint64_t isar0_el1; + // asm("mrs %[reg], ID_AA64ISAR0_EL1\n" : [reg] "=r"(isar0_el1) : :); + // HasDotProductInstructions = ((isar0_el1 >> 44) & 0xfu) == 0x1u; HasDotProductInstructions = MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeonDot(); #endif From 3744ae3bf0b43bfd3870af50979eefb2479ef678 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Wed, 1 Nov 2023 12:26:42 -0700 Subject: [PATCH 25/25] format code --- winml/test/model/skip_model_tests.h | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/winml/test/model/skip_model_tests.h b/winml/test/model/skip_model_tests.h index e2ea83faa743d..9d66320343c43 100644 --- a/winml/test/model/skip_model_tests.h +++ b/winml/test/model/skip_model_tests.h @@ -147,7 +147,7 @@ std::unordered_map disabledTests({ { "mlperf_ssd_mobilenet_300_opset10_GPU", "Bug 31005624: mlperf_ssd_mobilenet_300 opset 10 model fails to evaluate in DirectML https://microsoft.visualstudio.com/OS/_workitems/edit/31005624" }, { "mlperf_ssd_resnet34_1200_opset10_GPU", - "Bug 31005624: mlperf_ssd_resnet34_1200_opset10_GPU opset 10 model fails to evaluate in DirectML https://microsoft.visualstudio.com/OS/_workitems/edit/31005624" }, + "Bug 31005624: mlperf_ssd_resnet34_1200_opset10_GPU opset 10 model fails to evaluate in DirectML https://microsoft.visualstudio.com/OS/_workitems/edit/31005624" }, }); /* @@ -163,10 +163,8 @@ std::unordered_map> disabledGpu test name -> absolute difference sampleTolerance */ std::unordered_map sampleTolerancePerTests({ - {"fp16_inception_v1_opset7_GPU",0.005 }, - {"fp16_inception_v1_opset8_GPU", 0.005}, - { "candy_opset9_GPU", - 0.00150000 }, // Intel(R) UHD Graphics 630 (29.20.100.9020) AP machine has inaccurate GPU results for FNS Candy opset 9 https://microsoft.visualstudio.com/OS/_workitems/edit/30696168/ - { "fp16_tiny_yolov2_opset8_GPU", - 0.109000 }, // Intel(R) UHD Graphics 630 (29.20.100.9020) AP machine has inaccurate GPU results for FNS Candy opset 9 https://microsoft.visualstudio.com/OS/_workitems/edit/30696168/ + {"fp16_inception_v1_opset7_GPU", 0.005}, + {"fp16_inception_v1_opset8_GPU", 0.005}, + { "candy_opset9_GPU", 0.00150000}, // Intel(R) UHD Graphics 630 (29.20.100.9020) AP machine has inaccurate GPU results for FNS Candy opset 9 https://microsoft.visualstudio.com/OS/_workitems/edit/30696168/ + { "fp16_tiny_yolov2_opset8_GPU", 0.109000}, // Intel(R) UHD Graphics 630 (29.20.100.9020) AP machine has inaccurate GPU results for FNS Candy opset 9 https://microsoft.visualstudio.com/OS/_workitems/edit/30696168/ });