diff --git a/cmake/external/abseil-cpp.natvis b/cmake/external/abseil-cpp.natvis
index 708d6ba18750b..1e5a36fb9efb9 100644
--- a/cmake/external/abseil-cpp.natvis
+++ b/cmake/external/abseil-cpp.natvis
@@ -30,7 +30,6 @@
- empty
size={ _size() }
size=({_size()})
diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index 8565ffbb6c379..c73f978bdf404 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -2649,8 +2649,8 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Type Constraints
-- T1 : tensor(float), tensor(float16)
-- Constrain input and output types to float/half_float tensors.
+- T1 : tensor(float), tensor(float16), tensor(bfloat16)
+- Constrain input and output types to float/half_float/brain_float tensors.
- T2 : tensor(uint8)
- Constrain quantized weight types to uint8.
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 26b5ebbdbec36..16df788c284ee 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -840,7 +840,7 @@ Do not modify directly.*
|Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
-|MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)|
+|MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)
**T2** = tensor(uint8)|
|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T2**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)|
|MoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T**
*in* fc2_experts_weights:**T**
*in* fc1_experts_bias:**T**
*in* fc2_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**T** = tensor(float), tensor(float16)|
diff --git a/js/.eslintrc.js b/js/.eslintrc.js
index fd30cb96a5bd0..0bf47c5264f61 100644
--- a/js/.eslintrc.js
+++ b/js/.eslintrc.js
@@ -5,10 +5,18 @@
module.exports = {
root: true,
- ignorePatterns: ['**/*.js', 'ort-schema/', 'common/test/type-tests/', 'test/data/', 'node_modules/', 'dist/'],
+ ignorePatterns: [
+ '**/*.js',
+ 'node_modules/',
+ 'ort-schema/',
+ 'common/test/type-tests/',
+ 'web/types.d.ts',
+ 'test/data/',
+ 'dist/',
+ ],
env: { 'es6': true },
parser: '@typescript-eslint/parser',
- parserOptions: { 'project': 'tsconfig.json', 'sourceType': 'module' },
+ parserOptions: { 'project': true, 'sourceType': 'module' },
plugins: ['@typescript-eslint', 'prefer-arrow', 'header', 'import', 'unicorn', 'jsdoc'],
rules: {
'unicorn/filename-case': 'error',
@@ -144,15 +152,56 @@ module.exports = {
'no-unused-expressions': 'off',
}
}, {
- files: ['web/lib/**/*.ts'],
- excludedFiles: 'web/lib/wasm/proxy-worker/**/*',
- parserOptions: { 'project': 'web/tsconfig.json' },
- rules: {
- 'no-underscore-dangle': 'off',
+ files: ['web/lib/**/*.ts'], rules: {
+ 'no-underscore-dangle': ['error', {
+ 'allow': [
+ '_free',
+ '_malloc',
+ '_JsepGetNodeName',
+ '_JsepOutput',
+ '_OrtAddFreeDimensionOverride',
+ '_OrtAddRunConfigEntry',
+ '_OrtAddSessionConfigEntry',
+ '_OrtAppendExecutionProvider',
+ '_OrtBindInput',
+ '_OrtBindOutput',
+ '_OrtClearBoundOutputs',
+ '_OrtCreateBinding',
+ '_OrtCreateRunOptions',
+ '_OrtCreateSession',
+ '_OrtCreateSessionOptions',
+ '_OrtCreateTensor',
+ '_OrtEndProfiling',
+ '_OrtFree',
+ '_OrtGetInputName',
+ '_OrtGetInputOutputCount',
+ '_OrtGetLastError',
+ '_OrtGetOutputName',
+ '_OrtGetTensorData',
+ '_OrtInit',
+ '_OrtReleaseBinding',
+ '_OrtReleaseRunOptions',
+ '_OrtReleaseSession',
+ '_OrtReleaseSessionOptions',
+ '_OrtReleaseTensor',
+ '_OrtRun',
+ '_OrtRunWithBinding',
+ '_OrtTrainingCopyParametersFromBuffer',
+ '_OrtTrainingCopyParametersToBuffer',
+ '_OrtTrainingCreateSession',
+ '_OrtTrainingEvalStep',
+ '_OrtTrainingGetModelInputOutputCount',
+ '_OrtTrainingGetModelInputOutputName',
+ '_OrtTrainingGetParametersSize',
+ '_OrtTrainingLazyResetGrad',
+ '_OrtTrainingLoadCheckpoint',
+ '_OrtTrainingOptimizerStep',
+ '_OrtTrainingReleaseCheckpoint',
+ '_OrtTrainingReleaseSession',
+ '_OrtTrainingRunTrainStep'
+ ]
+ }]
}
- }, {
- files: ['web/lib/wasm/proxy-worker/**/*.ts'],
- parserOptions: { 'project': 'web/lib/wasm/proxy-worker/tsconfig.json' },
}, {
files: ['web/lib/onnxjs/**/*.ts'], rules: {
// TODO: those rules are useful. should turn on them in future (webgl refactor)
@@ -164,6 +213,7 @@ module.exports = {
'import/no-internal-modules': 'off',
'prefer-arrow/prefer-arrow-functions': 'off',
'no-param-reassign': 'off',
+ 'no-underscore-dangle': 'off',
'guard-for-in': 'off'
}
}, {
diff --git a/js/web/lib/onnxjs/attribute-with-cache-key.ts b/js/web/lib/onnxjs/attribute-with-cache-key.ts
index 6608b00471e77..5d47570f267a6 100644
--- a/js/web/lib/onnxjs/attribute-with-cache-key.ts
+++ b/js/web/lib/onnxjs/attribute-with-cache-key.ts
@@ -6,13 +6,13 @@ class AttributeWithCacheKeyImpl {
Object.assign(this, attribute);
}
- private _cacheKey: string;
+ private key: string;
public get cacheKey(): string {
- if (!this._cacheKey) {
- this._cacheKey =
+ if (!this.key) {
+ this.key =
Object.getOwnPropertyNames(this).sort().map(name => `${(this as Record)[name]}`).join(';');
}
- return this._cacheKey;
+ return this.key;
}
}
diff --git a/js/web/lib/wasm/jsep/webgpu/attribute-with-cache-key.ts b/js/web/lib/wasm/jsep/webgpu/attribute-with-cache-key.ts
index adba0fb9d022d..ad56b92c1d869 100644
--- a/js/web/lib/wasm/jsep/webgpu/attribute-with-cache-key.ts
+++ b/js/web/lib/wasm/jsep/webgpu/attribute-with-cache-key.ts
@@ -6,13 +6,13 @@ class AttributeWithCacheKeyImpl {
Object.assign(this, attribute);
}
- private _cacheKey: string;
+ private key: string;
public get cacheKey(): string {
- if (!this._cacheKey) {
- this._cacheKey =
+ if (!this.key) {
+ this.key =
Object.getOwnPropertyNames(this).sort().map(name => `${(this as Record)[name]}`).join(';');
}
- return this._cacheKey;
+ return this.key;
}
}
diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
index 9f5dceb8f4726..bac44328d8f44 100644
--- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
+++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
@@ -55,7 +55,6 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new
['BiasSplitGelu', [biasSplitGelu]],
['Cast', [unaryOps.cast, unaryOps.parseCastAttributes]],
['Ceil', [unaryOps.ceil]],
- ['ClipV10', [unaryOps.clipV10]],
['Clip', [unaryOps.clip]],
['Concat', [concat, parseConcatAttributes]],
['Conv', [conv, parseConvAttributes]],
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts
index 0841da11d9e86..c033c0ba05356 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts
@@ -17,8 +17,9 @@ type BinaryFunctionCall = BuiltinFunctionName|BinaryCustomExpression|{
const createBinaryOpProgramShader =
(shaderHelper: ShaderHelper, dimsA: readonly number[], dimsB: readonly number[], dimsOutput: readonly number[],
- vectorize: boolean, doBroadcast: boolean, funcCall: BinaryFunctionCall, typeA: number, typeB: number,
- typeOutput: number, useShapesUniforms: boolean, additionalImplementation?: string) => {
+ vectorize: boolean, doBroadcast: boolean, sharedDimensionDivisibleBy4: boolean, funcCall: BinaryFunctionCall,
+ typeA: number, typeB: number, typeOutput: number, useShapesUniforms: boolean,
+ additionalImplementation?: string) => {
let expressionScalar: BinaryCustomExpression;
let expressionVector: BinaryCustomExpression;
if (typeof funcCall === 'string') {
@@ -42,6 +43,8 @@ const createBinaryOpProgramShader =
if (doBroadcast) {
const isAOneElement = ShapeUtil.size(dimsA) === 1;
const isBOneElement = ShapeUtil.size(dimsB) === 1;
+ const aLastDimDivisibleBy4 = dimsA.length > 0 && dimsA[dimsA.length - 1] % 4 === 0;
+ const bLastDimDivisibleBy4 = dimsB.length > 0 && dimsB[dimsB.length - 1] % 4 === 0;
if (isAOneElement || isBOneElement) {
assignment = output.setByOffset(
'global_idx',
@@ -55,7 +58,14 @@ const createBinaryOpProgramShader =
let offsetB = ${b.broadcastedIndicesToOffset('outputIndices', output)};
${
output.setByOffset(
- 'global_idx', expressionVector(a.getByOffset('offsetA / 4u'), b.getByOffset('offsetB / 4u')))}
+ 'global_idx',
+ expressionVector(
+ sharedDimensionDivisibleBy4 || aLastDimDivisibleBy4 ?
+ a.getByOffset('offsetA / 4u') :
+ `${a.type.value}(${a.getByOffset('offsetA / 4u')}[offsetA % 4u])`,
+ sharedDimensionDivisibleBy4 || bLastDimDivisibleBy4 ?
+ b.getByOffset('offsetB / 4u') :
+ `${b.type.value}(${b.getByOffset('offsetB / 4u')}[offsetB % 4u])`))}
`;
}
} else {
@@ -118,6 +128,7 @@ const createBinaryOpProgramInfo =
let outputSize = ShapeUtil.size(a.dims);
let vectorize = false;
+ let sharedDimensionDivisibleBy4 = false;
// TODO: deal with zero-sized tensors (eg. dims=[1,0])
const cacheKeyAux = [isBroadcast];
@@ -130,8 +141,12 @@ const createBinaryOpProgramInfo =
outputSize = ShapeUtil.size(outputShape);
const isAOneElement = ShapeUtil.size(a.dims) === 1;
const isBOneElement = ShapeUtil.size(b.dims) === 1;
+ const aLastDimDivisibleBy4 = a.dims.length > 0 && a.dims[a.dims.length - 1] % 4 === 0;
+ const bLastDimDivisibleBy4 = b.dims.length > 0 && b.dims[b.dims.length - 1] % 4 === 0;
cacheKeyAux.push(isAOneElement);
cacheKeyAux.push(isBOneElement);
+ cacheKeyAux.push(aLastDimDivisibleBy4);
+ cacheKeyAux.push(bLastDimDivisibleBy4);
// check whether vectorize can be enabled
let sharedDimension = 1;
for (let i = 1; i < outputShape.length; i++) {
@@ -143,7 +158,10 @@ const createBinaryOpProgramInfo =
break;
}
}
- if (sharedDimension % 4 === 0 || isAOneElement || isBOneElement) {
+ if (sharedDimension % 4 === 0) {
+ sharedDimensionDivisibleBy4 = true;
+ vectorize = true;
+ } else if (isAOneElement || isBOneElement || aLastDimDivisibleBy4 || bLastDimDivisibleBy4) {
vectorize = true;
}
} else {
@@ -160,8 +178,8 @@ const createBinaryOpProgramInfo =
inputDependencies: useShapesUniforms ? ['rank', 'rank'] : ['dims', 'dims'],
},
getShaderSource: (shaderHelper) => createBinaryOpProgramShader(
- shaderHelper, a.dims, b.dims, outputShape, vectorize, isBroadcast, funcCall, a.dataType, b.dataType,
- outputDataType, useShapesUniforms, additionalImplementation),
+ shaderHelper, a.dims, b.dims, outputShape, vectorize, isBroadcast, sharedDimensionDivisibleBy4, funcCall,
+ a.dataType, b.dataType, outputDataType, useShapesUniforms, additionalImplementation),
getRunData: () => ({
outputs: [{dims: outputShape, dataType: outputDataType}],
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* component size */)},
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts
index 4238449f9246f..119609e06f5a3 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts
@@ -124,7 +124,14 @@ export interface ClipAttributes extends AttributeWithCacheKey {
readonly max: number;
}
-export const clipV10 = (context: ComputeContext, attributes: ClipAttributes): void => {
+const generateClipAttributesFromInputs = (inputs: readonly TensorView[]): ClipAttributes => {
+ const min = (inputs.length >= 2) ? inputs[1].getFloat32Array()[0] : MIN_CLIP;
+ const max = (inputs.length >= 3) ? inputs[2].getFloat32Array()[0] : MAX_CLIP;
+ return createAttributeWithCacheKey({min, max});
+};
+
+export const clip = (context: ComputeContext, clipAttributes: ClipAttributes): void => {
+ const attributes = context.inputs.length === 1 ? clipAttributes : generateClipAttributesFromInputs(context.inputs);
const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType);
context.compute(
createElementwiseProgramInfo(
@@ -135,16 +142,6 @@ export const clipV10 = (context: ComputeContext, attributes: ClipAttributes): vo
attributes.cacheKey),
{inputs: [0]});
};
-const generateClipAttributesFromInputs = (inputs: readonly TensorView[]): ClipAttributes => {
- const min = (inputs.length >= 2) ? inputs[1].getFloat32Array()[0] : MIN_CLIP;
- const max = (inputs.length >= 3) ? inputs[2].getFloat32Array()[0] : MAX_CLIP;
- return createAttributeWithCacheKey({min, max});
-};
-
-export const clip = (context: ComputeContext): void => {
- const attributes = generateClipAttributesFromInputs(context.inputs);
- clipV10(context, attributes);
-};
export const ceil = (context: ComputeContext): void => {
context.compute(createElementwiseProgramInfo(context.inputs[0], 'Ceil', 'ceil'));
diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
index 7172a28316f16..108eea1a73fe9 100644
--- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
+++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
@@ -121,6 +121,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Inverse);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MatMulNBits);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MatMulNBits);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, MatMulBnb4);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MatMulBnb4);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MatMulBnb4);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Trilu);
@@ -313,6 +314,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
diff --git a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cc b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cc
index 251850f621361..6cdccdb1becb1 100644
--- a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cc
+++ b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cc
@@ -14,17 +14,23 @@ namespace onnxruntime {
namespace contrib {
namespace cuda {
-#define REGISTER_KERNEL() \
- ONNX_OPERATOR_KERNEL_EX( \
- GemmFloat8, \
- kMSDomain, \
- 1, \
- kCudaExecutionProvider, \
- (*KernelDefBuilder::Create()) \
- .TypeConstraint("TA", BuildKernelDefConstraints()) \
- .TypeConstraint("TB", BuildKernelDefConstraints()) \
- .TypeConstraint("TR", BuildKernelDefConstraints()) \
- .TypeConstraint("TS", BuildKernelDefConstraints()), \
+#if !defined(DISABLE_FLOAT8_TYPES)
+#define GEMM_FLOAT8_CONSTRAINTS BuildKernelDefConstraints()
+#else
+#define GEMM_FLOAT8_CONSTRAINTS BuildKernelDefConstraints()
+#endif
+
+#define REGISTER_KERNEL() \
+ ONNX_OPERATOR_KERNEL_EX( \
+ GemmFloat8, \
+ kMSDomain, \
+ 1, \
+ kCudaExecutionProvider, \
+ (*KernelDefBuilder::Create()) \
+ .TypeConstraint("TA", GEMM_FLOAT8_CONSTRAINTS) \
+ .TypeConstraint("TB", GEMM_FLOAT8_CONSTRAINTS) \
+ .TypeConstraint("TR", GEMM_FLOAT8_CONSTRAINTS) \
+ .TypeConstraint("TS", BuildKernelDefConstraints()), \
GemmFloat8);
REGISTER_KERNEL()
@@ -38,7 +44,7 @@ GemmFloat8::GemmFloat8(const OpKernelInfo& info) : CudaKernel(info) {
alpha_ = info.GetAttrOrDefault("alpha", 1);
beta_ = info.GetAttrOrDefault("beta", 0);
-#if (CUDA_VERSION <= 12000)
+#if (CUDA_VERSION < 12000)
ORT_ENFORCE(beta_ == 0, "CUDA < 12.0 does not support bias, beta must be 0.");
#endif
diff --git a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu
index df25342342cd5..56b541f5256bf 100644
--- a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu
+++ b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu
@@ -28,7 +28,7 @@ int32_t TypeSize(int32_t element_type) {
case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16:
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
return 2;
-#if (!defined(DISABLE_FLOAT8_TYPES) && (CUDA_VERSION >= 11080))
+#if !defined(DISABLE_FLOAT8_TYPES)
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN:
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2:
return 1;
@@ -97,12 +97,16 @@ Status GemmFloat8::ComputeInternal(OpKernelContext* ctx) const {
}
auto first_type = input_A->GetElementType();
+#if !defined(DISABLE_FLOAT8_TYPES)
bool is_float8 = first_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN || first_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2;
if (!is_float8)
+#endif
return ComputeRowMajor(ctx, n_inputs, has_bias, has_scales, input_A, input_B,
input_C, scale_A, scale_B, scale_Y);
+#if !defined(DISABLE_FLOAT8_TYPES)
return ComputeColMajor(ctx, n_inputs, has_bias, has_scales, input_A, input_B,
input_C, scale_A, scale_B, scale_Y);
+#endif
}
Status GemmFloat8::ComputeRowMajor(
@@ -197,10 +201,15 @@ Status GemmFloat8::ComputeGemm(
switch (d_cuda_type) {
case CUDA_R_16F:
switch (a_cuda_type) {
+#if !defined(DISABLE_FLOAT8_TYPES)
+#if CUDA_VERSION < 11080
+#error CUDA_R_8F_E4M3 (float 8 types) is defined with CUDA>=11.8. Set flag DISABLE_FLOAT8_TYPES.
+#endif
case CUDA_R_8F_E4M3:
case CUDA_R_8F_E5M2:
compute_type = CUBLAS_COMPUTE_32F_FAST_TF32;
break;
+#endif
default:
compute_type = CUBLAS_COMPUTE_32F_FAST_16F;
break;
@@ -267,7 +276,7 @@ Status GemmFloat8::ComputeGemm(
sizeof(p_scale_b)));
// float 8
-#if CUDA_VERSION >= 11080
+#if !defined(DISABLE_FLOAT8_TYPES)
if (dtype_Y == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN ||
dtype_Y == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2) {
// For FP8 output, cuBLAS requires C_type to be same as bias_type
@@ -280,15 +289,14 @@ Status GemmFloat8::ComputeGemm(
CUBLAS_RETURN_IF_ERROR(
cublasLtMatrixLayoutCreate(&Cdesc, d_cuda_type, M, N, ldd));
}
- } else {
- CUBLAS_RETURN_IF_ERROR(
- cublasLtMatrixLayoutCreate(&Cdesc, d_cuda_type, M, N, ldd));
- }
#else
- // An output is still needed but it is not initialized.
CUBLAS_RETURN_IF_ERROR(
cublasLtMatrixLayoutCreate(&Cdesc, d_cuda_type, M, N, ldd));
#endif
+ } else {
+ CUBLAS_RETURN_IF_ERROR(
+ cublasLtMatrixLayoutCreate(&Cdesc, d_cuda_type, M, N, ldd));
+ }
if (row_major_compute) {
cublasLtOrder_t matrixOrder = CUBLASLT_ORDER_ROW;
@@ -345,7 +353,7 @@ Status GemmFloat8::ComputeGemm(
". Check NVIDIA documentation to see what combination is valid: ",
"https://docs.nvidia.com/cuda/cublas/"
"index.html?highlight=cublasLtMatmulAlgoGetHeuristic#"
- "cublasltmatmulalgogetheuristic.");
+ "cublasltmatmulalgogetheuristic. CUDA>=11.8 is required to use float 8 types.");
void* workspace = nullptr;
if (workspaceSize > 0) {
@@ -381,7 +389,8 @@ Status GemmFloat8::ComputeGemm(
", shape_A=", shape_A[0], "x", shape_A[1], ", shape_B=", shape_B[0], "x",
shape_B[1], ", M=", M, ", N=", N, ", K=", K, ", lda=", lda, ", ldb=", ldb,
", ldd=", ldd, ", workspaceSize=", workspaceSize,
- ", rowMajorCompute=", (row_major_compute ? 1 : 0), ".");
+ ", rowMajorCompute=", (row_major_compute ? 1 : 0),
+ ". CUDA>=11.8 is required to use float 8 types.");
if (workspaceSize > 0) {
CUDA_RETURN_IF_ERROR(cudaFree(workspace));
diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cu b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cu
index e58723f0b31e1..2f74dd41f0759 100644
--- a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cu
+++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cu
@@ -35,6 +35,8 @@ template Status SetBnbQuantMap(int quant_type, float* quant_map_buffer, c
template Status SetBnbQuantMap(int quant_type, half* quant_map_buffer, cudaStream_t stream);
+template Status SetBnbQuantMap(int quant_type, BFloat16* quant_map_buffer, cudaStream_t stream);
+
template
__global__ void kDequantizeBlockwise(
const T* quant_map,
@@ -62,22 +64,15 @@ __global__ void kDequantizeBlockwise(
valid_items_load = (n + 1) / 2 - i > TILE_SIZE ? TILE_SIZE : (n + 1) / 2 - i;
valid_items_store = n - i * 2 > TILE_SIZE * 2 ? TILE_SIZE * 2 : n - i * 2;
- local_abs_max = __ldg(&absmax[(i + threadIdx.x * NUM_PER_TH) / (block_size)]);
+ local_abs_max = absmax[(i + threadIdx.x * NUM_PER_TH) / (block_size)];
__syncthreads();
LoadChar(loadchar).Load(&(quant_data[i]), qvals, valid_items_load, 128);
#pragma unroll NUM_PER_TH
for (int j = 0; j < NUM_PER_TH; j++) {
- #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530
- vals[j * 2] = quant_map[qvals[j] >> 4] * local_abs_max;
- vals[j * 2 + 1] = quant_map[qvals[j] & 0x0F] * local_abs_max;
- #else
- // half multiplication not supported
- vals[j * 2] = static_cast(static_cast(quant_map[qvals[j] >> 4]) * static_cast(local_abs_max));
- vals[j * 2 + 1] =
- static_cast(static_cast(quant_map[qvals[j] & 0x0F]) * static_cast(local_abs_max));
- #endif
+ vals[j * 2] = ScalarMul(quant_map[qvals[j] >> 4], local_abs_max);
+ vals[j * 2 + 1] = ScalarMul(quant_map[qvals[j] & 0x0F], local_abs_max);
}
__syncthreads();
@@ -86,7 +81,7 @@ __global__ void kDequantizeBlockwise(
}
template
-Status DequantizeBnb4(
+void CallkDequantizeBlockwise(
const T* quant_map,
T* output,
const uint8_t* quant_data,
@@ -102,6 +97,18 @@ Status DequantizeBnb4(
absmax,
block_size / 2,
numel);
+}
+
+template
+Status DequantizeBnb4(
+ const T* quant_map,
+ T* output,
+ const uint8_t* quant_data,
+ const T* absmax,
+ int block_size,
+ int numel,
+ cudaStream_t stream) {
+ CallkDequantizeBlockwise(quant_map, output, quant_data, absmax, block_size, numel, stream);
return Status::OK();
}
@@ -119,11 +126,36 @@ template Status DequantizeBnb4(
const half* quant_map,
half* output,
const uint8_t* quant_data,
- const half *absmax,
+ const half* absmax,
int block_size,
int numel,
cudaStream_t stream);
+template <>
+Status DequantizeBnb4(
+ const BFloat16* quant_map,
+ BFloat16* output,
+ const uint8_t* quant_data,
+ const BFloat16* absmax,
+ int block_size,
+ int numel,
+ cudaStream_t stream) {
+ #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800
+ CallkDequantizeBlockwise(
+ reinterpret_cast(quant_map),
+ reinterpret_cast(output),
+ quant_data,
+ reinterpret_cast(absmax),
+ block_size,
+ numel,
+ stream);
+ #else
+ CallkDequantizeBlockwise(quant_map, output, quant_data, absmax, block_size, numel, stream);
+ #endif
+
+ return Status::OK();
+}
+
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh
index 4aef3ab699f9c..a0d38c9853cd6 100644
--- a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh
+++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh
@@ -11,6 +11,38 @@ namespace cuda {
template
Status SetBnbQuantMap(int quant_type, T* quant_map_buffer, cudaStream_t stream);
+// templated scalar multiply function
+template
+__device__ inline T ScalarMul(T a, T b);
+
+template <>
+__device__ inline float ScalarMul(float a, float b) {
+ return a * b;
+}
+
+template <>
+__device__ inline half ScalarMul(half a, half b) {
+ #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530
+ return a * b;
+ #else
+ // half multiplication not supported
+ return static_cast(static_cast(a) * static_cast(b));
+ #endif
+}
+
+template <>
+__device__ inline BFloat16 ScalarMul(BFloat16 a, BFloat16 b) {
+ return a * b;
+}
+
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
+// will use the native bfloat16 multiply instruction on sm_80+
+template <>
+__device__ inline nv_bfloat16 ScalarMul(nv_bfloat16 a, nv_bfloat16 b) {
+ return a * b;
+}
+#endif
+
template
Status DequantizeBnb4(
const T* quant_map,
diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc
index ecf332715d470..bbcb7de99781f 100644
--- a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc
+++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc
@@ -145,6 +145,17 @@ ONNX_OPERATOR_TYPED_KERNEL_EX(
.TypeConstraint("T2", DataTypeImpl::GetTensorType()),
MatMulBnb4);
+ONNX_OPERATOR_TYPED_KERNEL_EX(
+ MatMulBnb4,
+ kMSDomain,
+ 1,
+ BFloat16,
+ kCudaExecutionProvider,
+ (*KernelDefBuilder::Create())
+ .TypeConstraint("T1", DataTypeImpl::GetTensorType())
+ .TypeConstraint("T2", DataTypeImpl::GetTensorType()),
+ MatMulBnb4);
+
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cu b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cu
index 1d9aa75ff3701..098e3618beddd 100644
--- a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cu
+++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cu
@@ -6,12 +6,44 @@
#include
#include
#include
+#include "contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh"
#include "matmul_bnb4.cuh"
namespace onnxruntime {
namespace contrib {
namespace cuda {
+template
+__device__ inline float ScalarMulFloatOut(T a, T b);
+
+template <>
+__device__ inline float ScalarMulFloatOut(float a, float b) {
+ return a * b;
+}
+
+template <>
+__device__ inline float ScalarMulFloatOut(half a, half b) {
+ #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530
+ return static_cast(a * b);
+ #else
+ // half multiplication not supported
+ return static_cast(a) * static_cast(b);
+ #endif
+}
+
+template <>
+__device__ inline float ScalarMulFloatOut(BFloat16 a, BFloat16 b) {
+ return a * b;
+}
+
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
+// will use the native bfloat16 multiply instruction on sm_80+
+template <>
+__device__ inline float ScalarMulFloatOut(nv_bfloat16 a, nv_bfloat16 b) {
+ return static_cast(a * b);
+}
+#endif
+
#define num_values_4bit 32
template
__global__ void kgemm_4bit_inference_naive(
@@ -55,7 +87,7 @@ __global__ void kgemm_4bit_inference_naive(
int inner_idx_halved = inner_idx / 2;
int offset_B = ldb * row_B;
int absidx = ((2 * offset_B) + inner_idx) / block_size;
- local_absmax = __ldg(&(absmax[absidx]));
+ local_absmax = absmax[absidx];
if (row_B < N) {
if ((inner_idx_halved + num_values_8bit) < (K / 2)) {
@@ -78,18 +110,8 @@ __global__ void kgemm_4bit_inference_naive(
for (int i = 0; i < 4; i++) {
#pragma unroll
for (int k = 0; k < num_values_8bit / 4; k++) {
- #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530
- local_B[k * 2] = quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] >> 4] * local_absmax;
- local_B[k * 2 + 1] = quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] & 0x0F] * local_absmax;
- #else
- // half multiplication not supported
- local_B[k * 2] =
- static_cast(static_cast(quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] >> 4]) *
- static_cast(local_absmax));
- local_B[k * 2 + 1] =
- static_cast(static_cast(quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] & 0x0F]) *
- static_cast(local_absmax));
- #endif
+ local_B[k * 2] = ScalarMul(quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] >> 4], local_absmax);
+ local_B[k * 2 + 1] = ScalarMul(quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] & 0x0F], local_absmax);
}
if (inner_idx + (num_values_4bit / 4) + (i * num_values_4bit / 4) < K) {
@@ -116,12 +138,7 @@ __global__ void kgemm_4bit_inference_naive(
// accumulate in float; small performance hit for Ampere, but lower error for outputs
#pragma unroll
for (int k = 0; k < num_values_4bit / 4; k++) {
- #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530
- local_C += static_cast(local_A[k] * local_B[k]);
- #else
- // half multiplication not supported
- local_C += static_cast(local_A[k]) * static_cast(local_B[k]);
- #endif
+ local_C += ScalarMulFloatOut(local_A[k], local_B[k]);
}
}
}
@@ -131,8 +148,19 @@ __global__ void kgemm_4bit_inference_naive(
if (row_B < N && warp_lane == 0) out[row_B] = T(local_C);
}
+bool CheckDims(int m, int k, int block_size) {
+ if (k % block_size != 0 || m > 1) {
+ return false;
+ }
+ // supported block_sizes are [4096, 2048, 1024, 512, 256, 128, 64, 32]
+ if (block_size % 32 != 0 || block_size > 4096) {
+ return false;
+ }
+ return true;
+}
+
template
-bool TryMatMulBnb4(
+void Callkgemm_4bit_inference_naive(
const T* quant_map,
T* output,
const T* a_data,
@@ -143,22 +171,34 @@ bool TryMatMulBnb4(
int k,
int block_size,
cudaStream_t stream) {
- if (k % block_size != 0 || m > 1) {
- return false;
- }
- // supported block_sizes are [4096, 2048, 1024, 512, 256, 128, 64, 32]
- if (block_size % 32 != 0 || block_size > 4096) {
- return false;
- }
-
int lda = k;
int ldb = (k + 1) / 2;
int ldc = n;
int num_blocks = (n + 3) / 4;
- constexpr int bits = std::is_same_v ? 16 : 32;
+ constexpr int bits = std::is_same_v ? 32 : 16;
kgemm_4bit_inference_naive<<>>(
m, n, k, a_data, b_data_quant, absmax, quant_map, output, lda, ldb, ldc, block_size);
+}
+
+template
+bool TryMatMulBnb4(
+ const T* quant_map,
+ T* output,
+ const T* a_data,
+ const uint8_t* b_data_quant,
+ const T* absmax,
+ int m,
+ int n,
+ int k,
+ int block_size,
+ cudaStream_t stream) {
+ if (!CheckDims(m, k, block_size)) {
+ return false;
+ }
+
+ Callkgemm_4bit_inference_naive(
+ quant_map, output, a_data, b_data_quant, absmax, m, n, k, block_size, stream);
return true;
}
@@ -187,6 +227,42 @@ template bool TryMatMulBnb4(
int block_size,
cudaStream_t stream);
+template <>
+bool TryMatMulBnb4(
+ const BFloat16* quant_map,
+ BFloat16* output,
+ const BFloat16* a_data,
+ const uint8_t* b_data_quant,
+ const BFloat16* absmax,
+ int m,
+ int n,
+ int k,
+ int block_size,
+ cudaStream_t stream) {
+ if (!CheckDims(m, k, block_size)) {
+ return false;
+ }
+
+ #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800
+ Callkgemm_4bit_inference_naive(
+ reinterpret_cast(quant_map),
+ reinterpret_cast(output),
+ reinterpret_cast(a_data),
+ b_data_quant,
+ reinterpret_cast(absmax),
+ m,
+ n,
+ k,
+ block_size,
+ stream);
+ #else
+ Callkgemm_4bit_inference_naive(
+ quant_map, output, a_data, b_data_quant, absmax, m, n, k, block_size, stream);
+ #endif
+
+ return true;
+}
+
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
index db0b13b0e1d27..4c0d78f0ee297 100644
--- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
@@ -3431,7 +3431,7 @@ MatMulBnb4 is a MatMul with weight quantized with 4 bits using either FP4 or NF4
.Input(1, "B", "1-dimensional quantized data for weight", "T2")
.Input(2, "absmax", "quantization constants", "T1")
.Output(0, "Y", "tensor. The output tensor has the same rank as the input. ", "T1")
- .TypeConstraint("T1", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float/half_float tensors.")
+ .TypeConstraint("T1", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output types to float/half_float/brain_float tensors.")
.TypeConstraint("T2", {"tensor(uint8)"}, "Constrain quantized weight types to uint8.")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
// Type inference
diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc
index 3763e0758cc5c..d489a59c4b798 100644
--- a/onnxruntime/core/graph/graph.cc
+++ b/onnxruntime/core/graph/graph.cc
@@ -4062,7 +4062,9 @@ static void ReassignSubgraphDependentNodeArgs(const InlinedHashMapExists()) {
auto hit = name_to_nodearg.find(input_def->Name());
if (hit != name_to_nodearg.cend()) {
- input_def = hit->second;
+ // Make sure we create a local to this subgraph definition
+ const auto* new_name_arg = hit->second;
+ input_def = &graph.GetOrCreateNodeArg(new_name_arg->Name(), input_def->TypeAsProto());
}
}
}
@@ -4088,7 +4090,7 @@ Status Graph::InlineIfSubgraph(bool condition_value, Node& if_node, const loggin
Graph& graph_to_inline = *sub_graph;
- std::string unique_id{if_node.Name()};
+ std::string unique_id{"_if_"};
if (condition_value) {
unique_id.append(then_branch);
} else {
@@ -4107,7 +4109,7 @@ Status Graph::InlineIfSubgraph(bool condition_value, Node& if_node, const loggin
// Reason: there are no explicit inputs to the subgraphs, and the subgraph's
// implicit inputs must be covered by the implicit inputs of the If node.
InlinedHashMap outer_scope_values;
- const auto if_implicit_inputs = if_node.MutableImplicitInputDefs();
+ const auto& if_implicit_inputs = if_node.MutableImplicitInputDefs();
outer_scope_values.reserve(if_implicit_inputs.size());
for (auto* input : if_implicit_inputs) {
@@ -4121,8 +4123,8 @@ Status Graph::InlineIfSubgraph(bool condition_value, Node& if_node, const loggin
// We are going to map the outputs of the graph to inline to the outputs of the If node.
// They are assumed to be in the same order.
- const auto node_output_defs = if_node.MutableOutputDefs();
- const auto graph_output_defs = graph_to_inline.GetOutputs();
+ const auto& node_output_defs = if_node.MutableOutputDefs();
+ const auto& graph_output_defs = graph_to_inline.GetOutputs();
for (size_t i = 0; i < graph_output_defs.size(); ++i) {
name_to_nodearg.emplace(graph_output_defs[i]->Name(), node_output_defs[i]);
}
@@ -4206,6 +4208,7 @@ Status Graph::InlineIfSubgraph(bool condition_value, Node& if_node, const loggin
}
}
+ auto* non_existing_arg = &GetOrCreateNodeArg(std::string(), nullptr);
// We want to make sure we get nodes in topological order
// because Constant folding may cause the nodes appear in
// a different order.
@@ -4216,68 +4219,94 @@ Status Graph::InlineIfSubgraph(bool condition_value, Node& if_node, const loggin
auto* node = graph_to_inline.GetNode(node_idx);
assert(node->OpType() != kConstant);
- InlinedVector new_node_input_defs;
- for (const auto* input_def : node->InputDefs()) {
+ // Inputs
+ // Chop off trailing non-existing defs, but preserve non-existing in the middle
+ auto& input_defs = node->MutableInputDefs();
+ auto last_existing = std::find_if(input_defs.rbegin(), input_defs.rend(),
+ [](const NodeArg* node_arg) { return node_arg->Exists(); });
+ input_defs.resize(std::distance(input_defs.begin(), last_existing.base()));
+
+ InlinedVector new_input_defs;
+ for (auto* input_def : node->InputDefs()) {
if (input_def->Exists()) {
// Check if this is one of the implicit graph inputs
- // then leave the name as is and re-use the NodeArg
+ // then re-assign the def to the outer scope value.
const auto& input_name = input_def->Name();
auto outer_hit = outer_scope_values.find(input_name);
if (outer_hit != outer_scope_values.cend()) {
- new_node_input_defs.push_back(outer_hit->second);
+ // get/create local definition
+ NodeArg* outer_arg = outer_hit->second;
+ auto& this_scope_arg = GetOrCreateNodeArg(outer_arg->Name(), input_def->TypeAsProto());
+ new_input_defs.push_back(&this_scope_arg);
} else {
auto hit = name_to_nodearg.find(input_name);
if (hit != name_to_nodearg.cend()) {
- // This is other node output, constant node or initializer that was renamed.
- new_node_input_defs.push_back(hit->second);
+ // This is other node output in the dest graph,
+ // constant node or initializer that was renamed.
+ new_input_defs.push_back(hit->second);
} else {
ORT_THROW("Node's: ", node->Name(), " input: ", input_name,
" is not If node's input or previous node output in this subgraph");
}
}
+ } else {
+ new_input_defs.push_back(non_existing_arg);
}
}
- InlinedVector new_node_output_defs;
- for (const auto* output_def : node->OutputDefs()) {
- const auto& output_name = output_def->Name();
- auto hit = name_to_nodearg.find(output_name);
- if (hit != name_to_nodearg.cend()) {
- // This is one of the graph outputs, we rename it to
- // If node output.
- new_node_output_defs.push_back(hit->second);
+ // Outputs
+ // Chop off trailing non-existing defs
+ auto& output_defs = node->MutableOutputDefs();
+ last_existing = std::find_if(output_defs.rbegin(), output_defs.rend(),
+ [](const NodeArg* node_arg) { return node_arg->Exists(); });
+ output_defs.resize(std::distance(output_defs.begin(), last_existing.base()));
+
+ InlinedVector new_output_defs;
+ for (auto* output_def : node->OutputDefs()) {
+ if (output_def->Exists()) {
+ const auto& output_name = output_def->Name();
+ auto hit = name_to_nodearg.find(output_name);
+ if (hit != name_to_nodearg.cend()) {
+ // This is one of the If node outputs, simply reassign the def.
+ // If node defs are already in the destination graph
+ new_output_defs.push_back(hit->second);
+ } else {
+ // We generate an output to downstream nodes.
+ auto new_name = GenerateNodeArgName(make_unique(output_name));
+ NodeArg& new_arg = GetOrCreateNodeArg(new_name, output_def->TypeAsProto());
+ new_output_defs.push_back(&new_arg);
+ ORT_IGNORE_RETURN_VALUE(name_to_nodearg.emplace(output_name, &new_arg));
+ }
} else {
- // We generate an output to downstream nodes.
- auto new_name = GenerateNodeArgName(make_unique(output_name));
- NodeArg& new_arg = GetOrCreateNodeArg(new_name, output_def->TypeAsProto());
- new_node_output_defs.push_back(&new_arg);
- ORT_IGNORE_RETURN_VALUE(name_to_nodearg.emplace(output_name, &new_arg));
+ new_output_defs.push_back(non_existing_arg);
}
}
const auto new_node_name = GenerateNodeName(make_unique(node->OpType()));
Node& new_node = AddNode(new_node_name, node->OpType(), node->Description(),
- new_node_input_defs,
- new_node_output_defs,
+ new_input_defs,
+ new_output_defs,
nullptr,
node->Domain());
+ new_node.SetSinceVersion(node->SinceVersion());
+ new_node.op_ = node->op_;
+
if (!is_this_main_graph) {
map_defs(new_node, input_args, true);
map_defs(new_node, output_args, false);
new_nodes.push_back(&new_node);
}
- new_node.SetSinceVersion(node->SinceVersion());
- new_node.op_ = node->op_;
-
if (node->ContainsSubgraph()) {
auto& subgraphs = node->MutableSubgraphs();
// Check if any of this node implicit inputs of this graph is in the renaming map
+ // that would mean they come from the destination graph, not from the parent
+ // of the destination graph.
int renames_subgraph_names = 0;
- auto& new_implicit_defs = node->MutableImplicitInputDefs();
- for (auto& input_def : new_implicit_defs) {
+ auto& implicit_defs = node->MutableImplicitInputDefs();
+ for (auto& input_def : implicit_defs) {
auto hit = name_to_nodearg.find(input_def->Name());
if (hit != name_to_nodearg.cend()) {
input_def = hit->second;
@@ -4298,7 +4327,7 @@ Status Graph::InlineIfSubgraph(bool condition_value, Node& if_node, const loggin
new_node.MutableSubgraphs() = std::move(subgraphs);
new_node.GetMutableMapOfAttributeNameToSubgraph() = std::move(node->GetMutableMapOfAttributeNameToSubgraph());
- new_node.MutableImplicitInputDefs() = std::move(new_implicit_defs);
+ new_node.MutableImplicitInputDefs() = std::move(implicit_defs);
}
new_node.GetMutableAttributes() = std::move(node->GetMutableAttributes());
diff --git a/onnxruntime/core/providers/cuda/cuda_common.cc b/onnxruntime/core/providers/cuda/cuda_common.cc
index 288ca8e97e34d..33f2938940e4d 100644
--- a/onnxruntime/core/providers/cuda/cuda_common.cc
+++ b/onnxruntime/core/providers/cuda/cuda_common.cc
@@ -62,7 +62,8 @@ const char* CudaDataTypeToString(cudaDataType_t dt) {
return "CUDA_R_16BF";
case CUDA_R_32F:
return "CUDA_R_32F";
-#if (CUDA_VERSION >= 11080)
+#if !defined(DISABLE_FLOAT8_TYPES)
+ // Note: CUDA_R_8F_E4M3 is defined with CUDA>=11.8
case CUDA_R_8F_E4M3:
return "CUDA_R_8F_E4M3";
case CUDA_R_8F_E5M2:
@@ -101,7 +102,7 @@ cudaDataType_t ToCudaDataType(int32_t element_type) {
return CUDA_R_16F;
case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16:
return CUDA_R_16BF;
-#if (!defined(DISABLE_FLOAT8_TYPES) && (CUDA_VERSION >= 11080))
+#if !defined(DISABLE_FLOAT8_TYPES)
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN:
return CUDA_R_8F_E4M3;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2:
diff --git a/onnxruntime/core/providers/cuda/cuda_common.h b/onnxruntime/core/providers/cuda/cuda_common.h
index 9cd4e721ccab8..707099bac3ce0 100644
--- a/onnxruntime/core/providers/cuda/cuda_common.h
+++ b/onnxruntime/core/providers/cuda/cuda_common.h
@@ -58,6 +58,8 @@ class ToCudaType {
}
};
+#if !defined(DISABLE_FLOAT8_TYPES)
+
template <>
class ToCudaType {
public:
@@ -76,6 +78,8 @@ class ToCudaType {
}
};
+#endif
+
inline bool CalculateFdmStrides(gsl::span p, const std::vector& dims) {
int stride = 1;
if (dims.empty() || p.size() < dims.size())
diff --git a/onnxruntime/core/providers/cuda/tensor/cast_op.cu b/onnxruntime/core/providers/cuda/tensor/cast_op.cu
index 7542fb55757c6..f2c2e6d7458f9 100644
--- a/onnxruntime/core/providers/cuda/tensor/cast_op.cu
+++ b/onnxruntime/core/providers/cuda/tensor/cast_op.cu
@@ -141,7 +141,7 @@ struct CastSat {
#endif
-#endif
+#endif // DISABLE_FLOAT8_TYPES
template
__global__ void CastKernelStd(const InT* input, OutT* output, CUDA_LONG N, CastStd cast) {
diff --git a/onnxruntime/core/providers/cuda/tensor/quantize_linear.cu b/onnxruntime/core/providers/cuda/tensor/quantize_linear.cu
index ad2a44793fe26..1da308811fa48 100644
--- a/onnxruntime/core/providers/cuda/tensor/quantize_linear.cu
+++ b/onnxruntime/core/providers/cuda/tensor/quantize_linear.cu
@@ -104,7 +104,7 @@ struct RoundSat {
#endif
-#endif
+#endif // DISABLE_FLOAT8_TYPES
template <>
struct RoundStd {
@@ -189,7 +189,7 @@ __global__ void QuantizeLinearKernelAxisSat(const InT* input, OutT* output, cons
}
}
-#endif
+#endif // DISABLE_FLOAT8_TYPES
template
Status CudaQuantizeLinearStd(cudaStream_t stream, const InT* input, OutT* output, const InT* scale, const OutT* zero_point, size_t num_of_element) {
diff --git a/onnxruntime/core/providers/js/operators/unary.cc b/onnxruntime/core/providers/js/operators/unary.cc
index e9bbfabcf86bd..78563d30b0136 100644
--- a/onnxruntime/core/providers/js/operators/unary.cc
+++ b/onnxruntime/core/providers/js/operators/unary.cc
@@ -123,7 +123,7 @@ JSEP_ELEMENTWISE_TYPED_KERNEL(Not, 1, bool, Not)
// activation
-JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_2_DEFAULT(ClipV10, ClipV10, min, 3.402823e+38f, max, -3.402823e+38f)
+JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_2_DEFAULT(ClipV10, Clip, min, 3.402823e+38f, max, -3.402823e+38f)
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Clip, 6, 10, ClipV10)
JSEP_KERNEL_IMPL(Clip, Clip)
ONNX_OPERATOR_VERSIONED_KERNEL_EX(Clip, kOnnxDomain, 11, 11, kJsExecutionProvider,
diff --git a/onnxruntime/python/tools/transformers/large_model_exporter.py b/onnxruntime/python/tools/transformers/large_model_exporter.py
index 3b344d6dc9342..407c3b80e153f 100644
--- a/onnxruntime/python/tools/transformers/large_model_exporter.py
+++ b/onnxruntime/python/tools/transformers/large_model_exporter.py
@@ -157,14 +157,14 @@ def hook_for_inputs(_, inputs, kwargs):
for idx, (key, value) in enumerate(zip(input_keys, onnx_inputs)):
if type(value) is torch.Tensor:
value.to(model.device)
- # Didn't touch past_key_value now, please change it if you want
if "use_cache" in key:
onnx_inputs[idx] = with_past
+ out = model(sample_inputs[0], attention_mask=sample_inputs[1], use_cache=with_past) if with_past else out
return input_keys, onnx_inputs, out.past_key_values
-def move_to_approprate_device(model: nn.Module, sample_inputs_tp: tuple) -> nn.Module:
+def move_to_appropriate_device(model: nn.Module, sample_inputs_tp: tuple) -> nn.Module:
"""
According to the model size, we will upload it to
CPU if has no GPU or enough GPU memory,
@@ -307,7 +307,7 @@ def export_onnx(hf_model: str, cache_dir: Optional[str], onnx_path_str: str, wit
"""
model, sample_inputs_tp = initialize_model_and_sample_inputs(hf_model, cache_dir)
- model = move_to_approprate_device(model, sample_inputs_tp)
+ model = move_to_appropriate_device(model, sample_inputs_tp)
sample_inputs = adapt_inputs_to_device(sample_inputs_tp, next(model.parameters()).device)
diff --git a/onnxruntime/test/contrib_ops/gemm_float8_test.cc b/onnxruntime/test/contrib_ops/gemm_float8_test.cc
new file mode 100644
index 0000000000000..c022736075cde
--- /dev/null
+++ b/onnxruntime/test/contrib_ops/gemm_float8_test.cc
@@ -0,0 +1,126 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "gtest/gtest.h"
+#include "test/common/tensor_op_test_utils.h"
+#include "test/common/cuda_op_test_utils.h"
+#include "test/providers/provider_test_utils.h"
+
+namespace onnxruntime {
+namespace test {
+
+#if defined(USE_CUDA) && defined(CUDA_VERSION) && CUDA_VERSION >= 12000
+
+TEST(GemmFloat8OpTest, BFloat16) {
+ OpTester test("GemmFloat8", 1, onnxruntime::kMSDomain);
+ test.AddAttribute("transA", (int64_t)0);
+ test.AddAttribute("transB", (int64_t)0);
+ test.AddAttribute("alpha", 1.0f);
+ test.AddAttribute("beta", 1.0f);
+ test.AddAttribute("activation", "NONE");
+ test.AddAttribute("dtype", static_cast(ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16));
+ test.AddInput("A", {2, 4}, MakeBFloat16({1.0f, 2.0f, 3.0f, 4.0f, -1.0f, -2.0f, -3.0f, -4.0f}));
+ test.AddInput("B", {4, 3}, MakeBFloat16({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}));
+ test.AddInput("C", {2, 3}, MakeBFloat16({1.f, 1.f, 1.f, 1.f, 1.f, 1.f}));
+ test.AddOutput("Y", {2, 3}, MakeBFloat16({11.0f, 11.0f, 11.0f, -9.0f, -9.0f, -9.0f}));
+ std::vector> execution_providers;
+ execution_providers.push_back(DefaultCudaExecutionProvider());
+ test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
+}
+
+TEST(GemmFloat8OpTest, Float) {
+ OpTester test("GemmFloat8", 1, onnxruntime::kMSDomain);
+ test.AddAttribute("transA", (int64_t)0);
+ test.AddAttribute("transB", (int64_t)0);
+ test.AddAttribute("alpha", 1.0f);
+ test.AddAttribute("beta", 1.0f);
+ test.AddAttribute("activation", "NONE");
+ test.AddAttribute("dtype", static_cast(ONNX_NAMESPACE::TensorProto_DataType_FLOAT));
+ test.AddInput("A", {2, 4}, std::vector({1.0f, 2.0f, 3.0f, 4.0f, -1.0f, -2.0f, -3.0f, -4.0f}));
+ test.AddInput("B", {4, 3}, std::vector({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}));
+ test.AddInput("C", {2, 3}, std::vector({1.f, 1.f, 1.f, 1.f, 1.f, 1.f}));
+ test.AddOutput("Y", {2, 3}, std::vector({11.0f, 11.0f, 11.0f, -9.0f, -9.0f, -9.0f}));
+ std::vector> execution_providers;
+ execution_providers.push_back(DefaultCudaExecutionProvider());
+ test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
+}
+
+std::vector _Cvt(const std::vector& tensor) {
+ std::vector fp16_data(tensor.size());
+ ConvertFloatToMLFloat16(tensor.data(), fp16_data.data(), static_cast(tensor.size()));
+ return fp16_data;
+}
+
+TEST(GemmFloat8OpTest, Float16) {
+ OpTester test("GemmFloat8", 1, onnxruntime::kMSDomain);
+ test.AddAttribute("transA", (int64_t)0);
+ test.AddAttribute("transB", (int64_t)0);
+ test.AddAttribute("alpha", 1.0f);
+ test.AddAttribute("beta", 1.0f);
+ test.AddAttribute("activation", "NONE");
+ test.AddAttribute("dtype", static_cast(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16));
+ test.AddInput("A", {2, 4}, _Cvt(std::vector({1.0f, 2.0f, 3.0f, 4.0f, -1.0f, -2.0f, -3.0f, -4.0f})));
+ test.AddInput("B", {4, 3}, _Cvt(std::vector({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
+ test.AddInput("C", {2, 3}, _Cvt(std::vector({1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
+ test.AddOutput("Y", {2, 3}, _Cvt(std::vector({11.0f, 11.0f, 11.0f, -9.0f, -9.0f, -9.0f})));
+ std::vector> execution_providers;
+ execution_providers.push_back(DefaultCudaExecutionProvider());
+ test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
+}
+
+#if (!defined(DISABLE_FLOAT8_TYPES)) && (CUDA_VERSION >= 12000)
+
+template
+std::vector _TypedCvt(const std::vector& tensor);
+
+template <>
+std::vector _TypedCvt(const std::vector& tensor) {
+ return tensor;
+}
+
+template <>
+std::vector _TypedCvt(const std::vector& tensor) {
+ std::vector out(tensor.size());
+ for (size_t i = 0; i < tensor.size(); ++i) {
+ out[i] = Float8E4M3FN(tensor[i]);
+ }
+ return out;
+}
+
+template
+void TestGemmFloat8WithFloat8(int64_t dtype) {
+ int min_cuda_architecture = 11080;
+ if (!HasCudaEnvironment(min_cuda_architecture)) {
+ LOGS_DEFAULT(WARNING) << "Hardware NOT support Matrix Multiplication for FLOAT8";
+ return;
+ }
+ OpTester test("GemmFloat8", 1, onnxruntime::kMSDomain);
+ test.AddAttribute("transA", (int64_t)0);
+ test.AddAttribute("transB", (int64_t)1);
+ test.AddAttribute("alpha", 1.0f);
+ test.AddAttribute("beta", 1.0f);
+ test.AddAttribute("activation", "NONE");
+ test.AddAttribute("dtype", dtype);
+ test.AddInput("A", {2, 4}, _TypeCvt(std::vector({1.0f, 2.0f, 3.0f, 4.0f, -1.0f, -2.0f, -3.0f, -4.0f})));
+ test.AddInput("B", {3, 4}, _TypeCvt(std::vector({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
+ test.AddInput("C", {2, 3}, _TypeCvt(std::vector({1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
+ test.AddOutput("Y", {2, 3}, _TypeCvt(std::vector({11.0f, 11.0f, 11.0f, -9.0f, -9.0f, -9.0f})));
+ std::vector> execution_providers;
+ execution_providers.push_back(DefaultCudaExecutionProvider());
+ test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
+}
+
+TEST(GemmFloat8OpTest, Float8E4M3FNToFloat) {
+ TestGemmFloat8WithFloat8(static_cast(ONNX_NAMESPACE::TensorProto_DataType_FLOAT));
+}
+
+TEST(GemmFloat8OpTest, Float8E4M3FNToFloat8E4M3FN) {
+ TestGemmFloat8WithFloat8(static_cast(ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN));
+}
+
+#endif
+
+#endif
+
+} // namespace test
+} // namespace onnxruntime
diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc
index 17b26ed7ca4ca..ef6e2d531bc1a 100755
--- a/onnxruntime/test/optimizer/graph_transform_test.cc
+++ b/onnxruntime/test/optimizer/graph_transform_test.cc
@@ -1176,6 +1176,162 @@ TEST_F(GraphTransformationTests, ConstantFoldingIfConstantInliningRebuildEdges)
ASSERT_EQ(op_to_count["Cast"], 2);
}
+TEST_F(GraphTransformationTests, ConstantFoldingIfConstantInliningEdgesWithMiddleArgNonExisting) {
+ // This model has a Resize() call with a middle argument non-existing.
+ // We want to make sure that the input edges for that Resize() node
+ // are properly rebuilt with a middle argument non-existing
+ // during If constant folding
+ // This test is only valid if Resize() node resides in the nested subgraph which gets inlined
+ // however, the destination graph must not be the main graph. Then we test that the edges are rebuild
+ // properly. Also Resize() should not be the first node in the resulting subgraph, so it has edges
+ const char* code = R"(
+ <
+ ir_version: 8,
+ opset_import: [ "" : 16, "local" : 1 ]
+ >
+ agraph (float[128] x, float[128] x1) => (float[N] y)
+ {
+ y = local.aten_gather (x, x1)
+ }
+ <
+ opset_import: [ "" : 16, "local" : 1],
+ domain: "local"
+ >
+ aten_gather (self, index) => (result_16)
+ {
+ resize_scales = Constant ()
+ tmp_0 = Size (index)
+ int64_0 = Constant ()
+ int64_0_cast = CastLike (int64_0, tmp_0)
+ cond = Equal (tmp_0, int64_0_cast)
+ result_16 = If (cond) ( result) {
+ result = Identity (self)
+ }, else_branch: graph = elseGraph_10 () => ( result_15) {
+ tmp_1 = Shape (self)
+ tmp_2 = Size (tmp_1)
+ int64_0_3 = Constant ()
+ int64_0_3_cast = CastLike (int64_0_3, tmp_2)
+ cond_4 = Equal (tmp_2, int64_0_3_cast)
+ self_8 = If (cond_4) ( self_6) {
+ tmp_5 = Constant ()
+ self_6 = Reshape (self, tmp_5)
+ }, else_branch: graph = elseGraph_13 () => ( self_7) {
+ self_71 = Mul(self, self)
+ float_size = CastLike (tmp_0, resize_scales)
+ non_constant_resize_scales = Mul(float_size, resize_scales)
+ self_7 = Resize(self_71,, non_constant_resize_scales)
+ }>
+ tmp_9 = Size (index)
+ int64_0_10 = Constant ()
+ int64_0_10_cast = CastLike (int64_0_10, tmp_9)
+ cond_11 = Equal (tmp_9, int64_0_10_cast)
+ result_15 = If (cond_11) ( result_12) {
+ result_12 = CastLike (index, self_8)
+ }, else_branch: graph = elseGraph_15 () => ( result_14) {
+ index_13 = Cast (index)
+ result_14 = GatherElements (self_8, index_13)
+ }>
+ }>
+ }
+ )";
+
+ /** Optimized model graph
+ <
+ ir_version: 8,
+ opset_import: ["" : 16,
+ "local" : 1,
+ "com.microsoft.nchwc" : 1,
+ "ai.onnx.ml" : 4,
+ "ai.onnx.training" : 1,
+ "ai.onnx.preview.training" : 1,
+ "com.microsoft" : 1,
+ "com.microsoft.experimental" : 1, "org.pytorch.aten" : 1]
+ >
+ agraph (float[128] x, float[128] x1) => (float[128] y)
+
+ {
+ _inlfunc_aten_gather_tmp_0 = Size (x1)
+ _inlfunc_aten_gather_cond = Equal (_inlfunc_aten_gather_tmp_0, ortshared_7_0_1_0_token_8)
+ y = If (_inlfunc_aten_gather_cond)
+ (float[128] _inlfunc_aten_gather_result) {
+ _inlfunc_aten_gather_result = Identity (x)
+ }, else_branch: graph = elseGraph_10 () => (float[128] _inlfunc_aten_gather_result_15)
+
+ {
+ _if_else_branch__inlfunc_aten_gather_self_71 = Mul (x, x)
+ _if_else_branch__inlfunc_aten_gather_float_size = Cast (_inlfunc_aten_gather_tmp_0)
+ _if_else_branch__inlfunc_aten_gather_non_constant_resize_scales = Mul (
+ _if_else_branch__inlfunc_aten_gather_float_size, _inlfunc_aten_gather_resize_scales)
+ _inlfunc_aten_gather_self_8 = Resize (
+ _if_else_branch__inlfunc_aten_gather_self_71, ,
+ _if_else_branch__inlfunc_aten_gather_non_constant_resize_scales)
+ _inlfunc_aten_gather_tmp_9 = Size (x1)
+ _inlfunc_aten_gather_cond_11 = Equal (_inlfunc_aten_gather_tmp_9, _inlfunc_aten_gather_int64_0_10)
+ _inlfunc_aten_gather_result_15 = If (_inlfunc_aten_gather_cond_11)
+ (float[128] _inlfunc_aten_gather_result_12) {
+ _inlfunc_aten_gather_result_12 = Cast (x1)
+ }, else_branch: graph = elseGraph_15 () => (float[128] _inlfunc_aten_gather_result_14) {
+ _inlfunc_aten_gather_index_13 = Cast (x1)
+ _inlfunc_aten_gather_result_14 = GatherElements (
+ _inlfunc_aten_gather_self_8, _inlfunc_aten_gather_index_13)
+ }>
+ }>
+ }
+
+ */
+
+ ONNX_NAMESPACE::OnnxParser parser(code);
+ ONNX_NAMESPACE::ModelProto model_proto;
+ auto parse_status = parser.Parse(model_proto);
+ ASSERT_TRUE(parse_status.IsOK()) << parse_status.ErrorMessage();
+ ASSERT_TRUE(parser.EndOfInput()) << "Extra unparsed input unexpected.";
+
+ std::string serialized_model;
+ const bool serialization_status = model_proto.SerializeToString(&serialized_model);
+ ASSERT_TRUE(serialization_status) << "Failed to serialize proto to string";
+
+ // AOT inlining is necessary in this case, so the If nodes within the function
+ // are brought out to the outer scope. So we load this into a session object.
+ SessionOptions session_options;
+ InferenceSessionWrapper session_object{session_options, GetEnvironment()};
+ std::stringstream sstr(serialized_model);
+ ASSERT_STATUS_OK(session_object.Load(sstr));
+ ASSERT_STATUS_OK(session_object.Initialize());
+
+ // Let's verify the correctness of the rebuild edges in the Resize node that still
+ // resides within an if else subgraph.
+ auto& graph = session_object.GetModel().MainGraph();
+ auto op_to_count = CountOpsInGraph(graph);
+ ASSERT_EQ(op_to_count["If"], 2);
+ ASSERT_EQ(op_to_count["Resize"], 1);
+
+ auto if_node = std::find_if(graph.Nodes().begin(), graph.Nodes().end(),
+ [](const auto& node) { return node.OpType() == "If"; });
+ ASSERT_NE(graph.Nodes().cend(), if_node);
+ // Resize is in the else branch
+ auto subgraph_map = if_node->GetAttributeNameToSubgraphMap();
+ auto branch = subgraph_map.find("else_branch");
+ ASSERT_NE(subgraph_map.cend(), branch);
+
+ auto resize_node = std::find_if(branch->second->Nodes().begin(), branch->second->Nodes().end(),
+ [](const auto& node) { return node.OpType() == "Resize"; });
+ ASSERT_NE(branch->second->Nodes().cend(), resize_node);
+
+ // Check the edges
+ ASSERT_EQ(2U, resize_node->GetInputEdgesCount());
+ // Should have input edges with arg_pos 0 and 2
+ // With 1 is missing
+ InlinedHashSet dest_edges;
+ auto zero_edge = resize_node->InputEdgesBegin();
+ dest_edges.insert(zero_edge->GetDstArgIndex());
+ ++zero_edge;
+ dest_edges.insert(zero_edge->GetDstArgIndex());
+ ASSERT_TRUE(dest_edges.find(0) != dest_edges.end());
+ ASSERT_TRUE(dest_edges.find(2) != dest_edges.end());
+}
+
// Check transformations in the case of a subgraph with constant inputs.
TEST_F(GraphTransformationTests, SubgraphWithConstantInputs) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "constant-subgraph.onnx";
diff --git a/onnxruntime/test/python/onnxruntime_test_float8.py b/onnxruntime/test/python/onnxruntime_test_float8.py
index 76ca5d9538374..bb63ea234498f 100644
--- a/onnxruntime/test/python/onnxruntime_test_float8.py
+++ b/onnxruntime/test/python/onnxruntime_test_float8.py
@@ -334,7 +334,7 @@ def test_model_cast_cast_cpu(self, name: str, float_name: str, saturate: int):
]
)
@unittest.skipIf(not hasattr(TensorProto, "FLOAT8E4M3FN"), reason="needs onnx>=1.14.0")
- @unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running on CUDA.")
+ @unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running without CUDA.")
def test_model_cast_cast_cuda(self, name: str, float_name: str, saturate: int, provider: str):
so = onnxruntime.SessionOptions()
so.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
@@ -373,7 +373,7 @@ def test_model_cast_cast_cuda(self, name: str, float_name: str, saturate: int, p
]
)
@unittest.skipIf(not hasattr(TensorProto, "FLOAT8E4M3FN"), reason="needs onnx>=1.14.0")
- @unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running on CUDA.")
+ @unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running without CUDA.")
def test_model_cast_cast_cuda_ortvalue(self, name: str, float_name: str, saturate: int, provider: str):
so = onnxruntime.SessionOptions()
so.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
@@ -627,7 +627,7 @@ def test_model_cast_like_x2_cpu(self, name: str, float_name: str, saturate: int)
]
)
@unittest.skipIf(not hasattr(TensorProto, "FLOAT8E4M3FN"), reason="needs onnx>=1.14.0")
- @unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running on CUDA.")
+ @unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running without CUDA.")
def test_model_qdq_cuda(self, name: str, float_name: str, saturate: int, provider: str):
so = onnxruntime.SessionOptions()
so.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
@@ -693,7 +693,7 @@ def test_model_qdq_cuda_ortvalue(self, name: str, float_name: str, saturate: int
self.assertEqual(expect.shape, y.shape)
self.assertEqual(expect.dtype, y.dtype)
- @unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running on CUDA.")
+ @unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running without CUDA.")
def test_compare_cpu_cuda_e4m3fn(self):
folder = os.path.join(os.path.dirname(__file__), "..", "testdata", "float8")
model = os.path.join(folder, "te.cast_fp8_1_fp32.onnx")
diff --git a/onnxruntime/test/python/onnxruntime_test_float8_gemm8.py b/onnxruntime/test/python/onnxruntime_test_float8_gemm8.py
index 784ae8ce70bd8..7dffad8f84c83 100644
--- a/onnxruntime/test/python/onnxruntime_test_float8_gemm8.py
+++ b/onnxruntime/test/python/onnxruntime_test_float8_gemm8.py
@@ -17,7 +17,9 @@
from onnx.helper import make_graph, make_model, make_node, make_opsetid, make_tensor_value_info
from onnx.numpy_helper import from_array
-from onnxruntime import InferenceSession
+from onnxruntime import InferenceSession, get_available_providers
+
+available_providers = [provider for provider in get_available_providers()]
class TestFloat8Gemm8(unittest.TestCase):
@@ -192,21 +194,27 @@ def check(f):
self.assertEqual(expected.shape, y.shape)
self.assertEqual(expected.dtype, y.dtype)
+ @unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running without CUDA.")
def test_model_gemm_float(self):
self.common_test_model_gemm("FLOAT", transA=1, rtol=1e-3)
+ @unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running without CUDA.")
def test_model_gemm_float_default_values(self):
self.common_test_model_gemm("FLOAT", transA=1, rtol=1e-3, activation=None)
+ @unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running without CUDA.")
def test_model_gemm_float_relu(self):
self.common_test_model_gemm("FLOAT", transA=1, rtol=1e-3, activation="RELU")
+ @unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running without CUDA.")
def test_model_gemm_float_gelu(self):
self.common_test_model_gemm("FLOAT", transA=1, rtol=1e-3, activation="GELU")
+ @unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running without CUDA.")
def test_model_gemm_float_bias(self):
self.common_test_model_gemm("FLOAT", transA=1, beta=1.0, rtol=1e-3)
+ @unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running without CUDA.")
def test_model_gemm_float16(self):
self.common_test_model_gemm(
"FLOAT16",
@@ -215,6 +223,8 @@ def test_model_gemm_float16(self):
transB=1,
)
+ @unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running without CUDA.")
+ @unittest.skipIf(not hasattr(TensorProto, "FLOAT8E4M3FN"), reason="needs onnx>=1.14.0")
def test_model_gemm_float8_e4m3(self):
self.common_test_model_gemm(
"FLOAT8E4M3FN",
@@ -226,6 +236,7 @@ def test_model_gemm_float8_e4m3(self):
)
@parameterized.parameterized.expand(list(itertools.product([0, 1], [0, 1])))
+ @unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running without CUDA.")
def test_combinations_square_matrices(self, transA, transB):
self.common_test_model_gemm("FLOAT", transA=transA, transB=transB, rtol=1e-3)
@@ -237,6 +248,7 @@ def test_combinations_square_matrices(self, transA, transB):
((2, 3), (2, 5), 1, 0),
]
)
+ @unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running without CUDA.")
def test_combinations(self, shapeA, shapeB, transA, transB):
model = make_model(
make_graph(
diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py
index 4977272de5ac9..8efbe16d7d61d 100644
--- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py
+++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py
@@ -412,14 +412,24 @@ def _matmul4bit_export(g, n, *args, **kwargs):
return None
quant_state = args[4]
- absmax, shape, dtype, blocksize, compressed_stats, quant_type, data_type = quant_state
+ if isinstance(quant_state, list):
+ # version <= 0.41.1
+ absmax, shape, dtype, blocksize, compressed_stats, quant_type, data_type = quant_state
+ nested = compressed_stats is not None
+ else:
+ # version > 0.41.1
+ absmax = quant_state.absmax
+ shape = quant_state.shape
+ blocksize = quant_state.blocksize
+ nested = quant_state.nested
+ quant_type = quant_state.quant_type
# MatMulBnb4's blocksize needs to be a power of 2 and not smaller than 16
if blocksize < 16 or blocksize & (blocksize - 1) != 0:
return None
# MatMulBnb4 does not support double de-quantization (e.g. absmax is int, needs to be dequantized too)
- if compressed_stats is not None:
+ if nested:
return None
# The PyTorch linear weight shape is [out_feature, in_feature]
diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py
index 6bd3e2533c045..3b1a0317c58f1 100644
--- a/tools/ci_build/build.py
+++ b/tools/ci_build/build.py
@@ -14,6 +14,15 @@
import sys
from pathlib import Path
+
+def version_to_tuple(version: str) -> tuple:
+ v = []
+ for s in version.split("."):
+ with contextlib.suppress(ValueError):
+ v.append(int(s))
+ return tuple(v)
+
+
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
REPO_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", ".."))
@@ -1084,6 +1093,12 @@ def generate_build_tree(
if args.use_cuda:
nvcc_threads = number_of_nvcc_threads(args)
cmake_args.append("-Donnxruntime_NVCC_THREADS=" + str(nvcc_threads))
+ if not disable_float8_types and args.cuda_version:
+ if version_to_tuple(args.cuda_version) < (11, 8):
+ raise BuildError(
+ f"Float 8 types require CUDA>=11.8. They must be disabled on CUDA=={args.cuda_version}. "
+ f"Add '--disable_types float8' to your command line. See option disable_types."
+ )
if args.use_rocm:
cmake_args.append("-Donnxruntime_ROCM_HOME=" + rocm_home)
cmake_args.append("-Donnxruntime_ROCM_VERSION=" + args.rocm_version)
diff --git a/tools/ci_build/github/azure-pipelines/py-cuda-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-cuda-packaging-pipeline.yml
new file mode 100644
index 0000000000000..aee42d3675087
--- /dev/null
+++ b/tools/ci_build/github/azure-pipelines/py-cuda-packaging-pipeline.yml
@@ -0,0 +1,39 @@
+trigger: none
+
+parameters:
+ - name: enable_linux_gpu
+ type: boolean
+ default: true
+ - name: enable_windows_gpu
+ type: boolean
+ default: true
+ - name: cmake_build_type
+ type: string
+ default: 'Release'
+ values:
+ - Debug
+ - Release
+ - RelWithDebInfo
+ - MinSizeRel
+ - name: cuda_version
+ type: string
+ default: '12.2'
+ values:
+ - 11.8
+ - 12.2
+
+resources:
+ repositories:
+ - repository: manylinux
+ type: Github
+ endpoint: Microsoft
+ name: pypa/manylinux
+ ref: 5eda9aded5462201e6310105728d33016e637ea7
+
+stages:
+ - template: stages/py-cuda-packaging-stage.yml
+ parameters:
+ enable_linux_gpu: ${{ parameters.enable_linux_gpu }}
+ enable_windows_gpu: ${{ parameters.enable_windows_gpu }}
+ cmake_build_type: ${{ parameters.cmake_build_type }}
+ cuda_version: ${{ parameters.cuda_version }}
\ No newline at end of file
diff --git a/tools/ci_build/github/azure-pipelines/stages/py-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-cuda-packaging-stage.yml
new file mode 100644
index 0000000000000..f3d68957d649c
--- /dev/null
+++ b/tools/ci_build/github/azure-pipelines/stages/py-cuda-packaging-stage.yml
@@ -0,0 +1,105 @@
+parameters:
+- name: build_py_parameters
+ displayName: >
+ Extra parameters to pass to build.py. Don't put newlines in here.
+ type: string
+ default: ''
+
+- name: enable_linux_gpu
+ displayName: 'Whether Linux GPU package is built.'
+ type: boolean
+ default: true
+
+- name: enable_windows_gpu
+ displayName: 'Whether Windows GPU package is built.'
+ type: boolean
+ default: true
+
+# TODO: Now the Windows jobs use a different cmake build type. Consider to merge it.
+- name: cmake_build_type
+ type: string
+ displayName: 'Linux packages cmake build type. Linux Only.'
+ default: 'Release'
+ values:
+ - Debug
+ - Release
+ - RelWithDebInfo
+ - MinSizeRel
+
+- name: cuda_version
+ type: string
+ displayName: 'CUDA version. Windows Only.'
+ default: '12.2'
+ values:
+ - 11.8
+ - 12.2
+
+stages:
+- stage: Python_Packaging
+ dependsOn: []
+ variables:
+ - name: docker_base_image
+ ${{ if eq(parameters.cuda_version, '11.8') }}:
+ value: nvidia/cuda:11.8.0-cudnn8-devel-ubi8
+ ${{ if eq(parameters.cuda_version, '12.2') }}:
+ value: nvidia/cuda:12.2.2-cudnn8-devel-ubi8
+ - name: linux_trt_version
+ ${{ if eq(parameters.cuda_version, '11.8') }}:
+ value: 8.6.1.6-1.cuda11.8
+ ${{ if eq(parameters.cuda_version, '12.2') }}:
+ value: 8.6.1.6-1.cuda12.0
+ - name: win_trt_home
+ ${{ if eq(parameters.cuda_version, '11.8') }}:
+ value: $(Agent.TempDirectory)\TensorRT-8.6.1.6.Windows10.x86_64.cuda-11.8
+ ${{ if eq(parameters.cuda_version, '12.2') }}:
+ value: $(Agent.TempDirectory)\TensorRT-8.6.1.6.Windows10.x86_64.cuda-12.0
+ - name: win_cuda_home
+ ${{ if eq(parameters.cuda_version, '11.8') }}:
+ value: $(Agent.TempDirectory)\v11.8
+ ${{ if eq(parameters.cuda_version, '12.2') }}:
+ value: $(Agent.TempDirectory)\v12.2
+ jobs:
+ - ${{ if eq(parameters.enable_windows_gpu, true) }}:
+ - template: ../templates/py-win-gpu.yml
+ parameters:
+ MACHINE_POOL: 'onnxruntime-Win2022-GPU-T4'
+ PYTHON_VERSION: '3.8'
+ EP_BUILD_FLAGS: --use_tensorrt --tensorrt_home=${{ variables.win_trt_home }} --cuda_home=${{ variables.win_cuda_home }} --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80"
+ EP_NAME: gpu
+ CudaVersion: ${{ parameters.cuda_version }}
+
+ - template: ../templates/py-win-gpu.yml
+ parameters:
+ MACHINE_POOL: 'onnxruntime-Win2022-GPU-T4'
+ PYTHON_VERSION: '3.9'
+ EP_BUILD_FLAGS: --use_tensorrt --tensorrt_home=${{ variables.win_trt_home }} --cuda_home=${{ variables.win_cuda_home }} --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80"
+ EP_NAME: gpu
+ CudaVersion: ${{ parameters.cuda_version }}
+
+ - template: ../templates/py-win-gpu.yml
+ parameters:
+ MACHINE_POOL: 'onnxruntime-Win2022-GPU-T4'
+ PYTHON_VERSION: '3.10'
+ EP_BUILD_FLAGS: --use_tensorrt --tensorrt_home=${{ variables.win_trt_home }} --cuda_home=${{ variables.win_cuda_home }} --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80"
+ EP_NAME: gpu
+ CudaVersion: ${{ parameters.cuda_version }}
+
+ - template: ../templates/py-win-gpu.yml
+ parameters:
+ MACHINE_POOL: 'onnxruntime-Win2022-GPU-T4'
+ PYTHON_VERSION: '3.11'
+ EP_BUILD_FLAGS: --use_tensorrt --tensorrt_home=${{ variables.win_trt_home }} --cuda_home=${{ variables.win_cuda_home }} --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80"
+ EP_NAME: gpu
+ CudaVersion: ${{ parameters.cuda_version }}
+
+
+ - ${{ if eq(parameters.enable_linux_gpu, true) }}:
+ - template: ../templates/py-linux-gpu.yml
+ parameters:
+ arch: 'x86_64'
+ machine_pool: 'onnxruntime-Ubuntu2004-AMD-CPU'
+ extra_build_arg: ${{ parameters.build_py_parameters }}
+ cmake_build_type: ${{ parameters.cmake_build_type }}
+ docker_base_image: ${{ variables.docker_base_image }}
+ trt_version: ${{ variables.linux_trt_version }}
+ cuda_version: ${{ parameters.cuda_version }}
diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_gpu_library.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_gpu_library.yml
index 4573c56963e34..ff7f0957e94ba 100644
--- a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_gpu_library.yml
+++ b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_gpu_library.yml
@@ -34,7 +34,7 @@ steps:
displayName: 'Download TensorRT-8.6.1.6.Windows10.x86_64.cuda-11.8'
- powershell: |
Write-Host "##vso[task.prependpath]$(Agent.TempDirectory)\TensorRT-8.6.1.6.Windows10.x86_64.cuda-11.8\lib"
- displayName: 'Append CUDA SDK Directory to PATH'
+ displayName: 'Append TensorRT Directory to PATH'
- ${{ if eq(parameters.CudaVersion, '12.2') }}:
- powershell: |
@@ -42,7 +42,7 @@ steps:
displayName: 'Download TensorRT-8.6.1.6.Windows10.x86_64.cuda-12.0'
- powershell: |
Write-Host "##vso[task.prependpath]$(Agent.TempDirectory)\TensorRT-8.6.1.6.Windows10.x86_64.cuda-12.0\lib"
- displayName: 'Append CUDA SDK Directory to PATH'
+ displayName: 'Append TensorRT Directory to PATH'
- task: CmdLine@2
inputs:
diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml
index 9282cfccd02f0..e40c4d0e95dc5 100644
--- a/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml
+++ b/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml
@@ -4,6 +4,7 @@ parameters:
- name: EnvSetupScript
type: string
+ default: setup_env.bat
- name: job_name_suffix
type: string
diff --git a/tools/ci_build/github/azure-pipelines/templates/py-linux-gpu.yml b/tools/ci_build/github/azure-pipelines/templates/py-linux-gpu.yml
index f68847afff379..8cc48aac7a3b9 100644
--- a/tools/ci_build/github/azure-pipelines/templates/py-linux-gpu.yml
+++ b/tools/ci_build/github/azure-pipelines/templates/py-linux-gpu.yml
@@ -17,7 +17,24 @@ parameters:
- Release
- RelWithDebInfo
- MinSizeRel
-
+- name: docker_base_image
+ type: string
+ default: 'nvidia/cuda:11.8.0-cudnn8-devel-ubi8'
+ values:
+ - nvidia/cuda:11.8.0-cudnn8-devel-ubi8
+ - nvidia/cuda:12.2.2-cudnn8-devel-ubi8
+- name: trt_version
+ type: string
+ default: '8.6.1.6-1.cuda11.8'
+ values:
+ - 8.6.1.6-1.cuda11.8
+ - 8.6.1.6-1.cuda12.0
+- name: cuda_version
+ type: string
+ default: '11.8'
+ values:
+ - 11.8
+ - 12.2
jobs:
- job: Linux_py_GPU_Wheels_${{ parameters.arch }}
timeoutInMinutes: 240
@@ -26,7 +43,13 @@ jobs:
pool: ${{ parameters.machine_pool }}
variables:
# The build machine pool doesn't have dotnet, so it can't run CG.
- skipComponentGovernanceDetection: true
+ - name: skipComponentGovernanceDetection
+ value: true
+ - name: extra_build_args
+ ${{ if ne(parameters.extra_build_arg, '') }}:
+ value: -x ${{ parameters.extra_build_arg }}
+ ${{ if eq(parameters.extra_build_arg, '') }}:
+ value: ''
steps:
- checkout: self
clean: true
@@ -40,12 +63,12 @@ jobs:
Context: tools/ci_build/github/linux/docker
DockerBuildArgs: "
--network=host
- --build-arg BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubi8
- --build-arg TRT_VERSION=8.6.1.6-1.cuda11.8
+ --build-arg BASEIMAGE=${{ parameters.docker_base_image }}
+ --build-arg TRT_VERSION=${{ parameters.trt_version }}
--build-arg BUILD_UID=$( id -u )
--build-arg PLATFORM=${{ parameters.arch }}
"
- Repository: onnxruntimecuda118xtrt86build${{ parameters.arch }}
+ Repository: onnxruntimecuda${{ replace(parameters.cuda_version, '.', '') }}xtrt86build${{ parameters.arch }}
- task: Bash@3
@@ -53,8 +76,7 @@ jobs:
inputs:
targetType: filePath
filePath: tools/ci_build/github/linux/run_python_dockerbuild.sh
- # please check ONNXRUNTIME_CUDA_VERSION in tools/ci_build/github/linux/build_linux_arm64_python_package.sh
- arguments: -i onnxruntimecuda118xtrt86build${{ parameters.arch }} -d "GPU" -c ${{ parameters.cmake_build_type }} -x "${{ parameters.extra_build_arg }}"
+ arguments: -i onnxruntimecuda${{ replace(parameters.cuda_version, '.', '') }}xtrt86build${{ parameters.arch }} -d "GPU" -c ${{ parameters.cmake_build_type }} $(extra_build_args)
- task: PublishBuildArtifacts@1
displayName: 'Publish Artifact: ONNXRuntime python wheel'
diff --git a/tools/ci_build/github/azure-pipelines/templates/py-linux.yml b/tools/ci_build/github/azure-pipelines/templates/py-linux.yml
index 0774c3350b9b1..db3782c69cf62 100644
--- a/tools/ci_build/github/azure-pipelines/templates/py-linux.yml
+++ b/tools/ci_build/github/azure-pipelines/templates/py-linux.yml
@@ -46,9 +46,17 @@ jobs:
pool: ${{ parameters.machine_pool }}
variables:
# The build machine pool doesn't have dotnet, so it can't run CG.
- skipComponentGovernanceDetection: true
- ORT_CACHE_DIR: $(Agent.TempDirectory)/ort_ccache
- TODAY: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)]
+ - name: skipComponentGovernanceDetection
+ value: true
+ - name: ORT_CACHE_DIR
+ value: $(Agent.TempDirectory)/ort_ccache
+ - name: TODAY
+ value: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)]
+ - name: extra_build_args
+ ${{ if ne(parameters.extra_build_arg, '') }}:
+ value: -x ${{ parameters.extra_build_arg }}
+ ${{ if eq(parameters.extra_build_arg, '') }}:
+ value: ''
steps:
- task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3
displayName: 'Clean Agent Directories'
@@ -82,7 +90,7 @@ jobs:
inputs:
targetType: filePath
filePath: tools/ci_build/github/linux/run_python_dockerbuild.sh
- arguments: -i onnxruntimecpubuildpython${{ parameters.arch }} -d "${{ parameters.device }}" -c ${{ parameters.cmake_build_type }} -x "${{ parameters.extra_build_arg }}"
+ arguments: -i onnxruntimecpubuildpython${{ parameters.arch }} -d "${{ parameters.device }}" -c ${{ parameters.cmake_build_type }} $(extra_build_args)
${{ if eq(parameters.with_cache, 'true') }}:
env:
ADDITIONAL_DOCKER_PARAMETER: "--volume $(ORT_CACHE_DIR):/cache -e CCACHE_DIR=/cache -e ORT_BUILD_WITH_CACHE=1"
diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml
index 919749cac15b6..501251eaff20f 100644
--- a/tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml
+++ b/tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml
@@ -14,21 +14,32 @@ parameters:
- name: ENV_SETUP_SCRIPT
type: string
+ default: ''
- name: BUILD_PY_PARAMETERS
displayName: >
Extra parameters to pass to build.py. Don't put newlines in here.
type: string
default: ''
-
+- name: CudaVersion
+ type: string
+ default: '11.8'
+ values:
+ - 11.8
+ - 12.2
jobs:
- job: Win_py_${{ parameters.EP_NAME }}_Wheels_${{ replace(parameters.PYTHON_VERSION,'.','_') }}
timeoutInMinutes: 240
workspace:
clean: all
- pool: ${{ parameters.MACHINE_POOL }}
+ pool:
+ name: ${{ parameters.MACHINE_POOL }}
+# demands:
+# - ImageVersionOverride -equals 1.0.367516
variables:
+ GRADLE_OPTS: '-Dorg.gradle.daemon=false'
VSGenerator: 'Visual Studio 17 2022'
+ CUDA_MODULE_LOADING: 'LAZY'
steps:
- checkout: self
clean: true
@@ -61,10 +72,21 @@ jobs:
- template: download-deps.yml
- - template: jobs/set-winenv.yml
- parameters:
- EnvSetupScript: ${{ parameters.ENV_SETUP_SCRIPT }}
- DownloadCUDA: true
+ - ${{ if ne(parameters.ENV_SETUP_SCRIPT, '') }}:
+ - template: jobs/set-winenv.yml
+ parameters:
+ EnvSetupScript: ${{ parameters.ENV_SETUP_SCRIPT }}
+ ${{ if or(contains(parameters.EP_BUILD_FLAGS, 'use_cuda'), contains(parameters.EP_BUILD_FLAGS, 'use_tensorrt')) }}:
+ DownloadCUDA: true
+
+ - ${{ if eq(parameters.ENV_SETUP_SCRIPT, '') }}:
+ - template: jobs/download_win_gpu_library.yml
+ parameters:
+ CudaVersion: ${{ parameters.CudaVersion }}
+ ${{ if or(contains(parameters.EP_BUILD_FLAGS, 'use_cuda'), contains(parameters.EP_BUILD_FLAGS, 'use_tensorrt')) }}:
+ DownloadCUDA: true
+ ${{ if contains(parameters.EP_BUILD_FLAGS, 'use_tensorrt') }}:
+ DownloadTRT: true
- task: PythonScript@0
displayName: 'Update deps.txt'
diff --git a/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml
index ed010b5619db5..d7ffc1828c943 100644
--- a/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml
+++ b/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml
@@ -40,7 +40,6 @@ stages:
- template: templates/jobs/win-ci-vs-2022-job.yml
parameters:
BuildConfig: 'Debug'
- EnvSetupScript: setup_env.bat
buildArch: x64
additionalBuildFlags: --build_java --build_nodejs --build_wheel --disable_memleak_checker
msbuildPlatform: x64
@@ -59,7 +58,6 @@ stages:
- template: templates/jobs/win-ci-vs-2022-job.yml
parameters:
BuildConfig: 'RelWithDebInfo'
- EnvSetupScript: setup_env.bat
buildArch: x64
# Compare to our Nuget packaging pipeline, this job has "--build_wheel" but doesn't have "--enable_lto --disable_rtti --use_telemetry --enable_wcos"
# Python bindings use typeid so I can't disable RTTI here. If it causes a problem, we will need to split this job to two jobs.
@@ -80,7 +78,6 @@ stages:
- template: templates/jobs/win-ci-vs-2022-job.yml
parameters:
BuildConfig: 'RelWithDebInfo'
- EnvSetupScript: setup_env.bat
buildArch: x64
additionalBuildFlags: --build_wheel --use_dnnl --build_java
msbuildPlatform: x64
@@ -101,7 +98,6 @@ stages:
- template: templates/jobs/win-ci-vs-2022-job.yml
parameters:
BuildConfig: 'RelWithDebInfo'
- EnvSetupScript: setup_env.bat
buildArch: x64
additionalBuildFlags: --build_wheel --use_xnnpack
msbuildPlatform: x64
@@ -120,7 +116,6 @@ stages:
- template: templates/jobs/win-ci-vs-2022-job.yml
parameters:
BuildConfig: 'RelWithDebInfo'
- EnvSetupScript: setup_env.bat
buildArch: x64
additionalBuildFlags: --use_winml --enable_wcos --disable_rtti --cmake_extra_defines CMAKE_SYSTEM_VERSION=10.0.22000.0
msbuildPlatform: x64
@@ -160,7 +155,6 @@ stages:
- template: templates/jobs/win-ci-vs-2022-job.yml
parameters:
BuildConfig: 'Debug'
- EnvSetupScript: setup_env.bat
buildArch: x64
additionalBuildFlags: --enable_training --build_wheel --disable_memleak_checker
msbuildPlatform: x64
@@ -179,7 +173,6 @@ stages:
- template: templates/jobs/win-ci-vs-2022-job.yml
parameters:
BuildConfig: 'RelWithDebInfo'
- EnvSetupScript: setup_env.bat
buildArch: x64
additionalBuildFlags: --enable_training --build_wheel
msbuildPlatform: x64
@@ -198,7 +191,6 @@ stages:
- template: templates/jobs/win-ci-vs-2022-job.yml
parameters:
BuildConfig: 'RelWithDebInfo'
- EnvSetupScript: setup_env.bat
buildArch: x64
additionalBuildFlags: --enable_training_apis
msbuildPlatform: x64
@@ -215,10 +207,17 @@ stages:
- stage: x64_release_azure
dependsOn: []
jobs:
+ - job:
+ steps:
+ - powershell: |
+ Write-Host "##vso[task.prependpath]$(Build.BinariesDirectory)\RelWithDebInfo\_deps\vcpkg-src\installed\x86-windows\bin"
+ $env:PATH
+ Write-Host "##vso[task.prependpath]$(Build.BinariesDirectory)\RelWithDebInfo\_deps\vcpkg-src\installed\x64-windows\bin"
+ $env:PATH
+ displayName: 'Append x64-windows and x86-windows to PATH'
- template: templates/jobs/win-ci-vs-2022-job.yml
parameters:
BuildConfig: 'RelWithDebInfo'
- EnvSetupScript: setup_env_azure.bat
buildArch: x64
additionalBuildFlags: --use_azure --use_lock_free_queue
msbuildPlatform: x64
@@ -231,3 +230,5 @@ stages:
GenerateDocumentation: false
WITH_CACHE: true
MachinePool: 'onnxruntime-Win-CPU-2022'
+
+
diff --git a/tools/ci_build/github/linux/build_linux_arm64_python_package.sh b/tools/ci_build/github/linux/build_linux_python_package.sh
similarity index 78%
rename from tools/ci_build/github/linux/build_linux_arm64_python_package.sh
rename to tools/ci_build/github/linux/build_linux_python_package.sh
index 516f320cd64c4..3c1c65c9a6862 100755
--- a/tools/ci_build/github/linux/build_linux_arm64_python_package.sh
+++ b/tools/ci_build/github/linux/build_linux_python_package.sh
@@ -15,9 +15,11 @@ do case "${parameter_Option}"
in
#GPU or CPU.
d) BUILD_DEVICE=${OPTARG};;
-p) PYTHON_EXES=(${OPTARG});;
-x) EXTRA_ARG=(${OPTARG});;
+p) PYTHON_EXES=${OPTARG};;
+x) EXTRA_ARG=${OPTARG};;
c) BUILD_CONFIG=${OPTARG};;
+*) echo "Usage: $0 -d [-p ] [-x ] [-c ]"
+ exit 1;;
esac
done
@@ -48,7 +50,7 @@ if [ "$ARCH" == "x86_64" ] && [ "$GCC_VERSION" -ge 9 ]; then
fi
echo "EXTRA_ARG:"
-echo $EXTRA_ARG
+echo "$EXTRA_ARG"
if [ "$EXTRA_ARG" != "" ]; then
BUILD_ARGS+=("$EXTRA_ARG")
@@ -60,19 +62,19 @@ if [ "$ARCH" == "x86_64" ]; then
fi
if [ "$BUILD_DEVICE" == "GPU" ]; then
+ SHORT_CUDA_VERSION=$(echo $CUDA_VERSION | sed 's/\([[:digit:]]\+\.[[:digit:]]\+\)\.[[:digit:]]\+/\1/')
#Enable CUDA and TRT EPs.
- ONNXRUNTIME_CUDA_VERSION="11.8"
- BUILD_ARGS+=("--nvcc_threads=1" "--use_cuda" "--use_tensorrt" "--cuda_version=$ONNXRUNTIME_CUDA_VERSION" "--tensorrt_home=/usr" "--cuda_home=/usr/local/cuda-$ONNXRUNTIME_CUDA_VERSION" "--cudnn_home=/usr/local/cuda-$ONNXRUNTIME_CUDA_VERSION" "--cmake_extra_defines" "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80")
+ BUILD_ARGS+=("--nvcc_threads=1" "--use_cuda" "--use_tensorrt" "--cuda_version=$SHORT_CUDA_VERSION" "--tensorrt_home=/usr" "--cuda_home=/usr/local/cuda-$SHORT_CUDA_VERSION" "--cudnn_home=/usr/local/cuda-$SHORT_CUDA_VERSION" "--cmake_extra_defines" "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80")
fi
export CFLAGS
export CXXFLAGS
for PYTHON_EXE in "${PYTHON_EXES[@]}"
do
- rm -rf /build/$BUILD_CONFIG
+ rm -rf /build/"$BUILD_CONFIG"
${PYTHON_EXE} /onnxruntime_src/tools/ci_build/build.py "${BUILD_ARGS[@]}"
- cp /build/$BUILD_CONFIG/dist/*.whl /build/dist
+ cp /build/"$BUILD_CONFIG"/dist/*.whl /build/dist
done
which ccache && ccache -sv && ccache -z
diff --git a/tools/ci_build/github/linux/run_python_dockerbuild.sh b/tools/ci_build/github/linux/run_python_dockerbuild.sh
index 18ac6482827f9..ff2ce6f7ff231 100755
--- a/tools/ci_build/github/linux/run_python_dockerbuild.sh
+++ b/tools/ci_build/github/linux/run_python_dockerbuild.sh
@@ -9,24 +9,32 @@ i) DOCKER_IMAGE=${OPTARG};;
d) DEVICE=${OPTARG};;
x) BUILD_EXTR_PAR=${OPTARG};;
c) BUILD_CONFIG=${OPTARG};;
+*) echo "Usage: $0 -i -d [-x ] [-c ]"
+ exit 1;;
esac
done
-mkdir -p $HOME/.onnx
+mkdir -p "${HOME}/.onnx"
+DOCKER_SCRIPT_OPTIONS="-d ${DEVICE} -c ${BUILD_CONFIG}"
+
+if [ "${BUILD_EXTR_PAR}" != "" ] ; then
+ DOCKER_SCRIPT_OPTIONS+=" -x ${BUILD_EXTR_PAR}"
+fi
+
docker run --rm \
--volume /data/onnx:/data/onnx:ro \
- --volume $BUILD_SOURCESDIRECTORY:/onnxruntime_src \
- --volume $BUILD_BINARIESDIRECTORY:/build \
+ --volume "${BUILD_SOURCESDIRECTORY}:/onnxruntime_src" \
+ --volume "${BUILD_BINARIESDIRECTORY}:/build" \
--volume /data/models:/build/models:ro \
- --volume $HOME/.onnx:/home/onnxruntimedev/.onnx \
+ --volume "${HOME}/.onnx:/home/onnxruntimedev/.onnx" \
-w /onnxruntime_src \
-e NIGHTLY_BUILD \
-e BUILD_BUILDNUMBER \
$ADDITIONAL_DOCKER_PARAMETER \
- $DOCKER_IMAGE tools/ci_build/github/linux/build_linux_arm64_python_package.sh -d $DEVICE -c $BUILD_CONFIG -x $BUILD_EXTR_PAR
+ $DOCKER_IMAGE tools/ci_build/github/linux/build_linux_python_package.sh $DOCKER_SCRIPT_OPTIONS
-sudo rm -rf $BUILD_BINARIESDIRECTORY/$BUILD_CONFIG/onnxruntime $BUILD_BINARIESDIRECTORY/$BUILD_CONFIG/pybind11 \
- $BUILD_BINARIESDIRECTORY/$BUILD_CONFIG/models $BUILD_BINARIESDIRECTORY/$BUILD_CONFIG/_deps \
- $BUILD_BINARIESDIRECTORY/$BUILD_CONFIG/CMakeFiles
-cd $BUILD_BINARIESDIRECTORY/$BUILD_CONFIG
-find -executable -type f > $BUILD_BINARIESDIRECTORY/$BUILD_CONFIG/perms.txt
+sudo rm -rf "${BUILD_BINARIESDIRECTORY}/${BUILD_CONFIG}/onnxruntime" "${BUILD_BINARIESDIRECTORY}/${BUILD_CONFIG}/pybind11" \
+ "${BUILD_BINARIESDIRECTORY}/${BUILD_CONFIG}/models" "${BUILD_BINARIESDIRECTORY}/${BUILD_CONFIG}/_deps" \
+ "${BUILD_BINARIESDIRECTORY}/${BUILD_CONFIG}/CMakeFiles"
+cd "${BUILD_BINARIESDIRECTORY}/${BUILD_CONFIG}"
+find -executable -type f > "${BUILD_BINARIESDIRECTORY}/${BUILD_CONFIG}/perms.txt"
diff --git a/tools/ci_build/github/windows/setup_env_azure.bat b/tools/ci_build/github/windows/setup_env_azure.bat
deleted file mode 100644
index 44ba34b0bf23a..0000000000000
--- a/tools/ci_build/github/windows/setup_env_azure.bat
+++ /dev/null
@@ -1,4 +0,0 @@
-REM Copyright (c) Microsoft Corporation. All rights reserved.
-REM Licensed under the MIT License.
-set PATH=%cd%\RelWithDebInfo\_deps\vcpkg-src\installed\x64-windows\bin;%cd%\RelWithDebInfo\_deps\vcpkg-src\installed\x86-windows\bin;%PATH%
-set GRADLE_OPTS=-Dorg.gradle.daemon=false