Skip to content

Commit

Permalink
Merge branch 'main' into fused_conv_hardSigmoid
Browse files Browse the repository at this point in the history
  • Loading branch information
qjia7 committed Jan 26, 2024
2 parents bb35d9d + fc44f96 commit f15dae9
Show file tree
Hide file tree
Showing 57 changed files with 1,478 additions and 1,158 deletions.
23 changes: 7 additions & 16 deletions cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ option(onnxruntime_USE_PREINSTALLED_EIGEN "Use pre-installed EIGEN. Need to prov
option(onnxruntime_BUILD_BENCHMARKS "Build ONNXRuntime micro-benchmarks" OFF)
option(onnxruntime_USE_LLVM "Build TVM with LLVM" OFF)

cmake_dependent_option(onnxruntime_USE_CUTLASS "Build with cutlass support" ON "onnxruntime_USE_CUDA" OFF)
cmake_dependent_option(onnxruntime_USE_FLASH_ATTENTION "Build flash attention kernel for scaled dot product attention" ON "NOT WIN32; onnxruntime_USE_CUDA" OFF)
option(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON)

Expand Down Expand Up @@ -707,20 +706,16 @@ if (onnxruntime_USE_CUDA)
enable_language(CUDA)
message( STATUS "CMAKE_CUDA_COMPILER_VERSION: ${CMAKE_CUDA_COMPILER_VERSION}")

if (onnxruntime_DISABLE_CONTRIB_OPS)
set(onnxruntime_USE_FLASH_ATTENTION OFF)
set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF)
endif()
if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.6)
message( STATUS "Turn off cutlass since CUDA compiler version < 11.6")
set(onnxruntime_USE_CUTLASS OFF)
message( STATUS "Turn off flash attention since CUDA compiler version < 11.6")
set(onnxruntime_USE_FLASH_ATTENTION OFF)
set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF)
endif()
else()
set(onnxruntime_USE_CUTLASS OFF)
endif()

if (NOT onnxruntime_USE_CUTLASS OR onnxruntime_DISABLE_CONTRIB_OPS)
if (onnxruntime_DISABLE_CONTRIB_OPS)
message( STATUS "Turn off flash attention/memory efficient attention since contrib ops are disabled")
else()
message( STATUS "Turn off flash attention/memory efficient attention since cutlass is not enabled")
endif()
set(onnxruntime_USE_FLASH_ATTENTION OFF)
set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF)
endif()
Expand Down Expand Up @@ -906,10 +901,6 @@ function(onnxruntime_set_compile_flags target_name)
target_compile_definitions(${target_name} PRIVATE ENABLE_ATEN)
endif()

if (onnxruntime_USE_CUTLASS)
target_compile_definitions(${target_name} PRIVATE USE_CUTLASS)
endif()

if(USE_NEURAL_SPEED)
target_compile_definitions(${target_name} PRIVATE ORT_NEURAL_SPEED)
endif()
Expand Down
20 changes: 9 additions & 11 deletions cmake/external/cutlass.cmake
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
if (onnxruntime_USE_CUTLASS)
include(FetchContent)
FetchContent_Declare(
cutlass
URL ${DEP_URL_cutlass}
URL_HASH SHA1=${DEP_SHA1_cutlass}
)
include(FetchContent)
FetchContent_Declare(
cutlass
URL ${DEP_URL_cutlass}
URL_HASH SHA1=${DEP_SHA1_cutlass}
)

FetchContent_GetProperties(cutlass)
if(NOT cutlass_POPULATED)
FetchContent_Populate(cutlass)
endif()
FetchContent_GetProperties(cutlass)
if(NOT cutlass_POPULATED)
FetchContent_Populate(cutlass)
endif()
3 changes: 3 additions & 0 deletions cmake/onnxruntime_rocm_hipify.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ set(contrib_ops_excluded_files
"diffusion/group_norm.cc"
"diffusion/group_norm_impl.cu"
"diffusion/group_norm_impl.h"
"diffusion/group_norm_impl_kernel.cuh"
"diffusion/group_norm_common_base.h"
"diffusion/group_norm_common_base.cc"
"diffusion/nhwc_conv.cc"
"math/gemm_float8.cc"
"math/gemm_float8.cu"
Expand Down
107 changes: 107 additions & 0 deletions csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,48 @@ public IDisposableReadOnlyCollection<DisposableNamedOnnxValue> TrainStep(
}
}

/// <summary>
/// This function performs a training step that computes the outputs of the training model and the gradients
/// of the trainable parameters for the given OrtValue inputs. The train step is performed based on the training model
/// that was provided to the training session.
/// The TrainStep method is equivalent of running forward propagation and backward propagation in a single
/// step.
/// The gradients computed are stored inside the training session state so they can be later consumed
/// by the OptimizerStep function.
/// The gradients can be lazily reset by invoking the LazyResetGrad function.
/// Example usage:
/// <code>
/// using OrtValue x = OrtValue.CreateTensorValueFromMemory(...);
/// using OrtValue label = OrtValue.CreateTensorValueFromMemory(...);
/// List<OrtValue> inputValues = new List<OrtValue> { x, label };
/// using (var loss = trainingSession.TrainStep(inputValues))
/// {
/// // process output values
/// }
/// </code>
/// </summary>
/// <param name="inputValues">Specify a collection of <see cref="OrtValue"/> that indicates the input values to the training model.</param>
/// <returns>Output Tensors in a Collection of NamedOnnxValue. User must dispose the output.</returns>
public IDisposableReadOnlyCollection<OrtValue> TrainStep(IReadOnlyCollection<OrtValue> inputValues)
{
IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues);
IntPtr[] outputValuesArray = new IntPtr[(int)_trainOutputCount];

NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtTrainStep(_nativeHandle, IntPtr.Zero, (UIntPtr)inputValues.Count,
inputValuesArray, (UIntPtr)_trainOutputCount, outputValuesArray));


var disposableHandles = new DisposableOrtValueHandleArray(outputValuesArray);
try
{
return CreateDisposableResult(disposableHandles);
}
finally
{
disposableHandles.Dispose();
}
}

/// <summary>
/// Convert native OrtValue handles to OrtValue instances
/// in an exceptions safe manner.
Expand Down Expand Up @@ -370,6 +412,42 @@ public void EvalStep(
inputValuesArray, (UIntPtr)outputValues.Count, outputValuesArray));
}

/// <summary>
/// This function performs an eval step that computes the outputs of the eval model for the given inputs.
/// Inputs are expected to be of type OrtValue. The eval step is performed based on the eval model that was
/// provided to the training session.
/// Example usage:
/// <code>
/// using OrtValue x = OrtValue.CreateTensorValueFromMemory(...);
/// using OrtValue label = OrtValue.CreateTensorValueFromMemory(...);
/// List<OrtValue> inputValues = new List<OrtValue> { x, label };
/// using (var loss = trainingSession.EvalSteps(inputValues))
/// {
/// // process output values
/// }
/// </code>
/// </summary>
/// <param name="inputValues">Specify a collection of <see cref="OrtValue"/> that indicates the input values to the eval model.</param>
public IDisposableReadOnlyCollection<OrtValue> EvalStep(IReadOnlyCollection<OrtValue> inputValues)
{
IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues);
IntPtr[] outputValuesArray = new IntPtr[(int)_evalOutputCount];

NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtEvalStep(_nativeHandle, IntPtr.Zero, (UIntPtr)inputValues.Count,
inputValuesArray, (UIntPtr)_evalOutputCount, outputValuesArray));


var disposableHandles = new DisposableOrtValueHandleArray(outputValuesArray);
try
{
return CreateDisposableResult(disposableHandles);
}
finally
{
disposableHandles.Dispose();
}
}


/// <summary>
/// Sets the learning rate for this training session.
Expand Down Expand Up @@ -702,6 +780,35 @@ private IntPtr[] GetOrtValuesHandles(IReadOnlyCollection<FixedBufferOnnxValue> v
return valuesArray;
}

private IntPtr[] GetOrtValuesHandles(IReadOnlyCollection<OrtValue> inputValues)
{
var valuesArray = new IntPtr[inputValues.Count];
for (int index = 0; index < inputValues.Count; ++index)
{
valuesArray[index] = inputValues.ElementAt(index).Handle;
}
return valuesArray;
}

private static IDisposableReadOnlyCollection<OrtValue> CreateDisposableResult(DisposableOrtValueHandleArray disposableHandles)
{
var outputValues = new DisposableList<OrtValue>(disposableHandles.Span.Length);
try
{
for (int i = 0; i < disposableHandles.Span.Length; i++)
{
outputValues.Add(new OrtValue(disposableHandles.Span[i]));
disposableHandles.Span[i] = IntPtr.Zero;
}
return outputValues;
}
catch (Exception)
{
outputValues.Dispose();
throw;
}
}

private IntPtr[] ConvertNamesToUtf8(IReadOnlyCollection<string> names, DisposableList<IDisposable> cleanupList)
{
cleanupList.Capacity += names.Count;
Expand Down
75 changes: 75 additions & 0 deletions csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,81 @@ public void TestUpdateParameter()
}
}

[Fact(DisplayName = "TestTrainingSessionTrainStepWithOrtValues")]
public void TestTrainingSessionTrainStepWithOrtValues()
{
string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
using (var cleanUp = new DisposableListTest<IDisposable>())
{
var state = CheckpointState.LoadCheckpoint(checkpointPath);
cleanUp.Add(state);
Assert.NotNull(state);
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx");

var trainingSession = new TrainingSession(state, trainingPath, optimizerPath);
cleanUp.Add(trainingSession);

float[] expectedOutput = TestDataLoader.LoadTensorFromFile("loss_1.out");
var expectedOutputDimensions = new int[] { 1 };
float[] inputData = TestDataLoader.LoadTensorFromFile("input-0.in");
long[] inputShape = { 2, 784 };
Int32[] labelsData = { 1, 1 };
long[] labelsShape = { 2 };

using OrtValue inputOrtValue = OrtValue.CreateTensorValueFromMemory<float>(inputData, inputShape);
using OrtValue labelsOrtValue = OrtValue.CreateTensorValueFromMemory<Int32>(labelsData, labelsShape);
var inputValues = new List<OrtValue> { inputOrtValue, labelsOrtValue };

using (var results = trainingSession.TrainStep(inputValues))
{
Assert.Single(results);
var outputOrtValue = results[0];
Assert.True(outputOrtValue.IsTensor);
var resultSpan = outputOrtValue.GetTensorDataAsSpan<float>().ToArray();
Assert.Equal(expectedOutput, resultSpan, new FloatComparer());
}
}
}

[Fact(DisplayName = "TestTrainingSessionEvalStepWithOrtValues")]
public void TestTrainingSessionEvalStepWithOrtValues()
{
string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
using (var cleanUp = new DisposableListTest<IDisposable>())
{
var state = CheckpointState.LoadCheckpoint(checkpointPath);
cleanUp.Add(state);
Assert.NotNull(state);
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx");
string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx");

var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath);
cleanUp.Add(trainingSession);

float[] expectedOutput = TestDataLoader.LoadTensorFromFile("loss_1.out");
var expectedOutputDimensions = new int[] { 1 };
float[] inputData = TestDataLoader.LoadTensorFromFile("input-0.in");
long[] inputShape = { 2, 784 };
Int32[] labelsData = { 1, 1 };
long[] labelsShape = { 2 };

using OrtValue inputOrtValue = OrtValue.CreateTensorValueFromMemory<float>(inputData, inputShape);
using OrtValue labelsOrtValue = OrtValue.CreateTensorValueFromMemory<Int32>(labelsData, labelsShape);
var inputValues = new List<OrtValue> { inputOrtValue, labelsOrtValue };

using (var results = trainingSession.EvalStep(inputValues))
{
Assert.Single(results);
var outputOrtValue = results[0];
Assert.True(outputOrtValue.IsTensor);
var resultSpan = outputOrtValue.GetTensorDataAsSpan<float>().ToArray();
Assert.Equal(expectedOutput, resultSpan, new FloatComparer());
}
}
}

internal class FloatComparer : IEqualityComparer<float>
{
private float atol = 1e-3f;
Expand Down
26 changes: 21 additions & 5 deletions js/web/lib/wasm/jsep/backend-webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -428,13 +428,26 @@ export class WebGpuBackend {
return;
}
// https://www.w3.org/TR/WGSL/#alignof
const baseAlignment = data.length <= 2 ? data.length * 4 : 16;
const sizeOfElement = v.type === 'float16' ? 2 : 4;
let sizeOfVecOrMat;
let baseAlignment;
if (v.type === 'float16') {
baseAlignment = data.length > 4 ? 16 : (data.length > 2 ? 8 : data.length * sizeOfElement);
sizeOfVecOrMat = data.length > 4 ? 16 : sizeOfElement * data.length;
} else {
baseAlignment = data.length <= 2 ? data.length * sizeOfElement : 16;
sizeOfVecOrMat = 16;
}
currentOffset = Math.ceil(currentOffset / baseAlignment) * baseAlignment;
offsets.push(currentOffset);
// When data.length > 4, the uniform variable is of type array<vec4<i32|u32|f32>,N>, where N =
// Math.ceil(data.length / 4) and SizeOf(vec4<i32|u32|f32>) = 16. The total byte length is N *
// SizeOf(vec4<i32|u32|f32>).
currentOffset += data.length > 4 ? Math.ceil(data.length / 4) * 16 : data.length * 4;
// For non-float16 type, when data.length > 4, the uniform variable is of type array<vec4<i32|u32|f32>,N>, where
// N = Math.ceil(data.length / 4) and SizeOf(vec4<i32|u32|f32>) = 16. The total byte length is N *
// SizeOf(vec4<i32|u32|f32>). For float16 type, when data.length > 4, the uniform variable is of type
// array<mat2x4<f16>,N>, where N = Math.ceil(data.length / 8) and SizeOf(mat2x4<f16>) = 16. The total byte
// length is N * SizeOf(mat2x4<f16>).
const elementPerVecOrMat = v.type === 'float16' ? 8 : 4;
currentOffset += data.length > 4 ? Math.ceil(data.length / elementPerVecOrMat) * sizeOfVecOrMat :
data.length * sizeOfElement;
});

// Meet alignment of struct here: https://www.w3.org/TR/WGSL/#alignment-and-size. For simplicity, set
Expand All @@ -449,6 +462,9 @@ export class WebGpuBackend {
new Int32Array(arrayBuffer, offset, data.length).set(data);
} else if (v.type === 'uint32') {
new Uint32Array(arrayBuffer, offset, data.length).set(data);
} else if (v.type === 'float16') {
// TODO: use Float16Array.
new Uint16Array(arrayBuffer, offset, data.length).set(data);
} else {
new Float32Array(arrayBuffer, offset, data.length).set(data);
}
Expand Down
Loading

0 comments on commit f15dae9

Please sign in to comment.