diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs index 877677dcad57b..fec0d46e96dfb 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs @@ -282,6 +282,48 @@ public IDisposableReadOnlyCollection TrainStep( } } + /// + /// This function performs a training step that computes the outputs of the training model and the gradients + /// of the trainable parameters for the given OrtValue inputs. The train step is performed based on the training model + /// that was provided to the training session. + /// The TrainStep method is equivalent of running forward propagation and backward propagation in a single + /// step. + /// The gradients computed are stored inside the training session state so they can be later consumed + /// by the OptimizerStep function. + /// The gradients can be lazily reset by invoking the LazyResetGrad function. + /// Example usage: + /// + /// using OrtValue x = OrtValue.CreateTensorValueFromMemory(...); + /// using OrtValue label = OrtValue.CreateTensorValueFromMemory(...); + /// List inputValues = new List { x, label }; + /// using (var loss = trainingSession.TrainStep(inputValues)) + /// { + /// // process output values + /// } + /// + /// + /// Specify a collection of that indicates the input values to the training model. + /// Output Tensors in a Collection of NamedOnnxValue. User must dispose the output. + public IDisposableReadOnlyCollection TrainStep(IReadOnlyCollection inputValues) + { + IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues); + IntPtr[] outputValuesArray = new IntPtr[(int)_trainOutputCount]; + + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtTrainStep(_nativeHandle, IntPtr.Zero, (UIntPtr)inputValues.Count, + inputValuesArray, (UIntPtr)_trainOutputCount, outputValuesArray)); + + + var disposableHandles = new DisposableOrtValueHandleArray(outputValuesArray); + try + { + return CreateDisposableResult(disposableHandles); + } + finally + { + disposableHandles.Dispose(); + } + } + /// /// Convert native OrtValue handles to OrtValue instances /// in an exceptions safe manner. @@ -370,6 +412,42 @@ public void EvalStep( inputValuesArray, (UIntPtr)outputValues.Count, outputValuesArray)); } + /// + /// This function performs an eval step that computes the outputs of the eval model for the given inputs. + /// Inputs are expected to be of type OrtValue. The eval step is performed based on the eval model that was + /// provided to the training session. + /// Example usage: + /// + /// using OrtValue x = OrtValue.CreateTensorValueFromMemory(...); + /// using OrtValue label = OrtValue.CreateTensorValueFromMemory(...); + /// List inputValues = new List { x, label }; + /// using (var loss = trainingSession.EvalSteps(inputValues)) + /// { + /// // process output values + /// } + /// + /// + /// Specify a collection of that indicates the input values to the eval model. + public IDisposableReadOnlyCollection EvalStep(IReadOnlyCollection inputValues) + { + IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues); + IntPtr[] outputValuesArray = new IntPtr[(int)_evalOutputCount]; + + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtEvalStep(_nativeHandle, IntPtr.Zero, (UIntPtr)inputValues.Count, + inputValuesArray, (UIntPtr)_evalOutputCount, outputValuesArray)); + + + var disposableHandles = new DisposableOrtValueHandleArray(outputValuesArray); + try + { + return CreateDisposableResult(disposableHandles); + } + finally + { + disposableHandles.Dispose(); + } + } + /// /// Sets the learning rate for this training session. @@ -702,6 +780,35 @@ private IntPtr[] GetOrtValuesHandles(IReadOnlyCollection v return valuesArray; } + private IntPtr[] GetOrtValuesHandles(IReadOnlyCollection inputValues) + { + var valuesArray = new IntPtr[inputValues.Count]; + for (int index = 0; index < inputValues.Count; ++index) + { + valuesArray[index] = inputValues.ElementAt(index).Handle; + } + return valuesArray; + } + + private static IDisposableReadOnlyCollection CreateDisposableResult(DisposableOrtValueHandleArray disposableHandles) + { + var outputValues = new DisposableList(disposableHandles.Span.Length); + try + { + for (int i = 0; i < disposableHandles.Span.Length; i++) + { + outputValues.Add(new OrtValue(disposableHandles.Span[i])); + disposableHandles.Span[i] = IntPtr.Zero; + } + return outputValues; + } + catch (Exception) + { + outputValues.Dispose(); + throw; + } + } + private IntPtr[] ConvertNamesToUtf8(IReadOnlyCollection names, DisposableList cleanupList) { cleanupList.Capacity += names.Count; diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs index 68b1d5bcc6147..9b72326201322 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs @@ -612,6 +612,81 @@ public void TestUpdateParameter() } } + [Fact(DisplayName = "TestTrainingSessionTrainStepWithOrtValues")] + public void TestTrainingSessionTrainStepWithOrtValues() + { + string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt"); + using (var cleanUp = new DisposableListTest()) + { + var state = CheckpointState.LoadCheckpoint(checkpointPath); + cleanUp.Add(state); + Assert.NotNull(state); + string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); + string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx"); + + var trainingSession = new TrainingSession(state, trainingPath, optimizerPath); + cleanUp.Add(trainingSession); + + float[] expectedOutput = TestDataLoader.LoadTensorFromFile("loss_1.out"); + var expectedOutputDimensions = new int[] { 1 }; + float[] inputData = TestDataLoader.LoadTensorFromFile("input-0.in"); + long[] inputShape = { 2, 784 }; + Int32[] labelsData = { 1, 1 }; + long[] labelsShape = { 2 }; + + using OrtValue inputOrtValue = OrtValue.CreateTensorValueFromMemory(inputData, inputShape); + using OrtValue labelsOrtValue = OrtValue.CreateTensorValueFromMemory(labelsData, labelsShape); + var inputValues = new List { inputOrtValue, labelsOrtValue }; + + using (var results = trainingSession.TrainStep(inputValues)) + { + Assert.Single(results); + var outputOrtValue = results[0]; + Assert.True(outputOrtValue.IsTensor); + var resultSpan = outputOrtValue.GetTensorDataAsSpan().ToArray(); + Assert.Equal(expectedOutput, resultSpan, new FloatComparer()); + } + } + } + + [Fact(DisplayName = "TestTrainingSessionEvalStepWithOrtValues")] + public void TestTrainingSessionEvalStepWithOrtValues() + { + string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt"); + using (var cleanUp = new DisposableListTest()) + { + var state = CheckpointState.LoadCheckpoint(checkpointPath); + cleanUp.Add(state); + Assert.NotNull(state); + string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); + string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx"); + string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx"); + + var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath); + cleanUp.Add(trainingSession); + + float[] expectedOutput = TestDataLoader.LoadTensorFromFile("loss_1.out"); + var expectedOutputDimensions = new int[] { 1 }; + float[] inputData = TestDataLoader.LoadTensorFromFile("input-0.in"); + long[] inputShape = { 2, 784 }; + Int32[] labelsData = { 1, 1 }; + long[] labelsShape = { 2 }; + + using OrtValue inputOrtValue = OrtValue.CreateTensorValueFromMemory(inputData, inputShape); + using OrtValue labelsOrtValue = OrtValue.CreateTensorValueFromMemory(labelsData, labelsShape); + var inputValues = new List { inputOrtValue, labelsOrtValue }; + + using (var results = trainingSession.EvalStep(inputValues)) + { + Assert.Single(results); + var outputOrtValue = results[0]; + Assert.True(outputOrtValue.IsTensor); + var resultSpan = outputOrtValue.GetTensorDataAsSpan().ToArray(); + Assert.Equal(expectedOutput, resultSpan, new FloatComparer()); + } + } + } + internal class FloatComparer : IEqualityComparer { private float atol = 1e-3f;