diff --git a/.gitmodules b/.gitmodules index 036a248070855..7bb49e98bfec1 100644 --- a/.gitmodules +++ b/.gitmodules @@ -8,6 +8,3 @@ path = cmake/external/emsdk url = https://github.com/emscripten-core/emsdk.git branch = 3.1.44 -[submodule "cmake/external/onnxruntime-extensions"] - path = cmake/external/onnxruntime-extensions - url = https://github.com/microsoft/onnxruntime-extensions.git diff --git a/VERSION_NUMBER b/VERSION_NUMBER index 15b989e398fc7..092afa15df4df 100644 --- a/VERSION_NUMBER +++ b/VERSION_NUMBER @@ -1 +1 @@ -1.16.0 +1.17.0 diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index b483e6de81a47..cf71b6bcf7c7d 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -11,6 +11,8 @@ set(contrib_ops_excluded_files "bert/attention_softmax.h" "bert/attention_softmax.cu" "bert/attention_prepare_qkv.cu" + "bert/decoder_attention_impl.h" + "bert/decoder_attention_impl.cu" "bert/decoder_masked_multihead_attention.h" "bert/decoder_masked_multihead_attention.cc" "bert/decoder_masked_self_attention.h" diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs index ac790242409e3..1868ff509bfc3 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs @@ -62,10 +62,10 @@ static NativeTrainingMethods() DOrtGetApi OrtGetApi = (DOrtGetApi)Marshal.GetDelegateForFunctionPointer(NativeMethods.OrtGetApiBase().GetApi, typeof(DOrtGetApi)); // TODO: Make this save the pointer, and not copy the whole structure across - api_ = (OrtApi)OrtGetApi(16 /*ORT_API_VERSION*/); + api_ = (OrtApi)OrtGetApi(17 /*ORT_API_VERSION*/); OrtGetTrainingApi = (DOrtGetTrainingApi)Marshal.GetDelegateForFunctionPointer(api_.GetTrainingApi, typeof(DOrtGetTrainingApi)); - trainingApiPtr = OrtGetTrainingApi(16 /*ORT_API_VERSION*/); + trainingApiPtr = OrtGetTrainingApi(17 /*ORT_API_VERSION*/); if (trainingApiPtr != IntPtr.Zero) { trainingApi_ = (OrtTrainingApi)Marshal.PtrToStructure(trainingApiPtr, typeof(OrtTrainingApi)); diff --git a/docs/python/README.rst b/docs/python/README.rst index 7d978b0941235..32bb3729e01d0 100644 --- a/docs/python/README.rst +++ b/docs/python/README.rst @@ -8,6 +8,11 @@ For more information on ONNX Runtime, please see `aka.ms/onnxruntime string; @@ -101,6 +101,9 @@ export const parseCastAttributes = (attributes: Record): CastAt export const cast = (context: ComputeContext, attributes: CastAttributes): void => { let func: ElementwiseFunctionCall; switch (attributes.to) { + case DataType.float16: + func = 'vec4'; + break; case DataType.float: func = 'vec4'; break; @@ -126,11 +129,12 @@ export interface ClipAttributes extends AttributeWithCacheKey { } export const clipV10 = (context: ComputeContext, attributes: ClipAttributes): void => { + const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType); context.compute( createElementwiseProgramInfoLoader( context.inputs[0], 'Clip', a => `clamp(${a}, clip_min_, clip_max_)`, ` - const clip_min_: vec4 = vec4(f32(${attributes.min})); - const clip_max_: vec4 = vec4(f32(${attributes.max})); + const clip_min_: vec4<${dataType}> = vec4(${dataType}(${attributes.min})); + const clip_max_: vec4<${dataType}> = vec4(${dataType}(${attributes.max})); `, attributes.cacheKey), {inputs: [0]}); @@ -180,13 +184,13 @@ export const elu = (context: ComputeContext, attributes: AlphaAttributes): void attributes.cacheKey)); }; -export const erfImpl = (dataType: string) => ` -const r0: f32 = 0.3275911; -const r1: f32 = 0.254829592; -const r2: f32 = -0.284496736; -const r3: f32 = 1.421413741; -const r4: f32 = -1.453152027; -const r5: f32 = 1.061405429; +export const erfImpl = (dataType: string, varType = 'f32') => ` +const r0: ${varType} = 0.3275911; +const r1: ${varType} = 0.254829592; +const r2: ${varType} = -0.284496736; +const r3: ${varType} = 1.421413741; +const r4: ${varType} = -1.453152027; +const r5: ${varType} = 1.061405429; fn erf_vf32(v: ${dataType}) -> ${dataType} { let absv = abs(v); @@ -195,8 +199,9 @@ fn erf_vf32(v: ${dataType}) -> ${dataType} { }`; export const erf = (context: ComputeContext): void => { - context.compute( - createElementwiseProgramInfoLoader(context.inputs[0], 'Erf', a => `erf_vf32(${a})`, erfImpl('vec4'))); + const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType); + context.compute(createElementwiseProgramInfoLoader( + context.inputs[0], 'Erf', a => `erf_vf32(${a})`, erfImpl(`vec4<${dataType}>`, dataType))); }; export const exp = (context: ComputeContext): void => { @@ -208,9 +213,10 @@ export const floor = (context: ComputeContext): void => { }; export const gelu = (context: ComputeContext): void => { + const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType); context.compute(createElementwiseProgramInfoLoader( context.inputs[0], 'Gelu', a => `0.5 * ${a} * (1.0 + erf_vf32(${a} * 0.7071067811865475))`, - erfImpl('vec4'))); + erfImpl(`vec4<${dataType}>`, dataType))); }; export const leakyRelu = (context: ComputeContext, attributes: AlphaAttributes): void => { diff --git a/js/web/package-lock.json b/js/web/package-lock.json index eabd641914170..9567bc172c9ed 100644 --- a/js/web/package-lock.json +++ b/js/web/package-lock.json @@ -1,12 +1,12 @@ { "name": "onnxruntime-web", - "version": "1.16.0", + "version": "1.17.0", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "onnxruntime-web", - "version": "1.16.0", + "version": "1.17.0", "license": "MIT", "dependencies": { "flatbuffers": "^1.12.0", @@ -49,7 +49,7 @@ }, "../common": { "name": "onnxruntime-common", - "version": "1.16.0", + "version": "1.17.0", "license": "MIT", "devDependencies": { "typedoc": "^0.23.22" diff --git a/js/web/package.json b/js/web/package.json index ce06475f672fd..8ae5b733e5f21 100644 --- a/js/web/package.json +++ b/js/web/package.json @@ -8,7 +8,7 @@ "type": "git" }, "author": "fs-eire", - "version": "1.16.0", + "version": "1.17.0", "jsdelivr": "dist/ort.min.js", "dependencies": { "flatbuffers": "^1.12.0", diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 94592884ccad6..6e65645ef4756 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -602,6 +602,11 @@ // // "test_hardsigmoid", // // "test_hardswish_expanded", // // "test_hardswish", + "test_if", + // TODO: Uncomment 'test_if_seq' and 'test_if_opt' once the test infra + // supports Sequence and Optional types + // "test_if_seq", + // "test_if_opt", "test_instancenorm_epsilon", "test_instancenorm_example", // "test_isinf_negative", diff --git a/onnxruntime/__init__.py b/onnxruntime/__init__.py index d39d8edf0b73a..fd147eaa11f3f 100644 --- a/onnxruntime/__init__.py +++ b/onnxruntime/__init__.py @@ -7,7 +7,7 @@ For more information on ONNX Runtime, please see `aka.ms/onnxruntime `_ or the `Github project `_. """ -__version__ = "1.16.0" +__version__ = "1.17.0" __author__ = "Microsoft" # we need to do device version validation (for example to check Cuda version for an onnxruntime-training package). diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index 366d8fee1473b..b4a4ae208ceb1 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -153,408 +153,285 @@ size_t GetAttentionWorkspaceSize( } template -__global__ void AddBiasTransAppendKvToPresentSmall( - const T* qkv, const T* biases, T* present, - const int head_size, const int past_sequence_length, const int max_sequence_length) { - // Input: BxSxMxNxH (Format 1) - // Output: (2, B, N, [P..P+S) of MaxS, H), - // B is batch_size, S is sequence_length, M is number of matrices, N is num_heads, H is head_size - const int n = threadIdx.y; - const int s = blockIdx.x; - const int b = blockIdx.y; - const int N = blockDim.y; - const int S = gridDim.x; - const int B = gridDim.y; - - constexpr int M = 3; // Matrix count in qkv - const int m = blockIdx.z + 1; // k = 1, v = 2 - - const int NH = N * head_size; - const int NHS = NH * S; - - qkv += (n * head_size + (s * M + m) * NH + b * M * NHS); - if (biases) { - biases += (m * NH + n * head_size); - } +Status FusedTrtCrossAttention( + cudaStream_t stream, + contrib::AttentionParameters& parameters, + AttentionData& data) { + assert(data.qkv_format == AttentionQkvFormat::Q_KV_BSNH_BSN2H); - const int MsH = max_sequence_length * head_size; - const int NMsH = N * MsH; - const int BNMsH = B * NMsH; - present += ((past_sequence_length + s) * head_size + n * MsH + b * NMsH + (m - 1) * BNMsH); + // We only enable fused cross attention when there is no key padding mask. + // Otherwise, key have effective batch size 2 * batch_size, which is different from batch_size of query. + assert(data.mask_index == nullptr); - for (int h = threadIdx.x; h < head_size; h += blockDim.x) { - T bias = (biases ? biases[h] : (T)0.0f); - present[h] = qkv[h] + bias; - } -} + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + int* q_sequence_offset = GetCumulatedSequenceLength(data.cumulated_sequence_length_q_cache, + data.mask_index, batch_size, + sequence_length, stream, + data.scratch); -template -__global__ void AddBiasTransAppendKvToPresent( - const T* qkv, const T* biases, T* present, - const int head_size, const int past_sequence_length, const int max_sequence_length) { - // Input: BxSxMxNxH (Format 1) - // Output: (2, B, N, [P..P+S) of MaxS, H), - // B is batch_size, S is sequence_length, M is number of matrices, N is num_heads, H is head_size - const int n = blockIdx.x; - const int s = blockIdx.y; - const int b = (blockIdx.z >> 1); - const int N = gridDim.x; - const int S = gridDim.y; - const int B = (gridDim.z >> 1); - - constexpr int M = 3; // Matrix count in qkv - const int m = (blockIdx.z & 0x1) + 1; // k = 1, v = 2 - - const int NH = N * head_size; - const int NHS = NH * S; - - qkv += (n * head_size + (s * M + m) * NH + b * M * NHS); - if (biases) { - biases += (m * NH + n * head_size); - } + DUMP_TENSOR_INIT(); + DUMP_TENSOR_D("q_sequence_offset", q_sequence_offset, 1, batch_size + 1); - const int MsH = max_sequence_length * head_size; - const int NMsH = N * MsH; - const int BNMsH = B * NMsH; - present += ((past_sequence_length + s) * head_size + n * MsH + b * NMsH + (m - 1) * BNMsH); + int* kv_sequence_offset = q_sequence_offset + (GetSequenceOffsetSize(batch_size, false) / sizeof(int)); + kv_sequence_offset = GetCumulatedSequenceLength(data.cumulated_sequence_length_kv_cache, + data.mask_index, batch_size, parameters.kv_sequence_length, stream, + kv_sequence_offset); + CUDA_RETURN_IF_ERROR(cudaGetLastError()); - for (int h = threadIdx.x; h < head_size; h += blockDim.x) { - T bias = (biases ? biases[h] : (T)0.0f); - present[h] = qkv[h] + bias; - } -} + DUMP_TENSOR_D("kv_sequence_offset", kv_sequence_offset, 1, batch_size + 1); -// qkv buffer is merged tensor of shape (B,S,3,N,H), k v is the second/third of the 3. -// bias is of shape (3, NxH) or nullptr -// append to present of (2, B, N, (P..T) of M, H), -template -Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream, - const int max_sequence_length, - const int past_sequence_length, - const int sequence_length, - const int batch_size, - const int head_size, - const int num_heads, - const int max_threads_per_block, - const T* biases, - const T* qkv_buffer, - T* present) { - assert(head_size <= (1 << 30)); - - int64_t nh = (int64_t)head_size * num_heads; - if (nh <= max_threads_per_block) { - const dim3 grid(sequence_length, batch_size, 2); // 2 for k and v - const dim3 block(max_threads_per_block / num_heads, num_heads, 1); - - AddBiasTransAppendKvToPresentSmall<<>>( - qkv_buffer, biases, present, head_size, past_sequence_length, max_sequence_length); - } else { - const dim3 grid(num_heads, sequence_length, batch_size * 2); // 2 for k and v - const dim3 block(std::min(head_size, max_threads_per_block), 1, 1); - AddBiasTransAppendKvToPresent<<>>( - qkv_buffer, biases, present, head_size, past_sequence_length, max_sequence_length); + FusedMultiHeadCrossAttentionKernel const* cross_attention_kernel = + reinterpret_cast(data.fused_cross_attention_kernel); + + // When there is no bias, we can directly use q and packed kv from inputs. + void const* query = data.q; + void const* packed_kv = data.k; + if (data.value == nullptr && data.bias == nullptr) { + query = data.query; + packed_kv = data.key; } - return CUDA_CALL(cudaGetLastError()); + run_fused_cross_attention( + query, // Q + packed_kv, // packed KV + q_sequence_offset, // cumulated sequence length of Q + kv_sequence_offset, // cumulated sequence length of KV + data.output, // output + cross_attention_kernel, // kernels + batch_size, // batch size + parameters.num_heads, // number of heads + parameters.head_size, // head size of Q/K/V + sequence_length, // sequence length of Q + parameters.kv_sequence_length, // sequence length of KV + stream); + + DUMP_TENSOR("trt cross output", data.output, + batch_size, sequence_length, parameters.num_heads, parameters.v_head_size); + return Status::OK(); } -template Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream, - const int max_sequence_length, - const int total_sequence_length, - const int sequence_length, - const int batch_size, - const int head_size, - const int num_heads, - const int max_threads_per_block, - const float* bias, - const float* qkv_buffer, - float* present); - -template Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream, - const int max_sequence_length, - const int total_sequence_length, - const int sequence_length, - const int batch_size, - const int head_size, - const int num_heads, - const int max_threads_per_block, - const half* bias, - const half* qkv_buffer, - half* present); +template <> +Status FusedTrtCrossAttention( + cudaStream_t stream, + contrib::AttentionParameters& parameters, + AttentionData& data) { + return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, + "Trt fused cross attention does not support float tensor"); +} template -Status QkvToContext( - const cudaDeviceProp& device_prop, - cublasHandle_t& cublas, - Stream* ort_stream, +Status FusedTrtSelfAttention( + cudaStream_t stream, contrib::AttentionParameters& parameters, AttentionData& data) { - auto stream = static_cast(ort_stream->GetHandle()); - constexpr size_t element_size = sizeof(T); - const int max_threads_per_block = device_prop.maxThreadsPerBlock; const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; - const int kv_sequence_length = parameters.kv_sequence_length; - const int total_sequence_length = parameters.total_sequence_length; - const int num_heads = parameters.num_heads; - const int qk_head_size = parameters.head_size; - const int v_head_size = parameters.v_head_size; - const bool past_present_share_buffer = parameters.past_present_share_buffer; - const float mask_filter_value = parameters.mask_filter_value; - void* fused_runner = data.fused_runner; - - // At most one fused kernel is enabled. - assert((int(data.use_flash_attention) + - int(data.use_memory_efficient_attention) + - int(fused_runner != nullptr) + - int(data.fused_cross_attention_kernel != nullptr)) <= 1); - - const int batches = batch_size * num_heads; - - bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional); - bool use_fused_causal = (nullptr != fused_runner && parameters.is_unidirectional); - - QkvData qkv; - ORT_RETURN_IF_ERROR(PrepareQkv(parameters, data, stream, max_threads_per_block, qkv)); - T* scratch1 = data.has_qkv_workspace ? qkv.after_v : data.workspace; - - int present_size_per_batch_k = 0; - int present_size_per_batch_v = 0; - if (!past_present_share_buffer) { - present_size_per_batch_k = total_sequence_length * qk_head_size; - present_size_per_batch_v = total_sequence_length * v_head_size; - ORT_RETURN_IF_ERROR(ConcatPastToPresent(batch_size, num_heads, qk_head_size, v_head_size, - sequence_length, total_sequence_length, parameters.pass_past_in_kv, - stream, max_threads_per_block, data, qkv)); - - } else { // past_present_share_buffer - assert(qk_head_size == v_head_size); - assert(data.fused_cross_attention_kernel == nullptr); - assert(!use_fused_kernel); - assert(data.gemm_buffer != nullptr); - assert(!data.use_memory_efficient_attention); - assert(!data.use_flash_attention); - assert(data.has_qkv_workspace); - - if (nullptr != data.past_key || nullptr != data.present_key) { - // TODO: support this case. - ORT_THROW("buffer sharing for no bias case between past and present is not supported yet."); - } - - if (data.present != data.past) { - // For easy testing. Production should better avoid this path. - int64_t kv_size = 2LL * (int64_t)batch_size * num_heads * parameters.max_sequence_length * qk_head_size; - cudaMemcpyAsync(data.present, data.past, kv_size * sizeof(T), cudaMemcpyDeviceToDevice, stream); - } - - // append last k v to present - ORT_RETURN_IF_ERROR(LaunchAddBiasTransAppendKvToPresent( - stream, parameters.max_sequence_length, parameters.past_sequence_length, sequence_length, - batch_size, qk_head_size, num_heads, max_threads_per_block, - use_fused_causal ? nullptr : data.bias, // For fused causal, bias has been added to gemm_buffer - data.gemm_buffer, data.present)); + const bool causal = parameters.is_unidirectional; - present_size_per_batch_k = parameters.max_sequence_length * qk_head_size; - present_size_per_batch_v = present_size_per_batch_k; - qkv.k = data.present; - qkv.v = data.present + batches * present_size_per_batch_k; - } + int* sequence_offset = reinterpret_cast(data.scratch); - // Q, K and V are ready now DUMP_TENSOR_INIT(); - - if (data.fused_cross_attention_kernel != nullptr) { - assert(qkv.format == AttentionQkvFormat::Q_KV_BSNH_BSN2H); - - // We only enable fused cross attention when there is no key padding mask. - // Otherwise, key have effective batch size 2 * batch_size, which is different from batch_size of query. - assert(data.mask_index == nullptr); - - int* q_sequence_offset = GetCumulatedSequenceLength(data.cumulated_sequence_length_q_cache, - data.mask_index, batch_size, sequence_length, stream, - scratch1); - - DUMP_TENSOR_D("q_sequence_offset", q_sequence_offset, 1, batch_size + 1); - - int* kv_sequence_offset = q_sequence_offset + (GetSequenceOffsetSize(batch_size, false) / sizeof(int)); - kv_sequence_offset = GetCumulatedSequenceLength(data.cumulated_sequence_length_kv_cache, - data.mask_index, batch_size, kv_sequence_length, stream, - kv_sequence_offset); - CUDA_RETURN_IF_ERROR(cudaGetLastError()); - - DUMP_TENSOR_D("kv_sequence_offset", kv_sequence_offset, 1, batch_size + 1); - - FusedMultiHeadCrossAttentionKernel const* cross_attention_kernel = - reinterpret_cast(data.fused_cross_attention_kernel); - - // When there is no bias, we can directly use q and packed kv from inputs. - void const* query = qkv.q; - void const* packed_kv = qkv.k; - if (data.value == nullptr && data.bias == nullptr) { - query = data.query; - packed_kv = data.key; - } - - run_fused_cross_attention( - query, // Q - packed_kv, // packed KV - q_sequence_offset, // cumulated sequence length of Q - kv_sequence_offset, // cumulated sequence length of KV - data.output, // output - cross_attention_kernel, // kernels - batch_size, // batch size - num_heads, // number of heads - qk_head_size, // head size of Q/K/V - sequence_length, // sequence length of Q - kv_sequence_length, // sequence length of KV - stream); - - DUMP_TENSOR("trt cross output", data.output, batch_size, sequence_length, num_heads, v_head_size); - return Status::OK(); + if (parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING) { + DUMP_TENSOR_D("mask", reinterpret_cast(data.mask_index), batch_size, sequence_length); + LaunchTrtSequenceOffset2d(sequence_offset, data.mask_index, batch_size, sequence_length, stream); + } else { + sequence_offset = GetCumulatedSequenceLength(data.cumulated_sequence_length_q_cache, + data.mask_index, batch_size, sequence_length, stream, + sequence_offset); } + DUMP_TENSOR_D("sequence_offset", sequence_offset, 1, (data.mask_index != nullptr ? 2 : 1) * batch_size + 1); + CUDA_RETURN_IF_ERROR(cudaGetLastError()); - // Run TRT fused attention. - if (use_fused_kernel || use_fused_causal) { - int* sequence_offset = reinterpret_cast(scratch1); - if (parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING) { - DUMP_TENSOR_D("mask", reinterpret_cast(data.mask_index), batch_size, sequence_length); - LaunchTrtSequenceOffset2d(sequence_offset, data.mask_index, batch_size, sequence_length, stream); - } else { - sequence_offset = GetCumulatedSequenceLength(data.cumulated_sequence_length_q_cache, - data.mask_index, batch_size, sequence_length, stream, - sequence_offset); - } - DUMP_TENSOR_D("sequence_offset", sequence_offset, 1, (data.mask_index != nullptr ? 2 : 1) * batch_size + 1); - CUDA_RETURN_IF_ERROR(cudaGetLastError()); + FusedMHARunnerFP16v2* fused_fp16_runner = reinterpret_cast(data.fused_runner); - FusedMHARunnerFP16v2* fused_fp16_runner = reinterpret_cast(fused_runner); + const int S = causal ? sequence_length : fused_fp16_runner->getSFromMaxSeqLen(sequence_length); - const int S = use_fused_causal ? sequence_length : fused_fp16_runner->getSFromMaxSeqLen(sequence_length); + // B = 2 * batch_size when there is padding in input, and B = batch_size when padding is removed. + const int B = (nullptr == data.mask_index ? batch_size : 2 * batch_size); - // B = 2 * batch_size when there is padding in input, and B = batch_size when padding is removed. - const int B = (nullptr == data.mask_index ? batch_size : 2 * batch_size); + fused_fp16_runner->setup(S, B); - fused_fp16_runner->setup(S, B); + if (!causal) { + assert(data.qkv_format == AttentionQkvFormat::QKV_BSN3H); - if (use_fused_kernel) { - assert(qkv.format == AttentionQkvFormat::QKV_BSN3H); - - // When there is no bias, we can directly use packed qkv from inputs. - void const* packed_qkv = qkv.q; - if (data.query != nullptr && data.key == nullptr && data.bias == nullptr) { - packed_qkv = data.query; - } - - fused_fp16_runner->run(packed_qkv, sequence_offset, data.output, stream); - DUMP_TENSOR("fused output", data.output, batch_size, sequence_length, num_heads, v_head_size); - } else { - assert(qkv.format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH); - fused_fp16_runner->run(data.gemm_buffer, sequence_offset, data.output, stream); - DUMP_TENSOR("fused causal output", data.output, batch_size, sequence_length, num_heads, v_head_size); + // When there is no bias, we can directly use packed qkv from inputs. + void const* packed_qkv = data.q; + if (data.query != nullptr && data.key == nullptr && data.bias == nullptr) { + packed_qkv = data.query; } - return Status::OK(); + + fused_fp16_runner->run(packed_qkv, sequence_offset, data.output, stream); + DUMP_TENSOR("fused output", data.output, + batch_size, sequence_length, parameters.num_heads, parameters.v_head_size); + } else { + assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH); + fused_fp16_runner->run(data.gemm_buffer, sequence_offset, data.output, stream); + DUMP_TENSOR("fused causal output", data.output, + batch_size, sequence_length, parameters.num_heads, parameters.v_head_size); } + return Status::OK(); +} - // For raw attention mask, the scalar 1/sqrt(H) is moved to combine with softmax computation. - const float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(qk_head_size)) - : parameters.scale; +// Template Specialization for float type +template <> +Status FusedTrtSelfAttention( + cudaStream_t stream, + contrib::AttentionParameters& parameters, + AttentionData& data) { + return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, + "Trt fused attention does not support float tensor"); +} #if USE_FLASH_ATTENTION - if (data.use_flash_attention) { - assert(qkv.format == AttentionQkvFormat::Q_K_V_BSNH); - assert(nullptr == data.mask_index); - assert(nullptr == data.relative_position_bias); - assert(parameters.head_size == parameters.v_head_size); - - void* query = reinterpret_cast(qkv.q); - void* key = reinterpret_cast(qkv.k); - void* value = reinterpret_cast(qkv.v); - // For packed KV, we can use query input directly. - if (data.gemm_buffer == nullptr && data.key != nullptr && data.value == nullptr && data.bias == nullptr) { - query = reinterpret_cast(const_cast(data.query)); - } - - DUMP_TENSOR_D("q(BSNH)", reinterpret_cast(query), batch_size, sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("k(BSNH)", qkv.k, batch_size, parameters.total_sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("v(BSNH)", qkv.v, batch_size, parameters.total_sequence_length, num_heads, v_head_size); - - constexpr bool is_causal = false; - ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd( - device_prop, stream, query, key, value, data.output, reinterpret_cast(scratch1), - parameters.batch_size, parameters.num_heads, parameters.num_heads, parameters.head_size, - parameters.sequence_length, parameters.total_sequence_length, scale, is_causal)); +template +Status FlashAttention( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + contrib::AttentionParameters& parameters, + AttentionData& data, + float scale) { + assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH); + assert(nullptr == data.mask_index); + assert(nullptr == data.relative_position_bias); + assert(parameters.head_size == parameters.v_head_size); + + void* query = reinterpret_cast(data.q); + void* key = reinterpret_cast(data.k); + void* value = reinterpret_cast(data.v); + // For packed KV, we can use query input directly. + if (data.gemm_buffer == nullptr && data.key != nullptr && data.value == nullptr && data.bias == nullptr) { + query = reinterpret_cast(const_cast(data.query)); + } - DUMP_TENSOR("flash attention output", data.output, batch_size, sequence_length, num_heads, v_head_size); + DUMP_TENSOR_INIT(); + DUMP_TENSOR_D("q(BSNH)", reinterpret_cast(query), + parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.head_size); + DUMP_TENSOR_D("k(BSNH)", data.k, + parameters.batch_size, parameters.total_sequence_length, parameters.num_heads, parameters.head_size); + DUMP_TENSOR_D("v(BSNH)", data.v, + parameters.batch_size, parameters.total_sequence_length, + parameters.num_heads, parameters.v_head_size); + + ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd( + device_prop, stream, query, key, value, data.output, reinterpret_cast(data.scratch), + parameters.batch_size, parameters.num_heads, parameters.num_heads, parameters.head_size, + parameters.sequence_length, parameters.total_sequence_length, scale, parameters.is_unidirectional)); + + DUMP_TENSOR("flash attention output", data.output, + parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.v_head_size); + + return Status::OK(); +} - return Status::OK(); - } +template <> +Status FlashAttention( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + contrib::AttentionParameters& parameters, + AttentionData& data, + float scale) { + return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, "flash attention does not support float tensor"); +} #endif #if USE_MEMORY_EFFICIENT_ATTENTION - if (data.use_memory_efficient_attention) { - // We only enable fused cross attention when there is no key padding mask. - // Otherwise, key have effective batch size 2 * batch_size, which is different from batch_size of query. - assert(qkv.format == AttentionQkvFormat::Q_K_V_BSNH); - - const void* query = qkv.q; - const void* key = qkv.k; - const void* value = qkv.v; - // For packed KV, we can use query input directly. - if (data.gemm_buffer == nullptr && data.key != nullptr && data.value == nullptr) { - assert(data.bias == nullptr); - query = data.query; - } +template +Status EfficientAttention( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + contrib::AttentionParameters& parameters, + AttentionData& data, + float scale) { + // We only enable fused cross attention when there is no key padding mask. + // Otherwise, key have effective batch size 2 * batch_size, which is different from batch_size of query. + assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH); + + const void* query = data.q; + const void* key = data.k; + const void* value = data.v; + // For packed KV, we can use query input directly. + if (data.gemm_buffer == nullptr && data.key != nullptr && data.value == nullptr) { + assert(data.bias == nullptr); + query = data.query; + } - DUMP_TENSOR_D("q(BSNH)", reinterpret_cast(query), batch_size, sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("k(BSNH)", qkv.k, batch_size, parameters.total_sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("v(BSNH)", qkv.v, batch_size, parameters.total_sequence_length, num_heads, v_head_size); - - MemoryEfficientAttentionParams p; - p.sm = device_prop.major * 10 + device_prop.minor; - p.is_half = sizeof(T) == 2; - p.batch_size = parameters.batch_size; - p.num_heads = parameters.num_heads; - p.sequence_length = parameters.sequence_length; - p.kv_sequence_length = parameters.total_sequence_length; - p.qk_head_size = parameters.head_size; - p.v_head_size = parameters.v_head_size; - p.causal = parameters.is_unidirectional; - p.scale = scale; - p.seqlen_k_ptr = nullptr == data.mask_index + DUMP_TENSOR_INIT(); + DUMP_TENSOR_D("q(BSNH)", reinterpret_cast(query), + parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.head_size); + DUMP_TENSOR_D("k(BSNH)", data.k, + parameters.batch_size, parameters.total_sequence_length, parameters.num_heads, parameters.head_size); + DUMP_TENSOR_D("v(BSNH)", data.v, + parameters.batch_size, parameters.total_sequence_length, + parameters.num_heads, parameters.v_head_size); + + MemoryEfficientAttentionParams p; + p.sm = device_prop.major * 10 + device_prop.minor; + p.is_half = sizeof(T) == 2; + p.batch_size = parameters.batch_size; + p.num_heads = parameters.num_heads; + p.sequence_length = parameters.sequence_length; + p.kv_sequence_length = parameters.total_sequence_length; + p.qk_head_size = parameters.head_size; + p.v_head_size = parameters.v_head_size; + p.causal = parameters.is_unidirectional; + p.scale = scale; + p.seqlen_k_ptr = nullptr == data.mask_index + ? nullptr + : const_cast(reinterpret_cast(data.mask_index)); + p.seqstart_q_ptr = nullptr == data.mask_index ? nullptr - : const_cast(reinterpret_cast(data.mask_index)); - p.seqstart_q_ptr = nullptr == data.mask_index - ? nullptr - : const_cast(reinterpret_cast(data.mask_index + batch_size)); - p.seqstart_k_ptr = nullptr == data.mask_index - ? nullptr - : const_cast(reinterpret_cast(data.mask_index + 2 * batch_size + 1)); - p.query = query; - p.key = key; - p.value = value; - p.attn_bias = nullptr == data.relative_position_bias ? nullptr : data.relative_position_bias; - p.is_attn_bias_batched = !parameters.broadcast_res_pos_bias; - p.output = data.output; - p.workspace = MemoryEfficientAttentionParams::need_workspace(v_head_size, sizeof(T) == sizeof(float)) - ? scratch1 - : nullptr; - p.stream = stream; - run_memory_efficient_attention(p); - DUMP_TENSOR("efficient attention output", data.output, batch_size, sequence_length, num_heads, v_head_size); - - return Status::OK(); - } + : const_cast(reinterpret_cast( + data.mask_index + parameters.batch_size)); + p.seqstart_k_ptr = nullptr == data.mask_index + ? nullptr + : const_cast(reinterpret_cast( + data.mask_index + 2 * parameters.batch_size + 1)); + p.query = query; + p.key = key; + p.value = value; + p.attn_bias = nullptr == data.relative_position_bias ? nullptr : data.relative_position_bias; + p.is_attn_bias_batched = !parameters.broadcast_res_pos_bias; + p.output = data.output; + p.workspace = MemoryEfficientAttentionParams::need_workspace(parameters.v_head_size, sizeof(T) == sizeof(float)) + ? data.scratch + : nullptr; + p.stream = stream; + run_memory_efficient_attention(p); + DUMP_TENSOR("efficient attention output", data.output, + parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.v_head_size); + + return Status::OK(); +} #endif - // The following are unfused attention. - assert(qkv.format == AttentionQkvFormat::Q_K_V_BNSH); +template +Status UnfusedAttention( + const cudaDeviceProp& device_prop, + cublasHandle_t& cublas, + Stream* ort_stream, + contrib::AttentionParameters& parameters, + AttentionData& data, + float scale) { + assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH); + + auto stream = static_cast(ort_stream->GetHandle()); + + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int total_sequence_length = parameters.total_sequence_length; + const int num_heads = parameters.num_heads; + const int qk_head_size = parameters.head_size; + const int v_head_size = parameters.v_head_size; + const int batches = batch_size * num_heads; + const int* mask_index = data.mask_index; gsl::span& mask_index_dims = data.mask_index_dims; // Raw attention mask could be 2D (BxT) or 3D (BxSxT) or 4D(Bx1xMxM), where M is the max sequence length. bool use_raw_attention_mask = (nullptr != mask_index && mask_index_dims.size() >= 2); - // Compute Q*K' (as K'*Q), scaled by 1/sqrt(H) and store in scratch1: BxNxSxT + // Compute Q*K' (as K'*Q), scaled by 1/sqrt(H) and store in scratch: BxNxSxT // Q: BxNxSxH, K (present_k): BxNxTxH, Q*K': BxNxSxT float one = 1.0f; float zero = 0.f; @@ -563,22 +440,31 @@ Status QkvToContext( cublasSetStream(cublas, stream); - DUMP_TENSOR_D("q[BNSH]", q, batch_size, num_heads, sequence_length, qk_head_size); - DUMP_TENSOR_D("k[BNSH]", k, batch_size, num_heads, total_sequence_length, qk_head_size); + DUMP_TENSOR_INIT(); + DUMP_TENSOR_D("q[BNSH]", data.q, batch_size, num_heads, sequence_length, qk_head_size); + DUMP_TENSOR_D("k[BNSH]", data.k, batch_size, num_heads, total_sequence_length, qk_head_size); + + const int present_sequence_length = parameters.past_present_share_buffer + ? parameters.max_sequence_length + : total_sequence_length; + const int present_size_per_batch_k = present_sequence_length * qk_head_size; + const int present_size_per_batch_v = present_sequence_length * v_head_size; + CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( cublas, CUBLAS_OP_T, CUBLAS_OP_N, total_sequence_length, sequence_length, qk_head_size, - &alpha, qkv.k, qk_head_size, present_size_per_batch_k, - qkv.q, qk_head_size, sequence_length * qk_head_size, - &zero, scratch1, total_sequence_length, sequence_length * total_sequence_length, batches, device_prop)); + &alpha, data.k, qk_head_size, present_size_per_batch_k, + data.q, qk_head_size, sequence_length * qk_head_size, + &zero, data.scratch, total_sequence_length, sequence_length * total_sequence_length, batches, device_prop)); - DUMP_TENSOR_D("Q", qkv.q, batch_size, num_heads, sequence_length, qk_head_size); - DUMP_TENSOR_D("K", qkv.k, batch_size, num_heads, qk_head_size, sequence_length); - DUMP_TENSOR_D("QK", scratch1, batch_size, num_heads, sequence_length, total_sequence_length); + DUMP_TENSOR_D("Q", data.q, batch_size, num_heads, sequence_length, qk_head_size); + DUMP_TENSOR_D("K", data.k, batch_size, num_heads, qk_head_size, sequence_length); + DUMP_TENSOR_D("QK", data.scratch, batch_size, num_heads, sequence_length, total_sequence_length); + constexpr size_t element_size = sizeof(T); const size_t bytes = GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length, total_sequence_length); - T* scratch2 = scratch1 + (bytes / element_size); + T* scratch2 = data.scratch + (bytes / element_size); // Apply softmax and store result R to scratch2: BxNxSxT if (use_raw_attention_mask) { // 2d, 3d or 4d attention mask @@ -588,14 +474,15 @@ Status QkvToContext( const TransformerOptions* options = TransformerOptions::GetInstance(); bool use_persistent_softmax = options->IsPrecisionMode() && !options->DisablePersistentSoftmax(); - T* persistent_softmax_workspace = scratch1; // replace Q*K' in place with masked score for persistent softmax. + // replace Q*K' in place with masked score for persistent softmax. + T* persistent_softmax_workspace = data.scratch; ORT_RETURN_IF_ERROR( ComputeSoftmaxWithRawMask( ort_stream, total_sequence_length, sequence_length, batch_size, num_heads, mask_index, nullptr, data.relative_position_bias, parameters.broadcast_res_pos_bias, - scratch1, scratch2, parameters.is_unidirectional, scale, mask_dimension, + data.scratch, scratch2, parameters.is_unidirectional, scale, mask_dimension, parameters.max_sequence_length, use_persistent_softmax, persistent_softmax_workspace, - mask_filter_value)); + parameters.mask_filter_value)); } else if (nullptr != mask_index) { // 1d mask index assert(mask_index_dims.size() == 1); // mask_index has 1D shape: either (batch_size) or (2*batch_size). Only the later one has start postions. @@ -603,277 +490,123 @@ Status QkvToContext( ORT_RETURN_IF_ERROR(ComputeSoftmaxWithMask1D( stream, total_sequence_length, sequence_length, batch_size, num_heads, mask_index, mask_start, data.relative_position_bias, parameters.broadcast_res_pos_bias, - scratch1, scratch2, parameters.is_unidirectional)); + data.scratch, scratch2, parameters.is_unidirectional)); } else { // no mask ORT_RETURN_IF_ERROR( ComputeSoftmax( stream, total_sequence_length, sequence_length, batch_size, num_heads, data.relative_position_bias, - parameters.broadcast_res_pos_bias, scratch1, scratch2, parameters.is_unidirectional)); + parameters.broadcast_res_pos_bias, data.scratch, scratch2, parameters.is_unidirectional)); } DUMP_TENSOR_D("Softmax", scratch2, batch_size, num_heads, sequence_length, total_sequence_length); - DUMP_TENSOR_D("V", qkv.v, batch_size, num_heads, sequence_length, v_head_size); + DUMP_TENSOR_D("V", data.v, batch_size, num_heads, sequence_length, v_head_size); // compute R*V (as V*R), and store in temp_output (space used by Q): BxNxSxH_v - T* temp_output = qkv.q; + T* temp_output = data.q; CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( cublas, CUBLAS_OP_N, CUBLAS_OP_N, v_head_size, sequence_length, total_sequence_length, - &one, qkv.v, v_head_size, present_size_per_batch_v, + &one, data.v, v_head_size, present_size_per_batch_v, scratch2, total_sequence_length, sequence_length * total_sequence_length, &zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop)); // Temp_output is BxNxSxH_v, transpose to output BxSxNxH_v Status result = LaunchTransCtx(stream, sequence_length, batch_size, v_head_size, num_heads, - max_threads_per_block, false, temp_output, data.output); + device_prop.maxThreadsPerBlock, false, temp_output, data.output); DUMP_TENSOR("unfused output", data.output, batch_size, sequence_length, num_heads, v_head_size); return result; } template -Status DecoderQkvToContext( +Status QkvToContext( const cudaDeviceProp& device_prop, - Stream* ort_stream, cublasHandle_t& cublas, - const size_t element_size, - const int batch_size, - const int sequence_length, - const int kv_sequence_length, - const int num_heads, - const int head_size, - const bool static_kv, - const bool use_past, - const bool has_layer_state, - const bool has_key_padding_mask, - const float mask_filter_value, - const T* gemm_query_buffer, - const T* gemm_kv_buffer, - const bool* key_padding_mask, - const T* key_cache, - const T* value_cache, - T* qkv_buffer, - T* workspace_buffer, - T* output, - T* new_key_cache, - T* new_value_cache) { + Stream* ort_stream, + contrib::AttentionParameters& parameters, + AttentionData& data) { + auto stream = static_cast(ort_stream->GetHandle()); const int max_threads_per_block = device_prop.maxThreadsPerBlock; - const int BN = batch_size * num_heads; - const int BHN = BN * head_size; - const int BNS = BN * sequence_length; - const int k_buffer_offset = sequence_length * BHN; - const int v_buffer_offset = (sequence_length + kv_sequence_length) * BHN; + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int total_sequence_length = parameters.total_sequence_length; + const int num_heads = parameters.num_heads; + const int qk_head_size = parameters.head_size; + const int v_head_size = parameters.v_head_size; + void* fused_runner = data.fused_runner; - T* temp_qkv_buffer = workspace_buffer; - auto stream = static_cast(ort_stream->GetHandle()); + // At most one fused kernel is enabled. + assert((int(data.use_flash_attention) + + int(data.use_memory_efficient_attention) + + int(fused_runner != nullptr) + + int(data.fused_cross_attention_kernel != nullptr)) <= 1); - const T* q = qkv_buffer; - // transpose q and copy them to qkv_buffer - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, head_size, num_heads, - max_threads_per_block, true, gemm_query_buffer, qkv_buffer)); - - const T* k = qkv_buffer + k_buffer_offset; - const T* v = qkv_buffer + v_buffer_offset; - if (!has_layer_state || !use_past) { - if (!static_kv) { - // transpose kv and copy them to qkv_buffer - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, sequence_length, batch_size, head_size, num_heads, - max_threads_per_block, true, gemm_kv_buffer, qkv_buffer + k_buffer_offset)); - } else { - // transpose kv and copy them to qkv_buffer - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, kv_sequence_length, batch_size, head_size, num_heads, - max_threads_per_block, true, gemm_kv_buffer, qkv_buffer + k_buffer_offset)); - } - } else { - if (!static_kv) { - // transpose kv and copy them to temp_buffer - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, sequence_length, batch_size, head_size, num_heads, - max_threads_per_block, true, gemm_kv_buffer, temp_qkv_buffer)); - // concat cache-k with k and copy to qkv_buffer - if (nullptr != key_cache) { - ORT_RETURN_IF_ERROR(LaunchConcatTensorToTensor(stream, kv_sequence_length, - sequence_length, batch_size, head_size, num_heads, - max_threads_per_block, 1, - key_cache, - temp_qkv_buffer, - qkv_buffer + k_buffer_offset)); - } - // concat cache-v with v and copy to qkv_buffer - if (nullptr != value_cache) { - ORT_RETURN_IF_ERROR(LaunchConcatTensorToTensor(stream, kv_sequence_length, - sequence_length, batch_size, head_size, num_heads, - max_threads_per_block, 1, - value_cache, - temp_qkv_buffer + k_buffer_offset, - qkv_buffer + v_buffer_offset)); - } + ORT_RETURN_IF_ERROR(PrepareQkv(parameters, data, stream, max_threads_per_block)); + + if (!parameters.past_present_share_buffer) { + ORT_RETURN_IF_ERROR(ConcatPastToPresent(batch_size, num_heads, qk_head_size, v_head_size, + sequence_length, total_sequence_length, parameters.pass_past_in_kv, + stream, max_threads_per_block, data)); + + } else { // past_present_share_buffer + assert(qk_head_size == v_head_size); + assert(data.fused_cross_attention_kernel == nullptr); + assert(nullptr == fused_runner || parameters.is_unidirectional); + assert(data.gemm_buffer != nullptr); + assert(!data.use_memory_efficient_attention); + assert(!data.use_flash_attention); + assert(data.has_qkv_workspace); + + if (nullptr != data.past_key || nullptr != data.present_key) { + // TODO: support this case. + ORT_THROW("buffer sharing for no bias case between past and present is not supported yet."); } - } - if (has_layer_state) { - if (use_past && static_kv) { - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(new_key_cache, key_cache, kv_sequence_length * BHN * sizeof(T), - cudaMemcpyDeviceToDevice, stream)); - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(new_value_cache, value_cache, kv_sequence_length * BHN * sizeof(T), - cudaMemcpyDeviceToDevice, stream)); - } else { - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(new_key_cache, k, kv_sequence_length * BHN * sizeof(T), - cudaMemcpyDeviceToDevice, stream)); - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(new_value_cache, v, kv_sequence_length * BHN * sizeof(T), - cudaMemcpyDeviceToDevice, stream)); + if (data.present != data.past) { + // For easy testing. Production should better avoid this path. + int64_t kv_size = 2LL * (int64_t)batch_size * num_heads * parameters.max_sequence_length * qk_head_size; + cudaMemcpyAsync(data.present, data.past, kv_size * sizeof(T), cudaMemcpyDeviceToDevice, stream); } - } - // scratch1: BxNxSxL buffer - // scratch2: BxNxSxL buffer - // scratch3: BxNxSxH buffer - T* scratch1 = temp_qkv_buffer + 3 * BHN * sequence_length; - T* scratch2 = scratch1 + BNS * kv_sequence_length; - T* scratch3 = scratch2 + BNS * kv_sequence_length; - - // compute Q*K' (as K'*Q), scaled by 1/sqrt(H) and store in scratch1: BxNxSxL - // Q: BxNxSxH, K (present_k): BxNxLxH, Q*K': BxNxSxL - const float rsqrt_head_size = 1.f / sqrt(static_cast(head_size)); - const int temp_matrix_size = sequence_length * kv_sequence_length; - float one = 1.0f; - float zero = 0.f; + // For fused causal, bias has been added to gemm_buffer. + const T* bias = (nullptr != fused_runner && parameters.is_unidirectional) ? nullptr : data.bias; - float alpha = rsqrt_head_size; - const int strideA = kv_sequence_length * head_size; - const int strideB = sequence_length * head_size; - if (use_past && static_kv) { - CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( - cublas, CUBLAS_OP_T, CUBLAS_OP_N, - kv_sequence_length, sequence_length, head_size, - &alpha, key_cache, head_size, strideA, - q, head_size, strideB, - &zero, scratch1, kv_sequence_length, temp_matrix_size, BN, device_prop)); - } else { - CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( - cublas, CUBLAS_OP_T, CUBLAS_OP_N, - kv_sequence_length, sequence_length, head_size, - &alpha, k, head_size, strideA, - q, head_size, strideB, - &zero, scratch1, kv_sequence_length, temp_matrix_size, BN, device_prop)); + // append last k v to present + ORT_RETURN_IF_ERROR(LaunchAddBiasTransAppendKvToPresent( + stream, parameters.max_sequence_length, parameters.past_sequence_length, sequence_length, + batch_size, qk_head_size, num_heads, max_threads_per_block, + bias, data.gemm_buffer, data.present)); + + data.k = data.present; + data.v = data.present + batch_size * num_heads * parameters.max_sequence_length * qk_head_size; } - constexpr bool is_unidirectional = false; - const T* add_before_softmax = nullptr; - if (has_key_padding_mask) { - constexpr int mask_dimension = 2; - constexpr int max_sequence_length = 0; - ORT_RETURN_IF_ERROR(ComputeSoftmaxWithRawMask( - ort_stream, kv_sequence_length, sequence_length, batch_size, - num_heads, nullptr, key_padding_mask, add_before_softmax, - false /*broadcast rpb*/, scratch1, scratch2, is_unidirectional, - 1.0f, mask_dimension, max_sequence_length, false, nullptr, - mask_filter_value)); - } else { - ORT_RETURN_IF_ERROR(ComputeSoftmax( - stream, kv_sequence_length, sequence_length, batch_size, num_heads, - add_before_softmax, false /*broadcast rpb*/, scratch1, scratch2, - is_unidirectional)); + // Q, K and V are ready now + if (data.fused_cross_attention_kernel != nullptr) { + return FusedTrtCrossAttention(stream, parameters, data); } - // compute P*V (as V*P), and store in scratch3: BxNxSxH - if (use_past && static_kv) { - CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( - cublas, CUBLAS_OP_N, CUBLAS_OP_N, - head_size, sequence_length, kv_sequence_length, - &one, value_cache, head_size, strideA, - scratch2, kv_sequence_length, temp_matrix_size, - &zero, scratch3, head_size, strideB, BN, device_prop)); - } else { - CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( - cublas, CUBLAS_OP_N, CUBLAS_OP_N, - head_size, sequence_length, kv_sequence_length, - &one, v, head_size, strideA, - scratch2, kv_sequence_length, temp_matrix_size, - &zero, scratch3, head_size, strideB, BN, device_prop)); + // Run TRT fused attention. + if (nullptr != fused_runner) { + return FusedTrtSelfAttention(stream, parameters, data); } - // scratch3 is BxNxSxH, transpose to output SxBxNxH - return LaunchTransCtx(stream, sequence_length, batch_size, head_size, num_heads, - max_threads_per_block, true, scratch3, output); -} + // For raw attention mask, the scalar 1/sqrt(H) is moved to combine with softmax computation. + const float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(qk_head_size)) + : parameters.scale; -Status LaunchDecoderAttentionKernel( - const cudaDeviceProp& device_prop, - Stream* stream, - cublasHandle_t& cublas, - const size_t element_size, - const int batch_size, - const int sequence_length, - const int kv_sequence_length, - const int num_heads, - const int head_size, - const bool static_kv, - const bool use_past, - const bool has_layer_state, - const bool has_key_padding_mask, - const float mask_filter_value, - const void* gemm_query_buffer, - const void* gemm_kv_buffer, - const bool* key_padding_mask, - const void* key_cache, - const void* value_cache, - void* qkv_buffer, - void* workspace_buffer, - void* output, - void* new_key_cache, - void* new_value_cache) { - if (element_size == 2) { - return DecoderQkvToContext( - device_prop, - stream, - cublas, - element_size, - batch_size, - sequence_length, - kv_sequence_length, - num_heads, - head_size, - static_kv, - use_past, - has_layer_state, - has_key_padding_mask, - mask_filter_value, - reinterpret_cast(gemm_query_buffer), - reinterpret_cast(gemm_kv_buffer), - key_padding_mask, - reinterpret_cast(key_cache), - reinterpret_cast(value_cache), - reinterpret_cast(qkv_buffer), - reinterpret_cast(workspace_buffer), - reinterpret_cast(output), - reinterpret_cast(new_key_cache), - reinterpret_cast(new_value_cache)); - } else { - return DecoderQkvToContext( - device_prop, - stream, - cublas, - element_size, - batch_size, - sequence_length, - kv_sequence_length, - num_heads, - head_size, - static_kv, - use_past, - has_layer_state, - has_key_padding_mask, - mask_filter_value, - reinterpret_cast(gemm_query_buffer), - reinterpret_cast(gemm_kv_buffer), - key_padding_mask, - reinterpret_cast(key_cache), - reinterpret_cast(value_cache), - reinterpret_cast(qkv_buffer), - reinterpret_cast(workspace_buffer), - reinterpret_cast(output), - reinterpret_cast(new_key_cache), - reinterpret_cast(new_value_cache)); +#if USE_FLASH_ATTENTION + if (data.use_flash_attention) { + return FlashAttention(device_prop, stream, parameters, data, scale); + } +#endif + +#if USE_MEMORY_EFFICIENT_ATTENTION + if (data.use_memory_efficient_attention) { + return EfficientAttention(device_prop, stream, parameters, data, scale); } +#endif + + return UnfusedAttention(device_prop, cublas, ort_stream, parameters, data, scale); } // Template Instantiation diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h index c361a47c364d3..d0a5fb51a25d6 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h @@ -81,24 +81,20 @@ struct AttentionData { mutable CumulatedSequenceLengthCache* cumulated_sequence_length_q_cache = nullptr; mutable CumulatedSequenceLengthCache* cumulated_sequence_length_kv_cache = nullptr; -}; -// Intermediate data pointers available after PrepareQKV -template -struct QkvData { + // Intermediate data T* q = nullptr; T* k = nullptr; T* v = nullptr; - T* after_v = nullptr; // pointer right after v - AttentionQkvFormat format = AttentionQkvFormat::Q_K_V_BSNH; + T* scratch = nullptr; + AttentionQkvFormat qkv_format = AttentionQkvFormat::Q_K_V_BSNH; }; template Status PrepareQkv(contrib::AttentionParameters& parameters, AttentionData& data, cudaStream_t stream, - int max_threads_per_block, - QkvData& qkv_data); + int max_threads_per_block); template Status QkvToContext( @@ -108,33 +104,6 @@ Status QkvToContext( contrib::AttentionParameters& parameters, AttentionData& data); -Status LaunchDecoderAttentionKernel( - const cudaDeviceProp& prop, // Device Properties - Stream* stream, // ORT Stream - cublasHandle_t& cublas, // Cublas handle - const size_t element_size, // Element size of input tensor - const int batch_size, // Batch size (B) - const int sequence_length, // Sequence length (S) - const int kv_sequence_length, // Key/Value/Cache sequence length - const int num_heads, // Number of attention heads (N) - const int head_size, // Hidden size per head (H) - const bool static_kv, // Whether cross attention or not - const bool use_past, // Whether use cache or not - const bool has_layer_state, // Whether output cache or not - const bool has_key_padding_mask, // Whether use key_padding_mask or not - const float mask_filter_value, // Mask filter value - const void* gemm_query_buffer, // Query buffer - const void* gemm_kv_buffer, // Key and value buffer - const bool* key_padding_mask, // Key padding mask - const void* key_cache, // Input key cache - const void* value_cache, // Input value cache - void* qkv_buffer, // Temporary buffer - void* workspace_buffer, // Temporary buffer - void* output, // Output tensor - void* new_key_cache, // New_key_cache tensor - void* new_value_cache // New_value_cache tensor -); - // BxNxSxH => BxSxNxH or SxBxNxH (reversed_bs is true) Status LaunchTransCtx(cudaStream_t stream, const int sequence_length, const int batch_size, const int head_size, const int num_heads, @@ -184,14 +153,27 @@ Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int int sequence_length, int total_sequence_length, bool pass_past_in_kv, cudaStream_t stream, int max_threads_per_block, - AttentionData& data, - QkvData& qkv); + AttentionData& data); + +template +Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream, + const int max_sequence_length, + const int past_sequence_length, + const int sequence_length, + const int batch_size, + const int head_size, + const int num_heads, + const int max_threads_per_block, + const T* biases, + const T* qkv_buffer, + T* present); template Status LaunchStridedCopy(cudaStream_t stream, const T* in, int4 in_shape, longlong4 in_strides, // coord (b,n,s,h) T* out, longlong4 out_strides, // coord (b,n,s,h) int max_threads_per_block); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_concat.cu b/onnxruntime/contrib_ops/cuda/bert/attention_kv_cache.cu similarity index 67% rename from onnxruntime/contrib_ops/cuda/bert/attention_concat.cu rename to onnxruntime/contrib_ops/cuda/bert/attention_kv_cache.cu index 8378ee2691c6b..89be0f1115f41 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_concat.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_kv_cache.cu @@ -1,8 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/providers/cuda/cuda_common.h" #include "contrib_ops/cuda/bert/attention_impl.h" +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/cu_inc/common.cuh" using namespace onnxruntime::cuda; @@ -244,48 +245,48 @@ Status LaunchConcatPastToPresent(cudaStream_t stream, present); } -#ifndef USE_ROCM // exclude from hipify +#ifndef USE_ROCM // exclude the following from hipify since they are not used in ROCM EP + template Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int v_head_size, int sequence_length, int total_sequence_length, bool pass_past_in_kv, cudaStream_t stream, int max_threads_per_block, - AttentionData& data, - QkvData& qkv) { + AttentionData& data) { // Concat past key value to present (2xBxNxLxH), where L is kv_sequence_length and T is total_sequence_length. // past_k (BxNxPxH) + k (BxNxLxH) => present_k (BxNxTxH) // past_v (BxNxPxH) + v (BxNxLxH) => present_v (BxNxTxH) // When there is past state, the head size for Q/K/V shall be same: H == H_v. if (nullptr != data.present) { - assert(qkv.format == AttentionQkvFormat::Q_K_V_BNSH || qkv.format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH); + assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH || + data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH); + ORT_RETURN_IF_ERROR( LaunchConcatPastToPresent( stream, total_sequence_length, sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, data.past, qkv.k, data.present)); + max_threads_per_block, data.past, data.k, data.present)); // Update pointers to present_k and present_v. - qkv.k = data.present; - qkv.v = data.present + batch_size * num_heads * total_sequence_length * qk_head_size; - } - - if (nullptr != data.past_key || nullptr != data.present_key) { + data.k = data.present; + data.v = data.present + batch_size * num_heads * total_sequence_length * qk_head_size; + } else if (nullptr != data.past_key || nullptr != data.present_key) { if (nullptr != data.past_key && nullptr == data.present_key) { - qkv.k = const_cast(data.past_key); - qkv.v = const_cast(data.past_value); + data.k = const_cast(data.past_key); + data.v = const_cast(data.past_value); } else if (nullptr == data.past_key && nullptr != data.present_key) { - if (qkv.format == AttentionQkvFormat::Q_K_V_BNSH) { - qkv.k = data.present_key; - qkv.v = data.present_value; + if (data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH) { + data.k = data.present_key; + data.v = data.present_value; } else { - assert(qkv.format == AttentionQkvFormat::Q_K_V_BSNH); - qkv.k = data.temp_k_workspace; - qkv.v = data.temp_v_workspace; + assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH); + data.k = data.temp_k_workspace; + data.v = data.temp_v_workspace; } } else if (pass_past_in_kv) { // past_key and past_value are used directly as key and value in attention computations - qkv.k = const_cast(data.past_key); - qkv.v = const_cast(data.past_value); + data.k = const_cast(data.past_key); + data.v = const_cast(data.past_value); // This path has a memory copy from past_key and past_value to present_key and present_value // Avoid this path since the memory copy is unnecessary because past_key == present_key and @@ -298,14 +299,14 @@ Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int ORT_RETURN_IF_ERROR( LaunchConcatTensorToTensor(stream, total_sequence_length, sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, 1, data.past_key, qkv.k, data.present_key)); + max_threads_per_block, 1, data.past_key, data.k, data.present_key)); ORT_RETURN_IF_ERROR( LaunchConcatTensorToTensor(stream, total_sequence_length, sequence_length, batch_size, v_head_size, num_heads, - max_threads_per_block, 1, data.past_value, qkv.v, data.present_value)); + max_threads_per_block, 1, data.past_value, data.v, data.present_value)); // Update pointers to present_k and present_v. - qkv.k = data.present_key; - qkv.v = data.present_value; + data.k = data.present_key; + data.v = data.present_value; } } @@ -317,15 +318,147 @@ template Status ConcatPastToPresent(int batch_size, int num_heads, int qk int sequence_length, int total_sequence_length, bool pass_past_in_kv, cudaStream_t stream, int max_threads_per_block, - AttentionData& data, - QkvData& qkv); + AttentionData& data); template Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int v_head_size, int sequence_length, int total_sequence_length, bool pass_past_in_kv, cudaStream_t stream, int max_threads_per_block, - AttentionData& data, - QkvData& qkv); + AttentionData& data); + +// ---------------------------------------------------------------------------------- +// Below kernels are for past and present sharing buffer +// ---------------------------------------------------------------------------------- + +template +__global__ void AddBiasTransAppendKvToPresentSmall( + const T* qkv, const T* biases, T* present, + const int head_size, const int past_sequence_length, const int max_sequence_length) { + // Input: BxSxMxNxH (Format 1) + // Output: (2, B, N, [P..P+S) of MaxS, H), + // B is batch_size, S is sequence_length, M is number of matrices, N is num_heads, H is head_size + const int n = threadIdx.y; + const int s = blockIdx.x; + const int b = blockIdx.y; + const int N = blockDim.y; + const int S = gridDim.x; + const int B = gridDim.y; + + constexpr int M = 3; // Matrix count in qkv + const int m = blockIdx.z + 1; // k = 1, v = 2 + + const int NH = N * head_size; + const int NHS = NH * S; + + qkv += (n * head_size + (s * M + m) * NH + b * M * NHS); + if (biases) { + biases += (m * NH + n * head_size); + } + + const int MsH = max_sequence_length * head_size; + const int NMsH = N * MsH; + const int BNMsH = B * NMsH; + present += ((past_sequence_length + s) * head_size + n * MsH + b * NMsH + (m - 1) * BNMsH); + + for (int h = threadIdx.x; h < head_size; h += blockDim.x) { + T bias = (biases ? biases[h] : (T)0.0f); + present[h] = qkv[h] + bias; + } +} + +template +__global__ void AddBiasTransAppendKvToPresent( + const T* qkv, const T* biases, T* present, + const int head_size, const int past_sequence_length, const int max_sequence_length) { + // Input: BxSxMxNxH (Format 1) + // Output: (2, B, N, [P..P+S) of MaxS, H), + // B is batch_size, S is sequence_length, M is number of matrices, N is num_heads, H is head_size + const int n = blockIdx.x; + const int s = blockIdx.y; + const int b = (blockIdx.z >> 1); + const int N = gridDim.x; + const int S = gridDim.y; + const int B = (gridDim.z >> 1); + + constexpr int M = 3; // Matrix count in qkv + const int m = (blockIdx.z & 0x1) + 1; // k = 1, v = 2 + + const int NH = N * head_size; + const int NHS = NH * S; + + qkv += (n * head_size + (s * M + m) * NH + b * M * NHS); + if (biases) { + biases += (m * NH + n * head_size); + } + + const int MsH = max_sequence_length * head_size; + const int NMsH = N * MsH; + const int BNMsH = B * NMsH; + present += ((past_sequence_length + s) * head_size + n * MsH + b * NMsH + (m - 1) * BNMsH); + + for (int h = threadIdx.x; h < head_size; h += blockDim.x) { + T bias = (biases ? biases[h] : (T)0.0f); + present[h] = qkv[h] + bias; + } +} + +// qkv buffer is merged tensor of shape (B,S,3,N,H), k v is the second/third of the 3. +// bias is of shape (3, NxH) or nullptr +// append to present of (2, B, N, (P..T) of M, H), +template +Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream, + const int max_sequence_length, + const int past_sequence_length, + const int sequence_length, + const int batch_size, + const int head_size, + const int num_heads, + const int max_threads_per_block, + const T* biases, + const T* qkv_buffer, + T* present) { + assert(head_size <= (1 << 30)); + + int64_t nh = (int64_t)head_size * num_heads; + if (nh <= max_threads_per_block) { + const dim3 grid(sequence_length, batch_size, 2); // 2 for k and v + const dim3 block(max_threads_per_block / num_heads, num_heads, 1); + + AddBiasTransAppendKvToPresentSmall<<>>( + qkv_buffer, biases, present, head_size, past_sequence_length, max_sequence_length); + } else { + const dim3 grid(num_heads, sequence_length, batch_size * 2); // 2 for k and v + const dim3 block(std::min(head_size, max_threads_per_block), 1, 1); + AddBiasTransAppendKvToPresent<<>>( + qkv_buffer, biases, present, head_size, past_sequence_length, max_sequence_length); + } + + return CUDA_CALL(cudaGetLastError()); +} + +template Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream, + const int max_sequence_length, + const int total_sequence_length, + const int sequence_length, + const int batch_size, + const int head_size, + const int num_heads, + const int max_threads_per_block, + const float* bias, + const float* qkv_buffer, + float* present); + +template Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream, + const int max_sequence_length, + const int total_sequence_length, + const int sequence_length, + const int batch_size, + const int head_size, + const int num_heads, + const int max_threads_per_block, + const half* bias, + const half* qkv_buffer, + half* present); #endif } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu index cd4137ab11de9..5c65a30918ece 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu @@ -1,9 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include -#include "core/providers/cuda/cu_inc/common.cuh" #include "contrib_ops/cuda/bert/attention_impl.h" +#include "core/providers/cuda/cu_inc/common.cuh" #include "contrib_ops/cuda/bert/add_bias_transpose.h" #include "contrib_ops/cuda/transformers/dump_cuda_tensor.h" @@ -406,22 +405,25 @@ Status PrepareQkv_MHA_NotPacked(contrib::AttentionParameters& parameters, // Query (BxSxNxH) => Q (BxNxSxH) constexpr int format = 0; - LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, - batch_size, sequence_length, num_heads, qk_head_size, - data.query, data.bias, q, - true, -1); + LaunchAddBiasTranspose( + stream, 1, format, max_threads_per_block, + batch_size, sequence_length, num_heads, qk_head_size, + data.query, data.bias, q, + true, -1); // Key (BxLxNxH) => K (BxNxLxH) - LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, - batch_size, kv_sequence_length, num_heads, qk_head_size, - data.key, nullptr == data.bias ? nullptr : data.bias + num_heads * qk_head_size, k, - true, -1); + LaunchAddBiasTranspose( + stream, 1, format, max_threads_per_block, + batch_size, kv_sequence_length, num_heads, qk_head_size, + data.key, nullptr == data.bias ? nullptr : data.bias + num_heads * qk_head_size, k, + true, -1); // Value (BxLxNxH_v) => K (BxNxLxH_v) - LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, - batch_size, kv_sequence_length, num_heads, v_head_size, - data.value, nullptr == data.bias ? nullptr : data.bias + 2 * num_heads * qk_head_size, v, - true, -1); + LaunchAddBiasTranspose( + stream, 1, format, max_threads_per_block, + batch_size, kv_sequence_length, num_heads, v_head_size, + data.value, nullptr == data.bias ? nullptr : data.bias + 2 * num_heads * qk_head_size, v, + true, -1); DUMP_TENSOR_D("q(BNSH)", q, batch_size, num_heads, sequence_length, qk_head_size); DUMP_TENSOR_D("k(BNSH)", k, batch_size, num_heads, kv_sequence_length, qk_head_size); @@ -435,8 +437,8 @@ template Status PrepareQkv(contrib::AttentionParameters& parameters, AttentionData& data, cudaStream_t stream, - int max_threads_per_block, - QkvData& qkv) { + int max_threads_per_block) { + data.scratch = data.workspace; if (data.has_qkv_workspace) { const int size_per_batch_q = parameters.sequence_length * parameters.head_size; const int size_per_batch_k = parameters.kv_sequence_length * parameters.head_size; @@ -445,27 +447,27 @@ Status PrepareQkv(contrib::AttentionParameters& parameters, const size_t elements_q = static_cast(batches) * static_cast(size_per_batch_q); const size_t elements_k = static_cast(batches) * static_cast(size_per_batch_k); const size_t elements_v = static_cast(batches) * static_cast(size_per_batch_v); - qkv.q = data.workspace; - qkv.k = data.workspace + elements_q; - qkv.v = qkv.k + elements_k; - qkv.after_v = qkv.v + elements_v; + data.q = data.workspace; + data.k = data.workspace + elements_q; + data.v = data.k + elements_k; + data.scratch = data.v + elements_v; } if (nullptr != data.gemm_buffer) { // Attention operator ORT_RETURN_IF_ERROR(PrepareQkv_Attention(parameters, data, stream, max_threads_per_block, - qkv.format)); + data.qkv_format)); } else if (data.past_key != nullptr || data.present_key != nullptr) { // mha operator with past/present state ORT_RETURN_IF_ERROR(PrepareQkv_MHA_WithPast(parameters, data, stream, max_threads_per_block, - qkv.q, qkv.k, qkv.v, qkv.format)); + data.q, data.k, data.v, data.qkv_format)); } else if (data.key == nullptr) { // multihead attention operator, no past, packed qkv ORT_RETURN_IF_ERROR(PrepareQkv_MHA_PackedQKV(parameters, data, stream, max_threads_per_block, - qkv.q, qkv.k, qkv.v, qkv.format)); + data.q, data.k, data.v, data.qkv_format)); } else if (data.value == nullptr) { // multihead attention operator, no past, packed kv ORT_RETURN_IF_ERROR(PrepareQkv_MHA_PackedKV(parameters, data, stream, max_threads_per_block, - qkv.q, qkv.k, qkv.v, qkv.format)); + data.q, data.k, data.v, data.qkv_format)); } else { // multihead attention operator, no past, separated Q/K/V inputs ORT_RETURN_IF_ERROR(PrepareQkv_MHA_NotPacked(parameters, data, stream, max_threads_per_block, - qkv.q, qkv.k, qkv.v, qkv.format)); + data.q, data.k, data.v, data.qkv_format)); } CUDA_RETURN_IF_ERROR(cudaGetLastError()); @@ -477,15 +479,13 @@ template Status PrepareQkv( contrib::AttentionParameters& parameters, AttentionData& data, cudaStream_t stream, - int max_threads_per_block, - QkvData& qkv); + int max_threads_per_block); template Status PrepareQkv( contrib::AttentionParameters& parameters, AttentionData& data, cudaStream_t stream, - int max_threads_per_block, - QkvData& qkv); + int max_threads_per_block); } // namespace cuda } // namespace contrib diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu b/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu index e7d2255fb46b8..01ea02f48d3ab 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu @@ -18,7 +18,6 @@ limitations under the License. */ #include -#include #include #include "core/providers/cuda/cu_inc/common.cuh" #include "core/providers/cuda/cuda_common.h" diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc b/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc index f907d300607f9..3f703ae3d05e6 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc @@ -1,8 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "contrib_ops/cuda/bert/attention_impl.h" #include "contrib_ops/cuda/bert/decoder_attention.h" +#include "contrib_ops/cuda/bert/decoder_attention_impl.h" #include "contrib_ops/cuda/bert/transformer_cuda_common.h" #include "core/framework/op_kernel.h" #include "core/providers/cuda/shared_inc/fpgeneric.h" @@ -85,7 +85,8 @@ Status CheckInputs(const TensorShape& query_shape, } if (kv_weights_dims[0] != hidden_size || kv_weights_dims[1] != 2 * static_cast(hidden_size)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "kv_weights shall have shape (hidden size, 2 * hidden size)"); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "kv_weights shall have shape (hidden size, 2 * hidden size)"); } const auto& bias_dims = bias_shape.GetDims(); @@ -137,7 +138,8 @@ Status CheckInputs(const TensorShape& query_shape, const auto& value_cache_dims = value_cache->Shape().GetDims(); if (value_cache_dims.size() != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value_cache' is expected to have 4 dimension, got ", + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'value_cache' is expected to have 4 dimension, got ", value_cache_dims.size()); } @@ -353,10 +355,12 @@ Status DecoderAttention::ComputeInternal(OpKernelContext* context) const { } } - size_t bytes = element_size * batch_size * (static_cast(sequence_length) + static_cast(2) * kv_sequence_length) * hidden_size; + size_t bytes = element_size * batch_size * + (static_cast(sequence_length) + static_cast(2) * kv_sequence_length) * hidden_size; auto qkv_buffer_p = GetScratchBuffer(bytes, context->GetComputeStream()); - bytes = element_size * 2 * batch_size * sequence_length * num_heads_ * (static_cast(2) * head_size + static_cast(kv_sequence_length)); + bytes = element_size * 2 * batch_size * sequence_length * num_heads_ * + (static_cast(2) * head_size + static_cast(kv_sequence_length)); auto workspace_p = GetScratchBuffer(bytes, context->GetComputeStream()); Tensor* output(context->Output(0, query_shape)); diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu new file mode 100644 index 0000000000000..1dc22a9c8ea98 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu @@ -0,0 +1,263 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cuda/bert/decoder_attention_impl.h" +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/shared_inc/fpgeneric.h" +#include "contrib_ops/cuda/bert/attention_softmax.h" + +using namespace onnxruntime::contrib::attention_softmax_cuda; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +Status DecoderQkvToContext( + const cudaDeviceProp& device_prop, + Stream* ort_stream, + cublasHandle_t& cublas, + const size_t element_size, + const int batch_size, + const int sequence_length, + const int kv_sequence_length, + const int num_heads, + const int head_size, + const bool static_kv, + const bool use_past, + const bool has_layer_state, + const bool has_key_padding_mask, + const float mask_filter_value, + const T* gemm_query_buffer, + const T* gemm_kv_buffer, + const bool* key_padding_mask, + const T* key_cache, + const T* value_cache, + T* qkv_buffer, + T* workspace_buffer, + T* output, + T* new_key_cache, + T* new_value_cache) { + const int max_threads_per_block = device_prop.maxThreadsPerBlock; + const int BN = batch_size * num_heads; + const int BHN = BN * head_size; + const int BNS = BN * sequence_length; + const int k_buffer_offset = sequence_length * BHN; + const int v_buffer_offset = (sequence_length + kv_sequence_length) * BHN; + + T* temp_qkv_buffer = workspace_buffer; + auto stream = static_cast(ort_stream->GetHandle()); + + const T* q = qkv_buffer; + // transpose q and copy them to qkv_buffer + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, head_size, num_heads, + max_threads_per_block, true, gemm_query_buffer, qkv_buffer)); + + const T* k = qkv_buffer + k_buffer_offset; + const T* v = qkv_buffer + v_buffer_offset; + if (!has_layer_state || !use_past) { + if (!static_kv) { + // transpose kv and copy them to qkv_buffer + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, sequence_length, batch_size, head_size, num_heads, + max_threads_per_block, true, gemm_kv_buffer, qkv_buffer + k_buffer_offset)); + } else { + // transpose kv and copy them to qkv_buffer + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, kv_sequence_length, batch_size, head_size, num_heads, + max_threads_per_block, true, gemm_kv_buffer, qkv_buffer + k_buffer_offset)); + } + } else { + if (!static_kv) { + // transpose kv and copy them to temp_buffer + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, sequence_length, batch_size, head_size, num_heads, + max_threads_per_block, true, gemm_kv_buffer, temp_qkv_buffer)); + // concat cache-k with k and copy to qkv_buffer + if (nullptr != key_cache) { + ORT_RETURN_IF_ERROR(LaunchConcatTensorToTensor(stream, kv_sequence_length, + sequence_length, batch_size, head_size, num_heads, + max_threads_per_block, 1, + key_cache, + temp_qkv_buffer, + qkv_buffer + k_buffer_offset)); + } + // concat cache-v with v and copy to qkv_buffer + if (nullptr != value_cache) { + ORT_RETURN_IF_ERROR(LaunchConcatTensorToTensor(stream, kv_sequence_length, + sequence_length, batch_size, head_size, num_heads, + max_threads_per_block, 1, + value_cache, + temp_qkv_buffer + k_buffer_offset, + qkv_buffer + v_buffer_offset)); + } + } + } + + if (has_layer_state) { + if (use_past && static_kv) { + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(new_key_cache, key_cache, kv_sequence_length * BHN * sizeof(T), + cudaMemcpyDeviceToDevice, stream)); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(new_value_cache, value_cache, kv_sequence_length * BHN * sizeof(T), + cudaMemcpyDeviceToDevice, stream)); + } else { + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(new_key_cache, k, kv_sequence_length * BHN * sizeof(T), + cudaMemcpyDeviceToDevice, stream)); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(new_value_cache, v, kv_sequence_length * BHN * sizeof(T), + cudaMemcpyDeviceToDevice, stream)); + } + } + + // scratch1: BxNxSxL buffer + // scratch2: BxNxSxL buffer + // scratch3: BxNxSxH buffer + T* scratch1 = temp_qkv_buffer + 3 * BHN * sequence_length; + T* scratch2 = scratch1 + BNS * kv_sequence_length; + T* scratch3 = scratch2 + BNS * kv_sequence_length; + + // compute Q*K' (as K'*Q), scaled by 1/sqrt(H) and store in scratch1: BxNxSxL + // Q: BxNxSxH, K (present_k): BxNxLxH, Q*K': BxNxSxL + const float rsqrt_head_size = 1.f / sqrt(static_cast(head_size)); + const int temp_matrix_size = sequence_length * kv_sequence_length; + float one = 1.0f; + float zero = 0.f; + + float alpha = rsqrt_head_size; + const int strideA = kv_sequence_length * head_size; + const int strideB = sequence_length * head_size; + if (use_past && static_kv) { + CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( + cublas, CUBLAS_OP_T, CUBLAS_OP_N, + kv_sequence_length, sequence_length, head_size, + &alpha, key_cache, head_size, strideA, + q, head_size, strideB, + &zero, scratch1, kv_sequence_length, temp_matrix_size, BN, device_prop)); + } else { + CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( + cublas, CUBLAS_OP_T, CUBLAS_OP_N, + kv_sequence_length, sequence_length, head_size, + &alpha, k, head_size, strideA, + q, head_size, strideB, + &zero, scratch1, kv_sequence_length, temp_matrix_size, BN, device_prop)); + } + + constexpr bool is_unidirectional = false; + const T* add_before_softmax = nullptr; + if (has_key_padding_mask) { + constexpr int mask_dimension = 2; + constexpr int max_sequence_length = 0; + ORT_RETURN_IF_ERROR(ComputeSoftmaxWithRawMask( + ort_stream, kv_sequence_length, sequence_length, batch_size, + num_heads, nullptr, key_padding_mask, add_before_softmax, + false /*broadcast rpb*/, scratch1, scratch2, is_unidirectional, + 1.0f, mask_dimension, max_sequence_length, false, nullptr, + mask_filter_value)); + } else { + ORT_RETURN_IF_ERROR(ComputeSoftmax( + stream, kv_sequence_length, sequence_length, batch_size, num_heads, + add_before_softmax, false /*broadcast rpb*/, scratch1, scratch2, + is_unidirectional)); + } + + // compute P*V (as V*P), and store in scratch3: BxNxSxH + if (use_past && static_kv) { + CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( + cublas, CUBLAS_OP_N, CUBLAS_OP_N, + head_size, sequence_length, kv_sequence_length, + &one, value_cache, head_size, strideA, + scratch2, kv_sequence_length, temp_matrix_size, + &zero, scratch3, head_size, strideB, BN, device_prop)); + } else { + CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( + cublas, CUBLAS_OP_N, CUBLAS_OP_N, + head_size, sequence_length, kv_sequence_length, + &one, v, head_size, strideA, + scratch2, kv_sequence_length, temp_matrix_size, + &zero, scratch3, head_size, strideB, BN, device_prop)); + } + + // scratch3 is BxNxSxH, transpose to output SxBxNxH + return LaunchTransCtx(stream, sequence_length, batch_size, head_size, num_heads, + max_threads_per_block, true, scratch3, output); +} + +Status LaunchDecoderAttentionKernel( + const cudaDeviceProp& device_prop, + Stream* stream, + cublasHandle_t& cublas, + const size_t element_size, + const int batch_size, + const int sequence_length, + const int kv_sequence_length, + const int num_heads, + const int head_size, + const bool static_kv, + const bool use_past, + const bool has_layer_state, + const bool has_key_padding_mask, + const float mask_filter_value, + const void* gemm_query_buffer, + const void* gemm_kv_buffer, + const bool* key_padding_mask, + const void* key_cache, + const void* value_cache, + void* qkv_buffer, + void* workspace_buffer, + void* output, + void* new_key_cache, + void* new_value_cache) { + if (element_size == 2) { + return DecoderQkvToContext( + device_prop, + stream, + cublas, + element_size, + batch_size, + sequence_length, + kv_sequence_length, + num_heads, + head_size, + static_kv, + use_past, + has_layer_state, + has_key_padding_mask, + mask_filter_value, + reinterpret_cast(gemm_query_buffer), + reinterpret_cast(gemm_kv_buffer), + key_padding_mask, + reinterpret_cast(key_cache), + reinterpret_cast(value_cache), + reinterpret_cast(qkv_buffer), + reinterpret_cast(workspace_buffer), + reinterpret_cast(output), + reinterpret_cast(new_key_cache), + reinterpret_cast(new_value_cache)); + } else { + return DecoderQkvToContext( + device_prop, + stream, + cublas, + element_size, + batch_size, + sequence_length, + kv_sequence_length, + num_heads, + head_size, + static_kv, + use_past, + has_layer_state, + has_key_padding_mask, + mask_filter_value, + reinterpret_cast(gemm_query_buffer), + reinterpret_cast(gemm_kv_buffer), + key_padding_mask, + reinterpret_cast(key_cache), + reinterpret_cast(value_cache), + reinterpret_cast(qkv_buffer), + reinterpret_cast(workspace_buffer), + reinterpret_cast(output), + reinterpret_cast(new_key_cache), + reinterpret_cast(new_value_cache)); + } +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.h new file mode 100644 index 0000000000000..9db9ccb45e330 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.h @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "contrib_ops/cuda/bert/attention_impl.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +Status LaunchDecoderAttentionKernel( + const cudaDeviceProp& prop, // Device Properties + Stream* stream, // ORT Stream + cublasHandle_t& cublas, // Cublas handle + const size_t element_size, // Element size of input tensor + const int batch_size, // Batch size (B) + const int sequence_length, // Sequence length (S) + const int kv_sequence_length, // Key/Value/Cache sequence length + const int num_heads, // Number of attention heads (N) + const int head_size, // Hidden size per head (H) + const bool static_kv, // Whether cross attention or not + const bool use_past, // Whether use cache or not + const bool has_layer_state, // Whether output cache or not + const bool has_key_padding_mask, // Whether use key_padding_mask or not + const float mask_filter_value, // Mask filter value + const void* gemm_query_buffer, // Query buffer + const void* gemm_kv_buffer, // Key and value buffer + const bool* key_padding_mask, // Key padding mask + const void* key_cache, // Input key cache + const void* value_cache, // Input value cache + void* qkv_buffer, // Temporary buffer + void* workspace_buffer, // Temporary buffer + void* output, // Output tensor + void* new_key_cache, // New_key_cache tensor + void* new_value_cache // New_value_cache tensor +); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu b/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu index 9a150c9e6cd77..b0ed3ff82226a 100644 --- a/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu @@ -30,6 +30,7 @@ limitations under the License. #include "contrib_ops/cpu/bert/attention_base.h" #include "contrib_ops/rocm/bert/attention_impl.h" #include "contrib_ops/rocm/bert/attention_softmax.h" +#include "contrib_ops/rocm/bert/decoder_attention_impl.h" using namespace onnxruntime::rocm; diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_impl.h b/onnxruntime/contrib_ops/rocm/bert/attention_impl.h index 19b2bc34efaec..3164e8c211099 100644 --- a/onnxruntime/contrib_ops/rocm/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/rocm/bert/attention_impl.h @@ -28,34 +28,6 @@ size_t GetAttentionWorkspaceSize( int sequence_length, int past_sequence_length); -Status LaunchDecoderAttentionKernel( - const hipDeviceProp_t& prop, // Device Properties - RocmTuningContext* tuning_ctx, // context for tuning - Stream* stream, // ORT Stream - rocblas_handle& rocblas, // Rocblas handle - const size_t element_size, // Element size of input tensor - const int batch_size, // Batch size (B) - const int sequence_length, // Sequence length (S) - const int kv_sequence_length, // Key/Value/Cache sequence length - const int num_heads, // Number of attention heads (N) - const int head_size, // Hidden layer size per head (H) - const bool static_kv, // Whether cross attention or not - const bool use_past, // Whether use cache or not - const bool has_layer_state, // Whether output cache or not - const bool has_key_padding_mask, // Whether use key_padding_mask or not - const float mask_filter_value, // Mask filter value - const void* gemm_query_buffer, // Query buffer - const void* gemm_kv_buffer, // Key and value buffer - const bool* key_padding_mask, // Key padding mask - const void* key_cache, // Input key cache - const void* value_cache, // Input value cache - void* qkv_buffer, // Temporary buffer - void* workspace_buffer, // Temporary buffer - void* output, // Output tensor - void* new_key_cache, // New_key_cache tensor - void* new_value_cache // New_value_cache tensor -); - Status LaunchTransCtx(hipStream_t stream, const int sequence_length, const int batch_size, const int head_size, const int num_heads, const int max_threads_per_block, const bool reversed_bs, const float* input, float* output); diff --git a/onnxruntime/contrib_ops/rocm/bert/decoder_attention_impl.h b/onnxruntime/contrib_ops/rocm/bert/decoder_attention_impl.h new file mode 100644 index 0000000000000..d71c6d8440a44 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/decoder_attention_impl.h @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include "contrib_ops/cpu/bert/attention_common.h" +#include "core/providers/rocm/shared_inc/rocm_utils.h" +#include "core/providers/rocm/tunable/rocm_tunable.h" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +Status LaunchDecoderAttentionKernel( + const hipDeviceProp_t& prop, // Device Properties + RocmTuningContext* tuning_ctx, // context for tuning + Stream* stream, // ORT Stream + rocblas_handle& rocblas, // Rocblas handle + const size_t element_size, // Element size of input tensor + const int batch_size, // Batch size (B) + const int sequence_length, // Sequence length (S) + const int kv_sequence_length, // Key/Value/Cache sequence length + const int num_heads, // Number of attention heads (N) + const int head_size, // Hidden layer size per head (H) + const bool static_kv, // Whether cross attention or not + const bool use_past, // Whether use cache or not + const bool has_layer_state, // Whether output cache or not + const bool has_key_padding_mask, // Whether use key_padding_mask or not + const float mask_filter_value, // Mask filter value + const void* gemm_query_buffer, // Query buffer + const void* gemm_kv_buffer, // Key and value buffer + const bool* key_padding_mask, // Key padding mask + const void* key_cache, // Input key cache + const void* value_cache, // Input value cache + void* qkv_buffer, // Temporary buffer + void* workspace_buffer, // Temporary buffer + void* output, // Output tensor + void* new_key_cache, // New_key_cache tensor + void* new_value_cache // New_value_cache tensor +); + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index b7d26d87f2705..f0e5fbbd38721 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -1030,7 +1030,10 @@ Status SessionState::CreateSubgraphSessionState() { for (auto& node : graph_.Nodes()) { for (auto& entry : node.GetAttributeNameToMutableSubgraphMap()) { const auto& ep = node.GetExecutionProviderType(); - if (!ep.empty() && ep != kCpuExecutionProvider && ep != kCudaExecutionProvider && ep != kRocmExecutionProvider && ep != kDmlExecutionProvider) { + if (!ep.empty() && + ep != kCpuExecutionProvider && ep != kCudaExecutionProvider && + ep != kRocmExecutionProvider && ep != kDmlExecutionProvider && + ep != kJsExecutionProvider) { // SessionState is only used when ORT is executing the subgraph. If a non-ORT EP has taken the control flow // node containing the subgraph it will create whatever state it needs internally. continue; diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index 9dccd7c47fbb6..0674fe02d093d 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -318,7 +318,6 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Til class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 17, float, LayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 6, float, InstanceNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, float, InstanceNormalization); - class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, float, Einsum); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 2, 10, Pad); @@ -327,6 +326,11 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, 18, Pad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, Pad); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, If); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, If); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 18, If); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, If); + std::unique_ptr RegisterKernels() { auto kernel_registry = std::make_unique(); @@ -580,15 +584,17 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/core/providers/js/operators/binary.cc b/onnxruntime/core/providers/js/operators/binary.cc index 98f7ca6e613b0..e61cb1094736d 100644 --- a/onnxruntime/core/providers/js/operators/binary.cc +++ b/onnxruntime/core/providers/js/operators/binary.cc @@ -6,14 +6,13 @@ namespace onnxruntime { namespace js { -#define REG_ELEMENTWISE_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS) \ - ONNX_OPERATOR_KERNEL_EX( \ - OP_TYPE, \ - kOnnxDomain, \ - VERSION, \ - kJsExecutionProvider, \ - KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType(), \ - DataTypeImpl::GetTensorType()}), \ +#define REG_ELEMENTWISE_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS) \ + ONNX_OPERATOR_KERNEL_EX( \ + OP_TYPE, \ + kOnnxDomain, \ + VERSION, \ + kJsExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", JsepSupportedDataTypes()), \ KERNEL_CLASS); #define REG_ELEMENTWISE_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS) \ @@ -22,8 +21,7 @@ namespace js { kOnnxDomain, \ VERSION_FROM, VERSION_TO, \ kJsExecutionProvider, \ - KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType(), \ - DataTypeImpl::GetTensorType()}), \ + KernelDefBuilder().TypeConstraint("T", JsepSupportedDataTypes()), \ KERNEL_CLASS); JSEP_KERNEL_IMPL(Add, Add) diff --git a/onnxruntime/core/providers/js/operators/if.cc b/onnxruntime/core/providers/js/operators/if.cc new file mode 100644 index 0000000000000..ef072bb1635dd --- /dev/null +++ b/onnxruntime/core/providers/js/operators/if.cc @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "if.h" + +using namespace ONNX_NAMESPACE; +using namespace onnxruntime::common; + +namespace onnxruntime { +namespace js { + +ONNX_OPERATOR_VERSIONED_KERNEL_EX(If, + kOnnxDomain, + 1, 10, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 0) // 'cond' needs to be on CPU + .TypeConstraint("B", DataTypeImpl::GetTensorType()) + .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()), + If); +// output shape rules requiring the output shapes of the 'THEN' and 'ELSE' +// branches to be the same were relaxed in opset-11 +ONNX_OPERATOR_VERSIONED_KERNEL_EX(If, + kOnnxDomain, + 11, 12, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 0) // 'cond' needs to be on CPU + .TypeConstraint("B", DataTypeImpl::GetTensorType()) + .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()), + If); + +// opset-13 supports sequence type for If's subgraph outputs +ONNX_OPERATOR_VERSIONED_KERNEL_EX(If, + kOnnxDomain, + 13, 18, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 0) // 'cond' needs to be on CPU + .TypeConstraint("B", DataTypeImpl::GetTensorType()) + // Support sequence/optional tensors when all JSEP infra + // (including tests runner) supports it + .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()), + If); + +// opset-19 supports float8 +ONNX_OPERATOR_KERNEL_EX(If, + kOnnxDomain, + 19, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 0) // 'cond' needs to be on CPU + .TypeConstraint("B", DataTypeImpl::GetTensorType()) + // Support sequence/optional tensors when all JSEP infra + // (including tests runner) supports it + .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()), + If); + +Status If::Compute(OpKernelContext* ctx) const { + // call the base CPU version. + return onnxruntime::If::Compute(ctx); +} + +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/if.h b/onnxruntime/core/providers/js/operators/if.h new file mode 100644 index 0000000000000..d060444ccc1d2 --- /dev/null +++ b/onnxruntime/core/providers/js/operators/if.h @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include + +#include "core/providers/js/js_kernel.h" +#include "core/common/common.h" +#include "core/providers/cpu/controlflow/if.h" + +namespace onnxruntime { +class SessionState; + +namespace js { + +// Use the CPU implementation for the logic +class If final : public onnxruntime::If { + public: + If(const OpKernelInfo& info) : onnxruntime::If(info) {} + + Status Compute(OpKernelContext* ctx) const override; +}; +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/unary.cc b/onnxruntime/core/providers/js/operators/unary.cc index 5e972e43e4566..e9bbfabcf86bd 100644 --- a/onnxruntime/core/providers/js/operators/unary.cc +++ b/onnxruntime/core/providers/js/operators/unary.cc @@ -6,22 +6,29 @@ namespace onnxruntime { namespace js { -#define JSEP_ELEMENTWISE_KERNEL(OP_TYPE, VERSION, TYPE, KERNEL_CLASS) \ +#define JSEP_ELEMENTWISE_TYPED_KERNEL(OP_TYPE, VERSION, TYPE, KERNEL_CLASS) \ ONNX_OPERATOR_KERNEL_EX( \ OP_TYPE, kOnnxDomain, VERSION, kJsExecutionProvider, \ KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), \ KERNEL_CLASS); -#define JSEP_ELEMENTWISE_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, TYPE, KERNEL_CLASS) \ - ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ - OP_TYPE, kOnnxDomain, VERSION_FROM, VERSION_TO, kJsExecutionProvider, \ - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), \ +#define JSEP_ELEMENTWISE_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS) \ + ONNX_OPERATOR_KERNEL_EX( \ + OP_TYPE, kOnnxDomain, VERSION, kJsExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", JsepSupportedFloatTypes()), \ + KERNEL_CLASS); + +#define JSEP_ELEMENTWISE_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + OP_TYPE, kOnnxDomain, VERSION_FROM, VERSION_TO, kJsExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", JsepSupportedFloatTypes()), \ KERNEL_CLASS); #define JSEP_ELEMENTWISE_MULTI_TYPED_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS) \ ONNX_OPERATOR_KERNEL_EX( \ OP_TYPE, kOnnxDomain, VERSION, kJsExecutionProvider, \ KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType(), \ + DataTypeImpl::GetTensorType(), \ DataTypeImpl::GetTensorType()}), \ KERNEL_CLASS); @@ -29,6 +36,7 @@ namespace js { ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ OP_TYPE, kOnnxDomain, VERSION_FROM, VERSION_TO, kJsExecutionProvider, \ KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType(), \ + DataTypeImpl::GetTensorType(), \ DataTypeImpl::GetTensorType()}), \ KERNEL_CLASS); // math @@ -42,115 +50,115 @@ JSEP_ELEMENTWISE_MULTI_TYPED_VERSIONED_KERNEL(Neg, 6, 12, Neg) JSEP_ELEMENTWISE_MULTI_TYPED_KERNEL(Neg, 13, Neg) JSEP_KERNEL_IMPL(Floor, Floor) -JSEP_ELEMENTWISE_VERSIONED_KERNEL(Floor, 6, 12, float, Floor) -JSEP_ELEMENTWISE_KERNEL(Floor, 13, float, Floor) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(Floor, 6, 12, Floor) +JSEP_ELEMENTWISE_KERNEL(Floor, 13, Floor) JSEP_KERNEL_IMPL(Ceil, Ceil) -JSEP_ELEMENTWISE_VERSIONED_KERNEL(Ceil, 6, 12, float, Ceil) -JSEP_ELEMENTWISE_KERNEL(Ceil, 13, float, Ceil) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(Ceil, 6, 12, Ceil) +JSEP_ELEMENTWISE_KERNEL(Ceil, 13, Ceil) JSEP_KERNEL_IMPL(Reciprocal, Reciprocal) -JSEP_ELEMENTWISE_VERSIONED_KERNEL(Reciprocal, 6, 12, float, Reciprocal) -JSEP_ELEMENTWISE_KERNEL(Reciprocal, 13, float, Reciprocal) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(Reciprocal, 6, 12, Reciprocal) +JSEP_ELEMENTWISE_KERNEL(Reciprocal, 13, Reciprocal) JSEP_KERNEL_IMPL(Sqrt, Sqrt) -JSEP_ELEMENTWISE_VERSIONED_KERNEL(Sqrt, 6, 12, float, Sqrt) -JSEP_ELEMENTWISE_KERNEL(Sqrt, 13, float, Sqrt) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(Sqrt, 6, 12, Sqrt) +JSEP_ELEMENTWISE_KERNEL(Sqrt, 13, Sqrt) JSEP_KERNEL_IMPL(Exp, Exp) -JSEP_ELEMENTWISE_VERSIONED_KERNEL(Exp, 6, 12, float, Exp) -JSEP_ELEMENTWISE_KERNEL(Exp, 13, float, Exp) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(Exp, 6, 12, Exp) +JSEP_ELEMENTWISE_KERNEL(Exp, 13, Exp) JSEP_KERNEL_IMPL(Erf, Erf) -JSEP_ELEMENTWISE_VERSIONED_KERNEL(Erf, 9, 12, float, Erf) -JSEP_ELEMENTWISE_KERNEL(Erf, 13, float, Erf) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(Erf, 9, 12, Erf) +JSEP_ELEMENTWISE_KERNEL(Erf, 13, Erf) JSEP_KERNEL_IMPL(Sigmoid, Sigmoid) -JSEP_ELEMENTWISE_VERSIONED_KERNEL(Sigmoid, 6, 12, float, Sigmoid) -JSEP_ELEMENTWISE_KERNEL(Sigmoid, 13, float, Sigmoid) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(Sigmoid, 6, 12, Sigmoid) +JSEP_ELEMENTWISE_KERNEL(Sigmoid, 13, Sigmoid) JSEP_KERNEL_IMPL(Log, Log) -JSEP_ELEMENTWISE_VERSIONED_KERNEL(Log, 6, 12, float, Log) -JSEP_ELEMENTWISE_KERNEL(Log, 13, float, Log) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(Log, 6, 12, Log) +JSEP_ELEMENTWISE_KERNEL(Log, 13, Log) JSEP_KERNEL_IMPL(Sin, Sin) -JSEP_ELEMENTWISE_KERNEL(Sin, 7, float, Sin) +JSEP_ELEMENTWISE_KERNEL(Sin, 7, Sin) JSEP_KERNEL_IMPL(Cos, Cos) -JSEP_ELEMENTWISE_KERNEL(Cos, 7, float, Cos) +JSEP_ELEMENTWISE_KERNEL(Cos, 7, Cos) JSEP_KERNEL_IMPL(Tan, Tan) -JSEP_ELEMENTWISE_KERNEL(Tan, 7, float, Tan) +JSEP_ELEMENTWISE_KERNEL(Tan, 7, Tan) JSEP_KERNEL_IMPL(Asin, Asin) -JSEP_ELEMENTWISE_KERNEL(Asin, 7, float, Asin) +JSEP_ELEMENTWISE_KERNEL(Asin, 7, Asin) JSEP_KERNEL_IMPL(Acos, Acos) -JSEP_ELEMENTWISE_KERNEL(Acos, 7, float, Acos) +JSEP_ELEMENTWISE_KERNEL(Acos, 7, Acos) JSEP_KERNEL_IMPL(Atan, Atan) -JSEP_ELEMENTWISE_KERNEL(Atan, 7, float, Atan) +JSEP_ELEMENTWISE_KERNEL(Atan, 7, Atan) JSEP_KERNEL_IMPL(Sinh, Sinh) -JSEP_ELEMENTWISE_KERNEL(Sinh, 9, float, Sinh) +JSEP_ELEMENTWISE_KERNEL(Sinh, 9, Sinh) JSEP_KERNEL_IMPL(Cosh, Cosh) -JSEP_ELEMENTWISE_KERNEL(Cosh, 9, float, Cosh) +JSEP_ELEMENTWISE_KERNEL(Cosh, 9, Cosh) JSEP_KERNEL_IMPL(Asinh, Asinh) -JSEP_ELEMENTWISE_KERNEL(Asinh, 9, float, Asinh) +JSEP_ELEMENTWISE_KERNEL(Asinh, 9, Asinh) JSEP_KERNEL_IMPL(Acosh, Acosh) -JSEP_ELEMENTWISE_KERNEL(Acosh, 9, float, Acosh) +JSEP_ELEMENTWISE_KERNEL(Acosh, 9, Acosh) JSEP_KERNEL_IMPL(Atanh, Atanh) -JSEP_ELEMENTWISE_KERNEL(Atanh, 9, float, Atanh) +JSEP_ELEMENTWISE_KERNEL(Atanh, 9, Atanh) JSEP_KERNEL_IMPL(Tanh, Tanh) -JSEP_ELEMENTWISE_VERSIONED_KERNEL(Tanh, 6, 12, float, Tanh) -JSEP_ELEMENTWISE_KERNEL(Tanh, 13, float, Tanh) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(Tanh, 6, 12, Tanh) +JSEP_ELEMENTWISE_KERNEL(Tanh, 13, Tanh) JSEP_KERNEL_IMPL(Not, Not) -JSEP_ELEMENTWISE_KERNEL(Not, 1, bool, Not) +JSEP_ELEMENTWISE_TYPED_KERNEL(Not, 1, bool, Not) // activation JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_2_DEFAULT(ClipV10, ClipV10, min, 3.402823e+38f, max, -3.402823e+38f) -JSEP_ELEMENTWISE_VERSIONED_KERNEL(Clip, 6, 10, float, ClipV10) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(Clip, 6, 10, ClipV10) JSEP_KERNEL_IMPL(Clip, Clip) ONNX_OPERATOR_VERSIONED_KERNEL_EX(Clip, kOnnxDomain, 11, 11, kJsExecutionProvider, KernelDefBuilder() - .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T", JsepSupportedFloatTypes()) .InputMemoryType(OrtMemTypeCPU, 1) .InputMemoryType(OrtMemTypeCPU, 2), Clip); ONNX_OPERATOR_VERSIONED_KERNEL_EX(Clip, kOnnxDomain, 12, 12, kJsExecutionProvider, KernelDefBuilder() - .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T", JsepSupportedFloatTypes()) .InputMemoryType(OrtMemTypeCPU, 1) .InputMemoryType(OrtMemTypeCPU, 2), Clip); ONNX_OPERATOR_KERNEL_EX(Clip, kOnnxDomain, 13, kJsExecutionProvider, KernelDefBuilder() - .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T", JsepSupportedFloatTypes()) .InputMemoryType(OrtMemTypeCPU, 1) .InputMemoryType(OrtMemTypeCPU, 2), Clip); JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_DEFAULT(Elu, Elu, alpha, 1.0) -JSEP_ELEMENTWISE_KERNEL(Elu, 6, float, Elu) +JSEP_ELEMENTWISE_KERNEL(Elu, 6, Elu) JSEP_KERNEL_IMPL(Relu, Relu) -JSEP_ELEMENTWISE_VERSIONED_KERNEL(Relu, 6, 12, float, Relu) -JSEP_ELEMENTWISE_VERSIONED_KERNEL(Relu, 13, 13, float, Relu) -JSEP_ELEMENTWISE_KERNEL(Relu, 14, float, Relu) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(Relu, 6, 12, Relu) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(Relu, 13, 13, Relu) +JSEP_ELEMENTWISE_KERNEL(Relu, 14, Relu) JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_DEFAULT(LeakyRelu, LeakyRelu, alpha, 0.01) -JSEP_ELEMENTWISE_VERSIONED_KERNEL(LeakyRelu, 6, 15, float, LeakyRelu) -JSEP_ELEMENTWISE_KERNEL(LeakyRelu, 16, float, LeakyRelu) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(LeakyRelu, 6, 15, LeakyRelu) +JSEP_ELEMENTWISE_KERNEL(LeakyRelu, 16, LeakyRelu) JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_DEFAULT(ThresholdedRelu, ThresholdedRelu, alpha, 1.0) -JSEP_ELEMENTWISE_KERNEL(ThresholdedRelu, 10, float, ThresholdedRelu) +JSEP_ELEMENTWISE_KERNEL(ThresholdedRelu, 10, ThresholdedRelu) } // namespace js } // namespace onnxruntime diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 15fe5acfe0fd2..4c0adcdd374aa 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2397,7 +2397,7 @@ Second example, if we wanted to add and remove some members, we'd do this: In GetApi we now make it return ort_api_3 for version 3. */ -static constexpr OrtApi ort_api_1_to_16 = { +static constexpr OrtApi ort_api_1_to_17 = { // NOTE: The ordering of these fields MUST not change after that version has shipped since existing binaries depend on this ordering. // Shipped as version 1 - DO NOT MODIFY (see above text for more information) @@ -2745,16 +2745,16 @@ static_assert(offsetof(OrtApi, GetBuildInfoString) / sizeof(void*) == 254, "Size static_assert(offsetof(OrtApi, GetCUDAProviderOptionsByName) / sizeof(void*) == 264, "Size of version 16 API cannot change"); // So that nobody forgets to finish an API version, this check will serve as a reminder: -static_assert(std::string_view(ORT_VERSION) == "1.16.0", +static_assert(std::string_view(ORT_VERSION) == "1.17.0", "ORT_Version change detected, please follow below steps to ensure OrtApi is updated properly"); // 1. Update the hardcoded version string in above static_assert to silence it -// 2. If there were any APIs added to ort_api_1_to_16 above: +// 2. If there were any APIs added to ort_api_1_to_17 above: // a. Add the 'End of version #' markers (pattern above should be obvious) // b. Add a static_assert in the directly above list of version sizes to ensure nobody adds any more functions to the just shipped API version ORT_API(const OrtApi*, OrtApis::GetApi, uint32_t version) { if (version >= 1 && version <= ORT_API_VERSION) - return &ort_api_1_to_16; + return &ort_api_1_to_17; fprintf(stderr, "The requested API version [%u] is not available, only API versions [1, %u] are supported in this build." diff --git a/onnxruntime/python/tools/transformers/convert_generation.py b/onnxruntime/python/tools/transformers/convert_generation.py index 63c991167d235..c0cabbb5e9759 100644 --- a/onnxruntime/python/tools/transformers/convert_generation.py +++ b/onnxruntime/python/tools/transformers/convert_generation.py @@ -993,7 +993,11 @@ def get_shared_initializers(encoder_model: ModelProto, decoder_model: ModelProto encoder.remove_duplicated_initializer(signature_cache1) decoder.remove_duplicated_initializer(signature_cache2) initializers = remove_shared_initializers( - decoder.model.graph, encoder.model.graph, "s_", signature_cache1, signature_cache2 + decoder.model.graph, + encoder.model.graph, + shared_prefix="s_", + signature_cache1=signature_cache1, + signature_cache2=signature_cache2, ) return initializers diff --git a/onnxruntime/test/python/test_pytorch_export_contrib_ops.py b/onnxruntime/test/python/test_pytorch_export_contrib_ops.py index a378721932a35..5e20d6b4e692a 100644 --- a/onnxruntime/test/python/test_pytorch_export_contrib_ops.py +++ b/onnxruntime/test/python/test_pytorch_export_contrib_ops.py @@ -49,9 +49,7 @@ def to_numpy(tensor): # These set of tests verify ONNX model export and compares outputs between # PyTorch and ORT. class ONNXExporterTest(unittest.TestCase): - from torch.onnx.symbolic_helper import _export_onnx_opset_version - - opset_version = _export_onnx_opset_version + opset_version = 17 keep_initializers_as_inputs = True # For IR version 3 type export. def setUp(self): diff --git a/tools/ci_build/github/azure-pipelines/npm-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/npm-packaging-pipeline.yml index ec85002503c0e..2e7ac9508a41e 100644 --- a/tools/ci_build/github/azure-pipelines/npm-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/npm-packaging-pipeline.yml @@ -73,7 +73,8 @@ stages: project: '530acbc4-21bc-487d-8cd8-348ff451d2ff' definition: '940' specificBuildWithTriggering: true - buildVersionToDownload: 'latest' + buildVersionToDownload: 'latestFromBranch' + branchName: 'refs/heads/main' artifactName: 'NPM_packages' targetPath: '$(Pipeline.Workspace)' displayName: 'Download onnxruntime-node Pipeline Artifact' diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-build-steps.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-build-steps.yml index a81dd1e9cf240..2f67398908d5d 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-build-steps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-build-steps.yml @@ -12,6 +12,10 @@ parameters: type: string default: "$(Agent.TempDirectory)/ort_ccache" +- name: DebugCache + type: boolean + default: false + - name: AdditionalKey type: string default: "" @@ -45,6 +49,18 @@ parameters: steps: - ${{ if eq(parameters.WithCache, true) }}: + - powershell: | + if ([string]::IsNullOrEmpty((Get-Command ccache -errorAction SilentlyContinue))) + { + choco install ccache -y --version 4.7.4 + $ccache_path = (Get-Command ccache).Source + $ccache_parent_dir = (Split-Path -parent $ccache_path) + Copy-Item "C:\ProgramData\chocolatey\lib\ccache\tools\ccache-4.7.4-windows-x86_64\ccache.exe" -Destination "C:\ProgramData\chocolatey\bin\cl.exe" + Get-ChildItem $ccache_parent_dir + ccache --version + } + displayName: Install ccache + - task: Cache@2 inputs: ${{if eq(variables['Build.SourceBranchName'], 'merge')}}: @@ -83,6 +99,11 @@ steps: createLogFile: true env: CCACHE_DIR: ${{parameters.CacheDir}} + CCACHE_SLOPPINESS: file_macro,time_macros,include_file_mtime,include_file_ctime + CCACHE_COMPILERCHECK: content + ${{if eq(parameters.DebugCache, true)}}: + CCACHE_DEBUG: 1 + CCACHE_DEBUGDIR: $(Agent.TempDirectory)/cache_debug - ${{ if eq(parameters.WithCache, true) }}: - powershell: | @@ -91,3 +112,10 @@ steps: displayName: cache stat env: CCACHE_DIR: ${{parameters.CacheDir}} + + - ${{if eq(parameters.DebugCache, true)}}: + - task: PublishPipelineArtifact@0 + displayName: 'publish cache log' + inputs: + artifactName: 'cache-log' + targetPath: $(Agent.TempDirectory)/cache_debug diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-prebuild-steps.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-prebuild-steps.yml index e29d9d2cab91c..8868e671a5fa5 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-prebuild-steps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-prebuild-steps.yml @@ -1,6 +1,7 @@ parameters: - name: EnvSetupScript type: string + default: setup_env.bat - name: BuildConfig type: string diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml index 7db0c7302cd6f..68e0d51480a63 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml @@ -46,7 +46,8 @@ jobs: BuildConfig: 'RelWithDebInfo' ALLOW_RELEASED_ONNX_OPSET_ONLY: '1' QNN_SDK_ROOT: 'C:\data\qnnsdk\${{parameters.QnnSdk}}' - timeoutInMinutes: 150 + TODAY: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] + timeoutInMinutes: 120 workspace: clean: all steps: @@ -70,29 +71,19 @@ jobs: # '_Ret &std::_Visit_strategy<1>::_Visit2<_Ret,_ListOfIndexVectors,_Callable,const # std::variant&>(size_t,_Callable &&, # const std::variant &)' - - task: PythonScript@0 - displayName: 'Generate cmake config' - inputs: - scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' - arguments: '--config $(BuildConfig) --build_dir $(Build.BinariesDirectory) --compile_no_warning_as_error --update --cmake_generator "Visual Studio 17 2022" --use_qnn --qnn_home $(QNN_SDK_ROOT) --parallel' - workingDirectory: '$(Build.BinariesDirectory)' - - - task: VSBuild@1 - displayName: 'Build' - inputs: - solution: '$(Build.BinariesDirectory)\$(BuildConfig)\onnxruntime.sln' - platform: 'x64' - configuration: $(BuildConfig) - msbuildArgs: $(MsbuildArguments) - msbuildArchitecture: $(buildArch) - maximumCpuCount: true - logProjectEvents: false - workingFolder: '$(Build.BinariesDirectory)\$(BuildConfig)' - createLogFile: true + - template: templates/jobs/win-ci-build-steps.yml + parameters: + WithCache: True + Today: $(TODAY) + AdditionalKey: "win-qnn | $(BuildConfig)" + BuildPyArguments: '--config $(BuildConfig) --build_dir $(Build.BinariesDirectory) --compile_no_warning_as_error --update --cmake_generator "Visual Studio 17 2022" --use_qnn --qnn_home $(QNN_SDK_ROOT) --parallel' + MsbuildArguments: $(MsbuildArguments) + BuildArch: $(buildArch) + Platform: 'x64' + BuildConfig: $(BuildConfig) - powershell: | python $(Build.SourcesDirectory)\tools\ci_build\build.py --config $(BuildConfig) --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --test --cmake_generator "Visual Studio 17 2022" --enable_onnx_tests - workingDirectory: '$(Build.BinariesDirectory)\$(BuildConfig)\$(BuildConfig)' displayName: 'Run unit tests' diff --git a/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_lort.sh b/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_lort.sh index b2e575f38e425..3bca6413100a2 100755 --- a/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_lort.sh +++ b/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_lort.sh @@ -20,6 +20,10 @@ export ONNX_ML=1 export CMAKE_ARGS="-DONNX_GEN_PB_TYPE_STUBS=OFF -DONNX_WERROR=OFF" /opt/python/cp39-cp39/bin/python3.9 -m pip install transformers +# beartype is installed here so that onnxscript installation step won't +# install a version PyTorch doesn't like. Once beartype fixes this problem. +# We can remove this line. +/opt/python/cp39-cp39/bin/python3.9 -m pip install beartype==0.15.0 cd /usr/local/ echo "Cloning ONNX Script" diff --git a/tools/python/util/logger.py b/tools/python/util/logger.py index 9deb4475721ee..15e04528ac7ac 100644 --- a/tools/python/util/logger.py +++ b/tools/python/util/logger.py @@ -5,6 +5,7 @@ def get_logger(name): - logging.basicConfig(format="%(asctime)s %(name)s [%(levelname)s] - %(message)s", level=logging.DEBUG) - - return logging.getLogger(name) + logging.basicConfig(format="%(asctime)s %(name)s [%(levelname)s] - %(message)s") + logger = logging.getLogger(name) + logger.setLevel(logging.DEBUG) + return logger