diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 7d7304630c00e..0eb224623f678 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -97,7 +97,6 @@ option(onnxruntime_USE_PREINSTALLED_EIGEN "Use pre-installed EIGEN. Need to prov option(onnxruntime_BUILD_BENCHMARKS "Build ONNXRuntime micro-benchmarks" OFF) option(onnxruntime_USE_LLVM "Build TVM with LLVM" OFF) -cmake_dependent_option(onnxruntime_USE_CUTLASS "Build with cutlass support" ON "onnxruntime_USE_CUDA" OFF) cmake_dependent_option(onnxruntime_USE_FLASH_ATTENTION "Build flash attention kernel for scaled dot product attention" ON "NOT WIN32; onnxruntime_USE_CUDA" OFF) option(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON) @@ -707,20 +706,16 @@ if (onnxruntime_USE_CUDA) enable_language(CUDA) message( STATUS "CMAKE_CUDA_COMPILER_VERSION: ${CMAKE_CUDA_COMPILER_VERSION}") + if (onnxruntime_DISABLE_CONTRIB_OPS) + set(onnxruntime_USE_FLASH_ATTENTION OFF) + set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF) + endif() if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.6) - message( STATUS "Turn off cutlass since CUDA compiler version < 11.6") - set(onnxruntime_USE_CUTLASS OFF) + message( STATUS "Turn off flash attention since CUDA compiler version < 11.6") + set(onnxruntime_USE_FLASH_ATTENTION OFF) + set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF) endif() else() - set(onnxruntime_USE_CUTLASS OFF) -endif() - -if (NOT onnxruntime_USE_CUTLASS OR onnxruntime_DISABLE_CONTRIB_OPS) - if (onnxruntime_DISABLE_CONTRIB_OPS) - message( STATUS "Turn off flash attention/memory efficient attention since contrib ops are disabled") - else() - message( STATUS "Turn off flash attention/memory efficient attention since cutlass is not enabled") - endif() set(onnxruntime_USE_FLASH_ATTENTION OFF) set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF) endif() @@ -906,10 +901,6 @@ function(onnxruntime_set_compile_flags target_name) target_compile_definitions(${target_name} PRIVATE ENABLE_ATEN) endif() - if (onnxruntime_USE_CUTLASS) - target_compile_definitions(${target_name} PRIVATE USE_CUTLASS) - endif() - if(USE_NEURAL_SPEED) target_compile_definitions(${target_name} PRIVATE ORT_NEURAL_SPEED) endif() diff --git a/cmake/external/cutlass.cmake b/cmake/external/cutlass.cmake index efc708bd681c0..f04f4bec76cd5 100644 --- a/cmake/external/cutlass.cmake +++ b/cmake/external/cutlass.cmake @@ -1,13 +1,11 @@ -if (onnxruntime_USE_CUTLASS) - include(FetchContent) - FetchContent_Declare( - cutlass - URL ${DEP_URL_cutlass} - URL_HASH SHA1=${DEP_SHA1_cutlass} - ) +include(FetchContent) +FetchContent_Declare( + cutlass + URL ${DEP_URL_cutlass} + URL_HASH SHA1=${DEP_SHA1_cutlass} +) - FetchContent_GetProperties(cutlass) - if(NOT cutlass_POPULATED) - FetchContent_Populate(cutlass) - endif() +FetchContent_GetProperties(cutlass) +if(NOT cutlass_POPULATED) + FetchContent_Populate(cutlass) endif() diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index f70961a66329a..d485abe6bb1a6 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -47,6 +47,9 @@ set(contrib_ops_excluded_files "diffusion/group_norm.cc" "diffusion/group_norm_impl.cu" "diffusion/group_norm_impl.h" + "diffusion/group_norm_impl_kernel.cuh" + "diffusion/group_norm_common_base.h" + "diffusion/group_norm_common_base.cc" "diffusion/nhwc_conv.cc" "math/gemm_float8.cc" "math/gemm_float8.cu" diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs index 877677dcad57b..fec0d46e96dfb 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs @@ -282,6 +282,48 @@ public IDisposableReadOnlyCollection TrainStep( } } + /// + /// This function performs a training step that computes the outputs of the training model and the gradients + /// of the trainable parameters for the given OrtValue inputs. The train step is performed based on the training model + /// that was provided to the training session. + /// The TrainStep method is equivalent of running forward propagation and backward propagation in a single + /// step. + /// The gradients computed are stored inside the training session state so they can be later consumed + /// by the OptimizerStep function. + /// The gradients can be lazily reset by invoking the LazyResetGrad function. + /// Example usage: + /// + /// using OrtValue x = OrtValue.CreateTensorValueFromMemory(...); + /// using OrtValue label = OrtValue.CreateTensorValueFromMemory(...); + /// List inputValues = new List { x, label }; + /// using (var loss = trainingSession.TrainStep(inputValues)) + /// { + /// // process output values + /// } + /// + /// + /// Specify a collection of that indicates the input values to the training model. + /// Output Tensors in a Collection of NamedOnnxValue. User must dispose the output. + public IDisposableReadOnlyCollection TrainStep(IReadOnlyCollection inputValues) + { + IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues); + IntPtr[] outputValuesArray = new IntPtr[(int)_trainOutputCount]; + + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtTrainStep(_nativeHandle, IntPtr.Zero, (UIntPtr)inputValues.Count, + inputValuesArray, (UIntPtr)_trainOutputCount, outputValuesArray)); + + + var disposableHandles = new DisposableOrtValueHandleArray(outputValuesArray); + try + { + return CreateDisposableResult(disposableHandles); + } + finally + { + disposableHandles.Dispose(); + } + } + /// /// Convert native OrtValue handles to OrtValue instances /// in an exceptions safe manner. @@ -370,6 +412,42 @@ public void EvalStep( inputValuesArray, (UIntPtr)outputValues.Count, outputValuesArray)); } + /// + /// This function performs an eval step that computes the outputs of the eval model for the given inputs. + /// Inputs are expected to be of type OrtValue. The eval step is performed based on the eval model that was + /// provided to the training session. + /// Example usage: + /// + /// using OrtValue x = OrtValue.CreateTensorValueFromMemory(...); + /// using OrtValue label = OrtValue.CreateTensorValueFromMemory(...); + /// List inputValues = new List { x, label }; + /// using (var loss = trainingSession.EvalSteps(inputValues)) + /// { + /// // process output values + /// } + /// + /// + /// Specify a collection of that indicates the input values to the eval model. + public IDisposableReadOnlyCollection EvalStep(IReadOnlyCollection inputValues) + { + IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues); + IntPtr[] outputValuesArray = new IntPtr[(int)_evalOutputCount]; + + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtEvalStep(_nativeHandle, IntPtr.Zero, (UIntPtr)inputValues.Count, + inputValuesArray, (UIntPtr)_evalOutputCount, outputValuesArray)); + + + var disposableHandles = new DisposableOrtValueHandleArray(outputValuesArray); + try + { + return CreateDisposableResult(disposableHandles); + } + finally + { + disposableHandles.Dispose(); + } + } + /// /// Sets the learning rate for this training session. @@ -702,6 +780,35 @@ private IntPtr[] GetOrtValuesHandles(IReadOnlyCollection v return valuesArray; } + private IntPtr[] GetOrtValuesHandles(IReadOnlyCollection inputValues) + { + var valuesArray = new IntPtr[inputValues.Count]; + for (int index = 0; index < inputValues.Count; ++index) + { + valuesArray[index] = inputValues.ElementAt(index).Handle; + } + return valuesArray; + } + + private static IDisposableReadOnlyCollection CreateDisposableResult(DisposableOrtValueHandleArray disposableHandles) + { + var outputValues = new DisposableList(disposableHandles.Span.Length); + try + { + for (int i = 0; i < disposableHandles.Span.Length; i++) + { + outputValues.Add(new OrtValue(disposableHandles.Span[i])); + disposableHandles.Span[i] = IntPtr.Zero; + } + return outputValues; + } + catch (Exception) + { + outputValues.Dispose(); + throw; + } + } + private IntPtr[] ConvertNamesToUtf8(IReadOnlyCollection names, DisposableList cleanupList) { cleanupList.Capacity += names.Count; diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs index 68b1d5bcc6147..9b72326201322 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs @@ -612,6 +612,81 @@ public void TestUpdateParameter() } } + [Fact(DisplayName = "TestTrainingSessionTrainStepWithOrtValues")] + public void TestTrainingSessionTrainStepWithOrtValues() + { + string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt"); + using (var cleanUp = new DisposableListTest()) + { + var state = CheckpointState.LoadCheckpoint(checkpointPath); + cleanUp.Add(state); + Assert.NotNull(state); + string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); + string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx"); + + var trainingSession = new TrainingSession(state, trainingPath, optimizerPath); + cleanUp.Add(trainingSession); + + float[] expectedOutput = TestDataLoader.LoadTensorFromFile("loss_1.out"); + var expectedOutputDimensions = new int[] { 1 }; + float[] inputData = TestDataLoader.LoadTensorFromFile("input-0.in"); + long[] inputShape = { 2, 784 }; + Int32[] labelsData = { 1, 1 }; + long[] labelsShape = { 2 }; + + using OrtValue inputOrtValue = OrtValue.CreateTensorValueFromMemory(inputData, inputShape); + using OrtValue labelsOrtValue = OrtValue.CreateTensorValueFromMemory(labelsData, labelsShape); + var inputValues = new List { inputOrtValue, labelsOrtValue }; + + using (var results = trainingSession.TrainStep(inputValues)) + { + Assert.Single(results); + var outputOrtValue = results[0]; + Assert.True(outputOrtValue.IsTensor); + var resultSpan = outputOrtValue.GetTensorDataAsSpan().ToArray(); + Assert.Equal(expectedOutput, resultSpan, new FloatComparer()); + } + } + } + + [Fact(DisplayName = "TestTrainingSessionEvalStepWithOrtValues")] + public void TestTrainingSessionEvalStepWithOrtValues() + { + string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt"); + using (var cleanUp = new DisposableListTest()) + { + var state = CheckpointState.LoadCheckpoint(checkpointPath); + cleanUp.Add(state); + Assert.NotNull(state); + string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); + string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx"); + string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx"); + + var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath); + cleanUp.Add(trainingSession); + + float[] expectedOutput = TestDataLoader.LoadTensorFromFile("loss_1.out"); + var expectedOutputDimensions = new int[] { 1 }; + float[] inputData = TestDataLoader.LoadTensorFromFile("input-0.in"); + long[] inputShape = { 2, 784 }; + Int32[] labelsData = { 1, 1 }; + long[] labelsShape = { 2 }; + + using OrtValue inputOrtValue = OrtValue.CreateTensorValueFromMemory(inputData, inputShape); + using OrtValue labelsOrtValue = OrtValue.CreateTensorValueFromMemory(labelsData, labelsShape); + var inputValues = new List { inputOrtValue, labelsOrtValue }; + + using (var results = trainingSession.EvalStep(inputValues)) + { + Assert.Single(results); + var outputOrtValue = results[0]; + Assert.True(outputOrtValue.IsTensor); + var resultSpan = outputOrtValue.GetTensorDataAsSpan().ToArray(); + Assert.Equal(expectedOutput, resultSpan, new FloatComparer()); + } + } + } + internal class FloatComparer : IEqualityComparer { private float atol = 1e-3f; diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index 8ca025d66550c..a48fe99570abf 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -428,13 +428,26 @@ export class WebGpuBackend { return; } // https://www.w3.org/TR/WGSL/#alignof - const baseAlignment = data.length <= 2 ? data.length * 4 : 16; + const sizeOfElement = v.type === 'float16' ? 2 : 4; + let sizeOfVecOrMat; + let baseAlignment; + if (v.type === 'float16') { + baseAlignment = data.length > 4 ? 16 : (data.length > 2 ? 8 : data.length * sizeOfElement); + sizeOfVecOrMat = data.length > 4 ? 16 : sizeOfElement * data.length; + } else { + baseAlignment = data.length <= 2 ? data.length * sizeOfElement : 16; + sizeOfVecOrMat = 16; + } currentOffset = Math.ceil(currentOffset / baseAlignment) * baseAlignment; offsets.push(currentOffset); - // When data.length > 4, the uniform variable is of type array,N>, where N = - // Math.ceil(data.length / 4) and SizeOf(vec4) = 16. The total byte length is N * - // SizeOf(vec4). - currentOffset += data.length > 4 ? Math.ceil(data.length / 4) * 16 : data.length * 4; + // For non-float16 type, when data.length > 4, the uniform variable is of type array,N>, where + // N = Math.ceil(data.length / 4) and SizeOf(vec4) = 16. The total byte length is N * + // SizeOf(vec4). For float16 type, when data.length > 4, the uniform variable is of type + // array,N>, where N = Math.ceil(data.length / 8) and SizeOf(mat2x4) = 16. The total byte + // length is N * SizeOf(mat2x4). + const elementPerVecOrMat = v.type === 'float16' ? 8 : 4; + currentOffset += data.length > 4 ? Math.ceil(data.length / elementPerVecOrMat) * sizeOfVecOrMat : + data.length * sizeOfElement; }); // Meet alignment of struct here: https://www.w3.org/TR/WGSL/#alignment-and-size. For simplicity, set @@ -449,6 +462,9 @@ export class WebGpuBackend { new Int32Array(arrayBuffer, offset, data.length).set(data); } else if (v.type === 'uint32') { new Uint32Array(arrayBuffer, offset, data.length).set(data); + } else if (v.type === 'float16') { + // TODO: use Float16Array. + new Uint16Array(arrayBuffer, offset, data.length).set(data); } else { new Float32Array(arrayBuffer, offset, data.length).set(data); } diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts index 3638938df7dbe..1a03621512888 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts @@ -21,8 +21,8 @@ import {LOG_DEBUG} from '../../../log'; import {TensorView} from '../../../tensor-view'; -import {ProgramInfo, ProgramUniform} from '../../types'; -import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common'; +import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; +import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common'; import {ConvAttributes} from '../conv'; import {getActivationSnippet} from '../fuse-utils'; @@ -88,10 +88,10 @@ const conv2dCommonSnippet = let outRow = ${row} / outWidth; let outCol = ${row} % outWidth; - let WRow = ${col} / (filterDims[1] * inChannels); - let WCol = ${col} / inChannels % filterDims[1]; - let xRow = outRow * stride[0] + dilation[0] * WRow - pad[0]; - let xCol = outCol * stride[1] + dilation[1] * WCol - pad[1]; + let WRow = ${col} / (i32(uniforms.w_shape[1]) * inChannels); + let WCol = ${col} / inChannels % i32(uniforms.w_shape[1]); + let xRow = outRow * uniforms.stride[0] + uniforms.dilation[0] * WRow - uniforms.pad[0]; + let xCol = outCol * uniforms.stride[1] + uniforms.dilation[1] * WCol - uniforms.pad[1]; let xCh = ${col} % inChannels; var resData = ${typeSnippet(innerElementSizeX, dataType)}(0.0); // The bounds checking is always needed since we use it to pad zero for @@ -108,7 +108,7 @@ const conv2dCommonSnippet = ${readXSnippet}` : ` let col = colIn * ${innerElementSizeX}; - if (row < uniforms.dimAOuter && col < uniforms.dimInner) { + if (row < uniforms.dim_a_outer && col < uniforms.dim_inner) { ${readXSnippet} } return ${typeSnippet(innerElementSizeX, dataType)}(0.0);`) : @@ -117,7 +117,7 @@ const conv2dCommonSnippet = ${readXSnippet}` : ` let col = colIn * ${innerElementSizeX}; - if (row < uniforms.dimInner && col < uniforms.dimBOuter) { + if (row < uniforms.dim_inner && col < uniforms.dim_b_outer) { ${readXSnippet} } return ${typeSnippet(innerElementSizeX, dataType)}(0.0);`); @@ -129,9 +129,8 @@ const conv2dCommonSnippet = isChannelsLast ? typeSnippet(innerElementSizeX, dataType) : typeSnippet(innerElementSizeW, dataType); const bType = isChannelsLast ? typeSnippet(innerElementSizeW, dataType) : typeSnippet(innerElementSizeX, dataType); - const {activationFunction, applyActivation} = getActivationSnippet(attributes, resType); + const applyActivation = getActivationSnippet(attributes, resType); const userCode = ` - ${activationFunction} fn mm_readA(batch: i32, row : i32, colIn : i32) -> ${aType} { ${isChannelsLast ? sampleX : sampleW} } @@ -142,7 +141,7 @@ const conv2dCommonSnippet = fn mm_write(batch: i32, row : i32, colIn : i32, valueIn : ${resType}) { let col = colIn * ${innerElementSize}; - if (row < uniforms.dimAOuter && col < uniforms.dimBOuter) + if (row < uniforms.dim_a_outer && col < uniforms.dim_b_outer) { var value = valueIn; let outWidth = ${isChannelsLast ? 'i32(uniforms.result_shape[2])' : 'i32(uniforms.result_shape[3])'}; @@ -181,31 +180,46 @@ export const createConv2DMatMulProgramInfo = LOG_DEBUG('verbose', () => `[conv2d_mm_webgpu] dispatch = ${dispatch}`); const innerElementSize = isVec4 ? (isChannelsLast && inChannels % 4 !== 0 ? 3 : 4) : 1; - const tileAOuter = workGroupSize[1] * elementsPerThread[1]; const tileBOuter = workGroupSize[0] * elementsPerThread[0]; const tileInner = Math.max(workGroupSize[0] * innerElementSize, workGroupSize[1]); - const fitAOuter = dimAOuter % tileAOuter === 0; const fitBOuter = dimBOuter % tileBOuter === 0; const fitInner = dimInner % tileInner === 0; - const elementsSize = isVec4 ? [innerElementSize, 4, 4] : [1, 1, 1]; - const t = tensorTypeToWsglStorageType(inputs[0].dataType); - // TODO: support component 2, 3. - const components = isVec4 ? 4 : 1; - const programUniforms: ProgramUniform[] = - [{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}]; - const x = - inputVariable('x', inputs[0].dataType, inputs[0].dims.length, innerElementSize === 3 ? 1 : innerElementSize); - const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, components); - const inputVariables = [x, w]; + const programUniforms: ProgramUniform[] = [ + {type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}, + {type: 'int32', data: [attributes.pads[0], attributes.pads[1]]}, {type: 'int32', data: attributes.strides}, + {type: 'int32', data: attributes.dilations} + ]; + if (attributes.activation === 'Clip') { + programUniforms.push( + {type: 'float32', data: attributes.clipMax!}, {type: 'float32', data: attributes.clipMin!}); + } + programUniforms.push( + ...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims)); + const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; + if (hasBias) { + programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); + inputDependencies.push('rank'); + } + programUniforms.push(...createTensorShapeVariables(outputShape)); - programUniforms.push(...createTensorShapeVariables(inputs[0].dims)); - programUniforms.push(...createTensorShapeVariables(inputs[1].dims)); + const getShaderSource = (shaderHelper: ShaderHelper) => { + const uniforms: UniformsArrayType = [ + {name: 'dim_a_outer', type: 'i32'}, {name: 'dim_b_outer', type: 'i32'}, {name: 'dim_inner', type: 'i32'}, + {name: 'pad', type: 'i32', length: 2}, {name: 'stride', type: 'i32', length: 2}, + {name: 'dilation', type: 'i32', length: 2} + ]; + if (attributes.activation === 'Clip') { + uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'}); + } - let declareFunctions = ` + // TODO: support component 2, 3. + const components = isVec4 ? 4 : 1; + const t = tensorTypeToWsglStorageType(inputs[0].dataType); + let declareFunctions = ` fn setOutputAtIndex(flatIndex : i32, value : ${isVec4 ? `vec4<${t}>` : t}) { result[flatIndex] = ${isVec4 ? `vec4<${t}>` : t}(value); } @@ -213,51 +227,50 @@ export const createConv2DMatMulProgramInfo = let flatIndex = getOutputIndexFromCoords(vec4(d0, d1, d2, d3)); setOutputAtIndex(flatIndex ${isVec4 ? '/ 4' : ''}, value); }`; - if (hasBias) { - const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components); - inputVariables.push(bias); - - programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); - - declareFunctions += ` + const x = inputVariable( + 'x', inputs[0].dataType, inputs[0].dims.length, innerElementSize === 3 ? 1 : innerElementSize); + const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, components); + const inputVariables = [x, w]; + const output = outputVariable('result', inputs[0].dataType, outputShape.length, components); + if (hasBias) { + const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components); + inputVariables.push(bias); + declareFunctions += ` fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? `vec4<${t}>` : t} { return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}]; }`; - } - const output = outputVariable('result', inputs[0].dataType, outputShape.length, components); - programUniforms.push(...createTensorShapeVariables(outputShape)); - return { - name: 'Conv2DMatMul', - shaderCache: {hint: attributes.cacheKey}, - getRunData: () => ({ - outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]}, - programUniforms, - }), - getShaderSource: (shaderHelper: ShaderHelper) => ` + } + + return ` ${utilFunctions('uniforms.result_strides')} //struct Uniforms { xShape : vec4, wShape : vec4, outShape : vec4, // outShapeStrides: vec3, filterDims : vec2, pad : vec2, stride : vec2, // dilation : vec2, dimAOuter : i32, dimBOuter : i32, dimInner : i32 }; - ${ - shaderHelper.registerUniform('dimAOuter', 'i32') - .registerUniform('dimBOuter', 'i32') - .registerUniform('dimInner', 'i32') - .declareVariables(...inputVariables, output)} - const filterDims : vec2 = vec2(${attributes.kernelShape[0]}, ${attributes.kernelShape[1]}); - const pad : vec2 = vec2(${attributes.pads[0]}, ${attributes.pads[1]}); - const stride : vec2 = vec2(${attributes.strides[0]}, ${attributes.strides[1]}); - const dilation : vec2 = vec2(${attributes.dilations[0]}, ${attributes.dilations[1]}); + ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)} ${declareFunctions} ${ conv2dCommonSnippet( isChannelsLast, fitAOuter, fitBOuter, fitInner, hasBias, attributes, elementsSize[0], elementsSize[1], elementsSize[2], t)} - ${ + ${ isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workGroupSize, t, undefined, !isChannelsLast, tileInner) : makeMatMulPackedSource( elementsPerThread, workGroupSize, t, undefined, !isChannelsLast, tileInner, false, undefined, - sequentialAccessByThreads)}` + sequentialAccessByThreads)}`; + }; + return { + name: 'Conv2DMatMul', + shaderCache: { + hint: `${attributes.cacheKey};${innerElementSize};${isVec4};${fitAOuter};${fitBOuter};${fitInner};${ + tileAOuter};${tileBOuter};${tileInner}`, + inputDependencies + }, + getRunData: () => ({ + outputs: [{dims: outputShape, dataType: inputs[0].dataType}], + dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]}, + programUniforms, + }), + getShaderSource }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts index d425155857e14..33e50a9a39cb9 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts @@ -21,8 +21,8 @@ import {LOG_DEBUG} from '../../../log'; import {TensorView} from '../../../tensor-view'; -import {ProgramInfo, ProgramUniform} from '../../types'; -import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from '../common'; +import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; +import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from '../common'; import {ConvTransposeAttributes} from '../conv-transpose'; import {getActivationSnippet} from '../fuse-utils'; @@ -74,21 +74,21 @@ const conv2dTransposeCommonSnippet = col % outWidth); `; - const xHeight = isChannelsLast ? 'outBackprop[1]' : 'outBackprop[2]'; - const xWidth = isChannelsLast ? 'outBackprop[2]' : 'outBackprop[3]'; + const xHeight = isChannelsLast ? 'i32(uniforms.x_shape[1])' : 'i32(uniforms.x_shape[2])'; + const xWidth = isChannelsLast ? 'i32(uniforms.x_shape[2])' : 'i32(uniforms.x_shape[3])'; const row = isChannelsLast ? 'row' : 'col'; const col = isChannelsLast ? 'col' : 'row'; const readASnippet = ` - let inChannels = ${isChannelsLast ? 'outBackprop[3]' : 'outBackprop[1]'}; + let inChannels = ${isChannelsLast ? 'i32(uniforms.x_shape[3])' : 'i32(uniforms.x_shape[1])'}; let outWidth = ${isChannelsLast ? 'i32(uniforms.result_shape[2])' : 'i32(uniforms.result_shape[3])'}; let outRow = ${row} / outWidth; let outCol = ${row} % outWidth; - let WRow = ${col} / (filterDims[1] * inChannels); - let WCol = ${col} / inChannels % filterDims[1]; - let xR = f32(outRow - pads[0] + dilation[0] * WRow) / f32(strides[0]); - let xC = f32(outCol - pads[1] + dilation[1] * WCol) / f32(strides[1]); + let WRow = ${col} / (uniforms.filter_dims[1] * inChannels); + let WCol = ${col} / inChannels % uniforms.filter_dims[1]; + let xR = f32(outRow - uniforms.pads[0] + uniforms.dilations[0] * WRow) / f32(uniforms.strides[0]); + let xC = f32(outCol - uniforms.pads[1] + uniforms.dilations[1] * WCol) / f32(uniforms.strides[1]); if (xR < 0.0 || xR >= f32(${xHeight}) || fract(xR) > 0.0) { return ${type}(0.0); } @@ -103,25 +103,25 @@ const conv2dTransposeCommonSnippet = const sampleA = isChannelsLast ? ` let col = colIn * ${innerElementSize}; - if (row < uniforms.dimAOuter && col < uniforms.dimInner) { + if (row < uniforms.dim_a_outer && col < uniforms.dim_inner) { ${readASnippet} } return ${type}(0.0);` : ` let col = colIn * ${innerElementSize}; - if (row < uniforms.dimInner && col < uniforms.dimBOuter) { + if (row < uniforms.dim_inner && col < uniforms.dim_b_outer) { ${readASnippet} } return ${type}(0.0);`; const sampleW = ` let col = colIn * ${innerElementSize}; - let inChannels = ${isChannelsLast ? 'outBackprop[3]' : 'outBackprop[1]'}; - let coordX = filterDims.x - 1 - row / (filterDims[1] * inChannels); - let coordY = filterDims.y - 1 - (row / inChannels) % filterDims[1]; + let inChannels = ${isChannelsLast ? 'i32(uniforms.x_shape[3])' : 'i32(uniforms.x_shape[1])'}; + let coordX = uniforms.filter_dims[0] - 1 - row / (uniforms.filter_dims[1] * inChannels); + let coordY = uniforms.filter_dims[1] - 1 - (row / inChannels) % uniforms.filter_dims[1]; if (${ - isChannelsLast ? 'row < uniforms.dimInner && col < uniforms.dimBOuter' : - 'row < uniforms.dimInner && col < uniforms.dimAOuter'} && coordX >= 0 && coordY >= 0) { + isChannelsLast ? 'row < uniforms.dim_inner && col < uniforms.dim_b_outer' : + 'row < uniforms.dim_inner && col < uniforms.dim_a_outer'} && coordX >= 0 && coordY >= 0) { let rowInner = row % inChannels; let coord = vec4(coordX, coordY, col, rowInner); ${getWSnippet(innerElementSize)} @@ -129,9 +129,8 @@ const conv2dTransposeCommonSnippet = return ${type}(0.0); `; - const {activationFunction, applyActivation} = getActivationSnippet(attributes, type); + const applyActivation = getActivationSnippet(attributes, type); const userCode = ` - ${activationFunction} fn mm_readA(batch: i32, row : i32, colIn : i32) -> ${type} { ${isChannelsLast ? sampleA : sampleW} } @@ -142,7 +141,7 @@ const conv2dTransposeCommonSnippet = fn mm_write(batch: i32, row : i32, colIn : i32, valueInput : ${type}) { let col = colIn * ${innerElementSize}; - if (row < uniforms.dimAOuter && col < uniforms.dimBOuter) { + if (row < uniforms.dim_a_outer && col < uniforms.dim_b_outer) { var value = valueInput; let outWidth = ${isChannelsLast ? 'i32(uniforms.result_shape[2])' : 'i32(uniforms.result_shape[3])'}; ${coordResSnippet} @@ -186,65 +185,64 @@ export const createConv2DTransposeMatMulProgramInfo = const innerElementSize = isVec4 ? 4 : 1; const tileInner = Math.max(workGroupSize[0] * innerElementSize, workGroupSize[1]); const components = isVec4 ? 4 : 1; - const programUniforms: ProgramUniform[] = - [{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}]; - const x = inputVariable('x', inputs[0].dataType, inputs[0].dims.length, components); - const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, 1); - const output = outputVariable('result', inputs[0].dataType, outputShape.length, components); - const inputVariables = [x, w]; - programUniforms.push(...createTensorShapeVariables(inputs[0].dims)); - programUniforms.push(...createTensorShapeVariables(inputs[1].dims)); + const filterDims = + [attributes.kernelShape[isChannelsLast ? 1 : 2], attributes.kernelShape[isChannelsLast ? 2 : 3]]; + const effectiveFilterDims = [ + filterDims[0] + (attributes.dilations[0] <= 1 ? 0 : (filterDims[0] - 1) * (attributes.dilations[0] - 1)), + filterDims[1] + (attributes.dilations[1] <= 1 ? 0 : (filterDims[1] - 1) * (attributes.dilations[1] - 1)) + ]; + const pads = [ + effectiveFilterDims[0] - 1 - Math.floor((attributes.pads[0] + attributes.pads[2]) / 2), + effectiveFilterDims[1] - 1 - Math.floor((attributes.pads[1] + attributes.pads[3]) / 2) + ]; - let declareFunctions = ''; + const programUniforms: ProgramUniform[] = [ + {type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}, + {type: 'int32', data: attributes.strides}, {type: 'int32', data: attributes.dilations}, + {type: 'int32', data: filterDims}, {type: 'int32', data: pads} + ]; + if (attributes.activation === 'Clip') { + programUniforms.push( + {type: 'float32', data: attributes.clipMax!}, {type: 'float32', data: attributes.clipMin!}); + } + programUniforms.push( + ...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims)); + + const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; if (hasBias) { - const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components); - inputVariables.push(bias); programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); - - declareFunctions += ` - fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? 'vec4' : 'f32'} { - return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}]; - }`; + inputDependencies.push('rank'); } - programUniforms.push(...createTensorShapeVariables(outputShape)); - return { - name: 'Conv2DTransposeMatMul', - shaderCache: {hint: attributes.cacheKey}, - getRunData: () => ({ - outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]}, - programUniforms - }), - getShaderSource: (shaderHelper: ShaderHelper) => ` + const getShaderSource = (shaderHelper: ShaderHelper) => { + const x = inputVariable('x', inputs[0].dataType, inputs[0].dims.length, components); + const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, 1); + const output = outputVariable('result', inputs[0].dataType, outputShape.length, components); + const inputVariables = [x, w]; + + let declareFunctions = ''; + if (hasBias) { + const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components); + inputVariables.push(bias); + declareFunctions += ` + fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? 'vec4' : 'f32'} { + return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}]; + }`; + } + + const uniforms: UniformsArrayType = [ + {name: 'dim_a_outer', type: 'i32'}, {name: 'dim_b_outer', type: 'i32'}, {name: 'dim_inner', type: 'i32'}, + {name: 'strides', type: 'i32', length: 2}, {name: 'dilations', type: 'i32', length: 2}, + {name: 'filter_dims', type: 'i32', length: filterDims.length}, + {name: 'pads', type: 'i32', length: pads.length} + ]; + if (attributes.activation === 'Clip') { + uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'}); + } + return ` ${utilFunctions('uniforms.result_strides')} - ${ - shaderHelper.registerUniform('dimAOuter', 'i32') - .registerUniform('dimBOuter', 'i32') - .registerUniform('dimInner', 'i32') - .declareVariables(...inputVariables, output)}; - const outBackprop : vec4 = vec4(${inputs[0].dims.join(',')}); - const filterDims : vec2 = vec2(${attributes.kernelShape[isChannelsLast ? 1 : 2]}, ${ - attributes.kernelShape[isChannelsLast ? 2 : 3]}); - const effectiveFilterDims : vec2 = filterDims + vec2( - ${ - attributes.dilations[0] <= 1 ? - 0 : - (attributes.kernelShape[isChannelsLast ? 1 : 2] - 1) * (attributes.dilations[0] - 1)}, - ${ - attributes.dilations[1] <= 1 ? - 0 : - (attributes.kernelShape[isChannelsLast ? 2 : 3] - 1) * (attributes.dilations[1] - 1)}); - const pads : vec2 = vec2(i32(effectiveFilterDims[0]) - 1 - (${ - attributes.pads[0] + attributes.pads[2]})/2, - i32(effectiveFilterDims[1]) - 1 - (${ - attributes.pads[1] + attributes.pads[3]})/2); - const strides : vec2 = vec2(${attributes.strides[0]}, ${attributes.strides[1]}); - const dilation : vec2 = vec2(${attributes.dilations[0]}, ${attributes.dilations[1]}); - const dimAOuter : i32 = ${dimAOuter}; - const dimBOuter : i32 = ${dimBOuter}; - const dimInner : i32 = ${dimInner}; + ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)}; ${declareFunctions} ${conv2dTransposeCommonSnippet(isChannelsLast, hasBias, attributes, innerElementSize)} ${ @@ -252,6 +250,18 @@ export const createConv2DTransposeMatMulProgramInfo = elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner) : makeMatMulPackedSource( elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner, false, - undefined, sequentialAccessByThreads)}` + undefined, sequentialAccessByThreads)}`; + }; + + return { + name: 'Conv2DTransposeMatMul', + shaderCache: + {hint: `${attributes.cacheKey};${elementsPerThread};${workGroupSize};${isVec4}`, inputDependencies}, + getRunData: () => ({ + outputs: [{dims: outputShape, dataType: inputs[0].dataType}], + dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]}, + programUniforms + }), + getShaderSource }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts index 50b0841a0200a..380efc8bc577a 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts @@ -20,24 +20,18 @@ import {LOG_DEBUG} from '../../../log'; import {TensorView} from '../../../tensor-view'; import {ShapeUtil} from '../../../util'; -import {ProgramInfo} from '../../types'; -import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common'; +import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; +import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common'; import {ConvTransposeAttributes} from '../conv-transpose'; const createConvTranspose2DOpProgramShaderSource = - (shaderHelper: ShaderHelper, inputs: readonly TensorView[], attributes: ConvTransposeAttributes, - outputShape: readonly number[], hasBias: boolean, is1DimensionDispatch: boolean, isVec4 = false, - dataType: string): string => { - const isChannelsLast = attributes.format === 'NHWC'; + (shaderHelper: ShaderHelper, inputs: readonly TensorView[], outputShape: readonly number[], hasBias: boolean, + is1DimensionDispatch: boolean, isVec4 = false, dataType: string, uniforms: UniformsArrayType, + isChannelsLast = false): string => { const rowDim = isChannelsLast ? 1 : 2; const colDim = isChannelsLast ? 2 : 3; const channelDim = isChannelsLast ? 3 : 1; - const outputSize = ShapeUtil.size(outputShape); const workPerThread = isVec4 ? 2 : 1; - const group = attributes.group; - const wShape = inputs[1].dims; - const inputChannelsPerGroup = wShape[0] / group; - const outputChannelsPerGroup = wShape[1]; let declareFunctions = ` fn setOutputAtIndex(flatIndex : u32, value : ${isVec4 ? `vec4<${dataType}>` : dataType}) { @@ -50,20 +44,21 @@ const createConvTranspose2DOpProgramShaderSource = }`; } const components = isVec4 ? 4 : 1; - const w = inputVariable('W', inputs[1].dataType, inputs[1].dims, components); - const dy = inputVariable('Dy', inputs[0].dataType, inputs[0].dims, components); + const w = inputVariable('W', inputs[1].dataType, inputs[1].dims.length, components); + const dy = inputVariable('Dy', inputs[0].dataType, inputs[0].dims.length, components); const inputVariables = [dy, w]; if (hasBias) { - inputVariables.push(inputVariable('bias', inputs[2].dataType, [outputShape[channelDim]], components)); + inputVariables.push(inputVariable('bias', inputs[2].dataType, [outputShape[channelDim]].length, components)); } - const output = outputVariable('result', inputs[0].dataType, outputShape, components); + const output = outputVariable('result', inputs[0].dataType, outputShape.length, components); + const codeSnippet4 = `{ - let batch: u32 = ${is1DimensionDispatch ? 'global_id.z' : 'workgroup_id.z'} / outShape[1]; - let r = ${is1DimensionDispatch ? 'global_id.z' : 'workgroup_id.z'} % outShape[1]; + let batch: u32 = ${is1DimensionDispatch ? 'global_id.z' : 'workgroup_id.z'} / uniforms.result_shape[1]; + let r = ${is1DimensionDispatch ? 'global_id.z' : 'workgroup_id.z'} % uniforms.result_shape[1]; let c = ${is1DimensionDispatch ? 'global_id.y' : 'workgroup_id.y'} * ${workPerThread}; let d1: u32 = ${is1DimensionDispatch ? 'global_id.x' : 'workgroup_id.x'} * 4; - let dyCorner = vec2(i32(r), i32(c)) - vec2(pads); + let dyCorner = vec2(i32(r), i32(c)) - vec2(uniforms.pads); // Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1). // ? = to be determined. : = across all values in that axis. @@ -71,29 +66,29 @@ const createConvTranspose2DOpProgramShaderSource = for (var i = 0; i < ${workPerThread}; i++) { dotProd[i] = vec4<${dataType}>(0.0); } - for (var wR: u32 = 0; wR < filterDims[0]; wR = wR + 1) { - var dyR = (${dataType}(dyCorner.x) + ${dataType}(wR)) / ${dataType}(strides.x); - let wRPerm = filterDims[0] - 1 - wR; - if (dyR < 0.0 || dyR >= ${dataType}(outBackprop[1]) || + for (var wR: u32 = 0; wR < uniforms.filter_dims[0]; wR = wR + 1) { + var dyR = (${dataType}(dyCorner.x) + ${dataType}(wR)) / ${dataType}(uniforms.strides.x); + let wRPerm = uniforms.filter_dims[0] - 1 - wR; + if (dyR < 0.0 || dyR >= ${dataType}(uniforms.Dy_shape[1]) || fract(dyR) > 0.0 || wRPerm < 0) { continue; } let idyR: u32 = u32(dyR); - for (var wC: u32 = 0; wC < filterDims[1]; wC = wC + 1) { - let dyC = (${dataType}(dyCorner.y) + ${dataType}(wC)) / ${dataType}(strides.y); - let dyC2 = (${dataType}(dyCorner.y) + 1.0 + ${dataType}(wC)) / ${dataType}(strides.y); - let wCPerm = filterDims[1] - 1 - wC; + for (var wC: u32 = 0; wC < uniforms.filter_dims[1]; wC = wC + 1) { + let dyC = (${dataType}(dyCorner.y) + ${dataType}(wC)) / ${dataType}(uniforms.strides.y); + let dyC2 = (${dataType}(dyCorner.y) + 1.0 + ${dataType}(wC)) / ${dataType}(uniforms.strides.y); + let wCPerm = uniforms.filter_dims[1] - 1 - wC; if (wCPerm < 0) { continue; } var bDyCVal = true; var bDyCVal2 = true; - if (dyC < 0.0 || dyC >= ${dataType}(outBackprop[2]) || + if (dyC < 0.0 || dyC >= ${dataType}(uniforms.Dy_shape[2]) || fract(dyC) > 0.0) { bDyCVal = false; } - if (dyC2 < 0.0 || dyC2 >= ${dataType}(outBackprop[2]) || + if (dyC2 < 0.0 || dyC2 >= ${dataType}(uniforms.Dy_shape[2]) || fract(dyC2) > 0.0) { bDyCVal2 = false; } @@ -101,7 +96,7 @@ const createConvTranspose2DOpProgramShaderSource = let idyC: u32 = u32(dyC); let idyC2: u32 = u32(dyC2); if (bDyCVal && bDyCVal2) { - let d2Length = outBackprop[3]; + let d2Length = uniforms.Dy_shape[3]; for (var d2 :u32 = 0; d2 < d2Length; d2 = d2 + 4) { let wValue0 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1', 'd2')}; let wValue1 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 1', 'd2')}; @@ -123,7 +118,7 @@ const createConvTranspose2DOpProgramShaderSource = dot(xValue, wValue3)); } } else if (bDyCVal) { - let d2Length = outBackprop[${channelDim}]; + let d2Length = uniforms.Dy_shape[${channelDim}]; for (var d2: u32 = 0; d2 < d2Length; d2 = d2 + 4) { let wValue0 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1', 'd2')}; let wValue1 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 1', 'd2')}; @@ -138,7 +133,7 @@ const createConvTranspose2DOpProgramShaderSource = dotProd[0] = dotProd[0] + tmpval; } } else if (bDyCVal2) { - let d2Length = outBackprop[3]; + let d2Length = uniforms.Dy_shape[3]; for (var d2: u32 = 0; d2 < d2Length; d2 = d2 + 4) { let wValue0 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1', 'd2')}; let wValue1 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 1', 'd2')}; @@ -167,39 +162,39 @@ const createConvTranspose2DOpProgramShaderSource = let d1 = ${output.indicesGet('outputIndices', channelDim)}; let r = ${output.indicesGet('outputIndices', rowDim)}; let c = ${output.indicesGet('outputIndices', colDim)}; - let dyCorner = vec2(i32(r), i32(c)) - pads; + let dyCorner = vec2(i32(r), i32(c)) - uniforms.pads; let dyRCorner = dyCorner.x; let dyCCorner = dyCorner.y; - let groupId = d1 / ${outputChannelsPerGroup}; - let wOutChannel = d1 - groupId * ${outputChannelsPerGroup}; + let groupId = d1 / uniforms.output_channels_per_group; + let wOutChannel = d1 - groupId * uniforms.output_channels_per_group; // Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1). // ? = to be determined. : = across all values in that axis. var dotProd = ${dataType}(0.0); - for (var wR: u32 = 0; wR < effectiveFilterDims.x; wR = wR + 1) { - if (wR % dilations.x != 0) { + for (var wR: u32 = 0; wR < uniforms.effective_filter_dims.x; wR = wR + 1) { + if (wR % uniforms.dilations.x != 0) { continue; } - let dyR = (${dataType}(dyRCorner) + ${dataType}(wR)) / ${dataType}(strides[0]); - let wRPerm = filterDims.x - 1 - wR / dilations.x; - if (dyR < 0.0 || dyR >= ${dataType}(outBackprop[${rowDim}]) || fract(dyR) > 0.0 || + let dyR = (${dataType}(dyRCorner) + ${dataType}(wR)) / ${dataType}(uniforms.strides[0]); + let wRPerm = uniforms.filter_dims.x - 1 - wR / uniforms.dilations.x; + if (dyR < 0.0 || dyR >= ${dataType}(uniforms.Dy_shape[${rowDim}]) || fract(dyR) > 0.0 || wRPerm < 0) { continue; } let idyR: u32 = u32(dyR); - for (var wC: u32 = 0; wC < effectiveFilterDims.y; wC = wC + 1) { - if (wC % dilations.y != 0) { + for (var wC: u32 = 0; wC < uniforms.effective_filter_dims.y; wC = wC + 1) { + if (wC % uniforms.dilations.y != 0) { continue; } - let dyC = (${dataType}(dyCCorner) + ${dataType}(wC)) / ${dataType}(strides.y); - let wCPerm = filterDims.y - 1 - wC / dilations.y; - if (dyC < 0.0 || dyC >= ${dataType}(outBackprop[${colDim}]) || + let dyC = (${dataType}(dyCCorner) + ${dataType}(wC)) / ${dataType}(uniforms.strides.y); + let wCPerm = uniforms.filter_dims.y - 1 - wC / uniforms.dilations.y; + if (dyC < 0.0 || dyC >= ${dataType}(uniforms.Dy_shape[${colDim}]) || fract(dyC) > 0.0 || wCPerm < 0) { continue; } let idyC: u32 = u32(dyC); - var inputChannel = groupId * ${inputChannelsPerGroup}; - for (var d2: u32 = 0; d2 < ${inputChannelsPerGroup}; d2 = d2 + 1) { + var inputChannel = groupId * uniforms.input_channels_per_group; + for (var d2: u32 = 0; d2 < uniforms.input_channels_per_group; d2 = d2 + 1) { let xValue = ${ isChannelsLast ? dy.get('batch', 'idyR', 'idyC', 'inputChannel') : dy.get('batch', 'inputChannel', 'idyR', 'idyC')}; @@ -214,27 +209,11 @@ const createConvTranspose2DOpProgramShaderSource = `; return ` - ${shaderHelper.declareVariables(...inputVariables, output)} + ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)} ${declareFunctions} - const outShape : vec4 = vec4(${outputShape.join(',')}); - const outBackprop : vec4 = vec4(${inputs[0].dims.join(',')}); - const strides : vec2 = vec2(${attributes.strides[0]}, ${attributes.strides[1]}); - const filterDims : vec2 = vec2(${attributes.kernelShape[isChannelsLast ? 1 : 2]}, ${ - attributes.kernelShape[isChannelsLast ? 2 : 3]}); - const dilations : vec2 = vec2(${attributes.dilations[0]}, ${attributes.dilations[1]}); - const effectiveFilterDims : vec2 = filterDims + vec2( - ${ - attributes.dilations[0] <= 1 ? - 0 : - (attributes.kernelShape[isChannelsLast ? 1 : 2] - 1) * (attributes.dilations[0] - 1)}, - ${ - attributes.dilations[1] <= 1 ? - 0 : - (attributes.kernelShape[isChannelsLast ? 2 : 3] - 1) * (attributes.dilations[1] - 1)}); - const pads : vec2 = vec2(i32(effectiveFilterDims[0]) - 1 - (${attributes.pads[0] + attributes.pads[2]})/2, - i32(effectiveFilterDims[1]) - 1 - (${attributes.pads[1] + attributes.pads[3]})/2); + ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}; + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}; ${isVec4 ? codeSnippet4 : codeSnippet}}`; }; @@ -257,19 +236,72 @@ export const createConvTranspose2DProgramInfo = ]; LOG_DEBUG('verbose', () => `[conv2d_backprop_webgpu] dispatch = ${dispatch}`); - const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); + const isChannelsLast = attributes.format === 'NHWC'; + const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; + const strides = [attributes.strides[0], attributes.strides[1]]; + const filterDims = + [attributes.kernelShape[isChannelsLast ? 1 : 2], attributes.kernelShape[isChannelsLast ? 2 : 3]]; + const dilations = [attributes.dilations[0], attributes.dilations[1]]; + const effectiveFilterDims = [ + filterDims[0] + + (attributes.dilations[0] <= 1 ? + 0 : + (attributes.kernelShape[isChannelsLast ? 1 : 2] - 1) * (attributes.dilations[0] - 1)), + filterDims[1] + + (attributes.dilations[1] <= 1 ? + 0 : + (attributes.kernelShape[isChannelsLast ? 2 : 3] - 1) * (attributes.dilations[1] - 1)) + ]; + const pads = [ + effectiveFilterDims[0] - 1 - Math.floor((attributes.pads[0] + attributes.pads[2]) / 2), + effectiveFilterDims[1] - 1 - Math.floor(attributes.pads[1] + attributes.pads[3]) / 2 + ]; + + const isVec4 = false; + const group = attributes.group; + const wShape = inputs[1].dims; + const inputChannelsPerGroup = wShape[0] / group; + const outputChannelsPerGroup = wShape[1]; + + const programUniforms: ProgramUniform[] = [ + {type: 'int32', data: outputSize}, {type: 'uint32', data: strides}, {type: 'uint32', data: filterDims}, + {type: 'uint32', data: dilations}, {type: 'uint32', data: effectiveFilterDims}, {type: 'int32', data: pads}, + {type: 'uint32', data: inputChannelsPerGroup}, {type: 'uint32', data: outputChannelsPerGroup}, + ...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims) + ]; + if (hasBias) { + programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); + inputDependencies.push('rank'); + } + programUniforms.push(...createTensorShapeVariables(outputShape)); + + const is1DimensionDispatch = dispatch[1] === 1 && dispatch[2] === 1; + const getShaderSource = (shaderHelper: ShaderHelper) => { + const uniforms: UniformsArrayType = [ + {name: 'output_size', type: 'u32'}, {name: 'strides', type: 'u32', length: strides.length}, + {name: 'filter_dims', type: 'u32', length: filterDims.length}, + {name: 'dilations', type: 'u32', length: filterDims.length}, + {name: 'effective_filter_dims', type: 'u32', length: effectiveFilterDims.length}, + {name: 'pads', type: 'i32', length: pads.length}, {name: 'input_channels_per_group', type: 'u32'}, + {name: 'output_channels_per_group', type: 'u32'} + ]; + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); + return `${ + createConvTranspose2DOpProgramShaderSource( + shaderHelper, inputs, outputShape, hasBias, is1DimensionDispatch, isVec4, dataType, uniforms, + isChannelsLast)}`; + }; return { name: 'ConvTranspose2D', - shaderCache: {hint: attributes.cacheKey}, + shaderCache: {hint: `${attributes.cacheKey};`, inputDependencies}, getRunData: () => ({ dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]}, outputs: [{ dims: squeezeOutputShapeFunction ? squeezeOutputShapeFunction(outputShape) : outputShape, dataType: inputs[0].dataType - }] + }], + programUniforms }), - getShaderSource: (shaderHelper: ShaderHelper) => createConvTranspose2DOpProgramShaderSource( - shaderHelper, inputs, attributes, outputShape, hasBias, dispatch[1] === 1 && dispatch[2] === 1, false, - dataType), + getShaderSource }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts index 47ec16a296712..ee71110245252 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts @@ -22,7 +22,7 @@ import {TensorView} from '../../../tensor-view'; import {ShapeUtil} from '../../../util'; import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; -import {createTensorShapeVariables, enableShapesUniforms, getBroadcastDims, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common'; +import {createTensorShapeVariables, getBroadcastDims, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common'; import {getActivationSnippet, InternalActivationAttributes} from '../fuse-utils'; import {typeSnippet} from './activation_util'; @@ -112,14 +112,14 @@ fn main(@builtin(local_invocation_id) localId : vec3, ${batchDims ? `let batchIndices = ${batchDims.offsetToIndices('u32(batch)')};` : ''} let globalRowStart = i32(workgroupId.y) * ${tileAOuter}; - let numTiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(uniforms.dimInner - 1) / tileInner + 1'}; + let num_tiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(uniforms.dim_inner - 1) / tileInner + 1'}; var kStart = ${splitK ? `i32(globalId.z) * ${splitedDimInner}` : '0'}; var acc: array, rowPerThread>; // Loop over shared dimension. let tileRowB = localRow * ${rowPerThreadB}; - for (var t = 0; t < numTiles; t = t + 1) { + for (var t = 0; t < num_tiles; t = t + 1) { // Load one tile of A into local memory. for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) { let inputRow = tileRow + innerRow; @@ -204,7 +204,7 @@ export const makeMatMulPackedSource = let globalColStart = i32(workgroupId.x) * ${tileBOuter}; // Loop over shared dimension. - for (var t = 0; t < numTiles; t = t + 1) { + for (var t = 0; t < num_tiles; t = t + 1) { // Load one tile of A into local memory. for (var inputRow = localRow; inputRow < ${tileAHight}; inputRow = inputRow + ${workgroupSize[1]}) { for (var inputCol = localCol; inputCol < ${tileAWidth}; inputCol = inputCol + ${workgroupSize[0]}) { @@ -260,7 +260,7 @@ let tileRowA = i32(localId.y) * ${rowPerThreadA}; let tileColA = i32(localId.x) * ${colPerThreadA}; let tileRowB = i32(localId.y) * ${rowPerThreadB}; // Loop over shared dimension. -for (var t = 0; t < numTiles; t = t + 1) { +for (var t = 0; t < num_tiles; t = t + 1) { // Load one tile of A into local memory. for (var innerRow = 0; innerRow < ${rowPerThreadA}; innerRow = innerRow + 1) { for (var innerCol = 0; innerCol < ${colPerThreadA}; innerCol = innerCol + 1) { @@ -322,7 +322,8 @@ fn main(@builtin(local_invocation_id) localId : vec3, @builtin(workgroup_id) workgroupId : vec3) { let batch = ${splitK ? '0' : 'i32(globalId.z)'}; ${batchDims ? `let batchIndices = ${batchDims.offsetToIndices('u32(batch)')};` : ''} - let numTiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(uniforms.dimInner - 1) / tileInner + 1'}; + let num_tiles = ${ + splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(uniforms.dim_inner - 1) / tileInner + 1'}; var kStart = ${splitK ? `i32(globalId.z) * ${splitedDimInner}` : '0'}; var acc : array, rowPerThread>; @@ -379,7 +380,7 @@ const matMulReadWriteFnSource = typeSnippet(component, dataType)} { var value = ${typeSnippet(component, dataType)}(0.0); let col = colIn * ${component}; - if(row < uniforms.dimAOuter && col < uniforms.dimInner) + if(row < uniforms.dim_a_outer && col < uniforms.dim_inner) { ${getAIndices()} value = ${aVariable.getByIndices('aIndices')}; @@ -391,7 +392,7 @@ const matMulReadWriteFnSource = typeSnippet(component, dataType)} { var value = ${typeSnippet(component, dataType)}(0.0); let col = colIn * ${component}; - if(row < uniforms.dimInner && col < uniforms.dimBOuter) + if(row < uniforms.dim_inner && col < uniforms.dim_b_outer) { ${getBIndices()} value = ${bVariable.getByIndices('bIndices')}; @@ -401,7 +402,7 @@ const matMulReadWriteFnSource = fn mm_write(batch: i32, row: i32, colIn: i32, valueIn: ${typeSnippet(component, dataType)}) { let col = colIn * ${component}; - if (row < uniforms.dimAOuter && col < uniforms.dimBOuter) { + if (row < uniforms.dim_a_outer && col < uniforms.dim_b_outer) { var value = valueIn; let coords = vec3(batch, row, colIn); ${ @@ -422,16 +423,10 @@ export const createMatmulProgramInfo = isChannelsLast = false /* only used for conv2dByMatMul*/): ProgramInfo => { const aShape = inputs[0].dims; const bShape = inputs[1].dims; - const outerDimsA = aShape.slice(0, -2); const outerDimsB = bShape.slice(0, -2); - const outerDims = reshapedOutputShape ? reshapedOutputShape.slice(0, -2) : outputShape.slice(0, -2); - const enableBatchUniforms = enableShapesUniforms(outerDims.length); - const batchShapeOrRank = enableBatchUniforms ? outerDims.length : outerDims; - const batchDims = internalVariable('batchDims', inputs[0].dataType, batchShapeOrRank, 1); const batchSize = ShapeUtil.size(outerDims); - const dimAOuter = aShape[aShape.length - 2]; const dimInner = aShape[aShape.length - 1]; const dimBOuter = bShape[bShape.length - 1]; @@ -446,72 +441,67 @@ export const createMatmulProgramInfo = Math.ceil(batchSize / workgroupSize[2] / elementsPerThread[2]) ]; - const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); const components = isVec4 ? 4 : 1; - const aShapeTemp = [...outerDimsA, dimAOuter, dimInner / components]; - const enableAShapesUniforms = enableShapesUniforms(aShapeTemp.length); - const aShapeOrRank = enableAShapesUniforms ? aShapeTemp.length : aShapeTemp; - + const aShapeOrRank = aShapeTemp.length; const bShapeTemp = [...outerDimsB, dimInner, dimBOuter / components]; - const enableBShapesUniforms = enableShapesUniforms(bShapeTemp.length); - const bShapeOrRank = enableBShapesUniforms ? bShapeTemp.length : bShapeTemp; - + const bShapeOrRank = bShapeTemp.length; const outputShapeTemp = [batchSize, dimAOuter, dimBOuter / components]; - - const A = inputVariable('a', inputs[0].dataType, aShapeOrRank, components); - const B = inputVariable('b', inputs[1].dataType, bShapeOrRank, components); - const output = outputVariable('result', inputs[0].dataType, outputShapeTemp.length, components); - const inputVariables = [A, B]; const programUniforms: ProgramUniform[] = [{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}]; - if (enableBatchUniforms) { - programUniforms.push(...createTensorShapeVariables(outerDims)); + if (activationAttributes.activation === 'Clip') { + programUniforms.push( + {type: 'float32', data: activationAttributes.clipMax!}, + {type: 'float32', data: activationAttributes.clipMin!}); } - if (enableAShapesUniforms) { - programUniforms.push(...createTensorShapeVariables(aShapeTemp)); - } - if (enableBShapesUniforms) { - programUniforms.push(...createTensorShapeVariables(bShapeTemp)); - } - const inputDependencies: ProgramInputTensorInfoDependency[] = []; - inputDependencies.push(enableAShapesUniforms ? 'rank' : 'dims'); - inputDependencies.push(enableBShapesUniforms ? 'rank' : 'dims'); + programUniforms.push( + ...createTensorShapeVariables(outerDims), ...createTensorShapeVariables(aShapeTemp), + ...createTensorShapeVariables(bShapeTemp)); + const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; const hasBias = inputs.length > 2; - const {activationFunction, applyActivation} = getActivationSnippet(activationAttributes, output.type.value); - const declareFunctions = matMulReadWriteFnSource( - components, hasBias, applyActivation, [batchDims, A, B, output], [outerDimsA, outerDimsB, outerDims], - isChannelsLast); if (hasBias) { - const biasComponents = isChannelsLast ? components : 1; - inputVariables.push(inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, biasComponents)); programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); - inputDependencies.push('rank'); } programUniforms.push(...createTensorShapeVariables(outputShapeTemp)); - const getShaderSource = (shaderHelper: ShaderHelper) => ` + const getShaderSource = (shaderHelper: ShaderHelper) => { + const batchShapeOrRank = outerDims.length; + const batchDims = internalVariable('batchDims', inputs[0].dataType, batchShapeOrRank, 1); + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); + + const A = inputVariable('a', inputs[0].dataType, aShapeOrRank, components); + const B = inputVariable('b', inputs[1].dataType, bShapeOrRank, components); + const output = outputVariable('result', inputs[0].dataType, outputShapeTemp.length, components); + const inputVariables = [A, B]; + if (hasBias) { + const biasComponents = isChannelsLast ? components : 1; + inputVariables.push(inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, biasComponents)); + } + const uniforms: UniformsArrayType = + [{name: 'dim_a_outer', type: 'i32'}, {name: 'dim_b_outer', type: 'i32'}, {name: 'dim_inner', type: 'i32'}]; + if (activationAttributes.activation === 'Clip') { + uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'}); + } + const applyActivation = getActivationSnippet(activationAttributes, output.type.value); + const declareFunctions = matMulReadWriteFnSource( + components, hasBias, applyActivation, [batchDims, A, B, output], [outerDimsA, outerDimsB, outerDims], + isChannelsLast); + return ` ${ - shaderHelper.registerUniform('dimAOuter', 'i32') - .registerUniform('dimBOuter', 'i32') - .registerUniform('dimInner', 'i32') - .registerInternalVariables(batchDims) - .declareVariables(...inputVariables, output)} - ${activationFunction} + shaderHelper.registerUniforms(uniforms).registerInternalVariables(batchDims).declareVariables( + ...inputVariables, output)} ${declareFunctions} ${ - isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workgroupSize, dataType, batchDims) : - makeMatMulPackedSource(elementsPerThread, workgroupSize, dataType, batchDims)} + isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workgroupSize, dataType, batchDims) : + makeMatMulPackedSource(elementsPerThread, workgroupSize, dataType, batchDims)} `; - // TODO: turn clipMax and clipMin to uniforms. + }; return { name: 'MatMul', shaderCache: { - hint: activationAttributes.activationCacheKey + `${elementsPerThread}` + - `${isVec4}` + - `${isChannelsLast}`, + hint: `${elementsPerThread};${activationAttributes.activation};${isVec4};${isChannelsLast}`, inputDependencies }, getRunData: () => ({ diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index bc3265be955f0..643744108c0f4 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -330,18 +330,28 @@ export const sumVector = (name: string, components: number) => { * @param name - the name of variable. * @param index - the index of variable element. * @param length - the length of variable. + * @param type - the type of variable, optional. */ -export const getElementAt = (name: string, index: number|string, length: number): string => { - if (name.startsWith('uniforms.') && length > 4) { - if (typeof (index) === 'string') { - return `${name}[(${index}) / 4][(${index}) % 4]`; - } else { - return `${name}[${Math.floor(index / 4)}][${index % 4}]`; - } - } else { - return length > 1 ? `${name}[${index}]` : name; - } -}; +export const getElementAt = + (name: string, index: number|string, length: number, type?: UniformDataElementType): string => { + if (name.startsWith('uniforms.') && length > 4) { + if (typeof (index) === 'string') { + if (type === 'f16') { + return `${name}[(${index}) / 8][(${index}) % 8 / 4][(${index}) % 8 % 4]`; + } else { + return `${name}[(${index}) / 4][(${index}) % 4]`; + } + } else { + if (type === 'f16') { + return `${name}[${Math.floor(index / 8)}][${Math.floor(index % 8 / 4)}][${index % 8 % 4}]`; + } else { + return `${name}[${Math.floor(index / 4)}][${index % 4}]`; + } + } + } else { + return length > 1 ? `${name}[${index}]` : name; + } + }; /** * A helper function to get a IndicesHelper for a given input or output. @@ -688,7 +698,7 @@ export const internalVariable = (name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1): IndicesHelper => createIndicesHelper(name, type, shapeOrRank, 'internal', components); -export type UniformDataElementType = 'u32'|'f32'|'i32'; +export type UniformDataElementType = 'u32'|'f16'|'f32'|'i32'; export type UniformsArrayType = Array<{name: string; type: UniformDataElementType; length?: number}>; /** @@ -861,7 +871,11 @@ class ShaderHelperImpl implements ShaderHelper { const uniformSnippets: string[] = []; for (const {name, type, length} of this.uniforms) { if (length && length > 4) { - uniformSnippets.push(`${name}:array, ${Math.ceil(length / 4)}>`); + if (type === 'f16') { + uniformSnippets.push(`@align(16) ${name}:array, ${Math.ceil(length / 8)}>`); + } else { + uniformSnippets.push(`${name}:array, ${Math.ceil(length / 4)}>`); + } } else { const typeTemp = length == null || length === 1 ? type : `vec${length}<${type}>`; uniformSnippets.push(`${name}:${typeTemp}`); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts index 21b4953d3f90c..f81d6577890c5 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts @@ -3,9 +3,9 @@ import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; -import {ProgramInfo, ProgramUniform} from '../types'; +import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; -import {createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from './common'; import {calculateOutputShape, ConvAttributes} from './conv'; import {getActivationSnippet} from './fuse-utils'; @@ -27,52 +27,75 @@ export const createGroupedConvProgramInfo = xShape, wShape, attributes.dilations, attributes.pads, attributes.strides, isChannelLast); const outputSize = ShapeUtil.size(outputShape); - const output = outputVariable('output', inputs[0].dataType, outputShape); - const {activationFunction, applyActivation} = getActivationSnippet(attributes, output.type.value); - const x = inputVariable('x', inputs[0].dataType, xShape); - const w = inputVariable('w', inputs[1].dataType, wShape); - const inputVars = [x, w]; + const programUniforms: ProgramUniform[] = [ + {type: 'uint32', data: outputSize}, {type: 'uint32', data: attributes.dilations}, + {type: 'uint32', data: [attributes.strides[0], attributes.strides[1]]}, + {type: 'uint32', data: [attributes.pads[0], attributes.pads[1]]}, {type: 'uint32', data: outputChannelsPerGroup} + ]; + if (attributes.activation === 'Clip') { + programUniforms.push( + {type: 'float32', data: attributes.clipMax!}, {type: 'float32', data: attributes.clipMin!}); + } + programUniforms.push( + ...createTensorShapeVariables(xShape), ...createTensorShapeVariables(wShape), + ...createTensorShapeVariables(outputShape)); + const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; if (hasBias) { - inputVars.push(inputVariable('b', inputs[2].dataType, inputs[2].dims)); + programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); + inputDependencies.push('rank'); } + programUniforms.push(...createTensorShapeVariables(outputShape)); - const getShaderSource = (shaderHelper: ShaderHelper) => ` - const strides: vec2 = vec2(${attributes.strides[0]}u, ${attributes.strides[1]}u); - const pads: vec2 = vec2(${attributes.pads[0]}u, ${attributes.pads[1]}u); - - ${shaderHelper.declareVariables(...inputVars, output)} + const getShaderSource = (shaderHelper: ShaderHelper) => { + const output = outputVariable('output', inputs[0].dataType, outputShape.length); + const applyActivation = getActivationSnippet(attributes, output.type.value); + const x = inputVariable('x', inputs[0].dataType, xShape.length); + const w = inputVariable('w', inputs[1].dataType, wShape.length); + const inputVars = [x, w]; + if (hasBias) { + inputVars.push(inputVariable('b', inputs[2].dataType, inputs[2].dims)); + } - ${activationFunction} + const uniforms: UniformsArrayType = [ + {name: 'output_size', type: 'u32'}, {name: 'dilations', type: 'u32', length: attributes.dilations.length}, + {name: 'strides', type: 'u32', length: 2}, {name: 'pads', type: 'u32', length: 2}, + {name: 'output_channels_per_group', type: 'u32'} + ]; + if (attributes.activation === 'Clip') { + uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'}); + } + return ` + ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVars, output)} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} let outputIndices = ${output.offsetToIndices('global_idx')}; let batch: u32 = outputIndices[0]; let output_channel: u32 = outputIndices[${isChannelLast ? 3 : 1}]; let xRCCorner: vec2 = vec2(outputIndices[${isChannelLast ? 1 : 2}], outputIndices[${ - isChannelLast ? 2 : 3}]) * strides - pads; - let group_id: u32 = output_channel / ${outputChannelsPerGroup}u; + isChannelLast ? 2 : 3}]) * uniforms.strides - uniforms.pads; + let group_id: u32 = output_channel / uniforms.output_channels_per_group; var value: ${output.type.value} = ${output.type.value}(0); - for (var wInChannel: u32 = 0u; wInChannel < ${wShape[1]}u; wInChannel++) { - let input_channel = group_id * ${wShape[1]}u + wInChannel; - for (var wHeight: u32 = 0u; wHeight < ${wShape[2]}u; wHeight++) { - let xHeight = xRCCorner.x + wHeight * ${attributes.dilations[0]}u; + for (var wInChannel: u32 = 0u; wInChannel < uniforms.w_shape[1]; wInChannel++) { + let input_channel = group_id * uniforms.w_shape[1] + wInChannel; + for (var wHeight: u32 = 0u; wHeight < uniforms.w_shape[2]; wHeight++) { + let xHeight = xRCCorner.x + wHeight * uniforms.dilations[0]; - if (xHeight < 0u || xHeight >= ${xShape[isChannelLast ? 1 : 2]}u) { + if (xHeight < 0u || xHeight >= uniforms.x_shape[${isChannelLast ? 1 : 2}]) { continue; } - for (var wWidth: u32 = 0u; wWidth < ${wShape[3]}u; wWidth++) { - let xWidth = xRCCorner.y + wWidth * ${attributes.dilations[1]}u; - if (xWidth < 0u || xWidth >= ${xShape[isChannelLast ? 2 : 3]}u) { + for (var wWidth: u32 = 0u; wWidth < uniforms.w_shape[3]; wWidth++) { + let xWidth = xRCCorner.y + wWidth * uniforms.dilations[1]; + if (xWidth < 0u || xWidth >= uniforms.x_shape[${isChannelLast ? 2 : 3}]) { continue; } let xVal = ${ - isChannelLast ? x.get('batch', 'xHeight', 'xWidth', 'input_channel') : - x.get('batch', 'input_channel', 'xHeight', 'xWidth')}; + isChannelLast ? x.get('batch', 'xHeight', 'xWidth', 'input_channel') : + x.get('batch', 'input_channel', 'xHeight', 'xWidth')}; let wVal = ${w.get('output_channel', 'wInChannel', 'wHeight', 'wWidth')}; value += xVal*wVal; } @@ -82,15 +105,17 @@ export const createGroupedConvProgramInfo = ${applyActivation} ${output.setByOffset('global_idx', 'value')} }`; + }; return { name: 'GroupedConv', - shaderCache: {hint: attributes.cacheKey}, + shaderCache: {hint: attributes.cacheKey, inputDependencies}, getRunData: () => ({ outputs: [{ dims: squeezeOutputShapeFunction ? squeezeOutputShapeFunction(outputShape) : outputShape, dataType: inputs[0].dataType }], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + programUniforms }), getShaderSource, }; @@ -114,7 +139,7 @@ export const createGroupedConvVectorizeProgramInfo = const xNumber = (outputNumber - 1) * attributes.strides[1] + wShape[1]; const getShaderSource = (shaderHelper: ShaderHelper) => { const output = outputVariable('output', inputs[0].dataType, outputShapeInShader.length, components); - const {activationFunction, applyActivation} = getActivationSnippet(attributes, output.type.value); + const applyActivation = getActivationSnippet(attributes, output.type.value); const x = inputVariable('x', inputs[0].dataType, xShape.length, components); const w = inputVariable('w', inputs[1].dataType, wShape.length, components); const inputVars = [x, w]; @@ -129,7 +154,6 @@ export const createGroupedConvVectorizeProgramInfo = .registerUniform('strides', 'i32', 2) .registerUniform('pads', 'i32', 2) .declareVariables(...inputVars, output)} - ${activationFunction} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} let width0 = uniforms.output_shape[3]; @@ -179,7 +203,7 @@ export const createGroupedConvVectorizeProgramInfo = return { name: 'GroupedConv-Vectorize', shaderCache: { - hint: `${attributes.activationCacheKey};${components};${outputNumber};${xNumber};${wShape[0]};${wShape[1]}`, + hint: `${attributes.cacheKey};${components};${outputNumber};${xNumber};${wShape[0]};${wShape[1]}`, inputDependencies: hasBias ? ['rank', 'rank', 'type'] : ['rank', 'rank'] }, getRunData: () => ({ diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts index 32b1d52ed94ca..33d16754c737a 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts @@ -2,7 +2,6 @@ // Licensed under the MIT License. import {TensorView} from '../../tensor-view'; -import {createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext} from '../types'; import {createConv2DTransposeMatMulProgramInfo} from './3rd-party/conv_backprop_mm_webgpu'; @@ -59,7 +58,6 @@ export interface ConvTransposeAttributes extends ConvAttributes { readonly outputShape: readonly number[]; } - const getAdjustedConvTransposeAttributes = (attributes: T, inputs: readonly TensorView[]): T => { const kernelShape = attributes.kernelShape.slice(); @@ -96,11 +94,7 @@ const getAdjustedConvTransposeAttributes = // always return a new object so does not modify the original attributes const newAttributes: T = Object.assign({}, attributes); - const cacheKey = attributes.cacheKey + [ - kernelShape.join('n,'), pads.join(','), strides.join(','), outputPadding.join(','), outputShape.join(','), - dilations.join(',') - ].join('_'); - Object.assign(newAttributes, {kernelShape, pads, outputPadding, outputShape, dilations, strides, cacheKey}); + Object.assign(newAttributes, {kernelShape, pads, outputPadding, outputShape, dilations, strides}); return newAttributes; }; @@ -119,7 +113,7 @@ export const parseConvTransposeAttributes = (attributes: Record const wIsConst = (attributes.wIsConst as () => boolean)(); const outputPadding = attributes.outputPadding as [number, number, number, number]; const outputShape = attributes.outputShape as [number, number]; - return createAttributeWithCacheKey({ + return { autoPad, format, dilations, @@ -130,8 +124,9 @@ export const parseConvTransposeAttributes = (attributes: Record pads, strides, wIsConst, - ...activationAttributes - }); + ...activationAttributes, + cacheKey: `${attributes.format};${activationAttributes.activation};` + }; }; const validateInputs = (inputs: readonly TensorView[], attributes: ConvTransposeAttributes): void => { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts index 7af2c5db49f40..5afec0389fac8 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts @@ -3,7 +3,7 @@ import {TensorView} from '../../tensor-view'; import {PoolConvUtil} from '../../util'; -import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; +import {AttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext} from '../types'; import {createConv2DMatMulProgramInfo} from './3rd-party/conv2d_mm_webgpu'; @@ -110,7 +110,7 @@ const getAdjustedConvAttributes = (attributes: T, inpu // always return a new object so does not modify the original attributes const newAttributes: T = Object.assign({}, attributes); - Object.assign(newAttributes, {kernelShape, pads, cacheKey: attributes.cacheKey}); + Object.assign(newAttributes, {kernelShape, pads}); return newAttributes; }; @@ -126,8 +126,18 @@ export const parseConvAttributes = (attributes: Record): ConvAt const strides = attributes.strides as [number, number]; const wIsConst = (attributes.w_is_const as () => boolean)(); - return createAttributeWithCacheKey( - {autoPad, format, dilations, group, kernelShape, pads, strides, wIsConst, ...activationAttributes}); + return { + autoPad, + format, + dilations, + group, + kernelShape, + pads, + strides, + wIsConst, + ...activationAttributes, + cacheKey: `${attributes.format};${activationAttributes.activation};` + }; }; const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attributes: ConvAttributes): void => { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts b/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts index e904ed330a27d..2e0aa33a957dc 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts @@ -7,49 +7,29 @@ export interface InternalActivationAttributes { readonly activation: string; readonly clipMin?: number; readonly clipMax?: number; - readonly alpha?: number; - readonly beta?: number; - readonly activationCacheKey: string; } -export const getActivationSnippet = (attributes: InternalActivationAttributes, valueType: string): - {activationFunction: string; applyActivation: string} => { - switch (attributes.activation) { - case 'Relu': - return {activationFunction: '', applyActivation: `value = max(value, ${valueType}(0.0));`}; - case 'Sigmoid': - return { - activationFunction: '', - applyActivation: `value = (${valueType}(1.0) / (${valueType}(1.0) + exp(-value)));` - }; - case 'Clip': - return { - activationFunction: `const clip_min_=${valueType}(${attributes.clipMin!});const clip_max_=${valueType}(${ - attributes.clipMax!});`, - applyActivation: 'value = clamp(value, clip_min_, clip_max_);' - }; - case 'HardSigmoid': - return { - activationFunction: - `const alpha_ = ${valueType}(${attributes.alpha!});const beta_ = ${valueType}(${attributes.beta!});`, - applyActivation: `value = max(${valueType}(0.0), min(${valueType}(1.0), alpha_ * value + beta_));` - }; - // TODO: adding other activations that can be fused. - default: - return {activationFunction: '', applyActivation: ''}; - } - }; +export const getActivationSnippet = (attributes: InternalActivationAttributes, valueType: string): string => { + switch (attributes.activation) { + case 'Relu': + return `value = max(value, ${valueType}(0.0));`; + case 'Sigmoid': + return `value = (${valueType}(1.0) / (${valueType}(1.0) + exp(-value)));`; + case 'Clip': + return `value = clamp(value, ${valueType}(uniforms.clip_min), ${valueType}(uniforms.clip_max));`; + // TODO: adding other activations that can be fused. + default: + return ''; + } +}; export const parseInternalActivationAttributes = (attributes: Record|undefined): InternalActivationAttributes => { const activation = attributes?.activation as string || ''; - if (activation === 'HardSigmoid') { - const [alpha, beta] = attributes?.activation_params as [number, number] || [0.2, 0.5]; - return {activation, alpha, beta, activationCacheKey: `${activation}:${alpha},${beta}`}; - } + if (activation === 'Clip') { const [clipMin, clipMax] = attributes?.activation_params as [number, number] || [MIN_CLIP, MAX_CLIP]; - return {activation, clipMax, clipMin, activationCacheKey: `${activation}:${clipMin},${clipMax}`}; + return {activation, clipMax, clipMin}; } - return {activation, activationCacheKey: activation}; + return {activation}; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts index de9309d1e436f..c946ea6366123 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts @@ -6,7 +6,7 @@ import {BroadcastUtil, ShapeUtil} from '../../util'; import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; import {createMatmulProgramInfo} from './3rd-party/matmul_packed_webgpu'; -import {createTensorShapeVariables, getBroadcastDims, getMaxComponents, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper,} from './common'; +import {createTensorShapeVariables, getBroadcastDims, getMaxComponents, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper, UniformsArrayType,} from './common'; import {getActivationSnippet, InternalActivationAttributes} from './fuse-utils'; export const createNaiveMatmulProgramInfo = @@ -27,11 +27,19 @@ export const createNaiveMatmulProgramInfo = const outerDims = reshapedOutputShape ? reshapedOutputShape.slice(0, -2) : outputShape.slice(0, -2); const batchSize = ShapeUtil.size(outerDims); const outputShapeInShader = [batchSize, M, N]; + const programUniforms: ProgramUniform[] = [ {type: 'uint32', data: outputSize}, {type: 'uint32', data: M}, {type: 'uint32', data: N}, - {type: 'uint32', data: K}, ...createTensorShapeVariables(outerDims), ...createTensorShapeVariables(aShape), - ...createTensorShapeVariables(bShape) + {type: 'uint32', data: K} ]; + if (activationAttributes.activation === 'Clip') { + programUniforms.push( + {type: 'float32', data: activationAttributes.clipMax!}, + {type: 'float32', data: activationAttributes.clipMin!}); + } + programUniforms.push( + ...createTensorShapeVariables(outerDims), ...createTensorShapeVariables(aShape), + ...createTensorShapeVariables(bShape)); if (hasBias) { programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); } @@ -42,7 +50,7 @@ export const createNaiveMatmulProgramInfo = const a = inputVariable('a', inputs[0].dataType, aShape.length, aComponents); const b = inputVariable('b', inputs[1].dataType, bShape.length, components); const output = outputVariable('output', inputs[0].dataType, outputShapeInShader.length, components); - const {activationFunction, applyActivation} = getActivationSnippet(activationAttributes, output.type.value); + const applyActivation = getActivationSnippet(activationAttributes, output.type.value); const inputVariables = [a, b]; let processBias = ''; if (hasBias) { @@ -57,6 +65,14 @@ export const createNaiveMatmulProgramInfo = const outerDimsB = bShape.slice(0, -2); const broadCastADims = getBroadcastDims(outerDimsA, outerDims); const broadCastBDims = getBroadcastDims(outerDimsB, outerDims); + const uniforms: UniformsArrayType = [ + {name: 'output_size', type: 'u32'}, {name: 'M', type: 'u32'}, {name: 'N', type: 'u32'}, + {name: 'K', type: 'u32'} + ]; + if (activationAttributes.activation === 'Clip') { + uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'}); + } + const getIndices = (variable: IndicesHelper, broadCastDims: number[]) => { const rank = variable.rank; const name = variable.name; @@ -96,15 +112,10 @@ export const createNaiveMatmulProgramInfo = return ` ${ - shaderHelper.registerUniform('outputSize', 'u32') - .registerUniform('M', 'u32') - .registerUniform('N', 'u32') - .registerUniform('K', 'u32') - .registerInternalVariables(batchDims) - .declareVariables(...inputVariables, output)} - ${activationFunction} + shaderHelper.registerUniforms(uniforms).registerInternalVariables(batchDims).declareVariables( + ...inputVariables, output)} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} let col = (global_idx % (uniforms.N / ${components})) * ${components}; var index1 = global_idx / (uniforms.N / ${components}); let stride1 = uniforms.M / ${outputNumber}; @@ -134,8 +145,7 @@ export const createNaiveMatmulProgramInfo = return { name: 'MatMulNaive', shaderCache: { - hint: `${activationAttributes.activationCacheKey}_${components}_${aComponents}_${outputNumber}_${ - isChannelsLast}`, + hint: `${activationAttributes.activation};${components};${aComponents};${outputNumber};${isChannelsLast}`, inputDependencies: hasBias ? ['rank', 'rank', 'rank'] : ['rank', 'rank'] }, getRunData: () => ({ @@ -166,9 +176,8 @@ export const matMul = (context: ComputeContext): void => { const N = outputShape[outputShape.length - 1]; const K = context.inputs[0].dims[context.inputs[0].dims.length - 1]; if (N < 8 && K < 8) { - context.compute( - createNaiveMatmulProgramInfo(context.inputs, {activation: '', activationCacheKey: ''}, outputShape)); + context.compute(createNaiveMatmulProgramInfo(context.inputs, {activation: ''}, outputShape)); } else { - context.compute(createMatmulProgramInfo(context.inputs, {activation: '', activationCacheKey: ''}, outputShape)); + context.compute(createMatmulProgramInfo(context.inputs, {activation: ''}, outputShape)); } }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/pad.ts b/js/web/lib/wasm/jsep/webgpu/ops/pad.ts index eca3fa7d944bb..c65b741e1105a 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/pad.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/pad.ts @@ -19,8 +19,8 @@ const validateInputs = (inputs: readonly TensorView[]): void => { if (!inputs || inputs.length < 1) { throw new Error('Too few inputs'); } - if (inputs[0].dataType !== DataType.float) { - throw new Error('Input type must be float.'); + if (inputs[0].dataType !== DataType.float && inputs[0].dataType !== DataType.float16) { + throw new Error('Input type must be float or float16.'); } if (inputs.length >= 2) { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts index 82311d72e58b9..76929efb32537 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts @@ -279,7 +279,9 @@ export const tan = (context: ComputeContext): void => { }; export const tanh = (context: ComputeContext): void => { - context.compute(createElementwiseProgramInfo(context.inputs[0], 'Tanh', 'tanh')); + // TODO: revisit after https://github.com/gpuweb/gpuweb/issues/4458 is resolved + context.compute(createElementwiseProgramInfo( + context.inputs[0], 'Tanh', a => `sign(${a}) * (1 - exp(-2 * abs(${a}))) / (1 + exp(-2 * abs(${a})))`)); }; export const thresholdedRelu = (context: ComputeContext, attributes: AlphaAttributes): number => { diff --git a/js/web/lib/wasm/jsep/webgpu/types.ts b/js/web/lib/wasm/jsep/webgpu/types.ts index e55bfb6ba9f16..789ac70a6913a 100644 --- a/js/web/lib/wasm/jsep/webgpu/types.ts +++ b/js/web/lib/wasm/jsep/webgpu/types.ts @@ -24,7 +24,7 @@ export interface TensorInfo { } export interface ProgramUniform { - type: 'int32'|'float32'|'uint32'; + type: 'int32'|'float16'|'float32'|'uint32'; data: number|readonly number[]; } diff --git a/js/web/test/data/ops/tanh.jsonc b/js/web/test/data/ops/tanh.jsonc new file mode 100644 index 0000000000000..f7691535bd71c --- /dev/null +++ b/js/web/test/data/ops/tanh.jsonc @@ -0,0 +1,26 @@ +[ + { + "name": "tanh with no attributes", + "operator": "Tanh", + "attributes": [], + "cases": [ + { + "name": "T[2,4]", + "inputs": [ + { + "data": [-1000, -1, 0, 0.1, 0.2, 0.3, 0.4, 1000], + "dims": [2, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [-1, -0.761594, 0, 0.099668, 0.197375, 0.291313, 0.379949, 1], + "dims": [2, 4], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 373b3c645df57..56db28b0a379c 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -1389,6 +1389,7 @@ "sub.jsonc", "sub_int32.jsonc", "tan.jsonc", + "tanh.jsonc", "tile.jsonc", "transpose.jsonc", "transpose_int32_uint32.jsonc", diff --git a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc index 9b989dac9a94b..40a667ffd5d83 100644 --- a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc +++ b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc @@ -1,8 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef USE_CUTLASS - #include "core/common/safeint.h" #include "core/providers/cuda/cuda_common.h" #include "contrib_ops/cuda/bert/transformer_cuda_common.h" @@ -204,5 +202,3 @@ Status ShardedMoE::SynchronizeExpertsStartIndex(AllocatorPtr& allocator, } // namespace cuda } // namespace contrib } // namespace onnxruntime - -#endif diff --git a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.h b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.h index cbd483fddab78..5ea4ae59c4020 100644 --- a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.h +++ b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.h @@ -1,8 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef USE_CUTLASS - #pragma once #include "contrib_ops/cuda/moe/ft_moe/moe_kernel.h" @@ -36,5 +34,3 @@ class ShardedMoE final : public NcclKernel, public MoEBase { } // namespace cuda } // namespace contrib } // namespace onnxruntime - -#endif diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index fa73950c9c6f5..8f368251f12c7 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -70,10 +70,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, Crop); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, Crop); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, Crop); -#ifdef USE_CUTLASS class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MoE); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MoE); -#endif class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MultiHeadAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MultiHeadAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GroupQueryAttention); @@ -169,10 +167,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllR class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllGather); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllToAll); -#ifdef USE_CUTLASS class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ShardedMoE); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ShardedMoE); -#endif class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedMatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedMatMul); @@ -272,10 +268,8 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, -#ifdef USE_CUTLASS BuildKernelCreateInfo, BuildKernelCreateInfo, -#endif BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -377,10 +371,8 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, -#ifdef USE_CUTLASS BuildKernelCreateInfo, BuildKernelCreateInfo, -#endif BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_common_base.cc b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_common_base.cc new file mode 100644 index 0000000000000..5dec690528847 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_common_base.cc @@ -0,0 +1,101 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// The CUDA kernel is modified from GroupNorm plugin of TensorRT 8.5 +// Modifications: heuristic channels per block; support epsilon; support skip and bias; update coding style. +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cuda/diffusion/group_norm_common_base.h" + +using namespace onnxruntime::cuda; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +int NextSize(int x) { + for (size_t i = 0; i < kNumOfSizes; ++i) { + if (x <= kSizes[i]) { + return kSizes[i]; + } + } + + return x; +} + +int32_t GetThreadsPerBlock(int32_t channels_per_block, int32_t channels_per_thread) { + return NextSize(channels_per_block) / channels_per_thread; +} + +int32_t FindMaxDivisor(int32_t n, int32_t max_allowed_divisor) { + int32_t max_divisor = -1; + for (int32_t i = 1; i <= std::sqrt(n); i++) { + if (n % i == 0) { + int32_t divisor1 = n / i; + int32_t divisor2 = i; + + if (divisor1 > max_divisor && divisor1 < max_allowed_divisor) { + max_divisor = divisor1; + } + if (divisor2 > max_divisor && divisor2 < max_allowed_divisor) { + max_divisor = divisor2; + } + } + } + return max_divisor; +} + +// Find proper channels per block based on a cost function: The cost is number of channels corresponding to +// extra threads allocated but no channels assigned to them to work on. If cost is zero, every thread has +// work to do so it is ideal case. +int FindChannelsPerBlock(int num_channels, int channels_per_group) { + int min_cost = -1; + int best_candidate = -1; + for (size_t i = kNumOfSizes; i > 0; --i) { + if (kSizes[i - 1] < channels_per_group) { + break; + } + + int channels_per_block = kSizes[i - 1] / channels_per_group * channels_per_group; + int blocks = (num_channels + channels_per_block - 1) / channels_per_block; + int cost = blocks * kSizes[i - 1] - num_channels; + if (cost == 0) { + return channels_per_block; + } + + if (min_cost == -1 || cost < min_cost) { + min_cost = cost; + best_candidate = channels_per_block; + } + } + + return best_candidate; +} + +int GetChannelsPerBlock(int num_channels, int num_groups) { + int32_t channels_per_group = num_channels / num_groups; + int32_t channels_per_block = channels_per_group; + if (channels_per_group < kMaxSize / 2) { + channels_per_block = FindChannelsPerBlock(num_channels, channels_per_group); + } + return channels_per_block; +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_common_base.h b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_common_base.h new file mode 100644 index 0000000000000..84f3403b8d5ae --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_common_base.h @@ -0,0 +1,186 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// The CUDA kernel is modified from GroupNorm plugin of TensorRT 8.5 +// Modifications: heuristic channels per block; support epsilon; support skip and bias; update coding style. +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once +#include "core/providers/cuda/cuda_common.h" +using namespace onnxruntime::cuda; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +// TODO: Similar to SkipLayerNorm kernel, read/write up to 8 channels at same time. +constexpr static int32_t CHANNELS_PER_THREAD = 2; + +constexpr static int kSizes[] = {128, 256, 320, 384, 512}; +constexpr static size_t kNumOfSizes = sizeof(kSizes) / sizeof(kSizes[0]); +constexpr static int kMaxSize = kSizes[kNumOfSizes - 1]; + +int32_t GetThreadsPerBlock(int32_t channels_per_block, int32_t channels_per_thread); + +static inline int32_t DivUp(int32_t m, int32_t n) { + return (m + n - 1) / n; +} + +int32_t FindMaxDivisor(int32_t n, int32_t max_allowed_divisor); + +int GetChannelsPerBlock(int num_channels, int num_groups); + +template +struct GroupNormNHWCParams { + // The output buffer. Shape is (n, h, w, c). + T* dst; + + // Optional output of element-wise add result of src, skip and bias. Shape is (n, h, w, c). + T* add_out; + + // The input buffer. Shape is (n, h, w, c). + T const* src; + + // Optional input buffer for skip tensor. Shape is (n, h, w, c) or (n, 1, 1, c) or (n, c). + T const* skip; + + // Optional input buffer for bias tensor. Shape is (c). + T const* bias; + + // The gamma scaling factor. + float const* gamma; + + // The beta term to add in GN. + float const* beta; + + // The temporary buffer to do the global parallel reduction. Shape is (n, 2, g), where g is number of groups. + float* group_sum_buffer; + + // The number of instances in the batch. + int32_t n; + + // The height and width of each activation map. + int32_t h; + int32_t w; + + // Number of channels. + int32_t c; + + // Number of groups. + int32_t groups; + + // Do we apply the SiLU activation function? + bool use_silu; + + // Precomputed values and parameters to control the execution of the kernels. + + // Number of activations per instance (h * w) + int32_t hw; + + // Number of activations per block + int32_t hw_per_block; + + // Number of channels per block in the C dimension. + int32_t channels_per_block; + + // Number of channels per group in the C dimension. + int32_t channels_per_group; + + // The precomputed stride between instances. + int32_t hwc; + // The inverse of hw*channels_per_group to compute mean of a group. + float inv_hw_channels_per_group; + // The precomputed number of groups per block. + int32_t groups_per_block; + + // Number of threads per block + int32_t threads_per_block; + + // Epsilon to get stable variance in normalization. + float epsilon; + + // Whether skip need broadcast. True if shape of skip is (N, C) or (N, 1, 1, C); False otherwise. + bool broadcast_skip; + + // For SkipGroupNorm, it points to the intermediate result of adding skip and bias. + T* skip_workspace; + + GroupNormNHWCParams(T* output, + T* add_out, + const T* input, + const T* skip, + const T* bias, + const float* gamma, + const float* beta, + void* workspace, + float epsilon, + int batch_size, + int num_channels, + int height, + int width, + int num_groups, + bool use_silu, + bool broadcast_skip, + int channels_per_block) { + int32_t channels_per_group = num_channels / num_groups; + // channels_per_block is computed in PrePack. + // If the gamma is not initializer, channels_per_block might be zero after PrePack. In that happens, compute it here. + if (channels_per_block < channels_per_group) { + channels_per_block = GetChannelsPerBlock(num_channels, num_groups); + } + + this->use_silu = use_silu; + this->dst = output; + this->add_out = add_out; + this->src = input; + this->skip = skip; + this->bias = bias; + this->gamma = gamma; + this->beta = beta; + this->group_sum_buffer = reinterpret_cast(workspace); + this->n = batch_size; + this->h = height; + this->w = width; + this->c = num_channels; + this->groups = num_groups; + this->hw = this->h * this->w; + + // This will allocate as many blocks as possible to partition HW. + // For Stable Diffusion, latent hw is 4K ~ 16K. This will allocate 1024 blocks, and each handles 4~16 hw. + // TODO: tune this logic to find proper blocks when hw is small. + constexpr int32_t max_blocks_per_hw = 1024; + const int32_t blocks_per_hw = FindMaxDivisor(this->hw, max_blocks_per_hw); + this->hw_per_block = DivUp(this->hw, blocks_per_hw); + + this->channels_per_block = channels_per_block; + this->channels_per_group = channels_per_group; + this->hwc = this->hw * this->c; + this->inv_hw_channels_per_group = 1.F / (float)(this->hw * this->channels_per_group); + this->groups_per_block = channels_per_block / this->channels_per_group; + this->epsilon = epsilon; + this->broadcast_skip = broadcast_skip; + + // Workspace for SkipGroupNorm to store intermediate results of src+skip+bias. + this->skip_workspace = (this->add_out != nullptr) ? this->add_out : this->dst; + + this->threads_per_block = GetThreadsPerBlock(channels_per_block, CHANNELS_PER_THREAD); + } +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu index 48b161552ce0c..d7b2cc2379f4f 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu @@ -27,6 +27,8 @@ #include "core/providers/cuda/cu_inc/common.cuh" #include "contrib_ops/cuda/diffusion/group_norm_impl.h" #include "contrib_ops/cuda/transformers/dump_cuda_tensor.h" +#include "contrib_ops/cuda/diffusion/group_norm_common_base.h" +#include "contrib_ops/cuda/diffusion/group_norm_impl_kernel.cuh" using namespace onnxruntime::cuda; @@ -34,329 +36,6 @@ namespace onnxruntime { namespace contrib { namespace cuda { -namespace { - -// TODO: Similar to SkipLayerNorm kernel, read/write up to 8 channels at same time. -constexpr static int32_t CHANNELS_PER_THREAD = 2; - -constexpr static int kSizes[] = {128, 256, 320, 384, 512}; -constexpr static size_t kNumOfSizes = sizeof(kSizes) / sizeof(kSizes[0]); -constexpr static int kMaxSize = kSizes[kNumOfSizes - 1]; - -int NextSize(int x) { - for (size_t i = 0; i < kNumOfSizes; ++i) { - if (x <= kSizes[i]) { - return kSizes[i]; - } - } - - return x; -} -} // namespace - -static inline int32_t DivUp(int32_t m, int32_t n) { - return (m + n - 1) / n; -} - -static inline __device__ __host__ float sigmoid(float x) { - return 1.F / (1.F + expf(-x)); -} - -struct GroupSums { - // Is it the 1st element of the group? - int32_t flag; - // The sum. - float sum; - // The sum of squares. - float sum_sq; -}; - -struct GroupSumsOp { - inline __device__ GroupSums operator()(GroupSums const& a, GroupSums const& b) { - GroupSums dst; - dst.sum = b.flag ? b.sum : (a.sum + b.sum); - dst.sum_sq = b.flag ? b.sum_sq : (a.sum_sq + b.sum_sq); - dst.flag = a.flag + b.flag; - return dst; - } -}; - -template -struct GroupNormNHWCParams { - // The output buffer. Shape is (n, h, w, c). - T* dst; - - // Optional output of element-wise add result of src, skip and bias. Shape is (n, h, w, c). - T* add_out; - - // The input buffer. Shape is (n, h, w, c). - T const* src; - - // Optional input buffer for skip tensor. Shape is (n, h, w, c) or (n, 1, 1, c) or (n, c). - T const* skip; - - // Optional input buffer for bias tensor. Shape is (c). - T const* bias; - - // The gamma scaling factor. - float const* gamma; - - // The beta term to add in GN. - float const* beta; - - // The temporary buffer to do the global parallel reduction. Shape is (n, 2, g), where g is number of groups. - float* group_sum_buffer; - - // The number of instances in the batch. - int32_t n; - - // The height and width of each activation map. - int32_t h; - int32_t w; - - // Number of channels. - int32_t c; - - // Number of groups. - int32_t groups; - - // Do we apply the SiLU activation function? - bool use_silu; - - // Precomputed values and parameters to control the execution of the kernels. - - // Number of activations per instance (h * w) - int32_t hw; - - // Number of activations per block - int32_t hw_per_block; - - // Number of channels per block in the C dimension. - int32_t channels_per_block; - - // Number of channels per group in the C dimension. - int32_t channels_per_group; - - // The precomputed stride between instances. - int32_t hwc; - // The inverse of hw*channels_per_group to compute mean of a group. - float inv_hw_channels_per_group; - // The precomputed number of groups per block. - int32_t groups_per_block; - - // Number of threads per block - int32_t threads_per_block; - - // Epsilon to get stable variance in normalization. - float epsilon; - - // Whether skip need broadcast. True if shape of skip is (N, C) or (N, 1, 1, C); False otherwise. - bool broadcast_skip; - - // For SkipGroupNorm, it points to the intermediate result of adding skip and bias. - T* skip_workspace; -}; - -template -inline __device__ void UpdateSum(const T* src, int64_t offset, float& sum, float& sum_sq); - -template <> -inline __device__ void UpdateSum(const half* src, int64_t offset, float& sum, float& sum_sq) { - // Fetch two channels per thread. - __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); - - float2 f2 = __half22float2(h2); - - // Update the sum. - sum += f2.x + f2.y; - - // Update the sum of squares. - sum_sq += f2.x * f2.x + f2.y * f2.y; -} - -template <> -inline __device__ void UpdateSum(const float* src, int64_t offset, float& sum, float& sum_sq) { - // Fetch two channels per thread. - float2 f2 = *reinterpret_cast(&src[offset]); - - // Update the sum. - sum += f2.x + f2.y; - - // Update the sum of squares. - sum_sq += f2.x * f2.x + f2.y * f2.y; -} - -// Sum for SkipGroupNorm: add_out[offset] = src[offset] + skip[skip_offset] + bias[bias_offset] -template -inline __device__ void AddSkipBias(T* add_out, const T* src, const T* skip, const T* bias, - int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq); - -template <> -inline __device__ void AddSkipBias(half* add_out, const half* src, const half* skip, const half* bias, - int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq) { - // Fetch two channels per thread. - __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); - __half2 s = *reinterpret_cast<__half2 const*>(&skip[skip_offset]); - __half2 b = *reinterpret_cast<__half2 const*>(&bias[bias_offset]); - h2 = h2 + b; - h2 = h2 + s; - - *reinterpret_cast<__half2*>(&add_out[offset]) = h2; - - float2 f2 = __half22float2(h2); - sum += f2.x + f2.y; - sum_sq += f2.x * f2.x + f2.y * f2.y; -} - -template <> -inline __device__ void AddSkipBias(float* add_out, const float* src, const float* skip, const float* bias, - int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq) { - float2 f2 = *reinterpret_cast(&src[offset]); - float2 s = *reinterpret_cast(&skip[skip_offset]); - float2 b = *reinterpret_cast(&bias[bias_offset]); - f2.x += s.x + b.x; - f2.y += s.y + b.y; - - *reinterpret_cast(&add_out[offset]) = f2; - - sum += f2.x + f2.y; - sum_sq += f2.x * f2.x + f2.y * f2.y; -} - -// Sum for SkipGroupNorm without bias: add_out[offset] = src[offset] + skip[skip_offset] -template -inline __device__ void AddSkip(T* add_out, const T* src, const T* skip, - int64_t offset, int64_t skip_offset, float& sum, float& sum_sq); - -template <> -inline __device__ void AddSkip(half* add_out, const half* src, const half* skip, - int64_t offset, int64_t skip_offset, float& sum, float& sum_sq) { - __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); - __half2 s = *reinterpret_cast<__half2 const*>(&skip[skip_offset]); - h2 = h2 + s; - - *reinterpret_cast<__half2*>(&add_out[offset]) = h2; - - float2 f2 = __half22float2(h2); - sum += f2.x + f2.y; - sum_sq += f2.x * f2.x + f2.y * f2.y; -} - -template <> -inline __device__ void AddSkip(float* add_out, const float* src, const float* skip, - int64_t offset, int64_t skip_offset, float& sum, float& sum_sq) { - float2 f2 = *reinterpret_cast(&src[offset]); - float2 s = *reinterpret_cast(&skip[skip_offset]); - f2.x += s.x; - f2.y += s.y; - *reinterpret_cast(&add_out[offset]) = f2; - sum += f2.x + f2.y; - sum_sq += f2.x * f2.x + f2.y * f2.y; -} - -template -__global__ void GroupNormNHWCSumKernel(GroupNormNHWCParams params) { - // The object in charge of doing the sums for the different blocks. - typedef cub::BlockScan BlockScan; - - // Allocate shared memory for BlockScan. - __shared__ typename BlockScan::TempStorage temp_storage; - - // Allocate shared memory for the groups. We could reduce the amount of shared memory reserved. - __shared__ float2 smem[THREADS_PER_BLOCK]; - - // The instance in the batch. - int32_t ni = blockIdx.z; - - // The channel loaded by that thread. - int32_t ci = blockIdx.x * params.channels_per_block + threadIdx.x * CHANNELS_PER_THREAD; - - if (ci >= params.c || threadIdx.x * CHANNELS_PER_THREAD >= params.channels_per_block) { - return; - } - - // The first activation loaded by that block. - int32_t hw_begin = blockIdx.y * params.hw_per_block; - // The last activation loaded by that block. - int32_t hw_end = min(hw_begin + params.hw_per_block, params.hw); - - // The sums. - float sum = 0.F; - float sum_sq = 0.F; - - // Iterate over the activations to compute the sums. - int64_t offset = static_cast(ni) * params.hwc + static_cast(hw_begin) * params.c + ci; - if (params.skip != nullptr) { - // SkipGroupNorm: skip is (n, h, w, c) or (n, 1, 1, c) or (n, c), bias is (c), and add_out is (n, h, w, c) - const int64_t bias_offset = static_cast(ci); - T* add_out = params.skip_workspace; - if (params.broadcast_skip) { - const int64_t skip_offset = static_cast(ni) * params.c + ci; - - if (params.bias != nullptr) { - for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { - AddSkipBias(add_out, params.src, params.skip, params.bias, offset, skip_offset, bias_offset, sum, sum_sq); - } - } else { - for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { - AddSkip(add_out, params.src, params.skip, offset, skip_offset, sum, sum_sq); - } - } - } else { - if (params.bias != nullptr) { - for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { - AddSkipBias(add_out, params.src, params.skip, params.bias, offset, offset, bias_offset, sum, sum_sq); - } - } else { - for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { - AddSkip(add_out, params.src, params.skip, offset, offset, sum, sum_sq); - } - } - } - } else { // GroupNorm - for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { - UpdateSum(params.src, offset, sum, sum_sq); - } - } - - // The group index relative to the first group within the same block. - int32_t gi = threadIdx.x * CHANNELS_PER_THREAD / params.channels_per_group; - // The channel in the group. - int32_t cj = ci % params.channels_per_group; - - // The data for the summations. - GroupSums inp{cj == 0 ? 1 : 0, sum, sum_sq}; - - // Do the segmented scan. InclusiveScan is not deterministic. - GroupSums out; - BlockScan(temp_storage).InclusiveScan(inp, out, GroupSumsOp()); - - // Store the results for the groups in shared memory (to produce coalesced stores later). - // For each group, only the last thread of that group is picked to save sum to shared memory. - if (cj == params.channels_per_group - CHANNELS_PER_THREAD) { - smem[gi] = make_float2(out.sum, out.sum_sq); - } - - // Make sure the data is in shared memory. - __syncthreads(); - - // Threads that have nothing left to do, exit. - if (threadIdx.x >= params.groups_per_block) { - return; - } - - // The global group index. - // Use neighboring threads for coalesced write. - int32_t gj = blockIdx.x * params.groups_per_block + threadIdx.x; - - if (gj < params.groups) { - float2 sums = smem[threadIdx.x]; - const int index = (2 * ni) * params.groups + gj; - atomicAdd(¶ms.group_sum_buffer[index], sums.x); - atomicAdd(¶ms.group_sum_buffer[index + params.groups], sums.y); - } -} - template void GroupNormNHWCSum(GroupNormNHWCParams const& params, cudaStream_t stream) { dim3 grid; @@ -390,102 +69,6 @@ void GroupNormNHWCSum(GroupNormNHWCParams const& params, cudaStream_t stream) } } -template -__device__ void ComputeGroupNorm(const T* src, T* dst, int64_t offset, float mean, float inv_std_dev, - float2& gamma_f2, float2& beta_f2, bool silu); - -template <> -__device__ void ComputeGroupNorm(const half* src, half* dst, int64_t offset, float mean, float inv_std_dev, - float2& gamma_f2, float2& beta_f2, bool silu) { - // Fetch two channels per thread. - __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); - - // Extract the two half values. - float2 f2 = __half22float2(h2); - - // Normalize the channels. - f2.x = (f2.x - mean) * inv_std_dev; - f2.y = (f2.y - mean) * inv_std_dev; - - // Scale by gamma and add beta. - f2.x = gamma_f2.x * f2.x + beta_f2.x; - f2.y = gamma_f2.y * f2.y + beta_f2.y; - - // Apply SiLU activation if needed. - if (silu) { - f2.x = f2.x * sigmoid(f2.x); - f2.y = f2.y * sigmoid(f2.y); - } - - *reinterpret_cast<__half2*>(&dst[offset]) = __float22half2_rn(f2); -} - -template <> -__device__ void ComputeGroupNorm(const float* src, float* dst, int64_t offset, float mean, float inv_std_dev, - float2& gamma_f2, float2& beta_f2, bool silu) { - // Fetch two channels per thread. - float2 f2 = *reinterpret_cast(&src[offset]); - - // Normalize the channels. - f2.x = (f2.x - mean) * inv_std_dev; - f2.y = (f2.y - mean) * inv_std_dev; - - // Scale by gamma and add beta. - f2.x = gamma_f2.x * f2.x + beta_f2.x; - f2.y = gamma_f2.y * f2.y + beta_f2.y; - - // Apply SiLU activation if needed. - if (silu) { - f2.x = f2.x * sigmoid(f2.x); - f2.y = f2.y * sigmoid(f2.y); - } - - *reinterpret_cast(&dst[offset]) = f2; -} - -template -__global__ void GroupNormNHWCScaleKernel(GroupNormNHWCParams params) { - // The channel loaded by that thread. - int32_t ci = blockIdx.x * params.channels_per_block + threadIdx.x * CHANNELS_PER_THREAD; - if (ci >= params.c || threadIdx.x * CHANNELS_PER_THREAD >= params.channels_per_block) { - return; - } - - // The instance in the batch. - int32_t ni = blockIdx.z; - - // The group that thread works on. - int32_t gi = ci / params.channels_per_group; - - // Load the sum and sum of squares for the group. - float sum = 0.F, sum_sq = 0.F; - if (gi < params.groups) { - const int index = (2 * ni) * params.groups + gi; - sum = params.group_sum_buffer[index]; - sum_sq = params.group_sum_buffer[index + params.groups]; - } - - // Load gamma/beta. Fetch two per thread. - float2 gamma_f2 = *reinterpret_cast(¶ms.gamma[ci]); - float2 beta_f2 = *reinterpret_cast(¶ms.beta[ci]); - - // Compute the mean. - float mean = sum * params.inv_hw_channels_per_group; - // Compute the variance. - float var = sum_sq * params.inv_hw_channels_per_group - (mean * mean); - // Compute the inverse of the stddev. - float inv_std_dev = rsqrtf(var + params.epsilon); - - int32_t hw_begin = blockIdx.y * params.hw_per_block; - int32_t hw_end = min(hw_begin + params.hw_per_block, params.hw); - - const T* input = (params.skip != nullptr) ? params.skip_workspace : params.src; - int64_t offset = static_cast(ni) * params.hwc + static_cast(hw_begin) * params.c + ci; - for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { - ComputeGroupNorm(input, params.dst, offset, mean, inv_std_dev, gamma_f2, beta_f2, params.use_silu); - } -} - template void GroupNormNHWCScale(GroupNormNHWCParams const& params, cudaStream_t stream) { dim3 grid; @@ -517,60 +100,6 @@ void GroupNormNHWCScale(GroupNormNHWCParams const& params, cudaStream_t strea } } -int32_t FindMaxDivisor(int32_t n, int32_t max_allowed_divisor) { - int32_t max_divisor = -1; - for (int32_t i = 1; i <= std::sqrt(n); i++) { - if (n % i == 0) { - int32_t divisor1 = n / i; - int32_t divisor2 = i; - - if (divisor1 > max_divisor && divisor1 < max_allowed_divisor) { - max_divisor = divisor1; - } - if (divisor2 > max_divisor && divisor2 < max_allowed_divisor) { - max_divisor = divisor2; - } - } - } - return max_divisor; -} - -// Find proper channels per block based on a cost function: The cost is number of channels corresponding to -// extra threads allocated but no channels assigned to them to work on. If cost is zero, every thread has -// work to do so it is ideal case. -int FindChannelsPerBlock(int num_channels, int channels_per_group) { - int min_cost = -1; - int best_candidate = -1; - for (size_t i = kNumOfSizes; i > 0; --i) { - if (kSizes[i - 1] < channels_per_group) { - break; - } - - int channels_per_block = kSizes[i - 1] / channels_per_group * channels_per_group; - int blocks = (num_channels + channels_per_block - 1) / channels_per_block; - int cost = blocks * kSizes[i - 1] - num_channels; - if (cost == 0) { - return channels_per_block; - } - - if (min_cost == -1 || cost < min_cost) { - min_cost = cost; - best_candidate = channels_per_block; - } - } - - return best_candidate; -} - -int GetChannelsPerBlock(int num_channels, int num_groups) { - int32_t channels_per_group = num_channels / num_groups; - int32_t channels_per_block = channels_per_group; - if (channels_per_group < kMaxSize / 2) { - channels_per_block = FindChannelsPerBlock(num_channels, channels_per_group); - } - return channels_per_block; -} - template Status LaunchGroupNormKernel( cudaStream_t stream, @@ -591,19 +120,13 @@ Status LaunchGroupNormKernel( bool use_silu, bool broadcast_skip, int channels_per_block) { - GroupNormNHWCParams params; - - int32_t channels_per_group = num_channels / num_groups; - // channels_per_block is computed in PrePack. - // If the gamma is not initializer, channels_per_block might be zero after PrePack. In that happens, compute it here. - if (channels_per_block < channels_per_group) { - channels_per_block = GetChannelsPerBlock(num_channels, num_groups); - } + GroupNormNHWCParams params(output, add_out, input, skip, bias, gamma, beta, workspace, epsilon, + batch_size, num_channels, height, width, num_groups, use_silu, + broadcast_skip, channels_per_block); - // TODO: Update the kernel to support CHANNELS_PER_THREAD==1 and other corner cases - if (channels_per_block % channels_per_group != 0 || - channels_per_block > kMaxSize || - (channels_per_group % CHANNELS_PER_THREAD != 0)) { + if (params.channels_per_block % params.channels_per_group != 0 || + params.channels_per_block > kMaxSize || + (params.channels_per_group % CHANNELS_PER_THREAD != 0)) { return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "GroupNorm in CUDA does not support the input: n=", batch_size, " h=", height, @@ -612,42 +135,6 @@ Status LaunchGroupNormKernel( " groups=", num_groups); } - params.use_silu = use_silu; - params.dst = output; - params.add_out = add_out; - params.src = input; - params.skip = skip; - params.bias = bias; - params.gamma = gamma; - params.beta = beta; - params.group_sum_buffer = reinterpret_cast(workspace); - params.n = batch_size; - params.h = height; - params.w = width; - params.c = num_channels; - params.groups = num_groups; - params.hw = params.h * params.w; - - // This will allocate as many blocks as possible to partition HW. - // For Stable Diffusion, latent hw is 4K ~ 16K. This will allocate 1024 blocks, and each handles 4~16 hw. - // TODO: tune this logic to find proper blocks when hw is small. - constexpr int32_t max_blocks_per_hw = 1024; - const int32_t blocks_per_hw = FindMaxDivisor(params.hw, max_blocks_per_hw); - params.hw_per_block = DivUp(params.hw, blocks_per_hw); - - params.channels_per_block = channels_per_block; - params.channels_per_group = channels_per_group; - params.hwc = params.hw * params.c; - params.inv_hw_channels_per_group = 1.F / (float)(params.hw * params.channels_per_group); - params.groups_per_block = channels_per_block / params.channels_per_group; - params.epsilon = epsilon; - params.broadcast_skip = broadcast_skip; - - // Workspace for SkipGroupNorm to store intermediate results of src+skip+bias. - params.skip_workspace = (params.add_out != nullptr) ? params.add_out : params.dst; - - params.threads_per_block = NextSize(channels_per_block) / CHANNELS_PER_THREAD; - CUDA_RETURN_IF_ERROR(cudaMemsetAsync( params.group_sum_buffer, 0, GetGroupNormWorkspaceSizeInBytes(batch_size, num_groups), stream)); diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl_kernel.cuh b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl_kernel.cuh new file mode 100644 index 0000000000000..081e9a3de578c --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl_kernel.cuh @@ -0,0 +1,355 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// The CUDA kernel is modified from GroupNorm plugin of TensorRT 8.5 +// Modifications: heuristic channels per block; support epsilon; support skip and bias; update coding style. +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once +#include +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/cu_inc/common.cuh" +#include "contrib_ops/cuda/diffusion/group_norm_impl.h" + +using namespace onnxruntime::cuda; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +static inline __device__ __host__ float sigmoid(float x) { + return 1.F / (1.F + expf(-x)); +} + +struct GroupSums { + // Is it the 1st element of the group? + int32_t flag; + // The sum. + float sum; + // The sum of squares. + float sum_sq; +}; + +struct GroupSumsOp { + inline __device__ GroupSums operator()(GroupSums const& a, GroupSums const& b) { + GroupSums dst; + dst.sum = b.flag ? b.sum : (a.sum + b.sum); + dst.sum_sq = b.flag ? b.sum_sq : (a.sum_sq + b.sum_sq); + dst.flag = a.flag + b.flag; + return dst; + } +}; + +template +inline __device__ void UpdateSum(const T* src, int64_t offset, float& sum, float& sum_sq); + +template <> +inline __device__ void UpdateSum(const half* src, int64_t offset, float& sum, float& sum_sq) { + // Fetch two channels per thread. + __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); + + float2 f2 = __half22float2(h2); + + // Update the sum. + sum += f2.x + f2.y; + + // Update the sum of squares. + sum_sq += f2.x * f2.x + f2.y * f2.y; +} + +template <> +inline __device__ void UpdateSum(const float* src, int64_t offset, float& sum, float& sum_sq) { + // Fetch two channels per thread. + float2 f2 = *reinterpret_cast(&src[offset]); + + // Update the sum. + sum += f2.x + f2.y; + + // Update the sum of squares. + sum_sq += f2.x * f2.x + f2.y * f2.y; +} + +// Sum for SkipGroupNorm: add_out[offset] = src[offset] + skip[skip_offset] + bias[bias_offset] +template +inline __device__ void AddSkipBias(T* add_out, const T* src, const T* skip, const T* bias, + int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq); + +template <> +inline __device__ void AddSkipBias(half* add_out, const half* src, const half* skip, const half* bias, + int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq) { + // Fetch two channels per thread. + __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); + __half2 s = *reinterpret_cast<__half2 const*>(&skip[skip_offset]); + __half2 b = *reinterpret_cast<__half2 const*>(&bias[bias_offset]); + h2 = h2 + b; + h2 = h2 + s; + + *reinterpret_cast<__half2*>(&add_out[offset]) = h2; + + float2 f2 = __half22float2(h2); + sum += f2.x + f2.y; + sum_sq += f2.x * f2.x + f2.y * f2.y; +} + +template <> +inline __device__ void AddSkipBias(float* add_out, const float* src, const float* skip, const float* bias, + int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq) { + float2 f2 = *reinterpret_cast(&src[offset]); + float2 s = *reinterpret_cast(&skip[skip_offset]); + float2 b = *reinterpret_cast(&bias[bias_offset]); + f2.x += s.x + b.x; + f2.y += s.y + b.y; + + *reinterpret_cast(&add_out[offset]) = f2; + + sum += f2.x + f2.y; + sum_sq += f2.x * f2.x + f2.y * f2.y; +} + +// Sum for SkipGroupNorm without bias: add_out[offset] = src[offset] + skip[skip_offset] +template +inline __device__ void AddSkip(T* add_out, const T* src, const T* skip, + int64_t offset, int64_t skip_offset, float& sum, float& sum_sq); + +template <> +inline __device__ void AddSkip(half* add_out, const half* src, const half* skip, + int64_t offset, int64_t skip_offset, float& sum, float& sum_sq) { + __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); + __half2 s = *reinterpret_cast<__half2 const*>(&skip[skip_offset]); + h2 = h2 + s; + + *reinterpret_cast<__half2*>(&add_out[offset]) = h2; + + float2 f2 = __half22float2(h2); + sum += f2.x + f2.y; + sum_sq += f2.x * f2.x + f2.y * f2.y; +} + +template <> +inline __device__ void AddSkip(float* add_out, const float* src, const float* skip, + int64_t offset, int64_t skip_offset, float& sum, float& sum_sq) { + float2 f2 = *reinterpret_cast(&src[offset]); + float2 s = *reinterpret_cast(&skip[skip_offset]); + f2.x += s.x; + f2.y += s.y; + *reinterpret_cast(&add_out[offset]) = f2; + sum += f2.x + f2.y; + sum_sq += f2.x * f2.x + f2.y * f2.y; +} + +template +__global__ void GroupNormNHWCSumKernel(GroupNormNHWCParams params) { + // The object in charge of doing the sums for the different blocks. + typedef cub::BlockScan BlockScan; + + // Allocate shared memory for BlockScan. + __shared__ typename BlockScan::TempStorage temp_storage; + + // Allocate shared memory for the groups. We could reduce the amount of shared memory reserved. + __shared__ float2 smem[THREADS_PER_BLOCK]; + + // The instance in the batch. + int32_t ni = blockIdx.z; + + // The channel loaded by that thread. + int32_t ci = blockIdx.x * params.channels_per_block + threadIdx.x * CHANNELS_PER_THREAD; + + if (ci >= params.c || threadIdx.x * CHANNELS_PER_THREAD >= params.channels_per_block) { + return; + } + + // The first activation loaded by that block. + int32_t hw_begin = blockIdx.y * params.hw_per_block; + // The last activation loaded by that block. + int32_t hw_end = min(hw_begin + params.hw_per_block, params.hw); + + // The sums. + float sum = 0.F; + float sum_sq = 0.F; + + // Iterate over the activations to compute the sums. + int64_t offset = static_cast(ni) * params.hwc + static_cast(hw_begin) * params.c + ci; + if (params.skip != nullptr) { + // SkipGroupNorm: skip is (n, h, w, c) or (n, 1, 1, c) or (n, c), bias is (c), and add_out is (n, h, w, c) + const int64_t bias_offset = static_cast(ci); + T* add_out = params.skip_workspace; + if (params.broadcast_skip) { + const int64_t skip_offset = static_cast(ni) * params.c + ci; + + if (params.bias != nullptr) { + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { + AddSkipBias(add_out, params.src, params.skip, params.bias, offset, skip_offset, bias_offset, sum, sum_sq); + } + } else { + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { + AddSkip(add_out, params.src, params.skip, offset, skip_offset, sum, sum_sq); + } + } + } else { + if (params.bias != nullptr) { + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { + AddSkipBias(add_out, params.src, params.skip, params.bias, offset, offset, bias_offset, sum, sum_sq); + } + } else { + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { + AddSkip(add_out, params.src, params.skip, offset, offset, sum, sum_sq); + } + } + } + } else { // GroupNorm + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { + UpdateSum(params.src, offset, sum, sum_sq); + } + } + + // The group index relative to the first group within the same block. + int32_t gi = threadIdx.x * CHANNELS_PER_THREAD / params.channels_per_group; + // The channel in the group. + int32_t cj = ci % params.channels_per_group; + + // The data for the summations. + GroupSums inp{cj == 0 ? 1 : 0, sum, sum_sq}; + + // Do the segmented scan. InclusiveScan is not deterministic. + GroupSums out; + BlockScan(temp_storage).InclusiveScan(inp, out, GroupSumsOp()); + + // Store the results for the groups in shared memory (to produce coalesced stores later). + // For each group, only the last thread of that group is picked to save sum to shared memory. + if (cj == params.channels_per_group - CHANNELS_PER_THREAD) { + smem[gi] = make_float2(out.sum, out.sum_sq); + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // Threads that have nothing left to do, exit. + if (threadIdx.x >= params.groups_per_block) { + return; + } + + // The global group index. + // Use neighboring threads for coalesced write. + int32_t gj = blockIdx.x * params.groups_per_block + threadIdx.x; + + if (gj < params.groups) { + float2 sums = smem[threadIdx.x]; + const int index = (2 * ni) * params.groups + gj; + atomicAdd(¶ms.group_sum_buffer[index], sums.x); + atomicAdd(¶ms.group_sum_buffer[index + params.groups], sums.y); + } +} + +template +__device__ void ComputeGroupNorm(const T* src, T* dst, int64_t offset, float mean, float inv_std_dev, + float2& gamma_f2, float2& beta_f2, bool silu); + +template <> +__device__ void ComputeGroupNorm(const half* src, half* dst, int64_t offset, float mean, float inv_std_dev, + float2& gamma_f2, float2& beta_f2, bool silu) { + // Fetch two channels per thread. + __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); + + // Extract the two half values. + float2 f2 = __half22float2(h2); + + // Normalize the channels. + f2.x = (f2.x - mean) * inv_std_dev; + f2.y = (f2.y - mean) * inv_std_dev; + + // Scale by gamma and add beta. + f2.x = gamma_f2.x * f2.x + beta_f2.x; + f2.y = gamma_f2.y * f2.y + beta_f2.y; + + // Apply SiLU activation if needed. + if (silu) { + f2.x = f2.x * sigmoid(f2.x); + f2.y = f2.y * sigmoid(f2.y); + } + + *reinterpret_cast<__half2*>(&dst[offset]) = __float22half2_rn(f2); +} + +template <> +__device__ void ComputeGroupNorm(const float* src, float* dst, int64_t offset, float mean, float inv_std_dev, + float2& gamma_f2, float2& beta_f2, bool silu) { + // Fetch two channels per thread. + float2 f2 = *reinterpret_cast(&src[offset]); + + // Normalize the channels. + f2.x = (f2.x - mean) * inv_std_dev; + f2.y = (f2.y - mean) * inv_std_dev; + + // Scale by gamma and add beta. + f2.x = gamma_f2.x * f2.x + beta_f2.x; + f2.y = gamma_f2.y * f2.y + beta_f2.y; + + // Apply SiLU activation if needed. + if (silu) { + f2.x = f2.x * sigmoid(f2.x); + f2.y = f2.y * sigmoid(f2.y); + } + + *reinterpret_cast(&dst[offset]) = f2; +} + +template +__global__ void GroupNormNHWCScaleKernel(GroupNormNHWCParams params) { + // The channel loaded by that thread. + int32_t ci = blockIdx.x * params.channels_per_block + threadIdx.x * CHANNELS_PER_THREAD; + if (ci >= params.c || threadIdx.x * CHANNELS_PER_THREAD >= params.channels_per_block) { + return; + } + + // The instance in the batch. + int32_t ni = blockIdx.z; + + // The group that thread works on. + int32_t gi = ci / params.channels_per_group; + + // Load the sum and sum of squares for the group. + float sum = 0.F, sum_sq = 0.F; + if (gi < params.groups) { + const int index = (2 * ni) * params.groups + gi; + sum = params.group_sum_buffer[index]; + sum_sq = params.group_sum_buffer[index + params.groups]; + } + + // Load gamma/beta. Fetch two per thread. + float2 gamma_f2 = *reinterpret_cast(¶ms.gamma[ci]); + float2 beta_f2 = *reinterpret_cast(¶ms.beta[ci]); + + // Compute the mean. + float mean = sum * params.inv_hw_channels_per_group; + // Compute the variance. + float var = sum_sq * params.inv_hw_channels_per_group - (mean * mean); + // Compute the inverse of the stddev. + float inv_std_dev = rsqrtf(var + params.epsilon); + + int32_t hw_begin = blockIdx.y * params.hw_per_block; + int32_t hw_end = min(hw_begin + params.hw_per_block, params.hw); + + const T* input = (params.skip != nullptr) ? params.skip_workspace : params.src; + int64_t offset = static_cast(ni) * params.hwc + static_cast(hw_begin) * params.c + ci; + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { + ComputeGroupNorm(input, params.dst, offset, mean, inv_std_dev, gamma_f2, beta_f2, params.use_silu); + } +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h index 9b97690fe70fd..86136ea244e23 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h @@ -13,9 +13,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - -#ifdef USE_CUTLASS - #pragma once #include @@ -52,5 +49,3 @@ inline int compute_occupancy_for_kernel() { } } // namespace ort_fastertransformer - -#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc b/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc index f0abd46572a90..adc043e5689e2 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc @@ -13,7 +13,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifdef USE_CUTLASS #include "cutlass_heuristic.h" @@ -66,9 +65,9 @@ bool is_valid_split_k_factor(const int64_t m, const int64_t n, const int64_t k, } // Check that the workspace has sufficient space for this split-k factor - const int ctas_in_m_dim = static_cast((m + tile_shape.m - 1) / tile_shape.m); - const int ctas_in_n_dim = static_cast((n + tile_shape.n - 1) / tile_shape.n); - const int required_ws_bytes = split_k_factor == 1 ? 0 : sizeof(int) * ctas_in_m_dim * ctas_in_n_dim; + const size_t ctas_in_m_dim = static_cast((m + tile_shape.m - 1) / tile_shape.m); + const size_t ctas_in_n_dim = static_cast((n + tile_shape.n - 1) / tile_shape.n); + const size_t required_ws_bytes = split_k_factor == 1 ? 0 : sizeof(int) * ctas_in_m_dim * ctas_in_n_dim; if (required_ws_bytes > workspace_bytes) { return false; @@ -128,7 +127,7 @@ CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector= multi_processor_count * 256 ? 1 : split_k_limit; - for (int ii = 0; ii < candidate_configs.size(); ++ii) { + for (size_t ii = 0; ii < candidate_configs.size(); ++ii) { CutlassGemmConfig candidate_config = candidate_configs[ii]; TileShape tile_shape = get_cta_shape_for_config(candidate_config.tile_config); int occupancy = occupancies[ii]; @@ -186,5 +185,3 @@ CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector @@ -64,5 +62,3 @@ class MoeGemmRunner { }; } // namespace ort_fastertransformer - -#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_fp16.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_fp16.cu index 1d0dfe7c5a647..1d9a249db4237 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_fp16.cu +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_fp16.cu @@ -14,12 +14,8 @@ * limitations under the License. */ -#ifdef USE_CUTLASS - #include "moe_gemm_kernels_template.h" namespace ort_fastertransformer { template class MoeGemmRunner; } // namespace ort_fastertransformer - -#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp32_fp32.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp32_fp32.cu index 7a5d97902ee8f..7b250e6ca9060 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp32_fp32.cu +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp32_fp32.cu @@ -14,12 +14,8 @@ * limitations under the License. */ -#ifdef USE_CUTLASS - #include "moe_gemm_kernels_template.h" namespace ort_fastertransformer { template class MoeGemmRunner; } // namespace ort_fastertransformer - -#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h index 3fd0fc47055a5..66950c9b65970 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h @@ -14,8 +14,6 @@ * limitations under the License. */ -#ifdef USE_CUTLASS - // Ignore CUTLASS warnings about type punning #ifdef __GNUC__ #pragma GCC diagnostic push @@ -428,5 +426,3 @@ void MoeGemmRunner::moe_gemm(const T* A, const WeightType* B, con } } // namespace ort_fastertransformer - -#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu index 9232e8d012933..f4f2b49032d23 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu @@ -16,8 +16,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef USE_CUTLASS - #include #include #include @@ -900,5 +898,3 @@ template void finalize_moe_routing_kernelLauncher(const half*, half*, const half cudaStream_t); } // namespace ort_fastertransformer - -#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h index f09471de1cc2e..5cc2a3f79f003 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h @@ -16,8 +16,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef USE_CUTLASS - #pragma once #include "moe_gemm_kernels.h" @@ -174,6 +172,4 @@ class CutlassMoeFCRunner> { } // namespace layout } // namespace cutlass - -#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/moe.cc b/onnxruntime/contrib_ops/cuda/moe/moe.cc index 0da06192e266b..3f26a274109ad 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe.cc +++ b/onnxruntime/contrib_ops/cuda/moe/moe.cc @@ -1,8 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef USE_CUTLASS - #include "core/common/safeint.h" #include "core/providers/cuda/cuda_common.h" #include "moe.h" @@ -119,5 +117,3 @@ Status MoE::ComputeInternal(OpKernelContext* context) const { } // namespace cuda } // namespace contrib } // namespace onnxruntime - -#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/moe.h b/onnxruntime/contrib_ops/cuda/moe/moe.h index 710b914f0633d..c4d8c4dc64c57 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe.h +++ b/onnxruntime/contrib_ops/cuda/moe/moe.h @@ -1,8 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef USE_CUTLASS - #pragma once #include "contrib_ops/cuda/moe/ft_moe/moe_kernel.h" @@ -26,5 +24,3 @@ class MoE final : public CudaKernel, public MoEBase { } // namespace cuda } // namespace contrib } // namespace onnxruntime - -#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/moe_base.h b/onnxruntime/contrib_ops/cuda/moe/moe_base.h index dc8b9d57f79f6..f55a7cde2e208 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe_base.h +++ b/onnxruntime/contrib_ops/cuda/moe/moe_base.h @@ -1,8 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef USE_CUTLASS - #pragma once #include "core/common/common.h" @@ -172,5 +170,3 @@ class MoEBase { } // namespace cuda } // namespace contrib } // namespace onnxruntime - -#endif diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu index 67384957d8dd2..d4d583906b7f4 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu @@ -89,7 +89,7 @@ __device__ __forceinline__ void Convert8xInt4To8xHalfs(uint32_t value, half2* ha asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(kOneSixteenth), "r"(kNeg64)); } -__device__ __forceinline__ float AccumulateEightElements(uint32_t values_quant, half scale, uint8_t zp, const half* a, half* sums) { +__device__ __forceinline__ void AccumulateEightElements(uint32_t values_quant, half scale, uint8_t zp, const half* a, half* sums) { half2 scale_half2 = {scale, scale}; half zp_adjust = -scale * __short2half_rn(zp); half2 zp_adjust2 = {zp_adjust, zp_adjust}; @@ -120,7 +120,7 @@ __device__ __forceinline__ float AccumulateEightElements(uint32_t values_quant, sums_half2[3] = sums_half2[3] + v3 * (*(reinterpret_cast(&(vec_permuted.w)))); } #else -__device__ __forceinline__ float AccumulateEightElements(uint32_t values_quant, half scale, uint8_t zp, const half* a, half* sums) { +__device__ __forceinline__ void AccumulateEightElements(uint32_t values_quant, half scale, uint8_t zp, const half* a, half* sums) { half2 scale_half2 = {scale, scale}; half zp_adjust = -scale * __short2half_rn(zp); half2 zp_adjust2 = {zp_adjust, zp_adjust}; @@ -144,7 +144,7 @@ __device__ __forceinline__ float AccumulateEightElements(uint32_t values_quant, } #endif -__device__ __forceinline__ float AccumulateEightElements(uint32_t values_quant, float scale, uint8_t zp, const float* a, float* sums) { +__device__ __forceinline__ void AccumulateEightElements(uint32_t values_quant, float scale, uint8_t zp, const float* a, float* sums) { float4 a_vec_0 = *(reinterpret_cast(a)); float4 a_vec_1 = *(reinterpret_cast(a + 4)); diff --git a/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.cc b/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.cc index b31f5d243e001..4cfa89a4d58c2 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.cc @@ -203,23 +203,19 @@ void DumpGpuTensor(const char* name, const Tensor& tensor) { DumpGpuTensor(nullptr, tensor, static_cast(num_rows), static_cast(row_size)); } -void CudaTensorConsoleDumper::Print(const char* name, const float* tensor, int dim0, int dim1) const { +void CudaTensorConsoleDumper::Print(const char* name, const size_t* tensor, int dim0, int dim1) const { if (is_enabled_) - DumpGpuTensor(name, tensor, dim0, dim1, true); + DumpGpuTensor(name, tensor, dim0, dim1, true); } -void CudaTensorConsoleDumper::Print(const char* name, const MLFloat16* tensor, int dim0, int dim1) const { +void CudaTensorConsoleDumper::Print(const char* name, const int32_t* tensor, int dim0, int dim1) const { if (is_enabled_) - DumpGpuTensor(name, tensor, dim0, dim1, true); + DumpGpuTensor(name, tensor, dim0, dim1, true); } -void CudaTensorConsoleDumper::Print(const char* name, const size_t* tensor, int dim0, int dim1) const { +void CudaTensorConsoleDumper::Print(const char* name, const int32_t* tensor, int dim0, int dim1, int dim2) const { if (is_enabled_) - DumpGpuTensor(name, tensor, dim0, dim1, true); -} - -void CudaTensorConsoleDumper::Print(const char* name, const half* tensor, int dim0, int dim1) const { - Print(name, reinterpret_cast(tensor), dim0, dim1); + DumpGpuTensor(name, tensor, dim0, dim1, dim2, true); } void CudaTensorConsoleDumper::Print(const char* name, const int64_t* tensor, int dim0, int dim1) const { @@ -227,9 +223,14 @@ void CudaTensorConsoleDumper::Print(const char* name, const int64_t* tensor, int DumpGpuTensor(name, tensor, dim0, dim1, true); } -void CudaTensorConsoleDumper::Print(const char* name, const int32_t* tensor, int dim0, int dim1) const { +void CudaTensorConsoleDumper::Print(const char* name, const int64_t* tensor, int dim0, int dim1, int dim2) const { if (is_enabled_) - DumpGpuTensor(name, tensor, dim0, dim1, true); + DumpGpuTensor(name, tensor, dim0, dim1, dim2, true); +} + +void CudaTensorConsoleDumper::Print(const char* name, const float* tensor, int dim0, int dim1) const { + if (is_enabled_) + DumpGpuTensor(name, tensor, dim0, dim1, true); } void CudaTensorConsoleDumper::Print(const char* name, const float* tensor, int dim0, int dim1, int dim2) const { @@ -242,6 +243,11 @@ void CudaTensorConsoleDumper::Print(const char* name, const float* tensor, int d DumpGpuTensor(name, tensor, dim0, dim1, dim2, dim3, true); } +void CudaTensorConsoleDumper::Print(const char* name, const MLFloat16* tensor, int dim0, int dim1) const { + if (is_enabled_) + DumpGpuTensor(name, tensor, dim0, dim1, true); +} + void CudaTensorConsoleDumper::Print(const char* name, const MLFloat16* tensor, int dim0, int dim1, int dim2) const { if (is_enabled_) DumpGpuTensor(name, tensor, dim0, dim1, dim2, true); @@ -252,22 +258,31 @@ void CudaTensorConsoleDumper::Print(const char* name, const MLFloat16* tensor, i DumpGpuTensor(name, tensor, dim0, dim1, dim2, dim3, true); } -void CudaTensorConsoleDumper::Print(const char* name, const half* tensor, int dim0, int dim1, int dim2) const { - Print(name, reinterpret_cast(tensor), dim0, dim1, dim2); +void CudaTensorConsoleDumper::Print(const char* name, const BFloat16* tensor, int dim0, int dim1) const { + if (is_enabled_) + DumpGpuTensor(name, tensor, dim0, dim1, true); } -void CudaTensorConsoleDumper::Print(const char* name, const half* tensor, int dim0, int dim1, int dim2, int dim3) const { - Print(name, reinterpret_cast(tensor), dim0, dim1, dim2, dim3); +void CudaTensorConsoleDumper::Print(const char* name, const BFloat16* tensor, int dim0, int dim1, int dim2) const { + if (is_enabled_) + DumpGpuTensor(name, tensor, dim0, dim1, dim2, true); } -void CudaTensorConsoleDumper::Print(const char* name, const int64_t* tensor, int dim0, int dim1, int dim2) const { +void CudaTensorConsoleDumper::Print(const char* name, const BFloat16* tensor, int dim0, int dim1, int dim2, int dim3) const { if (is_enabled_) - DumpGpuTensor(name, tensor, dim0, dim1, dim2, true); + DumpGpuTensor(name, tensor, dim0, dim1, dim2, dim3, true); } -void CudaTensorConsoleDumper::Print(const char* name, const int32_t* tensor, int dim0, int dim1, int dim2) const { - if (is_enabled_) - DumpGpuTensor(name, tensor, dim0, dim1, dim2, true); +void CudaTensorConsoleDumper::Print(const char* name, const half* tensor, int dim0, int dim1) const { + Print(name, reinterpret_cast(tensor), dim0, dim1); +} + +void CudaTensorConsoleDumper::Print(const char* name, const half* tensor, int dim0, int dim1, int dim2) const { + Print(name, reinterpret_cast(tensor), dim0, dim1, dim2); +} + +void CudaTensorConsoleDumper::Print(const char* name, const half* tensor, int dim0, int dim1, int dim2, int dim3) const { + Print(name, reinterpret_cast(tensor), dim0, dim1, dim2, dim3); } void CudaTensorConsoleDumper::Print(const char* name, const Tensor& tensor) const { @@ -301,43 +316,52 @@ void CudaTensorConsoleDumper::Print(const char* name, const std::string& value, } #else -void CudaTensorConsoleDumper::Print(const char*, const float*, int, int) const { +void CudaTensorConsoleDumper::Print(const char*, const size_t*, int, int) const { } -void CudaTensorConsoleDumper::Print(const char*, const MLFloat16*, int, int) const { +void CudaTensorConsoleDumper::Print(const char*, const int32_t*, int, int) const { } -void CudaTensorConsoleDumper::Print(const char*, const size_t*, int, int) const { +void CudaTensorConsoleDumper::Print(const char*, const int32_t*, int, int, int) const { } -void CudaTensorConsoleDumper::Print(const char*, const half*, int, int) const { +void CudaTensorConsoleDumper::Print(const char*, const int64_t*, int, int) const { } -void CudaTensorConsoleDumper::Print(const char*, const int64_t*, int, int) const { +void CudaTensorConsoleDumper::Print(const char*, const int64_t*, int, int, int) const { } -void CudaTensorConsoleDumper::Print(const char*, const int32_t*, int, int) const { +void CudaTensorConsoleDumper::Print(const char*, const float*, int, int) const { } void CudaTensorConsoleDumper::Print(const char*, const float*, int, int, int) const { } +void CudaTensorConsoleDumper::Print(const char*, const float*, int, int, int, int) const { +} + +void CudaTensorConsoleDumper::Print(const char*, const MLFloat16*, int, int) const { +} + void CudaTensorConsoleDumper::Print(const char*, const MLFloat16*, int, int, int) const { } -void CudaTensorConsoleDumper::Print(const char*, const half*, int, int, int) const { +void CudaTensorConsoleDumper::Print(const char*, const MLFloat16*, int, int, int, int) const { } -void CudaTensorConsoleDumper::Print(const char*, const int64_t*, int, int, int) const { +void CudaTensorConsoleDumper::Print(const char*, const BFloat16*, int, int) const { } -void CudaTensorConsoleDumper::Print(const char*, const int32_t*, int, int, int) const { +void CudaTensorConsoleDumper::Print(const char*, const BFloat16*, int, int, int) const { } -void CudaTensorConsoleDumper::Print(const char*, const float*, int, int, int, int) const { +void CudaTensorConsoleDumper::Print(const char*, const BFloat16*, int, int, int, int) const { } -void CudaTensorConsoleDumper::Print(const char*, const MLFloat16*, int, int, int, int) const { +void CudaTensorConsoleDumper::Print(const char*, const half*, int, int) const { +} + +void CudaTensorConsoleDumper::Print(const char*, const half*, int, int, int) const { } void CudaTensorConsoleDumper::Print(const char*, const half*, int, int, int, int) const { diff --git a/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.h b/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.h index 264ecd7cfe2f5..773401f79531a 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.h +++ b/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.h @@ -16,20 +16,31 @@ class CudaTensorConsoleDumper : public onnxruntime::contrib::transformers::ICons public: CudaTensorConsoleDumper() = default; virtual ~CudaTensorConsoleDumper() {} - void Print(const char* name, const float* tensor, int dim0, int dim1) const override; - void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1) const override; + void Print(const char* name, const size_t* tensor, int dim0, int dim1) const override; - void Print(const char* name, const half* tensor, int dim0, int dim1) const; - void Print(const char* name, const int64_t* tensor, int dim0, int dim1) const override; + void Print(const char* name, const int32_t* tensor, int dim0, int dim1) const override; + void Print(const char* name, const int32_t* tensor, int dim0, int dim1, int dim2) const override; + + void Print(const char* name, const int64_t* tensor, int dim0, int dim1) const override; + void Print(const char* name, const int64_t* tensor, int dim0, int dim1, int dim2) const override; + + void Print(const char* name, const float* tensor, int dim0, int dim1) const override; void Print(const char* name, const float* tensor, int dim0, int dim1, int dim2) const override; void Print(const char* name, const float* tensor, int dim0, int dim1, int dim2, int dim3) const; - void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1, int dim2) const override; - void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1, int dim2, int dim3) const; + + void Print(const char* name, const half* tensor, int dim0, int dim1) const; void Print(const char* name, const half* tensor, int dim0, int dim1, int dim2) const; void Print(const char* name, const half* tensor, int dim0, int dim1, int dim2, int dim3) const; - void Print(const char* name, const int64_t* tensor, int dim0, int dim1, int dim2) const override; - void Print(const char* name, const int32_t* tensor, int dim0, int dim1, int dim2) const override; + + void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1) const override; + void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1, int dim2) const override; + void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1, int dim2, int dim3) const; + + void Print(const char* name, const BFloat16* tensor, int dim0, int dim1) const; + void Print(const char* name, const BFloat16* tensor, int dim0, int dim1, int dim2) const; + void Print(const char* name, const BFloat16* tensor, int dim0, int dim1, int dim2, int dim3) const; + void Print(const char* name, const Tensor& value) const override; void Print(const char* name, const OrtValue& value) const override; void Print(const char* name, int index, bool end_line) const override; diff --git a/onnxruntime/core/providers/js/operators/pad.cc b/onnxruntime/core/providers/js/operators/pad.cc index 24ba85cbf6e0d..83fee35481aa6 100644 --- a/onnxruntime/core/providers/js/operators/pad.cc +++ b/onnxruntime/core/providers/js/operators/pad.cc @@ -14,7 +14,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( 2, 10, kJsExecutionProvider, - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), + (*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()), Pad); ONNX_OPERATOR_VERSIONED_KERNEL_EX( @@ -24,7 +24,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( 12, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T", JsepSupportedFloatTypes()) .InputMemoryType(OrtMemTypeCPU, 1) .InputMemoryType(OrtMemTypeCPU, 2) .InputMemoryType(OrtMemTypeCPU, 3), @@ -37,7 +37,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( 17, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T", JsepSupportedFloatTypes()) .InputMemoryType(OrtMemTypeCPU, 1) .InputMemoryType(OrtMemTypeCPU, 2) .InputMemoryType(OrtMemTypeCPU, 3), @@ -50,7 +50,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( 18, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T", JsepSupportedFloatTypes()) .InputMemoryType(OrtMemTypeCPU, 1) .InputMemoryType(OrtMemTypeCPU, 2) .InputMemoryType(OrtMemTypeCPU, 3), @@ -62,7 +62,7 @@ ONNX_OPERATOR_KERNEL_EX( 19, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T", JsepSupportedFloatTypes()) .InputMemoryType(OrtMemTypeCPU, 1) .InputMemoryType(OrtMemTypeCPU, 2) .InputMemoryType(OrtMemTypeCPU, 3), diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index fe6b959b962de..39e5f5be000e5 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -1834,13 +1834,21 @@ nvinfer1::IBuilder* TensorrtExecutionProvider::GetBuilder() const { } void TensorrtExecutionProvider::GetCustomOpDomainList(std::vector& custom_op_domain_list) const { - if (info_.custom_op_domain_list.empty()) { - common::Status status = CreateTensorRTCustomOpDomainList(info_); - if (!status.IsOK()) { - LOGS_DEFAULT(WARNING) << "[TensorRT EP] Failed to get TRT plugins from TRT plugin registration."; + std::string extra_plugin_lib_paths{""}; + if (info_.has_trt_options) { + if (!info_.extra_plugin_lib_paths.empty()) { + extra_plugin_lib_paths = info_.extra_plugin_lib_paths; } + } else { + const std::string extra_plugin_lib_paths_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kExtraPluginLibPaths); + if (!extra_plugin_lib_paths_env.empty()) { + extra_plugin_lib_paths = extra_plugin_lib_paths_env; + } + } + auto status = CreateTensorRTCustomOpDomainList(custom_op_domain_list, extra_plugin_lib_paths); + if (status != Status::OK()) { + LOGS_DEFAULT(WARNING) << "[TensorRT EP] Failed to get TRT plugins from TRT plugin registration."; } - custom_op_domain_list = info_.custom_op_domain_list; } // Check the graph is the subgraph of control flow op diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc index 4e466a5d568a6..eb340ba1e64b6 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc @@ -27,8 +27,12 @@ extern TensorrtLogger& GetTensorrtLogger(); * So, TensorRTCustomOp uses variadic inputs/outputs to pass ONNX graph validation. */ common::Status CreateTensorRTCustomOpDomainList(std::vector& domain_list, const std::string extra_plugin_lib_paths) { - std::unique_ptr custom_op_domain = std::make_unique(); - custom_op_domain->domain_ = "trt.plugins"; + static std::unique_ptr custom_op_domain = std::make_unique(); + static std::vector> created_custom_op_list; + if (custom_op_domain->domain_ != "" && custom_op_domain->custom_ops_.size() > 0) { + domain_list.push_back(custom_op_domain.get()); + return Status::OK(); + } // Load any extra TRT plugin library if any. // When the TRT plugin library is loaded, the global static object is created and the plugin is registered to TRT registry. @@ -69,38 +73,19 @@ common::Status CreateTensorRTCustomOpDomainList(std::vector& continue; } - std::unique_ptr trt_custom_op = std::make_unique(onnxruntime::kTensorrtExecutionProvider, nullptr); - trt_custom_op->SetName(plugin_creator->getPluginName()); - custom_op_domain->custom_ops_.push_back(trt_custom_op.release()); + created_custom_op_list.push_back(std::make_unique(onnxruntime::kTensorrtExecutionProvider, nullptr)); // Make sure TensorRTCustomOp object won't be cleaned up + created_custom_op_list.back().get()->SetName(plugin_creator->getPluginName()); + custom_op_domain->custom_ops_.push_back(created_custom_op_list.back().get()); registered_plugin_names.insert(plugin_name); } - domain_list.push_back(custom_op_domain.release()); + custom_op_domain->domain_ = "trt.plugins"; + domain_list.push_back(custom_op_domain.get()); } catch (const std::exception&) { LOGS_DEFAULT(WARNING) << "[TensorRT EP] Failed to get TRT plugins from TRT plugin registration. Therefore, TRT EP can't create custom ops for TRT plugins"; } return Status::OK(); } -common::Status CreateTensorRTCustomOpDomainList(TensorrtExecutionProviderInfo& info) { - std::vector domain_list; - std::string extra_plugin_lib_paths{""}; - if (info.has_trt_options) { - if (!info.extra_plugin_lib_paths.empty()) { - extra_plugin_lib_paths = info.extra_plugin_lib_paths; - } - } else { - const std::string extra_plugin_lib_paths_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kExtraPluginLibPaths); - if (!extra_plugin_lib_paths_env.empty()) { - extra_plugin_lib_paths = extra_plugin_lib_paths_env; - } - } - auto status = CreateTensorRTCustomOpDomainList(domain_list, extra_plugin_lib_paths); - if (!domain_list.empty()) { - info.custom_op_domain_list = domain_list; - } - return Status::OK(); -} - void ReleaseTensorRTCustomOpDomain(OrtCustomOpDomain* domain) { if (domain != nullptr) { for (auto ptr : domain->custom_ops_) { diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 3178c13d30eec..f48110aa7ee5b 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -1713,17 +1713,9 @@ ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Dnnl, _In_ OrtSessi ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Tensorrt, _In_ OrtSessionOptions* options, int device_id) { API_IMPL_BEGIN - auto factory = onnxruntime::TensorrtProviderFactoryCreator::Create(device_id); - if (!factory) { - return OrtApis::CreateStatus(ORT_FAIL, "OrtSessionOptionsAppendExecutionProvider_Tensorrt: Failed to load shared library"); - } - - options->provider_factories.push_back(factory); - - std::string extra_plugin_lib_paths = onnxruntime::Env::Default().GetEnvironmentVar("trt_extra_plugin_lib_paths"); - AddTensorRTCustomOpDomainToSessionOption(options, extra_plugin_lib_paths); - - return nullptr; + OrtTensorRTProviderOptionsV2 tensorrt_options; + tensorrt_options.device_id = device_id; + return OrtApis::SessionOptionsAppendExecutionProvider_TensorRT_V2(options, &tensorrt_options); API_IMPL_END } @@ -1741,33 +1733,8 @@ ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_MIGraphX, _In_ OrtS ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_TensorRT, _In_ OrtSessionOptions* options, _In_ const OrtTensorRTProviderOptions* tensorrt_options) { API_IMPL_BEGIN - - std::shared_ptr factory; - -#if !defined(ORT_MINIMAL_BUILD) && defined(USE_TENSORRT) - auto ep_context_cache_enabled_from_sess_options = (options->value).config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") != "0"; - // If EP context configs are provided in session options, we need to propagate them to provider options - if (ep_context_cache_enabled_from_sess_options) { - OrtTensorRTProviderOptionsV2 trt_options_converted = onnxruntime::OrtTensorRTProviderOptionsToOrtTensorRTProviderOptionsV2(tensorrt_options); - - onnxruntime::UpdateOrtTensorRTProviderOptionsV2FromSessionOptionsConfigs(options, &trt_options_converted); - factory = onnxruntime::TensorrtProviderFactoryCreator::Create(&trt_options_converted); - } else { - factory = onnxruntime::TensorrtProviderFactoryCreator::Create(tensorrt_options); - } -#else - factory = onnxruntime::TensorrtProviderFactoryCreator::Create(tensorrt_options); -#endif - - if (!factory) { - return OrtApis::CreateStatus(ORT_FAIL, "SessionOptionsAppendExecutionProvider_Tensorrt: Failed to load shared library"); - } - - options->provider_factories.push_back(factory); - - AddTensorRTCustomOpDomainToSessionOption(options, ""); - - return nullptr; + OrtTensorRTProviderOptionsV2 trt_options_converted = onnxruntime::OrtTensorRTProviderOptionsToOrtTensorRTProviderOptionsV2(tensorrt_options); + return OrtApis::SessionOptionsAppendExecutionProvider_TensorRT_V2(options, &trt_options_converted); API_IMPL_END } @@ -1906,11 +1873,11 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_TensorRT_V2, // if provider options already have the EP context configs provided, the configs in session options will be ignored // since provider options has higher priority than session options. if (!ep_context_cache_enabled_from_provider_options && ep_context_cache_enabled_from_sess_options) { - // We need to create a new provider options V2 object and copy from provider_options, due to the "const" object pointed by provider_options can't be modified. - // Note: No need to worry about tensorrt_options being a local variable, CreateExecutionProviderFactory() in TRT EP will + // This function might need to update the "const" OrtTensorRTProviderOptionsV2 object which can't be modified. + // Therefore, we need to create a new OrtTensorRTProviderOptionsV2 object and copy from tensorrt_options and use this new object to create the factory instead. + // Note: No need to worry about new_tensorrt_options being a local variable, CreateExecutionProviderFactory() in TRT EP will // create a factory object that copies any provider options from tensorrt_options including "const char*" provider options. OrtTensorRTProviderOptionsV2 new_tensorrt_options = *tensorrt_options; // copy and assign from tensorrt_options - onnxruntime::UpdateOrtTensorRTProviderOptionsV2FromSessionOptionsConfigs(options, &new_tensorrt_options); factory = onnxruntime::TensorrtProviderFactoryCreator::Create(&new_tensorrt_options); } else { diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index f7ed5520727db..8e13982ca6861 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -443,9 +443,9 @@ void RegisterTensorRTPluginsAsCustomOps(PySessionOptions& so, const ProviderOpti if (it != options.end()) { trt_extra_plugin_lib_paths = it->second; } - std::vector domain_list; - tensorrt_provider_info->GetTensorRTCustomOpDomainList(domain_list, trt_extra_plugin_lib_paths); - for (auto ptr : domain_list) { + std::vector custom_op_domains; + tensorrt_provider_info->GetTensorRTCustomOpDomainList(custom_op_domains, trt_extra_plugin_lib_paths); + for (auto ptr : custom_op_domains) { if (!is_already_in_domains(ptr->domain_, so.custom_op_domains_)) { so.custom_op_domains_.push_back(ptr); } else { diff --git a/onnxruntime/test/contrib_ops/moe_test.cc b/onnxruntime/test/contrib_ops/moe_test.cc index 844cc877f2568..ebb0261deefa5 100644 --- a/onnxruntime/test/contrib_ops/moe_test.cc +++ b/onnxruntime/test/contrib_ops/moe_test.cc @@ -1,8 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef USE_CUTLASS - #include "gtest/gtest.h" #include "test/common/tensor_op_test_utils.h" #include "test/common/cuda_op_test_utils.h" @@ -423,5 +421,3 @@ TEST(MoETest, MoETest_Relu) { } // namespace test } // namespace onnxruntime - -#endif diff --git a/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml index ff2e7c0468a21..b767b7276b428 100644 --- a/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml @@ -136,11 +136,11 @@ stages: - template: templates/explicitly-defined-final-tasks.yml -- stage: Stale_Diffusion +- stage: Stable_Diffusion dependsOn: - Build_Onnxruntime_Cuda jobs: - - job: Stale_Diffusion + - job: Stable_Diffusion variables: skipComponentGovernanceDetection: true CCACHE_DIR: $(Pipeline.Workspace)/ccache @@ -171,12 +171,7 @@ stages: python3 -m pip install -r requirements-cuda11.txt; \ python3 -m pip install --upgrade polygraphy onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com; \ echo Generate an image guided by a text prompt; \ - python3 demo_txt2img.py "astronaut riding a horse on mars"; \ - echo Generate an image with Stable Diffusion XL guided by a text prompt; \ - python3 demo_txt2img_xl.py 'starry night over Golden Gate Bridge by van gogh'; \ - python3 demo_txt2img_xl.py --enable-refiner 'starry night over Golden Gate Bridge by van gogh'; \ - echo Generate an image guided by a text prompt using LCM LoRA; \ - python3 demo_txt2img_xl.py --scheduler LCM --lora-weights latent-consistency/lcm-lora-sdxl --denoising-steps 4 "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"; \ + python3 demo_txt2img.py 'astronaut riding a horse on mars'; \ popd; \ " displayName: 'Run stable diffusion demo'