diff --git a/cmake/external/abseil-cpp.natvis b/cmake/external/abseil-cpp.natvis index 708d6ba18750b..1e5a36fb9efb9 100644 --- a/cmake/external/abseil-cpp.natvis +++ b/cmake/external/abseil-cpp.natvis @@ -30,7 +30,6 @@ - empty size={ _size() } size=({_size()}) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 8565ffbb6c379..c73f978bdf404 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -2649,8 +2649,8 @@ This version of the operator has been available since version 1 of the 'com.micr #### Type Constraints
-
T1 : tensor(float), tensor(float16)
-
Constrain input and output types to float/half_float tensors.
+
T1 : tensor(float), tensor(float16), tensor(bfloat16)
+
Constrain input and output types to float/half_float/brain_float tensors.
T2 : tensor(uint8)
Constrain quantized weight types to uint8.
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 26b5ebbdbec36..16df788c284ee 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -840,7 +840,7 @@ Do not modify directly.* |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| -|MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)| +|MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)
**T2** = tensor(uint8)| |MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T2**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)| |MoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T**
*in* fc2_experts_weights:**T**
*in* fc1_experts_bias:**T**
*in* fc2_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**T** = tensor(float), tensor(float16)| diff --git a/js/.eslintrc.js b/js/.eslintrc.js index fd30cb96a5bd0..0bf47c5264f61 100644 --- a/js/.eslintrc.js +++ b/js/.eslintrc.js @@ -5,10 +5,18 @@ module.exports = { root: true, - ignorePatterns: ['**/*.js', 'ort-schema/', 'common/test/type-tests/', 'test/data/', 'node_modules/', 'dist/'], + ignorePatterns: [ + '**/*.js', + 'node_modules/', + 'ort-schema/', + 'common/test/type-tests/', + 'web/types.d.ts', + 'test/data/', + 'dist/', + ], env: { 'es6': true }, parser: '@typescript-eslint/parser', - parserOptions: { 'project': 'tsconfig.json', 'sourceType': 'module' }, + parserOptions: { 'project': true, 'sourceType': 'module' }, plugins: ['@typescript-eslint', 'prefer-arrow', 'header', 'import', 'unicorn', 'jsdoc'], rules: { 'unicorn/filename-case': 'error', @@ -144,15 +152,56 @@ module.exports = { 'no-unused-expressions': 'off', } }, { - files: ['web/lib/**/*.ts'], - excludedFiles: 'web/lib/wasm/proxy-worker/**/*', - parserOptions: { 'project': 'web/tsconfig.json' }, - rules: { - 'no-underscore-dangle': 'off', + files: ['web/lib/**/*.ts'], rules: { + 'no-underscore-dangle': ['error', { + 'allow': [ + '_free', + '_malloc', + '_JsepGetNodeName', + '_JsepOutput', + '_OrtAddFreeDimensionOverride', + '_OrtAddRunConfigEntry', + '_OrtAddSessionConfigEntry', + '_OrtAppendExecutionProvider', + '_OrtBindInput', + '_OrtBindOutput', + '_OrtClearBoundOutputs', + '_OrtCreateBinding', + '_OrtCreateRunOptions', + '_OrtCreateSession', + '_OrtCreateSessionOptions', + '_OrtCreateTensor', + '_OrtEndProfiling', + '_OrtFree', + '_OrtGetInputName', + '_OrtGetInputOutputCount', + '_OrtGetLastError', + '_OrtGetOutputName', + '_OrtGetTensorData', + '_OrtInit', + '_OrtReleaseBinding', + '_OrtReleaseRunOptions', + '_OrtReleaseSession', + '_OrtReleaseSessionOptions', + '_OrtReleaseTensor', + '_OrtRun', + '_OrtRunWithBinding', + '_OrtTrainingCopyParametersFromBuffer', + '_OrtTrainingCopyParametersToBuffer', + '_OrtTrainingCreateSession', + '_OrtTrainingEvalStep', + '_OrtTrainingGetModelInputOutputCount', + '_OrtTrainingGetModelInputOutputName', + '_OrtTrainingGetParametersSize', + '_OrtTrainingLazyResetGrad', + '_OrtTrainingLoadCheckpoint', + '_OrtTrainingOptimizerStep', + '_OrtTrainingReleaseCheckpoint', + '_OrtTrainingReleaseSession', + '_OrtTrainingRunTrainStep' + ] + }] } - }, { - files: ['web/lib/wasm/proxy-worker/**/*.ts'], - parserOptions: { 'project': 'web/lib/wasm/proxy-worker/tsconfig.json' }, }, { files: ['web/lib/onnxjs/**/*.ts'], rules: { // TODO: those rules are useful. should turn on them in future (webgl refactor) @@ -164,6 +213,7 @@ module.exports = { 'import/no-internal-modules': 'off', 'prefer-arrow/prefer-arrow-functions': 'off', 'no-param-reassign': 'off', + 'no-underscore-dangle': 'off', 'guard-for-in': 'off' } }, { diff --git a/js/web/lib/onnxjs/attribute-with-cache-key.ts b/js/web/lib/onnxjs/attribute-with-cache-key.ts index 6608b00471e77..5d47570f267a6 100644 --- a/js/web/lib/onnxjs/attribute-with-cache-key.ts +++ b/js/web/lib/onnxjs/attribute-with-cache-key.ts @@ -6,13 +6,13 @@ class AttributeWithCacheKeyImpl { Object.assign(this, attribute); } - private _cacheKey: string; + private key: string; public get cacheKey(): string { - if (!this._cacheKey) { - this._cacheKey = + if (!this.key) { + this.key = Object.getOwnPropertyNames(this).sort().map(name => `${(this as Record)[name]}`).join(';'); } - return this._cacheKey; + return this.key; } } diff --git a/js/web/lib/wasm/jsep/webgpu/attribute-with-cache-key.ts b/js/web/lib/wasm/jsep/webgpu/attribute-with-cache-key.ts index adba0fb9d022d..ad56b92c1d869 100644 --- a/js/web/lib/wasm/jsep/webgpu/attribute-with-cache-key.ts +++ b/js/web/lib/wasm/jsep/webgpu/attribute-with-cache-key.ts @@ -6,13 +6,13 @@ class AttributeWithCacheKeyImpl { Object.assign(this, attribute); } - private _cacheKey: string; + private key: string; public get cacheKey(): string { - if (!this._cacheKey) { - this._cacheKey = + if (!this.key) { + this.key = Object.getOwnPropertyNames(this).sort().map(name => `${(this as Record)[name]}`).join(';'); } - return this._cacheKey; + return this.key; } } diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index 9f5dceb8f4726..bac44328d8f44 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -55,7 +55,6 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['BiasSplitGelu', [biasSplitGelu]], ['Cast', [unaryOps.cast, unaryOps.parseCastAttributes]], ['Ceil', [unaryOps.ceil]], - ['ClipV10', [unaryOps.clipV10]], ['Clip', [unaryOps.clip]], ['Concat', [concat, parseConcatAttributes]], ['Conv', [conv, parseConvAttributes]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts index 0841da11d9e86..c033c0ba05356 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts @@ -17,8 +17,9 @@ type BinaryFunctionCall = BuiltinFunctionName|BinaryCustomExpression|{ const createBinaryOpProgramShader = (shaderHelper: ShaderHelper, dimsA: readonly number[], dimsB: readonly number[], dimsOutput: readonly number[], - vectorize: boolean, doBroadcast: boolean, funcCall: BinaryFunctionCall, typeA: number, typeB: number, - typeOutput: number, useShapesUniforms: boolean, additionalImplementation?: string) => { + vectorize: boolean, doBroadcast: boolean, sharedDimensionDivisibleBy4: boolean, funcCall: BinaryFunctionCall, + typeA: number, typeB: number, typeOutput: number, useShapesUniforms: boolean, + additionalImplementation?: string) => { let expressionScalar: BinaryCustomExpression; let expressionVector: BinaryCustomExpression; if (typeof funcCall === 'string') { @@ -42,6 +43,8 @@ const createBinaryOpProgramShader = if (doBroadcast) { const isAOneElement = ShapeUtil.size(dimsA) === 1; const isBOneElement = ShapeUtil.size(dimsB) === 1; + const aLastDimDivisibleBy4 = dimsA.length > 0 && dimsA[dimsA.length - 1] % 4 === 0; + const bLastDimDivisibleBy4 = dimsB.length > 0 && dimsB[dimsB.length - 1] % 4 === 0; if (isAOneElement || isBOneElement) { assignment = output.setByOffset( 'global_idx', @@ -55,7 +58,14 @@ const createBinaryOpProgramShader = let offsetB = ${b.broadcastedIndicesToOffset('outputIndices', output)}; ${ output.setByOffset( - 'global_idx', expressionVector(a.getByOffset('offsetA / 4u'), b.getByOffset('offsetB / 4u')))} + 'global_idx', + expressionVector( + sharedDimensionDivisibleBy4 || aLastDimDivisibleBy4 ? + a.getByOffset('offsetA / 4u') : + `${a.type.value}(${a.getByOffset('offsetA / 4u')}[offsetA % 4u])`, + sharedDimensionDivisibleBy4 || bLastDimDivisibleBy4 ? + b.getByOffset('offsetB / 4u') : + `${b.type.value}(${b.getByOffset('offsetB / 4u')}[offsetB % 4u])`))} `; } } else { @@ -118,6 +128,7 @@ const createBinaryOpProgramInfo = let outputSize = ShapeUtil.size(a.dims); let vectorize = false; + let sharedDimensionDivisibleBy4 = false; // TODO: deal with zero-sized tensors (eg. dims=[1,0]) const cacheKeyAux = [isBroadcast]; @@ -130,8 +141,12 @@ const createBinaryOpProgramInfo = outputSize = ShapeUtil.size(outputShape); const isAOneElement = ShapeUtil.size(a.dims) === 1; const isBOneElement = ShapeUtil.size(b.dims) === 1; + const aLastDimDivisibleBy4 = a.dims.length > 0 && a.dims[a.dims.length - 1] % 4 === 0; + const bLastDimDivisibleBy4 = b.dims.length > 0 && b.dims[b.dims.length - 1] % 4 === 0; cacheKeyAux.push(isAOneElement); cacheKeyAux.push(isBOneElement); + cacheKeyAux.push(aLastDimDivisibleBy4); + cacheKeyAux.push(bLastDimDivisibleBy4); // check whether vectorize can be enabled let sharedDimension = 1; for (let i = 1; i < outputShape.length; i++) { @@ -143,7 +158,10 @@ const createBinaryOpProgramInfo = break; } } - if (sharedDimension % 4 === 0 || isAOneElement || isBOneElement) { + if (sharedDimension % 4 === 0) { + sharedDimensionDivisibleBy4 = true; + vectorize = true; + } else if (isAOneElement || isBOneElement || aLastDimDivisibleBy4 || bLastDimDivisibleBy4) { vectorize = true; } } else { @@ -160,8 +178,8 @@ const createBinaryOpProgramInfo = inputDependencies: useShapesUniforms ? ['rank', 'rank'] : ['dims', 'dims'], }, getShaderSource: (shaderHelper) => createBinaryOpProgramShader( - shaderHelper, a.dims, b.dims, outputShape, vectorize, isBroadcast, funcCall, a.dataType, b.dataType, - outputDataType, useShapesUniforms, additionalImplementation), + shaderHelper, a.dims, b.dims, outputShape, vectorize, isBroadcast, sharedDimensionDivisibleBy4, funcCall, + a.dataType, b.dataType, outputDataType, useShapesUniforms, additionalImplementation), getRunData: () => ({ outputs: [{dims: outputShape, dataType: outputDataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* component size */)}, diff --git a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts index 4238449f9246f..119609e06f5a3 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts @@ -124,7 +124,14 @@ export interface ClipAttributes extends AttributeWithCacheKey { readonly max: number; } -export const clipV10 = (context: ComputeContext, attributes: ClipAttributes): void => { +const generateClipAttributesFromInputs = (inputs: readonly TensorView[]): ClipAttributes => { + const min = (inputs.length >= 2) ? inputs[1].getFloat32Array()[0] : MIN_CLIP; + const max = (inputs.length >= 3) ? inputs[2].getFloat32Array()[0] : MAX_CLIP; + return createAttributeWithCacheKey({min, max}); +}; + +export const clip = (context: ComputeContext, clipAttributes: ClipAttributes): void => { + const attributes = context.inputs.length === 1 ? clipAttributes : generateClipAttributesFromInputs(context.inputs); const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType); context.compute( createElementwiseProgramInfo( @@ -135,16 +142,6 @@ export const clipV10 = (context: ComputeContext, attributes: ClipAttributes): vo attributes.cacheKey), {inputs: [0]}); }; -const generateClipAttributesFromInputs = (inputs: readonly TensorView[]): ClipAttributes => { - const min = (inputs.length >= 2) ? inputs[1].getFloat32Array()[0] : MIN_CLIP; - const max = (inputs.length >= 3) ? inputs[2].getFloat32Array()[0] : MAX_CLIP; - return createAttributeWithCacheKey({min, max}); -}; - -export const clip = (context: ComputeContext): void => { - const attributes = generateClipAttributesFromInputs(context.inputs); - clipV10(context, attributes); -}; export const ceil = (context: ComputeContext): void => { context.compute(createElementwiseProgramInfo(context.inputs[0], 'Ceil', 'ceil')); diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 7172a28316f16..108eea1a73fe9 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -121,6 +121,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Inverse); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MatMulNBits); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MatMulNBits); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, MatMulBnb4); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MatMulBnb4); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MatMulBnb4); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Trilu); @@ -313,6 +314,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cc b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cc index 251850f621361..6cdccdb1becb1 100644 --- a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cc +++ b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cc @@ -14,17 +14,23 @@ namespace onnxruntime { namespace contrib { namespace cuda { -#define REGISTER_KERNEL() \ - ONNX_OPERATOR_KERNEL_EX( \ - GemmFloat8, \ - kMSDomain, \ - 1, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("TA", BuildKernelDefConstraints()) \ - .TypeConstraint("TB", BuildKernelDefConstraints()) \ - .TypeConstraint("TR", BuildKernelDefConstraints()) \ - .TypeConstraint("TS", BuildKernelDefConstraints()), \ +#if !defined(DISABLE_FLOAT8_TYPES) +#define GEMM_FLOAT8_CONSTRAINTS BuildKernelDefConstraints() +#else +#define GEMM_FLOAT8_CONSTRAINTS BuildKernelDefConstraints() +#endif + +#define REGISTER_KERNEL() \ + ONNX_OPERATOR_KERNEL_EX( \ + GemmFloat8, \ + kMSDomain, \ + 1, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("TA", GEMM_FLOAT8_CONSTRAINTS) \ + .TypeConstraint("TB", GEMM_FLOAT8_CONSTRAINTS) \ + .TypeConstraint("TR", GEMM_FLOAT8_CONSTRAINTS) \ + .TypeConstraint("TS", BuildKernelDefConstraints()), \ GemmFloat8); REGISTER_KERNEL() @@ -38,7 +44,7 @@ GemmFloat8::GemmFloat8(const OpKernelInfo& info) : CudaKernel(info) { alpha_ = info.GetAttrOrDefault("alpha", 1); beta_ = info.GetAttrOrDefault("beta", 0); -#if (CUDA_VERSION <= 12000) +#if (CUDA_VERSION < 12000) ORT_ENFORCE(beta_ == 0, "CUDA < 12.0 does not support bias, beta must be 0."); #endif diff --git a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu index df25342342cd5..56b541f5256bf 100644 --- a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu +++ b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu @@ -28,7 +28,7 @@ int32_t TypeSize(int32_t element_type) { case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: return 2; -#if (!defined(DISABLE_FLOAT8_TYPES) && (CUDA_VERSION >= 11080)) +#if !defined(DISABLE_FLOAT8_TYPES) case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN: case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2: return 1; @@ -97,12 +97,16 @@ Status GemmFloat8::ComputeInternal(OpKernelContext* ctx) const { } auto first_type = input_A->GetElementType(); +#if !defined(DISABLE_FLOAT8_TYPES) bool is_float8 = first_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN || first_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2; if (!is_float8) +#endif return ComputeRowMajor(ctx, n_inputs, has_bias, has_scales, input_A, input_B, input_C, scale_A, scale_B, scale_Y); +#if !defined(DISABLE_FLOAT8_TYPES) return ComputeColMajor(ctx, n_inputs, has_bias, has_scales, input_A, input_B, input_C, scale_A, scale_B, scale_Y); +#endif } Status GemmFloat8::ComputeRowMajor( @@ -197,10 +201,15 @@ Status GemmFloat8::ComputeGemm( switch (d_cuda_type) { case CUDA_R_16F: switch (a_cuda_type) { +#if !defined(DISABLE_FLOAT8_TYPES) +#if CUDA_VERSION < 11080 +#error CUDA_R_8F_E4M3 (float 8 types) is defined with CUDA>=11.8. Set flag DISABLE_FLOAT8_TYPES. +#endif case CUDA_R_8F_E4M3: case CUDA_R_8F_E5M2: compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; break; +#endif default: compute_type = CUBLAS_COMPUTE_32F_FAST_16F; break; @@ -267,7 +276,7 @@ Status GemmFloat8::ComputeGemm( sizeof(p_scale_b))); // float 8 -#if CUDA_VERSION >= 11080 +#if !defined(DISABLE_FLOAT8_TYPES) if (dtype_Y == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN || dtype_Y == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2) { // For FP8 output, cuBLAS requires C_type to be same as bias_type @@ -280,15 +289,14 @@ Status GemmFloat8::ComputeGemm( CUBLAS_RETURN_IF_ERROR( cublasLtMatrixLayoutCreate(&Cdesc, d_cuda_type, M, N, ldd)); } - } else { - CUBLAS_RETURN_IF_ERROR( - cublasLtMatrixLayoutCreate(&Cdesc, d_cuda_type, M, N, ldd)); - } #else - // An output is still needed but it is not initialized. CUBLAS_RETURN_IF_ERROR( cublasLtMatrixLayoutCreate(&Cdesc, d_cuda_type, M, N, ldd)); #endif + } else { + CUBLAS_RETURN_IF_ERROR( + cublasLtMatrixLayoutCreate(&Cdesc, d_cuda_type, M, N, ldd)); + } if (row_major_compute) { cublasLtOrder_t matrixOrder = CUBLASLT_ORDER_ROW; @@ -345,7 +353,7 @@ Status GemmFloat8::ComputeGemm( ". Check NVIDIA documentation to see what combination is valid: ", "https://docs.nvidia.com/cuda/cublas/" "index.html?highlight=cublasLtMatmulAlgoGetHeuristic#" - "cublasltmatmulalgogetheuristic."); + "cublasltmatmulalgogetheuristic. CUDA>=11.8 is required to use float 8 types."); void* workspace = nullptr; if (workspaceSize > 0) { @@ -381,7 +389,8 @@ Status GemmFloat8::ComputeGemm( ", shape_A=", shape_A[0], "x", shape_A[1], ", shape_B=", shape_B[0], "x", shape_B[1], ", M=", M, ", N=", N, ", K=", K, ", lda=", lda, ", ldb=", ldb, ", ldd=", ldd, ", workspaceSize=", workspaceSize, - ", rowMajorCompute=", (row_major_compute ? 1 : 0), "."); + ", rowMajorCompute=", (row_major_compute ? 1 : 0), + ". CUDA>=11.8 is required to use float 8 types."); if (workspaceSize > 0) { CUDA_RETURN_IF_ERROR(cudaFree(workspace)); diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cu b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cu index e58723f0b31e1..2f74dd41f0759 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cu @@ -35,6 +35,8 @@ template Status SetBnbQuantMap(int quant_type, float* quant_map_buffer, c template Status SetBnbQuantMap(int quant_type, half* quant_map_buffer, cudaStream_t stream); +template Status SetBnbQuantMap(int quant_type, BFloat16* quant_map_buffer, cudaStream_t stream); + template __global__ void kDequantizeBlockwise( const T* quant_map, @@ -62,22 +64,15 @@ __global__ void kDequantizeBlockwise( valid_items_load = (n + 1) / 2 - i > TILE_SIZE ? TILE_SIZE : (n + 1) / 2 - i; valid_items_store = n - i * 2 > TILE_SIZE * 2 ? TILE_SIZE * 2 : n - i * 2; - local_abs_max = __ldg(&absmax[(i + threadIdx.x * NUM_PER_TH) / (block_size)]); + local_abs_max = absmax[(i + threadIdx.x * NUM_PER_TH) / (block_size)]; __syncthreads(); LoadChar(loadchar).Load(&(quant_data[i]), qvals, valid_items_load, 128); #pragma unroll NUM_PER_TH for (int j = 0; j < NUM_PER_TH; j++) { - #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 - vals[j * 2] = quant_map[qvals[j] >> 4] * local_abs_max; - vals[j * 2 + 1] = quant_map[qvals[j] & 0x0F] * local_abs_max; - #else - // half multiplication not supported - vals[j * 2] = static_cast(static_cast(quant_map[qvals[j] >> 4]) * static_cast(local_abs_max)); - vals[j * 2 + 1] = - static_cast(static_cast(quant_map[qvals[j] & 0x0F]) * static_cast(local_abs_max)); - #endif + vals[j * 2] = ScalarMul(quant_map[qvals[j] >> 4], local_abs_max); + vals[j * 2 + 1] = ScalarMul(quant_map[qvals[j] & 0x0F], local_abs_max); } __syncthreads(); @@ -86,7 +81,7 @@ __global__ void kDequantizeBlockwise( } template -Status DequantizeBnb4( +void CallkDequantizeBlockwise( const T* quant_map, T* output, const uint8_t* quant_data, @@ -102,6 +97,18 @@ Status DequantizeBnb4( absmax, block_size / 2, numel); +} + +template +Status DequantizeBnb4( + const T* quant_map, + T* output, + const uint8_t* quant_data, + const T* absmax, + int block_size, + int numel, + cudaStream_t stream) { + CallkDequantizeBlockwise(quant_map, output, quant_data, absmax, block_size, numel, stream); return Status::OK(); } @@ -119,11 +126,36 @@ template Status DequantizeBnb4( const half* quant_map, half* output, const uint8_t* quant_data, - const half *absmax, + const half* absmax, int block_size, int numel, cudaStream_t stream); +template <> +Status DequantizeBnb4( + const BFloat16* quant_map, + BFloat16* output, + const uint8_t* quant_data, + const BFloat16* absmax, + int block_size, + int numel, + cudaStream_t stream) { + #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 + CallkDequantizeBlockwise( + reinterpret_cast(quant_map), + reinterpret_cast(output), + quant_data, + reinterpret_cast(absmax), + block_size, + numel, + stream); + #else + CallkDequantizeBlockwise(quant_map, output, quant_data, absmax, block_size, numel, stream); + #endif + + return Status::OK(); +} + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh index 4aef3ab699f9c..a0d38c9853cd6 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh @@ -11,6 +11,38 @@ namespace cuda { template Status SetBnbQuantMap(int quant_type, T* quant_map_buffer, cudaStream_t stream); +// templated scalar multiply function +template +__device__ inline T ScalarMul(T a, T b); + +template <> +__device__ inline float ScalarMul(float a, float b) { + return a * b; +} + +template <> +__device__ inline half ScalarMul(half a, half b) { + #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 + return a * b; + #else + // half multiplication not supported + return static_cast(static_cast(a) * static_cast(b)); + #endif +} + +template <> +__device__ inline BFloat16 ScalarMul(BFloat16 a, BFloat16 b) { + return a * b; +} + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +// will use the native bfloat16 multiply instruction on sm_80+ +template <> +__device__ inline nv_bfloat16 ScalarMul(nv_bfloat16 a, nv_bfloat16 b) { + return a * b; +} +#endif + template Status DequantizeBnb4( const T* quant_map, diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc index ecf332715d470..bbcb7de99781f 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc @@ -145,6 +145,17 @@ ONNX_OPERATOR_TYPED_KERNEL_EX( .TypeConstraint("T2", DataTypeImpl::GetTensorType()), MatMulBnb4); +ONNX_OPERATOR_TYPED_KERNEL_EX( + MatMulBnb4, + kMSDomain, + 1, + BFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + MatMulBnb4); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cu b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cu index 1d9aa75ff3701..098e3618beddd 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cu @@ -6,12 +6,44 @@ #include #include #include +#include "contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh" #include "matmul_bnb4.cuh" namespace onnxruntime { namespace contrib { namespace cuda { +template +__device__ inline float ScalarMulFloatOut(T a, T b); + +template <> +__device__ inline float ScalarMulFloatOut(float a, float b) { + return a * b; +} + +template <> +__device__ inline float ScalarMulFloatOut(half a, half b) { + #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 + return static_cast(a * b); + #else + // half multiplication not supported + return static_cast(a) * static_cast(b); + #endif +} + +template <> +__device__ inline float ScalarMulFloatOut(BFloat16 a, BFloat16 b) { + return a * b; +} + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +// will use the native bfloat16 multiply instruction on sm_80+ +template <> +__device__ inline float ScalarMulFloatOut(nv_bfloat16 a, nv_bfloat16 b) { + return static_cast(a * b); +} +#endif + #define num_values_4bit 32 template __global__ void kgemm_4bit_inference_naive( @@ -55,7 +87,7 @@ __global__ void kgemm_4bit_inference_naive( int inner_idx_halved = inner_idx / 2; int offset_B = ldb * row_B; int absidx = ((2 * offset_B) + inner_idx) / block_size; - local_absmax = __ldg(&(absmax[absidx])); + local_absmax = absmax[absidx]; if (row_B < N) { if ((inner_idx_halved + num_values_8bit) < (K / 2)) { @@ -78,18 +110,8 @@ __global__ void kgemm_4bit_inference_naive( for (int i = 0; i < 4; i++) { #pragma unroll for (int k = 0; k < num_values_8bit / 4; k++) { - #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 - local_B[k * 2] = quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] >> 4] * local_absmax; - local_B[k * 2 + 1] = quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] & 0x0F] * local_absmax; - #else - // half multiplication not supported - local_B[k * 2] = - static_cast(static_cast(quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] >> 4]) * - static_cast(local_absmax)); - local_B[k * 2 + 1] = - static_cast(static_cast(quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] & 0x0F]) * - static_cast(local_absmax)); - #endif + local_B[k * 2] = ScalarMul(quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] >> 4], local_absmax); + local_B[k * 2 + 1] = ScalarMul(quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] & 0x0F], local_absmax); } if (inner_idx + (num_values_4bit / 4) + (i * num_values_4bit / 4) < K) { @@ -116,12 +138,7 @@ __global__ void kgemm_4bit_inference_naive( // accumulate in float; small performance hit for Ampere, but lower error for outputs #pragma unroll for (int k = 0; k < num_values_4bit / 4; k++) { - #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 - local_C += static_cast(local_A[k] * local_B[k]); - #else - // half multiplication not supported - local_C += static_cast(local_A[k]) * static_cast(local_B[k]); - #endif + local_C += ScalarMulFloatOut(local_A[k], local_B[k]); } } } @@ -131,8 +148,19 @@ __global__ void kgemm_4bit_inference_naive( if (row_B < N && warp_lane == 0) out[row_B] = T(local_C); } +bool CheckDims(int m, int k, int block_size) { + if (k % block_size != 0 || m > 1) { + return false; + } + // supported block_sizes are [4096, 2048, 1024, 512, 256, 128, 64, 32] + if (block_size % 32 != 0 || block_size > 4096) { + return false; + } + return true; +} + template -bool TryMatMulBnb4( +void Callkgemm_4bit_inference_naive( const T* quant_map, T* output, const T* a_data, @@ -143,22 +171,34 @@ bool TryMatMulBnb4( int k, int block_size, cudaStream_t stream) { - if (k % block_size != 0 || m > 1) { - return false; - } - // supported block_sizes are [4096, 2048, 1024, 512, 256, 128, 64, 32] - if (block_size % 32 != 0 || block_size > 4096) { - return false; - } - int lda = k; int ldb = (k + 1) / 2; int ldc = n; int num_blocks = (n + 3) / 4; - constexpr int bits = std::is_same_v ? 16 : 32; + constexpr int bits = std::is_same_v ? 32 : 16; kgemm_4bit_inference_naive<<>>( m, n, k, a_data, b_data_quant, absmax, quant_map, output, lda, ldb, ldc, block_size); +} + +template +bool TryMatMulBnb4( + const T* quant_map, + T* output, + const T* a_data, + const uint8_t* b_data_quant, + const T* absmax, + int m, + int n, + int k, + int block_size, + cudaStream_t stream) { + if (!CheckDims(m, k, block_size)) { + return false; + } + + Callkgemm_4bit_inference_naive( + quant_map, output, a_data, b_data_quant, absmax, m, n, k, block_size, stream); return true; } @@ -187,6 +227,42 @@ template bool TryMatMulBnb4( int block_size, cudaStream_t stream); +template <> +bool TryMatMulBnb4( + const BFloat16* quant_map, + BFloat16* output, + const BFloat16* a_data, + const uint8_t* b_data_quant, + const BFloat16* absmax, + int m, + int n, + int k, + int block_size, + cudaStream_t stream) { + if (!CheckDims(m, k, block_size)) { + return false; + } + + #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 + Callkgemm_4bit_inference_naive( + reinterpret_cast(quant_map), + reinterpret_cast(output), + reinterpret_cast(a_data), + b_data_quant, + reinterpret_cast(absmax), + m, + n, + k, + block_size, + stream); + #else + Callkgemm_4bit_inference_naive( + quant_map, output, a_data, b_data_quant, absmax, m, n, k, block_size, stream); + #endif + + return true; +} + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index db0b13b0e1d27..4c0d78f0ee297 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -3431,7 +3431,7 @@ MatMulBnb4 is a MatMul with weight quantized with 4 bits using either FP4 or NF4 .Input(1, "B", "1-dimensional quantized data for weight", "T2") .Input(2, "absmax", "quantization constants", "T1") .Output(0, "Y", "tensor. The output tensor has the same rank as the input. ", "T1") - .TypeConstraint("T1", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float/half_float tensors.") + .TypeConstraint("T1", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output types to float/half_float/brain_float tensors.") .TypeConstraint("T2", {"tensor(uint8)"}, "Constrain quantized weight types to uint8.") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { // Type inference diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 3763e0758cc5c..d489a59c4b798 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -4062,7 +4062,9 @@ static void ReassignSubgraphDependentNodeArgs(const InlinedHashMapExists()) { auto hit = name_to_nodearg.find(input_def->Name()); if (hit != name_to_nodearg.cend()) { - input_def = hit->second; + // Make sure we create a local to this subgraph definition + const auto* new_name_arg = hit->second; + input_def = &graph.GetOrCreateNodeArg(new_name_arg->Name(), input_def->TypeAsProto()); } } } @@ -4088,7 +4090,7 @@ Status Graph::InlineIfSubgraph(bool condition_value, Node& if_node, const loggin Graph& graph_to_inline = *sub_graph; - std::string unique_id{if_node.Name()}; + std::string unique_id{"_if_"}; if (condition_value) { unique_id.append(then_branch); } else { @@ -4107,7 +4109,7 @@ Status Graph::InlineIfSubgraph(bool condition_value, Node& if_node, const loggin // Reason: there are no explicit inputs to the subgraphs, and the subgraph's // implicit inputs must be covered by the implicit inputs of the If node. InlinedHashMap outer_scope_values; - const auto if_implicit_inputs = if_node.MutableImplicitInputDefs(); + const auto& if_implicit_inputs = if_node.MutableImplicitInputDefs(); outer_scope_values.reserve(if_implicit_inputs.size()); for (auto* input : if_implicit_inputs) { @@ -4121,8 +4123,8 @@ Status Graph::InlineIfSubgraph(bool condition_value, Node& if_node, const loggin // We are going to map the outputs of the graph to inline to the outputs of the If node. // They are assumed to be in the same order. - const auto node_output_defs = if_node.MutableOutputDefs(); - const auto graph_output_defs = graph_to_inline.GetOutputs(); + const auto& node_output_defs = if_node.MutableOutputDefs(); + const auto& graph_output_defs = graph_to_inline.GetOutputs(); for (size_t i = 0; i < graph_output_defs.size(); ++i) { name_to_nodearg.emplace(graph_output_defs[i]->Name(), node_output_defs[i]); } @@ -4206,6 +4208,7 @@ Status Graph::InlineIfSubgraph(bool condition_value, Node& if_node, const loggin } } + auto* non_existing_arg = &GetOrCreateNodeArg(std::string(), nullptr); // We want to make sure we get nodes in topological order // because Constant folding may cause the nodes appear in // a different order. @@ -4216,68 +4219,94 @@ Status Graph::InlineIfSubgraph(bool condition_value, Node& if_node, const loggin auto* node = graph_to_inline.GetNode(node_idx); assert(node->OpType() != kConstant); - InlinedVector new_node_input_defs; - for (const auto* input_def : node->InputDefs()) { + // Inputs + // Chop off trailing non-existing defs, but preserve non-existing in the middle + auto& input_defs = node->MutableInputDefs(); + auto last_existing = std::find_if(input_defs.rbegin(), input_defs.rend(), + [](const NodeArg* node_arg) { return node_arg->Exists(); }); + input_defs.resize(std::distance(input_defs.begin(), last_existing.base())); + + InlinedVector new_input_defs; + for (auto* input_def : node->InputDefs()) { if (input_def->Exists()) { // Check if this is one of the implicit graph inputs - // then leave the name as is and re-use the NodeArg + // then re-assign the def to the outer scope value. const auto& input_name = input_def->Name(); auto outer_hit = outer_scope_values.find(input_name); if (outer_hit != outer_scope_values.cend()) { - new_node_input_defs.push_back(outer_hit->second); + // get/create local definition + NodeArg* outer_arg = outer_hit->second; + auto& this_scope_arg = GetOrCreateNodeArg(outer_arg->Name(), input_def->TypeAsProto()); + new_input_defs.push_back(&this_scope_arg); } else { auto hit = name_to_nodearg.find(input_name); if (hit != name_to_nodearg.cend()) { - // This is other node output, constant node or initializer that was renamed. - new_node_input_defs.push_back(hit->second); + // This is other node output in the dest graph, + // constant node or initializer that was renamed. + new_input_defs.push_back(hit->second); } else { ORT_THROW("Node's: ", node->Name(), " input: ", input_name, " is not If node's input or previous node output in this subgraph"); } } + } else { + new_input_defs.push_back(non_existing_arg); } } - InlinedVector new_node_output_defs; - for (const auto* output_def : node->OutputDefs()) { - const auto& output_name = output_def->Name(); - auto hit = name_to_nodearg.find(output_name); - if (hit != name_to_nodearg.cend()) { - // This is one of the graph outputs, we rename it to - // If node output. - new_node_output_defs.push_back(hit->second); + // Outputs + // Chop off trailing non-existing defs + auto& output_defs = node->MutableOutputDefs(); + last_existing = std::find_if(output_defs.rbegin(), output_defs.rend(), + [](const NodeArg* node_arg) { return node_arg->Exists(); }); + output_defs.resize(std::distance(output_defs.begin(), last_existing.base())); + + InlinedVector new_output_defs; + for (auto* output_def : node->OutputDefs()) { + if (output_def->Exists()) { + const auto& output_name = output_def->Name(); + auto hit = name_to_nodearg.find(output_name); + if (hit != name_to_nodearg.cend()) { + // This is one of the If node outputs, simply reassign the def. + // If node defs are already in the destination graph + new_output_defs.push_back(hit->second); + } else { + // We generate an output to downstream nodes. + auto new_name = GenerateNodeArgName(make_unique(output_name)); + NodeArg& new_arg = GetOrCreateNodeArg(new_name, output_def->TypeAsProto()); + new_output_defs.push_back(&new_arg); + ORT_IGNORE_RETURN_VALUE(name_to_nodearg.emplace(output_name, &new_arg)); + } } else { - // We generate an output to downstream nodes. - auto new_name = GenerateNodeArgName(make_unique(output_name)); - NodeArg& new_arg = GetOrCreateNodeArg(new_name, output_def->TypeAsProto()); - new_node_output_defs.push_back(&new_arg); - ORT_IGNORE_RETURN_VALUE(name_to_nodearg.emplace(output_name, &new_arg)); + new_output_defs.push_back(non_existing_arg); } } const auto new_node_name = GenerateNodeName(make_unique(node->OpType())); Node& new_node = AddNode(new_node_name, node->OpType(), node->Description(), - new_node_input_defs, - new_node_output_defs, + new_input_defs, + new_output_defs, nullptr, node->Domain()); + new_node.SetSinceVersion(node->SinceVersion()); + new_node.op_ = node->op_; + if (!is_this_main_graph) { map_defs(new_node, input_args, true); map_defs(new_node, output_args, false); new_nodes.push_back(&new_node); } - new_node.SetSinceVersion(node->SinceVersion()); - new_node.op_ = node->op_; - if (node->ContainsSubgraph()) { auto& subgraphs = node->MutableSubgraphs(); // Check if any of this node implicit inputs of this graph is in the renaming map + // that would mean they come from the destination graph, not from the parent + // of the destination graph. int renames_subgraph_names = 0; - auto& new_implicit_defs = node->MutableImplicitInputDefs(); - for (auto& input_def : new_implicit_defs) { + auto& implicit_defs = node->MutableImplicitInputDefs(); + for (auto& input_def : implicit_defs) { auto hit = name_to_nodearg.find(input_def->Name()); if (hit != name_to_nodearg.cend()) { input_def = hit->second; @@ -4298,7 +4327,7 @@ Status Graph::InlineIfSubgraph(bool condition_value, Node& if_node, const loggin new_node.MutableSubgraphs() = std::move(subgraphs); new_node.GetMutableMapOfAttributeNameToSubgraph() = std::move(node->GetMutableMapOfAttributeNameToSubgraph()); - new_node.MutableImplicitInputDefs() = std::move(new_implicit_defs); + new_node.MutableImplicitInputDefs() = std::move(implicit_defs); } new_node.GetMutableAttributes() = std::move(node->GetMutableAttributes()); diff --git a/onnxruntime/core/providers/cuda/cuda_common.cc b/onnxruntime/core/providers/cuda/cuda_common.cc index 288ca8e97e34d..33f2938940e4d 100644 --- a/onnxruntime/core/providers/cuda/cuda_common.cc +++ b/onnxruntime/core/providers/cuda/cuda_common.cc @@ -62,7 +62,8 @@ const char* CudaDataTypeToString(cudaDataType_t dt) { return "CUDA_R_16BF"; case CUDA_R_32F: return "CUDA_R_32F"; -#if (CUDA_VERSION >= 11080) +#if !defined(DISABLE_FLOAT8_TYPES) + // Note: CUDA_R_8F_E4M3 is defined with CUDA>=11.8 case CUDA_R_8F_E4M3: return "CUDA_R_8F_E4M3"; case CUDA_R_8F_E5M2: @@ -101,7 +102,7 @@ cudaDataType_t ToCudaDataType(int32_t element_type) { return CUDA_R_16F; case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: return CUDA_R_16BF; -#if (!defined(DISABLE_FLOAT8_TYPES) && (CUDA_VERSION >= 11080)) +#if !defined(DISABLE_FLOAT8_TYPES) case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN: return CUDA_R_8F_E4M3; case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2: diff --git a/onnxruntime/core/providers/cuda/cuda_common.h b/onnxruntime/core/providers/cuda/cuda_common.h index 9cd4e721ccab8..707099bac3ce0 100644 --- a/onnxruntime/core/providers/cuda/cuda_common.h +++ b/onnxruntime/core/providers/cuda/cuda_common.h @@ -58,6 +58,8 @@ class ToCudaType { } }; +#if !defined(DISABLE_FLOAT8_TYPES) + template <> class ToCudaType { public: @@ -76,6 +78,8 @@ class ToCudaType { } }; +#endif + inline bool CalculateFdmStrides(gsl::span p, const std::vector& dims) { int stride = 1; if (dims.empty() || p.size() < dims.size()) diff --git a/onnxruntime/core/providers/cuda/tensor/cast_op.cu b/onnxruntime/core/providers/cuda/tensor/cast_op.cu index 7542fb55757c6..f2c2e6d7458f9 100644 --- a/onnxruntime/core/providers/cuda/tensor/cast_op.cu +++ b/onnxruntime/core/providers/cuda/tensor/cast_op.cu @@ -141,7 +141,7 @@ struct CastSat { #endif -#endif +#endif // DISABLE_FLOAT8_TYPES template __global__ void CastKernelStd(const InT* input, OutT* output, CUDA_LONG N, CastStd cast) { diff --git a/onnxruntime/core/providers/cuda/tensor/quantize_linear.cu b/onnxruntime/core/providers/cuda/tensor/quantize_linear.cu index ad2a44793fe26..1da308811fa48 100644 --- a/onnxruntime/core/providers/cuda/tensor/quantize_linear.cu +++ b/onnxruntime/core/providers/cuda/tensor/quantize_linear.cu @@ -104,7 +104,7 @@ struct RoundSat { #endif -#endif +#endif // DISABLE_FLOAT8_TYPES template <> struct RoundStd { @@ -189,7 +189,7 @@ __global__ void QuantizeLinearKernelAxisSat(const InT* input, OutT* output, cons } } -#endif +#endif // DISABLE_FLOAT8_TYPES template Status CudaQuantizeLinearStd(cudaStream_t stream, const InT* input, OutT* output, const InT* scale, const OutT* zero_point, size_t num_of_element) { diff --git a/onnxruntime/core/providers/js/operators/unary.cc b/onnxruntime/core/providers/js/operators/unary.cc index e9bbfabcf86bd..78563d30b0136 100644 --- a/onnxruntime/core/providers/js/operators/unary.cc +++ b/onnxruntime/core/providers/js/operators/unary.cc @@ -123,7 +123,7 @@ JSEP_ELEMENTWISE_TYPED_KERNEL(Not, 1, bool, Not) // activation -JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_2_DEFAULT(ClipV10, ClipV10, min, 3.402823e+38f, max, -3.402823e+38f) +JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_2_DEFAULT(ClipV10, Clip, min, 3.402823e+38f, max, -3.402823e+38f) JSEP_ELEMENTWISE_VERSIONED_KERNEL(Clip, 6, 10, ClipV10) JSEP_KERNEL_IMPL(Clip, Clip) ONNX_OPERATOR_VERSIONED_KERNEL_EX(Clip, kOnnxDomain, 11, 11, kJsExecutionProvider, diff --git a/onnxruntime/python/tools/transformers/large_model_exporter.py b/onnxruntime/python/tools/transformers/large_model_exporter.py index 3b344d6dc9342..407c3b80e153f 100644 --- a/onnxruntime/python/tools/transformers/large_model_exporter.py +++ b/onnxruntime/python/tools/transformers/large_model_exporter.py @@ -157,14 +157,14 @@ def hook_for_inputs(_, inputs, kwargs): for idx, (key, value) in enumerate(zip(input_keys, onnx_inputs)): if type(value) is torch.Tensor: value.to(model.device) - # Didn't touch past_key_value now, please change it if you want if "use_cache" in key: onnx_inputs[idx] = with_past + out = model(sample_inputs[0], attention_mask=sample_inputs[1], use_cache=with_past) if with_past else out return input_keys, onnx_inputs, out.past_key_values -def move_to_approprate_device(model: nn.Module, sample_inputs_tp: tuple) -> nn.Module: +def move_to_appropriate_device(model: nn.Module, sample_inputs_tp: tuple) -> nn.Module: """ According to the model size, we will upload it to CPU if has no GPU or enough GPU memory, @@ -307,7 +307,7 @@ def export_onnx(hf_model: str, cache_dir: Optional[str], onnx_path_str: str, wit """ model, sample_inputs_tp = initialize_model_and_sample_inputs(hf_model, cache_dir) - model = move_to_approprate_device(model, sample_inputs_tp) + model = move_to_appropriate_device(model, sample_inputs_tp) sample_inputs = adapt_inputs_to_device(sample_inputs_tp, next(model.parameters()).device) diff --git a/onnxruntime/test/contrib_ops/gemm_float8_test.cc b/onnxruntime/test/contrib_ops/gemm_float8_test.cc new file mode 100644 index 0000000000000..c022736075cde --- /dev/null +++ b/onnxruntime/test/contrib_ops/gemm_float8_test.cc @@ -0,0 +1,126 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" +#include "test/common/tensor_op_test_utils.h" +#include "test/common/cuda_op_test_utils.h" +#include "test/providers/provider_test_utils.h" + +namespace onnxruntime { +namespace test { + +#if defined(USE_CUDA) && defined(CUDA_VERSION) && CUDA_VERSION >= 12000 + +TEST(GemmFloat8OpTest, BFloat16) { + OpTester test("GemmFloat8", 1, onnxruntime::kMSDomain); + test.AddAttribute("transA", (int64_t)0); + test.AddAttribute("transB", (int64_t)0); + test.AddAttribute("alpha", 1.0f); + test.AddAttribute("beta", 1.0f); + test.AddAttribute("activation", "NONE"); + test.AddAttribute("dtype", static_cast(ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16)); + test.AddInput("A", {2, 4}, MakeBFloat16({1.0f, 2.0f, 3.0f, 4.0f, -1.0f, -2.0f, -3.0f, -4.0f})); + test.AddInput("B", {4, 3}, MakeBFloat16({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})); + test.AddInput("C", {2, 3}, MakeBFloat16({1.f, 1.f, 1.f, 1.f, 1.f, 1.f})); + test.AddOutput("Y", {2, 3}, MakeBFloat16({11.0f, 11.0f, 11.0f, -9.0f, -9.0f, -9.0f})); + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(GemmFloat8OpTest, Float) { + OpTester test("GemmFloat8", 1, onnxruntime::kMSDomain); + test.AddAttribute("transA", (int64_t)0); + test.AddAttribute("transB", (int64_t)0); + test.AddAttribute("alpha", 1.0f); + test.AddAttribute("beta", 1.0f); + test.AddAttribute("activation", "NONE"); + test.AddAttribute("dtype", static_cast(ONNX_NAMESPACE::TensorProto_DataType_FLOAT)); + test.AddInput("A", {2, 4}, std::vector({1.0f, 2.0f, 3.0f, 4.0f, -1.0f, -2.0f, -3.0f, -4.0f})); + test.AddInput("B", {4, 3}, std::vector({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})); + test.AddInput("C", {2, 3}, std::vector({1.f, 1.f, 1.f, 1.f, 1.f, 1.f})); + test.AddOutput("Y", {2, 3}, std::vector({11.0f, 11.0f, 11.0f, -9.0f, -9.0f, -9.0f})); + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +std::vector _Cvt(const std::vector& tensor) { + std::vector fp16_data(tensor.size()); + ConvertFloatToMLFloat16(tensor.data(), fp16_data.data(), static_cast(tensor.size())); + return fp16_data; +} + +TEST(GemmFloat8OpTest, Float16) { + OpTester test("GemmFloat8", 1, onnxruntime::kMSDomain); + test.AddAttribute("transA", (int64_t)0); + test.AddAttribute("transB", (int64_t)0); + test.AddAttribute("alpha", 1.0f); + test.AddAttribute("beta", 1.0f); + test.AddAttribute("activation", "NONE"); + test.AddAttribute("dtype", static_cast(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16)); + test.AddInput("A", {2, 4}, _Cvt(std::vector({1.0f, 2.0f, 3.0f, 4.0f, -1.0f, -2.0f, -3.0f, -4.0f}))); + test.AddInput("B", {4, 3}, _Cvt(std::vector({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); + test.AddInput("C", {2, 3}, _Cvt(std::vector({1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); + test.AddOutput("Y", {2, 3}, _Cvt(std::vector({11.0f, 11.0f, 11.0f, -9.0f, -9.0f, -9.0f}))); + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +#if (!defined(DISABLE_FLOAT8_TYPES)) && (CUDA_VERSION >= 12000) + +template +std::vector _TypedCvt(const std::vector& tensor); + +template <> +std::vector _TypedCvt(const std::vector& tensor) { + return tensor; +} + +template <> +std::vector _TypedCvt(const std::vector& tensor) { + std::vector out(tensor.size()); + for (size_t i = 0; i < tensor.size(); ++i) { + out[i] = Float8E4M3FN(tensor[i]); + } + return out; +} + +template +void TestGemmFloat8WithFloat8(int64_t dtype) { + int min_cuda_architecture = 11080; + if (!HasCudaEnvironment(min_cuda_architecture)) { + LOGS_DEFAULT(WARNING) << "Hardware NOT support Matrix Multiplication for FLOAT8"; + return; + } + OpTester test("GemmFloat8", 1, onnxruntime::kMSDomain); + test.AddAttribute("transA", (int64_t)0); + test.AddAttribute("transB", (int64_t)1); + test.AddAttribute("alpha", 1.0f); + test.AddAttribute("beta", 1.0f); + test.AddAttribute("activation", "NONE"); + test.AddAttribute("dtype", dtype); + test.AddInput("A", {2, 4}, _TypeCvt(std::vector({1.0f, 2.0f, 3.0f, 4.0f, -1.0f, -2.0f, -3.0f, -4.0f}))); + test.AddInput("B", {3, 4}, _TypeCvt(std::vector({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); + test.AddInput("C", {2, 3}, _TypeCvt(std::vector({1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); + test.AddOutput("Y", {2, 3}, _TypeCvt(std::vector({11.0f, 11.0f, 11.0f, -9.0f, -9.0f, -9.0f}))); + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(GemmFloat8OpTest, Float8E4M3FNToFloat) { + TestGemmFloat8WithFloat8(static_cast(ONNX_NAMESPACE::TensorProto_DataType_FLOAT)); +} + +TEST(GemmFloat8OpTest, Float8E4M3FNToFloat8E4M3FN) { + TestGemmFloat8WithFloat8(static_cast(ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN)); +} + +#endif + +#endif + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 17b26ed7ca4ca..ef6e2d531bc1a 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -1176,6 +1176,162 @@ TEST_F(GraphTransformationTests, ConstantFoldingIfConstantInliningRebuildEdges) ASSERT_EQ(op_to_count["Cast"], 2); } +TEST_F(GraphTransformationTests, ConstantFoldingIfConstantInliningEdgesWithMiddleArgNonExisting) { + // This model has a Resize() call with a middle argument non-existing. + // We want to make sure that the input edges for that Resize() node + // are properly rebuilt with a middle argument non-existing + // during If constant folding + // This test is only valid if Resize() node resides in the nested subgraph which gets inlined + // however, the destination graph must not be the main graph. Then we test that the edges are rebuild + // properly. Also Resize() should not be the first node in the resulting subgraph, so it has edges + const char* code = R"( + < + ir_version: 8, + opset_import: [ "" : 16, "local" : 1 ] + > + agraph (float[128] x, float[128] x1) => (float[N] y) + { + y = local.aten_gather (x, x1) + } + < + opset_import: [ "" : 16, "local" : 1], + domain: "local" + > + aten_gather (self, index) => (result_16) + { + resize_scales = Constant () + tmp_0 = Size (index) + int64_0 = Constant () + int64_0_cast = CastLike (int64_0, tmp_0) + cond = Equal (tmp_0, int64_0_cast) + result_16 = If (cond) ( result) { + result = Identity (self) + }, else_branch: graph = elseGraph_10 () => ( result_15) { + tmp_1 = Shape (self) + tmp_2 = Size (tmp_1) + int64_0_3 = Constant () + int64_0_3_cast = CastLike (int64_0_3, tmp_2) + cond_4 = Equal (tmp_2, int64_0_3_cast) + self_8 = If (cond_4) ( self_6) { + tmp_5 = Constant () + self_6 = Reshape (self, tmp_5) + }, else_branch: graph = elseGraph_13 () => ( self_7) { + self_71 = Mul(self, self) + float_size = CastLike (tmp_0, resize_scales) + non_constant_resize_scales = Mul(float_size, resize_scales) + self_7 = Resize(self_71,, non_constant_resize_scales) + }> + tmp_9 = Size (index) + int64_0_10 = Constant () + int64_0_10_cast = CastLike (int64_0_10, tmp_9) + cond_11 = Equal (tmp_9, int64_0_10_cast) + result_15 = If (cond_11) ( result_12) { + result_12 = CastLike (index, self_8) + }, else_branch: graph = elseGraph_15 () => ( result_14) { + index_13 = Cast (index) + result_14 = GatherElements (self_8, index_13) + }> + }> + } + )"; + + /** Optimized model graph + < + ir_version: 8, + opset_import: ["" : 16, + "local" : 1, + "com.microsoft.nchwc" : 1, + "ai.onnx.ml" : 4, + "ai.onnx.training" : 1, + "ai.onnx.preview.training" : 1, + "com.microsoft" : 1, + "com.microsoft.experimental" : 1, "org.pytorch.aten" : 1] + > + agraph (float[128] x, float[128] x1) => (float[128] y) + + { + _inlfunc_aten_gather_tmp_0 = Size (x1) + _inlfunc_aten_gather_cond = Equal (_inlfunc_aten_gather_tmp_0, ortshared_7_0_1_0_token_8) + y = If (_inlfunc_aten_gather_cond) + (float[128] _inlfunc_aten_gather_result) { + _inlfunc_aten_gather_result = Identity (x) + }, else_branch: graph = elseGraph_10 () => (float[128] _inlfunc_aten_gather_result_15) + + { + _if_else_branch__inlfunc_aten_gather_self_71 = Mul (x, x) + _if_else_branch__inlfunc_aten_gather_float_size = Cast (_inlfunc_aten_gather_tmp_0) + _if_else_branch__inlfunc_aten_gather_non_constant_resize_scales = Mul ( + _if_else_branch__inlfunc_aten_gather_float_size, _inlfunc_aten_gather_resize_scales) + _inlfunc_aten_gather_self_8 = Resize ( + _if_else_branch__inlfunc_aten_gather_self_71, , + _if_else_branch__inlfunc_aten_gather_non_constant_resize_scales) + _inlfunc_aten_gather_tmp_9 = Size (x1) + _inlfunc_aten_gather_cond_11 = Equal (_inlfunc_aten_gather_tmp_9, _inlfunc_aten_gather_int64_0_10) + _inlfunc_aten_gather_result_15 = If (_inlfunc_aten_gather_cond_11) + (float[128] _inlfunc_aten_gather_result_12) { + _inlfunc_aten_gather_result_12 = Cast (x1) + }, else_branch: graph = elseGraph_15 () => (float[128] _inlfunc_aten_gather_result_14) { + _inlfunc_aten_gather_index_13 = Cast (x1) + _inlfunc_aten_gather_result_14 = GatherElements ( + _inlfunc_aten_gather_self_8, _inlfunc_aten_gather_index_13) + }> + }> + } + + */ + + ONNX_NAMESPACE::OnnxParser parser(code); + ONNX_NAMESPACE::ModelProto model_proto; + auto parse_status = parser.Parse(model_proto); + ASSERT_TRUE(parse_status.IsOK()) << parse_status.ErrorMessage(); + ASSERT_TRUE(parser.EndOfInput()) << "Extra unparsed input unexpected."; + + std::string serialized_model; + const bool serialization_status = model_proto.SerializeToString(&serialized_model); + ASSERT_TRUE(serialization_status) << "Failed to serialize proto to string"; + + // AOT inlining is necessary in this case, so the If nodes within the function + // are brought out to the outer scope. So we load this into a session object. + SessionOptions session_options; + InferenceSessionWrapper session_object{session_options, GetEnvironment()}; + std::stringstream sstr(serialized_model); + ASSERT_STATUS_OK(session_object.Load(sstr)); + ASSERT_STATUS_OK(session_object.Initialize()); + + // Let's verify the correctness of the rebuild edges in the Resize node that still + // resides within an if else subgraph. + auto& graph = session_object.GetModel().MainGraph(); + auto op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["If"], 2); + ASSERT_EQ(op_to_count["Resize"], 1); + + auto if_node = std::find_if(graph.Nodes().begin(), graph.Nodes().end(), + [](const auto& node) { return node.OpType() == "If"; }); + ASSERT_NE(graph.Nodes().cend(), if_node); + // Resize is in the else branch + auto subgraph_map = if_node->GetAttributeNameToSubgraphMap(); + auto branch = subgraph_map.find("else_branch"); + ASSERT_NE(subgraph_map.cend(), branch); + + auto resize_node = std::find_if(branch->second->Nodes().begin(), branch->second->Nodes().end(), + [](const auto& node) { return node.OpType() == "Resize"; }); + ASSERT_NE(branch->second->Nodes().cend(), resize_node); + + // Check the edges + ASSERT_EQ(2U, resize_node->GetInputEdgesCount()); + // Should have input edges with arg_pos 0 and 2 + // With 1 is missing + InlinedHashSet dest_edges; + auto zero_edge = resize_node->InputEdgesBegin(); + dest_edges.insert(zero_edge->GetDstArgIndex()); + ++zero_edge; + dest_edges.insert(zero_edge->GetDstArgIndex()); + ASSERT_TRUE(dest_edges.find(0) != dest_edges.end()); + ASSERT_TRUE(dest_edges.find(2) != dest_edges.end()); +} + // Check transformations in the case of a subgraph with constant inputs. TEST_F(GraphTransformationTests, SubgraphWithConstantInputs) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "constant-subgraph.onnx"; diff --git a/onnxruntime/test/python/onnxruntime_test_float8.py b/onnxruntime/test/python/onnxruntime_test_float8.py index 76ca5d9538374..bb63ea234498f 100644 --- a/onnxruntime/test/python/onnxruntime_test_float8.py +++ b/onnxruntime/test/python/onnxruntime_test_float8.py @@ -334,7 +334,7 @@ def test_model_cast_cast_cpu(self, name: str, float_name: str, saturate: int): ] ) @unittest.skipIf(not hasattr(TensorProto, "FLOAT8E4M3FN"), reason="needs onnx>=1.14.0") - @unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running on CUDA.") + @unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running without CUDA.") def test_model_cast_cast_cuda(self, name: str, float_name: str, saturate: int, provider: str): so = onnxruntime.SessionOptions() so.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL @@ -373,7 +373,7 @@ def test_model_cast_cast_cuda(self, name: str, float_name: str, saturate: int, p ] ) @unittest.skipIf(not hasattr(TensorProto, "FLOAT8E4M3FN"), reason="needs onnx>=1.14.0") - @unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running on CUDA.") + @unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running without CUDA.") def test_model_cast_cast_cuda_ortvalue(self, name: str, float_name: str, saturate: int, provider: str): so = onnxruntime.SessionOptions() so.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL @@ -627,7 +627,7 @@ def test_model_cast_like_x2_cpu(self, name: str, float_name: str, saturate: int) ] ) @unittest.skipIf(not hasattr(TensorProto, "FLOAT8E4M3FN"), reason="needs onnx>=1.14.0") - @unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running on CUDA.") + @unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running without CUDA.") def test_model_qdq_cuda(self, name: str, float_name: str, saturate: int, provider: str): so = onnxruntime.SessionOptions() so.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL @@ -693,7 +693,7 @@ def test_model_qdq_cuda_ortvalue(self, name: str, float_name: str, saturate: int self.assertEqual(expect.shape, y.shape) self.assertEqual(expect.dtype, y.dtype) - @unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running on CUDA.") + @unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running without CUDA.") def test_compare_cpu_cuda_e4m3fn(self): folder = os.path.join(os.path.dirname(__file__), "..", "testdata", "float8") model = os.path.join(folder, "te.cast_fp8_1_fp32.onnx") diff --git a/onnxruntime/test/python/onnxruntime_test_float8_gemm8.py b/onnxruntime/test/python/onnxruntime_test_float8_gemm8.py index 784ae8ce70bd8..7dffad8f84c83 100644 --- a/onnxruntime/test/python/onnxruntime_test_float8_gemm8.py +++ b/onnxruntime/test/python/onnxruntime_test_float8_gemm8.py @@ -17,7 +17,9 @@ from onnx.helper import make_graph, make_model, make_node, make_opsetid, make_tensor_value_info from onnx.numpy_helper import from_array -from onnxruntime import InferenceSession +from onnxruntime import InferenceSession, get_available_providers + +available_providers = [provider for provider in get_available_providers()] class TestFloat8Gemm8(unittest.TestCase): @@ -192,21 +194,27 @@ def check(f): self.assertEqual(expected.shape, y.shape) self.assertEqual(expected.dtype, y.dtype) + @unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running without CUDA.") def test_model_gemm_float(self): self.common_test_model_gemm("FLOAT", transA=1, rtol=1e-3) + @unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running without CUDA.") def test_model_gemm_float_default_values(self): self.common_test_model_gemm("FLOAT", transA=1, rtol=1e-3, activation=None) + @unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running without CUDA.") def test_model_gemm_float_relu(self): self.common_test_model_gemm("FLOAT", transA=1, rtol=1e-3, activation="RELU") + @unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running without CUDA.") def test_model_gemm_float_gelu(self): self.common_test_model_gemm("FLOAT", transA=1, rtol=1e-3, activation="GELU") + @unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running without CUDA.") def test_model_gemm_float_bias(self): self.common_test_model_gemm("FLOAT", transA=1, beta=1.0, rtol=1e-3) + @unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running without CUDA.") def test_model_gemm_float16(self): self.common_test_model_gemm( "FLOAT16", @@ -215,6 +223,8 @@ def test_model_gemm_float16(self): transB=1, ) + @unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running without CUDA.") + @unittest.skipIf(not hasattr(TensorProto, "FLOAT8E4M3FN"), reason="needs onnx>=1.14.0") def test_model_gemm_float8_e4m3(self): self.common_test_model_gemm( "FLOAT8E4M3FN", @@ -226,6 +236,7 @@ def test_model_gemm_float8_e4m3(self): ) @parameterized.parameterized.expand(list(itertools.product([0, 1], [0, 1]))) + @unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running without CUDA.") def test_combinations_square_matrices(self, transA, transB): self.common_test_model_gemm("FLOAT", transA=transA, transB=transB, rtol=1e-3) @@ -237,6 +248,7 @@ def test_combinations_square_matrices(self, transA, transB): ((2, 3), (2, 5), 1, 0), ] ) + @unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running without CUDA.") def test_combinations(self, shapeA, shapeB, transA, transB): model = make_model( make_graph( diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py index 4977272de5ac9..8efbe16d7d61d 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py @@ -412,14 +412,24 @@ def _matmul4bit_export(g, n, *args, **kwargs): return None quant_state = args[4] - absmax, shape, dtype, blocksize, compressed_stats, quant_type, data_type = quant_state + if isinstance(quant_state, list): + # version <= 0.41.1 + absmax, shape, dtype, blocksize, compressed_stats, quant_type, data_type = quant_state + nested = compressed_stats is not None + else: + # version > 0.41.1 + absmax = quant_state.absmax + shape = quant_state.shape + blocksize = quant_state.blocksize + nested = quant_state.nested + quant_type = quant_state.quant_type # MatMulBnb4's blocksize needs to be a power of 2 and not smaller than 16 if blocksize < 16 or blocksize & (blocksize - 1) != 0: return None # MatMulBnb4 does not support double de-quantization (e.g. absmax is int, needs to be dequantized too) - if compressed_stats is not None: + if nested: return None # The PyTorch linear weight shape is [out_feature, in_feature] diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 6bd3e2533c045..3b1a0317c58f1 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -14,6 +14,15 @@ import sys from pathlib import Path + +def version_to_tuple(version: str) -> tuple: + v = [] + for s in version.split("."): + with contextlib.suppress(ValueError): + v.append(int(s)) + return tuple(v) + + SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__)) REPO_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..")) @@ -1084,6 +1093,12 @@ def generate_build_tree( if args.use_cuda: nvcc_threads = number_of_nvcc_threads(args) cmake_args.append("-Donnxruntime_NVCC_THREADS=" + str(nvcc_threads)) + if not disable_float8_types and args.cuda_version: + if version_to_tuple(args.cuda_version) < (11, 8): + raise BuildError( + f"Float 8 types require CUDA>=11.8. They must be disabled on CUDA=={args.cuda_version}. " + f"Add '--disable_types float8' to your command line. See option disable_types." + ) if args.use_rocm: cmake_args.append("-Donnxruntime_ROCM_HOME=" + rocm_home) cmake_args.append("-Donnxruntime_ROCM_VERSION=" + args.rocm_version) diff --git a/tools/ci_build/github/azure-pipelines/py-cuda-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-cuda-packaging-pipeline.yml new file mode 100644 index 0000000000000..aee42d3675087 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/py-cuda-packaging-pipeline.yml @@ -0,0 +1,39 @@ +trigger: none + +parameters: + - name: enable_linux_gpu + type: boolean + default: true + - name: enable_windows_gpu + type: boolean + default: true + - name: cmake_build_type + type: string + default: 'Release' + values: + - Debug + - Release + - RelWithDebInfo + - MinSizeRel + - name: cuda_version + type: string + default: '12.2' + values: + - 11.8 + - 12.2 + +resources: + repositories: + - repository: manylinux + type: Github + endpoint: Microsoft + name: pypa/manylinux + ref: 5eda9aded5462201e6310105728d33016e637ea7 + +stages: + - template: stages/py-cuda-packaging-stage.yml + parameters: + enable_linux_gpu: ${{ parameters.enable_linux_gpu }} + enable_windows_gpu: ${{ parameters.enable_windows_gpu }} + cmake_build_type: ${{ parameters.cmake_build_type }} + cuda_version: ${{ parameters.cuda_version }} \ No newline at end of file diff --git a/tools/ci_build/github/azure-pipelines/stages/py-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-cuda-packaging-stage.yml new file mode 100644 index 0000000000000..f3d68957d649c --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/stages/py-cuda-packaging-stage.yml @@ -0,0 +1,105 @@ +parameters: +- name: build_py_parameters + displayName: > + Extra parameters to pass to build.py. Don't put newlines in here. + type: string + default: '' + +- name: enable_linux_gpu + displayName: 'Whether Linux GPU package is built.' + type: boolean + default: true + +- name: enable_windows_gpu + displayName: 'Whether Windows GPU package is built.' + type: boolean + default: true + +# TODO: Now the Windows jobs use a different cmake build type. Consider to merge it. +- name: cmake_build_type + type: string + displayName: 'Linux packages cmake build type. Linux Only.' + default: 'Release' + values: + - Debug + - Release + - RelWithDebInfo + - MinSizeRel + +- name: cuda_version + type: string + displayName: 'CUDA version. Windows Only.' + default: '12.2' + values: + - 11.8 + - 12.2 + +stages: +- stage: Python_Packaging + dependsOn: [] + variables: + - name: docker_base_image + ${{ if eq(parameters.cuda_version, '11.8') }}: + value: nvidia/cuda:11.8.0-cudnn8-devel-ubi8 + ${{ if eq(parameters.cuda_version, '12.2') }}: + value: nvidia/cuda:12.2.2-cudnn8-devel-ubi8 + - name: linux_trt_version + ${{ if eq(parameters.cuda_version, '11.8') }}: + value: 8.6.1.6-1.cuda11.8 + ${{ if eq(parameters.cuda_version, '12.2') }}: + value: 8.6.1.6-1.cuda12.0 + - name: win_trt_home + ${{ if eq(parameters.cuda_version, '11.8') }}: + value: $(Agent.TempDirectory)\TensorRT-8.6.1.6.Windows10.x86_64.cuda-11.8 + ${{ if eq(parameters.cuda_version, '12.2') }}: + value: $(Agent.TempDirectory)\TensorRT-8.6.1.6.Windows10.x86_64.cuda-12.0 + - name: win_cuda_home + ${{ if eq(parameters.cuda_version, '11.8') }}: + value: $(Agent.TempDirectory)\v11.8 + ${{ if eq(parameters.cuda_version, '12.2') }}: + value: $(Agent.TempDirectory)\v12.2 + jobs: + - ${{ if eq(parameters.enable_windows_gpu, true) }}: + - template: ../templates/py-win-gpu.yml + parameters: + MACHINE_POOL: 'onnxruntime-Win2022-GPU-T4' + PYTHON_VERSION: '3.8' + EP_BUILD_FLAGS: --use_tensorrt --tensorrt_home=${{ variables.win_trt_home }} --cuda_home=${{ variables.win_cuda_home }} --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" + EP_NAME: gpu + CudaVersion: ${{ parameters.cuda_version }} + + - template: ../templates/py-win-gpu.yml + parameters: + MACHINE_POOL: 'onnxruntime-Win2022-GPU-T4' + PYTHON_VERSION: '3.9' + EP_BUILD_FLAGS: --use_tensorrt --tensorrt_home=${{ variables.win_trt_home }} --cuda_home=${{ variables.win_cuda_home }} --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" + EP_NAME: gpu + CudaVersion: ${{ parameters.cuda_version }} + + - template: ../templates/py-win-gpu.yml + parameters: + MACHINE_POOL: 'onnxruntime-Win2022-GPU-T4' + PYTHON_VERSION: '3.10' + EP_BUILD_FLAGS: --use_tensorrt --tensorrt_home=${{ variables.win_trt_home }} --cuda_home=${{ variables.win_cuda_home }} --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" + EP_NAME: gpu + CudaVersion: ${{ parameters.cuda_version }} + + - template: ../templates/py-win-gpu.yml + parameters: + MACHINE_POOL: 'onnxruntime-Win2022-GPU-T4' + PYTHON_VERSION: '3.11' + EP_BUILD_FLAGS: --use_tensorrt --tensorrt_home=${{ variables.win_trt_home }} --cuda_home=${{ variables.win_cuda_home }} --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" + EP_NAME: gpu + CudaVersion: ${{ parameters.cuda_version }} + + + - ${{ if eq(parameters.enable_linux_gpu, true) }}: + - template: ../templates/py-linux-gpu.yml + parameters: + arch: 'x86_64' + machine_pool: 'onnxruntime-Ubuntu2004-AMD-CPU' + extra_build_arg: ${{ parameters.build_py_parameters }} + cmake_build_type: ${{ parameters.cmake_build_type }} + docker_base_image: ${{ variables.docker_base_image }} + trt_version: ${{ variables.linux_trt_version }} + cuda_version: ${{ parameters.cuda_version }} diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_gpu_library.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_gpu_library.yml index 4573c56963e34..ff7f0957e94ba 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_gpu_library.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_gpu_library.yml @@ -34,7 +34,7 @@ steps: displayName: 'Download TensorRT-8.6.1.6.Windows10.x86_64.cuda-11.8' - powershell: | Write-Host "##vso[task.prependpath]$(Agent.TempDirectory)\TensorRT-8.6.1.6.Windows10.x86_64.cuda-11.8\lib" - displayName: 'Append CUDA SDK Directory to PATH' + displayName: 'Append TensorRT Directory to PATH' - ${{ if eq(parameters.CudaVersion, '12.2') }}: - powershell: | @@ -42,7 +42,7 @@ steps: displayName: 'Download TensorRT-8.6.1.6.Windows10.x86_64.cuda-12.0' - powershell: | Write-Host "##vso[task.prependpath]$(Agent.TempDirectory)\TensorRT-8.6.1.6.Windows10.x86_64.cuda-12.0\lib" - displayName: 'Append CUDA SDK Directory to PATH' + displayName: 'Append TensorRT Directory to PATH' - task: CmdLine@2 inputs: diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml index 9282cfccd02f0..e40c4d0e95dc5 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml @@ -4,6 +4,7 @@ parameters: - name: EnvSetupScript type: string + default: setup_env.bat - name: job_name_suffix type: string diff --git a/tools/ci_build/github/azure-pipelines/templates/py-linux-gpu.yml b/tools/ci_build/github/azure-pipelines/templates/py-linux-gpu.yml index f68847afff379..8cc48aac7a3b9 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-linux-gpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-linux-gpu.yml @@ -17,7 +17,24 @@ parameters: - Release - RelWithDebInfo - MinSizeRel - +- name: docker_base_image + type: string + default: 'nvidia/cuda:11.8.0-cudnn8-devel-ubi8' + values: + - nvidia/cuda:11.8.0-cudnn8-devel-ubi8 + - nvidia/cuda:12.2.2-cudnn8-devel-ubi8 +- name: trt_version + type: string + default: '8.6.1.6-1.cuda11.8' + values: + - 8.6.1.6-1.cuda11.8 + - 8.6.1.6-1.cuda12.0 +- name: cuda_version + type: string + default: '11.8' + values: + - 11.8 + - 12.2 jobs: - job: Linux_py_GPU_Wheels_${{ parameters.arch }} timeoutInMinutes: 240 @@ -26,7 +43,13 @@ jobs: pool: ${{ parameters.machine_pool }} variables: # The build machine pool doesn't have dotnet, so it can't run CG. - skipComponentGovernanceDetection: true + - name: skipComponentGovernanceDetection + value: true + - name: extra_build_args + ${{ if ne(parameters.extra_build_arg, '') }}: + value: -x ${{ parameters.extra_build_arg }} + ${{ if eq(parameters.extra_build_arg, '') }}: + value: '' steps: - checkout: self clean: true @@ -40,12 +63,12 @@ jobs: Context: tools/ci_build/github/linux/docker DockerBuildArgs: " --network=host - --build-arg BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubi8 - --build-arg TRT_VERSION=8.6.1.6-1.cuda11.8 + --build-arg BASEIMAGE=${{ parameters.docker_base_image }} + --build-arg TRT_VERSION=${{ parameters.trt_version }} --build-arg BUILD_UID=$( id -u ) --build-arg PLATFORM=${{ parameters.arch }} " - Repository: onnxruntimecuda118xtrt86build${{ parameters.arch }} + Repository: onnxruntimecuda${{ replace(parameters.cuda_version, '.', '') }}xtrt86build${{ parameters.arch }} - task: Bash@3 @@ -53,8 +76,7 @@ jobs: inputs: targetType: filePath filePath: tools/ci_build/github/linux/run_python_dockerbuild.sh - # please check ONNXRUNTIME_CUDA_VERSION in tools/ci_build/github/linux/build_linux_arm64_python_package.sh - arguments: -i onnxruntimecuda118xtrt86build${{ parameters.arch }} -d "GPU" -c ${{ parameters.cmake_build_type }} -x "${{ parameters.extra_build_arg }}" + arguments: -i onnxruntimecuda${{ replace(parameters.cuda_version, '.', '') }}xtrt86build${{ parameters.arch }} -d "GPU" -c ${{ parameters.cmake_build_type }} $(extra_build_args) - task: PublishBuildArtifacts@1 displayName: 'Publish Artifact: ONNXRuntime python wheel' diff --git a/tools/ci_build/github/azure-pipelines/templates/py-linux.yml b/tools/ci_build/github/azure-pipelines/templates/py-linux.yml index 0774c3350b9b1..db3782c69cf62 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-linux.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-linux.yml @@ -46,9 +46,17 @@ jobs: pool: ${{ parameters.machine_pool }} variables: # The build machine pool doesn't have dotnet, so it can't run CG. - skipComponentGovernanceDetection: true - ORT_CACHE_DIR: $(Agent.TempDirectory)/ort_ccache - TODAY: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] + - name: skipComponentGovernanceDetection + value: true + - name: ORT_CACHE_DIR + value: $(Agent.TempDirectory)/ort_ccache + - name: TODAY + value: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] + - name: extra_build_args + ${{ if ne(parameters.extra_build_arg, '') }}: + value: -x ${{ parameters.extra_build_arg }} + ${{ if eq(parameters.extra_build_arg, '') }}: + value: '' steps: - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 displayName: 'Clean Agent Directories' @@ -82,7 +90,7 @@ jobs: inputs: targetType: filePath filePath: tools/ci_build/github/linux/run_python_dockerbuild.sh - arguments: -i onnxruntimecpubuildpython${{ parameters.arch }} -d "${{ parameters.device }}" -c ${{ parameters.cmake_build_type }} -x "${{ parameters.extra_build_arg }}" + arguments: -i onnxruntimecpubuildpython${{ parameters.arch }} -d "${{ parameters.device }}" -c ${{ parameters.cmake_build_type }} $(extra_build_args) ${{ if eq(parameters.with_cache, 'true') }}: env: ADDITIONAL_DOCKER_PARAMETER: "--volume $(ORT_CACHE_DIR):/cache -e CCACHE_DIR=/cache -e ORT_BUILD_WITH_CACHE=1" diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml index 919749cac15b6..501251eaff20f 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml @@ -14,21 +14,32 @@ parameters: - name: ENV_SETUP_SCRIPT type: string + default: '' - name: BUILD_PY_PARAMETERS displayName: > Extra parameters to pass to build.py. Don't put newlines in here. type: string default: '' - +- name: CudaVersion + type: string + default: '11.8' + values: + - 11.8 + - 12.2 jobs: - job: Win_py_${{ parameters.EP_NAME }}_Wheels_${{ replace(parameters.PYTHON_VERSION,'.','_') }} timeoutInMinutes: 240 workspace: clean: all - pool: ${{ parameters.MACHINE_POOL }} + pool: + name: ${{ parameters.MACHINE_POOL }} +# demands: +# - ImageVersionOverride -equals 1.0.367516 variables: + GRADLE_OPTS: '-Dorg.gradle.daemon=false' VSGenerator: 'Visual Studio 17 2022' + CUDA_MODULE_LOADING: 'LAZY' steps: - checkout: self clean: true @@ -61,10 +72,21 @@ jobs: - template: download-deps.yml - - template: jobs/set-winenv.yml - parameters: - EnvSetupScript: ${{ parameters.ENV_SETUP_SCRIPT }} - DownloadCUDA: true + - ${{ if ne(parameters.ENV_SETUP_SCRIPT, '') }}: + - template: jobs/set-winenv.yml + parameters: + EnvSetupScript: ${{ parameters.ENV_SETUP_SCRIPT }} + ${{ if or(contains(parameters.EP_BUILD_FLAGS, 'use_cuda'), contains(parameters.EP_BUILD_FLAGS, 'use_tensorrt')) }}: + DownloadCUDA: true + + - ${{ if eq(parameters.ENV_SETUP_SCRIPT, '') }}: + - template: jobs/download_win_gpu_library.yml + parameters: + CudaVersion: ${{ parameters.CudaVersion }} + ${{ if or(contains(parameters.EP_BUILD_FLAGS, 'use_cuda'), contains(parameters.EP_BUILD_FLAGS, 'use_tensorrt')) }}: + DownloadCUDA: true + ${{ if contains(parameters.EP_BUILD_FLAGS, 'use_tensorrt') }}: + DownloadTRT: true - task: PythonScript@0 displayName: 'Update deps.txt' diff --git a/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml index ed010b5619db5..d7ffc1828c943 100644 --- a/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml @@ -40,7 +40,6 @@ stages: - template: templates/jobs/win-ci-vs-2022-job.yml parameters: BuildConfig: 'Debug' - EnvSetupScript: setup_env.bat buildArch: x64 additionalBuildFlags: --build_java --build_nodejs --build_wheel --disable_memleak_checker msbuildPlatform: x64 @@ -59,7 +58,6 @@ stages: - template: templates/jobs/win-ci-vs-2022-job.yml parameters: BuildConfig: 'RelWithDebInfo' - EnvSetupScript: setup_env.bat buildArch: x64 # Compare to our Nuget packaging pipeline, this job has "--build_wheel" but doesn't have "--enable_lto --disable_rtti --use_telemetry --enable_wcos" # Python bindings use typeid so I can't disable RTTI here. If it causes a problem, we will need to split this job to two jobs. @@ -80,7 +78,6 @@ stages: - template: templates/jobs/win-ci-vs-2022-job.yml parameters: BuildConfig: 'RelWithDebInfo' - EnvSetupScript: setup_env.bat buildArch: x64 additionalBuildFlags: --build_wheel --use_dnnl --build_java msbuildPlatform: x64 @@ -101,7 +98,6 @@ stages: - template: templates/jobs/win-ci-vs-2022-job.yml parameters: BuildConfig: 'RelWithDebInfo' - EnvSetupScript: setup_env.bat buildArch: x64 additionalBuildFlags: --build_wheel --use_xnnpack msbuildPlatform: x64 @@ -120,7 +116,6 @@ stages: - template: templates/jobs/win-ci-vs-2022-job.yml parameters: BuildConfig: 'RelWithDebInfo' - EnvSetupScript: setup_env.bat buildArch: x64 additionalBuildFlags: --use_winml --enable_wcos --disable_rtti --cmake_extra_defines CMAKE_SYSTEM_VERSION=10.0.22000.0 msbuildPlatform: x64 @@ -160,7 +155,6 @@ stages: - template: templates/jobs/win-ci-vs-2022-job.yml parameters: BuildConfig: 'Debug' - EnvSetupScript: setup_env.bat buildArch: x64 additionalBuildFlags: --enable_training --build_wheel --disable_memleak_checker msbuildPlatform: x64 @@ -179,7 +173,6 @@ stages: - template: templates/jobs/win-ci-vs-2022-job.yml parameters: BuildConfig: 'RelWithDebInfo' - EnvSetupScript: setup_env.bat buildArch: x64 additionalBuildFlags: --enable_training --build_wheel msbuildPlatform: x64 @@ -198,7 +191,6 @@ stages: - template: templates/jobs/win-ci-vs-2022-job.yml parameters: BuildConfig: 'RelWithDebInfo' - EnvSetupScript: setup_env.bat buildArch: x64 additionalBuildFlags: --enable_training_apis msbuildPlatform: x64 @@ -215,10 +207,17 @@ stages: - stage: x64_release_azure dependsOn: [] jobs: + - job: + steps: + - powershell: | + Write-Host "##vso[task.prependpath]$(Build.BinariesDirectory)\RelWithDebInfo\_deps\vcpkg-src\installed\x86-windows\bin" + $env:PATH + Write-Host "##vso[task.prependpath]$(Build.BinariesDirectory)\RelWithDebInfo\_deps\vcpkg-src\installed\x64-windows\bin" + $env:PATH + displayName: 'Append x64-windows and x86-windows to PATH' - template: templates/jobs/win-ci-vs-2022-job.yml parameters: BuildConfig: 'RelWithDebInfo' - EnvSetupScript: setup_env_azure.bat buildArch: x64 additionalBuildFlags: --use_azure --use_lock_free_queue msbuildPlatform: x64 @@ -231,3 +230,5 @@ stages: GenerateDocumentation: false WITH_CACHE: true MachinePool: 'onnxruntime-Win-CPU-2022' + + diff --git a/tools/ci_build/github/linux/build_linux_arm64_python_package.sh b/tools/ci_build/github/linux/build_linux_python_package.sh similarity index 78% rename from tools/ci_build/github/linux/build_linux_arm64_python_package.sh rename to tools/ci_build/github/linux/build_linux_python_package.sh index 516f320cd64c4..3c1c65c9a6862 100755 --- a/tools/ci_build/github/linux/build_linux_arm64_python_package.sh +++ b/tools/ci_build/github/linux/build_linux_python_package.sh @@ -15,9 +15,11 @@ do case "${parameter_Option}" in #GPU or CPU. d) BUILD_DEVICE=${OPTARG};; -p) PYTHON_EXES=(${OPTARG});; -x) EXTRA_ARG=(${OPTARG});; +p) PYTHON_EXES=${OPTARG};; +x) EXTRA_ARG=${OPTARG};; c) BUILD_CONFIG=${OPTARG};; +*) echo "Usage: $0 -d [-p ] [-x ] [-c ]" + exit 1;; esac done @@ -48,7 +50,7 @@ if [ "$ARCH" == "x86_64" ] && [ "$GCC_VERSION" -ge 9 ]; then fi echo "EXTRA_ARG:" -echo $EXTRA_ARG +echo "$EXTRA_ARG" if [ "$EXTRA_ARG" != "" ]; then BUILD_ARGS+=("$EXTRA_ARG") @@ -60,19 +62,19 @@ if [ "$ARCH" == "x86_64" ]; then fi if [ "$BUILD_DEVICE" == "GPU" ]; then + SHORT_CUDA_VERSION=$(echo $CUDA_VERSION | sed 's/\([[:digit:]]\+\.[[:digit:]]\+\)\.[[:digit:]]\+/\1/') #Enable CUDA and TRT EPs. - ONNXRUNTIME_CUDA_VERSION="11.8" - BUILD_ARGS+=("--nvcc_threads=1" "--use_cuda" "--use_tensorrt" "--cuda_version=$ONNXRUNTIME_CUDA_VERSION" "--tensorrt_home=/usr" "--cuda_home=/usr/local/cuda-$ONNXRUNTIME_CUDA_VERSION" "--cudnn_home=/usr/local/cuda-$ONNXRUNTIME_CUDA_VERSION" "--cmake_extra_defines" "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80") + BUILD_ARGS+=("--nvcc_threads=1" "--use_cuda" "--use_tensorrt" "--cuda_version=$SHORT_CUDA_VERSION" "--tensorrt_home=/usr" "--cuda_home=/usr/local/cuda-$SHORT_CUDA_VERSION" "--cudnn_home=/usr/local/cuda-$SHORT_CUDA_VERSION" "--cmake_extra_defines" "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80") fi export CFLAGS export CXXFLAGS for PYTHON_EXE in "${PYTHON_EXES[@]}" do - rm -rf /build/$BUILD_CONFIG + rm -rf /build/"$BUILD_CONFIG" ${PYTHON_EXE} /onnxruntime_src/tools/ci_build/build.py "${BUILD_ARGS[@]}" - cp /build/$BUILD_CONFIG/dist/*.whl /build/dist + cp /build/"$BUILD_CONFIG"/dist/*.whl /build/dist done which ccache && ccache -sv && ccache -z diff --git a/tools/ci_build/github/linux/run_python_dockerbuild.sh b/tools/ci_build/github/linux/run_python_dockerbuild.sh index 18ac6482827f9..ff2ce6f7ff231 100755 --- a/tools/ci_build/github/linux/run_python_dockerbuild.sh +++ b/tools/ci_build/github/linux/run_python_dockerbuild.sh @@ -9,24 +9,32 @@ i) DOCKER_IMAGE=${OPTARG};; d) DEVICE=${OPTARG};; x) BUILD_EXTR_PAR=${OPTARG};; c) BUILD_CONFIG=${OPTARG};; +*) echo "Usage: $0 -i -d [-x ] [-c ]" + exit 1;; esac done -mkdir -p $HOME/.onnx +mkdir -p "${HOME}/.onnx" +DOCKER_SCRIPT_OPTIONS="-d ${DEVICE} -c ${BUILD_CONFIG}" + +if [ "${BUILD_EXTR_PAR}" != "" ] ; then + DOCKER_SCRIPT_OPTIONS+=" -x ${BUILD_EXTR_PAR}" +fi + docker run --rm \ --volume /data/onnx:/data/onnx:ro \ - --volume $BUILD_SOURCESDIRECTORY:/onnxruntime_src \ - --volume $BUILD_BINARIESDIRECTORY:/build \ + --volume "${BUILD_SOURCESDIRECTORY}:/onnxruntime_src" \ + --volume "${BUILD_BINARIESDIRECTORY}:/build" \ --volume /data/models:/build/models:ro \ - --volume $HOME/.onnx:/home/onnxruntimedev/.onnx \ + --volume "${HOME}/.onnx:/home/onnxruntimedev/.onnx" \ -w /onnxruntime_src \ -e NIGHTLY_BUILD \ -e BUILD_BUILDNUMBER \ $ADDITIONAL_DOCKER_PARAMETER \ - $DOCKER_IMAGE tools/ci_build/github/linux/build_linux_arm64_python_package.sh -d $DEVICE -c $BUILD_CONFIG -x $BUILD_EXTR_PAR + $DOCKER_IMAGE tools/ci_build/github/linux/build_linux_python_package.sh $DOCKER_SCRIPT_OPTIONS -sudo rm -rf $BUILD_BINARIESDIRECTORY/$BUILD_CONFIG/onnxruntime $BUILD_BINARIESDIRECTORY/$BUILD_CONFIG/pybind11 \ - $BUILD_BINARIESDIRECTORY/$BUILD_CONFIG/models $BUILD_BINARIESDIRECTORY/$BUILD_CONFIG/_deps \ - $BUILD_BINARIESDIRECTORY/$BUILD_CONFIG/CMakeFiles -cd $BUILD_BINARIESDIRECTORY/$BUILD_CONFIG -find -executable -type f > $BUILD_BINARIESDIRECTORY/$BUILD_CONFIG/perms.txt +sudo rm -rf "${BUILD_BINARIESDIRECTORY}/${BUILD_CONFIG}/onnxruntime" "${BUILD_BINARIESDIRECTORY}/${BUILD_CONFIG}/pybind11" \ + "${BUILD_BINARIESDIRECTORY}/${BUILD_CONFIG}/models" "${BUILD_BINARIESDIRECTORY}/${BUILD_CONFIG}/_deps" \ + "${BUILD_BINARIESDIRECTORY}/${BUILD_CONFIG}/CMakeFiles" +cd "${BUILD_BINARIESDIRECTORY}/${BUILD_CONFIG}" +find -executable -type f > "${BUILD_BINARIESDIRECTORY}/${BUILD_CONFIG}/perms.txt" diff --git a/tools/ci_build/github/windows/setup_env_azure.bat b/tools/ci_build/github/windows/setup_env_azure.bat deleted file mode 100644 index 44ba34b0bf23a..0000000000000 --- a/tools/ci_build/github/windows/setup_env_azure.bat +++ /dev/null @@ -1,4 +0,0 @@ -REM Copyright (c) Microsoft Corporation. All rights reserved. -REM Licensed under the MIT License. -set PATH=%cd%\RelWithDebInfo\_deps\vcpkg-src\installed\x64-windows\bin;%cd%\RelWithDebInfo\_deps\vcpkg-src\installed\x86-windows\bin;%PATH% -set GRADLE_OPTS=-Dorg.gradle.daemon=false