From c65e892089e2e6f383e80339334450ac06be5ba0 Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Wed, 20 Sep 2023 10:35:15 -0700 Subject: [PATCH 01/58] [CUDA] Fix performance bug in DecoderMaskedMultiheadAttention for BeamSearch (#17613) --- ...decoder_masked_multihead_attention_impl.cu | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu index 5827bdfee1ab5..c8877a5e3f872 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu @@ -174,7 +174,6 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio q = add_vec(q, q_bias); } - T* params_k_cache = reinterpret_cast(params.k_cache); const float inv_sqrt_dh = params.scale; @@ -350,24 +349,22 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio // The keys loaded from the key cache. K_vec_k k_vec[K_VECS_PER_THREAD]; + if (ti < tlength) { + if (has_beams) { + const int beam_offset = beam_indices[ti] * params.num_heads * params.max_sequence_length * head_size; - if (has_beams) { #pragma unroll - for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { - int jj = ii * params.max_sequence_length + ti; + for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { + int jj = ii * params.max_sequence_length + ti; - if (ti < tlength) { - const int beam_offset = beam_indices[ti] * params.num_heads * params.max_sequence_length * head_size; k_vec[ii] = vec_conversion( (*reinterpret_cast(&k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B]))); } - } - } else { + } else { #pragma unroll - for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { - int jj = ii * params.max_sequence_length + ti; + for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { + int jj = ii * params.max_sequence_length + ti; - if (ti < tlength) { k_vec[ii] = vec_conversion( (*reinterpret_cast(&k_cache_batch[jj * QK_ELTS_IN_16B]))); } From c55da45e20435b8aa9edb78179b6027502b778b0 Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Wed, 20 Sep 2023 14:31:01 -0700 Subject: [PATCH 02/58] [QNN EP] Add more op unit tests (fix Clip, TopK, Tile) (#17457) ### Description Adds more operator unit tests (all op types should now have at least 1 unit test): - [x] Reshape - [x] Flatten - [x] Squeeze - [x] Unsqueeze - [x] Gemm - [x] Clip - Enable QDQ Clip on HTP backend (when not optimized away by L1 ClipQuantFusion optimizer) - Add support for 16-bit QDQ Clip to ClipQuantFusion optimizer - [x] Split - [x] Topk - Enable QDQ TopK on HTP backend - [x] Tile - Enable QDQ Tile on HTP backend ### Motivation and Context Increase QNN operator support and test coverage. --- .../qdq_transformer/clip_quantizelinear.cc | 25 +- .../selectors_actions/qdq_selectors.cc | 36 ++ .../selectors_actions/qdq_selectors.h | 8 + .../selectors_actions/shared/utils.cc | 18 +- .../qnn/builder/opbuilder/clip_op_builder.cc | 127 +++--- .../providers/qnn/builder/opbuilder/topk.cc | 15 +- .../test/optimizer/qdq_transformer_test.cc | 7 + .../test/providers/qnn/average_pool_test.cc | 6 +- .../test/providers/qnn/clip_op_test.cc | 188 +++++++++ .../test/providers/qnn/flatten_op_test.cc | 202 +++++++++ .../test/providers/qnn/gather_op_htp_test.cc | 64 +-- .../test/providers/qnn/gemm_op_test.cc | 341 +++++++++++++++ .../providers/qnn/instance_norm_htp_test.cc | 21 +- .../test/providers/qnn/layer_norm_test.cc | 4 +- .../providers/qnn/leakyrelu_op_htp_test.cc | 46 +-- .../test/providers/qnn/max_min_op_test.cc | 9 +- .../test/providers/qnn/pool_op_test.cpp | 19 +- .../test/providers/qnn/qnn_test_utils.cc | 14 +- .../test/providers/qnn/qnn_test_utils.h | 82 ++-- .../test/providers/qnn/reshape_op_test.cc | 225 ++++++++++ .../test/providers/qnn/simple_op_htp_test.cc | 16 +- .../test/providers/qnn/slice_htp_test.cc | 64 +-- .../test/providers/qnn/split_op_test.cc | 387 ++++++++++++++++++ .../qnn/squeeze_unsqueeze_op_test.cc | 324 +++++++++++++++ .../test/providers/qnn/tile_op_test.cc | 132 ++++++ .../test/providers/qnn/topk_op_test.cc | 209 ++++++++++ 26 files changed, 2273 insertions(+), 316 deletions(-) create mode 100644 onnxruntime/test/providers/qnn/clip_op_test.cc create mode 100644 onnxruntime/test/providers/qnn/flatten_op_test.cc create mode 100644 onnxruntime/test/providers/qnn/gemm_op_test.cc create mode 100644 onnxruntime/test/providers/qnn/reshape_op_test.cc create mode 100644 onnxruntime/test/providers/qnn/split_op_test.cc create mode 100644 onnxruntime/test/providers/qnn/squeeze_unsqueeze_op_test.cc create mode 100644 onnxruntime/test/providers/qnn/tile_op_test.cc create mode 100644 onnxruntime/test/providers/qnn/topk_op_test.cc diff --git a/onnxruntime/core/optimizer/qdq_transformer/clip_quantizelinear.cc b/onnxruntime/core/optimizer/qdq_transformer/clip_quantizelinear.cc index a0942c31b0161..50653b368857d 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/clip_quantizelinear.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/clip_quantizelinear.cc @@ -1,8 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/optimizer/initializer.h" #include "core/optimizer/qdq_transformer/clip_quantizelinear.h" + +#include + +#include "core/optimizer/initializer.h" #include "core/optimizer/qdq_transformer/qdq_util.h" #include "core/optimizer/utils.h" #include "core/graph/graph_utils.h" @@ -50,14 +53,26 @@ static bool GetQConstantLowerUpper(const Graph& graph, const Node& node, float& switch (zp_initializer.data_type()) { case ONNX_NAMESPACE::TensorProto_DataType_INT8: { const int8_t zero_point = zp_initializer.data()[0]; - lower = scale * (-128 - zero_point); - upper = scale * (127 - zero_point); + lower = scale * (std::numeric_limits::lowest() - zero_point); + upper = scale * (std::numeric_limits::max() - zero_point); break; } case ONNX_NAMESPACE::TensorProto_DataType_UINT8: { const uint8_t zero_point = zp_initializer.data()[0]; - lower = scale * (0 - zero_point); - upper = scale * (255 - zero_point); + lower = scale * (std::numeric_limits::lowest() - zero_point); + upper = scale * (std::numeric_limits::max() - zero_point); + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_INT16: { + const int16_t zero_point = zp_initializer.data()[0]; + lower = scale * (std::numeric_limits::lowest() - zero_point); + upper = scale * (std::numeric_limits::max() - zero_point); + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_UINT16: { + const uint16_t zero_point = zp_initializer.data()[0]; + lower = scale * (std::numeric_limits::lowest() - zero_point); + upper = scale * (std::numeric_limits::max() - zero_point); break; } default: diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index 16c7bd5fce960..5015e48fdb7b8 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -496,6 +496,42 @@ bool LogicalComparisonNodeGroupSelector::Check(const GraphViewer& graph_viewer, return dt_input_1 == dt_input_2; } +bool TopKNodeGroupSelector::Check(const GraphViewer& graph_viewer, + const Node& node, + const std::vector& dq_nodes, + const std::vector& q_nodes) const { + constexpr int num_dq_inputs = 1; + constexpr int num_q_outputs = 1; + if (num_dq_inputs != gsl::narrow_cast(dq_nodes.size())) { + return false; + } + + if (const auto dq_validation_status = QDQ::ValidateNodeGroupDQNodes(graph_viewer, node, dq_nodes); + !dq_validation_status.IsOK()) { + return false; + } + + if (num_q_outputs != gsl::narrow_cast(q_nodes.size())) { + return false; + } + + const Node& dq_node = *dq_nodes.front(); + const Node& q_node = *q_nodes.front(); + + int32_t dt_input = dq_node.InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + int32_t dt_output = q_node.OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + + if (dt_input != dt_output) { + return false; + } + + auto get_const_initializer = [&graph_viewer](const std::string& initializer_name) { + return graph_viewer.GetConstantInitializer(initializer_name, true); + }; + + return IsQDQPairSupported(q_node, dq_node, get_const_initializer, graph_viewer.ModelPath()); +} + } // namespace QDQ } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h index d8fefdd8dc3d9..be7f7e0288eda 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h @@ -220,6 +220,14 @@ class LogicalComparisonNodeGroupSelector : public NodeGroupSelector { const std::vector& q_nodes) const override; }; +// TopK has 1 DQ input node and 1 Q output node. +// Zero point and scale are constant scalars and must match +class TopKNodeGroupSelector : public NodeGroupSelector { + bool Check(const GraphViewer& graph_viewer, const Node& node, + const std::vector& dq_nodes, + const std::vector& q_nodes) const override; +}; + /* * NodeSelector instances for use in the QDQ::SelectorActionTransformer. */ diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc index f1bdd7a99c329..3f1b2f0458bc0 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc @@ -36,7 +36,8 @@ static const OpVersionsAndSelector::OpVersionsMap GetMiscOpVersionsMap() { {"Resize", {}}, {"Split", {}}, {"Squeeze", {}}, - {"Unsqueeze", {}}}; + {"Unsqueeze", {}}, + {"Tile", {}}}; } static const OpVersionsAndSelector::OpVersionsMap GetDropDQOpVersionsMap() { @@ -78,7 +79,8 @@ static const OpVersionsAndSelector::OpVersionsMap GetUnaryOpVersionsMap() { {"Abs", {}}, {"Neg", {}}, {"DepthToSpace", {}}, - {"SpaceToDepth", {}}}; + {"SpaceToDepth", {}}, + {"Clip", {}}}; } static const OpVersionsAndSelector::OpVersionsMap GetBinaryOpVersionsMap() { return {{"Add", {}}, @@ -127,6 +129,10 @@ static const OpVersionsAndSelector::OpVersionsMap GetPadOpVersionsMap() { return {{"Pad", {}}}; } +static const OpVersionsAndSelector::OpVersionsMap GetTopKOpVersionsMap() { + return {{"TopK", {}}}; +} + /* Selector rules registration related */ void RegisterMiscSelectors(Selectors& qdq_selectors) { /* register selectors for miscellaneous ops */ @@ -227,6 +233,13 @@ void RegisterPadSelectors(Selectors& qdq_selectors) { std::move(selector)); } +void RegisterTopKSelector(Selectors& qdq_selectors) { + /* register selector for TopK op */ + std::unique_ptr selector = std::make_unique(); + qdq_selectors.RegisterSelector(GetTopKOpVersionsMap(), + std::move(selector)); +} + void SelectorManager::CreateSelectors() { RegisterMiscSelectors(qdq_selectors_); RegisterDropDQSelectors(qdq_selectors_); @@ -242,6 +255,7 @@ void SelectorManager::CreateSelectors() { RegisterLogicalComparisonSelectors(qdq_selectors_); RegisterWhereSelectors(qdq_selectors_); RegisterPadSelectors(qdq_selectors_); + RegisterTopKSelector(qdq_selectors_); } void SelectorManager::InitializeSelectorsMap() { diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/clip_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/clip_op_builder.cc index 92a7feea7fc54..df4c718949269 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/clip_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/clip_op_builder.cc @@ -1,6 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include +#include + #include "core/providers/common.h" #include "core/providers/shared/utils/utils.h" #include "core/providers/qnn/builder/qnn_model_wrapper.h" @@ -9,8 +12,6 @@ #include "base_op_builder.h" -#include - namespace onnxruntime { namespace qnn { class ClipOpBuilder : public BaseOpBuilder { @@ -33,8 +34,6 @@ class ClipOpBuilder : public BaseOpBuilder { private: Status ExplictOpCheck(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit) const; - mutable float min_value_ = std::numeric_limits::lowest(); - mutable float max_value_ = std::numeric_limits::max(); }; Status ClipOpBuilder::ExplictOpCheck(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit) const { @@ -61,61 +60,8 @@ Status ClipOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, if (do_op_validation) { ORT_RETURN_IF_ERROR(ExplictOpCheck(qnn_model_wrapper, node_unit)); } - Qnn_DataType_t qnn_data_type = QNN_DATATYPE_FLOAT_32; - - auto inputs = node_unit.Inputs(); - for (size_t input_i = 0; input_i < inputs.size(); ++input_i) { - Qnn_QuantizeParams_t quantize_param = QNN_QUANTIZE_PARAMS_INIT; - bool is_quantized_tensor = inputs[input_i].quant_param.has_value(); - utils::InitializeQuantizeParam(quantize_param, is_quantized_tensor); - - auto& input_name = inputs[input_i].node_arg.Name(); - if (input_name.empty()) { - // Ignore unspecified/unused optional input - continue; - } - if (qnn_model_wrapper.IsQnnTensorWrapperExist(input_name)) { - LOGS(logger, VERBOSE) << "Tensor already added or the input is not named, skip it: " << input_name; - input_names.push_back(input_name); - continue; - } - - const auto* type_proto = inputs[input_i].node_arg.TypeAsProto(); - ORT_RETURN_IF_ERROR(utils::GetQnnDataType(is_quantized_tensor, type_proto, qnn_data_type)); - - std::vector input_shape; - ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[input_i].node_arg, input_shape), "Cannot get shape"); - - ORT_RETURN_IF_NOT(qnn_model_wrapper.ProcessQuantizationParameter(inputs[input_i].quant_param, - quantize_param.scaleOffsetEncoding.scale, - quantize_param.scaleOffsetEncoding.offset), - "Cannot get quantization parameter"); - - float* ini_data = nullptr; - std::vector unpacked_tensor; - bool is_initializer_input = qnn_model_wrapper.IsInitializerInput(input_name); - if (is_initializer_input) { - const auto& input_tensor = qnn_model_wrapper.GetInitializerTensors().at(input_name); - ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*input_tensor, unpacked_tensor)); - ini_data = reinterpret_cast(unpacked_tensor.data()); - if (input_i == 1) { - min_value_ = *ini_data; - continue; - } else if (input_i == 2) { - max_value_ = *ini_data; - continue; - } - } - ORT_ENFORCE(input_i == 0, "QNN ReluMinMax operator expects only one input. Min and max are expected to be parameters, ie. initializer inputs in ONNX model"); - - Qnn_TensorType_t tensor_type = GetInputTensorType(qnn_model_wrapper, input_name); - QnnTensorWrapper input_tensorwrapper(input_name, tensor_type, qnn_data_type, quantize_param, - std::move(input_shape), std::move(unpacked_tensor)); - ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensorwrapper)), "Failed to add tensor."); - input_names.push_back(input_name); - } - return Status::OK(); + return ProcessInput(qnn_model_wrapper, node_unit.Inputs()[0], logger, input_names); } Status ClipOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, @@ -123,20 +69,59 @@ Status ClipOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra std::vector&& input_names, const logging::Logger& logger, bool do_op_validation) const { + const auto& inputs = node_unit.Inputs(); + const size_t num_inputs = inputs.size(); + + const Qnn_DataType_t qnn_data_type = QNN_DATATYPE_FLOAT_32; std::vector param_tensor_names; - Qnn_Scalar_t min_qnn_scalar = QNN_SCALAR_INIT; - min_qnn_scalar.dataType = QNN_DATATYPE_FLOAT_32; - min_qnn_scalar.floatValue = min_value_; - QnnParamWrapper min_value_param(node_unit.Index(), node_unit.Name(), QNN_OP_RELU_MIN_MAX_PARAM_MIN_VALUE, min_qnn_scalar); - param_tensor_names.push_back(min_value_param.GetParamTensorName()); - qnn_model_wrapper.AddParamWrapper(std::move(min_value_param)); - - Qnn_Scalar_t max_qnn_scalar = QNN_SCALAR_INIT; - max_qnn_scalar.dataType = QNN_DATATYPE_FLOAT_32; - max_qnn_scalar.floatValue = max_value_; - QnnParamWrapper max_value_param(node_unit.Index(), node_unit.Name(), QNN_OP_RELU_MIN_MAX_PARAM_MAX_VALUE, max_qnn_scalar); - param_tensor_names.push_back(max_value_param.GetParamTensorName()); - qnn_model_wrapper.AddParamWrapper(std::move(max_value_param)); + + auto get_f32_from_bytes = [](const std::vector& bytes, float default_val) -> float { + return bytes.empty() ? default_val : *reinterpret_cast(bytes.data()); + }; + + // Set the 'min' parameter. + { + std::vector min_val_bytes; + + if (num_inputs > 1 && !inputs[1].node_arg.Name().empty()) { + OnnxInputInfo min_input_info = {}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetOnnxInputInfo(inputs[1], min_input_info)); + ORT_RETURN_IF_NOT(min_input_info.qnn_data_type == qnn_data_type, + "QNN EP: The 'min' input of the Clip operator must be of type float32."); + assert(min_input_info.is_initializer); // Checked by ExplicitOpCheck(). + ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*min_input_info.initializer_tensor, min_val_bytes)); + } + + Qnn_Scalar_t min_qnn_scalar = QNN_SCALAR_INIT; + min_qnn_scalar.dataType = qnn_data_type; + min_qnn_scalar.floatValue = get_f32_from_bytes(min_val_bytes, std::numeric_limits::lowest()); + QnnParamWrapper min_value_param(node_unit.Index(), node_unit.Name(), QNN_OP_RELU_MIN_MAX_PARAM_MIN_VALUE, + min_qnn_scalar); + param_tensor_names.push_back(min_value_param.GetParamTensorName()); + qnn_model_wrapper.AddParamWrapper(std::move(min_value_param)); + } + + // Set the 'max' parameter. + { + std::vector max_val_bytes; + + if (num_inputs > 2 && !inputs[2].node_arg.Name().empty()) { + OnnxInputInfo max_input_info = {}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetOnnxInputInfo(inputs[2], max_input_info)); + ORT_RETURN_IF_NOT(max_input_info.qnn_data_type == qnn_data_type, + "QNN EP: The 'max' input of the Clip operator must of type float32."); + assert(max_input_info.is_initializer); // Checked by ExplicitOpCheck(). + ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*max_input_info.initializer_tensor, max_val_bytes)); + } + + Qnn_Scalar_t max_qnn_scalar = QNN_SCALAR_INIT; + max_qnn_scalar.dataType = qnn_data_type; + max_qnn_scalar.floatValue = get_f32_from_bytes(max_val_bytes, std::numeric_limits::max()); + QnnParamWrapper max_value_param(node_unit.Index(), node_unit.Name(), QNN_OP_RELU_MIN_MAX_PARAM_MAX_VALUE, + max_qnn_scalar); + param_tensor_names.push_back(max_value_param.GetParamTensorName()); + qnn_model_wrapper.AddParamWrapper(std::move(max_value_param)); + } ORT_RETURN_IF_ERROR(ProcessOutputs(qnn_model_wrapper, node_unit, std::move(input_names), diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/topk.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/topk.cc index 6ca36736f2f7f..047972294f78c 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/topk.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/topk.cc @@ -63,9 +63,20 @@ Status TopKOpBuilder::ExplictOpCheck(QnnModelWrapper& qnn_model_wrapper, const N auto rank = input_shape.size(); auto axis = node_helper.Get("axis", -1); - if (-1 == axis && axis != static_cast(rank - 1)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN TopK axis is always the last dimension"); + ORT_RETURN_IF_NOT(axis == -1 || axis == static_cast(rank - 1), + "QNN TopK's axis is always the last dimension"); + + // ONNX TopK outputs int64 indices, but the equivalent QNN op outputs uint32 indices. + // The QNN HTP backend does not generally support the int64 type, but QNN EP can just use the uint32 type + // for TopK ops within the graph. However, if the TopK op **generates** a graph output, + // then we cannot support it on the HTP backend. + bool is_npu_backend = IsNpuBackend(qnn_model_wrapper.GetQnnBackendType()); + if (is_npu_backend) { + const std::string& output_name = node_unit.Outputs()[0].node_arg.Name(); + ORT_RETURN_IF(qnn_model_wrapper.IsGraphOutput(output_name), + "QNN EP does not support TopK ops that generate a graph output."); } + return Status::OK(); } diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index a438a61cb9b36..d3616a14d8a5d 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -2497,10 +2497,15 @@ TEST(QDQTransformerTests, Clip) { epsilon); }; + constexpr int16_t int16_min = std::numeric_limits::min(); + constexpr uint16_t uint16_min = std::numeric_limits::min(); + std::vector opsets{12, 18, 19}; for (auto opset : opsets) { test_case(.0235294122248888f, static_cast(-128), 0, opset); // [0, 6] test_case(.0235294122248888f, static_cast(-128), 0, opset, true); // [0, 6] contrib qdq + test_case(9.15541313801785e-5f, int16_min, 0, opset, true); // [0, 6] contrib 16-bit qdq + test_case(0.0009f, int16_min, 1, opset, true); // [0, 58.98] contrib 16-bit qdq test_case(.02f, static_cast(-128), 0, opset); // [0, 5.1] test_case(.02f, static_cast(-128), 0, opset, true); // [0, 5.1] contrib qdq test_case(.03f, static_cast(-128), 1, opset); // [0, 7.65] @@ -2513,6 +2518,8 @@ TEST(QDQTransformerTests, Clip) { test_case(.04f, static_cast(-97), 1, opset, true); // [-1.24, 8.96] contrib qdq test_case(.02352941176f, static_cast(0), 0, opset); // [0, 6] test_case(.02352941176f, static_cast(0), 0, opset, true); // [0, 6] contrib qdq + test_case(9.15541313801785e-5f, uint16_min, 0, opset, true); // [0, 6] contrib 16-bit qdq + test_case(0.0009f, uint16_min, 1, opset, true); // [0, 58.98] contrib 16-bit qdq test_case(.02f, static_cast(0), 0, opset); // [0, 5.1] test_case(.02f, static_cast(0), 0, opset, true); // [0, 5.1] contrib qdq test_case(.03f, static_cast(0), 1, opset); // [0, 7.65] diff --git a/onnxruntime/test/providers/qnn/average_pool_test.cc b/onnxruntime/test/providers/qnn/average_pool_test.cc index 79ec07796c0e8..0ee52f7fec21a 100644 --- a/onnxruntime/test/providers/qnn/average_pool_test.cc +++ b/onnxruntime/test/providers/qnn/average_pool_test.cc @@ -32,7 +32,7 @@ static void RunAveragePoolOpTest(const std::string& op_type, provider_options["backend_path"] = "libQnnCpu.so"; #endif - RunQnnModelTest(BuildOpTestCase(op_type, input_defs, attrs), + RunQnnModelTest(BuildOpTestCase(op_type, input_defs, {}, attrs), provider_options, opset, expected_ep_assignment); @@ -53,8 +53,8 @@ static void RunQDQAveragePoolOpTest(const std::string& op_type, provider_options["backend_path"] = "libQnnHtp.so"; #endif - TestQDQModelAccuracy(BuildOpTestCase(op_type, input_defs, attrs), - BuildQDQOpTestCase(op_type, input_defs, attrs), + TestQDQModelAccuracy(BuildOpTestCase(op_type, input_defs, {}, attrs), + BuildQDQOpTestCase(op_type, input_defs, {}, attrs), provider_options, opset, expected_ep_assignment); diff --git a/onnxruntime/test/providers/qnn/clip_op_test.cc b/onnxruntime/test/providers/qnn/clip_op_test.cc new file mode 100644 index 0000000000000..15ba3b5de2fa1 --- /dev/null +++ b/onnxruntime/test/providers/qnn/clip_op_test.cc @@ -0,0 +1,188 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include + +#include "test/providers/qnn/qnn_test_utils.h" +#include "core/graph/node_attr_utils.h" + +#include "onnx/onnx_pb.h" +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +// Runs a model with a Clip operator on the QNN CPU backend. Checks the graph node assignment +// and that inference outputs for QNN EP and CPU EP match. +template +static void RunClipTestOnCPU(const TestInputDef& input_def, + const std::vector>& min_max_defs, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 13) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnCpu.dll"; +#else + provider_options["backend_path"] = "libQnnCpu.so"; +#endif + + RunQnnModelTest(BuildOpTestCase("Clip", {input_def}, min_max_defs, {}), + provider_options, + opset, + expected_ep_assignment); +} + +// +// CPU tests: +// + +// Test that Clip with a dynamic min or max input is not supported by QNN EP. +TEST_F(QnnCPUBackendTests, Clip_Dynamic_MinMax_Unsupported) { + // Dynamic min input is not supported. + RunClipTestOnCPU(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + {TestInputDef({}, false /* is_initializer */, {-5.0f})}, + ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP. + // Dynamic max input is not supported. + RunClipTestOnCPU(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + {TestInputDef({}, true, {-5.0f}), + TestInputDef({}, false, {5.0f})}, + ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP. +} + +// Test Clip with default min/max. +TEST_F(QnnCPUBackendTests, Clip_4D_f32_DefaultMinMax) { + RunClipTestOnCPU(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + {}, // Don't specify min/max inputs. + ExpectedEPNodeAssignment::All); +} + +// Test Clip with 5D input. +TEST_F(QnnCPUBackendTests, Clip_5D_f32) { + RunClipTestOnCPU(TestInputDef({1, 1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + {TestInputDef({}, true, {-5.0f}), + TestInputDef({}, true, {5.0f})}, + ExpectedEPNodeAssignment::All); +} + +#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +// +// HTP tests: +// + +// Runs a QDQ Clip model on the QNN (HTP) EP and the ORT CPU EP. Checks the graph node assignment and that inference +// running the QDQ model on QNN EP is at least as accurate as on ORT CPU EP (compared to the baseline float32 model). +template +static void RunQDQClipTestOnHTP(const TestInputDef& input_def, + const std::vector>& min_max_defs, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 13, + bool use_contrib_qdq = false) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + auto f32_model_builder = BuildOpTestCase("Clip", {input_def}, {min_max_defs}, {}); + auto qdq_model_builder = BuildQDQOpTestCase("Clip", {input_def}, {min_max_defs}, {}, + kOnnxDomain, use_contrib_qdq); + + TestQDQModelAccuracy(f32_model_builder, + qdq_model_builder, + provider_options, + opset, + expected_ep_assignment); +} + +// Test 8-bit QDQ Clip with default min/max. +// NOTE: The Clip operator is *optimized* away during L1 optimizations, so QNN EP does not get a graph with a Clip op. +// Instead, QNN EP will get a graph with a Q -> DQ. +// - Original sequence: Q1 -> DQ1 -> Clip -> Q2 -> DQ2 +// - ClipQuantFusion: Fuses Clip -> QuantizeLinear resulting in Q1 -> DQ1 -> Q2' -> DQ2 +// - DoubleQDQPairsRemover: Simplifies remaining Q1 -> DQ1 -> Q2' -> DQ2 sequence to Q1 -> DQ2. +TEST_F(QnnHTPBackendTests, Clip_U8_DefaultMinMax_Rank4) { + RunQDQClipTestOnHTP(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + {}, // Don't specify min/max inputs. + ExpectedEPNodeAssignment::All); +} + +// Test 16-bit QDQ Clip with default min/max. +// NOTE: The Clip operator is *optimized* away during L1 optimizations, so QNN EP does not get a graph with a Clip op. +// Instead, QNN EP will get a graph with a Q -> DQ. +// - Original sequence: Q1 -> DQ1 -> Clip -> Q2 -> DQ2 +// - ClipQuantFusion: Fuses Clip -> QuantizeLinear resulting in Q1 -> DQ1 -> Q2' -> DQ2 +// - DoubleQDQPairsRemover: Simplifies remaining Q1 -> DQ1 -> Q2' -> DQ2 sequence to Q1 -> DQ2. +TEST_F(QnnHTPBackendTests, Clip_U16_DefaultMinMax_Rank4) { + RunQDQClipTestOnHTP(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + {}, // Don't specify min/max inputs. + ExpectedEPNodeAssignment::All, + 13, // opset + true); // Use com.microsoft Q/DQ ops +} + +// Test 8-bit QDQ Clip with non-default min and max inputs. QNN EP will get a graph with a Clip operator. +TEST_F(QnnHTPBackendTests, Clip_U8_Rank4) { + RunQDQClipTestOnHTP(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + {TestInputDef({}, true, {-5.0f}), + TestInputDef({}, true, {5.0f})}, + ExpectedEPNodeAssignment::All); +} + +// Test 16-bit QDQ Clip with non-default min and max inputs. QNN EP will get a graph with a Clip operator. +TEST_F(QnnHTPBackendTests, Clip_U16_Rank4) { + RunQDQClipTestOnHTP(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + {TestInputDef({}, true, {-5.0f}), + TestInputDef({}, true, {5.0f})}, + ExpectedEPNodeAssignment::All, + 13, // opset + true); // Use com.microsoft Q/DQ ops +} + +// Test QDQ Clip of rank 5. +TEST_F(QnnHTPBackendTests, Clip_U8_Rank5) { + // We can't use the usual model-building functions because they add standalone Quantize and Dequantize nodes + // at the input and output. These Q/DQ ops get lowered to QNN's Quantize and Dequantize operators, which DO NOT + // support rank 5 tensors. Therefore, we have to create a test model that only instantiates the DQ -> Clip -> Q + // QDQ node group, which gets lowered to a single QNN Clip node. + GetTestModelFn model_fn = [](ModelTestBuilder& builder) { + // input (u8) -> DQ -> + NodeArg* quant_input = builder.MakeInput({1, 1, 2, 2, 2}, {0, 1, 6, 10, 20, 100, 128, 255}); + NodeArg* input_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(quant_input, 1.0f, 0, input_dq); // scale = 1.0, zp = 0 + + // Min/Max initializers + NodeArg* min_input = builder.MakeScalarInitializer(5.0f); + NodeArg* max_input = builder.MakeScalarInitializer(100.0f); + + // Clip -> + NodeArg* clip_output = builder.MakeIntermediate(); + builder.AddNode("Clip", {input_dq, min_input, max_input}, {clip_output}); + + // Q -> output (u8) + NodeArg* output = builder.MakeOutput(); + builder.AddQuantizeLinearNode(clip_output, 1.0f, 0, output); // scale = 1.0, zp = 0 + }; + + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + RunQnnModelTest(model_fn, + provider_options, + 13, // opset + ExpectedEPNodeAssignment::All); +} + +#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +} // namespace test +} // namespace onnxruntime +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/providers/qnn/flatten_op_test.cc b/onnxruntime/test/providers/qnn/flatten_op_test.cc new file mode 100644 index 0000000000000..637d3257ddea7 --- /dev/null +++ b/onnxruntime/test/providers/qnn/flatten_op_test.cc @@ -0,0 +1,202 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include + +#include "test/providers/qnn/qnn_test_utils.h" +#include "core/graph/node_attr_utils.h" + +#include "onnx/onnx_pb.h" +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +// Runs a model with a Flatten operator on the QNN CPU backend. Checks the graph node assignment +// and that inference outputs for QNN EP and CPU EP match. +template +static void RunFlattenTestOnCPU(const TestInputDef& input_def, + const std::vector& attrs, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 13) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnCpu.dll"; +#else + provider_options["backend_path"] = "libQnnCpu.so"; +#endif + + RunQnnModelTest(BuildOpTestCase("Flatten", {input_def}, {}, attrs), + provider_options, + opset, + expected_ep_assignment); +} + +// +// CPU tests: +// + +// Test that Flatten input (rank4) with axis == 0. +TEST_F(QnnCPUBackendTests, Flatten_Rank4_Axis0) { + RunFlattenTestOnCPU(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + {utils::MakeAttribute("axis", static_cast(0))}, + ExpectedEPNodeAssignment::All); +} + +// Test that Flatten input (rank4) with axis == -1. +TEST_F(QnnCPUBackendTests, Flatten_Rank4_AxisNeg1) { + RunFlattenTestOnCPU(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + {utils::MakeAttribute("axis", static_cast(-1))}, + ExpectedEPNodeAssignment::All); +} + +// Test that Flatten input (rank5) with axis == 2. +TEST_F(QnnCPUBackendTests, Flatten_Rank5_Axis2) { + RunFlattenTestOnCPU(TestInputDef({1, 2, 3, 4, 4}, false, -10.0f, 10.0f), + {utils::MakeAttribute("axis", static_cast(2))}, + ExpectedEPNodeAssignment::All); +} + +#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +// +// HTP tests: +// + +// Runs a model with a non-QDQ Flatten operator on the QNN HTP backend. Checks the graph node assignment +// and that inference outputs for QNN EP and CPU EP match. +template +static void RunFlattenTestOnHTP(const TestInputDef& input_def, + const std::vector& attrs, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 13) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + RunQnnModelTest(BuildOpTestCase("Flatten", {input_def}, {}, attrs), + provider_options, + opset, + expected_ep_assignment); +} + +// Runs a QDQ Flatten model on the QNN (HTP) EP and the ORT CPU EP. Checks the graph node assignment and that inference +// running the QDQ model on QNN EP is at least as accurate as on ORT CPU EP (compared to the baseline float32 model). +template +static void RunQDQFlattenTestOnHTP(const TestInputDef& input_def, + const std::vector& attrs, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 13, + bool use_contrib_qdq = false) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + auto f32_model_builder = BuildOpTestCase("Flatten", {input_def}, {}, attrs); + auto qdq_model_builder = BuildQDQOpTestCase("Flatten", {input_def}, {}, attrs, kOnnxDomain, use_contrib_qdq); + TestQDQModelAccuracy(f32_model_builder, + qdq_model_builder, + provider_options, + opset, + expected_ep_assignment); +} + +// Test 8-bit QDQ Flatten input (rank4) with axis == 0. +TEST_F(QnnHTPBackendTests, Flatten_Rank4_Axis0) { + RunQDQFlattenTestOnHTP(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + {utils::MakeAttribute("axis", static_cast(0))}, + ExpectedEPNodeAssignment::All); +} + +// Test 16-bit QDQ Flatten input (rank4) with axis == 0. +TEST_F(QnnHTPBackendTests, Flatten_Rank4_Axis0_U16) { + RunQDQFlattenTestOnHTP(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + {utils::MakeAttribute("axis", static_cast(0))}, + ExpectedEPNodeAssignment::All, + 13, // opset + true); // Use com.microsoft Q/DQ ops +} + +// Test 8-bit QDQ Flatten input (rank4) with axis == -1. +TEST_F(QnnHTPBackendTests, Flatten_Rank4_AxisNeg1) { + RunQDQFlattenTestOnHTP(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + {utils::MakeAttribute("axis", static_cast(-1))}, + ExpectedEPNodeAssignment::All); +} + +// Test 16-bit QDQ Flatten input (rank4) with axis == -1. +TEST_F(QnnHTPBackendTests, Flatten_Rank4_AxisNeg1_U16) { + RunQDQFlattenTestOnHTP(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + {utils::MakeAttribute("axis", static_cast(-1))}, + ExpectedEPNodeAssignment::All, + 13, // opset + true); // Use com.microsoft Q/DQ ops +} + +// Test 8-bit QDQ Flatten with an input of rank5. +TEST_F(QnnHTPBackendTests, Flatten_QDQ8bit_Rank5) { + // We can't use the usual model-building functions because they add standalone Quantize and Dequantize nodes + // at the input and output. These Q/DQ ops get lowered to QNN's Quantize and Dequantize operators, which DO NOT + // support rank 5 tensors. Therefore, we have to create a test model that only instantiates the DQ -> Flatten -> Q + // QDQ node group, which gets lowered to a single QNN Reshape node. + GetTestModelFn model_fn = [](ModelTestBuilder& builder) { + // input (u8) -> DQ -> + NodeArg* quant_input = builder.MakeInput({1, 2, 3, 4, 5}, 0, 255); + NodeArg* input_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(quant_input, 1.0f, 0, input_dq); // scale = 1.0, zp = 0 + + // Flatten -> + NodeArg* flatten_output = builder.MakeIntermediate(); + Node& flatten_node = builder.AddNode("Flatten", {input_dq}, {flatten_output}); + flatten_node.AddAttribute("axis", static_cast(2)); + + // Q -> output (u8) + NodeArg* output = builder.MakeOutput(); + builder.AddQuantizeLinearNode(flatten_output, 1.0f, 0, output); // scale = 1.0, zp = 0 + }; + + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + RunQnnModelTest(model_fn, + provider_options, + 13, // opset + ExpectedEPNodeAssignment::All); +} + +// Test that int32 non-QDQ Flatten runs on HTP backend. +TEST_F(QnnHTPBackendTests, Flatten_Int32_Rank4_Axis2) { + std::vector input_data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; + RunFlattenTestOnHTP(TestInputDef({1, 3, 2, 2}, false, input_data), + {utils::MakeAttribute("axis", static_cast(2))}, + ExpectedEPNodeAssignment::All); +} + +// Test that rank 5 int32 Flatten runs on HTP backend. +TEST_F(QnnHTPBackendTests, Flatten_Int32_Rank5_Axis2) { + std::vector input_data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}; + RunFlattenTestOnHTP(TestInputDef({1, 3, 2, 2, 2}, false, input_data), + {utils::MakeAttribute("axis", static_cast(2))}, + ExpectedEPNodeAssignment::All); +} + +#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +} // namespace test +} // namespace onnxruntime +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/providers/qnn/gather_op_htp_test.cc b/onnxruntime/test/providers/qnn/gather_op_htp_test.cc index 5b05b39f34a27..37e0db906d054 100644 --- a/onnxruntime/test/providers/qnn/gather_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/gather_op_htp_test.cc @@ -5,6 +5,7 @@ #include #include "core/graph/graph.h" +#include "core/graph/node_attr_utils.h" #include "test/providers/qnn/qnn_test_utils.h" @@ -14,47 +15,14 @@ namespace onnxruntime { namespace test { #if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) -// Function that builds a float model with a Gather op. -template -static GetTestModelFn BuildGatherOpTestCase(const TestInputDef& input_def, - const TestInputDef& indices_def, - int64_t axis = 0) { - return [input_def, indices_def, axis](ModelTestBuilder& builder) { - NodeArg* input = MakeTestInput(builder, input_def); - NodeArg* indices = MakeTestInput(builder, indices_def); - NodeArg* output = builder.MakeOutput(); - - Node& gather_node = builder.AddNode("Gather", {input, indices}, {output}); - gather_node.AddAttribute("axis", axis); - }; -} - -// Function that builds a QDQ model with a Gather op. -template -static GetTestQDQModelFn BuildQDQGatherOpTestCase(const TestInputDef& input_def, - const TestInputDef& indices_def, - int64_t axis = 0) { - return [input_def, indices_def, axis](ModelTestBuilder& builder, - std::vector>& output_qparams) { - NodeArg* input = MakeTestInput(builder, input_def); - QuantParams input_qparams = GetTestInputQuantParams(input_def); - NodeArg* input_qdq = AddQDQNodePair(builder, input, input_qparams.scale, input_qparams.zero_point); - - NodeArg* indices = MakeTestInput(builder, indices_def); - - NodeArg* gather_output = builder.MakeIntermediate(); - Node& gather_node = builder.AddNode("Gather", {input_qdq, indices}, {gather_output}); - gather_node.AddAttribute("axis", axis); - - AddQDQNodePairWithOutputAsGraphOutput(builder, gather_output, output_qparams[0].scale, output_qparams[0].zero_point); - }; -} - // Test the accuracy of a QDQ Gather model on QNN EP. Checks if the QDQ model on QNN EP as accurate as the QDQ model on CPU EP // (compared to float32 model). template -static void RunQDQGatherOpTest(const TestInputDef& input_def, const TestInputDef& indices_def, - int64_t axis, int opset, ExpectedEPNodeAssignment expected_ep_assignment) { +static void RunQDQGatherOpTest(const TestInputDef& input_def, + const TestInputDef& indices_def, + const std::vector& attrs, + int opset, + ExpectedEPNodeAssignment expected_ep_assignment) { ProviderOptions provider_options; #if defined(_WIN32) provider_options["backend_path"] = "QnnHtp.dll"; @@ -62,12 +30,14 @@ static void RunQDQGatherOpTest(const TestInputDef& input_def, const TestI provider_options["backend_path"] = "libQnnHtp.so"; #endif - TestQDQModelAccuracy(BuildGatherOpTestCase(input_def, indices_def, axis), - BuildQDQGatherOpTestCase(input_def, indices_def, axis), + auto f32_model_builder = BuildOpTestCase("Gather", {input_def}, {indices_def}, attrs); + auto qdq_model_builder = BuildQDQOpTestCase("Gather", {input_def}, {indices_def}, attrs); + + TestQDQModelAccuracy(f32_model_builder, + qdq_model_builder, provider_options, opset, - expected_ep_assignment, - 1e-5f); + expected_ep_assignment); } // Test creates a DQ -> Gather -> Q -> DQ graph, and checks that all @@ -77,7 +47,7 @@ static void RunQDQGatherOpTest(const TestInputDef& input_def, const TestI TEST_F(QnnHTPBackendTests, GatherOp_IndicesStaticInt64_Axis0) { RunQDQGatherOpTest(TestInputDef({3, 2}, false, {1.0f, 1.2f, 2.3f, 3.4f, 4.5f, 5.7f}), TestInputDef({2, 2}, true, {0, 1, 1, 2}), - 0, + {utils::MakeAttribute("axis", static_cast(0))}, 13, ExpectedEPNodeAssignment::All); } @@ -86,7 +56,7 @@ TEST_F(QnnHTPBackendTests, GatherOp_IndicesStaticInt64_Axis0) { TEST_F(QnnHTPBackendTests, GatherOp_IndicesDynamicInt64_Axis0) { RunQDQGatherOpTest(TestInputDef({3, 2}, false, {1.0f, 1.2f, 2.3f, 3.4f, 4.5f, 5.7f}), TestInputDef({2, 2}, false, {0, 1, 1, 2}), - 0, + {utils::MakeAttribute("axis", static_cast(0))}, 13, ExpectedEPNodeAssignment::None); } @@ -98,7 +68,7 @@ TEST_F(QnnHTPBackendTests, GatherOp_IndicesDynamicInt64_Axis0) { TEST_F(QnnHTPBackendTests, GatherOp_IndicesStaticInt32_Axis0) { RunQDQGatherOpTest(TestInputDef({3, 2}, false, {1.0f, 1.2f, 2.3f, 3.4f, 4.5f, 5.7f}), TestInputDef({2, 2}, true, {0, 1, 1, 2}), - 0, + {utils::MakeAttribute("axis", static_cast(0))}, 13, ExpectedEPNodeAssignment::All); } @@ -110,7 +80,7 @@ TEST_F(QnnHTPBackendTests, GatherOp_IndicesStaticInt32_Axis0) { TEST_F(QnnHTPBackendTests, GatherOp_IndicesDynamicInt32_Axis0) { RunQDQGatherOpTest(TestInputDef({3, 2}, false, {1.0f, 1.2f, 2.3f, 3.4f, 4.5f, 5.7f}), TestInputDef({2, 2}, false, {0, 1, 1, 2}), - 0, + {utils::MakeAttribute("axis", static_cast(0))}, 13, ExpectedEPNodeAssignment::All); } @@ -122,7 +92,7 @@ TEST_F(QnnHTPBackendTests, GatherOp_IndicesDynamicInt32_Axis0) { TEST_F(QnnHTPBackendTests, GatherOp_IndicesStaticInt32_Axis1) { RunQDQGatherOpTest(TestInputDef({3, 3}, false, {1.0f, 1.2f, 1.9f, 2.3f, 3.4f, 3.9f, 4.5f, 5.7f, 5.9f}), TestInputDef({1, 2}, true, {0, 2}), - 1, + {utils::MakeAttribute("axis", static_cast(1))}, 13, ExpectedEPNodeAssignment::All); } diff --git a/onnxruntime/test/providers/qnn/gemm_op_test.cc b/onnxruntime/test/providers/qnn/gemm_op_test.cc new file mode 100644 index 0000000000000..15f26717b06fd --- /dev/null +++ b/onnxruntime/test/providers/qnn/gemm_op_test.cc @@ -0,0 +1,341 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include +#include + +#include "test/providers/qnn/qnn_test_utils.h" +#include "core/graph/node_attr_utils.h" + +#include "onnx/onnx_pb.h" +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +// Runs a model with a Gemm operator on the QNN CPU backend. Checks the graph node assignment +// and that inference outputs for QNN EP and CPU EP match. +template +static void RunGemmTestOnCPU(const std::vector>& input_defs, + const std::vector& attrs, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 13) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnCpu.dll"; +#else + provider_options["backend_path"] = "libQnnCpu.so"; +#endif + + RunQnnModelTest(BuildOpTestCase("Gemm", input_defs, {}, attrs), + provider_options, + opset, + expected_ep_assignment); +} + +// +// CPU tests: +// + +// Test that Gemm with non-default 'alpha' or 'beta' attributes is not supported by QNN EP. +TEST_F(QnnCPUBackendTests, Gemm_NonDefaultAlphaBeta_Unsupported) { + // Check that alpha != 1.0f is not supported. + RunGemmTestOnCPU({TestInputDef({1, 2}, false, -10.0f, 10.0f), + TestInputDef({2, 4}, false, -10.0f, 10.0f)}, + {utils::MakeAttribute("alpha", 1.5f)}, + ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP. + + // Check that beta != 1.0f is not supported. + RunGemmTestOnCPU({TestInputDef({1, 2}, false, -10.0f, 10.0f), + TestInputDef({2, 4}, false, -10.0f, 10.0f), + TestInputDef({1, 4}, false, -1.0f, 1.0f)}, + {utils::MakeAttribute("beta", 1.2f)}, + ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP. +} + +// Test that Gemm with general 2D bias (M, N) is NOT supported (unless M == 1). +// QNN's FullyConnected operator only supports `outputVector = ( inputAsVector * weightsMatrix ) + biasesVector` +TEST_F(QnnCPUBackendTests, Gemm_2D_Bias_Unsupported) { + std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); + std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 12); + + // 2D matrix mul with bias not supported. + RunGemmTestOnCPU({TestInputDef({2, 3}, false, input_a_data), + TestInputDef({3, 4}, false, input_b_data), + TestInputDef({2, 4}, false, -1.0f, 1.0f)}, + {}, + ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP. + + // However, 2D matrix mul without a bias is supported. Input A's 0th dimension is interpreted as `batch_size`. + RunGemmTestOnCPU({TestInputDef({2, 3}, false, input_a_data), + TestInputDef({3, 4}, false, input_b_data)}, + {}, + ExpectedEPNodeAssignment::All); // Assigned to QNN EP. +} + +// Test Gemm with dynamic (i.e., not initializer) inputs (A, B, Bias). +TEST_F(QnnCPUBackendTests, Gemm_Dynamic_A_B_Bias) { + std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); + std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 24); + std::vector input_c_data = GetFloatDataInRange(-1.0f, 1.0f, 4); + RunGemmTestOnCPU({TestInputDef({1, 6}, false, input_a_data), + TestInputDef({6, 4}, false, input_b_data), + TestInputDef({1, 4}, false, input_c_data)}, + {}, + ExpectedEPNodeAssignment::All); +} + +// Test Gemm with static B and Bias inputs. +TEST_F(QnnCPUBackendTests, Gemm_Static_B_And_Bias) { + std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); + std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 24); + std::vector input_c_data = GetFloatDataInRange(-1.0f, 1.0f, 4); + RunGemmTestOnCPU({TestInputDef({1, 6}, false, input_a_data), + TestInputDef({6, 4}, true, input_b_data), + TestInputDef({1, 4}, true, input_c_data)}, + {}, + ExpectedEPNodeAssignment::All); +} + +// Test Gemm with transposed A/B and static B and Bias inputs. +TEST_F(QnnCPUBackendTests, Gemm_TransAB_Static_B_And_Bias) { + std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); + std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 24); + std::vector input_c_data = GetFloatDataInRange(-1.0f, 1.0f, 4); + RunGemmTestOnCPU({TestInputDef({6, 1}, false, input_a_data), + TestInputDef({4, 6}, true, input_b_data), + TestInputDef({1, 4}, true, input_c_data)}, + {utils::MakeAttribute("transA", static_cast(1)), + utils::MakeAttribute("transB", static_cast(1))}, + ExpectedEPNodeAssignment::All); +} + +// Test Gemm with transposed A/B and dynamic (i.e., not initializer) B and Bias inputs. +TEST_F(QnnCPUBackendTests, Gemm_TransAB_Dynamic_B_And_Bias) { + std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); + std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 24); + std::vector input_c_data = GetFloatDataInRange(-1.0f, 1.0f, 4); + RunGemmTestOnCPU({TestInputDef({6, 1}, false, input_a_data), + TestInputDef({4, 6}, false, input_b_data), + TestInputDef({1, 4}, false, input_c_data)}, + {utils::MakeAttribute("transA", static_cast(1)), + utils::MakeAttribute("transB", static_cast(1))}, + ExpectedEPNodeAssignment::All); +} + +#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +// +// HTP tests: +// + +// Returns a function that builds a model with a QDQ Gemm node. +template +inline GetTestQDQModelFn BuildQDQGemmTestCase(const std::vector>& input_defs, + const std::vector& attrs, + bool use_contrib_qdq = false) { + return [input_defs, attrs, use_contrib_qdq](ModelTestBuilder& builder, + std::vector>& output_qparams) { + const size_t num_inputs = input_defs.size(); + assert(num_inputs == 2 || num_inputs == 3); + + std::vector op_inputs; + op_inputs.reserve(num_inputs); + + // Process input 0 + NodeArg* input0 = MakeTestInput(builder, input_defs[0]); + QuantParams input0_qparams = GetTestInputQuantParams(input_defs[0]); + NodeArg* input0_after_qdq = AddQDQNodePair(builder, input0, input0_qparams.scale, + input0_qparams.zero_point, use_contrib_qdq); + op_inputs.push_back(input0_after_qdq); + + // Process input 1 + NodeArg* input1 = MakeTestInput(builder, input_defs[1]); + QuantParams input1_qparams = GetTestInputQuantParams(input_defs[1]); + NodeArg* input1_after_qdq = AddQDQNodePair(builder, input1, input1_qparams.scale, + input1_qparams.zero_point, use_contrib_qdq); + op_inputs.push_back(input1_after_qdq); + + // Process bias + if (num_inputs == 3) { + NodeArg* bias_input = MakeTestQDQBiasInput(builder, input_defs[2], input0_qparams.scale * input1_qparams.scale, + use_contrib_qdq); + op_inputs.push_back(bias_input); + } + + // Op -> op_output + auto* gemm_output = builder.MakeIntermediate(); + Node& gemm_node = builder.AddNode("Gemm", op_inputs, {gemm_output}); + + for (const auto& attr : attrs) { + gemm_node.AddAttributeProto(attr); + } + + // op_output -> Q -> DQ -> output + AddQDQNodePairWithOutputAsGraphOutput(builder, gemm_output, output_qparams[0].scale, + output_qparams[0].zero_point, use_contrib_qdq); + }; +} + +// Runs a QDQ Gemm model on the QNN (HTP) EP and the ORT CPU EP. Checks the graph node assignment and that inference +// running the QDQ model on QNN EP is at least as accurate as on ORT CPU EP (compared to the baseline float32 model). +template +static void RunQDQGemmTestOnHTP(const std::vector>& input_defs, + const std::vector& attrs, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 13, + float f32_abs_err = 1e-4f, + bool use_contrib_qdq = false) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + auto f32_model_builder = BuildOpTestCase("Gemm", input_defs, {}, attrs); + auto qdq_model_builder = BuildQDQGemmTestCase(input_defs, attrs, use_contrib_qdq); + TestQDQModelAccuracy(f32_model_builder, + qdq_model_builder, + provider_options, + opset, + expected_ep_assignment, + f32_abs_err); +} + +// Test 8-bit QDQ Gemm with dynamic inputs A and Bias. The B input is an initializer. +TEST_F(QnnHTPBackendTests, Gemm_Dynamic_A_Static_B_Dynamic_Bias_U8) { + std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); + std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 24); + std::vector input_c_data = GetFloatDataInRange(-1.0f, 1.0f, 4); + RunQDQGemmTestOnHTP({TestInputDef({1, 6}, false, input_a_data), + TestInputDef({6, 4}, true, input_b_data), + TestInputDef({1, 4}, false, input_c_data)}, + {}, + ExpectedEPNodeAssignment::All); +} + +// Test 16-bit QDQ Gemm with dynamic inputs A and Bias. The B input is an initializer. +// TODO: Inaccuracy detected for output 'output_0', element 0. +// Output quant params: scale=0.001872879103757441, zero_point=0. +// Expected val: 120.73912048339844 +// QNN QDQ val: 0 (err 120.73912048339844) +// CPU QDQ val: 120.73889923095703 (err 0.00022125244140625) +TEST_F(QnnHTPBackendTests, DISABLED_Gemm_Dynamic_A_Static_B_Dynamic_Bias_U16) { + std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); + std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 24); + std::vector input_c_data = GetFloatDataInRange(-1.0f, 1.0f, 4); + RunQDQGemmTestOnHTP({TestInputDef({1, 6}, false, input_a_data), + TestInputDef({6, 4}, true, input_b_data), + TestInputDef({1, 4}, false, input_c_data)}, + {}, + ExpectedEPNodeAssignment::All, + 13, // opset + 1e-4f, // f32_abs_err + true); // Use com.microsoft Q/DQ ops +} + +// Test QDQ Gemm (16bit act, 8bit weight) with dynamic inputs A and Bias. The B input is an initializer. +// TODO: Allow small inaccuracies based on % of expected value. +// Inaccuracy detected for output 'output_0', element 0. +// Output quant params: scale=0.001872879103757441, zero_point=0. +// Expected val: 120.73912048339844 +// QNN QDQ val: 120.48043823242188 (err 0.2586822509765625) +// CPU QDQ val: 120.48980712890625 (err 0.2493133544921875) +TEST_F(QnnHTPBackendTests, Gemm_Dynamic_A_Static_B_Dynamic_Bias_U16Act_U8Weight) { + std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); + std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 24); + std::vector input_c_data = GetFloatDataInRange(-1.0f, 1.0f, 4); + RunQDQGemmTestOnHTP({TestInputDef({1, 6}, false, input_a_data), + TestInputDef({6, 4}, true, input_b_data), + TestInputDef({1, 4}, false, input_c_data)}, + {}, + ExpectedEPNodeAssignment::All, + 13, // opset + 0.15f, // f32_abs_err + true); // Use com.microsoft Q/DQ ops +} + +// Test QDQ Gemm with dynamic A and B inputs. The Bias is static. +// TODO: Inaccuracy detected for output 'output', element 0. +// Output quant params: scale=0.48132994771003723, zero_point=0. +// Expected val: 120.73912048339844 +// QNN QDQ val: 77.012794494628906 (err 43.726325988769531) +// CPU QDQ val: 119.85115814208984 (err 0.88796234130859375) +TEST_F(QnnHTPBackendTests, DISABLED_Gemm_Dynamic_A_B_Static_Bias) { + std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); + std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 24); + std::vector input_c_data = GetFloatDataInRange(-1.0f, 1.0f, 4); + RunQDQGemmTestOnHTP({TestInputDef({1, 6}, false, input_a_data), + TestInputDef({6, 4}, false, input_b_data), // Dynamic => inaccuracy + TestInputDef({1, 4}, true, input_c_data)}, + {}, + ExpectedEPNodeAssignment::All); +} + +// Test QDQ Gemm with static B and Bias inputs. +TEST_F(QnnHTPBackendTests, Gemm_Static_B_And_Bias) { + std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); + std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 24); + std::vector input_c_data = GetFloatDataInRange(-1.0f, 1.0f, 4); + RunQDQGemmTestOnHTP({TestInputDef({1, 6}, false, input_a_data), + TestInputDef({6, 4}, true, input_b_data), + TestInputDef({1, 4}, true, input_c_data)}, + {}, + ExpectedEPNodeAssignment::All); +} + +// Test 8-bit QDQ Gemm with transposed A/B and static B and Bias inputs. +TEST_F(QnnHTPBackendTests, Gemm_TransAB_Static_B_And_Bias_U8) { + std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); + std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 24); + std::vector input_c_data = GetFloatDataInRange(-1.0f, 1.0f, 4); + RunQDQGemmTestOnHTP({TestInputDef({6, 1}, false, input_a_data), + TestInputDef({4, 6}, true, input_b_data), + TestInputDef({1, 4}, true, input_c_data)}, + {utils::MakeAttribute("transA", static_cast(1)), + utils::MakeAttribute("transB", static_cast(1))}, + ExpectedEPNodeAssignment::All); +} + +// Test QDQ Gemm (16bit activation, 8bit weight) with transposed A/B and static B and Bias inputs. +// TODO: Allow small inaccuracies based on % of expected value. +// Inaccuracy detected for output 'output_0', element 0. +// Output quant params: scale=0.00047966410056687891, zero_point=0. +// Expected val: 29.434776306152344 +// QNN QDQ val: 29.191877365112305 (err 0.24289894104003906) +// CPU QDQ val: 29.197153091430664 (err 0.23762321472167969) +TEST_F(QnnHTPBackendTests, Gemm_TransAB_Static_B_And_Bias_U16Act_U8Weight) { + std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); + std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 24); + std::vector input_c_data = GetFloatDataInRange(-1.0f, 1.0f, 4); + RunQDQGemmTestOnHTP({TestInputDef({6, 1}, false, input_a_data), + TestInputDef({4, 6}, true, input_b_data), + TestInputDef({1, 4}, true, input_c_data)}, + {utils::MakeAttribute("transA", static_cast(1)), + utils::MakeAttribute("transB", static_cast(1))}, + ExpectedEPNodeAssignment::All, + 13, // opset + 0.15f, // f32_abs_err + true); // Use com.microsoft Q/DQ ops +} + +// Test QDQ Gemm with transposed A/B and dynamic (i.e., not initializer) B and Bias inputs. +TEST_F(QnnHTPBackendTests, Gemm_TransAB_Dynamic_B_And_Bias) { + std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); + std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 24); + std::vector input_c_data = GetFloatDataInRange(-1.0f, 1.0f, 4); + RunQDQGemmTestOnHTP({TestInputDef({6, 1}, false, input_a_data), + TestInputDef({4, 6}, false, input_b_data), + TestInputDef({1, 4}, false, input_c_data)}, + {utils::MakeAttribute("transA", static_cast(1)), + utils::MakeAttribute("transB", static_cast(1))}, + ExpectedEPNodeAssignment::All); +} + +#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +} // namespace test +} // namespace onnxruntime +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/providers/qnn/instance_norm_htp_test.cc b/onnxruntime/test/providers/qnn/instance_norm_htp_test.cc index 594973e37ef0b..f662ac14336f8 100644 --- a/onnxruntime/test/providers/qnn/instance_norm_htp_test.cc +++ b/onnxruntime/test/providers/qnn/instance_norm_htp_test.cc @@ -16,25 +16,6 @@ namespace onnxruntime { namespace test { #if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) -// Function that builds a float32 model with an InstanceNormalization operator. -GetTestModelFn BuildInstanceNormTestCase(const TestInputDef& input_def, - const TestInputDef& scale_def, - const TestInputDef& bias_def, - const std::vector& attrs) { - return [input_def, scale_def, bias_def, attrs](ModelTestBuilder& builder) { - NodeArg* input = MakeTestInput(builder, input_def); - NodeArg* scale = MakeTestInput(builder, scale_def); - NodeArg* bias = MakeTestInput(builder, bias_def); - - NodeArg* output = builder.MakeOutput(); - Node& op_node = builder.AddNode("InstanceNormalization", {input, scale, bias}, {output}); - - for (const auto& attr : attrs) { - op_node.AddAttributeProto(attr); - } - }; -} - // Function that builds a QDQ model with an InstanceNormalization operator. template static GetTestQDQModelFn BuildQDQInstanceNormTestCase(const TestInputDef& input_def, @@ -93,7 +74,7 @@ static void RunInstanceNormQDQTest(const TestInputDef& input_def, #endif // Runs model with DQ-> InstanceNorm -> Q and compares the outputs of the CPU and QNN EPs. - TestQDQModelAccuracy(BuildInstanceNormTestCase(input_def, scale_def, bias_def, attrs), + TestQDQModelAccuracy(BuildOpTestCase("InstanceNormalization", {input_def, scale_def, bias_def}, {}, attrs), BuildQDQInstanceNormTestCase(input_def, scale_def, bias_def, attrs), provider_options, 18, diff --git a/onnxruntime/test/providers/qnn/layer_norm_test.cc b/onnxruntime/test/providers/qnn/layer_norm_test.cc index aa6c6a142e6d1..085454004e5a5 100644 --- a/onnxruntime/test/providers/qnn/layer_norm_test.cc +++ b/onnxruntime/test/providers/qnn/layer_norm_test.cc @@ -29,7 +29,7 @@ static void RunLayerNormCpuTest(const TestInputDef& input_def, provider_options["backend_path"] = "libQnnCpu.so"; #endif - RunQnnModelTest(BuildOpTestCase("LayerNormalization", {input_def, scale_def}, attrs), + RunQnnModelTest(BuildOpTestCase("LayerNormalization", {input_def, scale_def}, {}, attrs), provider_options, 17, expected_ep_assignment); @@ -114,7 +114,7 @@ static void RunLayerNormQDQTest(const TestInputDef& input_def, provider_options["backend_path"] = "libQnnHtp.so"; #endif - TestQDQModelAccuracy(BuildOpTestCase("LayerNormalization", {input_def, scale_def}, attrs), + TestQDQModelAccuracy(BuildOpTestCase("LayerNormalization", {input_def, scale_def}, {}, attrs), BuildQDQLayerNormTestCase(input_def, scale_def, attrs), provider_options, 17, // opset diff --git a/onnxruntime/test/providers/qnn/leakyrelu_op_htp_test.cc b/onnxruntime/test/providers/qnn/leakyrelu_op_htp_test.cc index a8237817c71df..e3077ec569923 100644 --- a/onnxruntime/test/providers/qnn/leakyrelu_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/leakyrelu_op_htp_test.cc @@ -5,6 +5,7 @@ #include #include "core/graph/graph.h" +#include "core/graph/node_attr_utils.h" #include "test/optimizer/qdq_test_utils.h" #include "test/providers/qnn/qnn_test_utils.h" @@ -15,42 +16,10 @@ namespace onnxruntime { namespace test { #if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) -// Creates a function that builds a model with a LeakyRelu operator. -static GetTestModelFn BuildLeakyReluOpTestCase(const TestInputDef& input_def, float alpha) { - return [input_def, alpha](ModelTestBuilder& builder) { - NodeArg* input = MakeTestInput(builder, input_def); - NodeArg* output = builder.MakeOutput(); - Node& leakyrelu_node = builder.AddNode("LeakyRelu", {input}, {output}); - leakyrelu_node.AddAttribute("alpha", alpha); - }; -} - -// Creates a function that builds a QDQ model with a LeakyRelu operator. -template -static GetTestQDQModelFn BuildQDQLeakyReluOpTestCase(const TestInputDef& input_def, - float alpha) { - return [input_def, alpha](ModelTestBuilder& builder, - std::vector>& output_qparams) { - // input => Q => DQ => - NodeArg* input = MakeTestInput(builder, input_def); - QuantParams input_qparams = GetTestInputQuantParams(input_def); - NodeArg* input_qdq = AddQDQNodePair(builder, input, input_qparams.scale, input_qparams.zero_point); - - // LeakryRelu - auto* leakyrelu_output = builder.MakeIntermediate(); - Node& leakyrelu_node = builder.AddNode("LeakyRelu", {input_qdq}, {leakyrelu_output}); - leakyrelu_node.AddAttribute("alpha", alpha); - - // => Q => DQ -> final output - AddQDQNodePairWithOutputAsGraphOutput(builder, leakyrelu_output, output_qparams[0].scale, - output_qparams[0].zero_point); - }; -} - // Checks the accuracy of a QDQ LeakyRelu model by comparing to ORT CPU EP. template static void RunLeakyReluOpQDQTest(const TestInputDef& input_def, - float alpha, + const std::vector& attrs, int opset, ExpectedEPNodeAssignment expected_ep_assignment) { ProviderOptions provider_options; @@ -60,12 +29,11 @@ static void RunLeakyReluOpQDQTest(const TestInputDef& input_def, provider_options["backend_path"] = "libQnnHtp.so"; #endif - TestQDQModelAccuracy(BuildLeakyReluOpTestCase(input_def, alpha), - BuildQDQLeakyReluOpTestCase(input_def, alpha), + TestQDQModelAccuracy(BuildOpTestCase("LeakyRelu", {input_def}, {}, attrs), + BuildQDQOpTestCase("LeakyRelu", {input_def}, {}, attrs), provider_options, opset, - expected_ep_assignment, - 1e-5f); + expected_ep_assignment); } // Test creates a DQ -> Gather -> Q -> DQ graph, and checks that all @@ -74,7 +42,7 @@ static void RunLeakyReluOpQDQTest(const TestInputDef& input_def, // - Uses uint8 as the quantization type. TEST_F(QnnHTPBackendTests, LeakyReluOpSet15) { RunLeakyReluOpQDQTest(TestInputDef({1, 2, 3}, false, {-40.0f, -20.0f, 0.0f, 10.0f, 30.0f, 40.0f}), - 0.2f, + {utils::MakeAttribute("alpha", 0.2f)}, 15, ExpectedEPNodeAssignment::All); } @@ -85,7 +53,7 @@ TEST_F(QnnHTPBackendTests, LeakyReluOpSet15) { // - Uses uint8 as the quantization type. TEST_F(QnnHTPBackendTests, LeakyReluOpSet16) { RunLeakyReluOpQDQTest(TestInputDef({1, 2, 3}, false, {-40.0f, -20.0f, 0.0f, 10.0f, 30.0f, 40.0f}), - 0.2f, + {utils::MakeAttribute("alpha", 0.2f)}, 16, ExpectedEPNodeAssignment::All); } diff --git a/onnxruntime/test/providers/qnn/max_min_op_test.cc b/onnxruntime/test/providers/qnn/max_min_op_test.cc index 09ea71e5f03eb..3deff121f3c72 100644 --- a/onnxruntime/test/providers/qnn/max_min_op_test.cc +++ b/onnxruntime/test/providers/qnn/max_min_op_test.cc @@ -27,7 +27,7 @@ static void RunCPUMinOrMaxOpTest(const std::string& op_type, provider_options["backend_path"] = "libQnnCpu.so"; #endif - RunQnnModelTest(BuildOpTestCase(op_type, input_defs, {}, kOnnxDomain), + RunQnnModelTest(BuildOpTestCase(op_type, input_defs, {}, {}, kOnnxDomain), provider_options, opset, expected_ep_assignment); @@ -48,12 +48,11 @@ static void RunQDQMinOrMaxOpTest(const std::string& op_type, provider_options["backend_path"] = "libQnnHtp.so"; #endif - TestQDQModelAccuracy(BuildOpTestCase(op_type, input_defs, {}, kOnnxDomain), // baseline float32 model - BuildQDQOpTestCase(op_type, input_defs, {}, kOnnxDomain), // QDQ model + TestQDQModelAccuracy(BuildOpTestCase(op_type, input_defs, {}, {}, kOnnxDomain), // baseline float32 model + BuildQDQOpTestCase(op_type, input_defs, {}, {}, kOnnxDomain), // QDQ model provider_options, opset, - expected_ep_assignment, - 1e-4f); + expected_ep_assignment); } // diff --git a/onnxruntime/test/providers/qnn/pool_op_test.cpp b/onnxruntime/test/providers/qnn/pool_op_test.cpp index fee10a542fb82..7ed9072a95b32 100644 --- a/onnxruntime/test/providers/qnn/pool_op_test.cpp +++ b/onnxruntime/test/providers/qnn/pool_op_test.cpp @@ -17,21 +17,6 @@ namespace onnxruntime { namespace test { -// Returns a function that creates a graph with a single MaxPool operator. -static GetTestModelFn BuildPoolTestCase(const std::string& op_type, - const TestInputDef& input_def, - const std::vector& attrs) { - return [op_type, input_def, attrs](ModelTestBuilder& builder) { - NodeArg* input = MakeTestInput(builder, input_def); - NodeArg* output = builder.MakeOutput(); - Node& pool_node = builder.AddNode(op_type, {input}, {output}); - - for (const auto& attr : attrs) { - pool_node.AddAttributeProto(attr); - } - }; -} - // Returns a function that creates a graph with a QDQ MaxPool operator. template GetTestQDQModelFn BuildPoolQDQTestCase(const std::string& op_type, @@ -74,7 +59,7 @@ static void RunPoolOpTest(const std::string& op_type, provider_options["backend_path"] = "libQnnCpu.so"; #endif - RunQnnModelTest(BuildPoolTestCase(op_type, input_def, attrs), + RunQnnModelTest(BuildOpTestCase(op_type, {input_def}, {}, attrs), provider_options, opset, expected_ep_assignment); @@ -95,7 +80,7 @@ static void RunQDQPoolOpTest(const std::string& op_type, provider_options["backend_path"] = "libQnnHtp.so"; #endif - TestQDQModelAccuracy(BuildPoolTestCase(op_type, input_def, attrs), + TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, attrs), BuildPoolQDQTestCase(op_type, input_def, attrs), provider_options, opset, diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.cc b/onnxruntime/test/providers/qnn/qnn_test_utils.cc index 724e9a11cd781..51df93f8853ec 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.cc +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.cc @@ -73,7 +73,7 @@ void RunQnnModelTest(const GetTestModelFn& build_test_case, const ProviderOption void InferenceModel(const std::string& model_data, const char* log_id, std::unique_ptr execution_provider, ExpectedEPNodeAssignment expected_ep_assignment, const NameMLValMap& feeds, - std::vector& output_names, std::vector& output_vals) { + std::vector& output_vals) { SessionOptions so; so.session_logid = log_id; RunOptions run_options; @@ -102,14 +102,12 @@ void InferenceModel(const std::string& model_data, const char* log_id, } const auto& outputs = graph.GetOutputs(); + std::vector output_names; - // fetch all outputs if necessary. - if (output_names.empty()) { - output_names.reserve(outputs.size()); - for (const auto* node_arg : outputs) { - if (node_arg->Exists()) { - output_names.push_back(node_arg->Name()); - } + output_names.reserve(outputs.size()); + for (const auto* node_arg : outputs) { + if (node_arg->Exists()) { + output_names.push_back(node_arg->Name()); } } diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.h b/onnxruntime/test/providers/qnn/qnn_test_utils.h index fd572fa17f2b1..14c62f98f6a3e 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.h +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.h @@ -213,13 +213,12 @@ inline QuantParams GetTestInputQuantParams(const TestInputDef& inp * \param execution_provider The EP on which to run the model. Set to nullptr for CPU EP. * \param expected_ep_assignment Describes "which nodes" should be assigned to the EP. * \param feeds The input feeds. - * \param output_names If empty, the function will write the output names. * \param output_vals Initialized to the inference results. */ void InferenceModel(const std::string& model_data, const char* log_id, std::unique_ptr execution_provider, ExpectedEPNodeAssignment expected_ep_assignment, const NameMLValMap& feeds, - std::vector& output_names, std::vector& output_vals); + std::vector& output_vals); /** * Tests the accuracy of a QDQ model on QNN EP by runnning 3 inferences: @@ -263,9 +262,8 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe // Run f32 model on CPU EP and collect outputs. std::vector cpu_f32_outputs; - std::vector output_names; InferenceModel(f32_model_data, "f32_model_logger", nullptr, ExpectedEPNodeAssignment::All, - f32_helper.feeds_, output_names, cpu_f32_outputs); + f32_helper.feeds_, cpu_f32_outputs); ASSERT_FALSE(cpu_f32_outputs.empty()); const size_t num_outputs = cpu_f32_outputs.size(); @@ -304,13 +302,13 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe // Run QDQ model on QNN EP and collect outputs. std::vector qnn_qdq_outputs; InferenceModel(qdq_model_data, "qdq_model_logger", QnnExecutionProviderWithOptions(qnn_options), - expected_ep_assignment, qdq_helper.feeds_, output_names, qnn_qdq_outputs); + expected_ep_assignment, qdq_helper.feeds_, qnn_qdq_outputs); if (expected_ep_assignment != ExpectedEPNodeAssignment::None) { // Run QDQ model on CPU EP and collect outputs. std::vector cpu_qdq_outputs; InferenceModel(qdq_model_data, "qdq_model_logger", nullptr, ExpectedEPNodeAssignment::All, - qdq_helper.feeds_, output_names, cpu_qdq_outputs); + qdq_helper.feeds_, cpu_qdq_outputs); ASSERT_EQ(cpu_qdq_outputs.size(), num_outputs); ASSERT_EQ(qnn_qdq_outputs.size(), num_outputs); @@ -320,7 +318,9 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe // Compare accuracy of QDQ results with float model. // QNN EP must be at least as accurate as CPU EP when running the QDQ model. + const std::string base_output_name = "output_"; for (size_t i = 0; i < num_outputs; i++) { + std::string debug_output_name = base_output_name + std::to_string(i); auto& cpu_qdq_tensor = cpu_qdq_outputs[i].Get(); auto& qnn_qdq_tensor = qnn_qdq_outputs[i].Get(); @@ -353,8 +353,7 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe } EXPECT_TRUE(is_as_accurate_as_cpu_qdq) - << "Inaccuracy detected for output '" - << output_names[i] + << "Inaccuracy detected for output '" << debug_output_name << "', element " << j << ".\nOutput quant params: scale=" << output_qparams[i].scale << ", zero_point=" << static_cast(output_qparams[i].zero_point) @@ -363,7 +362,7 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe << "CPU QDQ val: " << cpu_qdq_val << " (err " << cpu_err << ")"; } } else { - VerifyOutput(output_names[i], cpu_f32_outputs[i].Get(), qnn_qdq_tensor, fp32_abs_err); + VerifyOutput(debug_output_name, cpu_f32_outputs[i].Get(), qnn_qdq_tensor, fp32_abs_err); } } } @@ -438,25 +437,33 @@ NodeArg* MakeTestQDQBiasInput(ModelTestBuilder& builder, const TestInputDef +template inline GetTestModelFn BuildOpTestCase(const std::string& op_type, - const std::vector>& input_defs, + const std::vector>& input_defs_1, + const std::vector>& input_defs_2, const std::vector& attrs, const std::string& op_domain = kOnnxDomain) { - return [op_type, input_defs, attrs, op_domain](ModelTestBuilder& builder) { + return [op_type, input_defs_1, input_defs_2, attrs, op_domain](ModelTestBuilder& builder) { std::vector op_inputs; - op_inputs.reserve(input_defs.size()); + op_inputs.reserve(input_defs_1.size() + input_defs_2.size()); + + for (const auto& input_def : input_defs_1) { + NodeArg* input = MakeTestInput(builder, input_def); + op_inputs.push_back(input); + } - for (const auto& input_def : input_defs) { - NodeArg* input = MakeTestInput(builder, input_def); + for (const auto& input_def : input_defs_2) { + NodeArg* input = MakeTestInput(builder, input_def); op_inputs.push_back(input); } @@ -470,7 +477,8 @@ inline GetTestModelFn BuildOpTestCase(const std::string& op_type, } /** - * Returns a function that builds a model with a single QDQ operator with N inputs of the same element type. + * Returns a function that builds a model with a single QDQ operator with N float (quantizeable) inputs + * and M inputs of a potentially different type. * * \param op_type The operator to instantiate. * \param input_defs List of input definitions. @@ -478,25 +486,33 @@ inline GetTestModelFn BuildOpTestCase(const std::string& op_type, * \param op_domain The operator's domain. Defaults to the ONNX domain (i.e., ""). * \returns A model building function. */ -template -inline GetTestQDQModelFn BuildQDQOpTestCase(const std::string& op_type, - const std::vector>& input_defs, - const std::vector& attrs, - const std::string& op_domain = kOnnxDomain, - bool use_contrib_qdq = false) { - return [op_type, input_defs, attrs, op_domain, - use_contrib_qdq](ModelTestBuilder& builder, std::vector>& output_qparams) { +template +inline GetTestQDQModelFn BuildQDQOpTestCase(const std::string& op_type, + const std::vector>& quant_input_defs, + const std::vector>& non_quant_input_defs, + const std::vector& attrs, + const std::string& op_domain = kOnnxDomain, + bool use_contrib_qdq = false) { + return [op_type, quant_input_defs, non_quant_input_defs, attrs, op_domain, + use_contrib_qdq](ModelTestBuilder& builder, std::vector>& output_qparams) { std::vector op_inputs; - op_inputs.reserve(input_defs.size()); + op_inputs.reserve(quant_input_defs.size() + non_quant_input_defs.size()); - for (const auto& input_def : input_defs) { + // Create QDQ inputs + for (const auto& input_def : quant_input_defs) { NodeArg* input = MakeTestInput(builder, input_def); - QuantParams input_qparams = GetTestInputQuantParams(input_def); - NodeArg* input_after_qdq = AddQDQNodePair(builder, input, input_qparams.scale, - input_qparams.zero_point, use_contrib_qdq); + QuantParams input_qparams = GetTestInputQuantParams(input_def); + NodeArg* input_after_qdq = AddQDQNodePair(builder, input, input_qparams.scale, + input_qparams.zero_point, use_contrib_qdq); op_inputs.push_back(input_after_qdq); } + // Create non-QDQ inputs + for (const auto& input_def : non_quant_input_defs) { + NodeArg* input = MakeTestInput(builder, input_def); + op_inputs.push_back(input); + } + // Op -> op_output auto* op_output = builder.MakeIntermediate(); Node& onnx_node = builder.AddNode(op_type, op_inputs, {op_output}, op_domain); @@ -506,8 +522,8 @@ inline GetTestQDQModelFn BuildQDQOpTestCase(const std::string& op_ty } // op_output -> Q -> DQ -> output - AddQDQNodePairWithOutputAsGraphOutput(builder, op_output, output_qparams[0].scale, - output_qparams[0].zero_point, use_contrib_qdq); + AddQDQNodePairWithOutputAsGraphOutput(builder, op_output, output_qparams[0].scale, + output_qparams[0].zero_point, use_contrib_qdq); }; } diff --git a/onnxruntime/test/providers/qnn/reshape_op_test.cc b/onnxruntime/test/providers/qnn/reshape_op_test.cc new file mode 100644 index 0000000000000..eb495e44ec770 --- /dev/null +++ b/onnxruntime/test/providers/qnn/reshape_op_test.cc @@ -0,0 +1,225 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include + +#include "test/providers/qnn/qnn_test_utils.h" +#include "core/graph/node_attr_utils.h" + +#include "onnx/onnx_pb.h" +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +// Runs a model with a Reshape operator on the QNN CPU backend. Checks the graph node assignment +// and that inference outputs for QNN EP and CPU EP match. +template +static void RunReshapeTestOnCPU(const TestInputDef& input_def, + const TestInputDef& shape_def, + const std::vector& attrs, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 19) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnCpu.dll"; +#else + provider_options["backend_path"] = "libQnnCpu.so"; +#endif + + RunQnnModelTest(BuildOpTestCase("Reshape", {input_def}, {shape_def}, attrs), + provider_options, + opset, + expected_ep_assignment); +} + +// +// CPU tests: +// + +// Test that Reshape with a dynamic shape input is not supported by QNN EP. +TEST_F(QnnCPUBackendTests, Reshape_DynamicShape_Unsupported) { + RunReshapeTestOnCPU(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + TestInputDef({2}, false /* is_initializer */, {1, 48}), + {}, // Attributes + ExpectedEPNodeAssignment::None, // Should not be assigned to QNN EP. + 19); // Opset +} + +// Test that Reshape with an enabled 'allowzero' attribute is not supported by QNN EP. +TEST_F(QnnCPUBackendTests, Reshape_AllowZeroAttr_Unsupported) { + RunReshapeTestOnCPU(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + TestInputDef({2}, true, {1, 48}), + {utils::MakeAttribute("allowzero", static_cast(1))}, + ExpectedEPNodeAssignment::None, // Should not be assigned to QNN EP. + 19); // Opset +} + +// Test Reshape of rank 4 -> rank 2. +TEST_F(QnnCPUBackendTests, Reshape_4D_f32) { + RunReshapeTestOnCPU(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + TestInputDef({2}, true, {1, 48}), + {}, // Attributes + ExpectedEPNodeAssignment::All, + 19); // Opset +} + +#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +// +// HTP tests: +// + +// Returns a function that creates a graph with a QDQ Reshape operator. +template +GetTestQDQModelFn BuildQDQReshapeTestCase(const TestInputDef& input_def, + const TestInputDef& shape_def, + const std::vector& attrs, + bool use_contrib_qdq = false) { + return [input_def, shape_def, attrs, + use_contrib_qdq](ModelTestBuilder& builder, + std::vector>& output_qparams) { + // input -> Q -> DQ -> + NodeArg* input = MakeTestInput(builder, input_def); + QuantParams input_qparams = GetTestInputQuantParams(input_def); + NodeArg* input_qdq = AddQDQNodePair(builder, input, input_qparams.scale, input_qparams.zero_point, + use_contrib_qdq); + + // shape input + NodeArg* shape_input = MakeTestInput(builder, shape_def); + + // Reshape op + NodeArg* reshape_output = builder.MakeIntermediate(); + Node& reshape_node = builder.AddNode("Reshape", {input_qdq, shape_input}, {reshape_output}); + + for (const auto& attr : attrs) { + reshape_node.AddAttributeProto(attr); + } + + // op_output -> Q -> DQ -> output + // NOTE: Input and output quantization parameters must be equal for Reshape. + output_qparams[0] = input_qparams; // Overwrite! + AddQDQNodePairWithOutputAsGraphOutput(builder, reshape_output, input_qparams.scale, + input_qparams.zero_point, use_contrib_qdq); + }; +} + +// Runs a model with a non-QDQ Reshape operator on the QNN HTP backend. Checks the graph node assignment +// and that inference outputs for QNN EP and CPU EP match. +template +static void RunReshapeTestOnHTP(const TestInputDef& input_def, + const TestInputDef& shape_def, + const std::vector& attrs, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 19) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + RunQnnModelTest(BuildOpTestCase("Reshape", {input_def}, {shape_def}, attrs), + provider_options, + opset, + expected_ep_assignment); +} + +// Runs a QDQ Reshape model on the QNN (HTP) EP and the ORT CPU EP. Checks the graph node assignment and that inference +// running the QDQ model on QNN EP is at least as accurate as on ORT CPU EP (compared to the baseline float32 model). +template +static void RunQDQReshapeTestOnHTP(const TestInputDef& input_def, + const TestInputDef& shape_def, + const std::vector& attrs, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 19, + bool use_contrib_qdq = false) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + auto f32_model_builder = BuildOpTestCase("Reshape", {input_def}, {shape_def}, attrs); + auto qdq_model_builder = BuildQDQReshapeTestCase(input_def, shape_def, attrs, use_contrib_qdq); + TestQDQModelAccuracy(f32_model_builder, + qdq_model_builder, + provider_options, + opset, + expected_ep_assignment); +} + +// Test that QDQ Reshape with a dynamic shape input is not supported by QNN EP. +TEST_F(QnnHTPBackendTests, Reshape_DynamicShape_Unsupported) { + RunQDQReshapeTestOnHTP(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + TestInputDef({2}, false /* is_initializer */, {1, 48}), + {}, // Attributes + ExpectedEPNodeAssignment::None, // Should not be assigned to QNN EP. + 19); // Opset +} + +// Test that QDQ Reshape with an enabled 'allowzero' attribute is not supported by QNN EP. +TEST_F(QnnHTPBackendTests, Reshape_AllowZeroAttr_Unsupported) { + RunQDQReshapeTestOnHTP(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + TestInputDef({2}, true, {1, 48}), + {utils::MakeAttribute("allowzero", static_cast(1))}, + ExpectedEPNodeAssignment::None, // Should not be assigned to QNN EP. + 19); // Opset +} + +// Test 8-bit QDQ Reshape of rank 4 -> rank 2. +TEST_F(QnnHTPBackendTests, Reshape_4D_u8) { + RunQDQReshapeTestOnHTP(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + TestInputDef({2}, true, {1, 48}), + {}, // Attributes + ExpectedEPNodeAssignment::All, + 19); // Opset +} + +// Test 16-bit QDQ Reshape of rank 4 -> rank 2. +TEST_F(QnnHTPBackendTests, Reshape_4D_u16) { + RunQDQReshapeTestOnHTP(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + TestInputDef({2}, true, {1, 48}), + {}, // Attributes + ExpectedEPNodeAssignment::All, + 19, // Opset + true); // Use com.microsoft Q/DQ ops +} + +// Test that int32 Reshape runs on HTP backend. +TEST_F(QnnHTPBackendTests, Reshape_4D_int32) { + std::vector input_data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; + RunReshapeTestOnHTP(TestInputDef({1, 3, 2, 2}, false, input_data), + TestInputDef({3}, true, {1, 1, 12}), + {}, // Attributes + ExpectedEPNodeAssignment::All, + 19); // Opset +} + +// Test QDQ Reshape with a shape value of 0 (copy dimension from input) +TEST_F(QnnHTPBackendTests, Reshape_4D_0MeansCopy) { + RunQDQReshapeTestOnHTP(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + TestInputDef({3}, true, {1, 0, 16}), // zero means copy => '(1, 3, 16)' + {}, // Attributes + ExpectedEPNodeAssignment::All, + 19); // Opset +} + +// Test QDQ Reshape with a shape value of -1 (dimension is inferred from the expected number of elements) +TEST_F(QnnHTPBackendTests, Reshape_4D_Neg1MeansInfer) { + RunQDQReshapeTestOnHTP(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + TestInputDef({3}, true, {1, 3, -1}), // -1 means infer => '(1, 3, 16)' + {}, // Attributes + ExpectedEPNodeAssignment::All, + 19); // Opset +} + +#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +} // namespace test +} // namespace onnxruntime +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc index 63498982930f5..f77c098f72116 100644 --- a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc @@ -32,7 +32,7 @@ static void RunOpTestOnCPU(const std::string& op_type, provider_options["backend_path"] = "libQnnCpu.so"; #endif - RunQnnModelTest(BuildOpTestCase(op_type, input_defs, attrs, op_domain), + RunQnnModelTest(BuildOpTestCase(op_type, input_defs, {}, attrs, op_domain), provider_options, opset_version, expected_ep_assignment); @@ -113,8 +113,8 @@ static void RunQDQOpTest(const std::string& op_type, provider_options["backend_path"] = "libQnnHtp.so"; #endif - TestQDQModelAccuracy(BuildOpTestCase(op_type, input_defs, attrs, op_domain), - BuildQDQOpTestCase(op_type, input_defs, attrs, op_domain, use_contrib_qdq), + TestQDQModelAccuracy(BuildOpTestCase(op_type, input_defs, {}, attrs, op_domain), + BuildQDQOpTestCase(op_type, input_defs, {}, attrs, op_domain, use_contrib_qdq), provider_options, opset_version, expected_ep_assignment, @@ -137,7 +137,7 @@ static void RunOpTest(const std::string& op_type, #endif // Runs model with a Q/DQ binary op and compares the outputs of the CPU and QNN EPs. - RunQnnModelTest(BuildOpTestCase(op_type, input_defs, attrs, op_domain), + RunQnnModelTest(BuildOpTestCase(op_type, input_defs, {}, attrs, op_domain), provider_options, opset_version, expected_ep_assignment); @@ -698,8 +698,8 @@ TEST_F(QnnHTPBackendTests, ContextBinaryCacheTest) { // Runs model with DQ-> Atan-> Q and compares the outputs of the CPU and QNN EPs. // 1st run will generate the Qnn context cache binary file - TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}), - BuildQDQOpTestCase(op_type, {input_def}, {}), + TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), + BuildQDQOpTestCase(op_type, {input_def}, {}, {}), provider_options, 14, ExpectedEPNodeAssignment::All); @@ -708,8 +708,8 @@ TEST_F(QnnHTPBackendTests, ContextBinaryCacheTest) { EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); // 2nd run will load and run from Qnn context cache binary file - TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}), - BuildQDQOpTestCase(op_type, {input_def}, {}), + TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), + BuildQDQOpTestCase(op_type, {input_def}, {}, {}), provider_options, 14, ExpectedEPNodeAssignment::All); diff --git a/onnxruntime/test/providers/qnn/slice_htp_test.cc b/onnxruntime/test/providers/qnn/slice_htp_test.cc index f7163f04736a5..edc079dc65276 100644 --- a/onnxruntime/test/providers/qnn/slice_htp_test.cc +++ b/onnxruntime/test/providers/qnn/slice_htp_test.cc @@ -16,51 +16,6 @@ namespace onnxruntime { namespace test { #if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) -// Function that builds a model with a Slice operator. -template -GetTestModelFn BuildSliceTestCase(const TestInputDef& data_def, - const TestInputDef& starts_def, - const TestInputDef& ends_def, - const TestInputDef& axes_def, - const TestInputDef& steps_def) { - return [data_def, starts_def, ends_def, axes_def, steps_def](ModelTestBuilder& builder) { - NodeArg* data = MakeTestInput(builder, data_def); - NodeArg* starts = MakeTestInput(builder, starts_def); - NodeArg* ends = MakeTestInput(builder, ends_def); - NodeArg* axes = MakeTestInput(builder, axes_def); - NodeArg* steps = MakeTestInput(builder, steps_def); - - NodeArg* output = builder.MakeOutput(); - builder.AddNode("Slice", {data, starts, ends, axes, steps}, {output}); - }; -} - -// Function that builds a QDQ model with a Slice operator. -template -static GetTestQDQModelFn BuildQDQSliceTestCase(const TestInputDef& data_def, - const TestInputDef& starts_def, - const TestInputDef& ends_def, - const TestInputDef& axes_def, - const TestInputDef& steps_def) { - return [data_def, starts_def, ends_def, axes_def, steps_def](ModelTestBuilder& builder, - std::vector>& output_qparams) { - NodeArg* data = MakeTestInput(builder, data_def); - QuantParams data_qparams = GetTestInputQuantParams(data_def); - NodeArg* data_qdq = AddQDQNodePair(builder, data, data_qparams.scale, data_qparams.zero_point); - - NodeArg* starts = MakeTestInput(builder, starts_def); - NodeArg* ends = MakeTestInput(builder, ends_def); - NodeArg* axes = MakeTestInput(builder, axes_def); - NodeArg* steps = MakeTestInput(builder, steps_def); - - auto* slice_output = builder.MakeIntermediate(); - builder.AddNode("Slice", {data_qdq, starts, ends, axes, steps}, {slice_output}); - - // Add output -> Q -> output_u8 - AddQDQNodePairWithOutputAsGraphOutput(builder, slice_output, output_qparams[0].scale, output_qparams[0].zero_point); - }; -} - /** * Runs an Slice model on the QNN HTP backend. Checks the graph node assignment, and that inference * outputs for QNN and CPU match. @@ -86,13 +41,14 @@ static void RunSliceQDQTest(const TestInputDef& data_def, provider_options["backend_path"] = "libQnnHtp.so"; #endif - // Runs model with DQ-> Slice -> Q and compares the outputs of the CPU and QNN EPs. - TestQDQModelAccuracy(BuildSliceTestCase(data_def, starts_def, ends_def, axes_def, steps_def), - BuildQDQSliceTestCase(data_def, starts_def, ends_def, axes_def, steps_def), + const std::vector> f32_inputs = {data_def}; + const std::vector> int64_inputs = {starts_def, ends_def, axes_def, steps_def}; + + TestQDQModelAccuracy(BuildOpTestCase("Slice", f32_inputs, int64_inputs, {}), + BuildQDQOpTestCase("Slice", f32_inputs, int64_inputs, {}), provider_options, 18, - expected_ep_assignment, - 1e-5f); + expected_ep_assignment); } /** @@ -119,12 +75,12 @@ static void RunSliceNonQDQOnHTP(const TestInputDef& data_def, #else provider_options["backend_path"] = "libQnnHtp.so"; #endif - - RunQnnModelTest(BuildSliceTestCase(data_def, starts_def, ends_def, axes_def, steps_def), + auto f32_model_builder = BuildOpTestCase("Slice", {data_def}, + {starts_def, ends_def, axes_def, steps_def}, {}); + RunQnnModelTest(f32_model_builder, provider_options, 13, - expected_ep_assignment, - 1e-5f); + expected_ep_assignment); } // Check that QNN compiles DQ -> Slice -> Q as a single unit. diff --git a/onnxruntime/test/providers/qnn/split_op_test.cc b/onnxruntime/test/providers/qnn/split_op_test.cc new file mode 100644 index 0000000000000..57e4b211777bb --- /dev/null +++ b/onnxruntime/test/providers/qnn/split_op_test.cc @@ -0,0 +1,387 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include + +#include "test/providers/qnn/qnn_test_utils.h" + +#include "onnx/onnx_pb.h" +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +template +GetTestModelFn BuildSplitTestCase(const TestInputDef& input_def, + const std::vector& split, bool split_is_input, + int64_t axis, int64_t num_outputs) { + return [input_def, split, split_is_input, axis, num_outputs](ModelTestBuilder& builder) { + std::vector op_inputs; + + op_inputs.push_back(MakeTestInput(builder, input_def)); + + if (split_is_input && !split.empty()) { + op_inputs.push_back(builder.Make1DInitializer(split)); + } + + // Determine the actual number of outputs from the 'split' or 'num_outputs' arguments. + // In opset 18, the num_outputs attribute or the split input can determine the actual number of outputs. + // In opset 13, the split input determines the number of actual outputs. + // In opsets < 13, the split attribute determines the number of actual outputs. + size_t actual_num_outputs = (num_outputs > -1) ? static_cast(num_outputs) : split.size(); + + std::vector split_outputs; + for (size_t i = 0; i < actual_num_outputs; i++) { + split_outputs.push_back(builder.MakeOutput()); + } + + Node& split_node = builder.AddNode("Split", op_inputs, split_outputs); + + if (!split_is_input && !split.empty()) { + split_node.AddAttribute("split", split); + } + + if (num_outputs > -1) { + split_node.AddAttribute("num_outputs", num_outputs); + } + + split_node.AddAttribute("axis", axis); + }; +} + +template +static void RunSplitOpTestOnCPU(const TestInputDef& input_def, + const std::vector& split, + int64_t axis, + int64_t num_outputs, + int opset, + ExpectedEPNodeAssignment expected_ep_assignment) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnCpu.dll"; +#else + provider_options["backend_path"] = "libQnnCpu.so"; +#endif + + const bool split_is_input = opset >= 13; + RunQnnModelTest(BuildSplitTestCase(input_def, split, split_is_input, axis, num_outputs), + provider_options, + opset, + expected_ep_assignment); +} + +// +// CPU tests: +// + +// Test Split opset 18 on CPU backend: equal split of axis 0 via 'num_outputs' attribute +// and 'split' input. +TEST_F(QnnCPUBackendTests, Split_Equal_Axis0_Opset18) { + // Use 'split' input (initializer). + RunSplitOpTestOnCPU(TestInputDef({4, 2}, false, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f}), + {2, 2}, // split + 0, // axis + -1, // num_outputs + 18, // opset + ExpectedEPNodeAssignment::All); + RunSplitOpTestOnCPU(TestInputDef({4, 2}, false, {1, 2, 3, 4, 5, 6, 7, 8}), + {2, 2}, // split + 0, // axis + -1, // num_outputs + 18, // opset + ExpectedEPNodeAssignment::All); + + // Use 'num_outputs' attribute. + RunSplitOpTestOnCPU(TestInputDef({4, 2}, false, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f}), + {}, // split (use num_outputs instead) + 0, // axis + 2, // num_outputs + 18, // opset + ExpectedEPNodeAssignment::All); + RunSplitOpTestOnCPU(TestInputDef({4, 2}, false, {1, 2, 3, 4, 5, 6, 7, 8}), + {}, // split (use num_outputs instead) + 0, // axis + 2, // num_outputs + 18, // opset + ExpectedEPNodeAssignment::All); +} + +// Test Split opset 13 on CPU backend: equal split of axis 0 +TEST_F(QnnCPUBackendTests, Split_Equal_Axis0_Opset13) { + RunSplitOpTestOnCPU(TestInputDef({4, 2}, false, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f}), + {2, 2}, // split + 0, // axis + -1, // num_outputs (not in opset 13) + 13, // opset + ExpectedEPNodeAssignment::All); + RunSplitOpTestOnCPU(TestInputDef({4, 2}, false, {1, 2, 3, 4, 5, 6, 7, 8}), + {2, 2}, // split + 0, // axis + -1, // num_outputs (not in opset 13) + 13, // opset + ExpectedEPNodeAssignment::All); +} + +// Test Split opset 11 on CPU backend: equal split of axis 0 +TEST_F(QnnCPUBackendTests, Split_Equal_Axis0_Opset11) { + RunSplitOpTestOnCPU(TestInputDef({4, 2}, false, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f}), + {2, 2}, // split + 0, // axis + -1, // num_outputs (not in opset 11) + 11, // opset + ExpectedEPNodeAssignment::All); + RunSplitOpTestOnCPU(TestInputDef({4, 2}, false, {1, 2, 3, 4, 5, 6, 7, 8}), + {2, 2}, // split + 0, // axis + -1, // num_outputs (not in opset 11) + 11, // opset + ExpectedEPNodeAssignment::All); +} + +// Test Split opset 13 on CPU backend: unequal split of axis 1 +TEST_F(QnnCPUBackendTests, Split_Unequal_Axis1_Opset13) { + RunSplitOpTestOnCPU(TestInputDef({2, 4}, false, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f}), + {1, 3}, // split + 1, // axis + -1, // num_outputs (not in opset 13) + 13, // opset + ExpectedEPNodeAssignment::All); + RunSplitOpTestOnCPU(TestInputDef({2, 4}, false, {1, 2, 3, 4, 5, 6, 7, 8}), + {1, 3}, // split + 1, // axis + -1, // num_outputs (not in opset 13) + 13, // opset + ExpectedEPNodeAssignment::All); +} + +// Test Split opset 11 on CPU backend: unequal split of axis 1 +TEST_F(QnnCPUBackendTests, Split_Unequal_Axis1_Opset11) { + RunSplitOpTestOnCPU(TestInputDef({2, 4}, false, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f}), + {1, 3}, // split + 1, // axis + -1, // num_outputs (not in opset 11) + 11, // opset + ExpectedEPNodeAssignment::All); + RunSplitOpTestOnCPU(TestInputDef({2, 4}, false, {1, 2, 3, 4, 5, 6, 7, 8}), + {1, 3}, // split + 1, // axis + -1, // num_outputs (not in opset 11) + 11, // opset + ExpectedEPNodeAssignment::All); +} + +#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +// +// HTP tests: +// + +// Return function that builds a model with a QDQ Split. +template +GetTestQDQModelFn BuildQDQSplitTestCase(const TestInputDef& input_def, + const std::vector& split, + bool split_is_input, + int64_t axis, + int64_t num_outputs, + bool use_contrib_qdq = false) { + return [input_def, split, split_is_input, axis, num_outputs, + use_contrib_qdq](ModelTestBuilder& builder, + std::vector>& output_qparams) { + std::vector op_inputs; + + // Add QDQ input + NodeArg* input = MakeTestInput(builder, input_def); + QuantParams input_qparams = GetTestInputQuantParams(input_def); + NodeArg* input_after_qdq = AddQDQNodePair(builder, input, input_qparams.scale, + input_qparams.zero_point, use_contrib_qdq); + op_inputs.push_back(input_after_qdq); + + // Add split input + if (split_is_input && !split.empty()) { + op_inputs.push_back(builder.Make1DInitializer(split)); + } + + // Determine the actual number of outputs from the 'split' or 'num_outputs' arguments. + // In opset 18, the num_outputs attribute or the split input can determine the actual number of outputs. + // In opset 13, the split input determines the number of actual outputs. + // In opsets < 13, the split attribute determines the number of actual outputs. + size_t actual_num_outputs = (num_outputs > -1) ? static_cast(num_outputs) : split.size(); + + std::vector split_outputs; + for (size_t i = 0; i < actual_num_outputs; i++) { + split_outputs.push_back(builder.MakeIntermediate()); + } + + Node& split_node = builder.AddNode("Split", op_inputs, split_outputs); + + if (!split_is_input && !split.empty()) { + split_node.AddAttribute("split", split); + } + + if (num_outputs > -1) { + split_node.AddAttribute("num_outputs", num_outputs); + } + + split_node.AddAttribute("axis", axis); + + // op_output -> Q -> DQ -> output + assert(output_qparams.size() == actual_num_outputs); + for (size_t i = 0; i < actual_num_outputs; i++) { + // NOTE: Input and output quantization parameters must be equal for Split. + output_qparams[i] = input_qparams; + AddQDQNodePairWithOutputAsGraphOutput(builder, split_outputs[i], output_qparams[i].scale, + output_qparams[i].zero_point, use_contrib_qdq); + } + }; +} + +// Runs a non-QDQ Split operator on the HTP backend. +template +static void RunSplitOpTestOnHTP(const TestInputDef& input_def, + const std::vector& split, + int64_t axis, + int64_t num_outputs, + int opset, + ExpectedEPNodeAssignment expected_ep_assignment) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + const bool split_is_input = opset >= 13; + RunQnnModelTest(BuildSplitTestCase(input_def, split, split_is_input, axis, num_outputs), + provider_options, + opset, + expected_ep_assignment); +} + +// Runs a QDQ Split operator on the HTP backend. +template +static void RunQDQSplitOpTestOnHTP(const TestInputDef& input_def, + const std::vector& split, + int64_t axis, + int64_t num_outputs, + int opset, + ExpectedEPNodeAssignment expected_ep_assignment, + bool use_contrib_qdq = false) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + const bool split_is_input = opset >= 13; + auto f32_model_builder = BuildSplitTestCase(input_def, split, split_is_input, axis, num_outputs); + auto qdq_model_builder = BuildQDQSplitTestCase(input_def, split, split_is_input, axis, num_outputs, + use_contrib_qdq); + TestQDQModelAccuracy(f32_model_builder, + qdq_model_builder, + provider_options, + opset, + expected_ep_assignment); +} + +// Test that HTP can run non-QDQ Split (int32 input). +TEST_F(QnnHTPBackendTests, Split_Int32_Opset13) { + // Equal split. + RunSplitOpTestOnHTP(TestInputDef({4, 2}, false, {1, 2, 3, 4, 5, 6, 7, 8}), + {2, 2}, // split + 0, // axis + -1, // num_outputs (not in opset 13) + 13, // opset + ExpectedEPNodeAssignment::All); +} + +// Test 8-bit QDQ Split opset 18 on HTP backend: equal split of axis 0 via 'num_outputs' attribute +// and 'split' input. +TEST_F(QnnHTPBackendTests, Split_Equal_Axis0_Opset18) { + // Use 'split' input (initializer). + RunQDQSplitOpTestOnHTP(TestInputDef({4, 2}, false, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f}), + {2, 2}, // split + 0, // axis + -1, // num_outputs + 18, // opset + ExpectedEPNodeAssignment::All); + + // Use 'num_outputs' attribute. + RunQDQSplitOpTestOnHTP(TestInputDef({4, 2}, false, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f}), + {}, // split (use num_outputs instead) + 0, // axis + 2, // num_outputs + 18, // opset + ExpectedEPNodeAssignment::All); +} + +// Test 16-bit QDQ Split opset 18 on HTP backend: equal split of axis 0 via 'num_outputs' attribute +// and 'split' input. +TEST_F(QnnHTPBackendTests, Split_Equal_Axis0_Opset18_U16) { + // Use 'split' input (initializer). + RunQDQSplitOpTestOnHTP(TestInputDef({4, 2}, false, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f}), + {2, 2}, // split + 0, // axis + -1, // num_outputs + 18, // opset + ExpectedEPNodeAssignment::All, + true); // Use com.microsoft Q/DQ ops + + // Use 'num_outputs' attribute. + RunQDQSplitOpTestOnHTP(TestInputDef({4, 2}, false, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f}), + {}, // split (use num_outputs instead) + 0, // axis + 2, // num_outputs + 18, // opset + ExpectedEPNodeAssignment::All, + true); // Use com.microsoft Q/DQ ops +} + +// Test QDQ Split op on HTP backend: equal split on axis 0 with opset 13. +TEST_F(QnnHTPBackendTests, Split_Equal_Axis0_Opset13) { + RunQDQSplitOpTestOnHTP(TestInputDef({4, 2}, false, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f}), + {2, 2}, // split + 0, // axis + -1, // num_outputs (not in opset 13) + 13, // opset + ExpectedEPNodeAssignment::All); +} + +// Test QDQ Split op on HTP backend: equal split on axis 0 with opset 11. +TEST_F(QnnHTPBackendTests, Split_Equal_Axis0_Opset11) { + RunQDQSplitOpTestOnHTP(TestInputDef({4, 2}, false, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f}), + {2, 2}, // split + 0, // axis + -1, // num_outputs (not in opset 11) + 11, // opset + ExpectedEPNodeAssignment::All); +} + +// Test Split opset 13 on HTP backend: unequal split of axis 1 +TEST_F(QnnHTPBackendTests, Split_Unequal_Axis1_Opset13) { + RunQDQSplitOpTestOnHTP(TestInputDef({2, 4}, false, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f}), + {1, 3}, // split + 1, // axis + -1, // num_outputs (not in opset 13) + 13, // opset + ExpectedEPNodeAssignment::All); +} + +// Test Split opset 11 on HTP backend: unequal split of axis 1 +TEST_F(QnnHTPBackendTests, Split_Unequal_Axis1_Opset11) { + RunQDQSplitOpTestOnHTP(TestInputDef({2, 4}, false, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f}), + {1, 3}, // split + 1, // axis + -1, // num_outputs (not in opset 11) + 11, // opset + ExpectedEPNodeAssignment::All); +} + +#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +} // namespace test +} // namespace onnxruntime +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/providers/qnn/squeeze_unsqueeze_op_test.cc b/onnxruntime/test/providers/qnn/squeeze_unsqueeze_op_test.cc new file mode 100644 index 0000000000000..33d2f64c0315e --- /dev/null +++ b/onnxruntime/test/providers/qnn/squeeze_unsqueeze_op_test.cc @@ -0,0 +1,324 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include + +#include "test/providers/qnn/qnn_test_utils.h" + +#include "onnx/onnx_pb.h" +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +// Runs a model with a Squeeze (or Unsqueeze) operator on the QNN CPU backend. Checks the graph node assignment +// and that inference outputs for QNN EP and CPU EP match. +template +static void RunSqueezeTestOnCPU(const std::string& op_type, // Squeeze or Unsqueeze + const TestInputDef& input_def, + const TestInputDef& axes_def, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 13) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnCpu.dll"; +#else + provider_options["backend_path"] = "libQnnCpu.so"; +#endif + + RunQnnModelTest(BuildOpTestCase(op_type, {input_def}, {axes_def}, {}), + provider_options, + opset, + expected_ep_assignment); +} + +// +// CPU tests: +// + +// Test that Squeeze with a dynamic axes input is not supported by QNN EP. +TEST_F(QnnCPUBackendTests, Squeeze_DynamicAxes_Unsupported) { + RunSqueezeTestOnCPU("Squeeze", + TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + TestInputDef({1}, false /* is_initializer */, {0}), + ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP. +} + +// Test that Unsqueeze with a dynamic axes input is not supported by QNN EP. +TEST_F(QnnCPUBackendTests, Unsqueeze_DynamicAxes_Unsupported) { + RunSqueezeTestOnCPU("Unsqueeze", + TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + TestInputDef({1}, false /* is_initializer */, {0}), + ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP. +} + +// Test Squeeze of rank 5 -> rank 2. +TEST_F(QnnCPUBackendTests, Squeeze_Rank5_Rank2_f32) { + RunSqueezeTestOnCPU("Squeeze", + TestInputDef({1, 3, 1, 2, 4}, false, -10.0f, 10.0f), + TestInputDef({2}, true, {0, 2}), // Squeeze axes 0 and 2 => (3, 2, 4) + ExpectedEPNodeAssignment::All); +} + +// Test Squeeze of rank 4 -> rank 3 with a negative axes value. +TEST_F(QnnCPUBackendTests, Squeeze_Rank4_Rank3_NegAxes_f32) { + RunSqueezeTestOnCPU("Squeeze", + TestInputDef({1, 3, 2, 1}, false, -10.0f, 10.0f), + TestInputDef({1}, true, {-1}), // Squeeze last axis => (1, 3, 2) + ExpectedEPNodeAssignment::All); +} + +// Test Unsqueeze of rank 3 -> rank 5. +TEST_F(QnnCPUBackendTests, Unsqueeze_Rank3_Rank5_f32) { + RunSqueezeTestOnCPU("Unsqueeze", + TestInputDef({3, 2, 4}, false, -10.0f, 10.0f), + TestInputDef({2}, true, {0, 2}), // Add 1's => (1, 3, 1, 2, 4) + ExpectedEPNodeAssignment::All); +} + +// Test Unsqueeze of rank 3 -> rank 4 with a negative axes value. +TEST_F(QnnCPUBackendTests, Unsqueeze_Rank3_Rank4_NegAxes_f32) { + RunSqueezeTestOnCPU("Unsqueeze", + TestInputDef({1, 3, 2}, false, -10.0f, 10.0f), + TestInputDef({1}, true, {-1}), // Add 1 as last axis => (1, 3, 2, 1) + ExpectedEPNodeAssignment::All); +} + +#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +// +// HTP tests: +// + +// Returns a function that creates a graph with a QDQ (Un)Squeeze operator. +template +GetTestQDQModelFn BuildQDQSqueezeTestCase(const std::string& op_type, // Squeeze or Unsqueeze + const TestInputDef& input_def, + const TestInputDef& axes_def, + bool use_contrib_qdq = false) { + return [op_type, input_def, axes_def, + use_contrib_qdq](ModelTestBuilder& builder, + std::vector>& output_qparams) { + // input -> Q -> DQ -> + NodeArg* input = MakeTestInput(builder, input_def); + QuantParams input_qparams = GetTestInputQuantParams(input_def); + NodeArg* input_qdq = AddQDQNodePair(builder, input, input_qparams.scale, input_qparams.zero_point, + use_contrib_qdq); + + // axes input + NodeArg* axes_input = MakeTestInput(builder, axes_def); + + // (Un)Squeeze op + NodeArg* op_output = builder.MakeIntermediate(); + builder.AddNode(op_type, {input_qdq, axes_input}, {op_output}); + + // op_output -> Q -> DQ -> output + // NOTE: Input and output quantization parameters must be equal for (Un)Squeeze. + output_qparams[0] = input_qparams; // Overwrite! + AddQDQNodePairWithOutputAsGraphOutput(builder, op_output, input_qparams.scale, + input_qparams.zero_point, use_contrib_qdq); + }; +} + +// Runs a model with a non-QDQ (Un)Squeeze operator on the QNN HTP backend. Checks the graph node assignment +// and that inference outputs for QNN EP and CPU EP match. +template +static void RunSqueezeTestOnHTP(const std::string& op_type, // Squeeze or Unsqueeze + const TestInputDef& input_def, + const TestInputDef& axes_def, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 13) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + RunQnnModelTest(BuildOpTestCase(op_type, {input_def}, {axes_def}, {}), + provider_options, + opset, + expected_ep_assignment); +} + +// Runs a QDQ (Un)Squeeze model on the QNN (HTP) EP and the ORT CPU EP. Checks the graph node assignment and +// that inference running the QDQ model on QNN EP is at least as accurate as on ORT CPU EP +// (when compared to the baseline float32 model). +template +static void RunQDQSqueezeTestOnHTP(const std::string& op_type, + const TestInputDef& input_def, + const TestInputDef& axes_def, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 13, + bool use_contrib_qdq = false) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + auto f32_model_builder = BuildOpTestCase(op_type, {input_def}, {axes_def}, {}); + auto qdq_model_builder = BuildQDQSqueezeTestCase(op_type, input_def, axes_def, use_contrib_qdq); + + TestQDQModelAccuracy(f32_model_builder, + qdq_model_builder, + provider_options, + opset, + expected_ep_assignment); +} + +// Test that QDQ Squeeze with a dynamic axes input is not supported by QNN EP. +TEST_F(QnnHTPBackendTests, Squeeze_DynamicAxes_Unsupported) { + RunQDQSqueezeTestOnHTP("Squeeze", + TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + TestInputDef({1}, false /* is_initializer */, {0}), + ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP. +} + +// Test that Unsqueeze with a dynamic axes input is not supported by QNN EP. +TEST_F(QnnHTPBackendTests, Unsqueeze_DynamicAxes_Unsupported) { + RunQDQSqueezeTestOnHTP("Unsqueeze", + TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + TestInputDef({1}, false /* is_initializer */, {0}), + ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP. +} + +// Test Squeeze of rank 5 -> rank 2. +TEST_F(QnnHTPBackendTests, Squeeze_Rank5_Rank2_f32) { + // We can't use the usual model-building functions because they add standalone Quantize and Dequantize nodes + // at the input and output. These Q/DQ ops get lowered to QNN's Quantize and Dequantize operators, which DO NOT + // support rank 5 tensors. Therefore, we have to create a test model that only instantiates the DQ -> Squeeze -> Q + // QDQ node group, which gets lowered to a single QNN Reshape node. + GetTestModelFn model_fn = [](ModelTestBuilder& builder) { + // input (u8) -> DQ -> + NodeArg* quant_input = builder.MakeInput({1, 3, 1, 2, 4}, 0, 255); + NodeArg* input_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(quant_input, 1.0f, 0, input_dq); // scale = 1.0, zp = 0 + + // axes_input -> + NodeArg* axes_input = builder.Make1DInitializer({0, 2}); // Squeeze axes 0 and 2 => (3, 2, 4) + + // Squeeze -> + NodeArg* squeeze_output = builder.MakeIntermediate(); + builder.AddNode("Squeeze", {input_dq, axes_input}, {squeeze_output}); + + // Q -> output (u8) + NodeArg* output = builder.MakeOutput(); + builder.AddQuantizeLinearNode(squeeze_output, 1.0f, 0, output); // scale = 1.0, zp = 0 + }; + + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + RunQnnModelTest(model_fn, + provider_options, + 13, // opset + ExpectedEPNodeAssignment::All); +} + +// Test 8-bit QDQ Squeeze of rank 4 -> rank 3 with a negative axes value. +TEST_F(QnnHTPBackendTests, Squeeze_Rank4_Rank3_NegAxes_u8) { + RunQDQSqueezeTestOnHTP("Squeeze", + TestInputDef({1, 3, 2, 1}, false, -10.0f, 10.0f), + TestInputDef({1}, true, {-1}), // Squeeze last axis => (1, 3, 2) + ExpectedEPNodeAssignment::All); +} + +// Test 16-bit QDQ Squeeze of rank 4 -> rank 3 with a negative axes value. +TEST_F(QnnHTPBackendTests, Squeeze_Rank4_Rank3_NegAxes_u16) { + RunQDQSqueezeTestOnHTP("Squeeze", + TestInputDef({1, 3, 2, 1}, false, -10.0f, 10.0f), + TestInputDef({1}, true, {-1}), // Squeeze last axis => (1, 3, 2) + ExpectedEPNodeAssignment::All, + 13, // opset + true); // Use com.microsoft Q/DQ ops +} + +// Test QDQ Unsqueeze of rank 3 -> rank 5. +TEST_F(QnnHTPBackendTests, Unsqueeze_Rank3_Rank5_f32) { + // We can't use the usual model-building functions because they add standalone Quantize and Dequantize nodes + // at the input and output. These Q/DQ ops get lowered to QNN's Quantize and Dequantize operators, which DO NOT + // support rank 5 tensors. Therefore, we have to create a test model that only instantiates the DQ -> Unsqueeze -> Q + // QDQ node group, which gets lowered to a single QNN Reshape node. + GetTestModelFn model_fn = [](ModelTestBuilder& builder) { + // input (u8) -> DQ -> + NodeArg* quant_input = builder.MakeInput({3, 2, 4}, 0, 255); + NodeArg* input_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(quant_input, 1.0f, 0, input_dq); // scale = 1.0, zp = 0 + + // axes_input -> + NodeArg* axes_input = builder.Make1DInitializer({0, 2}); // Add 1's => (1, 3, 1, 2, 4) + + // Unsqueeze -> + NodeArg* unsqueeze_output = builder.MakeIntermediate(); + builder.AddNode("Unsqueeze", {input_dq, axes_input}, {unsqueeze_output}); + + // Q -> output (u8) + NodeArg* output = builder.MakeOutput(); + builder.AddQuantizeLinearNode(unsqueeze_output, 1.0f, 0, output); // scale = 1.0, zp = 0 + }; + + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + RunQnnModelTest(model_fn, + provider_options, + 13, // opset + ExpectedEPNodeAssignment::All); +} + +// Test 8-bit QDQ Unsqueeze of rank 3 -> rank 4 with a negative axes value. +TEST_F(QnnHTPBackendTests, Unsqueeze_Rank3_Rank4_NegAxes_u8) { + RunQDQSqueezeTestOnHTP("Unsqueeze", + TestInputDef({1, 3, 2}, false, -10.0f, 10.0f), + TestInputDef({1}, true, {-1}), // Add 1 as last axis => (1, 3, 2, 1) + ExpectedEPNodeAssignment::All); +} + +// Test 16-bit QDQ Unsqueeze of rank 3 -> rank 4 with a negative axes value. +TEST_F(QnnHTPBackendTests, Unsqueeze_Rank3_Rank4_NegAxes_u16) { + RunQDQSqueezeTestOnHTP("Unsqueeze", + TestInputDef({1, 3, 2}, false, -10.0f, 10.0f), + TestInputDef({1}, true, {-1}), // Add 1 as last axis => (1, 3, 2, 1) + ExpectedEPNodeAssignment::All, + 13, // opset + true); // Use com.microsoft Q/DQ ops +} + +// Test that int32 Squeeze runs on HTP backend. +TEST_F(QnnHTPBackendTests, Squeeze_Int32_Rank4_Rank3) { + std::vector input_data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; + RunSqueezeTestOnHTP("Squeeze", + TestInputDef({1, 3, 2, 2}, false, input_data), + TestInputDef({1}, true, {0}), // Squeeze 0th axis => (3, 2, 2) + ExpectedEPNodeAssignment::All); +} + +// Test that int32 Unsqueeze runs on HTP backend. +TEST_F(QnnHTPBackendTests, Unsqueeze_Int32_Rank3_Rank4) { + std::vector input_data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; + RunSqueezeTestOnHTP("Unsqueeze", + TestInputDef({3, 2, 2}, false, input_data), + TestInputDef({1}, true, {0}), // Unsqueeze 0th axis => (1, 3, 2, 2) + ExpectedEPNodeAssignment::All); +} + +#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +} // namespace test +} // namespace onnxruntime +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/providers/qnn/tile_op_test.cc b/onnxruntime/test/providers/qnn/tile_op_test.cc new file mode 100644 index 0000000000000..2b35c730ee5fe --- /dev/null +++ b/onnxruntime/test/providers/qnn/tile_op_test.cc @@ -0,0 +1,132 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include + +#include "test/providers/qnn/qnn_test_utils.h" +#include "core/graph/node_attr_utils.h" + +#include "onnx/onnx_pb.h" +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +// Runs a model with a Tile operator on the QNN CPU backend. Checks the graph node assignment +// and that inference outputs for QNN EP and CPU EP match. +template +static void RunTileTestOnCPU(const TestInputDef& input_def, + const TestInputDef& repeats_def, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 13) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnCpu.dll"; +#else + provider_options["backend_path"] = "libQnnCpu.so"; +#endif + + RunQnnModelTest(BuildOpTestCase("Tile", {input_def}, {repeats_def}, {}), + provider_options, + opset, + expected_ep_assignment); +} + +// Test that Tile with a dynamic repeats input is not supported by QNN EP. +TEST_F(QnnCPUBackendTests, Tile_DynamicRepeats_Unsupported) { + RunTileTestOnCPU(TestInputDef({2, 2}, false, {1.0f, 2.0f, 3.0f, 4.0f}), + TestInputDef({2}, false /* is_initializer */, {1, 2}), + ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP. +} + +// Test that Tile with rank 4 float input. +TEST_F(QnnCPUBackendTests, Tile_F32_Rank4) { + std::vector input_data = {-4.0f, -3.0f, -1.0f, 0.0f, 1.0f, 2.0f, 3.0f, 4.0f}; + RunTileTestOnCPU(TestInputDef({1, 2, 2, 2}, false, input_data), + TestInputDef({4}, true /* is_initializer */, {1, 2, 1, 1}), + ExpectedEPNodeAssignment::All); +} + +#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +// +// HTP tests: +// + +// Returns a function that creates a graph with a QDQ Tile operator. +template +GetTestQDQModelFn BuildQDQTileTestCase(const TestInputDef& input_def, + const TestInputDef& repeats_def, + bool use_contrib_qdq = false) { + return [input_def, repeats_def, use_contrib_qdq](ModelTestBuilder& builder, + std::vector>& output_qparams) { + // input -> Q -> DQ -> + NodeArg* input = MakeTestInput(builder, input_def); + QuantParams input_qparams = GetTestInputQuantParams(input_def); + NodeArg* input_qdq = AddQDQNodePair(builder, input, input_qparams.scale, input_qparams.zero_point, + use_contrib_qdq); + + // repeats input + NodeArg* repeats_input = MakeTestInput(builder, repeats_def); + + // Tile op + NodeArg* tile_output = builder.MakeIntermediate(); + builder.AddNode("Tile", {input_qdq, repeats_input}, {tile_output}); + + // op_output -> Q -> DQ -> output + // NOTE: Input and output quantization parameters must be equal for Tile. + output_qparams[0] = input_qparams; // Overwrite! + AddQDQNodePairWithOutputAsGraphOutput(builder, tile_output, input_qparams.scale, + input_qparams.zero_point, use_contrib_qdq); + }; +} + +// Runs a QDQ Tile model on the QNN (HTP) EP and the ORT CPU EP. Checks the graph node assignment and that inference +// running the QDQ model on QNN EP is at least as accurate as on ORT CPU EP (compared to the baseline float32 model). +template +static void RunQDQTileTestOnHTP(const TestInputDef& input_def, + const TestInputDef& repeats_def, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 13, + bool use_contrib_qdq = false) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + auto f32_model_builder = BuildOpTestCase("Tile", {input_def}, {repeats_def}, {}); + auto qdq_model_builder = BuildQDQTileTestCase(input_def, repeats_def, use_contrib_qdq); + TestQDQModelAccuracy(f32_model_builder, + qdq_model_builder, + provider_options, + opset, + expected_ep_assignment); +} + +// Test 8-bit QDQ Tile with rank 4 input. +TEST_F(QnnHTPBackendTests, Tile_U8_Rank4) { + std::vector input_data = {-4.0f, -3.0f, -1.0f, 0.0f, 1.0f, 2.0f, 3.0f, 4.0f}; + RunQDQTileTestOnHTP(TestInputDef({1, 2, 2, 2}, false, input_data), + TestInputDef({4}, true /* is_initializer */, {1, 2, 1, 1}), + ExpectedEPNodeAssignment::All); +} + +// Test 16-bit QDQ Tile with rank 4 input. +TEST_F(QnnHTPBackendTests, Tile_U16_Rank4) { + std::vector input_data = {-4.0f, -3.0f, -1.0f, 0.0f, 1.0f, 2.0f, 3.0f, 4.0f}; + RunQDQTileTestOnHTP(TestInputDef({1, 2, 2, 2}, false, input_data), + TestInputDef({4}, true /* is_initializer */, {1, 2, 1, 1}), + ExpectedEPNodeAssignment::All, + 13, // opset + true); // Use com.microsoft Q/DQ ops +} + +#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +} // namespace test +} // namespace onnxruntime +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/providers/qnn/topk_op_test.cc b/onnxruntime/test/providers/qnn/topk_op_test.cc new file mode 100644 index 0000000000000..93e725af5f20e --- /dev/null +++ b/onnxruntime/test/providers/qnn/topk_op_test.cc @@ -0,0 +1,209 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include + +#include "test/providers/qnn/qnn_test_utils.h" +#include "core/graph/node_attr_utils.h" + +#include "onnx/onnx_pb.h" +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +// Returns a function that builds a model with a TopK operator. +template +inline GetTestModelFn BuildTopKTestCase(const TestInputDef& input_def, + const TestInputDef& k_def, + const std::vector& attrs, + bool cast_output_indices = true) { + return [input_def, k_def, attrs, cast_output_indices](ModelTestBuilder& builder) { + NodeArg* input = MakeTestInput(builder, input_def); + NodeArg* k_input = MakeTestInput(builder, k_def); + + NodeArg* values_output = builder.MakeOutput(); + NodeArg* indices_output = cast_output_indices ? builder.MakeIntermediate() : builder.MakeOutput(); + Node& topk_node = builder.AddNode("TopK", {input, k_input}, {values_output, indices_output}); + + for (const auto& attr : attrs) { + topk_node.AddAttributeProto(attr); + } + + // Cast indices to uint32 + if (cast_output_indices) { + auto* uint32_indices_output = builder.MakeOutput(); + Node& cast_node = builder.AddNode("Cast", {indices_output}, {uint32_indices_output}); + const auto dst_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT32; + cast_node.AddAttribute("to", static_cast(dst_type)); + } + }; +} + +// Runs a model with a TopK operator on the QNN CPU backend. Checks the graph node assignment +// and that inference outputs for QNN EP and CPU EP match. +template +static void RunTopKTestOnCPU(const TestInputDef& input_def, + const TestInputDef& k_def, + const std::vector& attrs, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 19) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnCpu.dll"; +#else + provider_options["backend_path"] = "libQnnCpu.so"; +#endif + + RunQnnModelTest(BuildTopKTestCase(input_def, k_def, attrs, false /*cast_output_indices*/), + provider_options, + opset, + expected_ep_assignment); +} + +// +// CPU tests: +// + +// Test that TopK with a dynamic K input is not supported by QNN EP. +TEST_F(QnnCPUBackendTests, TopK_DynamicK_Unsupported) { + RunTopKTestOnCPU(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + TestInputDef({1}, false /* is_initializer */, {2}), + {}, // Attributes + ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP. +} + +// Test that TopK with an axis attribute that is not the last dimension is not supported by QNN EP. +TEST_F(QnnCPUBackendTests, TopK_NonLastAxis_Unsupported) { + RunTopKTestOnCPU(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + TestInputDef({1}, true /* is_initializer */, {2}), + {utils::MakeAttribute("axis", static_cast(1))}, + ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP. +} + +// Test that TopK that returns the top k minimum values is not supported by QNN EP. +TEST_F(QnnCPUBackendTests, TopK_MinValues_Unsupported) { + RunTopKTestOnCPU(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + TestInputDef({1}, true /* is_initializer */, {2}), + {utils::MakeAttribute("largest", static_cast(0))}, + ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP. +} + +// Test TopK on CPU backend: top 2 largest floats from last axis +TEST_F(QnnCPUBackendTests, TopK_LargestFloats_LastAxis) { + RunTopKTestOnCPU(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + TestInputDef({1}, true /* is_initializer */, {2}), + {}, // Attributes + ExpectedEPNodeAssignment::All); +} + +// Test TopK on CPU backend: top 2 largest int32s from last axis +TEST_F(QnnCPUBackendTests, TopK_LargestInt32s_LastAxis) { + std::vector input_data = {-6, -5, -4, -3, -2, 0, 1, 2, 3, 4, 5, 6}; + RunTopKTestOnCPU(TestInputDef({1, 2, 2, 3}, false, input_data), + TestInputDef({1}, true /* is_initializer */, {2}), + {}, // Attributes + ExpectedEPNodeAssignment::All); +} + +#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +// +// HTP tests: +// + +// Returns a function that creates a graph with a QDQ TopK operator. +template +GetTestQDQModelFn BuildQDQTopKTestCase(const TestInputDef& input_def, + const TestInputDef& k_def, + const std::vector& attrs, + bool use_contrib_qdq = false) { + return [input_def, k_def, attrs, use_contrib_qdq](ModelTestBuilder& builder, + std::vector>& output_qparams) { + // input -> Q -> DQ -> + NodeArg* input = MakeTestInput(builder, input_def); + QuantParams input_qparams = GetTestInputQuantParams(input_def); + NodeArg* input_qdq = AddQDQNodePair(builder, input, input_qparams.scale, input_qparams.zero_point, + use_contrib_qdq); + + // K input + NodeArg* k_input = MakeTestInput(builder, k_def); + + // Reshape op + NodeArg* values_output = builder.MakeIntermediate(); + NodeArg* indices_output = builder.MakeIntermediate(); + Node& topk_node = builder.AddNode("TopK", {input_qdq, k_input}, {values_output, indices_output}); + + for (const auto& attr : attrs) { + topk_node.AddAttributeProto(attr); + } + + // op_output -> Q -> DQ -> output + // NOTE: Input and output quantization parameters must be equal for Reshape. + output_qparams[0] = input_qparams; // Overwrite! + AddQDQNodePairWithOutputAsGraphOutput(builder, values_output, input_qparams.scale, + input_qparams.zero_point, use_contrib_qdq); + + // Cast indices to uint32 (HTP backend does not support int64 graph outputs) + auto* uint32_indices_output = builder.MakeOutput(); + Node& cast_node = builder.AddNode("Cast", {indices_output}, {uint32_indices_output}); + const auto dst_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT32; + cast_node.AddAttribute("to", static_cast(dst_type)); + }; +} + +// Runs a QDQ TopK model on the QNN (HTP) EP and the ORT CPU EP. Checks the graph node assignment and that inference +// running the QDQ model on QNN EP is at least as accurate as on ORT CPU EP (compared to the baseline float32 model). +template +static void RunQDQTopKTestOnHTP(const TestInputDef& input_def, + const TestInputDef& k_def, + const std::vector& attrs, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 19, + bool use_contrib_qdq = false) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + auto f32_model_builder = BuildTopKTestCase(input_def, k_def, attrs, true /*cast_output_indices*/); + auto qdq_model_builder = BuildQDQTopKTestCase(input_def, k_def, attrs, use_contrib_qdq); + TestQDQModelAccuracy(f32_model_builder, + qdq_model_builder, + provider_options, + opset, + expected_ep_assignment); +} + +// Test 8-bit QDQ TopK on HTP backend: top 2 largest floats from last axis +TEST_F(QnnHTPBackendTests, TopK_LargestFloats_U8_LastAxis) { + RunQDQTopKTestOnHTP(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + TestInputDef({1}, true /* is_initializer */, {2}), + {}, // Attributes + ExpectedEPNodeAssignment::All); +} + +// Test 16-bit QDQ TopK on HTP backend: top 2 largest floats from last axis +// TODO: Inaccuracy detected for output 'output_0', element 6. +// Output quant params: scale=0.00061036087572574615, zero_point=32768. +// Expected val: -7.2340402603149414 +// QNN QDQ val: -17.446556091308594 (err 10.212515830993652) +// CPU QDQ val: -7.2339968681335449 (err 4.3392181396484375e-05) +TEST_F(QnnHTPBackendTests, DISABLED_TopK_LargestFloats_U16_LastAxis) { + RunQDQTopKTestOnHTP(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-20.0f, 20.0f, 48)), + TestInputDef({1}, true /* is_initializer */, {2}), + {}, // Attributes + ExpectedEPNodeAssignment::All, + 19, // opset + true); // Use com.microsoft Q/DQ ops +} + +#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +} // namespace test +} // namespace onnxruntime +#endif // !defined(ORT_MINIMAL_BUILD) From dd561f201524ea5c78d9fa26df818397139f80af Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Wed, 20 Sep 2023 18:44:23 -0700 Subject: [PATCH 03/58] Upgrade sympy (#17639) AB#17015 --- .../docker/inference/x64/python/cpu/scripts/requirements.txt | 2 +- .../github/linux/docker/scripts/manylinux/requirements.txt | 2 +- tools/ci_build/github/linux/docker/scripts/requirements.txt | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tools/ci_build/github/linux/docker/inference/x64/python/cpu/scripts/requirements.txt b/tools/ci_build/github/linux/docker/inference/x64/python/cpu/scripts/requirements.txt index c0c6505ca010d..8a9c4dac1dd58 100644 --- a/tools/ci_build/github/linux/docker/inference/x64/python/cpu/scripts/requirements.txt +++ b/tools/ci_build/github/linux/docker/inference/x64/python/cpu/scripts/requirements.txt @@ -6,5 +6,5 @@ setuptools>=41.4.0 wheel git+http://github.com/onnx/onnx.git@e2525550194ce3d8a2c4a3af451c9d9b3ae6650e#egg=onnx protobuf==3.20.2 -sympy==1.10.1 +sympy==1.12 flatbuffers diff --git a/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt b/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt index c8ff7a804e1df..6b8003c01c24d 100644 --- a/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt +++ b/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt @@ -6,6 +6,6 @@ setuptools>=41.4.0 wheel git+http://github.com/onnx/onnx.git@e2525550194ce3d8a2c4a3af451c9d9b3ae6650e#egg=onnx protobuf==3.20.2 -sympy==1.10.1 +sympy==1.12 flatbuffers neural-compressor>=2.2.1 diff --git a/tools/ci_build/github/linux/docker/scripts/requirements.txt b/tools/ci_build/github/linux/docker/scripts/requirements.txt index 2248652c98043..9dbe856753faa 100644 --- a/tools/ci_build/github/linux/docker/scripts/requirements.txt +++ b/tools/ci_build/github/linux/docker/scripts/requirements.txt @@ -7,7 +7,7 @@ setuptools>=41.4.0 wheel>=0.35.1 git+http://github.com/onnx/onnx.git@e2525550194ce3d8a2c4a3af451c9d9b3ae6650e#egg=onnx argparse -sympy==1.10.1 +sympy==1.12 flatbuffers protobuf==3.20.2 packaging From 1f991f27f1d4a8d19c877bc4d33457dd45994c80 Mon Sep 17 00:00:00 2001 From: PeixuanZuo <94887879+PeixuanZuo@users.noreply.github.com> Date: Thu, 21 Sep 2023 10:45:16 +0800 Subject: [PATCH 04/58] [ROCm] add manylinux build test for ROCm CI (#17621) manylinux build is used for nightly packaging generation and it's hard to capture issue in time when related files change. This PR add manylinux build in CI. --- .../orttraining-pai-ci-pipeline.yml | 107 +++++++++++++++++- .../docker/Dockerfile.manylinux2_28_rocm | 7 ++ 2 files changed, 111 insertions(+), 3 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/orttraining-pai-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/orttraining-pai-ci-pipeline.yml index 523390debc887..3333a7d22a41b 100644 --- a/tools/ci_build/github/azure-pipelines/orttraining-pai-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/orttraining-pai-ci-pipeline.yml @@ -11,6 +11,14 @@ pr: - 'onnxruntime/core/providers/js' name: 'orttraining_ci_$(Date:yyyyMMdd)_$(Rev:r)' +resources: + repositories: + - repository: manylinux + type: Github + endpoint: Microsoft + name: pypa/manylinux + ref: 5eda9aded5462201e6310105728d33016e637ea7 + variables: - name: video value: 44 @@ -22,7 +30,101 @@ variables: value: Release jobs: -- job: Linux_Build +- job: Linux_Build_manylinux + variables: + skipComponentGovernanceDetection: true + CCACHE_DIR: $(Pipeline.Workspace)/ccache + TODAY: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] + workspace: + clean: all + pool: onnxruntime-Ubuntu2004-AMD-CPU + timeoutInMinutes: 120 + + steps: + - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 + displayName: 'Clean Agent Directories' + condition: always() + + - checkout: self + clean: true + submodules: recursive + + - template: templates/get-docker-image-steps.yml + parameters: + Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm + Context: tools/ci_build/github/linux/docker + DockerBuildArgs: >- + --build-arg INSTALL_DEPS_EXTRA_ARGS=-tmur + --network=host --build-arg POLICY=manylinux_2_28 --build-arg PLATFORM=x86_64 + --build-arg BUILD_UID=$(id -u) + --build-arg ROCM_VERSION=$(RocmVersion) + --build-arg DEVTOOLSET_ROOTPATH=/opt/rh/gcc-toolset-12/root + --build-arg PREPEND_PATH=/opt/rh/gcc-toolset-12/root/usr/bin: + --build-arg LD_LIBRARY_PATH_ARG=/opt/rh/gcc-toolset-12/root/usr/lib64:/opt/rh/gcc-toolset-12/root/usr/lib:/opt/rh/gcc-toolset-12/root/usr/lib64/dyninst:/opt/rh/gcc-toolset-12/root/usr/lib/dyninst:/usr/local/lib64:/usr/local/lib + Repository: onnxruntimetrainingrocm-cibuild-rocm$(RocmVersion)-manylinux-build + + - task: Cache@2 + inputs: + key: '"manylinux" | "$(TODAY)" | "$(Build.SourceBranch)" | "$(Build.SourceVersion)"' + path: $(CCACHE_DIR) + cacheHitVar: CACHE_RESTORED + restoreKeys: | + "manylinux" | "$(TODAY)" | "$(Build.SourceBranch)" + "manylinux" | "$(TODAY)" | + displayName: Cache Task + + - script: mkdir -p $(CCACHE_DIR) + condition: ne(variables.CACHE_RESTORED, 'true') + displayName: Create Cache Dir + + - task: CmdLine@2 + inputs: + script: |- + export ROCM_HOME=/opt/rocm + docker run --rm \ + --ipc=host \ + --network=host \ + --cap-add=SYS_PTRACE \ + --security-opt seccomp=unconfined \ + --shm-size=1024m \ + --user $UID:$(id -g $USER) \ + -e CC=/opt/rh/gcc-toolset-12/root/usr/bin/cc -e CXX=/opt/rh/gcc-toolset-12/root/usr/bin/c++ -e CFLAGS="-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -fcf-protection -O3 -Wl,--strip-all" -e CXXFLAGS="-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -fcf-protection -O3 -Wl,--strip-all" \ + -e CCACHE_DIR=/cache \ + --volume $(Build.SourcesDirectory):/onnxruntime_src \ + --volume $(Build.BinariesDirectory):/build \ + --volume $(CCACHE_DIR):/cache \ + --workdir /onnxruntime_src \ + onnxruntimetrainingrocm-cibuild-rocm$(RocmVersion)-manylinux-build \ + /bin/bash -c " + set -ex; \ + ccache -s; \ + /opt/python/cp38-cp38/bin/python3 tools/ci_build/build.py \ + --config $(BuildConfig) \ + --enable_training \ + --mpi_home /opt/ompi \ + --cmake_extra_defines \ + CMAKE_HIP_COMPILER=${ROCM_HOME}/llvm/bin/clang++ \ + onnxruntime_BUILD_UNIT_TESTS=OFF \ + FETCHCONTENT_TRY_FIND_PACKAGE_MODE=NEVER \ + --use_cache \ + --use_rocm \ + --rocm_version=$(RocmVersion) \ + --rocm_home ${ROCM_HOME} \ + --nccl_home ${ROCM_HOME}\ + --update \ + --build_dir /build \ + --build \ + --parallel \ + --build_wheel \ + --skip_submodule_sync \ + --skip_tests; \ + ccache -sv; \ + ccache -z" + displayName: 'Build onnxruntime' + + - template: templates/explicitly-defined-final-tasks.yml + +- job: Linux_Build_ubuntu variables: skipComponentGovernanceDetection: true CCACHE_DIR: $(Pipeline.Workspace)/ccache @@ -115,8 +217,7 @@ jobs: - template: templates/explicitly-defined-final-tasks.yml - -- job: Linux_Test +- job: Linux_Test_ubuntu workspace: clean: all pool: AMD-GPU diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm index 10ce8f0ed65f7..19599c9f613d4 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm @@ -185,6 +185,13 @@ RUN cd /tmp/scripts && \ rm -rf /tmp/scripts +# Install ccache to reuse this dockerfile for CI +RUN mkdir -p /tmp/ccache && \ + cd /tmp/ccache && \ + wget -q -O - https://github.com/ccache/ccache/releases/download/v4.7.4/ccache-4.7.4-linux-x86_64.tar.xz | tar --strip 1 -J -xf - && \ + cp /tmp/ccache/ccache /usr/bin && \ + rm -rf /tmp/ccache + ARG BUILD_UID=1001 ARG BUILD_USER=onnxruntimedev RUN adduser --uid $BUILD_UID $BUILD_USER From 4f3f4366d5838f80b45060c26bcd648ee387a252 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Wed, 20 Sep 2023 19:51:50 -0700 Subject: [PATCH 05/58] Fix API 16's marker (#17640) --- onnxruntime/core/session/onnxruntime_c_api.cc | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 4c0adcdd374aa..60b6296f7f539 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2711,9 +2711,8 @@ static constexpr OrtApi ort_api_1_to_17 = { &OrtApis::GetTensorRTProviderOptionsByName, &OrtApis::UpdateCUDAProviderOptionsWithValue, &OrtApis::GetCUDAProviderOptionsByName, - // End of Version 16 - DO NOT MODIFY ABOVE (see above text for more information) - &OrtApis::KernelContext_GetResource, + // End of Version 16 - DO NOT MODIFY ABOVE (see above text for more information) }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. @@ -2742,7 +2741,7 @@ static_assert(offsetof(OrtApi, ReleaseKernelInfo) / sizeof(void*) == 218, "Size static_assert(offsetof(OrtApi, ReleaseCANNProviderOptions) / sizeof(void*) == 224, "Size of version 13 API cannot change"); static_assert(offsetof(OrtApi, GetSessionConfigEntry) / sizeof(void*) == 238, "Size of version 14 API cannot change"); static_assert(offsetof(OrtApi, GetBuildInfoString) / sizeof(void*) == 254, "Size of version 15 API cannot change"); -static_assert(offsetof(OrtApi, GetCUDAProviderOptionsByName) / sizeof(void*) == 264, "Size of version 16 API cannot change"); +static_assert(offsetof(OrtApi, KernelContext_GetResource) / sizeof(void*) == 265, "Size of version 16 API cannot change"); // So that nobody forgets to finish an API version, this check will serve as a reminder: static_assert(std::string_view(ORT_VERSION) == "1.17.0", From 038c76378fdee45261d43af45466a0797e6ad124 Mon Sep 17 00:00:00 2001 From: Pranav Sharma Date: Thu, 21 Sep 2023 00:08:10 -0700 Subject: [PATCH 06/58] Include onnxruntime_float16.h in the package. (#17637) ### Description Include onnxruntime_float16.h in the package. ### Motivation and Context This was missed in the recently released 1.16 pkgs (except Nuget). --- tools/ci_build/github/linux/copy_strip_binary.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/tools/ci_build/github/linux/copy_strip_binary.sh b/tools/ci_build/github/linux/copy_strip_binary.sh index b875a3937aaa9..63690b69fc91a 100755 --- a/tools/ci_build/github/linux/copy_strip_binary.sh +++ b/tools/ci_build/github/linux/copy_strip_binary.sh @@ -48,6 +48,7 @@ fi cp $SOURCE_DIR/include/onnxruntime/core/session/onnxruntime_c_api.h $BINARY_DIR/$ARTIFACT_NAME/include cp $SOURCE_DIR/include/onnxruntime/core/session/onnxruntime_cxx_api.h $BINARY_DIR/$ARTIFACT_NAME/include cp $SOURCE_DIR/include/onnxruntime/core/session/onnxruntime_cxx_inline.h $BINARY_DIR/$ARTIFACT_NAME/include +cp $SOURCE_DIR/include/onnxruntime/core/session/onnxruntime_float16.h $BINARY_DIR/$ARTIFACT_NAME/include cp $SOURCE_DIR/include/onnxruntime/core/providers/cpu/cpu_provider_factory.h $BINARY_DIR/$ARTIFACT_NAME/include cp $SOURCE_DIR/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h $BINARY_DIR/$ARTIFACT_NAME/include cp $SOURCE_DIR/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h $BINARY_DIR/$ARTIFACT_NAME/include From 57dfd15d7bc9d9c5779896f6685ec473875dc6e1 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Thu, 21 Sep 2023 07:33:29 -0700 Subject: [PATCH 07/58] Remove dnf update from docker build scripts (#17551) ### Description 1. Remove 'dnf update' from docker build scripts, because it upgrades TRT packages from CUDA 11.x to CUDA 12.x. To reproduce it, you can run the following commands in a CentOS CUDA 11.x docker image such as nvidia/cuda:11.8.0-cudnn8-devel-ubi8. ``` export v=8.6.1.6-1.cuda11.8 dnf install -y libnvinfer8-${v} libnvparsers8-${v} libnvonnxparsers8-${v} libnvinfer-plugin8-${v} libnvinfer-vc-plugin8-${v} libnvinfer-devel-${v} libnvparsers-devel-${v} libnvonnxparsers-devel-${v} libnvinfer-plugin-devel-${v} libnvinfer-vc-plugin-devel-${v} libnvinfer-headers-devel-${v} libnvinfer-headers-plugin-devel-${v} dnf update -y ``` The last command will generate the following outputs: ``` ======================================================================================================================== Package Architecture Version Repository Size ======================================================================================================================== Upgrading: libnvinfer-devel x86_64 8.6.1.6-1.cuda12.0 cuda 542 M libnvinfer-headers-devel x86_64 8.6.1.6-1.cuda12.0 cuda 118 k libnvinfer-headers-plugin-devel x86_64 8.6.1.6-1.cuda12.0 cuda 14 k libnvinfer-plugin-devel x86_64 8.6.1.6-1.cuda12.0 cuda 13 M libnvinfer-plugin8 x86_64 8.6.1.6-1.cuda12.0 cuda 13 M libnvinfer-vc-plugin-devel x86_64 8.6.1.6-1.cuda12.0 cuda 107 k libnvinfer-vc-plugin8 x86_64 8.6.1.6-1.cuda12.0 cuda 251 k libnvinfer8 x86_64 8.6.1.6-1.cuda12.0 cuda 543 M libnvonnxparsers-devel x86_64 8.6.1.6-1.cuda12.0 cuda 467 k libnvonnxparsers8 x86_64 8.6.1.6-1.cuda12.0 cuda 757 k libnvparsers-devel x86_64 8.6.1.6-1.cuda12.0 cuda 2.0 M libnvparsers8 x86_64 8.6.1.6-1.cuda12.0 cuda 854 k Installing dependencies: cuda-toolkit-12-0-config-common noarch 12.0.146-1 cuda 7.7 k cuda-toolkit-12-config-common noarch 12.2.140-1 cuda 7.9 k libcublas-12-0 x86_64 12.0.2.224-1 cuda 361 M libcublas-devel-12-0 x86_64 12.0.2.224-1 cuda 397 M Transaction Summary ======================================================================================================================== ``` As you can see from the output, they are CUDA 12 packages. The problem can also be solved by lock the packages' versions by using "dnf versionlock" command right after installing the CUDA/TRT packages. However, going forward, to get the better reproducibility, I suggest manually fix dnf package versions in the installation scripts like we do for TRT now. ```bash v="8.6.1.6-1.cuda11.8" &&\ yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo &&\ yum -y install libnvinfer8-${v} libnvparsers8-${v} libnvonnxparsers8-${v} libnvinfer-plugin8-${v} libnvinfer-vc-plugin8-${v}\ libnvinfer-devel-${v} libnvparsers-devel-${v} libnvonnxparsers-devel-${v} libnvinfer-plugin-devel-${v} libnvinfer-vc-plugin-devel-${v} libnvinfer-headers-devel-${v} libnvinfer-headers-plugin-devel-${v} ``` When we have a need to upgrade a package due to security alert or some other reasons, we manually change the version string instead of relying on "dnf update". Though this approach increases efforts, it can make our pipeines more stable. 2. Move python test to docker ### Motivation and Context Right now the nightly gpu package mixes using CUDA 11.x and CUDA 12.x and the result package is totally not usable(crashes every time) --- .../azure-pipelines/linux-ci-pipeline.yml | 7 +- .../py-package-test-pipeline.yml | 37 +++--- .../templates/c-api-linux-cpu.yml | 2 +- .../templates/py-package-smoking-test.yml | 28 ++--- .../templates/py-packaging-linux-test-cpu.yml | 117 ++++++++++++++++++ .../py-packaging-linux-test-cuda.yml | 98 +++++++++++++++ .../templates/py-packaging-linux-test.yml | 85 ------------- .../linux/docker/Dockerfile.manylinux2_28_cpu | 9 +- .../docker/Dockerfile.manylinux2_28_cuda11 | 5 +- ...kerfile.manylinux2_28_cuda11_6_tensorrt8_4 | 5 +- ...kerfile.manylinux2_28_cuda11_6_tensorrt8_5 | 5 +- ...kerfile.manylinux2_28_cuda11_8_tensorrt8_6 | 5 +- ...Dockerfile.manylinux2_28_training_cuda11_8 | 3 - ...erfile.package_ubuntu_cuda11_8_tensorrt8_6 | 20 +-- .../default/cpu/scripts/install_centos.sh | 7 +- .../default/cpu/scripts/install_deps.sh | 24 ++-- .../inference/x64/default/cpu/Dockerfile | 4 +- .../x64/default/cpu/scripts/install_centos.sh | 8 +- .../inference/x64/default/gpu/Dockerfile | 2 + .../x64/default/gpu/scripts/install_centos.sh | 8 +- .../python/cpu/Dockerfile.manylinux2_28_cpu | 3 - .../x64/python/cpu/scripts/install_centos.sh | 6 +- .../github/linux/docker/manylinux.patch | 9 +- .../linux/docker/scripts/install_dotnet.sh | 10 +- .../scripts/manylinux/install_centos.sh | 9 +- .../docker/scripts/manylinux/install_deps.sh | 26 ++-- .../scripts/manylinux/install_deps_aten.sh | 2 +- .../scripts/manylinux/install_deps_eager.sh | 2 +- .../github/linux/run_python_dockertest.sh | 29 +++++ .../ci_build/github/linux/run_python_tests.sh | 20 ++- tools/scripts/python_test.sh | 0 tools/scripts/symbolic_shape_infer_test.sh | 0 32 files changed, 351 insertions(+), 244 deletions(-) create mode 100644 tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cpu.yml create mode 100644 tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cuda.yml delete mode 100644 tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test.yml create mode 100755 tools/ci_build/github/linux/run_python_dockertest.sh mode change 100644 => 100755 tools/scripts/python_test.sh mode change 100644 => 100755 tools/scripts/symbolic_shape_infer_test.sh diff --git a/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml index 21bc1c481b3e6..33fc9d94bac09 100644 --- a/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml @@ -200,8 +200,11 @@ stages: - stage: arm64_test dependsOn: ['arm64_build'] jobs: - - template: templates/py-packaging-linux-test.yml + - template: templates/py-packaging-linux-test-cpu.yml parameters: arch: 'aarch64' machine_pool: 'onnxruntime-linux-ARM64-CPU-2019' - device: 'CPU' + base_image: 'arm64v8/almalinux:8' + devtoolset_rootpath: /opt/rh/gcc-toolset-12/root + ld_library_path_arg: /opt/rh/gcc-toolset-12/root/usr/lib64:/opt/rh/gcc-toolset-12/root/usr/lib:/opt/rh/gcc-toolset-12/root/usr/lib64/dyninst:/opt/rh/gcc-toolset-12/root/usr/lib/dyninst:/usr/local/lib64 + prepend_path: '/opt/rh/gcc-toolset-12/root/usr/bin:' diff --git a/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml index c684e08ba1258..2161a9205f22d 100644 --- a/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml @@ -3,24 +3,38 @@ resources: - pipeline: build source: 'Python packaging pipeline' trigger: true + branch: main # branch to pick the artifact, Used only for manual triggered pipeline runs for testing the pipeline itself + #TODO: Remove the following dependency. Running python tests should not need to use manylinux. + repositories: + - repository: manylinux # The name used to reference this repository in the checkout step + type: Github + endpoint: Microsoft + name: pypa/manylinux + ref: 5eda9aded5462201e6310105728d33016e637ea7 stages: - stage: Linux_Test_CPU_x86_64_stage jobs: - - template: templates/py-packaging-linux-test.yml + - template: templates/py-packaging-linux-test-cpu.yml parameters: arch: 'x86_64' machine_pool: 'onnxruntime-Ubuntu2004-AMD-CPU' - device: 'CPU' + base_image: 'registry.access.redhat.com/ubi8/ubi' + devtoolset_rootpath: /opt/rh/gcc-toolset-12/root + ld_library_path_arg: /opt/rh/gcc-toolset-12/root/usr/lib64:/opt/rh/gcc-toolset-12/root/usr/lib:/opt/rh/gcc-toolset-12/root/usr/lib64/dyninst:/opt/rh/gcc-toolset-12/root/usr/lib/dyninst:/usr/local/lib64 + prepend_path: '/opt/rh/gcc-toolset-12/root/usr/bin:' - stage: Linux_Test_CPU_aarch64_stage dependsOn: [] jobs: - - template: templates/py-packaging-linux-test.yml + - template: templates/py-packaging-linux-test-cpu.yml parameters: arch: 'aarch64' machine_pool: 'aiinfra-linux-ARM64-CPU-2019' - device: 'CPU' + base_image: 'arm64v8/almalinux:8' + devtoolset_rootpath: /opt/rh/gcc-toolset-12/root + ld_library_path_arg: /opt/rh/gcc-toolset-12/root/usr/lib64:/opt/rh/gcc-toolset-12/root/usr/lib:/opt/rh/gcc-toolset-12/root/usr/lib64/dyninst:/opt/rh/gcc-toolset-12/root/usr/lib/dyninst:/usr/local/lib64 + prepend_path: '/opt/rh/gcc-toolset-12/root/usr/bin:' - stage: Packages_Somking_Test dependsOn: [] @@ -31,19 +45,6 @@ stages: machine_pool: vmImage: 'macOS-13' itemPattern: '*/*mac*x86_64.whl' - - template: templates/py-package-smoking-test.yml - parameters: - job_name: Test_WIN_64_Wheels - itemPattern: '*/*win_amd64.whl' - machine_pool: - vmImage: 'windows-2022' - - template: templates/py-package-smoking-test.yml - parameters: - job_name: Test_WIN_32_Wheels - itemPattern: '*/*win32.whl' - python_arch: 'x86' - machine_pool: - vmImage: 'windows-2022' - template: templates/py-package-smoking-test.yml parameters: job_name: Test_LINUX_x86_64_Wheels @@ -61,7 +62,7 @@ stages: - Linux_Test_CPU_aarch64_stage - Packages_Somking_Test jobs: - - template: templates/py-packaging-linux-test.yml + - template: templates/py-packaging-linux-test-cuda.yml parameters: arch: 'x86_64' machine_pool: 'Onnxruntime-Linux-GPU' diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-linux-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-linux-cpu.yml index 796938dc22a67..15fcec0511741 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-linux-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-linux-cpu.yml @@ -68,7 +68,7 @@ jobs: script: | mkdir -p $HOME/.onnx docker run --rm -e CFLAGS="${{parameters.OnnxruntimeCFlags}}" -e CXXFLAGS="${{parameters.OnnxruntimeCXXFlags}}" --volume /data/onnx:/data/onnx:ro --volume $(Build.SourcesDirectory):/onnxruntime_src --volume $(Build.BinariesDirectory):/build \ - --volume $HOME/.onnx:/home/onnxruntimedev/.onnx -e NIGHTLY_BUILD onnxruntimecpubuildcentos8${{parameters.OnnxruntimeArch}} /bin/bash -c "python3 \ + --volume $HOME/.onnx:/home/onnxruntimedev/.onnx -e NIGHTLY_BUILD onnxruntimecpubuildcentos8${{parameters.OnnxruntimeArch}} /bin/bash -c "python3.9 \ /onnxruntime_src/tools/ci_build/build.py --build_java --build_nodejs --build_dir /build --config Release \ --skip_submodule_sync --parallel --build_shared_lib ${{ parameters.AdditionalBuildFlags }} && cd /build/Release && make install DESTDIR=/build/linux-${{parameters.OnnxruntimeArch}}" workingDirectory: $(Build.SourcesDirectory) diff --git a/tools/ci_build/github/azure-pipelines/templates/py-package-smoking-test.yml b/tools/ci_build/github/azure-pipelines/templates/py-package-smoking-test.yml index cee3bd9c9e968..8d5ca19a73535 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-package-smoking-test.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-package-smoking-test.yml @@ -39,36 +39,22 @@ jobs: versionSpec: $(PythonVersion) architecture: ${{ parameters.python_arch }} - - task: DownloadPipelineArtifact@2 - displayName: 'Download Pipeline Artifact' - inputs: - artifactName: 'onnxruntime' - targetPath: '$(Build.BinariesDirectory)/whl' - itemPattern: ${{parameters.itemPattern}} - # The public ADO project - ${{ if eq(variables['System.CollectionId'], 'f3ad12f2-e480-4533-baf2-635c95467d29') }}: - buildType: current - # The private ADO project - ${{ if eq(variables['System.CollectionId'], 'bc038106-a83b-4dab-9dd3-5a41bc58f34c') }}: - project: '530acbc4-21bc-487d-8cd8-348ff451d2ff' - definition: 841 - preferTriggeringPipeline: true - runVersion: 'latest' - buildType: specific + - download: build # pipeline resource identifier. + artifact: 'onnxruntime' - task: Bash@3 inputs: targetType: 'inline' script: | set -ex - files=(whl/*.whl) + files=(*.whl) FILE_NAME="${files[0]}" FILE_NAME=$(basename $FILE_NAME) PYTHON_PACKAGE_NAME=$(echo "$FILE_NAME" | cut -f 1 -d '-') - python3 -m pip install --find-links "$(Build.BinariesDirectory)/whl" $PYTHON_PACKAGE_NAME - pip show $PYTHON_PACKAGE_NAME - python -c "import onnxruntime as ort; print(ort.__version__)" - workingDirectory: $(Build.BinariesDirectory) + python3 -m pip install --find-links "$(Pipeline.Workspace)/build/onnxruntime" $PYTHON_PACKAGE_NAME + python3 -m pip show $PYTHON_PACKAGE_NAME + python3 -c "import onnxruntime as ort; print(ort.__version__)" + workingDirectory: $(Pipeline.Workspace)/build/onnxruntime displayName: Test Package Installation - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cpu.yml new file mode 100644 index 0000000000000..cc90085e184dc --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cpu.yml @@ -0,0 +1,117 @@ +parameters: +- name: arch + type: string + +- name: base_image + type: string + +- name: devtoolset_rootpath + type: string + +- name: ld_library_path_arg + type: string + +- name: prepend_path + type: string + +- name: machine_pool + type: string + +- name: extra_job_id + type: string + default: '' + +- name: python_wheel_suffix + type: string + default: '' + + +# TODO: Ideally it should fetch information from the build that triggers it +- name: cmake_build_type + type: string + default: 'Release' + values: + - Debug + - Release + - RelWithDebInfo + - MinSizeRel + +- name: timeout + type: number + default: 120 + +jobs: +- job: Linux_Test_CPU${{ parameters.extra_job_id }}_${{ parameters.arch }} + timeoutInMinutes: ${{ parameters.timeout }} + variables: + skipComponentGovernanceDetection: true + workspace: + clean: all + pool: ${{ parameters.machine_pool }} + steps: + - checkout: self + clean: true + submodules: none + # The public ADO project + - ${{ if eq(variables['System.CollectionId'], 'f3ad12f2-e480-4533-baf2-635c95467d29') }}: + - download: current # pipeline resource identifier. + artifact: 'drop-linux-cpu-${{ parameters.arch }}' + + - download: current # pipeline resource identifier. + artifact: 'onnxruntime${{ parameters.python_wheel_suffix }}' + + - bash: | + set -e -x + mv "$(Pipeline.Workspace)/drop-linux-cpu-${{ parameters.arch }}" $(Build.BinariesDirectory)/${{parameters.cmake_build_type}} + mv "$(Pipeline.Workspace)/onnxruntime${{ parameters.python_wheel_suffix }}" "$(Build.BinariesDirectory)/whl" + cp -r "$(Build.BinariesDirectory)/whl" $(Build.BinariesDirectory)/tmp + find "$(Build.BinariesDirectory)/tmp" -name '*.whl' -exec bash -c 'unzip -d "${1%.*}" "$1"' _ {} \; + # The private ADO project + - ${{ if eq(variables['System.CollectionId'], 'bc038106-a83b-4dab-9dd3-5a41bc58f34c') }}: + - download: build # pipeline resource identifier. + artifact: 'drop-linux-cpu-${{ parameters.arch }}' + + - download: build # pipeline resource identifier. + artifact: 'onnxruntime${{ parameters.python_wheel_suffix }}' + + - bash: | + set -e -x + ls $(Pipeline.Workspace)/build + mv "$(Pipeline.Workspace)/build/drop-linux-cpu-${{ parameters.arch }}" $(Build.BinariesDirectory)/${{parameters.cmake_build_type}} + mv "$(Pipeline.Workspace)/build/onnxruntime${{ parameters.python_wheel_suffix }}" "$(Build.BinariesDirectory)/whl" + cp -r "$(Build.BinariesDirectory)/whl" $(Build.BinariesDirectory)/tmp + find "$(Build.BinariesDirectory)/tmp" -name '*.whl' -exec bash -c 'unzip -d "${1%.*}" "$1"' _ {} \; + + # The BinSkim task uses a dotnet program which doesn't support ARM CPUs yet + - ${{ if eq(parameters.arch, 'x86_64') }}: + - task: BinSkim@4 + displayName: 'Run BinSkim' + inputs: + AnalyzeTargetGlob: '$(Build.BinariesDirectory)/tmp/**/*.so' + continueOnError: true + + #- task: PostAnalysis@2 + # inputs: + # GdnBreakAllTools: true + # GdnBreakPolicy: M365 + # GdnBreakPolicyMinSev: Error + + - template: get-docker-image-steps.yml + parameters: + Dockerfile: tools/ci_build/github/linux/docker/inference/x64/python/cpu/Dockerfile.manylinux2_28_cpu + Context: tools/ci_build/github/linux/docker/inference/x64/python/cpu + DockerBuildArgs: "--build-arg POLICY=manylinux_2_28 --build-arg BUILD_UID=$( id -u ) --build-arg BASEIMAGE=${{ parameters.base_image }} --build-arg PLATFORM=${{ parameters.arch }} --build-arg PREPEND_PATH=${{ parameters.prepend_path }} --build-arg LD_LIBRARY_PATH_ARG=${{ parameters.ld_library_path_arg }} --build-arg DEVTOOLSET_ROOTPATH=${{ parameters.devtoolset_rootpath }}" + Repository: onnxruntimecpubuildpython${{ parameters.arch }} + ${{ if eq(parameters.arch, 'aarch64') }}: + UpdateDepsTxt: false + + - task: Bash@3 + displayName: 'Bash Script' + inputs: + targetType: filePath + filePath: tools/ci_build/github/linux/run_python_dockertest.sh + arguments: -d CPU -c ${{parameters.cmake_build_type}} -i onnxruntimecpubuildpython${{ parameters.arch }} + + - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 + displayName: 'Clean Agent Directories' + condition: always() diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cuda.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cuda.yml new file mode 100644 index 0000000000000..43ed0172825bc --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cuda.yml @@ -0,0 +1,98 @@ +parameters: +- name: arch + type: string + +- name: device + type: string + values: + - CPU + - GPU + +- name: machine_pool + type: string + +- name: extra_job_id + type: string + default: '' + +- name: python_wheel_suffix + type: string + default: '' + + +# TODO: Ideally it should fetch information from the build that triggers it +- name: cmake_build_type + type: string + default: 'Release' + values: + - Debug + - Release + - RelWithDebInfo + - MinSizeRel + +- name: timeout + type: number + default: 120 + +jobs: +- job: Linux_Test_GPU${{ parameters.extra_job_id }}_${{ parameters.arch }} + timeoutInMinutes: ${{ parameters.timeout }} + variables: + skipComponentGovernanceDetection: true + workspace: + clean: all + pool: ${{ parameters.machine_pool }} + steps: + - checkout: self + clean: true + submodules: none + # The public ADO project + # - ${{ if eq(variables['System.CollectionId'], 'f3ad12f2-e480-4533-baf2-635c95467d29') }}: + + # The private ADO project + - ${{ if eq(variables['System.CollectionId'], 'bc038106-a83b-4dab-9dd3-5a41bc58f34c') }}: + - download: build # pipeline resource identifier. + artifact: 'drop-linux-gpu-${{ parameters.arch }}' + + - download: build # pipeline resource identifier. + artifact: 'onnxruntime${{ parameters.python_wheel_suffix }}' + + - bash: | + set -e -x + ls $(Pipeline.Workspace)/build + mv "$(Pipeline.Workspace)/build/drop-linux-gpu-${{ parameters.arch }}" $(Build.BinariesDirectory)/${{parameters.cmake_build_type}} + mv "$(Pipeline.Workspace)/build/onnxruntime${{ parameters.python_wheel_suffix }}" "$(Build.BinariesDirectory)/whl" + cp -r "$(Build.BinariesDirectory)/whl" $(Build.BinariesDirectory)/tmp + find "$(Build.BinariesDirectory)/tmp" -name '*.whl' -exec bash -c 'unzip -d "${1%.*}" "$1"' _ {} \; + + # The BinSkim task uses a dotnet program which doesn't support ARM CPUs yet + - ${{ if eq(parameters.arch, 'x86_64') }}: + - task: BinSkim@4 + displayName: 'Run BinSkim' + inputs: + AnalyzeTargetGlob: '$(Build.BinariesDirectory)/tmp/**/*.so' + continueOnError: true + + #- task: PostAnalysis@2 + # inputs: + # GdnBreakAllTools: true + # GdnBreakPolicy: M365 + # GdnBreakPolicyMinSev: Error + + - template: get-docker-image-steps.yml + parameters: + Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_8_tensorrt8_6 + Context: tools/ci_build/github/linux/docker + DockerBuildArgs: "--network=host --build-arg POLICY=manylinux_2_28 --build-arg PLATFORM=x86_64 --build-arg PREPEND_PATH=/usr/local/cuda/bin --build-arg LD_LIBRARY_PATH_ARG=/usr/local/lib64 --build-arg DEVTOOLSET_ROOTPATH=/usr --build-arg BUILD_UID=$( id -u ) --build-arg PLATFORM=${{ parameters.arch }}" + Repository: onnxruntimecuda118xtrt86build${{ parameters.arch }} + + - task: Bash@3 + displayName: 'Bash Script' + inputs: + targetType: filePath + filePath: tools/ci_build/github/linux/run_python_dockertest.sh + arguments: -d GPU -c ${{parameters.cmake_build_type}} -i onnxruntimecuda118xtrt86build${{ parameters.arch }} + + - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 + displayName: 'Clean Agent Directories' + condition: always() diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test.yml deleted file mode 100644 index 8ddc917e8591e..0000000000000 --- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test.yml +++ /dev/null @@ -1,85 +0,0 @@ -parameters: -- name: arch - type: string - -- name: device - type: string - -- name: machine_pool - type: string - -- name: extra_job_id - type: string - default: '' - -- name: python_wheel_suffix - type: string - default: '' - - -# TODO: Ideally it should fetch information from the build that triggers it -- name: cmake_build_type - type: string - default: 'Release' - values: - - Debug - - Release - - RelWithDebInfo - - MinSizeRel - -- name: timeout - type: number - default: 120 - -jobs: -- job: Linux_Test_${{ parameters.device }}${{ parameters.extra_job_id }}_${{ parameters.arch }} - timeoutInMinutes: ${{ parameters.timeout }} - variables: - skipComponentGovernanceDetection: true - workspace: - clean: all - pool: ${{ parameters.machine_pool }} - steps: - - task: DownloadPipelineArtifact@2 - displayName: 'Download Pipeline Artifact' - inputs: - artifactName: 'drop-linux-${{ lower(parameters.device) }}-${{ parameters.arch }}' - targetPath: '$(Build.BinariesDirectory)/${{parameters.cmake_build_type}}' - # The public ADO project - ${{ if eq(variables['System.CollectionId'], 'f3ad12f2-e480-4533-baf2-635c95467d29') }}: - buildType: current - # The private ADO project - ${{ if eq(variables['System.CollectionId'], 'bc038106-a83b-4dab-9dd3-5a41bc58f34c') }}: - project: '530acbc4-21bc-487d-8cd8-348ff451d2ff' - definition: 841 - preferTriggeringPipeline: true - runVersion: 'latest' - buildType: specific - - - task: DownloadPipelineArtifact@2 - displayName: 'Download Pipeline Artifact' - inputs: - artifactName: 'onnxruntime${{ parameters.python_wheel_suffix }}' - targetPath: '$(Build.BinariesDirectory)/whl' - # The public ADO project - ${{ if eq(variables['System.CollectionId'], 'f3ad12f2-e480-4533-baf2-635c95467d29') }}: - buildType: current - # The private ADO project - ${{ if eq(variables['System.CollectionId'], 'bc038106-a83b-4dab-9dd3-5a41bc58f34c') }}: - project: '530acbc4-21bc-487d-8cd8-348ff451d2ff' - definition: 841 - preferTriggeringPipeline: true - runVersion: 'latest' - buildType: specific - - - - task: Bash@3 - displayName: 'Bash Script' - inputs: - targetType: filePath - filePath: tools/ci_build/github/linux/run_python_tests.sh - arguments: -d ${{ parameters.device }} -c ${{parameters.cmake_build_type}} - - - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 - displayName: 'Clean Agent Directories' - condition: always() diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cpu b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cpu index a9a1e6b39a8cb..af87852561e0a 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cpu +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cpu @@ -1,9 +1,9 @@ -ARG BASEIMAGE=amd64/almalinux:8 +ARG BASEIMAGE=registry.access.redhat.com/ubi8/ubi ARG POLICY=manylinux_2_28 ARG PLATFORM=x86_64 ARG DEVTOOLSET_ROOTPATH=/opt/rh/gcc-toolset-12/root ARG LD_LIBRARY_PATH_ARG=${DEVTOOLSET_ROOTPATH}/usr/lib64:${DEVTOOLSET_ROOTPATH}/usr/lib:${DEVTOOLSET_ROOTPATH}/usr/lib64/dyninst:${DEVTOOLSET_ROOTPATH}/usr/lib/dyninst:/usr/local/lib64 -ARG PREPEND_PATH=${DEVTOOLSET_ROOTPATH}/usr/bin: +ARG PREPEND_PATH=/usr/lib/jvm/msopenjdk-11/bin:${DEVTOOLSET_ROOTPATH}/usr/bin: #Build manylinux2014 docker image begin FROM $BASEIMAGE AS runtime_base @@ -26,7 +26,6 @@ COPY build_scripts/fixup-mirrors.sh /usr/local/sbin/fixup-mirrors # setup entrypoint, this will wrap commands with `linux32` with i686 images COPY build_scripts/install-entrypoint.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ @@ -35,7 +34,6 @@ COPY manylinux-entrypoint /usr/local/bin/manylinux-entrypoint ENTRYPOINT ["manylinux-entrypoint"] COPY build_scripts/install-runtime-packages.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ RUN manylinux-entrypoint /build_scripts/install-runtime-packages.sh && rm -rf /build_scripts/ @@ -137,9 +135,7 @@ COPY --from=build_git /manylinux-rootfs / COPY --from=build_cpython /manylinux-rootfs / COPY --from=all_python /opt/_internal /opt/_internal/ COPY build_scripts/finalize.sh \ - build_scripts/update-system-packages.sh \ build_scripts/python-tag-abi-tag.py \ - build_scripts/requirements3.8.txt \ build_scripts/requirements3.9.txt \ build_scripts/requirements3.10.txt \ @@ -156,6 +152,7 @@ CMD ["/bin/bash"] #Build manylinux2014 docker image end ENV PATH ${DEVTOOLSET_ROOTPATH}/usr/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin +ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-11 ADD scripts /tmp/scripts RUN cd /tmp/scripts && /tmp/scripts/manylinux/install_centos.sh diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11 b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11 index dab8df6703c4f..933b0211b0e6c 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11 +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11 @@ -31,7 +31,6 @@ COPY build_scripts/fixup-mirrors.sh /usr/local/sbin/fixup-mirrors # setup entrypoint, this will wrap commands with `linux32` with i686 images COPY build_scripts/install-entrypoint.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ @@ -40,7 +39,6 @@ COPY manylinux-entrypoint /usr/local/bin/manylinux-entrypoint ENTRYPOINT ["manylinux-entrypoint"] COPY build_scripts/install-runtime-packages.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ RUN manylinux-entrypoint /build_scripts/install-runtime-packages.sh && rm -rf /build_scripts/ @@ -140,7 +138,6 @@ COPY --from=build_git /manylinux-rootfs / COPY --from=build_cpython /manylinux-rootfs / COPY --from=all_python /opt/_internal /opt/_internal/ COPY build_scripts/finalize.sh \ - build_scripts/update-system-packages.sh \ build_scripts/python-tag-abi-tag.py \ build_scripts/requirements3.8.txt \ build_scripts/requirements3.9.txt \ @@ -156,7 +153,7 @@ ENV SSL_CERT_FILE=/opt/_internal/certs.pem CMD ["/bin/bash"] #Build manylinux2014 docker image end - +ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-11 #Add our own dependencies ADD scripts /tmp/scripts RUN cd /tmp/scripts && /tmp/scripts/manylinux/install_centos.sh && /tmp/scripts/manylinux/install_deps.sh && rm -rf /tmp/scripts diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_6_tensorrt8_4 b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_6_tensorrt8_4 index 303e83eb23bca..003bb2324c049 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_6_tensorrt8_4 +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_6_tensorrt8_4 @@ -31,7 +31,6 @@ COPY build_scripts/fixup-mirrors.sh /usr/local/sbin/fixup-mirrors # setup entrypoint, this will wrap commands with `linux32` with i686 images COPY build_scripts/install-entrypoint.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ @@ -40,7 +39,6 @@ COPY manylinux-entrypoint /usr/local/bin/manylinux-entrypoint ENTRYPOINT ["manylinux-entrypoint"] COPY build_scripts/install-runtime-packages.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ RUN manylinux-entrypoint /build_scripts/install-runtime-packages.sh && rm -rf /build_scripts/ @@ -140,7 +138,6 @@ COPY --from=build_git /manylinux-rootfs / COPY --from=build_cpython /manylinux-rootfs / COPY --from=all_python /opt/_internal /opt/_internal/ COPY build_scripts/finalize.sh \ - build_scripts/update-system-packages.sh \ build_scripts/python-tag-abi-tag.py \ build_scripts/requirements3.8.txt \ build_scripts/requirements3.9.txt \ @@ -163,7 +160,7 @@ RUN v="8.4.1-1.cuda11.6" &&\ yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel7/x86_64/cuda-rhel7.repo &&\ yum -y install libnvinfer8-${v} libnvparsers8-${v} libnvonnxparsers8-${v} libnvinfer-plugin8-${v} \ libnvinfer-devel-${v} libnvparsers-devel-${v} libnvonnxparsers-devel-${v} libnvinfer-plugin-devel-${v} - +ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-11 #Add our own dependencies ADD scripts /tmp/scripts RUN cd /tmp/scripts && /tmp/scripts/manylinux/install_centos.sh && /tmp/scripts/manylinux/install_deps.sh && rm -rf /tmp/scripts diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_6_tensorrt8_5 b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_6_tensorrt8_5 index d17e4b24582fe..0337ffc5e00a0 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_6_tensorrt8_5 +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_6_tensorrt8_5 @@ -31,7 +31,6 @@ COPY build_scripts/fixup-mirrors.sh /usr/local/sbin/fixup-mirrors # setup entrypoint, this will wrap commands with `linux32` with i686 images COPY build_scripts/install-entrypoint.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ @@ -40,7 +39,6 @@ COPY manylinux-entrypoint /usr/local/bin/manylinux-entrypoint ENTRYPOINT ["manylinux-entrypoint"] COPY build_scripts/install-runtime-packages.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ RUN manylinux-entrypoint /build_scripts/install-runtime-packages.sh && rm -rf /build_scripts/ @@ -140,7 +138,6 @@ COPY --from=build_git /manylinux-rootfs / COPY --from=build_cpython /manylinux-rootfs / COPY --from=all_python /opt/_internal /opt/_internal/ COPY build_scripts/finalize.sh \ - build_scripts/update-system-packages.sh \ build_scripts/python-tag-abi-tag.py \ build_scripts/requirements3.8.txt \ build_scripts/requirements3.9.txt \ @@ -163,7 +160,7 @@ RUN v="8.5.1-1.cuda11.8" &&\ yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel7/x86_64/cuda-rhel7.repo &&\ yum -y install libnvinfer8-${v} libnvparsers8-${v} libnvonnxparsers8-${v} libnvinfer-plugin8-${v} \ libnvinfer-devel-${v} libnvparsers-devel-${v} libnvonnxparsers-devel-${v} libnvinfer-plugin-devel-${v} - +ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-11 #Add our own dependencies ADD scripts /tmp/scripts RUN cd /tmp/scripts && /tmp/scripts/manylinux/install_centos.sh && /tmp/scripts/manylinux/install_deps.sh && rm -rf /tmp/scripts diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_8_tensorrt8_6 b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_8_tensorrt8_6 index 3c0ac22e38b5a..2c953a10cbf64 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_8_tensorrt8_6 +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_8_tensorrt8_6 @@ -31,7 +31,6 @@ COPY build_scripts/fixup-mirrors.sh /usr/local/sbin/fixup-mirrors # setup entrypoint, this will wrap commands with `linux32` with i686 images COPY build_scripts/install-entrypoint.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ @@ -40,7 +39,6 @@ COPY manylinux-entrypoint /usr/local/bin/manylinux-entrypoint ENTRYPOINT ["manylinux-entrypoint"] COPY build_scripts/install-runtime-packages.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ RUN manylinux-entrypoint /build_scripts/install-runtime-packages.sh && rm -rf /build_scripts/ @@ -147,7 +145,6 @@ COPY --from=build_git /manylinux-rootfs / COPY --from=build_cpython /manylinux-rootfs / COPY --from=all_python /opt/_internal /opt/_internal/ COPY build_scripts/finalize.sh \ - build_scripts/update-system-packages.sh \ build_scripts/python-tag-abi-tag.py \ build_scripts/requirements3.7.txt \ build_scripts/requirements3.8.txt \ @@ -171,7 +168,7 @@ RUN v="8.6.1.6-1.cuda11.8" &&\ yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo &&\ yum -y install libnvinfer8-${v} libnvparsers8-${v} libnvonnxparsers8-${v} libnvinfer-plugin8-${v} libnvinfer-vc-plugin8-${v}\ libnvinfer-devel-${v} libnvparsers-devel-${v} libnvonnxparsers-devel-${v} libnvinfer-plugin-devel-${v} libnvinfer-vc-plugin-devel-${v} libnvinfer-headers-devel-${v} libnvinfer-headers-plugin-devel-${v} - +ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-11 #Add our own dependencies ADD scripts /tmp/scripts RUN cd /tmp/scripts && /tmp/scripts/manylinux/install_centos.sh && /tmp/scripts/manylinux/install_deps.sh && rm -rf /tmp/scripts diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_training_cuda11_8 b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_training_cuda11_8 index 326e15d58456a..09ab7951552a0 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_training_cuda11_8 +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_training_cuda11_8 @@ -31,7 +31,6 @@ COPY build_scripts/fixup-mirrors.sh /usr/local/sbin/fixup-mirrors # setup entrypoint, this will wrap commands with `linux32` with i686 images COPY build_scripts/install-entrypoint.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ @@ -40,7 +39,6 @@ COPY manylinux-entrypoint /usr/local/bin/manylinux-entrypoint ENTRYPOINT ["manylinux-entrypoint"] COPY build_scripts/install-runtime-packages.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ RUN manylinux-entrypoint /build_scripts/install-runtime-packages.sh && rm -rf /build_scripts/ @@ -140,7 +138,6 @@ COPY --from=build_git /manylinux-rootfs / COPY --from=build_cpython /manylinux-rootfs / COPY --from=all_python /opt/_internal /opt/_internal/ COPY build_scripts/finalize.sh \ - build_scripts/update-system-packages.sh \ build_scripts/python-tag-abi-tag.py \ build_scripts/requirements3.8.txt \ build_scripts/requirements3.9.txt \ diff --git a/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_cuda11_8_tensorrt8_6 b/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_cuda11_8_tensorrt8_6 index c211fa9b9e2b8..83a974469234f 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_cuda11_8_tensorrt8_6 +++ b/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_cuda11_8_tensorrt8_6 @@ -7,40 +7,30 @@ # Build base image with required system packages FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04 AS base -# The local directory into which to build and install CMAKE -ARG ONNXRUNTIME_LOCAL_CODE_DIR=/code - -ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${ONNXRUNTIME_LOCAL_CODE_DIR}/cmake-3.27.3-linux-x86_64/bin:/opt/miniconda/bin:${PATH} +ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${PATH} ENV DEBIAN_FRONTEND=noninteractive RUN apt-get update &&\ - apt-get install -y sudo git bash unattended-upgrades wget -RUN unattended-upgrade + apt-get install -y git bash wget # Install python3 RUN apt-get install -y --no-install-recommends \ python3 \ python3-pip \ python3-dev \ - python3-wheel &&\ - cd /usr/local/bin &&\ - ln -s /usr/bin/python3 python &&\ - ln -s /usr/bin/pip3 pip; + python3-wheel + RUN pip install --upgrade pip -RUN pip install setuptools>=41.0.0 # Install TensorRT RUN v="8.6.1.6-1+cuda11.8" &&\ apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/7fa2af80.pub &&\ apt-get update &&\ - sudo apt-get install -y libnvinfer8=${v} libnvonnxparsers8=${v} libnvparsers8=${v} libnvinfer-plugin8=${v} libnvinfer-lean8=${v} libnvinfer-vc-plugin8=${v} libnvinfer-dispatch8=${v}\ + apt-get install -y libnvinfer8=${v} libnvonnxparsers8=${v} libnvparsers8=${v} libnvinfer-plugin8=${v} libnvinfer-lean8=${v} libnvinfer-vc-plugin8=${v} libnvinfer-dispatch8=${v}\ libnvinfer-headers-dev=${v} libnvinfer-headers-plugin-dev=${v} libnvinfer-dev=${v} libnvonnxparsers-dev=${v} libnvparsers-dev=${v} libnvinfer-plugin-dev=${v} libnvinfer-lean-dev=${v} libnvinfer-vc-plugin-dev=${v} libnvinfer-dispatch-dev=${v}\ python3-libnvinfer=${v} libnvinfer-samples=${v} tensorrt-dev=${v} tensorrt-libs=${v} -# Install Valgrind -RUN apt-get install -y valgrind - ADD scripts /tmp/scripts RUN cd /tmp/scripts && /tmp/scripts/install_dotnet.sh && rm -rf /tmp/scripts diff --git a/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/scripts/install_centos.sh b/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/scripts/install_centos.sh index a1ade39e57e16..adb0464d6496a 100755 --- a/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/scripts/install_centos.sh +++ b/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/scripts/install_centos.sh @@ -1,9 +1,8 @@ #!/bin/bash set -e -x -os_major_version=$(cat /etc/redhat-release | tr -dc '0-9.'|cut -d \. -f1) +os_major_version=$(tr -dc '0-9.' < /etc/redhat-release |cut -d \. -f1) echo "installing for CentOS version : $os_major_version" - -dnf install -y glibc-langpack-\* glibc-locale-source which gdb redhat-lsb-core expat-devel tar unzip zlib-devel make bzip2 bzip2-devel java-11-openjdk-devel graphviz gcc-toolset-12-binutils gcc-toolset-12-gcc gcc-toolset-12-gcc-c++ gcc-toolset-12-gcc-gfortran -locale \ No newline at end of file +dnf install -y python39-devel glibc-langpack-\* glibc-locale-source which redhat-lsb-core expat-devel tar unzip zlib-devel make bzip2 bzip2-devel java-11-openjdk-devel graphviz gcc-toolset-12-binutils gcc-toolset-12-gcc gcc-toolset-12-gcc-c++ gcc-toolset-12-gcc-gfortran +locale diff --git a/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/scripts/install_deps.sh b/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/scripts/install_deps.sh index 7ecd0525c7e7e..7598ab0a7a536 100755 --- a/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/scripts/install_deps.sh +++ b/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/scripts/install_deps.sh @@ -14,20 +14,20 @@ function GetFile { echo "File '$path' already exists. Skipping download" return 0 else - rm -rf $path + rm -rf "$path" fi fi if [[ -f $uri ]]; then echo "'$uri' is a file path, copying file to '$path'" - cp $uri $path + cp "$uri" "$path" return $? fi echo "Downloading $uri" # Use aria2c if available, otherwise use curl if command -v aria2c > /dev/null; then - aria2c -q -d $(dirname $path) -o $(basename $path) "$uri" + aria2c -q -d "$(dirname $path)" -o "$(basename $path)" "$uri" else curl "$uri" -sSL --retry $download_retries --retry-delay $retry_wait_time_seconds --create-dirs -o "$path" --fail fi @@ -38,9 +38,10 @@ mkdir -p /tmp/src cd /tmp/src +CPU_ARCH=$(uname -m) echo "Installing cmake" -GetFile https://github.com/Kitware/CMake/releases/download/v3.27.3/cmake-3.27.3-linux-`uname -m`.tar.gz /tmp/src/cmake-3.27.3-linux-`uname -m`.tar.gz -tar -zxf /tmp/src/cmake-3.27.3-linux-`uname -m`.tar.gz --strip=1 -C /usr +GetFile "https://github.com/Kitware/CMake/releases/download/v3.27.3/cmake-3.27.3-linux-$CPU_ARCH.tar.gz" "/tmp/src/cmake.tar.gz" +tar -zxf /tmp/src/cmake.tar.gz --strip=1 -C /usr echo "Installing Ninja" GetFile https://github.com/ninja-build/ninja/archive/v1.10.0.tar.gz /tmp/src/ninja-linux.tar.gz @@ -52,7 +53,7 @@ mv ./build-cmake/ninja /usr/bin popd echo "Installing Node.js" -CPU_ARCH=`uname -m` + if [[ "$CPU_ARCH" = "x86_64" ]]; then NODEJS_ARCH=x64 elif [[ "$CPU_ARCH" = "aarch64" ]]; then @@ -64,16 +65,5 @@ fi GetFile https://nodejs.org/dist/v18.17.1/node-v18.17.1-linux-${NODEJS_ARCH}.tar.gz /tmp/src/node-v18.17.1-linux-${NODEJS_ARCH}.tar.gz tar --strip 1 -xf /tmp/src/node-v18.17.1-linux-${NODEJS_ARCH}.tar.gz -C /usr -# The Python version in CentOS 7's python3 package is no longer supported (3.6) so we will build Python from source. -echo "Installing Python" -PYTHON_VERSION="3.8.17" -GetFile https://www.python.org/ftp/python/${PYTHON_VERSION}/Python-${PYTHON_VERSION}.tgz /tmp/src/Python-${PYTHON_VERSION}.tgz -tar -zxf Python-${PYTHON_VERSION}.tgz -pushd Python-${PYTHON_VERSION} -./configure -make -make install -popd - cd / rm -rf /tmp/src diff --git a/tools/ci_build/github/linux/docker/inference/x64/default/cpu/Dockerfile b/tools/ci_build/github/linux/docker/inference/x64/default/cpu/Dockerfile index 0324f377b8e9e..caf9583807b62 100644 --- a/tools/ci_build/github/linux/docker/inference/x64/default/cpu/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/x64/default/cpu/Dockerfile @@ -5,10 +5,10 @@ ARG BASEIMAGE=amd64/almalinux:8 FROM $BASEIMAGE -ENV PATH /opt/rh/gcc-toolset-12/root/usr/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin +ENV PATH /usr/lib/jvm/msopenjdk-11/bin:/opt/rh/gcc-toolset-12/root/usr/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin ENV LANG=en_US.UTF-8 ENV LC_ALL=en_US.UTF-8 - +ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-11 ADD scripts /tmp/scripts RUN cd /tmp/scripts && /tmp/scripts/install_centos.sh && /tmp/scripts/install_deps.sh && rm -rf /tmp/scripts diff --git a/tools/ci_build/github/linux/docker/inference/x64/default/cpu/scripts/install_centos.sh b/tools/ci_build/github/linux/docker/inference/x64/default/cpu/scripts/install_centos.sh index 8e18a237a807e..b5f8bf1a49a19 100755 --- a/tools/ci_build/github/linux/docker/inference/x64/default/cpu/scripts/install_centos.sh +++ b/tools/ci_build/github/linux/docker/inference/x64/default/cpu/scripts/install_centos.sh @@ -1,9 +1,9 @@ #!/bin/bash set -e -x -os_major_version=$(cat /etc/redhat-release | tr -dc '0-9.'|cut -d \. -f1) +os_major_version=$(tr -dc '0-9.' < /etc/redhat-release |cut -d \. -f1) echo "installing for CentOS version : $os_major_version" - -dnf install -y python39-devel glibc-langpack-\* glibc-locale-source which gdb redhat-lsb-core expat-devel tar unzip zlib-devel make bzip2 bzip2-devel java-11-openjdk-devel graphviz gcc-toolset-12-binutils gcc-toolset-12-gcc gcc-toolset-12-gcc-c++ gcc-toolset-12-gcc-gfortran -locale \ No newline at end of file +rpm -Uvh https://packages.microsoft.com/config/centos/$os_major_version/packages-microsoft-prod.rpm +dnf install -y python39-devel glibc-langpack-\* glibc-locale-source which redhat-lsb-core expat-devel tar unzip zlib-devel make bzip2 bzip2-devel msopenjdk-11 graphviz gcc-toolset-12-binutils gcc-toolset-12-gcc gcc-toolset-12-gcc-c++ gcc-toolset-12-gcc-gfortran +locale diff --git a/tools/ci_build/github/linux/docker/inference/x64/default/gpu/Dockerfile b/tools/ci_build/github/linux/docker/inference/x64/default/gpu/Dockerfile index 386759890d085..318791072f46d 100644 --- a/tools/ci_build/github/linux/docker/inference/x64/default/gpu/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/x64/default/gpu/Dockerfile @@ -4,8 +4,10 @@ # This file is used by Zip-Nuget Packaging NoContribOps Pipeline,Zip-Nuget-Java Packaging Pipeline FROM nvidia/cuda:11.8.0-cudnn8-devel-ubi8 +ENV PATH /usr/lib/jvm/msopenjdk-11/bin:/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin ENV LANG=en_US.UTF-8 ENV LC_ALL=en_US.UTF-8 +ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-11 ADD scripts /tmp/scripts RUN cd /tmp/scripts && /tmp/scripts/install_centos.sh && /tmp/scripts/install_deps.sh && rm -rf /tmp/scripts diff --git a/tools/ci_build/github/linux/docker/inference/x64/default/gpu/scripts/install_centos.sh b/tools/ci_build/github/linux/docker/inference/x64/default/gpu/scripts/install_centos.sh index 3cf259dc7240e..31e3e40f1b7ee 100755 --- a/tools/ci_build/github/linux/docker/inference/x64/default/gpu/scripts/install_centos.sh +++ b/tools/ci_build/github/linux/docker/inference/x64/default/gpu/scripts/install_centos.sh @@ -1,9 +1,9 @@ #!/bin/bash set -e -x -os_major_version=$(cat /etc/redhat-release | tr -dc '0-9.'|cut -d \. -f1) +os_major_version=$(tr -dc '0-9.' < /etc/redhat-release |cut -d \. -f1) echo "installing for CentOS version : $os_major_version" - -dnf install -y python39-devel python3-devel glibc-langpack-\* glibc-locale-source which gdb redhat-lsb-core expat-devel tar unzip zlib-devel make bzip2 bzip2-devel java-11-openjdk-devel -locale \ No newline at end of file +rpm -Uvh https://packages.microsoft.com/config/centos/$os_major_version/packages-microsoft-prod.rpm +dnf install -y python39-devel glibc-langpack-\* glibc-locale-source which redhat-lsb-core expat-devel tar unzip zlib-devel make bzip2 bzip2-devel msopenjdk-11 +locale diff --git a/tools/ci_build/github/linux/docker/inference/x64/python/cpu/Dockerfile.manylinux2_28_cpu b/tools/ci_build/github/linux/docker/inference/x64/python/cpu/Dockerfile.manylinux2_28_cpu index 33660cbb3f2e5..06e75ee1a39f6 100644 --- a/tools/ci_build/github/linux/docker/inference/x64/python/cpu/Dockerfile.manylinux2_28_cpu +++ b/tools/ci_build/github/linux/docker/inference/x64/python/cpu/Dockerfile.manylinux2_28_cpu @@ -26,7 +26,6 @@ COPY build_scripts/fixup-mirrors.sh /usr/local/sbin/fixup-mirrors # setup entrypoint, this will wrap commands with `linux32` with i686 images COPY build_scripts/install-entrypoint.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ @@ -35,7 +34,6 @@ COPY manylinux-entrypoint /usr/local/bin/manylinux-entrypoint ENTRYPOINT ["manylinux-entrypoint"] COPY build_scripts/install-runtime-packages.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ RUN manylinux-entrypoint /build_scripts/install-runtime-packages.sh && rm -rf /build_scripts/ @@ -132,7 +130,6 @@ COPY --from=build_git /manylinux-rootfs / COPY --from=build_cpython /manylinux-rootfs / COPY --from=all_python /opt/_internal /opt/_internal/ COPY build_scripts/finalize.sh \ - build_scripts/update-system-packages.sh \ build_scripts/python-tag-abi-tag.py \ build_scripts/requirements3.8.txt \ build_scripts/requirements3.9.txt \ diff --git a/tools/ci_build/github/linux/docker/inference/x64/python/cpu/scripts/install_centos.sh b/tools/ci_build/github/linux/docker/inference/x64/python/cpu/scripts/install_centos.sh index 98bb730a43776..c81e57c60c9da 100755 --- a/tools/ci_build/github/linux/docker/inference/x64/python/cpu/scripts/install_centos.sh +++ b/tools/ci_build/github/linux/docker/inference/x64/python/cpu/scripts/install_centos.sh @@ -1,11 +1,11 @@ #!/bin/bash set -e -os_major_version=$(cat /etc/redhat-release | tr -dc '0-9.'|cut -d \. -f1) +os_major_version=$(tr -dc '0-9.' < /etc/redhat-release |cut -d \. -f1) echo "installing for os major version : $os_major_version" dnf install -y glibc-langpack-\* -yum install -y which gdb redhat-lsb-core expat-devel tar unzip zlib-devel make bzip2 bzip2-devel perl-IPC-Cmd openssl-devel wget +yum install -y which redhat-lsb-core expat-devel tar unzip zlib-devel make bzip2 bzip2-devel perl-IPC-Cmd openssl-devel wget # export PATH=/opt/python/cp38-cp38/bin:$PATH @@ -17,4 +17,4 @@ mkdir build cd build cmake .. cmake --install . -cd ../.. \ No newline at end of file +cd ../.. diff --git a/tools/ci_build/github/linux/docker/manylinux.patch b/tools/ci_build/github/linux/docker/manylinux.patch index f1821f9197525..75923e746f93c 100644 --- a/tools/ci_build/github/linux/docker/manylinux.patch +++ b/tools/ci_build/github/linux/docker/manylinux.patch @@ -94,7 +94,7 @@ index 9ef1e99..ec52833 100755 +fi \ No newline at end of file diff --git a/install-runtime-packages.sh b/install-runtime-packages.sh -index 137d2e2..4269afb 100755 +index 137d2e2..203b4bc 100755 --- a/install-runtime-packages.sh +++ b/install-runtime-packages.sh @@ -33,7 +33,7 @@ source $MY_DIR/build_utils.sh @@ -130,7 +130,7 @@ index 137d2e2..4269afb 100755 elif [ "${AUDITWHEEL_ARCH}" == "aarch64" ] || [ "${AUDITWHEEL_ARCH}" == "ppc64le" ] || [ "${AUDITWHEEL_ARCH}" == "s390x" ]; then # Software collection (for devtoolset-10) yum -y install centos-release-scl-rh -@@ -86,19 +88,18 @@ if [ "${AUDITWHEEL_POLICY}" == "manylinux2014" ]; then +@@ -86,19 +88,21 @@ if [ "${AUDITWHEEL_POLICY}" == "manylinux2014" ]; then fi elif [ "${AUDITWHEEL_POLICY}" == "manylinux_2_28" ]; then PACKAGE_MANAGER=dnf @@ -148,6 +148,9 @@ index 137d2e2..4269afb 100755 - TOOLCHAIN_DEPS="gcc-toolset-12-binutils gcc-toolset-12-gcc gcc-toolset-12-gcc-c++ gcc-toolset-12-gcc-gfortran" - if [ "${AUDITWHEEL_ARCH}" == "x86_64" ]; then - TOOLCHAIN_DEPS="${TOOLCHAIN_DEPS} yasm" ++ if test -f "/etc/yum.repos.d/ubi.repo"; then ++ sed -i 's/enabled\s*=\s*1/enabled = 1\nexclude=dotnet* aspnet* netstandard*/g' /etc/yum.repos.d/ubi.repo ++ fi + if [[ -d /usr/local/cuda ]]; then + TOOLCHAIN_DEPS="gcc gcc-c++" + else @@ -155,7 +158,7 @@ index 137d2e2..4269afb 100755 fi elif [ "${AUDITWHEEL_POLICY}" == "musllinux_1_1" ]; then TOOLCHAIN_DEPS="binutils gcc g++ gfortran" -@@ -121,12 +122,6 @@ else +@@ -121,12 +125,6 @@ else exit 1 fi diff --git a/tools/ci_build/github/linux/docker/scripts/install_dotnet.sh b/tools/ci_build/github/linux/docker/scripts/install_dotnet.sh index b9accb134b26d..c4689ed19c148 100755 --- a/tools/ci_build/github/linux/docker/scripts/install_dotnet.sh +++ b/tools/ci_build/github/linux/docker/scripts/install_dotnet.sh @@ -2,13 +2,15 @@ set -e -x if [ -f /etc/redhat-release ]; then - dnf update --refresh -y \ - && dnf install -y dotnet-sdk-6.0 + # If you found the following command went successfully but dotnet command still reports no sdk was found, most likely + # it was because the dotnet packages were installed from more than one dnf repos. + dnf install -y dotnet-sdk-6.0 dotnet-runtime-6.0 elif [ -f /etc/os-release ]; then # Get Ubuntu version - declare repo_version=$(if command -v lsb_release &> /dev/null; then lsb_release -r -s; else grep -oP '(?<=^VERSION_ID=).+' /etc/os-release | tr -d '"'; fi) + declare repo_version + repo_version=$(if command -v lsb_release &> /dev/null; then lsb_release -r -s; else grep -oP '(?<=^VERSION_ID=).+' /etc/os-release | tr -d '"'; fi) # Download Microsoft signing key and repository - wget https://packages.microsoft.com/config/ubuntu/$repo_version/packages-microsoft-prod.deb -O packages-microsoft-prod.deb + wget "https://packages.microsoft.com/config/ubuntu/$repo_version/packages-microsoft-prod.deb" -O packages-microsoft-prod.deb # Install Microsoft signing key and repository dpkg -i packages-microsoft-prod.deb # Clean up diff --git a/tools/ci_build/github/linux/docker/scripts/manylinux/install_centos.sh b/tools/ci_build/github/linux/docker/scripts/manylinux/install_centos.sh index 4f544a50cb94d..63b953a95add6 100755 --- a/tools/ci_build/github/linux/docker/scripts/manylinux/install_centos.sh +++ b/tools/ci_build/github/linux/docker/scripts/manylinux/install_centos.sh @@ -1,17 +1,18 @@ #!/bin/bash set -e -os_major_version=$(cat /etc/redhat-release | tr -dc '0-9.'|cut -d \. -f1) +os_major_version=$(tr -dc '0-9.' < /etc/redhat-release |cut -d \. -f1) echo "installing for os major version : $os_major_version" if [ "$os_major_version" -gt 7 ]; then PACKAGE_MANAGER="dnf" - $PACKAGE_MANAGER install -y which gdb redhat-lsb-core expat-devel tar unzip zlib-devel make bzip2 bzip2-devel perl-IPC-Cmd openssl-devel wget + $PACKAGE_MANAGER install -y which redhat-lsb-core expat-devel tar unzip zlib-devel make bzip2 bzip2-devel perl-IPC-Cmd openssl-devel wget else PACKAGE_MANAGER="yum" - $PACKAGE_MANAGER install -y which gdb redhat-lsb-core expat-devel tar unzip zlib-devel make libunwind bzip2 bzip2-devel perl-IPC-Cmd openssl-devel wget + $PACKAGE_MANAGER install -y which redhat-lsb-core expat-devel tar unzip zlib-devel make libunwind bzip2 bzip2-devel perl-IPC-Cmd openssl-devel wget fi +rpm -Uvh https://packages.microsoft.com/config/centos/$os_major_version/packages-microsoft-prod.rpm # Install Java # Install automatic documentation generation dependencies -$PACKAGE_MANAGER install -y java-11-openjdk-devel graphviz +$PACKAGE_MANAGER install -y msopenjdk-11 graphviz diff --git a/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps.sh b/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps.sh index a1cb4be5b72c9..8c79918120d8d 100755 --- a/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps.sh +++ b/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps.sh @@ -3,18 +3,20 @@ set -e -x # Development tools and libraries if [ -f /etc/redhat-release ]; then - yum update && yum -y install graphviz - os_major_version=$(cat /etc/redhat-release | tr -dc '0-9.'|cut -d \. -f1) + dnf -y install graphviz elif [ -f /etc/os-release ]; then apt-get update && apt-get install -y graphviz - os_major_version=$(cat /etc/os-release | tr -dc '0-9.'|cut -d \. -f1) else echo "Unsupported OS" exit 1 fi # Install dotnet -source $(cd "$(dirname "${BASH_SOURCE[0]}")/.." &> /dev/null && pwd)/install_dotnet.sh +LOCAL_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" &> /dev/null && pwd)" +PARENT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." &> /dev/null && pwd)" +# ShellCheck is unable to follow dynamic paths, such as source "$somedir/file". +# shellcheck disable=SC1091 +source "$PARENT_DIR/install_dotnet.sh" if [ ! -d "/opt/conda/bin" ]; then PYTHON_EXES=("/opt/python/cp38-cp38/bin/python3.8" "/opt/python/cp39-cp39/bin/python3.9" "/opt/python/cp310-cp310/bin/python3.10" "/opt/python/cp311-cp311/bin/python3.11") @@ -22,23 +24,17 @@ else PYTHON_EXES=("/opt/conda/bin/python") fi -SYS_LONG_BIT=$(getconf LONG_BIT) mkdir -p /tmp/src -GLIBC_VERSION=$(getconf GNU_LIBC_VERSION | cut -f 2 -d \.) - -if [[ $SYS_LONG_BIT = "64" ]]; then - LIBDIR="lib64" -else - LIBDIR="lib" -fi cd /tmp/src -source $(cd "$(dirname "${BASH_SOURCE[0]}")" &> /dev/null && pwd)/install_shared_deps.sh +# shellcheck disable=SC1091 +source "$LOCAL_DIR/install_shared_deps.sh" cd /tmp/src if ! [ -x "$(command -v protoc)" ]; then - source ${0/%install_deps.sh/..\/install_protobuf.sh} +# shellcheck disable=SC1091 + source "$PARENT_DIR/install_protobuf.sh" fi export ONNX_ML=1 @@ -46,7 +42,7 @@ export CMAKE_ARGS="-DONNX_GEN_PB_TYPE_STUBS=OFF -DONNX_WERROR=OFF" for PYTHON_EXE in "${PYTHON_EXES[@]}" do - ${PYTHON_EXE} -m pip install -r ${0/%install_deps\.sh/requirements\.txt} + ${PYTHON_EXE} -m pip install -r "${0/%install_deps\.sh/requirements\.txt}" done cd / diff --git a/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_aten.sh b/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_aten.sh index ed220b487d06c..1f85f72aef423 100755 --- a/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_aten.sh +++ b/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_aten.sh @@ -11,7 +11,7 @@ else PYTHON_EXES=("/opt/conda/bin/python") fi -os_major_version=$(cat /etc/redhat-release | tr -dc '0-9.'|cut -d \. -f1) +os_major_version=$(tr -dc '0-9.' < /etc/redhat-release |cut -d \. -f1) SYS_LONG_BIT=$(getconf LONG_BIT) mkdir -p /tmp/src diff --git a/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_eager.sh b/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_eager.sh index e141e0793a2bd..ad3366b0bb3b6 100755 --- a/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_eager.sh +++ b/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_eager.sh @@ -11,7 +11,7 @@ else PYTHON_EXES=("/opt/conda/bin/python") fi -os_major_version=$(cat /etc/redhat-release | tr -dc '0-9.'|cut -d \. -f1) +os_major_version=$(tr -dc '0-9.' < /etc/redhat-release |cut -d \. -f1) SYS_LONG_BIT=$(getconf LONG_BIT) mkdir -p /tmp/src diff --git a/tools/ci_build/github/linux/run_python_dockertest.sh b/tools/ci_build/github/linux/run_python_dockertest.sh new file mode 100755 index 0000000000000..332dd9c7284c0 --- /dev/null +++ b/tools/ci_build/github/linux/run_python_dockertest.sh @@ -0,0 +1,29 @@ +#!/bin/bash +set -e -x +BUILD_CONFIG="Release" + +while getopts "i:d:x:c:" parameter_Option +do case "${parameter_Option}" +in +i) DOCKER_IMAGE=${OPTARG};; +d) DEVICE=${OPTARG};; +c) BUILD_CONFIG=${OPTARG};; +esac +done + +if [ $DEVICE = "GPU" ]; then + ADDITIONAL_DOCKER_PARAMETER="--gpus all" +fi + +mkdir -p $HOME/.onnx +docker run --rm \ + --volume /data/onnx:/data/onnx:ro \ + --volume $BUILD_SOURCESDIRECTORY:/onnxruntime_src \ + --volume $BUILD_BINARIESDIRECTORY:/build \ + --volume /data/models:/build/models:ro \ + --volume $HOME/.onnx:/home/onnxruntimedev/.onnx \ + -w /onnxruntime_src \ + -e NIGHTLY_BUILD \ + -e BUILD_BUILDNUMBER \ + $ADDITIONAL_DOCKER_PARAMETER \ + $DOCKER_IMAGE tools/ci_build/github/linux/run_python_tests.sh -d $DEVICE -c $BUILD_CONFIG diff --git a/tools/ci_build/github/linux/run_python_tests.sh b/tools/ci_build/github/linux/run_python_tests.sh index c11ea42cd0541..f080c7e8c39d8 100755 --- a/tools/ci_build/github/linux/run_python_tests.sh +++ b/tools/ci_build/github/linux/run_python_tests.sh @@ -15,7 +15,8 @@ c) BUILD_CONFIG=${OPTARG};; esac done -cd $BUILD_BINARIESDIRECTORY +export PATH=/opt/python/cp38-cp38/bin:$PATH +cd /build files=(whl/*.whl) FILE_NAME="${files[0]}" FILE_NAME=$(basename $FILE_NAME) @@ -23,7 +24,7 @@ PYTHON_PACKAGE_NAME=$(echo "$FILE_NAME" | cut -f 1 -d '-') echo "Package name:$PYTHON_PACKAGE_NAME" -BUILD_ARGS="--build_dir $BUILD_BINARIESDIRECTORY --config $BUILD_CONFIG --test --skip_submodule_sync --parallel --enable_lto --build_wheel " +BUILD_ARGS="--build_dir /build --config $BUILD_CONFIG --test --skip_submodule_sync --parallel --enable_lto --build_wheel " ARCH=$(uname -m) @@ -35,20 +36,15 @@ if [ $BUILD_DEVICE == "GPU" ]; then BUILD_ARGS="$BUILD_ARGS --use_cuda --use_tensorrt --cuda_version=11.8 --tensorrt_home=/usr --cuda_home=/usr/local/cuda-11.8 --cudnn_home=/usr/local/cuda-11.8" fi # We assume the machine doesn't have gcc and python development header files, so we don't build onnxruntime from source -sudo rm -rf /build /onnxruntime_src -sudo ln -s $BUILD_SOURCESDIRECTORY /onnxruntime_src python3 -m pip install --upgrade pip -python3 -m pip uninstall -y $PYTHON_PACKAGE_NAME ort-nightly-gpu ort-nightly onnxruntime onnxruntime-gpu onnxruntime-training onnxruntime-directml ort-nightly-directml onnx -qq # Install the packages that are needed for installing the onnxruntime python package -python3 -m pip install -r $BUILD_BINARIESDIRECTORY/$BUILD_CONFIG/requirements.txt +python3 -m pip install -r /build/$BUILD_CONFIG/requirements.txt # Install the packages that are needed for running test scripts -# Install the latest ONNX release which may contain not fixed bugs. However, it is what most people use. -python3 -m pip install onnx pytest +python3 -m pip install pytest # The "--no-index" flag is crucial. The local whl folder is just an additional source. Pypi's doc says "there is no # ordering in the locations that are searched" if we don't disable the default one with "--no-index" -python3 -m pip install --no-index --find-links $BUILD_BINARIESDIRECTORY/whl $PYTHON_PACKAGE_NAME -ln -s /data/models $BUILD_BINARIESDIRECTORY -cd $BUILD_BINARIESDIRECTORY/$BUILD_CONFIG +python3 -m pip install --no-index --find-links /build/whl $PYTHON_PACKAGE_NAME +cd /build/$BUILD_CONFIG # Restore file permissions xargs -a perms.txt chmod a+x -python3 $BUILD_SOURCESDIRECTORY/tools/ci_build/build.py $BUILD_ARGS --ctest_path '' +python3 /onnxruntime_src/tools/ci_build/build.py $BUILD_ARGS --ctest_path '' diff --git a/tools/scripts/python_test.sh b/tools/scripts/python_test.sh old mode 100644 new mode 100755 diff --git a/tools/scripts/symbolic_shape_infer_test.sh b/tools/scripts/symbolic_shape_infer_test.sh old mode 100644 new mode 100755 From 5b9cd91a9cddbe7c461c1ad7ca44edd5111ea920 Mon Sep 17 00:00:00 2001 From: PeixuanZuo <94887879+PeixuanZuo@users.noreply.github.com> Date: Thu, 21 Sep 2023 22:37:50 +0800 Subject: [PATCH 08/58] [ROCm] fix CI (#17648) fix CI, follow #17621 --- .../github/azure-pipelines/orttraining-pai-ci-pipeline.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/ci_build/github/azure-pipelines/orttraining-pai-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/orttraining-pai-ci-pipeline.yml index 3333a7d22a41b..8dd1f0c5c6461 100644 --- a/tools/ci_build/github/azure-pipelines/orttraining-pai-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/orttraining-pai-ci-pipeline.yml @@ -222,7 +222,7 @@ jobs: clean: all pool: AMD-GPU dependsOn: - - Linux_Build + - Linux_Build_ubuntu timeoutInMinutes: 120 steps: From f299016cbe87a5341e0a8aa69b621555c9d49a35 Mon Sep 17 00:00:00 2001 From: George Nash Date: Thu, 21 Sep 2023 09:25:41 -0700 Subject: [PATCH 09/58] Fix crash on Windows server 2016 on Intel Gen4 Xeon processors (#17611) This adds an additional check before enabling MlasGemmU8S8DispatchAmx for GEMM operations. After checking the CPUID for AMX-TILE and AMX-INT8, an additional check is added that checks value of the XCR0 register. The value in the OXR0 register is set by the OS and indicates support for various CPU features. In this case the bits indicating XTILECFG and XTILEDATA support are checked. ### Description This adds an additional check before enabling MlasGemmU8S8DispatchAmx for GEMM operations. After checking the CPUID for AMX-TILE and AMX-INT8, an additional check is added that checks value of the XCR0 register. The value in the OXR0 register is set by the OS and indicates support for various CPU features. In this case the bits indicating XTILECFG and XTILEDATA support are checked. ### Motivation and Context Fix for crash reported directly by customer. When running older Windows server OS on newer Gen4 Xeon processors. Signed-off-by: Nash --- onnxruntime/core/mlas/lib/platform.cpp | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 7e2b117d6f249..96bc1d8010bed 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -112,6 +112,14 @@ MLAS_INTERNAL_DATA MLAS_DECLSPEC_ALIGN(const int16_t MlasOpmask16BitTableAvx512[ #define _XCR_XFEATURE_ENABLED_MASK 0 #endif +#if !defined(XFEATURE_MASK_XTILE) +#define XFEATURE_XTILECFG 17 +#define XFEATURE_XTILEDATA 18 +#define XFEATURE_MASK_XTILECFG (1 << XFEATURE_XTILECFG) +#define XFEATURE_MASK_XTILEDATA (1 << XFEATURE_XTILEDATA) +#define XFEATURE_MASK_XTILE (XFEATURE_MASK_XTILECFG | XFEATURE_MASK_XTILEDATA) +#endif + inline uint64_t MlasReadExtendedControlRegister( @@ -142,11 +150,6 @@ bool MlasInitAMX() { #if defined(__linux__) -#define XFEATURE_XTILECFG 17 -#define XFEATURE_XTILEDATA 18 -#define XFEATURE_MASK_XTILECFG (1 << XFEATURE_XTILECFG) -#define XFEATURE_MASK_XTILEDATA (1 << XFEATURE_XTILEDATA) -#define XFEATURE_MASK_XTILE (XFEATURE_MASK_XTILECFG | XFEATURE_MASK_XTILEDATA) #define ARCH_GET_XCOMP_PERM 0x1022 #define ARCH_REQ_XCOMP_PERM 0x1023 @@ -417,7 +420,9 @@ Return Value: // Check if the processor supports AMX-TILE and AMX-INT8 // features. // - if ((Cpuid7[3] & 0b1 << 24) != 0 && (Cpuid7[3] & 0b1 << 25) != 0) { + if ((Cpuid7[3] & 0b1 << 24) != 0 && + (Cpuid7[3] & 0b1 << 25) != 0 && + (xcr0 & XFEATURE_MASK_XTILE) == XFEATURE_MASK_XTILE) { if (MlasInitAMX()) { this->GemmU8U8Dispatch = &MlasGemmU8S8DispatchAmx; this->GemmU8S8Dispatch = &MlasGemmU8S8DispatchAmx; From d56fc7ebf5377abc96db728eafaffd8bf79a3b81 Mon Sep 17 00:00:00 2001 From: Abhishek Jindal Date: Thu, 21 Sep 2023 14:16:41 -0700 Subject: [PATCH 10/58] Layer norm fusion deepspeed stage3 changes (#17614) ### Description Layer norm fusion changes required for deepspeed stage 3, also includes test case. ### Motivation and Context It helps fusing layer norm for Deepspeed Stage 3. Added a test case scenario which ensures that the fusion is working properly for the scenario. --- .../core/optimizer/layer_norm_fusion.cc | 42 ++++----- .../graph_transform_test_layernorm.cc | 34 ++++++++ .../fusion/layer_norm_fusion_scale_bias.onnx | Bin 0 -> 854 bytes .../fusion/layer_norm_fusion_scale_bias.py | 81 ++++++++++++++++++ 4 files changed, 136 insertions(+), 21 deletions(-) create mode 100644 onnxruntime/test/testdata/transform/fusion/layer_norm_fusion_scale_bias.onnx create mode 100644 onnxruntime/test/testdata/transform/fusion/layer_norm_fusion_scale_bias.py diff --git a/onnxruntime/core/optimizer/layer_norm_fusion.cc b/onnxruntime/core/optimizer/layer_norm_fusion.cc index bf36f11521be2..159e3b23d1ab0 100644 --- a/onnxruntime/core/optimizer/layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/layer_norm_fusion.cc @@ -414,20 +414,20 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, NodeArg* scale = nullptr; NodeArg* bias = nullptr; for (size_t i = 0; i < mul_node.MutableInputDefs().size(); i++) { - if (graph_utils::NodeArgIsConstant(graph, *(mul_node.MutableInputDefs()[i])) || - graph_utils::IsGraphInput(graph, mul_node.MutableInputDefs()[i])) { - if (mul_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast(axes_values.size())) { - scale = mul_node.MutableInputDefs()[i]; - } + if (mul_node.MutableInputDefs()[i]->Shape() == nullptr) { + continue; + } + if (mul_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast(axes_values.size())) { + scale = mul_node.MutableInputDefs()[i]; } } for (size_t i = 0; i < last_add_node.MutableInputDefs().size(); i++) { - if (graph_utils::NodeArgIsConstant(graph, *(last_add_node.MutableInputDefs()[i])) || - graph_utils::IsGraphInput(graph, last_add_node.MutableInputDefs()[i])) { - if (last_add_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast(axes_values.size())) { - bias = last_add_node.MutableInputDefs()[i]; - } + if (last_add_node.MutableInputDefs()[i]->Shape() == nullptr) { + continue; + } + if (last_add_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast(axes_values.size())) { + bias = last_add_node.MutableInputDefs()[i]; } } if (scale == nullptr || bias == nullptr) { @@ -667,20 +667,20 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr // because SkipLayerNorm kernel, for example, has dependency on single dim size NodeArg* scale = nullptr; for (size_t i = 0; i < mul_node.MutableInputDefs().size(); i++) { - if (graph_utils::NodeArgIsConstant(graph, *(mul_node.MutableInputDefs()[i])) || - graph_utils::IsGraphInput(graph, mul_node.MutableInputDefs()[i])) { + if (mul_node.MutableInputDefs()[i]->Shape() == nullptr) { + continue; + } #ifdef ENABLE_TRAINING_CORE - if (axes_values.empty() || - mul_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast(axes_values.size())) { - scale = mul_node.MutableInputDefs()[i]; - } + if (axes_values.empty() || + mul_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast(axes_values.size())) { + scale = mul_node.MutableInputDefs()[i]; + } #else - // Scale must be 1d. - if (mul_node.MutableInputDefs()[i]->Shape()->dim_size() == 1) { - scale = mul_node.MutableInputDefs()[i]; - } -#endif + // Scale must be 1d. + if (mul_node.MutableInputDefs()[i]->Shape()->dim_size() == 1) { + scale = mul_node.MutableInputDefs()[i]; } +#endif } if (scale == nullptr) { diff --git a/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc b/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc index 1f671e90090ba..a55238396cea3 100755 --- a/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc +++ b/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc @@ -429,6 +429,40 @@ TEST_F(GraphTransformationTests, SimplifiedLayerNormFusionTest) { } } +// It tests the scenario when scale or bias are not Graph Inputs and not initialized in Graph +// To test this added a Identity node after Scale and Bias terms to ensure LayerNormFusion works properly +TEST_F(GraphTransformationTests, LayerNormScaleBiasTest) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/layer_norm_fusion_scale_bias.onnx"; + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["ReduceMean"], 0); + ASSERT_EQ(op_to_count["Sub"], 0); + ASSERT_EQ(op_to_count["Cast"], 0); + ASSERT_EQ(op_to_count["Pow"], 0); + ASSERT_EQ(op_to_count["Add"], 0); + ASSERT_EQ(op_to_count["Sqrt"], 0); + ASSERT_EQ(op_to_count["Div"], 0); + ASSERT_EQ(op_to_count["Mul"], 0); + ASSERT_EQ(op_to_count["LayerNormalization"], 1); + + for (const Node& node : graph.Nodes()) { + if (node.OpType() == "LayerNormalization") { + // LayerNormalization should have three inputs. + EXPECT_EQ(node.InputDefs().size(), 3u) << "LayerNormalization number of inputs does not equal to 3. Got:" << node.InputDefs().size(); + // LayerNormalization input "scale" and "bias" should have the same dimension. + const TensorShapeProto* scale_shape = node.InputDefs()[1]->Shape(); + EXPECT_EQ(scale_shape->dim_size(), 1) << "LayerNormalization scale should be 1D. Got: " << scale_shape->dim_size(); + } + } +} + // If EP is non-GPU EP or unknown, the sub-graph will be not fused because CPU impl for SimplifiedLayerNormalization // doesn't support input and scale having different data types. TEST_F(GraphTransformationTests, SimplifiedLayerNormWithCastsFusionTest) { diff --git a/onnxruntime/test/testdata/transform/fusion/layer_norm_fusion_scale_bias.onnx b/onnxruntime/test/testdata/transform/fusion/layer_norm_fusion_scale_bias.onnx new file mode 100644 index 0000000000000000000000000000000000000000..ec0f9a97815b888701198d94c92e3f61c581d6dc GIT binary patch literal 854 zcmbVLO;6h}7_J-BxG#kTYZE93gnT22m1Yu$ooG5~nzW*c-nc|*?V^aLfqd|B%VCEd z_!0b!9ry|SC$N(kk+Bn&96fr!@;r}i(*63k1K$A+X(!=+oM*O~2%gWxfWb)##v)ic z9{~q9B0YN23*95r`2gfxhzlM@>6Q$%VOtJ@dJr|!d|FO4Bw)rQpTZvWWxz-=(HP>i32O%=i^w!!hU}H52Z>Cg;9~+&<_rur`aAl7<+#{``weNx=D_oR1Y^ z#*lN^ftN5P>1C2t1qv}dk>9s!bQLvucvY#9fEnMyE9gV-EQq4O4@;YCBkDS8M){&@ zkboKEd;zSzP?b?MvK3XgqUwV7nl}8kiFTXek@Vf^LOYAAqdEXhvhLB8 zJ7tgC=m2%NeOM_K(1s9uUCR>7EX-~h`N1m$Ln*TIxg<{C$gn@X&P!+h9YMQ4gIkdt z$4TT+3o+bkwT`@(TjOl1*yF?9p4U84XNzD99K0cy*C0`6R*RdWK*euV{6StN>vU7S p0tyxZ+JiQ+ Date: Fri, 22 Sep 2023 01:52:13 +0400 Subject: [PATCH 11/58] [js/web] fp16 Pool & Reduce (#17512) ### Description Two more ops to support fp16 --- js/web/lib/wasm/jsep/webgpu/ops/pool.ts | 6 +- js/web/lib/wasm/jsep/webgpu/ops/reduce.ts | 14 +- .../providers/js/js_execution_provider.cc | 256 +++++++++--------- .../core/providers/js/operators/pool.cc | 112 ++++---- .../core/providers/js/operators/pool.h | 8 +- .../core/providers/js/operators/reduce.cc | 28 +- .../core/providers/js/operators/reduce.h | 2 +- 7 files changed, 206 insertions(+), 220 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/pool.ts b/js/web/lib/wasm/jsep/webgpu/ops/pool.ts index 8c8c12fc54ddb..120a0e9de5490 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/pool.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/pool.ts @@ -1,7 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {PoolConvUtil, ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; @@ -22,9 +21,6 @@ const validateInputs = (inputs: readonly TensorView[]): void => { if (inputs[0].dims.length !== 4) { throw new Error('Pool ops supports 2-D inputs only for now.'); } - if (inputs[0].dataType !== DataType.float) { - throw new Error('Invalid input type.'); - } }; const getAdjustedPoolAttributesAndOutputShape = ( @@ -248,7 +244,7 @@ const createAveragePoolProgramInfo = const kernelSize = ShapeUtil.size(adjustedAttributes.kernelShape); const x = inputVariable('x', input.dataType, input.dims); - const dataType = 'f32'; + const dataType = x.type.value; const op1 = 'value += x_val;'; let op2 = ''; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts b/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts index 0b8d03ea73b6b..598b1db033c61 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts @@ -17,10 +17,6 @@ const validateInputs = (inputs: readonly TensorView[]): void => { if (inputs.length === 2 && inputs[1].dims.length !== 1) { throw new Error('Invalid axes input dims.'); } - - if (inputs[0].dataType !== DataType.float) { - throw new Error('Invalid input type.'); - } }; export interface ReduceAttributes extends AttributeWithCacheKey { @@ -161,7 +157,7 @@ export const reduceL1 = (context: ComputeContext, attributes: ReduceAttributes): export const reduceL2 = (context: ComputeContext, attributes: ReduceAttributes): void => { validateInputs(context.inputs); const reduceOp: ReduceOp = (input, output) => - [`var t = f32(0); var value = ${output.type.storage}(0);`, + [`var t = ${output.type.value}(0); var value = ${output.type.value}(0);`, '', `t = ${input.getByOffset('inputOffset')}; value += (t * t);`, 'value = sqrt(value);', @@ -212,10 +208,10 @@ export const reduceMean = (context: ComputeContext, attributes: ReduceAttributes } return [ - `var value = ${output.type.storage}(0);`, + 'var sum = f32(0);', '', - `value += ${input.getByOffset('inputOffset')};`, - `value = value / ${size}.;`, + `sum += f32(${input.getByOffset('inputOffset')});`, + `let value = ${output.type.value}(sum / ${size});`, ]; }; context.compute(createReduceProgramInfoLoader(context.inputs, 'ReduceMean', attributes, reduceOp), {inputs: [0]}); @@ -266,7 +262,7 @@ export const reduceSum = (context: ComputeContext, attributes: ReduceAttributes) export const reduceSumSquare = (context: ComputeContext, attributes: ReduceAttributes): void => { validateInputs(context.inputs); const reduceOp: ReduceOp = (input, output) => - [`var t = f32(0); var value = ${output.type.storage}(0);`, + [`var t = ${output.type.value}(0); var value = ${output.type.value}(0);`, '', `t = ${input.getByOffset('inputOffset')}; value += t * t;`, '', diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index 0674fe02d093d..72e36a161e9aa 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -129,56 +129,56 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 14, Rel class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, 15, LeakyRelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 16, LeakyRelu); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 11, float, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, 12, float, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, float, ReduceMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, float, ReduceMax); - -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, ReduceMean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, ReduceMean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, float, ReduceMean); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, float, ReduceMean); - -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 11, float, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, 12, float, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, float, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, float, ReduceMin); - -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, ReduceProd); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, ReduceProd); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, float, ReduceProd); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, float, ReduceProd); - -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, ReduceSum); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, float, ReduceSum); - -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, ReduceL1); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, ReduceL1); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, float, ReduceL1); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, float, ReduceL1); - -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, ReduceL2); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, ReduceL2); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, float, ReduceL2); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, float, ReduceL2); - -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, ReduceLogSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, ReduceLogSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, float, ReduceLogSum); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, float, ReduceLogSum); - -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, ReduceSumSquare); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, ReduceSumSquare); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, float, ReduceSumSquare); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, float, ReduceSumSquare); - -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, ReduceLogSumExp); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, ReduceLogSumExp); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, float, ReduceLogSumExp); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, float, ReduceLogSumExp); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, ReduceMax); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 11, ReduceMax); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, 12, ReduceMax); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, ReduceMax); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, ReduceMax); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, ReduceMean); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, ReduceMean); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, ReduceMean); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, ReduceMean); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, ReduceMin); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 11, ReduceMin); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, 12, ReduceMin); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, ReduceMin); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, ReduceMin); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, ReduceProd); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, ReduceProd); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, ReduceProd); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, ReduceProd); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, ReduceSum); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, ReduceSum); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, ReduceSum); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, ReduceL1); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, ReduceL1); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, ReduceL1); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, ReduceL1); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, ReduceL2); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, ReduceL2); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, ReduceL2); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, ReduceL2); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, ReduceLogSum); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, ReduceLogSum); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, ReduceLogSum); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, ReduceLogSum); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, ReduceSumSquare); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, ReduceSumSquare); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, ReduceSumSquare); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, ReduceSumSquare); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, ReduceLogSumExp); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, ReduceLogSumExp); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, ReduceLogSumExp); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, ReduceLogSumExp); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, ThresholdedRelu); @@ -234,11 +234,11 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Tra class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, float, Conv); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, float, ConvTranspose); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, 11, float, MaxPool); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 12, float, MaxPool); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, float, AveragePool); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 1, float, GlobalAveragePool); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 1, float, GlobalMaxPool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, 11, MaxPool); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 12, MaxPool); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, AveragePool); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 1, GlobalAveragePool); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 1, GlobalMaxPool); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, Conv); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, float, Conv); @@ -251,16 +251,16 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Gem class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, MatMul); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, MatMul); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 9, float, AveragePool); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, float, AveragePool); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, float, AveragePool); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, float, GlobalAveragePool); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 7, float, MaxPool); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 8, 9, float, MaxPool); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, float, MaxPool); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 11, float, MaxPool); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, float, MaxPool); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, float, GlobalMaxPool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 9, AveragePool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, AveragePool); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, AveragePool); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, GlobalAveragePool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 7, MaxPool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 8, 9, MaxPool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, MaxPool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 11, MaxPool); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, MaxPool); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, GlobalMaxPool); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, ArgMax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, ArgMax); @@ -438,71 +438,71 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -515,16 +515,16 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/js/operators/pool.cc b/onnxruntime/core/providers/js/operators/pool.cc index 03e6caef7e5b8..7fdb4e5d114ea 100644 --- a/onnxruntime/core/providers/js/operators/pool.cc +++ b/onnxruntime/core/providers/js/operators/pool.cc @@ -8,69 +8,65 @@ namespace onnxruntime { namespace js { -#define POOLING_KERNEL(op_name, domain, is_channels_last, data_type, pool_type, since_version) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - op_name, \ - domain, \ - since_version, \ - data_type, \ - kJsExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - Pool); +#define POOLING_KERNEL(op_name, domain, is_channels_last, pool_type, since_version) \ + ONNX_OPERATOR_KERNEL_EX( \ + op_name, \ + domain, \ + since_version, \ + kJsExecutionProvider, \ + (*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()), \ + Pool); -#define POOLING_KERNEL_VERSIONED(op_name, domain, is_channels_last, data_type, pool_type, since_version, end_version) \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - op_name, \ - domain, \ - since_version, \ - end_version, \ - data_type, \ - kJsExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - Pool); +#define POOLING_KERNEL_VERSIONED(op_name, domain, is_channels_last, pool_type, since_version, end_version) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + op_name, \ + domain, \ + since_version, \ + end_version, \ + kJsExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", JsepSupportedFloatTypes()), \ + Pool); -#define POOLING_KERNEL_WITH_INDICES(op_name, domain, is_channels_last, data_type, pool_type, since_version) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - op_name, \ - domain, \ - since_version, \ - data_type, \ - kJsExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("I", DataTypeImpl::GetTensorType()), \ - Pool); +#define POOLING_KERNEL_WITH_INDICES(op_name, domain, is_channels_last, pool_type, since_version) \ + ONNX_OPERATOR_KERNEL_EX( \ + op_name, \ + domain, \ + since_version, \ + kJsExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", JsepSupportedFloatTypes()) \ + .TypeConstraint("I", DataTypeImpl::GetTensorType()), \ + Pool); -#define POOLING_KERNEL_VERSIONED_WITH_INDICES(op_name, domain, is_channels_last, data_type, pool_type, since_version, end_version) \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - op_name, \ - domain, \ - since_version, \ - end_version, \ - data_type, \ - kJsExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("I", DataTypeImpl::GetTensorType()), \ - Pool); +#define POOLING_KERNEL_VERSIONED_WITH_INDICES(op_name, domain, is_channels_last, pool_type, since_version, end_version) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + op_name, \ + domain, \ + since_version, \ + end_version, \ + kJsExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", JsepSupportedFloatTypes()) \ + .TypeConstraint("I", DataTypeImpl::GetTensorType()), \ + Pool); -POOLING_KERNEL_VERSIONED(AveragePool, kOnnxDomain, false, float, AveragePool, 7, 9) -POOLING_KERNEL_VERSIONED(AveragePool, kOnnxDomain, false, float, AveragePool, 10, 10) -POOLING_KERNEL(AveragePool, kOnnxDomain, false, float, AveragePool, 11) -POOLING_KERNEL(AveragePool, kMSInternalNHWCDomain, true, float, AveragePool, 11) -POOLING_KERNEL(GlobalAveragePool, kOnnxDomain, false, float, AveragePool, 1) -POOLING_KERNEL(GlobalAveragePool, kMSInternalNHWCDomain, true, float, AveragePool, 1) +POOLING_KERNEL_VERSIONED(AveragePool, kOnnxDomain, false, AveragePool, 7, 9) +POOLING_KERNEL_VERSIONED(AveragePool, kOnnxDomain, false, AveragePool, 10, 10) +POOLING_KERNEL(AveragePool, kOnnxDomain, false, AveragePool, 11) +POOLING_KERNEL(AveragePool, kMSInternalNHWCDomain, true, AveragePool, 11) +POOLING_KERNEL(GlobalAveragePool, kOnnxDomain, false, AveragePool, 1) +POOLING_KERNEL(GlobalAveragePool, kMSInternalNHWCDomain, true, AveragePool, 1) -POOLING_KERNEL_VERSIONED(MaxPool, kOnnxDomain, false, float, MaxPool<1>, 1, 7) -POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, kOnnxDomain, false, float, MaxPool<8>, 8, 9) -POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, kOnnxDomain, false, float, MaxPool<8>, 10, 10) -POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, kOnnxDomain, false, float, MaxPool<8>, 11, 11) -POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, kMSInternalNHWCDomain, true, float, MaxPool<8>, 11, 11) -POOLING_KERNEL_WITH_INDICES(MaxPool, kOnnxDomain, false, float, MaxPool<8>, 12) -POOLING_KERNEL_WITH_INDICES(MaxPool, kMSInternalNHWCDomain, true, float, MaxPool<8>, 12) -POOLING_KERNEL(GlobalMaxPool, kOnnxDomain, false, float, MaxPool<1>, 1) -POOLING_KERNEL(GlobalMaxPool, kMSInternalNHWCDomain, true, float, MaxPool<1>, 1) +POOLING_KERNEL_VERSIONED(MaxPool, kOnnxDomain, false, MaxPool<1>, 1, 7) +POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, kOnnxDomain, false, MaxPool<8>, 8, 9) +POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, kOnnxDomain, false, MaxPool<8>, 10, 10) +POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, kOnnxDomain, false, MaxPool<8>, 11, 11) +POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, kMSInternalNHWCDomain, true, MaxPool<8>, 11, 11) +POOLING_KERNEL_WITH_INDICES(MaxPool, kOnnxDomain, false, MaxPool<8>, 12) +POOLING_KERNEL_WITH_INDICES(MaxPool, kMSInternalNHWCDomain, true, MaxPool<8>, 12) +POOLING_KERNEL(GlobalMaxPool, kOnnxDomain, false, MaxPool<1>, 1) +POOLING_KERNEL(GlobalMaxPool, kMSInternalNHWCDomain, true, MaxPool<1>, 1) } // namespace js } // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/pool.h b/onnxruntime/core/providers/js/operators/pool.h index 5dbe5d0b8881d..5723123c0c3b8 100644 --- a/onnxruntime/core/providers/js/operators/pool.h +++ b/onnxruntime/core/providers/js/operators/pool.h @@ -41,7 +41,7 @@ namespace js { #define GLOBAL_POOL_ATTRIBUTES_JS_OBJ_MAPPING ({"format" : $1 ? "NHWC" : "NCHW"}) #define GLOBAL_POOL_ATTRIBUTES_PARAM_LIST static_cast(is_channels_last) -template +template class Pool : public JsKernel, public PoolBase { public: Pool(const OpKernelInfo& info) : JsKernel(info), PoolBase(info) { @@ -65,10 +65,10 @@ class Pool : public JsKernel, public PoolBase { } }; -template -class Pool, is_channels_last> final : public Pool, is_channels_last> { +template +class Pool, is_channels_last> final : public Pool, is_channels_last> { public: - Pool(const OpKernelInfo& info) : Pool, is_channels_last>(info) {} + Pool(const OpKernelInfo& info) : Pool, is_channels_last>(info) {} }; } // namespace js diff --git a/onnxruntime/core/providers/js/operators/reduce.cc b/onnxruntime/core/providers/js/operators/reduce.cc index 21854fccc37ca..2679cfed86124 100644 --- a/onnxruntime/core/providers/js/operators/reduce.cc +++ b/onnxruntime/core/providers/js/operators/reduce.cc @@ -7,32 +7,30 @@ namespace onnxruntime { namespace js { #define REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL(ReduceOp, sinceVersion, endVersion) \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ ReduceOp, \ kOnnxDomain, \ sinceVersion, endVersion, \ - float, \ kJsExecutionProvider, \ (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - ReduceOp); + .TypeConstraint("T", JsepSupportedFloatTypes()), \ + ReduceOp); // macro REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL does not set .InputMemoryType(OrtMemTypeCPU, 1), so in future if // a new opset version update applies to Reduce* operators, we may need to add another macro like // REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL_WITH_AXIS_IN_INPUT to set input memory type. // i.e. we cannot use REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL to version 18 when the opset version is increased. -#define REGISTER_REDUCE_ELEMENTWISE_KERNEL(ReduceOp, sinceVersion) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - ReduceOp, \ - kOnnxDomain, \ - sinceVersion, \ - float, \ - kJsExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ - .InputMemoryType(OrtMemTypeCPU, 1), \ - ReduceOp); +#define REGISTER_REDUCE_ELEMENTWISE_KERNEL(ReduceOp, sinceVersion) \ + ONNX_OPERATOR_KERNEL_EX( \ + ReduceOp, \ + kOnnxDomain, \ + sinceVersion, \ + kJsExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", JsepSupportedFloatTypes()) \ + .InputMemoryType(OrtMemTypeCPU, 1), \ + ReduceOp); REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL(ReduceMean, 1, 10); REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL(ReduceMean, 11, 12); diff --git a/onnxruntime/core/providers/js/operators/reduce.h b/onnxruntime/core/providers/js/operators/reduce.h index 19a6d298c7696..a5a4aa834c2ca 100644 --- a/onnxruntime/core/providers/js/operators/reduce.h +++ b/onnxruntime/core/providers/js/operators/reduce.h @@ -9,7 +9,7 @@ namespace onnxruntime { namespace js { #define JSEP_DEFINE_REDUCE_KERNEL(ReduceKernel) \ - template \ + template \ class ReduceKernel : public JsKernel, public ReduceKernelBase { \ public: \ using ReduceKernelBase::axes_; \ From 6b7bce5ec992f2b3333ee22066201f53e7978faf Mon Sep 17 00:00:00 2001 From: pengwa Date: Fri, 22 Sep 2023 08:54:25 +0800 Subject: [PATCH 12/58] Model post process for zero stage3 training (#17187) ### Model post process for zero stage3 training This is the last change to make single GPU/Multiple GPUs run pass. Design details: https://microsoft.sharepoint.com/:p:/t/ONNX2/EfNfJ43necpIoPI6x5M2zvYBVbfjoPQmG4Boc_F7-tHm1w?e=ekQwA6&nav=eyJzSWQiOjMxNiwiY0lkIjoxMDE1Nzg3NDZ9 `PyTorch` runs with ZeROOffloadSubscriber: ``` model = prepare_model(...) from onnxruntime.training.utils.hooks import configure_ort_compatible_zero_stage3 configure_ort_compatible_zero_stage3() ``` `ORTModule` runs with ZeROOffloadSubscriber: ``` os.environ['ORTMODULE_ENABLE_ZERO_STAGE3'] = '1' from onnxruntime.training.ortmodule import ORTModule model = ORTModule(self.model) ``` It will be fairly easy to debug convergence issue if both ORT and PyTorch can run the same offload path. ### Motivation and Context --- .../_custom_autograd_function_exporter.py | 28 +- .../_custom_autograd_function_runner.py | 10 + .../ortmodule/_graph_execution_manager.py | 62 +++- .../training/ortmodule/_inference_manager.py | 4 + .../python/training/ortmodule/_io.py | 8 +- .../training/ortmodule/_training_manager.py | 4 + .../ortmodule/_zero_stage3_compatibility.py | 312 ++++++++++++++++++ .../python/training/utils/__init__.py | 3 +- .../utils/hooks/_statistics_subscriber.py | 171 +++++----- .../utils/hooks/_subscriber_manager.py | 17 +- .../utils/hooks/_zero_offload_subscriber.py | 155 ++++++--- .../python/training/utils/torch_type_map.py | 9 + .../torch_custom_function_kernel_base.cc | 7 +- 13 files changed, 619 insertions(+), 171 deletions(-) create mode 100644 orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py index 4c72b6d98a088..f75d553a5f460 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py @@ -28,7 +28,8 @@ class PythonOpShapeInferStore: @classmethod def register(cls, kclass: torch.autograd.Function) -> None: - """Register a shape inference function for a torch.autograd.Function if there is staticmethod "infer_shape" defined. + """Register a shape inference function for a torch.autograd.Function if there is staticmethod + "infer_shape" defined. The signature of the shape inference function should be: @staticmethod @@ -51,6 +52,11 @@ def infer_shape( if hasattr(kclass, "infer_shape") and kclass_name not in cls._CLASS_MAP: cls._CLASS_MAP[kclass_name] = kclass.infer_shape + @classmethod + def register_func(cls, name: str, func: Callable) -> None: + """Register a shape inference function for a torch.autograd.Function by name.""" + cls._CLASS_MAP[name] = func + @classmethod def get_shape_infer(cls, name: str) -> Optional[Callable]: return cls._CLASS_MAP.get(name, None) @@ -228,9 +234,9 @@ def _export_pt_1_10(g, n, *args, **kwargs): input_float_tuples.extend(list(arg)) continue - is_inspect_activation = ( - func_full_qual_name == "onnxruntime.training.utils.hooks._subscriber_manager._InspectActivation" - ) + from onnxruntime.training.utils.hooks._statistics_subscriber import _InspectActivation + + is_inspect_activation = func_full_qual_name == get_fully_qualified_class_name(_InspectActivation) if is_inspect_activation and isinstance(arg, str): # _InspectActivation is a special case where the first argument is a string # that is used to determine the activation name to be inspected. @@ -307,14 +313,7 @@ def _export_pt_1_10(g, n, *args, **kwargs): _export = wrap_custom_export_function(_export_pt_1_10) -def _post_process_after_export(exported_model: ModelProto, enable_custom_autograd_function: bool) -> ModelProto: - """Post process the exported model.""" - if enable_custom_autograd_function: - exported_model = _post_process_enabling_autograd_function(exported_model) - return exported_model - - -def _post_process_enabling_autograd_function(exported_model: ModelProto) -> ModelProto: +def post_process_enabling_autograd_function(exported_model: ModelProto) -> ModelProto: # Loop all PythonOp, append "_ctx" as the first output. index = 0 for node in exported_model.graph.node: @@ -330,8 +329,7 @@ def _post_process_enabling_autograd_function(exported_model: ModelProto) -> Mode op_name_prefix = kclass_name break - if not node.name: - node.name = f"{op_name_prefix}_id_{index}" - index += 1 + node.name = f"{op_name_prefix}_id_{index}" + index += 1 return exported_model diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py index 845c7d83c2e7b..a5b96c4e37140 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py @@ -376,6 +376,16 @@ def wrap_all_outputs(result): result = backward_function(*wrapped_args) # Extract results as DLPack tensor list. + if isinstance(result, torch.Tensor): + result = [result] + elif isinstance(result, (tuple, list)): + result = list(result) + else: + raise wrap_exception( + ORTModuleIOError, + TypeError(f"ORTModule does not support the following model output type {type(result)}."), + ) + wrapped_returned_args = wrap_all_outputs(result) torch_interop_utils.unregister_grad_fn(id(ctx)) diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index 2227b630aee23..dfaac5f0fa836 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -19,11 +19,10 @@ import onnxruntime from onnxruntime.capi import _pybind_state as C from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference -from onnxruntime.training.utils import ORTModelInputOutputSchemaType +from onnxruntime.training.utils import ORTModelInputOutputSchemaType, onnx_dtype_to_pytorch from onnxruntime.training.utils.hooks import configure_ort_compatible_zero_stage3 from . import _are_deterministic_algorithms_enabled, _io, _logger, _onnx_models, _utils -from ._custom_autograd_function_exporter import _post_process_after_export from ._fallback import ( ORTModuleDeviceException, ORTModuleONNXModelException, @@ -141,9 +140,14 @@ def __init__( register_triton_op_executor() + self._zero_stage3_param_map = {} if self._runtime_options.enable_zero_stage3_support: # Cannot toggle feature enabling/disabling after the first time enabled. - configure_ort_compatible_zero_stage3() + from onnxruntime.training.utils.hooks._zero_offload_subscriber import _get_all_zero_stage3_params + + self._zero_stage3_param_map = _get_all_zero_stage3_params(self._flattened_module) + + configure_ort_compatible_zero_stage3(debug=False, stats_output_dir="ort_output", stats_overwrite=True) def _get_torch_gpu_allocator_function_addresses(self): if self._runtime_options.use_external_gpu_allocator and torch.cuda.is_available(): @@ -345,7 +349,8 @@ def _get_exported_model(self, input_schema: ORTModelInputOutputSchemaType, *inpu ) if os.path.exists(cache_dir) and os.path.isfile(filename): self._logger.info( - f"Cached model detected! Cached model will be used to save export and initialization time. If you want the model to be re-exported then DELETE {filename}." + f"Cached model detected! Cached model will be used to save export and initialization time." + f"If you want the model to be re-exported then DELETE {filename}." ) exported_model = onnx.load(filename) return exported_model @@ -409,9 +414,24 @@ def _get_exported_model(self, input_schema: ORTModelInputOutputSchemaType, *inpu ) exported_model = onnx.load_model_from_string(f.getvalue()) - exported_model = _post_process_after_export( - exported_model, self._runtime_options.enable_custom_autograd_function - ) + if self._runtime_options.enable_custom_autograd_function: + from ._custom_autograd_function_exporter import post_process_enabling_autograd_function + + exported_model = post_process_enabling_autograd_function(exported_model) + + if self._runtime_options.enable_zero_stage3_support: + from ._zero_stage3_compatibility import post_processing_enable_zero_stage3_compat + + exported_model = post_processing_enable_zero_stage3_compat( + exported_model, + self._zero_stage3_param_map, + [name for name, _ in self._flattened_module.named_parameters()], + ) + + # Cannot append pull weight trigger name to input names as following, otherwise, the later check ( + # https://github.com/microsoft/onnxruntime/blob/068300d97eb25e5b52324e7af54a45ed1fa6a4c3/orttraining/orttraining/python/training/ortmodule/_training_manager.py#L466C18-L466C18) + # find input info mismatch, will re-initialize the graph builder. + # self._input_info.require_grad_names.append(STAGE3_PULL_WEIGHT_TRIGGER_NAME) # Cache model for future runs if cache_dir: @@ -477,7 +497,14 @@ def _initialize_graph_builder(self): grad_builder_config = C.OrtModuleGraphBuilderConfiguration() grad_builder_config.initializer_names = initializer_names grad_builder_config.initializer_names_to_train = initializer_names_to_train - grad_builder_config.input_names_require_grad = self._input_info.require_grad_names + + input_names_require_grad = self._input_info.require_grad_names + if self._runtime_options.enable_zero_stage3_support: + from ._zero_stage3_compatibility import STAGE3_PULL_WEIGHT_TRIGGER_NAME + + # Add stage3 pull weight trigger name to require_grad_names, so that it will be included in the gradient graph. + input_names_require_grad.append(STAGE3_PULL_WEIGHT_TRIGGER_NAME) + grad_builder_config.input_names_require_grad = input_names_require_grad grad_builder_config.build_gradient_graph = self._export_mode == torch.onnx.TrainingMode.TRAINING grad_builder_config.enable_caching = self._runtime_options.enable_grad_acc_optimization grad_builder_config.loglevel = _logger.ortmodule_loglevel_to_onnxruntime_c_loglevel( @@ -553,6 +580,9 @@ def _enable_conditional_optimizations( inputs, kwargs ) + if self._runtime_options.enable_zero_stage3_support: + self._append_pull_weight_trigger_as_input(kwargs, detected_device) + _, embed_sparsity_results, label_sparsity_results = _io._combine_input_buffers_initializers( self._graph_initializers, self._graph_builder.get_graph_info().user_input_names, @@ -562,6 +592,7 @@ def _enable_conditional_optimizations( kwargs, detected_device, self._runtime_inspector, + self._zero_stage3_param_map, ) # Enable sparsity-based optimization when applicable. @@ -587,6 +618,21 @@ def _enable_conditional_optimizations( if self._runtime_options.print_memory_stat: self._runtime_inspector.enable_memory_inspector(self._original_module) + def _append_pull_weight_trigger_as_input(self, kwargs: Dict, device: torch.device): + from ._zero_stage3_compatibility import ( + STAGE3_PULL_WEIGHT_TRIGGER_NAME, + STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE, + STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE, + ) + + kwargs[STAGE3_PULL_WEIGHT_TRIGGER_NAME] = torch.zeros( + STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE, + dtype=onnx_dtype_to_pytorch(STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE), + device=device, + ).requires_grad_() + + return kwargs + def _log_feature_stats(self): if get_rank() != 0: return diff --git a/orttraining/orttraining/python/training/ortmodule/_inference_manager.py b/orttraining/orttraining/python/training/ortmodule/_inference_manager.py index b7c01a1f5baf9..8d8be81c549d1 100644 --- a/orttraining/orttraining/python/training/ortmodule/_inference_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_inference_manager.py @@ -159,6 +159,9 @@ def forward(self, *inputs, **kwargs): # Assert that the input and model device match _utils._check_same_device(self._device, "Input argument to forward", *inputs) + if self._runtime_options.enable_zero_stage3_support: + self._append_pull_weight_trigger_as_input(kwargs, self._device) + prepared_input_list, _, _ = _io._combine_input_buffers_initializers( self._graph_initializers, self._graph_info.user_input_names, @@ -168,6 +171,7 @@ def forward(self, *inputs, **kwargs): kwargs, self._device, self._runtime_inspector, + self._zero_stage3_param_map, ) user_outputs, _ = InferenceManager.execution_session_run_forward( diff --git a/orttraining/orttraining/python/training/ortmodule/_io.py b/orttraining/orttraining/python/training/ortmodule/_io.py index 18b965c549645..e7c1b30daae0d 100644 --- a/orttraining/orttraining/python/training/ortmodule/_io.py +++ b/orttraining/orttraining/python/training/ortmodule/_io.py @@ -168,6 +168,7 @@ def _combine_input_buffers_initializers( kwargs: Mapping[str, ORTModelInputOutputType], device: torch.device, rt_inspector: RuntimeInspector, + zero_stage3_offload_param_map: Optional[Dict[str, torch.nn.parameter.Parameter]], ): """Creates forward `*inputs` list from user input and PyTorch initializers @@ -254,7 +255,12 @@ def _expand_inputs(current_input, non_none_inputs, name=""): ) # params is a list of all initializers known to the onnx graph - result.extend(params) + if zero_stage3_offload_param_map: + for p in params: + if p not in zero_stage3_offload_param_map.values(): + result.append(p) + else: + result.extend(params) return result, embed_sparsity_results, label_sparsity_results diff --git a/orttraining/orttraining/python/training/ortmodule/_training_manager.py b/orttraining/orttraining/python/training/ortmodule/_training_manager.py index 3be4c05797978..19effe2086e0a 100644 --- a/orttraining/orttraining/python/training/ortmodule/_training_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_training_manager.py @@ -311,6 +311,9 @@ def forward(self, *inputs, **kwargs): self._gradient_accumulation_manager.maybe_update_cache_before_run() + if self._runtime_options.enable_zero_stage3_support: + self._append_pull_weight_trigger_as_input(kwargs, self._device) + prepared_input_list, _, _ = _io._combine_input_buffers_initializers( self._graph_initializers, self._graph_info.user_input_names, @@ -320,6 +323,7 @@ def forward(self, *inputs, **kwargs): kwargs, self._device, self._runtime_inspector, + self._zero_stage3_param_map, ) outputs = unflatten_user_output( diff --git a/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py b/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py new file mode 100644 index 0000000000000..17756600d601e --- /dev/null +++ b/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py @@ -0,0 +1,312 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +from typing import Dict, List, Optional, Tuple, Union + +import torch +from onnx import ModelProto, NodeProto, TensorProto, ValueInfoProto, helper + +from onnxruntime.capi._pybind_state import register_torch_autograd_function +from onnxruntime.training.utils import pytorch_dtype_to_onnx + +from ._custom_autograd_function_exporter import PythonOpShapeInferStore +from ._utils import get_fully_qualified_class_name + +STAGE3_PULL_WEIGHT_TRIGGER_NAME = "pull_weight_trigger" +STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE = TensorProto.FLOAT +STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE = [1] + + +def post_processing_enable_zero_stage3_compat( + exported_model: ModelProto, + zero_stage3_named_params: Dict[str, torch.nn.parameter.Parameter], + all_param_names: List[str], +) -> ModelProto: + """This function is used to enable zero stage3 compatibility. + + Args: + exported_model (ModelProto): The exported model. + zero_stage3_named_params (Optional[Dict[str, torch.nn.parameter.Parameter]]): The offload named parameters. + all_param_names (List[str]): All parameter names. + """ + + # Register symbolic shape inference functions for PythonOp used in DeepSpeed ZeRO stage3. + _register_symbolic_shape_infer_functions() + + # Create weight retrieving function using zero_stage3_named_params. + func_full_qual_name = _create_weight_retrieval_function(zero_stage3_named_params) + + consumer_map = {} + for node in exported_model.graph.node: + for inp in node.input: + if inp not in consumer_map: + consumer_map[inp] = [] + + if node not in consumer_map[inp]: + consumer_map[inp].append(node) + + def _get_param_pull_trigger_name(param_name: str) -> str: + return f"pull_{param_name}" + + def _get_func_name(node: NodeProto) -> Optional[str]: + for attr in node.attribute: + if attr.name == "func_name": + return attr.s.decode("utf-8") if isinstance(attr.s, bytes) else attr.s + return None + + # Create weight retrieving PythonOp. + new_input, weight_pull_node = _create_weight_retrieval_pythonop( + zero_stage3_named_params, + func_full_qual_name, + STAGE3_PULL_WEIGHT_TRIGGER_NAME, + [_get_param_pull_trigger_name(pname) for pname in zero_stage3_named_params], + STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE, + STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE, + ) + + from onnxruntime.training.utils.hooks._zero_offload_subscriber import ORTZeROOffloadPreForwardFunction + + prefowrad_function_name = get_fully_qualified_class_name(ORTZeROOffloadPreForwardFunction) + + # Connect weight consumers to use the full-sized parameter output of ORTZeROOffloadPreForwardFunction. + for graph_input in exported_model.graph.input: + if graph_input.name not in zero_stage3_named_params: + continue + + if graph_input.name not in consumer_map: + continue + + consumers = consumer_map[graph_input.name] + pre_forward_pythonop_node = None + + for c in consumers: + if c.op_type != "PythonOp": + continue + + func_name = _get_func_name(c) + if func_name == prefowrad_function_name: + assert ( + pre_forward_pythonop_node is None + ), "Multiple ORTZeROOffloadPreForwardFunction nodes found, it should not happen" + pre_forward_pythonop_node = c + + if pre_forward_pythonop_node is None: + raise RuntimeError( + "Fail to find ORTZeROOffloadPreForwardFunction for partitioned param: " + graph_input.name + ) + + index_offset_on_python_op_input = [] + for i, input_name in enumerate(pre_forward_pythonop_node.input): + if input_name == graph_input.name: + index_offset_on_python_op_input.append(i) + + assert ( + len(index_offset_on_python_op_input) == 1 + ), f"index_offset_on_python_op_input length is not 1: {index_offset_on_python_op_input}" + + reverse_index_among_inputs = index_offset_on_python_op_input[0] - len(pre_forward_pythonop_node.input) + new_input_name = _get_param_pull_trigger_name(graph_input.name) + pre_forward_pythonop_node.input[index_offset_on_python_op_input[0]] = new_input_name + + _update_python_op_input_related_attributes( + pre_forward_pythonop_node, + new_input_name, + len(STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE), # new rank + STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE, # new data type + ) + + output_index = reverse_index_among_inputs + len(pre_forward_pythonop_node.output) + pre_forward_pythonop_node.output[output_index] = graph_input.name + + # If the consumer of original `graph_input.name` is PythonOp, we need also update its attributes because now + # `graph_input.name` as output of pre_forward_pythonop_node, is full-sized parameter, the rank might differ + # from the original one. + for c in consumers: + if c == pre_forward_pythonop_node or c.op_type != "PythonOp": + continue + _update_python_op_input_related_attributes( + c, + graph_input.name, + len(zero_stage3_named_params[graph_input.name].ds_shape), # new rank + pytorch_dtype_to_onnx(zero_stage3_named_params[graph_input.name].dtype), # new data type + ) + + # Delete exported_model.graph.input + graph_inputs_to_remove = [ + graph_input for graph_input in exported_model.graph.input if graph_input.name in zero_stage3_named_params + ] + for input_to_remove in graph_inputs_to_remove: + exported_model.graph.input.remove(input_to_remove) + + # Re-order graph input to make sure the weight pull trigger is before all parameter inputs. + offset = 0 + for graph_input in exported_model.graph.input: + if graph_input.name in all_param_names: + break + offset += 1 + + exported_model.graph.input.insert(offset, new_input) + exported_model.graph.node.insert(0, weight_pull_node) + + return exported_model + + +def _create_weight_retrieval_function( + zero_stage3_named_params: Optional[Dict[str, torch.nn.parameter.Parameter]] +) -> str: + """This function is used to create a weight retrieving function using zero_stage3_named_params.""" + + class WeightRetrievalFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, weight_in_trigger): + params = list(zero_stage3_named_params.values()) + ctx.params = params + ctx.dtype = weight_in_trigger.dtype + ctx.device = weight_in_trigger.device + ctx.shape = weight_in_trigger.shape + return (torch.zeros(ctx.shape, device=ctx.device, dtype=ctx.dtype),) * len(params) + + @staticmethod + def backward(ctx, *grad_outputs): + return torch.zeros(ctx.shape, device=ctx.device, dtype=ctx.dtype) + + @staticmethod + def infer_shape( + node: NodeProto, + tensor_input_shapes: List[Optional[List[Union[int, str]]]], + tensor_input_dtypes: List[torch.onnx.TensorProtoDataType], + ) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]: + param_count = len(zero_stage3_named_params.values()) + tensor_output_shapes = [ + tensor_input_shapes[0], + ] * param_count + tensor_output_dtypes = [ + tensor_input_dtypes[0], + ] * param_count + return tensor_output_shapes, tensor_output_dtypes + + func_full_qual_name = get_fully_qualified_class_name(WeightRetrievalFunction) + register_torch_autograd_function(func_full_qual_name, WeightRetrievalFunction) + PythonOpShapeInferStore.register(WeightRetrievalFunction) + + return func_full_qual_name + + +def _register_symbolic_shape_infer_functions(): + """This function is used to register symbolic shape inference functions for PythonOp used in + DeepSpeed ZeRO stage3.""" + + def _simple_pass_through_infer_shape( + node: NodeProto, + tensor_input_shapes: List[Optional[List[Union[int, str]]]], + tensor_input_dtypes: List[torch.onnx.TensorProtoDataType], + ) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]: + return tensor_input_shapes, tensor_input_dtypes + + PythonOpShapeInferStore.register_func( + "deepspeed.runtime.zero.parameter_offload.PreBackwardFunction", _simple_pass_through_infer_shape + ) + PythonOpShapeInferStore.register_func( + "deepspeed.runtime.zero.parameter_offload.PostBackwardFunction", _simple_pass_through_infer_shape + ) + + def _linear_infer_shape( + node: NodeProto, + tensor_input_shapes: List[Optional[List[Union[int, str]]]], + tensor_input_dtypes: List[torch.onnx.TensorProtoDataType], + ) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]: + # output = input.matmul(weight.t()) + tensor_input_shapes[0] # input + shape2 = tensor_input_shapes[1] # weight + output_shape = tensor_input_shapes[0] + output_shape[-1] = shape2[-2] + return [output_shape], [tensor_input_dtypes[0]] + + PythonOpShapeInferStore.register_func( + "deepspeed.runtime.zero.linear.LinearFunctionForZeroStage3", _linear_infer_shape + ) + + +def _create_weight_retrieval_pythonop( + zero_stage3_named_params: Optional[Dict[str, torch.nn.parameter.Parameter]], + func_full_qual_name: str, + input_name: str, + output_names: List[str], + STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE, + STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE: List[int], +) -> Tuple[ValueInfoProto, NodeProto]: + """This function is used to create a weight retrieving PythonOp.""" + offload_param_count = 0 if zero_stage3_named_params is None else len(zero_stage3_named_params) + new_input = helper.make_tensor_value_info( + input_name, STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE, STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE + ) + output_rank_for_pull_weight_trigger = len(STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE) + output_dtype_for_pull_weight_trigger = STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE + output_tensor_ranks = [ + output_rank_for_pull_weight_trigger, + ] * offload_param_count + output_tensor_types = [ + output_dtype_for_pull_weight_trigger, + ] * offload_param_count + + node_attributes = { + "comment": "", + "inplace": 0, + "input_convention": "d", + "input_tensor_ranks": [len(STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE)], + "input_tensor_types": [STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE], + "output_tensor_ranks": output_tensor_ranks, + "output_tensor_types": output_tensor_types, + "training_mode": 1, + "func_name": func_full_qual_name, + } + + weight_pull_node = helper.make_node( + "PythonOp", + [input_name], + ["pull_weight_trigger_ctx", *output_names], + "pull_weight_trigger", # node name + "PythonOp for weight retrieving.", + "com.microsoft", + **node_attributes, + ) + + return new_input, weight_pull_node + + +def _update_python_op_input_related_attributes(node: NodeProto, input_name: str, new_rank: int, new_dtype: int): + """This function is used to update PythonOp's input related attributes, e.g. + input_tensor_ranks and input_tensor_types. + + Args: + node (NodeProto): The PythonOp node. + input_name (str): The input name to be updated. + new_rank (int): The new rank of the input, to be used in input_tensor_ranks. + new_dtype (int): The new data type of the input, to be used in input_tensor_types. + """ + input_tensor_ranks = None + input_tensor_dtypes = None + rank_attr = None + dtype_attr = None + for attr in node.attribute: + if attr.name == "input_tensor_ranks": + input_tensor_ranks = attr.ints + rank_attr = attr + if attr.name == "input_tensor_types": + input_tensor_dtypes = attr.ints + dtype_attr = attr + + assert input_tensor_ranks is not None, "input_tensor_ranks is None" + assert input_tensor_dtypes is not None, "input_tensor_dtypes is None" + + for index, node_input_name in enumerate(node.input): + if node_input_name == input_name: + input_tensor_ranks[index] = new_rank + input_tensor_dtypes[index] = new_dtype + + node.attribute.remove(rank_attr) + node.attribute.remove(dtype_attr) + node.attribute.append(helper.make_attribute("input_tensor_ranks", input_tensor_ranks)) + node.attribute.append(helper.make_attribute("input_tensor_types", input_tensor_dtypes)) diff --git a/orttraining/orttraining/python/training/utils/__init__.py b/orttraining/orttraining/python/training/utils/__init__.py index acf2698d55eaf..fa7c9f2750cdd 100644 --- a/orttraining/orttraining/python/training/utils/__init__.py +++ b/orttraining/orttraining/python/training/utils/__init__.py @@ -9,7 +9,7 @@ extract_data_and_schema, unflatten_data_using_schema, ) -from onnxruntime.training.utils.torch_type_map import pytorch_dtype_to_onnx +from onnxruntime.training.utils.torch_type_map import onnx_dtype_to_pytorch, pytorch_dtype_to_onnx __all__ = [ "PrimitiveType", @@ -18,4 +18,5 @@ "extract_data_and_schema", "unflatten_data_using_schema", "pytorch_dtype_to_onnx", + "onnx_dtype_to_pytorch", ] diff --git a/orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py b/orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py index 6c8027b2fefaa..db1c69cf95ba4 100644 --- a/orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py +++ b/orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py @@ -6,6 +6,7 @@ import os import shutil import warnings +from io import TextIOWrapper from pathlib import Path from typing import List, Optional, Tuple, Union @@ -178,87 +179,97 @@ def _summarize_activations(self, tensor: torch.Tensor, depth: int, name: str, st order_file_path = step_path / "order.txt" tensor_file_path = step_path / output_file_name - # This is to try the best effort to align the count of numbers per line for easier comparison in diff views, - # though it does not always guarantee to do this way. - torch.set_printoptions(precision=6, linewidth=128) - - tensor_shape = tensor.shape - tensor_dtype = tensor.dtype - flatten_array = tensor.flatten().view(-1) - - if self._run_on_cpu: - flatten_array = flatten_array.to("cpu") - - if self._run_on_cpu: - num_nan = torch.isnan(flatten_array).sum() - num_inf = torch.isinf(flatten_array).sum() - num_neg = (flatten_array < 0).sum() - num_pos = (flatten_array > 0).sum() - num_zero = (flatten_array == 0).sum() - min_value = flatten_array.min() - max_value = flatten_array.max() - mean_value = flatten_array.mean() - std_value = flatten_array.std() - else: - # Split the calculation for each bucket, then do another round of calculation on the bucket results. - # This can at the best effort reduce the peak memory impact. - bucket_size = self._bucket_size - element_count = flatten_array.numel() - ceil_bucket_count = (element_count + bucket_size - 1) // (bucket_size) - nan_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device) - inf_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device) - neg_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device) - pos_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device) - zero_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device) - min_buckets = torch.zeros(ceil_bucket_count, dtype=flatten_array.dtype, device=flatten_array.device) - max_buckets = torch.zeros(ceil_bucket_count, dtype=flatten_array.dtype, device=flatten_array.device) - mean_buckets = torch.zeros(ceil_bucket_count, dtype=flatten_array.dtype, device=flatten_array.device) - std_buckets = torch.zeros(ceil_bucket_count, dtype=flatten_array.dtype, device=flatten_array.device) - - # Summary for each bucket - element_count_per_bucket = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device) - for i in range(ceil_bucket_count): - end = min((i + 1) * bucket_size, element_count) - bucket = flatten_array[i * bucket_size : end] - element_count_per_bucket[i] = bucket.numel() - - nan_buckets[i] = torch.isnan(bucket).sum() - inf_buckets[i] = torch.isinf(bucket).sum() - neg_buckets[i] = (bucket < 0).sum() - pos_buckets[i] = (bucket > 0).sum() - zero_buckets[i] = (bucket == 0).sum() - min_buckets[i] = bucket.min() - max_buckets[i] = bucket.max() - mean_buckets[i] = bucket.sum() - std_buckets[i] = bucket.std() - - # Reduction across all buckets - num_nan = nan_buckets.sum() - num_inf = inf_buckets.sum() - num_neg = neg_buckets.sum() - num_pos = pos_buckets.sum() - num_zero = zero_buckets.sum() - min_value = min_buckets.min() - max_value = max_buckets.max() - mean_value = float(mean_buckets.sum()) / float(element_count) - # Here we refer to - # https://math.stackexchange.com/questions/2971315/how-do-i-combine-standard-deviations-of-two-groups - # to calculate the combined standard deviation of all buckets. - s = (element_count_per_bucket - 1) * (std_buckets**2) + element_count_per_bucket * ( - (mean_buckets - mean_value) ** 2 - ) - std_value = torch.sqrt(s.sum() / (element_count - 1)) - with order_file_path.open(mode="a", encoding="utf-8") as f: f.write(f"{output_file_name}\n") with tensor_file_path.open(mode="w", encoding="utf-8") as f: - f.write( - f"{'>'*max(0, depth) + display_name} shape: {tensor_shape} dtype: {tensor_dtype} size: {flatten_array.size()} \n" - f"min: {min_value} max: {max_value}, mean: {mean_value}, " - f"std: {std_value} \n" - f"nan: {num_nan}, inf: {num_inf}\n" - ) - f.write(f"samples(top 128): {flatten_array[:128]}\n") - f.write(f"neg: {num_neg}, pos: {num_pos}, zero: {num_zero},\n") - f.write(f"{'='*16}\n") + _summarize_tensor(display_name, tensor, f, depth, self._run_on_cpu, self._bucket_size) + + +def _summarize_tensor( + display_name: str, + tensor: torch.Tensor, + f: TextIOWrapper, + depth: int = 0, + run_on_cpu: bool = False, + bucket_size: int = 1024 * 1024 * 1024 // 2, +): + # This is to try the best effort to align the count of numbers per line for easier comparison in diff views, + # though it does not always guarantee to do this way. + torch.set_printoptions(precision=6, linewidth=128) + + tensor_shape = tensor.shape + tensor_dtype = tensor.dtype + flatten_array = tensor.flatten().view(-1) + + if run_on_cpu: + flatten_array = flatten_array.to("cpu") + + if run_on_cpu: + num_nan = torch.isnan(flatten_array).sum() + num_inf = torch.isinf(flatten_array).sum() + num_neg = (flatten_array < 0).sum() + num_pos = (flatten_array > 0).sum() + num_zero = (flatten_array == 0).sum() + min_value = flatten_array.min() + max_value = flatten_array.max() + mean_value = flatten_array.mean() + std_value = flatten_array.std() + else: + # Split the calculation for each bucket, then do another round of calculation on the bucket results. + # This can at the best effort reduce the peak memory impact. + element_count = flatten_array.numel() + ceil_bucket_count = (element_count + bucket_size - 1) // (bucket_size) + nan_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device) + inf_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device) + neg_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device) + pos_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device) + zero_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device) + min_buckets = torch.zeros(ceil_bucket_count, dtype=flatten_array.dtype, device=flatten_array.device) + max_buckets = torch.zeros(ceil_bucket_count, dtype=flatten_array.dtype, device=flatten_array.device) + mean_buckets = torch.zeros(ceil_bucket_count, dtype=flatten_array.dtype, device=flatten_array.device) + std_buckets = torch.zeros(ceil_bucket_count, dtype=flatten_array.dtype, device=flatten_array.device) + + # Summary for each bucket + element_count_per_bucket = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device) + for i in range(ceil_bucket_count): + end = min((i + 1) * bucket_size, element_count) + bucket = flatten_array[i * bucket_size : end] + element_count_per_bucket[i] = bucket.numel() + + nan_buckets[i] = torch.isnan(bucket).sum() + inf_buckets[i] = torch.isinf(bucket).sum() + neg_buckets[i] = (bucket < 0).sum() + pos_buckets[i] = (bucket > 0).sum() + zero_buckets[i] = (bucket == 0).sum() + min_buckets[i] = bucket.min() + max_buckets[i] = bucket.max() + mean_buckets[i] = bucket.sum() + std_buckets[i] = bucket.std() + + # Reduction across all buckets + num_nan = nan_buckets.sum() + num_inf = inf_buckets.sum() + num_neg = neg_buckets.sum() + num_pos = pos_buckets.sum() + num_zero = zero_buckets.sum() + min_value = min_buckets.min() + max_value = max_buckets.max() + mean_value = float(mean_buckets.sum()) / float(element_count) + # Here we refer to + # https://math.stackexchange.com/questions/2971315/how-do-i-combine-standard-deviations-of-two-groups + # to calculate the combined standard deviation of all buckets. + s = (element_count_per_bucket - 1) * (std_buckets**2) + element_count_per_bucket * ( + (mean_buckets - mean_value) ** 2 + ) + std_value = torch.sqrt(s.sum() / (element_count - 1)) + + f.write( + f"{'>'*max(0, depth) + display_name} shape: {tensor_shape} dtype: {tensor_dtype} size: {flatten_array.size()} \n" + f"min: {min_value} max: {max_value}, mean: {mean_value}, " + f"std: {std_value} \n" + f"nan: {num_nan}, inf: {num_inf}\n" + ) + f.write(f"samples(top 128): {flatten_array[:128]}\n") + f.write(f"neg: {num_neg}, pos: {num_pos}, zero: {num_zero},\n") + f.write(f"{'='*16}\n") diff --git a/orttraining/orttraining/python/training/utils/hooks/_subscriber_manager.py b/orttraining/orttraining/python/training/utils/hooks/_subscriber_manager.py index db38f58d8f324..b2bc64be42fc1 100644 --- a/orttraining/orttraining/python/training/utils/hooks/_subscriber_manager.py +++ b/orttraining/orttraining/python/training/utils/hooks/_subscriber_manager.py @@ -29,14 +29,6 @@ def no_increase_global_step(): finally: ORT_NO_INCREASE_GLOBAL_STEP[0] = False - @staticmethod - def infer_shape( - node: onnx.NodeProto, - tensor_input_shapes: List[Optional[List[Union[int, str]]]], - tensor_input_dtypes: List[torch.onnx.TensorProtoDataType], - ) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]: - return tensor_input_shapes, tensor_input_dtypes - class _IncrementStep(torch.autograd.Function): """This class is used to manage the global execution step, e.g. @@ -55,8 +47,9 @@ def forward(ctx, run_ctx: RuntimeStates, *input_tensor_list: Tuple[torch.Tensor, ctx.current_step = run_ctx.global_states.execution_step ctx.run_ctx = run_ctx - if ctx.current_step >= 0: - print(f"{'='*6} Completed forward pass for STEP {ctx.current_step} {'='*6}") + # Uncomment the following line for debugging purposes. + # if ctx.current_step >= 0: + # print(f"{'='*6} Completed forward pass for STEP {ctx.current_step} {'='*6}") if ORT_NO_INCREASE_GLOBAL_STEP[0] is False: ctx.run_ctx.global_states.execution_step += 1 @@ -191,7 +184,7 @@ def _reset_recursively(module: torch.nn.Module, depth: int, next_module_index: L next_module_index: list of int, carrying a global unique module index that can be used next. """ module_index = next_module_index[0] - module.id = module_index # STAGE3WARN: needed by DeepSpeed + module.id = module_index # STAGE3WARN#1: needed by DeepSpeed self._run_ctx.global_states.module_index_to_depth[module_index] = depth self._run_ctx.global_states.module_to_module_index[module] = module_index @@ -217,7 +210,7 @@ def _register_hooks_recursively(self, module: torch.nn.Module, depth: int, next_ next_module_index: list of int, carrying a global unique module index that can be used next. """ module_index = next_module_index[0] - module.id = module_index # STAGE3WARN: needed by DeepSpeed + module.id = module_index # STAGE3WARN#2: needed by DeepSpeed self._run_ctx.global_states.module_index_to_depth[module_index] = depth self._run_ctx.global_states.module_to_module_index[module] = module_index diff --git a/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py b/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py index 3d42e172eea82..ad1297962db71 100644 --- a/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py +++ b/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py @@ -23,25 +23,37 @@ from ._subscriber_base import RuntimeStates, SubscriberBase -# Used to monkey patch the original function -# Adapted from https://github.com/microsoft/DeepSpeed/blob/e8318634b4313eaad89842cf4322e1762d34ced3/deepspeed/runtime/zero/parameter_offload.py#L333 -def _setup_zero_stage3_ort_compatible_hooks(self): - self.hierarchy = 0 +def _get_ort_compatible_zero_stage3_hook_function(debug, stats_output_dir, stats_overwrite): + """Create ort compatible hook function for DeepSpeed ZeRO stage3. - from onnxruntime.training.utils.hooks import SubscriberManager, ZeROOffloadSubscriber - from onnxruntime.training.utils.hooks._zero_offload_subscriber import _zero_offload_one_time_initializer + Args: + debug: whether to enable convergence debugging. + stats_output_dir: the directory to store convergence stats. + stats_overwrite: whether to overwrite the stats file if it already exists. + """ + + # Used to monkey patch the original function + # Adapted from https://github.com/microsoft/DeepSpeed/blob/e8318634b4313eaad89842cf4322e1762d34ced3/deepspeed/runtime/zero/parameter_offload.py#L333 + def _setup_zero_stage3_ort_compatible_hooks(self): + self.hierarchy = 0 + + from onnxruntime.training.utils.hooks import StatisticsSubscriber, SubscriberManager, ZeROOffloadSubscriber + from onnxruntime.training.utils.hooks._zero_offload_subscriber import _zero_offload_one_time_initializer - # Each DeepSpeed engine has a separate subscriber manager. - self._offload_subscriber_manager = SubscriberManager() - self._offload_subscriber_manager.subscribe( - self.module, [ZeROOffloadSubscriber(self, _zero_offload_one_time_initializer)] - ) - self.forward_hooks.extend(self._offload_subscriber_manager._pre_forward_hooks) - self.forward_hooks.extend(self._offload_subscriber_manager._post_forward_hooks) + subscribers = [ZeROOffloadSubscriber(self, _zero_offload_one_time_initializer)] + if debug is True: + subscribers.append(StatisticsSubscriber(output_dir=stats_output_dir, override_output_dir=stats_overwrite)) + # Each DeepSpeed engine has a separate subscriber manager. + self._offload_subscriber_manager = SubscriberManager() + self._offload_subscriber_manager.subscribe(self.module, subscribers) + self.forward_hooks.extend(self._offload_subscriber_manager._pre_forward_hooks) + self.forward_hooks.extend(self._offload_subscriber_manager._post_forward_hooks) - # Add top module to stack trace - global FWD_MODULE_STACK # noqa: PLW0602 - FWD_MODULE_STACK.append(self.module) + # Add top module to stack trace + global FWD_MODULE_STACK # noqa: PLW0602 + FWD_MODULE_STACK.append(self.module) + + return _setup_zero_stage3_ort_compatible_hooks # Adapted from https://github.com/microsoft/DeepSpeed/blob/e8318634b4313eaad89842cf4322e1762d34ced3/deepspeed/runtime/zero/linear.py#L104 @@ -86,14 +98,16 @@ def collect_code(self, function: Callable): _zero_offload_one_time_initializer.collect_code(DeepSpeedZeRoOffload.setup_zero_stage3_hooks) # This is the function to enable ORT ZeRO offload. - def configure_ort_compatible_zero_stage3(): + def configure_ort_compatible_zero_stage3(debug=False, stats_output_dir="./", stats_overwrite=False): """Configure ZeRO stage3 to be ORT compatible. This function will overwrite the original DeepSpeed ZeRO stage3 hooks to make it ORT compatible. """ # Only done once no matter how many times this function is called for different modules. - DeepSpeedZeRoOffload.setup_zero_stage3_hooks = _setup_zero_stage3_ort_compatible_hooks + DeepSpeedZeRoOffload.setup_zero_stage3_hooks = _get_ort_compatible_zero_stage3_hook_function( + debug, stats_output_dir, stats_overwrite + ) from deepspeed.runtime.zero.linear import zero3_linear_wrap @@ -103,7 +117,7 @@ def configure_ort_compatible_zero_stage3(): except ImportError as e: warnings.warn(f"DeepSpeed import error {e}") - def configure_ort_compatible_zero_stage3(): + def configure_ort_compatible_zero_stage3(debug=False, stats_output_dir=None, stats_overwrite=False): raise RuntimeError("DeepSpeed is not installed, cannot configure ORT compatible ZeRO stage3.") @@ -115,13 +129,13 @@ def _get_params_for_current_module(module: torch.nn.Module) -> List[torch.nn.par """ from deepspeed.runtime.zero.partitioned_param_coordinator import iter_params - # Retrive the parameters that are not available for this module. + # Retrieve all parameters for this module. partitioned_params = [param for param in iter_params(module)] return partitioned_params -def _get_all_offloaded_params(module: torch.nn.Module) -> Dict[str, torch.nn.parameter.Parameter]: +def _get_all_zero_stage3_params(module: torch.nn.Module) -> Dict[str, torch.nn.parameter.Parameter]: """Retrieve all the parameters that are offloaded.""" from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus @@ -134,16 +148,13 @@ def _get_all_offloaded_params(module: torch.nn.Module) -> Dict[str, torch.nn.par class ORTZeROOffloadPreForwardFunction(torch.autograd.Function): - """This function is a common bridge to call original PyTorch's - pre_forward_function and post_backward_function. - """ + """This function is a common bridge to call original PyTorch's pre_forward_function""" @staticmethod def forward( ctx, module, pre_forward_with_kwargs_function, - post_backward_function, args_schema, kwargs_schema, args_tensor_count, @@ -155,7 +166,6 @@ def forward( ctx: context object module: the module to be called pre_forward_with_kwargs_function: the function to be called before forward (PyTorch's pre_forward_function) - post_backward_function: the function to be called after backward (PyTorch's post_backward_function) args_schema: the schema of the args, used to reconstruct the args in original form in PyTorch's pre_forward_function's inputs. kwargs_schema: the schema of the kwargs, used to reconstruct the kwargs in original form in @@ -168,6 +178,17 @@ def forward( args_tensors = tensor_list[:args_tensor_count] kwargs_tensors = tensor_list[args_tensor_count : args_tensor_count + kwargs_tensor_count] + # For PyTorch runs, the sizes are all 0, it does not need a gradient because + # param._detach().requires_grad_(False) is called. + # But for ORT runs, the sizes are all [1], as output of weight retrieval function. + # So we keep track of the shapes and dtypes of the passed-in tensors, then generate the grads in backward. + # While for both PyTorch and ORT runs, the grad is not important because they are not param grads + # anymore, they are only used for completing the full backward propagation. + passed_in_param_tensors = tensor_list[args_tensor_count + kwargs_tensor_count :] + ctx.shapes = [p.shape for p in passed_in_param_tensors] + ctx.dtypes = [p.dtype for p in passed_in_param_tensors] + ctx.devices = [p.device for p in passed_in_param_tensors] + args = unflatten_data_using_schema(args_tensors, args_schema) kwargs = unflatten_data_using_schema(kwargs_tensors, kwargs_schema) @@ -179,6 +200,8 @@ def forward( partitioned_params = _get_params_for_current_module(module) ctx.partitioned_params = partitioned_params + assert len(partitioned_params) == len(passed_in_param_tensors) + f_ret = pre_forward_with_kwargs_function(module, args, kwargs) if f_ret is None: @@ -188,7 +211,6 @@ def forward( updated_args, updated_kwargs = f_ret ctx.module = module - ctx.post_backward_function = post_backward_function updated_args_tensors, _ = extract_data_and_schema(updated_args) updated_kwargs_tensors, _ = extract_data_and_schema(updated_kwargs) @@ -203,17 +225,32 @@ def forward( @staticmethod def backward(ctx, *grads): updated_grads = grads - if ctx.post_backward_function is not None: - ret = ctx.post_backward_function(ctx.module, grads) - if ret is not None: - updated_grads = ret - # TODO(pengwa) Update grad for partitioned parameters. input_count = len(updated_grads) - len(ctx.partitioned_params) - zeros = [torch.zeros(0, dtype=p.dtype, device=p.device) for p in ctx.partitioned_params] - zero_grads = updated_grads[:input_count] + tuple(zeros) - - return (None, None, None, None, None, None, None, *zero_grads) + param_start_offset = input_count + + # Only need to accumulate grad explicitly for ORT run (e.g. ctx.shapes[0] == (1,)); + # In the PyTorch run, the accumulation happens automatically. + need_manual_grad_acc = len(ctx.shapes) > 0 and ctx.shapes[0] == (1,) + if need_manual_grad_acc: + for param_index, p in enumerate(ctx.partitioned_params): + g = updated_grads[param_index + param_start_offset] + if g is None: + raise RuntimeError(f"param {p} has no grad, this should not happen.") + # Param gradient accumulation is triggered here, along with the attached hooks, done by PyTorch. + assert p.shape == g.shape, f"param_index: {param_index} - param shape {p.shape} != grad shape {g.shape}" + p.backward(g) + + # At this point, the **real** param grads are already updated, the following grads are only used for + # completing the full backward propagation, will not affect parameter updates. + passed_in_param_grad = [ + torch.zeros(shape, dtype=dtype, device=device) + for shape, dtype, device in zip(ctx.shapes, ctx.dtypes, ctx.devices) + ] + + zero_grads = updated_grads[:input_count] + tuple(passed_in_param_grad) + + return (None, None, None, None, None, None, *zero_grads) @staticmethod def infer_shape( @@ -258,14 +295,14 @@ def forward( module: the module to be called post_forward_function: the function to be called after forward (PyTorch's post_forward_function) pre_backward_function: the function to be called before backward (PyTorch's pre_backward_function) - output_schema: the schema of the output, used to reconstruct the output in original form in + output_schema: the schema of the output, used to reconstruct the output in its original form in PyTorch's post_forward_function's inputs. output_tensors: the list of tensors. """ outputs = unflatten_data_using_schema(output_tensors, output_schema) - # STAGE3WARN: _post_forward_module_hook's second argument `input is not used, so we just pass a None here. + # STAGE3WARN#3: _post_forward_module_hook's second argument `input is not used, so we just pass a None here. updated_outputs = post_forward_function(module, None, outputs) if updated_outputs is None: @@ -341,11 +378,19 @@ def pre_forward_module_apply_impl( input and output for torch.autograd.Function, so we do flatten and unflatten here. """ + ## Handle `_post_backward_module_hook` - args_tensors, args_schema = extract_data_and_schema(args) - kwargs_tensors, kwargs_schema = extract_data_and_schema(kwargs) + # Put `_post_backward_module_hook` first because in backward, it is responsible for unloading parameters, + # we want ORTZeROOffloadPreForwardFunction's backward still be able to access the full sized parameters. + _post_backward_module_hook = self._functions.get("_post_backward_module_hook") + # STAGE3WARN#4: most logic in _post_backward_module_hook can be traced correctly so we don't need to + # wrap with PythonOp. For those cannot be traced, we handle them in STAGE3WARN#5. + updated_args = _post_backward_module_hook(module, args) - partitioned_params = _get_params_for_current_module(module) + ## Handle `_pre_forward_module_hook` + + args_tensors, args_schema = extract_data_and_schema(updated_args) + kwargs_tensors, kwargs_schema = extract_data_and_schema(kwargs) _pre_forward_module_hook = self._functions.get("_pre_forward_module_hook") @@ -358,18 +403,29 @@ def _wrap_pre_forward_module_hook(module, args, kwargs): if rets is not None: updated_args = rets - # STAGE3WARN: Moved from _post_backward_module_hook to make sure ORT run will trigger every iteration. + # STAGE3WARN#5: Moved from _post_backward_module_hook to make sure ORT run will trigger every iteration. module.ds_grads_remaining = 0 + return updated_args, updated_kwargs - all_tensors = args_tensors + kwargs_tensors + partitioned_params + # Need to pass the parameters as input to let the exporter trace the related weights for + # current ORTZeROOffloadPreForwardFunction + partitioned_params = _get_params_for_current_module(module) + # Don't require grad for passed-in parameter, otherwise it will be treated as a leaf node, in backward + # returned 0-sized grad did not match the param's gradient accumulator function's input shape metadata, + # PyTorch run will fail during backward. + # This will not harm parameter gradient build either in ORT or PyTorch, imagine the weights are used by + # computation anyway, so the gradient will be built. This hook only references the parameter, but won't + # generate a gradient path for it. + detached_partitioned_params = [p.detach().requires_grad_(False) for p in partitioned_params] + + all_tensors = args_tensors + kwargs_tensors + detached_partitioned_params self._check_all_tensor(all_tensors, module, "pre_forward_module_apply_impl input check") rets = ORTZeROOffloadPreForwardFunction.apply( module, _wrap_pre_forward_module_hook, - None, args_schema, kwargs_schema, args_tensor_count, @@ -385,11 +441,6 @@ def _wrap_pre_forward_module_hook(module, args, kwargs): updated_args = unflatten_data_using_schema(updated_args_tensors, args_schema) updated_kwargs = unflatten_data_using_schema(updated_kwargs_tensors, kwargs_schema) - _post_backward_module_hook = self._functions.get("_post_backward_module_hook") - # STAGE3WARN: Other part of _post_backward_module_hook can be traced correctly so we don't need to - # wrap with PythonOp. - updated_args = _post_backward_module_hook(module, updated_args) - return updated_args, updated_kwargs def post_forward_module_apply_impl( @@ -411,7 +462,7 @@ def post_forward_module_apply_impl( _post_forward_module_hook = self._functions.get("_post_forward_module_hook") def _wrap_post_forward_module_hook(module, input, outputs): - # STAGE3WARN: _post_forward_module_hook applied this for each tensor output, so we do a simple wrap here. + # STAGE3WARN#6: _post_forward_module_hook applied this for each tensor output, so we do a simple wrap here. from deepspeed.runtime.zero.partition_parameters import is_zero_param updated_outputs = _post_forward_module_hook(module, input, outputs) @@ -438,8 +489,8 @@ def _wrap_post_forward_module_hook(module, input, outputs): updated_outputs = unflatten_data_using_schema(updated_outputs_tensors, outputs_schema) _pre_backward_module_hook = self._functions.get("_pre_backward_module_hook") - # STAGE3WARN: _pre_backward_module_hook's second argument `input is not used, so we just pass a None here. - # STAGE3WARN: part of the original _pre_backward_module_hook can be traced correctly so we moved them into + # STAGE3WARN#7: _pre_backward_module_hook's second argument `input is not used, so we just pass a None here. + # STAGE3WARN#8: part of the original _pre_backward_module_hook can be traced correctly so we moved them into # _wrap_post_forward_module_hook above. updated_outputs = _pre_backward_module_hook(module, None, updated_outputs) diff --git a/orttraining/orttraining/python/training/utils/torch_type_map.py b/orttraining/orttraining/python/training/utils/torch_type_map.py index 699747723f457..bdacab8ad04fe 100644 --- a/orttraining/orttraining/python/training/utils/torch_type_map.py +++ b/orttraining/orttraining/python/training/utils/torch_type_map.py @@ -33,6 +33,8 @@ _DTYPE_TO_ONNX = {torch_dtype: onnx_dtype for k, (onnx_dtype, torch_dtype) in _CAST_PYTORCH_TO_ONNX.items()} +_ONNX_TO_DTYPE = {onnx_dtype: torch_dtype for torch_dtype, onnx_dtype in _DTYPE_TO_ONNX.items()} + def pytorch_dtype_to_onnx(dtype_or_scalar_type: Union[torch.dtype, str]) -> torch.onnx.TensorProtoDataType: """Converts a pytorch dtype or scalar type string to an onnx dtype.""" @@ -45,3 +47,10 @@ def pytorch_dtype_to_onnx(dtype_or_scalar_type: Union[torch.dtype, str]) -> torc if dtype not in _DTYPE_TO_ONNX: raise RuntimeError(f"Unsupported dtype {dtype}") return _DTYPE_TO_ONNX[dtype] + + +def onnx_dtype_to_pytorch(dtype: torch.onnx.TensorProtoDataType) -> torch.dtype: + """Converts an onnx dtype to a pytorch dtype.""" + if dtype not in _ONNX_TO_DTYPE: + raise RuntimeError(f"Unsupported dtype {dtype}") + return _ONNX_TO_DTYPE[dtype] diff --git a/orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.cc b/orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.cc index 4e7fcbc95bb1d..e1d4be24861f5 100644 --- a/orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.cc +++ b/orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.cc @@ -153,8 +153,11 @@ void PythonOpBase::RunForward(OpKernelContext* context, inplace_ != 0, kernel_invoke_id_); - ORT_ENFORCE(1 + returned_ortvalues.size() == static_cast(context->OutputCount()), - "Output count mismatch for PythonOp run"); + const size_t returned_output_count = 1 + returned_ortvalues.size(); + const size_t kernel_output_count = static_cast(context->OutputCount()); + ORT_ENFORCE(returned_output_count == kernel_output_count, "Output count mismatch for PythonOp run, ", + "returned_output_count: ", returned_output_count, ", expected kernel_output_count: ", + kernel_output_count); } void PythonOpBase::SetOutputs(OpKernelContext* context, void* diff_ctx, std::vector& returned_args) const { From 1bc215e1d1c1e3509a1dd0bc413b1537563dedb5 Mon Sep 17 00:00:00 2001 From: Yiming Hu Date: Thu, 21 Sep 2023 19:22:28 -0700 Subject: [PATCH 13/58] [VITISAI] add float16 and bfloat16 support (#17438) ### Description Add float16 and bfloat16 data type support for VitisAI ep ### Motivation and Context The VitisAI ep has added the bfloat datatype support. So we would like to register the datatype from onnxruntime side to enable them. --------- Signed-off-by: Yiming Hu --- onnxruntime/core/providers/vitisai/README.md | 2 +- onnxruntime/core/providers/vitisai/imp/register_xir_ops.cc | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/vitisai/README.md b/onnxruntime/core/providers/vitisai/README.md index 15e0c804489c5..6ddb58b8d96ae 100644 --- a/onnxruntime/core/providers/vitisai/README.md +++ b/onnxruntime/core/providers/vitisai/README.md @@ -1,4 +1,4 @@ -VitsAI Execution Prividers +VitisAI Execution Provider ============================ diff --git a/onnxruntime/core/providers/vitisai/imp/register_xir_ops.cc b/onnxruntime/core/providers/vitisai/imp/register_xir_ops.cc index 544e18350635d..ee8dfc6d03d12 100644 --- a/onnxruntime/core/providers/vitisai/imp/register_xir_ops.cc +++ b/onnxruntime/core/providers/vitisai/imp/register_xir_ops.cc @@ -34,9 +34,12 @@ static void xir_shape_infer(ONNX_NAMESPACE::InferenceContext& ctx) { updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::INT64); } else if (data_type->s() == "int1") { updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::BOOL); + } else if (data_type->s() == "bfloat16") { + updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::BFLOAT16); + } else if (data_type->s() == "float16") { + updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::FLOAT16); } else { - std::cerr << "not supported data_type " << data_type->s(); - abort(); + vai_assert(false, ", not supported data_type: " + data_type->s()); } if (shape != nullptr) { for (auto i = 0; i < shape->ints_size(); ++i) { From cd3fb377ea867570796cf61bc420cd985129a2a0 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Fri, 22 Sep 2023 11:55:08 +0800 Subject: [PATCH 14/58] [js/webgpu] Allow binary ops with scalar to use the vectorize path (#17589) ### Description 1. For binary ops, the components is always 4. So the dispatchGroup should be : `{x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* component size */)}` instead of `{x: Math.ceil(outputSize / 64 /* workgroup size */ / (vectorize ? 4 : 1) /* vec size */)}`. 2. If any of a or b only has one element, we still can use the vectorize path since the same value will be broadcasted. --- js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts | 23 +++++++++++++++----- 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts index 13d3a91bb339e..9c05080f7e118 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts @@ -62,14 +62,24 @@ const createBinaryOpProgramShader = let assignment: string; if (vectorize) { if (doBroadcast) { - assignment = ` + const isAOneElement = ShapeUtil.size(dimsA) === 1; + const isBOneElement = ShapeUtil.size(dimsB) === 1; + if (isAOneElement || isBOneElement) { + assignment = output.setByOffset( + 'global_idx', + expressionVector( + isAOneElement ? `${a.type.value}(${a.getByOffset('0')}.x)` : a.getByOffset('global_idx'), + isBOneElement ? `${b.type.value}(${b.getByOffset('0')}.x)` : b.getByOffset('global_idx'))); + } else { + assignment = ` let outputIndices = ${output.offsetToIndices('global_idx * 4u')}; let offsetA = calcOffsetA(outputIndices); let offsetB = calcOffsetB(outputIndices); ${ - output.setByOffset( - 'global_idx', expressionVector(a.getByOffset('offsetA / 4u'), b.getByOffset('offsetB / 4u')))} + output.setByOffset( + 'global_idx', expressionVector(a.getByOffset('offsetA / 4u'), b.getByOffset('offsetB / 4u')))} `; + } } else { assignment = output.setByOffset( 'global_idx', expressionVector(a.getByOffset('global_idx'), b.getByOffset('global_idx'))); @@ -141,6 +151,8 @@ const createBinaryOpProgramInfo = } outputShape = calculatedShape; outputSize = ShapeUtil.size(outputShape); + const isAOneElement = ShapeUtil.size(a.dims) === 1; + const isBOneElement = ShapeUtil.size(b.dims) === 1; // check whether vectorize can be enabled let sharedDimension = 1; @@ -153,7 +165,7 @@ const createBinaryOpProgramInfo = break; } } - if (sharedDimension % 4 === 0) { + if (sharedDimension % 4 === 0 || isAOneElement || isBOneElement) { vectorize = true; } } else { @@ -167,8 +179,7 @@ const createBinaryOpProgramInfo = shaderHelper, a.dims, b.dims, outputShape, vectorize, isBroadcast, funcCall, a.dataType, b.dataType, outputDataType, additionalImplementation), outputs: [{dims: outputShape, dataType: outputDataType, gpuDataType: GpuDataType.default}], - dispatchGroup: () => - ({x: Math.ceil(outputSize / 64 /* workgroup size */ / (vectorize ? 4 : 1) /* vec size */)}) + dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* component size */)}) }; }; From 891fba3b9cd71e2e1afdeab9fb3c5b5497db20cf Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Fri, 22 Sep 2023 12:00:36 +0800 Subject: [PATCH 15/58] [js/webgpu] Optimize Gather op (#17625) ### Description This PR optimizes the gather op, which is improved ~6ms in segment anything model in ADL. The problem in original algorithm is that it includes a for loop to calculate a block size of data. However, the block size may be very large, like `65536`. In GPU shader, we should try to avoid large loop in shader and try to use more threads to do it parallelly. Before: ``` [profiling] kernel "41771992|[Gather] 41771992" input[0]: [4,65536] | float32, input[1]: [1] | int64, output[0]: [1,65536] | float32, execution time: 6886207 ns ``` After: ``` [profiling] kernel "41771992|[Gather] 41771992" input[0]: [4,65536] | float32, input[1]: [1] | int64, output[0]: [1,65536] | float32, execution time: 11719 ns --- js/web/lib/wasm/jsep/webgpu/ops/common.ts | 2 +- js/web/lib/wasm/jsep/webgpu/ops/gather.ts | 91 ++++++++++------------- 2 files changed, 42 insertions(+), 51 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index c054da51a3098..0ab777bfbdee9 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -366,7 +366,7 @@ const createIndicesHelper = const getByIndicesImplementation = rank < 2 ? '' : ` fn get_${name}ByIndices(indices: ${type.indices}) -> ${valueType} { - return ${name}[i2o_${name}(indices)]; + return ${getByOffset(`i2o_${name}(indices)`)}; }`; const getImplementation = rank < 2 ? '' : (() => { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts index 0db060dbec54a..47aae13d6799d 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts @@ -1,13 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types'; -import {ShaderHelper} from './common'; +import {inputVariable, outputVariable, ShaderHelper} from './common'; export interface GatherAttributes extends AttributeWithCacheKey { axis: number; @@ -30,63 +29,55 @@ const createGatherProgramInfo = const outputShape = inputShape.slice(0); outputShape.splice(axis, 1, ...indicesShape); - const inputDataType = inputs[0].dataType; - const block = ShapeUtil.sizeFromDimension(inputShape, axis + 1); - const elementSize = [DataType.int64, DataType.uint64, DataType.double].includes(inputDataType) ? 2 : 1; - const indicesElementSize = inputs[1].dataType === DataType.int64 ? 2 : 1; - const blockSize = elementSize * block; - const M = ShapeUtil.sizeToDimension(inputShape, axis); - const N = ShapeUtil.size(indicesShape); - const dataBatchElements = ShapeUtil.sizeFromDimension(inputShape, axis) * elementSize; - const gatheredBatchElements = N * block * elementSize; const axisDimLimit = inputShape[axis]; + const outputSize = ShapeUtil.size(outputShape); + + const data = inputVariable('data', inputs[0].dataType, inputs[0].dims); + const indices = inputVariable('inputIndices', inputs[1].dataType, inputs[1].dims); + const output = outputVariable('output', inputs[0].dataType, outputShape); + const calcDataIndices = (): string => { + const indicesRank = indicesShape.length; + let calcStr = `var indicesIndices = ${indices.type.indices}(0);`; + for (let i = 0; i < indicesRank; i++) { + calcStr += `${indicesRank > 1 ? `indicesIndices[${i}]` : 'indicesIndices'} = ${ + outputShape.length > 1 ? `outputIndices[${axis + i}]` : 'outputIndices'};`; + } + calcStr += ` + var idx = ${indices.getByIndices('indicesIndices')}; + if (idx < 0) { + idx = idx + ${axisDimLimit}; + } + var dataIndices = ${data.type.indices}(0); + `; + for (let i = 0, j = 0; i < inputRank; i++) { + if (i === axis) { + calcStr += `${inputRank > 1 ? `dataIndices[${i}]` : 'dataIndices'} = u32(idx);`; + j += indicesRank; + } else { + calcStr += `${inputRank > 1 ? `dataIndices[${i}]` : 'dataIndices'} = ${ + outputShape.length > 1 ? `outputIndices[${j}]` : 'outputIndices'};`; + j++; + } + } + return calcStr; + }; - const inputSize = ShapeUtil.size(inputShape) * elementSize; - const outputSize = ShapeUtil.size(outputShape) * elementSize; - - const totalGathers = M * N; - // int64 indices would be treated as little endian i32 with assumption they fall in i32 limits - // That assumption is safe as it's not possible to allocate >2gb buffer for input tensor - // Input data will be treated as u32 or two u32 for 8-byte tensors const getShaderSource = (shaderHelper: ShaderHelper) => ` - const N: u32 = ${N}; - const elementSize: u32 = ${elementSize}; - const indicesElementSize: u32 = ${indicesElementSize}; - - @group(0) @binding(0) var input : array; - @group(0) @binding(1) var inputIndices : array; - @group(0) @binding(2) var output: array; - - ${shaderHelper.mainStart()} - let batch: u32 = global_idx / N; - let i: u32 = global_idx % N; - - let srcOffsetBatch: u32 = batch * ${dataBatchElements}; - let dstOffsetBatch: u32 = batch * ${gatheredBatchElements}; - var idx = inputIndices[i * indicesElementSize]; - if (idx < 0) { - idx = idx + ${axisDimLimit}; - } - - let srcOffset = srcOffsetBatch + u32(idx) * ${blockSize}; - let dstOffset = dstOffsetBatch + i * ${blockSize}; - if (srcOffset >= ${inputSize}) { - return; - } - if (dstOffset >= ${outputSize}) { - return; - } - for (var j: u32 = 0; j < ${blockSize}; j++) { - output[dstOffset + j] = input[srcOffset + j]; - } - }`; + ${shaderHelper.declareVariables(data, indices, output)} + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + let outputIndices = ${output.offsetToIndices('global_idx')}; + ${calcDataIndices()}; + let value = ${data.getByIndices('dataIndices')}; + ${output.setByOffset('global_idx', 'value')}; + }`; return { ...metadata, outputs: [ {dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}, ], getShaderSource, - dispatchGroup: () => ({x: Math.ceil(totalGathers / 64 /* workgroup size */)}) + dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */)}) }; }; From 55b16d347cbcde41b35c3ed12f34eeca1a1b05d6 Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Sat, 23 Sep 2023 00:50:36 +0800 Subject: [PATCH 16/58] Read model zoo test (#17666) --- onnxruntime/test/providers/cpu/model_tests.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/onnxruntime/test/providers/cpu/model_tests.cc b/onnxruntime/test/providers/cpu/model_tests.cc index ef2d7e31654ba..9b41ba8c0d2ba 100644 --- a/onnxruntime/test/providers/cpu/model_tests.cc +++ b/onnxruntime/test/providers/cpu/model_tests.cc @@ -1133,11 +1133,15 @@ ::std::vector<::std::basic_string> GetParameterStrings() { #if defined(NDEBUG) || defined(RUN_MODELTEST_IN_DEBUG_MODE) #ifdef _WIN32 ORT_STRING_VIEW model_test_root_path = ORT_TSTR("..\\models"); + // thus, only the root path should be mounted. + ORT_STRING_VIEW model_zoo_path = ORT_TSTR("..\\models\\zoo"); #else ORT_STRING_VIEW model_test_root_path = ORT_TSTR("../models"); + ORT_STRING_VIEW model_zoo_path = ORT_TSTR("../models/zoo"); #endif for (auto p : kvp.second) { paths.push_back(ConcatPathComponent(model_test_root_path, p)); + paths.push_back(ConcatPathComponent(model_zoo_path, p)); } #endif From 6d7bc2a097a1a08541cd0d4628831c79ab8092d5 Mon Sep 17 00:00:00 2001 From: Lukas Berbuer <36054362+lukasberbuer@users.noreply.github.com> Date: Fri, 22 Sep 2023 18:54:38 +0200 Subject: [PATCH 17/58] Fix ARMv7 build (#13891) Fix ARMv7 build error on Linux. ### Description `cpuinfo_*` functions are only available if `CPUINFO_SUPPORTED` set and therefore `"cpuinfo.h"` included. Fixed with extended conditional code. ### Motivation and Context Compilation with ARMv7 on Linux system fails. --- onnxruntime/core/common/cpuid_info.cc | 54 +++++++++++++-------------- 1 file changed, 25 insertions(+), 29 deletions(-) diff --git a/onnxruntime/core/common/cpuid_info.cc b/onnxruntime/core/common/cpuid_info.cc index a23409292bb74..6a82b3fcc734d 100644 --- a/onnxruntime/core/common/cpuid_info.cc +++ b/onnxruntime/core/common/cpuid_info.cc @@ -135,38 +135,34 @@ void CPUIDInfo::ArmLinuxInit() { LOGS_DEFAULT(WARNING) << "Failed to init pytorch cpuinfo library, may cause CPU EP performance degradation due to undetected CPU features."; return; } + is_hybrid_ = cpuinfo_get_uarchs_count() > 1; + has_arm_neon_dot_ = cpuinfo_has_arm_neon_dot(); + has_fp16_ = cpuinfo_has_arm_neon_fp16_arith(); + const uint32_t core_cnt = cpuinfo_get_cores_count(); + core_uarchs_.resize(core_cnt, cpuinfo_uarch_unknown); + is_armv8_narrow_ld_.resize(core_cnt, false); + for (uint32_t c = 0; c < core_cnt; c++) { + const struct cpuinfo_processor* proc = cpuinfo_get_processor(c); + if (proc == nullptr) { + continue; + } + const struct cpuinfo_core* corep = proc->core; + if (corep == nullptr) { + continue; + } + auto coreid = proc->linux_id; + auto uarch = corep->uarch; + core_uarchs_[coreid] = uarch; + if (uarch == cpuinfo_uarch_cortex_a53 || uarch == cpuinfo_uarch_cortex_a55r0 || + uarch == cpuinfo_uarch_cortex_a55) { + is_armv8_narrow_ld_[coreid] = true; + } + } #else pytorch_cpuinfo_init_ = false; + has_arm_neon_dot_ = ((getauxval(AT_HWCAP) & HWCAP_ASIMDDP) != 0); + has_fp16_ |= has_arm_neon_dot_; #endif - - if (pytorch_cpuinfo_init_) { - is_hybrid_ = cpuinfo_get_uarchs_count() > 1; - has_arm_neon_dot_ = cpuinfo_has_arm_neon_dot(); - has_fp16_ = cpuinfo_has_arm_neon_fp16_arith(); - const uint32_t core_cnt = cpuinfo_get_cores_count(); - core_uarchs_.resize(core_cnt, cpuinfo_uarch_unknown); - is_armv8_narrow_ld_.resize(core_cnt, false); - for (uint32_t c = 0; c < core_cnt; c++) { - const struct cpuinfo_processor* proc = cpuinfo_get_processor(c); - if (proc == nullptr) { - continue; - } - const struct cpuinfo_core* corep = proc->core; - if (corep == nullptr) { - continue; - } - auto coreid = proc->linux_id; - auto uarch = corep->uarch; - core_uarchs_[coreid] = uarch; - if (uarch == cpuinfo_uarch_cortex_a53 || uarch == cpuinfo_uarch_cortex_a55r0 || - uarch == cpuinfo_uarch_cortex_a55) { - is_armv8_narrow_ld_[coreid] = true; - } - } - } else { - has_arm_neon_dot_ = ((getauxval(AT_HWCAP) & HWCAP_ASIMDDP) != 0); - has_fp16_ |= has_arm_neon_dot_; - } } #elif defined(_WIN32) From e70a23f8dc6fc181218106f0e12730f980cc867e Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Fri, 22 Sep 2023 10:52:47 -0700 Subject: [PATCH 18/58] [QNN EP] Integrate Resize op fixes from QNN 2.14.1 (#17641) ### Description QNN SDK version 2.14.1 fixed several issues with the QNN Resize operator. This PR integrates the fixes and simplifies the implementation. ### Motivation and Context Improve Resize operator and test coverage. --- .../builder/opbuilder/resize_op_builder.cc | 379 ++++++------------ .../providers/cpu/tensor/resize_op_test.cc | 38 +- onnxruntime/test/providers/qnn/resize_test.cc | 224 ++++++++--- 3 files changed, 308 insertions(+), 333 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/resize_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/resize_op_builder.cc index 511f2a5149f2e..4039c4fbf8d70 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/resize_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/resize_op_builder.cc @@ -2,7 +2,8 @@ // Licensed under the MIT License. #include -#include +#include +#include #include "core/providers/common.h" #include "core/providers/shared/utils/utils.h" @@ -42,76 +43,6 @@ class ResizeOpBuilder : public BaseOpBuilder { bool do_op_validation) const override ORT_MUST_USE_RESULT; private: - /** - * Returns the QNN integer value that corresponds to the given ONNX mode (string). - * - * /param onnx_modes Array of ONNX modes supported by QNN. The index of each mode corresponds to the QNN value. - * /param onnx_mode The ONNX mode for which to get the corresponding QNN value. - * /param onnx_model_label Mode label to print out in case of error (e.g., "nearest_mode"). - * /param qnn_mode Output parameter that is set to the appropriate QNN value from the given ONNX mode. - * - * /returns A status indicating failure or success. - */ - template - Status GetQnnModeFromString(const std::array& onnx_modes, std::string_view onnx_mode, - const char* onnx_mode_label, QnnValType& qnn_mode) const ORT_MUST_USE_RESULT; - - /** - * Called by IsOpSupported to validate the op for non-quantized models. - * - * /param qnn_model_wrapper The QNN model wrapper instance. - * /param node_unit The node unit containing metadata for the ONNX Resize operator. - * - * /returns A status indicating failure or success. - */ - Status ValidateOp(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit) const ORT_MUST_USE_RESULT; - - /** - * Called by IsOpSupported to validate the op for quantized models. - * - * /param qnn_model_wrapper The QNN model wrapper instance. - * /param node_unit The node unit containing metadata for the ONNX Resize operator and its Q/DQ nodes. - * - * /returns A status indicating failure or success. - */ - Status ValidateQDQOp(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit) const ORT_MUST_USE_RESULT; - - /** - * Called by ProcessAttributesAndOutputs to process the op's attributes and outputs - * for non-quantized models. - * - * /param qnn_model_wrapper The QNN model wrapper instance. - * /param node_unit The node unit containing metadata for the ONNX Resize operator. - * /param input_names The operator's input names. - * /param logger A logger. - * /param do_op_validation Set to true if the op should be validated using QNN's validation API. - * - * /returns A status indicating failure or success. - */ - Status ProcessOpAttrsAndOutputs(QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& node_unit, - std::vector&& input_names, - const logging::Logger& logger, - bool do_op_validation) const ORT_MUST_USE_RESULT; - - /** - * Called by ProcessAttributesAndOutputs to process the op's attributes and outputs - * for quantized models. - * - * /param qnn_model_wrapper The QNN model wrapper instance. - * /param node_unit The node unit containing metadata for the ONNX Resize operator and its Q/DQ nodes. - * /param input_names The operator's input names. - * /param logger A logger. - * /param do_op_validation Set to true if the op should be validated using QNN's validation API. - * - * /returns A status indicating failure or success. - */ - Status ProcessQDQOpAttrsAndOutputs(QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& node_unit, - std::vector&& input_names, - const logging::Logger& logger, - bool do_op_validation) const ORT_MUST_USE_RESULT; - // Info for each ONNX attribute of interest (attribute name + default value) static const OnnxAttrInfo onnx_mode_attr; static const OnnxAttrInfo onnx_coord_transf_mode_attr; @@ -119,21 +50,29 @@ class ResizeOpBuilder : public BaseOpBuilder { static const OnnxAttrInfo onnx_antialias_attr; static const OnnxAttrInfo onnx_exclude_outside_attr; - // Arrays of supported QNN modes for QNN's Resize op. The index of each mode is used as the corresponding - // QNN parameter value. Ex: The "nearest" mode is represented as the value 0 in QNN. Note, that - // not all modes are supported by every QNN backend. + // Tables that map an ONNX attribute value (string) to the corresponding integer (enum) QNN parameter value. + // Ex: The "half_pixel" coordinate_transformation_mode is represented as the value 0 in QNN. + // Only the modes supported by QNN Resize are mapped by these tables. + static const std::unordered_map supported_modes; + static const std::unordered_map supported_coord_transf_modes; + static const std::unordered_map supported_nearest_modes; +}; - // QNN values: NEAREST = 0, LINEAR = 1 - static constexpr std::array supported_modes = {"nearest", "linear"}; +const std::unordered_map ResizeOpBuilder::supported_modes = { + {"nearest", QNN_OP_RESIZE_INTERPOLATION_MODE_NEAREST}, + {"linear", QNN_OP_RESIZE_INTERPOLATION_MODE_LINEAR}}; - // QNN values: HALF_PIXEL = 0, PYTORCH_HALF_PIXEL = 1, ALIGN_CORNERS = 2, ASYMMETRIC = 3 - static constexpr std::array supported_coord_transf_modes = {"half_pixel", "pytorch_half_pixel", - "align_corners", "asymmetric"}; +const std::unordered_map ResizeOpBuilder::supported_coord_transf_modes = { + {"half_pixel", QNN_OP_RESIZE_TRANSFORMATION_MODE_HALF_PIXEL}, + {"pytorch_half_pixel", QNN_OP_RESIZE_TRANSFORMATION_MODE_PYTORCH_HALF_PIXEL}, + {"align_corners", QNN_OP_RESIZE_TRANSFORMATION_MODE_ALIGN_CORNERS}, + {"asymmetric", QNN_OP_RESIZE_TRANSFORMATION_MODE_ASYMMETRIC}}; - // QNN values: ROUND_PREFER_FLOOR = 0, ROUND_PREFER_CEIL = 1, FLOOR = 2, CEIL = 3 - static constexpr std::array supported_nearest_modes = {"round_prefer_floor", "round_prefer_ceil", - "floor", "ceil"}; -}; +const std::unordered_map ResizeOpBuilder::supported_nearest_modes = { + {"round_prefer_floor", QNN_OP_RESIZE_NEAREST_MODE_ROUND_PREFER_FLOOR}, + {"round_prefer_ceil", QNN_OP_RESIZE_NEAREST_MODE_ROUND_PREFER_CEIL}, + {"floor", QNN_OP_RESIZE_NEAREST_MODE_FLOOR}, + {"ceil", QNN_OP_RESIZE_NEAREST_MODE_CEIL}}; const OnnxAttrInfo ResizeOpBuilder::onnx_mode_attr = {"mode", "nearest"}; const OnnxAttrInfo ResizeOpBuilder::onnx_coord_transf_mode_attr = {"coordinate_transformation_mode", @@ -143,19 +82,26 @@ const OnnxAttrInfo ResizeOpBuilder::onnx_nearest_mode_attr = {"near const OnnxAttrInfo ResizeOpBuilder::onnx_antialias_attr = {"antialias", 0}; const OnnxAttrInfo ResizeOpBuilder::onnx_exclude_outside_attr = {"exclude_outside", 0}; -template -Status ResizeOpBuilder::GetQnnModeFromString(const std::array& onnx_modes, - std::string_view onnx_mode, const char* onnx_mode_label, - QnnValType& qnn_mode) const { - for (size_t i = 0; i < onnx_modes.size(); ++i) { - if (onnx_modes[i] == onnx_mode) { - qnn_mode = SafeInt(i); - return Status::OK(); - } +// Returns the QNN parameter integer value that corresponds to the given ONNX attribute mode string value. +static Status GetQnnModeValFromOnnxString(const std::unordered_map& supported_qnn_modes, + const std::string& onnx_attr_value, + const char* onnx_attr_name, + uint32_t& qnn_mode_value) { + auto it = supported_qnn_modes.find(onnx_attr_value); + if (it != supported_qnn_modes.end()) { + qnn_mode_value = it->second; + return Status::OK(); } - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN EP: Resize operator does not support ", onnx_mode_label, - " ", std::string(onnx_mode)); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN EP: Resize operator does not support ", onnx_attr_name, + " ", std::string(onnx_attr_value)); +} + +// Returns true if the given ONNX attribute mode value is generally supported on QNN. Note that +// different QNN backends may support a smaller subset of modes. +static bool IsOnnxAttrModeSupported(const std::unordered_map& supported_qnn_modes, + const std::string& onnx_attr_value) { + return supported_qnn_modes.find(onnx_attr_value) != supported_qnn_modes.end(); } // Resize ops are sensitive with data layout, no special validation so far @@ -169,118 +115,95 @@ Status ResizeOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, return AddToModelBuilder(qnn_model_wrapper, node_unit, logger, true); } + const bool is_npu_backend = IsNpuBackend(qnn_model_wrapper.GetQnnBackendType()); + NodeAttrHelper node_helper(node_unit); + // QNN doesn't support anti-aliasing (added in opset 18) if (node_unit.SinceVersion() >= 18) { - NodeAttrHelper node_helper(node_unit); const bool antialias = GetOnnxAttr(node_helper, onnx_antialias_attr) != 0; ORT_RETURN_IF(antialias, "QNN EP: Resize doesn't support anti-aliasing."); } - // The QNN Resize op does not currently work with the QNN cpu backend, but works with the HTP backend. Therefore, we - // currently use QNN's Resize op for quantized models and either ResizeBilinear or ResizeNearestNeighbor for - // non-quantized models. This requires separate validation for quantized models. - // TODO: Use only Resize once QNN's Resize op works in the QNN cpu backend. - bool is_npu_backend = IsNpuBackend(qnn_model_wrapper.GetQnnBackendType()); - return is_npu_backend ? ValidateQDQOp(qnn_model_wrapper, node_unit) : ValidateOp(qnn_model_wrapper, node_unit); -} - -Status ResizeOpBuilder::ValidateOp(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit) const { - NodeAttrHelper node_helper(node_unit); - const std::string resize_mode = GetOnnxAttr(node_helper, onnx_mode_attr); - ORT_RETURN_IF((resize_mode != "nearest") && (resize_mode != "linear"), - "QNN EP: Resize doesn't support mode '", resize_mode.c_str(), "'.", - "Only 'nearest' and 'linear' are supported."); - - const std::string coordinate_mode = GetOnnxAttr(node_helper, onnx_coord_transf_mode_attr); - ORT_RETURN_IF((coordinate_mode != "half_pixel") && (coordinate_mode != "align_corners"), - "QNN EP: coordinate transformation mode '", coordinate_mode.c_str(), "' not supported for Resize op.", - "Only 'align_corners' and 'half_pixel' are supported."); - - // Check for a valid "nearest_mode" if the mode is "nearest". - if (resize_mode == "nearest") { - // NOTE: QNN's ResizeNearestNeighbor operator does not have a way to specify rounding (i.e., "nearest_mode"). - // The output of the QNN ResizeNearestNeighbor operator is not always equivalent to ONNX's Resize - // operator with any single specific "nearest_mode". - // - // For some input/output shapes, QNN's ResizeNearestNeighbor is equivalent to ONNX's Resize with "round_prefer_floor". - // For other shapes, QNN's ResizeNearestNeighbor is equivalent to ONNX Resize with "round_prefer_ceil". - // - // From unit tests, I've found a relationship between input/output shapes and the equivalent ONNX "nearest_mode". - // If the new and old spatial dimensions are evenly divisible, the "nearest_mode" is "round_prefer_floor". - // Otherwise, the "nearest_mode" is "round_prefer_ceil". - // - // This relationship is probably incomplete/wrong. - // - // TODO: Ask Qualcomm what the correct "nearest_mode" should be, - // OR use QNN's own Resize operator once it works on QnnCpu. - const std::string& nearest_mode = GetOnnxAttr(node_helper, onnx_nearest_mode_attr); - ORT_RETURN_IF_NOT("floor" == nearest_mode, "QNN Resize only supports nearest_mode: floor!"); // This is wrong! - } - - auto& input_0 = node_unit.Inputs()[0]; - std::vector input_shape; - ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(input_0.node_arg, input_shape), - "QNN EP: Cannot get input shape for Resize op"); - - const auto& output_0 = node_unit.Outputs()[0]; - std::vector output_shape; - ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(output_0.node_arg, output_shape), - "QNN EP: Cannot get output shape for Resize op"); - - ORT_RETURN_IF(input_shape.size() != 4 || output_shape.size() != 4, "QNN Resize only supports 4D!"); - - ONNX_NAMESPACE::DataType input_data_type = input_0.node_arg.Type(); - ORT_RETURN_IF(input_data_type != ONNX_NAMESPACE::Utils::DataTypeUtils::ToType("float"), - "QNN EP: Data type ", input_data_type->c_str(), - " is not supported for Resize operator in CPU backend."); - - return Status::OK(); -} - -Status ResizeOpBuilder::ValidateQDQOp(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit) const { - NodeAttrHelper node_helper(node_unit); - - using namespace onnxruntime::qnn::utils; // Check mode const std::string interp_mode = GetOnnxAttr(node_helper, onnx_mode_attr); - ORT_RETURN_IF_NOT(ArrayHasString(supported_modes, interp_mode), "QNN EP: Resize does not support mode ", + ORT_RETURN_IF_NOT(IsOnnxAttrModeSupported(supported_modes, interp_mode), "QNN EP: Resize does not support mode ", interp_mode.c_str()); // Check coordinate transformation mode const std::string transformation_mode = GetOnnxAttr(node_helper, onnx_coord_transf_mode_attr); - ORT_RETURN_IF_NOT(ArrayHasString(supported_coord_transf_modes, transformation_mode), + ORT_RETURN_IF_NOT(IsOnnxAttrModeSupported(supported_coord_transf_modes, transformation_mode), "QNN EP: Resize does not support coordinate_transformation_mode ", transformation_mode.c_str()); - // Check nearest mode + const auto& input_0 = node_unit.Inputs()[0]; + std::vector input_shape; + ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(input_0.node_arg, input_shape), + "QNN EP: Cannot get shape for Resize input"); + const size_t input_rank = input_shape.size(); + + // Validate Resize w/ "nearest" mode. + // Translation matrix of ONNX Resize w/ "nearest" mode on HTP backend. + // Table entries correspond to the QNN operator used for the given configuration + // (Resize = QNN Resize op, RNN = QNN ResizeNearestNeighbor op, X = Unsupported). + // + // nearest_mode: + // coordinate_transformation_mode: | round_prefer_floor round_prefer_ceil floor ceil + // ----------------------------------------------------------------------------------------- + // half_pixel | Resize X RNN X + // pytorch_half_pixel | Resize X X X + // align_corners | Resize X RNN X + // asymmetric | Resize X RNN X + if (interp_mode == "nearest") { const std::string nearest_mode = GetOnnxAttr(node_helper, onnx_nearest_mode_attr); - ORT_RETURN_IF_NOT(ArrayHasString(supported_nearest_modes, nearest_mode), + ORT_RETURN_IF_NOT(IsOnnxAttrModeSupported(supported_nearest_modes, nearest_mode), "QNN EP: Resize does not support nearest_mode ", nearest_mode.c_str()); - // TODO: Support 'asymmetric' transformation mode with nearest_mode != 'floor'. - // - // QNN's ONNX converter tool translates 'nearest' + 'asymmetric' (regardless of rounding mode) - // to QNN's ResizeNearestNeighbor with {align_corners: 0, half_pixel: 0}. - // This is only accurate if the rounding mode is "floor". Need to investigate how to handle - // other rounding modes with Qualcomm. Ideally, we would use QNN's Resize operator, but it doesn't support - // the "asymmetric" coordinate transformation mode on HTP. - ORT_RETURN_IF(transformation_mode == "asymmetric" && nearest_mode != "floor", - "QNN EP: Resize with coordinate_transformation_mode 'asymmetric' and nearest_mode '", nearest_mode, - "' is not currently supported on the HTP backend."); + if (is_npu_backend) { + // QNN only supports the following nearest_mode values on HTP: + // - "round_prefer_floor" via QNN's Resize operator + // - "floor" via QNN's ResizeNearestNeighbor operator + // + // QNN validation does not throw an error if unsupported nearest_mode values are used, so we have to + // catch them here. Otherwise, accuracy is significantly degraded. + ORT_RETURN_IF_NOT(nearest_mode == "round_prefer_floor" || nearest_mode == "floor", + "QNN EP: Resize on the NPU does not support nearest_mode ", nearest_mode.c_str()); + + const bool use_resize_nn_op = nearest_mode == "floor"; + + // If HTP uses ResizeNearestNeighbor ("floor"), then the "pytorch_half_pixel" coordinate_transformation_mode + // is not supported. + ORT_RETURN_IF(use_resize_nn_op && transformation_mode == "pytorch_half_pixel", + "QNN EP: Resize on the NPU does not support the combination of nearest_mode == 'floor' ", + " and coordinate_transformation_mode == 'pytorch_half_pixel'."); + + // QNN's ResizeNearestNeighbor requires rank 4 inputs. + ORT_RETURN_IF(use_resize_nn_op && input_rank != 4, + "QNN EP: Resize on the NPU with nearest_mode == 'floor' requires an input with rank 4."); + } } - // Check that input shape has at least a rank of 3. - const auto& input_0 = node_unit.Inputs()[0]; - std::vector input_shape; - ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(input_0.node_arg, input_shape), - "QNN EP: Cannot get shape for Resize input"); - ORT_RETURN_IF(input_shape.size() < 3, "QNN EP: Resize input must have a rank >= 3."); + // Check that the input shape has at least a rank of 3 (and a max of 5 on HTP). + ORT_RETURN_IF(input_rank < 3 || (is_npu_backend && input_rank > 5), + "QNN EP: Resize input must have a rank >= 3. The maximum rank is 5 on the NPU."); const auto& output_0 = node_unit.Outputs()[0]; std::vector output_shape; ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(output_0.node_arg, output_shape), "QNN EP: Cannot get shape for Resize output"); - ORT_RETURN_IF(output_shape.size() < 3, "QNN EP: Resize output must have a rank >= 3."); + + // Check that only the spatial dimensions (width, height) are resized. The batch_size (N) and channels (C) should + // be untouched. This code runs before layout transformation, so we know that the current layout is "channel first" + // (e.g., N, C, S1, S2, ..., SN), and that the minimum rank is 3. + assert(node_unit.Domain() != kMSInternalNHWCDomain); + ORT_RETURN_IF_NOT(input_shape[0] == output_shape[0] && input_shape[1] == output_shape[1], + "QNN EP: Resize may only change the spatial dimensions."); + + if (!is_npu_backend) { + ONNX_NAMESPACE::DataType input_data_type = input_0.node_arg.Type(); + ORT_RETURN_IF(input_data_type != ONNX_NAMESPACE::Utils::DataTypeUtils::ToType("float"), + "QNN EP: Data type ", input_data_type->c_str(), + " is not supported for Resize operator in CPU backend."); + } return Status::OK(); } @@ -305,92 +228,34 @@ Status ResizeOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w std::vector&& input_names, const logging::Logger& logger, bool do_op_validation) const { - // The QNN Resize op does not currently work with the QNN cpu backend, but works with the HTP backend. Therefore, we - // currently use QNN's Resize op for quantized models and either ResizeBilinear or ResizeNearestNeighbor for - // non-quantized models. This requires separate handling for quantized models. - // TODO: Use only Resize once QNN's Resize op works in the QNN cpu backend. - bool is_quantized_node = NodeUnit::Type::QDQGroup == node_unit.UnitType(); - return is_quantized_node ? ProcessQDQOpAttrsAndOutputs(qnn_model_wrapper, node_unit, std::move(input_names), logger, do_op_validation) : ProcessOpAttrsAndOutputs(qnn_model_wrapper, node_unit, std::move(input_names), logger, do_op_validation); -} - -Status ResizeOpBuilder::ProcessOpAttrsAndOutputs(QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& node_unit, - std::vector&& input_names, - const logging::Logger& logger, - bool do_op_validation) const { - ORT_UNUSED_PARAMETER(logger); - NodeAttrHelper node_helper(node_unit); - const std::string resize_mode = GetOnnxAttr(node_helper, onnx_mode_attr); - std::string qnn_node_type = "ResizeNearestNeighbor"; - if ("linear" == resize_mode) { - qnn_node_type = "ResizeBilinear"; - } - - const std::string coordinate_mode = GetOnnxAttr(node_helper, onnx_coord_transf_mode_attr); - - Qnn_Scalar_t qnn_align_corners = QNN_SCALAR_INIT; - qnn_align_corners.dataType = QNN_DATATYPE_BOOL_8; - qnn_align_corners.bool8Value = static_cast(0); - - Qnn_Scalar_t qnn_half_pixel = QNN_SCALAR_INIT; - qnn_half_pixel.dataType = QNN_DATATYPE_BOOL_8; - qnn_half_pixel.bool8Value = static_cast(0); - - if ("align_corners" == coordinate_mode) { - qnn_align_corners.bool8Value = static_cast(1); - } else if ("half_pixel" == coordinate_mode) { - qnn_half_pixel.bool8Value = static_cast(1); - } - QnnParamWrapper qnn_align_corners_param(node_unit.Index(), node_unit.Name(), - QNN_OP_RESIZE_BILINEAR_PARAM_ALIGN_CORNERS, qnn_align_corners); - QnnParamWrapper qnn_half_pixel_param(node_unit.Index(), node_unit.Name(), - QNN_OP_RESIZE_BILINEAR_PARAM_HALF_PIXEL_CENTERS, qnn_half_pixel); - - std::vector param_tensor_names; - param_tensor_names.push_back(qnn_align_corners_param.GetParamTensorName()); - qnn_model_wrapper.AddParamWrapper(std::move(qnn_align_corners_param)); - param_tensor_names.push_back(qnn_half_pixel_param.GetParamTensorName()); - qnn_model_wrapper.AddParamWrapper(std::move(qnn_half_pixel_param)); - - return ProcessOutputs(qnn_model_wrapper, node_unit, std::move(input_names), std::move(param_tensor_names), - logger, do_op_validation, qnn_node_type); -} - -Status ResizeOpBuilder::ProcessQDQOpAttrsAndOutputs(QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& node_unit, - std::vector&& input_names, - const logging::Logger& logger, - bool do_op_validation) const { std::vector param_tensor_names; NodeAttrHelper node_helper(node_unit); const std::string interp_mode = GetOnnxAttr(node_helper, onnx_mode_attr); const std::string transformation_mode = GetOnnxAttr(node_helper, onnx_coord_transf_mode_attr); + const std::string nearest_mode = GetOnnxAttr(node_helper, onnx_nearest_mode_attr); + const bool is_npu_backend = IsNpuBackend(qnn_model_wrapper.GetQnnBackendType()); std::string qnn_op_type = "Resize"; - // Handle Resize with {mode: "nearest", coordinate_transformation_mode: "asymmetric"} uniquely. - // QNN's ONNX converter tool translates this configuration (regardless of rounding mode) - // to QNN's ResizeNearestNeighbor with {align_corners: 0, half_pixel: 0}. - // - // NOTE: This is only accurate if the rounding mode is "floor". Need to investigate how to handle - // other rounding modes with Qualcomm. Ideally, we would use QNN's Resize operator, but it doesn't support - // the "asymmetric" coordinate transformation mode on HTP. - if (interp_mode == "nearest" && transformation_mode == "asymmetric") { + // Translate Resize with {mode: "nearest", nearest_mode: "floor", coordinate_transformation_mode: XXX} to + // QNN's ResizeNearestNeighbor operator on the HTP backend. This combination of parameters is not supported on HTP + // via QNN's Resize operator. Note that QNN's ResizeNearestNeighbor operator always uses "floor" rounding. + if (is_npu_backend && interp_mode == "nearest" && nearest_mode == "floor") { qnn_op_type = "ResizeNearestNeighbor"; - // Set parameter 'align_corners' to 0 + // Parameter 'align_corners' Qnn_Scalar_t qnn_align_corners = QNN_SCALAR_INIT; qnn_align_corners.dataType = QNN_DATATYPE_BOOL_8; - qnn_align_corners.bool8Value = static_cast(0); + qnn_align_corners.bool8Value = static_cast(transformation_mode == "align_corners"); QnnParamWrapper qnn_align_corners_param(node_unit.Index(), node_unit.Name(), QNN_OP_RESIZE_BILINEAR_PARAM_ALIGN_CORNERS, qnn_align_corners); param_tensor_names.push_back(qnn_align_corners_param.GetParamTensorName()); qnn_model_wrapper.AddParamWrapper(std::move(qnn_align_corners_param)); - // Set parameter 'half_pixel_centers' to 0 + // Parameter 'half_pixel_centers' Qnn_Scalar_t qnn_half_pixel = QNN_SCALAR_INIT; qnn_half_pixel.dataType = QNN_DATATYPE_BOOL_8; - qnn_half_pixel.bool8Value = static_cast(0); + qnn_half_pixel.bool8Value = static_cast(transformation_mode == "half_pixel"); QnnParamWrapper qnn_half_pixel_param(node_unit.Index(), node_unit.Name(), QNN_OP_RESIZE_BILINEAR_PARAM_HALF_PIXEL_CENTERS, qnn_half_pixel); param_tensor_names.push_back(qnn_half_pixel_param.GetParamTensorName()); @@ -399,11 +264,12 @@ Status ResizeOpBuilder::ProcessQDQOpAttrsAndOutputs(QnnModelWrapper& qnn_model_w // Parameter 'transformation_mode' Qnn_Scalar_t qnn_transformation_mode = QNN_SCALAR_INIT; qnn_transformation_mode.dataType = QNN_DATATYPE_UINT_32; - ORT_RETURN_IF_ERROR(GetQnnModeFromString(supported_coord_transf_modes, transformation_mode, - "coordinate_transformation_mode", qnn_transformation_mode.uint32Value)); + ORT_RETURN_IF_ERROR(GetQnnModeValFromOnnxString(supported_coord_transf_modes, transformation_mode, + "coordinate_transformation_mode", + qnn_transformation_mode.uint32Value)); - QnnParamWrapper qnn_transformation_mode_param(node_unit.Index(), node_unit.Name(), QNN_OP_RESIZE_PARAM_TRANSFORMATION_MODE, - qnn_transformation_mode); + QnnParamWrapper qnn_transformation_mode_param(node_unit.Index(), node_unit.Name(), + QNN_OP_RESIZE_PARAM_TRANSFORMATION_MODE, qnn_transformation_mode); param_tensor_names.push_back(qnn_transformation_mode_param.GetParamTensorName()); qnn_model_wrapper.AddParamWrapper(std::move(qnn_transformation_mode_param)); @@ -420,7 +286,7 @@ Status ResizeOpBuilder::ProcessQDQOpAttrsAndOutputs(QnnModelWrapper& qnn_model_w // Parameter 'interpolation_mode' Qnn_Scalar_t qnn_interp_mode = QNN_SCALAR_INIT; qnn_interp_mode.dataType = QNN_DATATYPE_UINT_32; - ORT_RETURN_IF_ERROR(GetQnnModeFromString(supported_modes, interp_mode, "mode", qnn_interp_mode.uint32Value)); + ORT_RETURN_IF_ERROR(GetQnnModeValFromOnnxString(supported_modes, interp_mode, "mode", qnn_interp_mode.uint32Value)); QnnParamWrapper qnn_interp_mode_param(node_unit.Index(), node_unit.Name(), QNN_OP_RESIZE_PARAM_INTERPOLATION_MODE, qnn_interp_mode); @@ -429,11 +295,10 @@ Status ResizeOpBuilder::ProcessQDQOpAttrsAndOutputs(QnnModelWrapper& qnn_model_w // Parameter 'nearest_mode'. Processed only when 'interpolation_mode' is NEAREST(0). if (qnn_interp_mode.uint32Value == 0) { - const std::string nearest_mode = GetOnnxAttr(node_helper, onnx_nearest_mode_attr); Qnn_Scalar_t qnn_nearest_mode = QNN_SCALAR_INIT; qnn_nearest_mode.dataType = QNN_DATATYPE_UINT_32; - ORT_RETURN_IF_ERROR(GetQnnModeFromString(supported_nearest_modes, nearest_mode, "nearest_mode", - qnn_nearest_mode.uint32Value)); + ORT_RETURN_IF_ERROR(GetQnnModeValFromOnnxString(supported_nearest_modes, nearest_mode, "nearest_mode", + qnn_nearest_mode.uint32Value)); QnnParamWrapper qnn_nearest_mode_param(node_unit.Index(), node_unit.Name(), QNN_OP_RESIZE_PARAM_NEAREST_MODE, qnn_nearest_mode); diff --git a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc index 832a8a744c08b..0434b16dc66ce 100644 --- a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc @@ -99,9 +99,8 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_with_extr // CUDA: result mismatch due to not implementing NHWC support // TensorRT: results mismatch // ROCm: results mismatch - // QNN: conflict with layout transformer, need furture investigation test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider, kQnnExecutionProvider}); + {kCudaExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider}); } TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_with_extrapolation_uint8) { @@ -131,7 +130,7 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_with_extr test.AddOutput("Y", {N, static_cast(H * scales[1]), static_cast(W * scales[2]), C}, Y); // CUDA: result mismatch due to not implementing NHWC support // ROCm: results mismatch - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider, kQnnExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider}); } TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_with_extrapolation_int8) { @@ -159,7 +158,7 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_with_extr 10, 10, 10}; test.AddOutput("Y", {N, static_cast(H * scales[1]), static_cast(W * scales[2]), C}, Y); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kQnnExecutionProvider}); + test.Run(); } TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_without_extrapolation_uint8) { @@ -188,7 +187,7 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_without_e test.AddOutput("Y", {N, static_cast(H * scales[1]), static_cast(W * scales[2]), C}, Y); // CUDA: result mismatch due to not implementing NHWC support // ROCm: results mismatch - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider, kQnnExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider}); } TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_without_extrapolation_int8) { @@ -215,7 +214,7 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_without_e 0, 0, 0}; test.AddOutput("Y", {N, static_cast(H * scales[1]), static_cast(W * scales[2]), C}, Y); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kQnnExecutionProvider}); + test.Run(); } TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_4DBilinear) { @@ -261,9 +260,8 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear) { test.AddOutput("Y", {N, static_cast(H * scales[1]), static_cast(W * scales[2]), C}, Y); // CUDA: result mismatch due to not implementing NHWC support // ROCm: results mismatch - // QNN: conflict with layout transformer, need furture investigation test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaExecutionProvider, kRocmExecutionProvider, kQnnExecutionProvider}); + {kCudaExecutionProvider, kRocmExecutionProvider}); } TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_uint8) { @@ -287,7 +285,7 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_uint8) { test.AddOutput("Y", {N, static_cast(H * scales[1]), static_cast(W * scales[2]), C}, Y); // CUDA: result mismatch due to not implementing NHWC support // ROCm: results mismatch - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider, kQnnExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider}); } TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_int8) { @@ -309,7 +307,7 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_int8) { std::vector Y = {0, 0}; test.AddOutput("Y", {N, static_cast(H * scales[1]), static_cast(W * scales[2]), C}, Y); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kQnnExecutionProvider}); + test.Run(); } // Since NNAPI(TFLite) only using the scale calculate using the input/output size @@ -399,7 +397,9 @@ TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_4DBilinear_align_corners) { std::vector Y = {1.0f, 4.0f}; test.AddOutput("Y", {N, C, static_cast(H * scales[2]), static_cast(W * scales[3])}, Y); - test.Run(); + + // QNN: result mismatch ("NaN" instead of 1.0f on QNN CPU backend) + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kQnnExecutionProvider}); }; run_test(false); @@ -435,7 +435,7 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_align_corners_uin test.AddOutput("Y", {N, static_cast(H * scales[1]), static_cast(W * scales[2]), C}, Y); // CUDA: result mismatch due to not implementing NHWC support // ROCm: results mismatch - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider, kQnnExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider}); }; run_test(false); @@ -465,7 +465,7 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_align_corners_int test.AddOutput("Y", {N, static_cast(H * scales[1]), static_cast(W * scales[2]), C}, Y); // TensorRT: results mismatch - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kQnnExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); }; run_test(false); @@ -532,7 +532,7 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_pytorch_half_pixe test.AddOutput("Y", {N, sizes[1], sizes[2], C}, Y); // CUDA: result mismatch due to not implementing NHWC support // ROCm: results mismatch - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider, kQnnExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider}); } TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_pytorch_half_pixel_int8) { @@ -560,7 +560,7 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_pytorch_half_pixe std::vector Y = {0, 2, -9}; test.AddOutput("Y", {N, sizes[1], sizes[2], C}, Y); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kQnnExecutionProvider}); // TensorRT: results mismatch + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: results mismatch } TEST(ResizeOpTest, ResizeOpLinearUpSampleTest_4DBilinear_asymmetric) { @@ -641,7 +641,7 @@ TEST(ResizeOpTest, NhwcResizeOpLinearUpSampleTest_4DBilinear_asymmetric_uint8) { Y, false, .0f, 1.0f); // CUDA: result mismatch due to not implementing NHWC support // ROCm: results mismatch - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider, kQnnExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider}); }; run_test(false); @@ -683,7 +683,7 @@ TEST(ResizeOpTest, NhwcResizeOpLinearUpSampleTest_4DBilinear_asymmetric_int8) { test.AddOutput("Y", {N, static_cast(H * scales[1]), static_cast(W * scales[2]), C}, Y, false, .0f, 1.0f); // TensorRT: results mismatch - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kQnnExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); }; run_test(false); @@ -1079,7 +1079,7 @@ TEST(ResizeOpTest, ResizeOpNearestUpSample_Floor_Align_Corners) { 13.0f, 13.0f, 13.0f, 14.0f, 14.0f, 15.0f, 15.0f, 16.0f}; test.AddOutput("Y", {N, C, static_cast(H * scales[2]), static_cast(W * scales[3])}, Y); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kQnnExecutionProvider}); // QNN: result diff + test.Run(); } TEST(ResizeOpTest, ResizeOpNearest_OneToOneMappingBetweenInputAndOutputDataDims) { @@ -1887,7 +1887,7 @@ void TestAntialiasing(std::map attributes, test.AddOutput("Y", output_shape, output_data); // TensorRT 8.5 supports operators up to Opset 17. Temporarily exclude TensorRT EP due to accurarcy issue. - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kQnnExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } TEST(ResizeOpTest, Antialias_Bilinear_No_ExcludeOutside) { diff --git a/onnxruntime/test/providers/qnn/resize_test.cc b/onnxruntime/test/providers/qnn/resize_test.cc index cf336ca9eeb8b..cd6865d443cc0 100644 --- a/onnxruntime/test/providers/qnn/resize_test.cc +++ b/onnxruntime/test/providers/qnn/resize_test.cc @@ -120,7 +120,7 @@ static void RunCPUResizeOpTest(const TestInputDef& input_def, const std:: const std::string& mode, const std::string& coordinate_transformation_mode, const std::string& nearest_mode, ExpectedEPNodeAssignment expected_ep_assignment, - int opset = 11) { + int opset = 19) { ProviderOptions provider_options; #if defined(_WIN32) provider_options["backend_path"] = "QnnCpu.dll"; @@ -138,7 +138,7 @@ static void RunCPUResizeOpTestWithScales(const TestInputDef& input_def, c const std::string& mode, const std::string& coordinate_transformation_mode, const std::string& nearest_mode, ExpectedEPNodeAssignment expected_ep_assignment, - int opset = 11) { + int opset = 19) { ProviderOptions provider_options; #if defined(_WIN32) provider_options["backend_path"] = "QnnCpu.dll"; @@ -157,7 +157,8 @@ static void RunQDQResizeOpTest(const TestInputDef& input_def, const std::vector& sizes_data, const std::string& mode, const std::string& coordinate_transformation_mode, const std::string& nearest_mode, - ExpectedEPNodeAssignment expected_ep_assignment) { + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 19) { ProviderOptions provider_options; #if defined(_WIN32) provider_options["backend_path"] = "QnnHtp.dll"; @@ -169,27 +170,20 @@ static void RunQDQResizeOpTest(const TestInputDef& input_def, GetQDQResizeModelBuilder(input_def, sizes_data, mode, coordinate_transformation_mode, nearest_mode), provider_options, - 18, // opset - expected_ep_assignment, - 1e-5f); + opset, + expected_ep_assignment); } // // CPU tests: // -// TODO: Our QNN CPU translation of ONNX Resize with "nearest" mode uses QNN's ResizeNearestNeighbor -// operator, which does not have a way to specify rounding (i.e., "nearest_mode" in ONNX). It is not clear -// what kind of rounding QNN's ResizeNearestNeighbor uses. Therefore, we do not yet know how to compare -// ONNX Resize to QNN ResizeNearestNeighbor. These tests should remain disabled until this behavior is -// clarified. If, for example, it turns out that ResizeNearestNeighbor uses "floor" rounding, then we should -// only compare against ONNX resize with "floor" rounding. - // Upsample that uses "round_prefer_floor" as the "nearest_mode". // coordinate_transformation_mode: "half_pixel" -TEST_F(QnnCPUBackendTests, DISABLED_ResizeUpsampleNearestHalfPixel_rpf) { - RunCPUResizeOpTest(TestInputDef({1, 2, 7, 5}, false, -10.0f, 10.0f), // Random input w/ range [-10, 10] - {1, 2, 21, 10}, // Sizes +TEST_F(QnnCPUBackendTests, ResizeUpsampleNearestHalfPixel_rpf) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 70); + RunCPUResizeOpTest(TestInputDef({1, 2, 7, 5}, false, input_data), + {1, 2, 21, 10}, // Sizes "nearest", "half_pixel", "round_prefer_floor", @@ -198,57 +192,72 @@ TEST_F(QnnCPUBackendTests, DISABLED_ResizeUpsampleNearestHalfPixel_rpf) { // Upsample that uses "round_prefer_ceil" as the "nearest_mode". // coordinate_transformation_mode: "half_pixel" -TEST_F(QnnCPUBackendTests, DISABLED_ResizeUpsampleNearestHalfPixel_rpc) { - RunCPUResizeOpTest(TestInputDef({1, 1, 2, 4}, false, -10.0f, 10.0f), +TEST_F(QnnCPUBackendTests, ResizeUpsampleNearestHalfPixel_rpc) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 8); + RunCPUResizeOpTest(TestInputDef({1, 1, 2, 4}, false, input_data), {1, 1, 7, 5}, "nearest", "half_pixel", "round_prefer_ceil", ExpectedEPNodeAssignment::All); } // Downsample that uses "round_prefer_ceil" as the "nearest_mode". // coordinate_transformation_mode: "half_pixel" -TEST_F(QnnCPUBackendTests, DISABLED_ResizeDownsampleNearestHalfPixel_rpc) { - RunCPUResizeOpTest(TestInputDef({1, 1, 2, 4}, false, -10.0f, 10.0f), +TEST_F(QnnCPUBackendTests, ResizeDownsampleNearestHalfPixel_rpc) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 8); + RunCPUResizeOpTest(TestInputDef({1, 1, 2, 4}, false, input_data), {1, 1, 1, 3}, "nearest", "half_pixel", "round_prefer_ceil", ExpectedEPNodeAssignment::All); } // Downsample that uses "round_prefer_floor" as the "nearest_mode". // coordinate_transformation_mode: "half_pixel" -TEST_F(QnnCPUBackendTests, DISABLED_ResizeDownsampleNearestHalfPixel_rpf) { - RunCPUResizeOpTest(TestInputDef({1, 1, 2, 4}, false, -10.0f, 10.0f), +TEST_F(QnnCPUBackendTests, ResizeDownsampleNearestHalfPixel_rpf) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 8); + RunCPUResizeOpTest(TestInputDef({1, 1, 2, 4}, false, input_data), {1, 1, 1, 2}, "nearest", "half_pixel", "round_prefer_ceil", ExpectedEPNodeAssignment::All); } // Upsample that uses "round_prefer_floor" as the "nearest_mode". // coordinate_transformation_mode: "align_corners" -// QNN v2.13: index #50 don't match, which is 4.67152 from -1.93515 -TEST_F(QnnCPUBackendTests, DISABLED_ResizeUpsampleNearestAlignCorners_rpf) { - RunCPUResizeOpTest(TestInputDef({1, 2, 7, 5}, false, -10.0f, 10.0f), +TEST_F(QnnCPUBackendTests, ResizeUpsampleNearestAlignCorners_rpf) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 70); + RunCPUResizeOpTest(TestInputDef({1, 2, 7, 5}, false, input_data), {1, 2, 21, 10}, "nearest", "align_corners", "round_prefer_floor", ExpectedEPNodeAssignment::All); } +// Upsample that uses "round_prefer_floor" as the "nearest_mode". +// coordinate_transformation_mode: "asymmetric" +TEST_F(QnnCPUBackendTests, ResizeUpsampleNearestAsymmetric_rpf) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 70); + RunCPUResizeOpTest(TestInputDef({1, 2, 7, 5}, false, input_data), + {1, 2, 21, 10}, "nearest", "asymmetric", "round_prefer_floor", + ExpectedEPNodeAssignment::All); +} + // Upsample that uses "round_prefer_ceil" as the "nearest_mode". // coordinate_transformation_mode: "align_corners" -TEST_F(QnnCPUBackendTests, DISABLED_ResizeUpsampleNearestAlignCorners_rpc) { - RunCPUResizeOpTest(TestInputDef({1, 1, 2, 4}, false, -10.0f, 10.0f), +TEST_F(QnnCPUBackendTests, ResizeUpsampleNearestAlignCorners_rpc) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 8); + RunCPUResizeOpTest(TestInputDef({1, 1, 2, 4}, false, input_data), {1, 1, 7, 5}, "nearest", "align_corners", "round_prefer_ceil", ExpectedEPNodeAssignment::All); } // Downsample that uses "round_prefer_ceil" as the "nearest_mode". // coordinate_transformation_mode: "align_corners" -TEST_F(QnnCPUBackendTests, DISABLED_ResizeDownsampleNearestAlignCorners_rpc) { - RunCPUResizeOpTest(TestInputDef({1, 1, 2, 4}, false, -10.0f, 10.0f), +TEST_F(QnnCPUBackendTests, ResizeDownsampleNearestAlignCorners_rpc) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 8); + RunCPUResizeOpTest(TestInputDef({1, 1, 2, 4}, false, input_data), {1, 1, 1, 3}, "nearest", "align_corners", "round_prefer_ceil", ExpectedEPNodeAssignment::All); } // Downsample that uses "round_prefer_floor" as the "nearest_mode". // coordinate_transformation_mode: "align_corners" -TEST_F(QnnCPUBackendTests, DISABLED_ResizeDownsampleNearestAlignCorners_rpf) { - RunCPUResizeOpTest(TestInputDef({1, 1, 2, 4}, false, -10.0f, 10.0f), +TEST_F(QnnCPUBackendTests, ResizeDownsampleNearestAlignCorners_rpf) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 8); + RunCPUResizeOpTest(TestInputDef({1, 1, 2, 4}, false, input_data), {1, 1, 1, 2}, "nearest", "align_corners", "round_prefer_floor", ExpectedEPNodeAssignment::All); } @@ -258,76 +267,177 @@ TEST_F(QnnCPUBackendTests, DISABLED_ResizeDownsampleNearestAlignCorners_rpf) { // TEST_F(QnnCPUBackendTests, Resize2xLinearHalfPixel) { - RunCPUResizeOpTest(TestInputDef({1, 3, 4, 5}, false, -10.0f, 10.0f), + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 60); + RunCPUResizeOpTest(TestInputDef({1, 3, 4, 5}, false, input_data), {1, 3, 8, 10}, "linear", "half_pixel", "", ExpectedEPNodeAssignment::All); } TEST_F(QnnCPUBackendTests, Resize2xLinearHalfPixel_scales) { - RunCPUResizeOpTestWithScales(TestInputDef({1, 3, 4, 5}, false, -10.0f, 10.0f), + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 60); + RunCPUResizeOpTestWithScales(TestInputDef({1, 3, 4, 5}, false, input_data), {1.0f, 1.0f, 2.0f, 2.0f}, "linear", "half_pixel", "", ExpectedEPNodeAssignment::All); } TEST_F(QnnCPUBackendTests, Resize2xLinearAlignCorners) { - RunCPUResizeOpTest(TestInputDef({1, 3, 4, 5}, false, -10.0f, 10.0f), + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 60); + RunCPUResizeOpTest(TestInputDef({1, 3, 4, 5}, false, input_data), {1, 3, 8, 10}, "linear", "align_corners", "", ExpectedEPNodeAssignment::All); } TEST_F(QnnCPUBackendTests, Resize2xLinearAlignCorners_scales) { - RunCPUResizeOpTestWithScales(TestInputDef({1, 3, 4, 5}, false, -10.0f, 10.0f), + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 60); + RunCPUResizeOpTestWithScales(TestInputDef({1, 3, 4, 5}, false, input_data), {1.0f, 1.0f, 2.0f, 2.0f}, "linear", "align_corners", "", ExpectedEPNodeAssignment::All); } +// Test Resize downsample with mode: "linear", coordinate_transformation_mode: "align_corners" +// TODO: Enable ResizeOpTest.ResizeOpLinearDownSampleTest_4DBilinear_align_corners in cpu resize_op tests when fixed. +// +// Input f32[1,1,2,4]: 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0 +// Expected output f32[1, 1, 1, 2]: 1.0, 4.0 +// Actual output f32[1, 1, 1, 2]: NaN, NaN +TEST_F(QnnCPUBackendTests, DISABLED_Resize_DownSample_Linear_AlignCorners_scales) { + std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}; + RunCPUResizeOpTestWithScales(TestInputDef({1, 1, 2, 4}, false, input_data), + {1.0f, 1.0f, 0.6f, 0.6f}, "linear", "align_corners", "", + ExpectedEPNodeAssignment::All); +} + +// Test Resize downsample with mode: "linear", coordinate_transformation_mode: "half_pixel" +// TODO: Enable ResizeOpTest.ResizeOpLinearDownSampleTest_4DBilinear cpu resize_op tests when fixed. +// +// Input f32[1,1,2,4]: 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0 +// Expected output f32[1, 1, 1, 2]: 2.6666 4.3333 +// Actual output f32[1, 1, 1, 2]: NaN, NaN +TEST_F(QnnCPUBackendTests, DISABLED_Resize_DownSample_Linear_HalfPixel_scales) { + std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}; + RunCPUResizeOpTestWithScales(TestInputDef({1, 1, 2, 4}, false, input_data), + {1.0f, 1.0f, 0.6f, 0.6f}, "linear", "half_pixel", "", + ExpectedEPNodeAssignment::All); +} + #if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) // // HTP tests: // +// Test QDQ Resize downsample with mode: "linear", coordinate_transformation_mode: "align_corners" +TEST_F(QnnHTPBackendTests, Resize_DownSample_Linear_AlignCorners) { + std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}; + RunQDQResizeOpTest(TestInputDef({1, 1, 2, 4}, false, input_data), + {1, 1, 1, 2}, "linear", "align_corners", "", + ExpectedEPNodeAssignment::All); +} + +// Test QDQ Resize downsample with mode: "linear", coordinate_transformation_mode: "half_pixel" +TEST_F(QnnHTPBackendTests, Resize_DownSample_Linear_HalfPixel) { + std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}; + RunQDQResizeOpTest(TestInputDef({1, 1, 2, 4}, false, input_data), + {1, 1, 1, 2}, "linear", "half_pixel", "", + ExpectedEPNodeAssignment::All); +} + +// Test 2x QDQ Resize mode: "linear", coordinate_transformation_mode: "pytorch_half_pixel" +// QNN EP uses QNN's Resize op. TEST_F(QnnHTPBackendTests, ResizeU8_2xLinearPytorchHalfPixel) { - RunQDQResizeOpTest(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 48); + RunQDQResizeOpTest(TestInputDef({1, 3, 4, 4}, false, input_data), {1, 3, 8, 8}, "linear", "pytorch_half_pixel", "", ExpectedEPNodeAssignment::All); } -TEST_F(QnnHTPBackendTests, ResizeU8_2xNearestHalfPixelRoundPreferFloor) { - RunQDQResizeOpTest(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), - {1, 3, 8, 8}, "nearest", "half_pixel", "round_prefer_floor", +// Test 2x QDQ Resize mode: "linear", coordinate_transformation_mode: "half_pixel" +// QNN EP uses QNN's Resize op. +TEST_F(QnnHTPBackendTests, ResizeU8_2xLinearHalfPixel) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 48); + RunQDQResizeOpTest(TestInputDef({1, 3, 4, 4}, false, input_data), + {1, 3, 8, 8}, "linear", "half_pixel", "", ExpectedEPNodeAssignment::All); } -TEST_F(QnnHTPBackendTests, ResizeU8_2xNearestAsymmetricFloor) { - RunQDQResizeOpTest(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), - {1, 3, 8, 8}, "nearest", "asymmetric", "floor", +// Test 2x QDQ Resize mode: "linear", coordinate_transformation_mode: "align_corners" +// QNN EP uses QNN's Resize op. +TEST_F(QnnHTPBackendTests, ResizeU8_2xLinearAlignCorners) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 48); + RunQDQResizeOpTest(TestInputDef({1, 3, 4, 4}, false, input_data), + {1, 3, 8, 8}, "linear", "align_corners", "", ExpectedEPNodeAssignment::All); } -// TODO: Investigate with Qualcomm. The qnn-onnx-converter tool translates ONNX Resize [nearest, asymmetric, ceil] to -// QNN ResizeNearestNeighbor {align_corners: 0, half_pixel: 0}, which is NOT equivalent. It would be better to use -// QNN's own Resize operator (instead of ResizeNearestNeighbor), but it doesn't support the "asymmetric" coordinate -// transform mode. -// -// QNN v2.13: Inaccuracy detected for output 'output', element 189. -// Output quant params: scale=0.078431375324726105, zero_point=127. -// Expected val: -2.663428783416748 -// QNN QDQ val: 7.4509806632995605 (err 10.114409446716309) -// CPU QDQ val: -2.6666667461395264 (err 0.0032379627227783203) -TEST_F(QnnHTPBackendTests, DISABLED_ResizeU8_2xNearestAsymmetricCeil) { - RunQDQResizeOpTest(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), - {1, 3, 8, 8}, "nearest", "asymmetric", "ceil", +// Test 2x QDQ Resize mode: "linear", coordinate_transformation_mode: "asymmetric" +// QNN EP uses QNN's Resize op. +TEST_F(QnnHTPBackendTests, ResizeU8_2xLinearAsymmetric) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 48); + RunQDQResizeOpTest(TestInputDef({1, 3, 4, 4}, false, input_data), + {1, 3, 8, 8}, "linear", "asymmetric", "", ExpectedEPNodeAssignment::All); } +// Test 2x QDQ Resize mode: "nearest", coordinate_transformation_mode: "half_pixel", nearest_mode: "round_prefer_floor" +// QNN EP uses QNN's Resize op. +TEST_F(QnnHTPBackendTests, ResizeU8_2xNearestHalfPixelRoundPreferFloor) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 48); + RunQDQResizeOpTest(TestInputDef({1, 3, 4, 4}, false, input_data), + {1, 3, 8, 8}, "nearest", "half_pixel", "round_prefer_floor", + ExpectedEPNodeAssignment::All); +} + +// Test that the nearest_mode "ceil" is not supported on the HTP backend. +TEST_F(QnnHTPBackendTests, ResizeU8_NearestModeCeil_Unsupported) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 48); + RunQDQResizeOpTest(TestInputDef({1, 3, 4, 4}, false, input_data), + {1, 3, 8, 8}, "nearest", "asymmetric", "ceil", + ExpectedEPNodeAssignment::None); +} + +// Test 3x QDQ Resize mode: "nearest", coordinate_transformation_mode: "asymmetric", nearest_mode: "floor". +// QNN EP uses QNN's ResizeNearestNeighbor op. TEST_F(QnnHTPBackendTests, ResizeU8_3xNearestAsymmetricFloor) { - RunQDQResizeOpTest(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 48); + RunQDQResizeOpTest(TestInputDef({1, 3, 4, 4}, false, input_data), {1, 3, 12, 12}, "nearest", "asymmetric", "floor", ExpectedEPNodeAssignment::All); } +// Test 2x QDQ Resize mode: "nearest", coordinate_transformation_mode: "asymmetric", nearest_mode: "round_prefer_floor" +// QNN EP uses QNN's Resize op. +TEST_F(QnnHTPBackendTests, ResizeU8_2xNearestAsymmetricRoundPreferFloor) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 8); + RunQDQResizeOpTest(TestInputDef({1, 2, 2, 2}, false, input_data), + {1, 2, 4, 4}, "nearest", "asymmetric", "round_prefer_floor", + ExpectedEPNodeAssignment::All); +} + +// Test 3x QDQ Resize mode: "nearest", coordinate_transformation_mode: "asymmetric", nearest_mode: "round_prefer_floor" +// QNN EP uses QNN's Resize op. +// +// TODO: Inaccuracy detected for output 'output_0', element 2. +// Output quant params: scale=0.078431375324726105, zero_point=127. +// Expected val: -3.3333334922790527 +// QNN QDQ val: -9.960784912109375 (err 6.6274514198303223) +// CPU QDQ val: -3.2941176891326904 (err 0.039215803146362305) +// +// More debugging info: +// Input elements f32[1,1,2,2] = -10.0000000 -3.33333349 3.33333302 10.0000000 +// ORT CPU EP (f32 model) outputs: -10.0000000 -10.0000000 -3.33333349 -3.33333349 -3.33333349 -3.33333349 -10.00 ... +// ORT CPU EP (qdq model) outputs: -9.96078491 -9.96078491 -3.29411769 -3.29411769 -3.29411769 -3.29411769 -9.961 ... +// ORT QNN EP (qdq model) outputs: -9.96078491 -9.96078491 -9.96078491 -3.37254906 -3.37254906 -3.37254906 -9.961 ... +TEST_F(QnnHTPBackendTests, DISABLED_ResizeU8_3xNearestAsymmetricRoundPreferFloor) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 4); + RunQDQResizeOpTest(TestInputDef({1, 1, 2, 2}, false, input_data), + {1, 1, 6, 6}, "nearest", "asymmetric", "round_prefer_floor", + ExpectedEPNodeAssignment::All); +} + +// Test 0.5x QDQ Resize mode: "nearest", coordinate_transformation_mode: "asymmetric", nearest_mode: "floor" +// QNN EP uses QNN's ResizeNearestNeighbor op. TEST_F(QnnHTPBackendTests, ResizeU8_HalfNearestAsymmetricFloor) { - RunQDQResizeOpTest(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 48); + RunQDQResizeOpTest(TestInputDef({1, 3, 4, 4}, false, input_data), {1, 3, 2, 2}, "nearest", "asymmetric", "floor", ExpectedEPNodeAssignment::All); } From ce287a4e77895e7f6147a044ae5c723a48cb8277 Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Sat, 23 Sep 2023 07:06:04 +0800 Subject: [PATCH 19/58] [WebNN EP] Remove workaround for dynamic shape (#17644) As now we have the FreeDimensionOverrides option to support dynamic shape, we can remove the previous workaround. --- onnxruntime/core/providers/webnn/builders/helper.cc | 7 +++++-- .../core/providers/webnn/builders/model_builder.cc | 9 +++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/helper.cc b/onnxruntime/core/providers/webnn/builders/helper.cc index 31453e005272e..774df067fe347 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.cc +++ b/onnxruntime/core/providers/webnn/builders/helper.cc @@ -53,9 +53,12 @@ bool IsInputSupported(const NodeArg& input, const std::string& parent_name, cons } for (const auto& dim : shape_proto->dim()) { - // For now we workaround dynamic shape support by assuming 1. + // WebNN doesn't support dynamic shape - use sessionOptions.freeDimensionOverrides to fix the shape. if (!dim.has_dim_value()) { - LOGS(logger, VERBOSE) << "Dynamic shape is not supported for now, assume to be 1, for input:" << input_name; + LOGS(logger, VERBOSE) << "Dynamic shape is not supported, " + << "use sessionOptions.FreeDimensionOverrides to set a fixed shape for input: " + << input_name; + return false; } } diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc index 14ca4f1a1e674..2eae8cebbbd66 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -218,12 +218,9 @@ Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_i } else { dims.reserve(shape.size()); for (const auto& dim : shape) { - if (!dim.has_dim_value()) { - // FIXME: support dyanmic shape. - dims.push_back(1); - } else { - dims.push_back(SafeInt(dim.dim_value())); - } + // dim_param free dimensions should have already been excluded by IsInputSupported(). + assert(dim.has_dim_value()); + dims.push_back(SafeInt(dim.dim_value())); } } } From 216214b7d302cb504d1e5a647f65b6fe49c22dbb Mon Sep 17 00:00:00 2001 From: PeixuanZuo <94887879+PeixuanZuo@users.noreply.github.com> Date: Mon, 25 Sep 2023 10:35:28 +0800 Subject: [PATCH 20/58] [ROCm] Remove ROCm5.4.2, ROCm 5.5 and add ROCm5.7 to python package pipeline (#17668) - Remove ROCm5.4.2, ROCm 5.5 and add ROCm5.7 to python package pipeline - Remove redundant arg --- ...orttraining-py-packaging-pipeline-rocm.yml | 30 ++++++------------- .../github/azure-pipelines/templates/rocm.yml | 1 - 2 files changed, 9 insertions(+), 22 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-rocm.yml b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-rocm.yml index eb837b35af428..a45b7d57205d1 100644 --- a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-rocm.yml +++ b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-rocm.yml @@ -14,51 +14,39 @@ stages: - template: templates/rocm.yml parameters: PythonVersion: '3.8' - RocmVersion: '5.4.2' - - template: templates/rocm.yml - parameters: - PythonVersion: '3.9' - RocmVersion: '5.4.2' - - template: templates/rocm.yml - parameters: - PythonVersion: '3.10' - RocmVersion: '5.4.2' - - template: templates/rocm.yml - parameters: - PythonVersion: '3.8' - RocmVersion: '5.5' + RocmVersion: '5.6' - template: templates/rocm.yml parameters: PythonVersion: '3.9' - RocmVersion: '5.5' + RocmVersion: '5.6' - template: templates/rocm.yml parameters: PythonVersion: '3.10' - RocmVersion: '5.5' + RocmVersion: '5.6' - template: templates/rocm.yml parameters: PythonVersion: '3.8' - RocmVersion: '5.6' + RocmVersion: '5.7' - template: templates/rocm.yml parameters: PythonVersion: '3.9' - RocmVersion: '5.6' + RocmVersion: '5.7' - template: templates/rocm.yml parameters: PythonVersion: '3.10' - RocmVersion: '5.6' + RocmVersion: '5.7' - template: templates/rocm.yml parameters: PythonVersion: '3.8' - RocmVersion: '5.6' + RocmVersion: '5.7' BuildConfig: 'RelWithDebInfo' - template: templates/rocm.yml parameters: PythonVersion: '3.9' - RocmVersion: '5.6' + RocmVersion: '5.7' BuildConfig: 'RelWithDebInfo' - template: templates/rocm.yml parameters: PythonVersion: '3.10' - RocmVersion: '5.6' + RocmVersion: '5.7' BuildConfig: 'RelWithDebInfo' diff --git a/tools/ci_build/github/azure-pipelines/templates/rocm.yml b/tools/ci_build/github/azure-pipelines/templates/rocm.yml index cc2e8745e8946..d43029266b4b0 100644 --- a/tools/ci_build/github/azure-pipelines/templates/rocm.yml +++ b/tools/ci_build/github/azure-pipelines/templates/rocm.yml @@ -51,7 +51,6 @@ jobs: --build-arg INSTALL_DEPS_EXTRA_ARGS=-tmur --build-arg BUILD_UID=$(id -u) --network=host --build-arg POLICY=manylinux_2_28 --build-arg PLATFORM=x86_64 - --build-arg ROCM_VERSION=$(RocmVersion) --build-arg DEVTOOLSET_ROOTPATH=/opt/rh/gcc-toolset-12/root --build-arg PREPEND_PATH=/opt/rh/gcc-toolset-12/root/usr/bin: --build-arg LD_LIBRARY_PATH_ARG=/opt/rh/gcc-toolset-12/root/usr/lib64:/opt/rh/gcc-toolset-12/root/usr/lib:/opt/rh/gcc-toolset-12/root/usr/lib64/dyninst:/opt/rh/gcc-toolset-12/root/usr/lib/dyninst:/usr/local/lib64:/usr/local/lib From df15a3a335cb4e6703404beeb405f987521e4cf9 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Mon, 25 Sep 2023 09:22:00 -0700 Subject: [PATCH 21/58] [js/web] configure 5GB memory space for webpack build (#17684) ### Description ort-web build step - webpack consumes the amount of memory on the edge of Node.js(V8)'s default max-old-space-size, so increase the default memory size to 5GB to avoid this issue. --- js/web/script/build.ts | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/js/web/script/build.ts b/js/web/script/build.ts index d3a5be429bfa1..03510ae86b85f 100644 --- a/js/web/script/build.ts +++ b/js/web/script/build.ts @@ -176,7 +176,12 @@ npmlog.info('Build', 'Building bundle...'); webpackArgs.push('--env', `-f=${FILTER}`); } npmlog.info('Build.Bundle', `CMD: npx ${webpackArgs.join(' ')}`); - const webpack = spawnSync('npx', webpackArgs, {shell: true, stdio: 'inherit', cwd: ROOT_FOLDER}); + const webpack = spawnSync('npx', webpackArgs, { + shell: true, + stdio: 'inherit', + cwd: ROOT_FOLDER, + env: {...process.env, NODE_OPTIONS: (process.env.NODE_OPTIONS ?? '') + ' --max-old-space-size=5120'} + }); if (webpack.status !== 0) { console.error(webpack.error); process.exit(webpack.status === null ? undefined : webpack.status); From 905faea3b2683383bdb71f4ef3bb0a8f0a0832c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 25 Sep 2023 19:11:58 +0200 Subject: [PATCH 22/58] Fix static quantization for QDQ and Percentile distribution (#17649) ### Description One quantization case was not covered by the current list of unit tests. This PR adds a unit test to cover that case with the fix. It fixes the issue #17619. ### Motivation and Context --- .../providers/cpu/quantization/qlinearconv.cc | 3 +- .../python/tools/quantization/calibrate.py | 4 +- .../tools/quantization/operators/conv.py | 2 +- .../tools/quantization/operators/lstm.py | 4 +- .../tools/quantization/qdq_quantizer.py | 8 +- .../test/python/quantization/resnet_code.py | 13757 ++++++++++++++++ .../test_quantize_static_resnet.py | 138 + 7 files changed, 13909 insertions(+), 7 deletions(-) create mode 100644 onnxruntime/test/python/quantization/resnet_code.py create mode 100644 onnxruntime/test/python/quantization/test_quantize_static_resnet.py diff --git a/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc b/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc index e9fc8d857b831..21a256eee6f14 100644 --- a/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc +++ b/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc @@ -77,7 +77,8 @@ class QLinearConv : public OpKernel { W_zero_point_value = W_zero_point_data[0]; for (int64_t i = 1; i < W_zero_point_size; i++) { ORT_ENFORCE(W_zero_point_data[i] == W_zero_point_value, - "QLinearConv : zero point of per-channel filter must be same"); + "QLinearConv : zero point of per-channel filter must be same. " + "This happens by design if the quantization is symmetric."); } } diff --git a/onnxruntime/python/tools/quantization/calibrate.py b/onnxruntime/python/tools/quantization/calibrate.py index bdf00f21100bf..26e74a6dfbac9 100644 --- a/onnxruntime/python/tools/quantization/calibrate.py +++ b/onnxruntime/python/tools/quantization/calibrate.py @@ -22,7 +22,7 @@ class TensorData: - _allowed = frozenset(["avg", "std", "lowest", "highest", "hist", "hist_edges"]) + _allowed = frozenset(["avg", "std", "lowest", "highest", "hist", "hist_edges", "bins"]) def __init__(self, **kwargs): for k, v in kwargs.items(): @@ -55,7 +55,7 @@ def __init__(self, calibration_method, data: Dict[str, Union[TensorData, Tuple]] self.data[k] = TensorData(lowest=v[0], highest=v[1]) continue if len(v) == 4: - self.data[k] = TensorData(lowest=v[0], highest=v[1], histogram=v[2], bins=v[3]) + self.data[k] = TensorData(lowest=v[0], highest=v[1], hist=v[2], bins=v[3]) continue raise TypeError(f"Unexpected tuple for {k:r}, it has {len(v)} elements: {v}.") if not isinstance(v, TensorData): diff --git a/onnxruntime/python/tools/quantization/operators/conv.py b/onnxruntime/python/tools/quantization/operators/conv.py index d23459b478e6a..23f9eaf4b0e0b 100644 --- a/onnxruntime/python/tools/quantization/operators/conv.py +++ b/onnxruntime/python/tools/quantization/operators/conv.py @@ -157,7 +157,7 @@ def quantize(self): nodes, ) = self.quantizer.quantize_activation(node, [0]) quant_weight_tuple = self.quantizer.quantize_weight_per_channel( - node.input[1], onnx_proto.TensorProto.INT8, 0 + node.input[1], onnx_proto.TensorProto.INT8, 0 # self.quantizer.weight_qType? ) quantized_input_names.append(quant_weight_tuple[0]) zero_point_names.append(quant_weight_tuple[1]) diff --git a/onnxruntime/python/tools/quantization/operators/lstm.py b/onnxruntime/python/tools/quantization/operators/lstm.py index 7e91f9b76ca36..90a52cb528b32 100644 --- a/onnxruntime/python/tools/quantization/operators/lstm.py +++ b/onnxruntime/python/tools/quantization/operators/lstm.py @@ -47,10 +47,10 @@ def quantize(self): R.dims[0] = R_num_dir * R_4_hidden_size quant_input_weight_tuple = self.quantizer.quantize_weight_per_channel( - node.input[1], onnx_proto.TensorProto.INT8, 0 + node.input[1], onnx_proto.TensorProto.INT8, 0 # self.quantizer.weight_qType? ) quant_recurrent_weight_tuple = self.quantizer.quantize_weight_per_channel( - node.input[2], onnx_proto.TensorProto.INT8, 0 + node.input[2], onnx_proto.TensorProto.INT8, 0 # self.quantizer.weight_qType? ) W_quant_weight = model.get_initializer(quant_input_weight_tuple[0]) # noqa: N806 diff --git a/onnxruntime/python/tools/quantization/qdq_quantizer.py b/onnxruntime/python/tools/quantization/qdq_quantizer.py index e595b580b20df..5c97dd20cf507 100644 --- a/onnxruntime/python/tools/quantization/qdq_quantizer.py +++ b/onnxruntime/python/tools/quantization/qdq_quantizer.py @@ -283,7 +283,13 @@ def _add_qdq_pair_for_initializer(self, weight_proto, tensor_type, axis=None): raise ValueError("Per-Channel support with QDQ format requires onnx opset version 13 or above.") q_weight_name, zp_name, scale_name = self.quantize_weight_per_channel( weight_name, - self.weight_qType if tensor_type is QDQQuantTensorType.WEIGHT else self.activation_qType, + # Quantization type is forced to be TensorProto.INT8. + # when the expected value would be (see below) + # self.weight_qType if tensor_type is QDQQuantTensorType.WEIGHT else self.activation_qType. + # QLinearConv expects to have a unique value for all channels. + # This code does not enforce that but it is necessarily the case when the + # quantization is symmetric (as for INT8). + onnx_proto.TensorProto.INT8, axis, keep_float_weight=self.add_qdq_pair_to_weight, ) diff --git a/onnxruntime/test/python/quantization/resnet_code.py b/onnxruntime/test/python/quantization/resnet_code.py new file mode 100644 index 0000000000000..2f78047c824a6 --- /dev/null +++ b/onnxruntime/test/python/quantization/resnet_code.py @@ -0,0 +1,13757 @@ +import numpy +from onnx import numpy_helper +from onnx.helper import make_graph, make_model, make_node, make_opsetid, make_tensor_value_info, set_model_props + + +def create_model(): + initializers = [] + nodes = [] + inputs = [] + outputs = [] + functions = [] + + # opsets + opsets = {"": 13} + + # initializers + + list_value = [ + -0.013732648454606533, + -0.005861935671418905, + 0.06889285147190094, + -0.1172710582613945, + 0.08841240406036377, + -0.03748627379536629, + 0.016256270930171013, + -0.1059316024184227, + 0.08246039599180222, + 0.14295539259910583, + -0.32958757877349854, + 0.1631188541650772, + 0.05412565544247627, + -0.10758306831121445, + 0.12607362866401672, + -0.4987836182117462, + 0.7441706657409668, + -0.24774713814258575, + -0.30415549874305725, + 0.4033295810222626, + -0.13447114825248718, + 0.04623159021139145, + 0.2380414456129074, + -1.226112723350525, + 2.150630235671997, + -1.702580213546753, + 0.5305419564247131, + -0.06836353242397308, + -0.20055373013019562, + 0.7035881280899048, + -0.8389442563056946, + -0.1904432326555252, + 1.2609282732009888, + -1.0670661926269531, + 0.4142579436302185, + 0.04739700257778168, + -0.3265092074871063, + 1.1873037815093994, + -1.6817731857299805, + 0.9709527492523193, + -0.09095840901136398, + -0.12556785345077515, + 0.0835147574543953, + -0.24109329283237457, + 0.032948240637779236, + 0.46304041147232056, + -0.6594106554985046, + 0.349990576505661, + -0.04113377630710602, + 0.016451245173811913, + 0.008994563482701778, + -0.028321878984570503, + -0.05336569994688034, + 0.16036668419837952, + -0.12088149785995483, + 0.031160499900579453, + -0.0618649423122406, + 0.07205374538898468, + 0.15965768694877625, + -0.3389044404029846, + 0.21603335440158844, + 0.04029613360762596, + -0.0813034325838089, + 0.1019665077328682, + -0.4873599112033844, + 0.7873126268386841, + -0.2951086163520813, + -0.43754327297210693, + 0.5905176401138306, + -0.21821773052215576, + 0.06022067740559578, + 0.26326146721839905, + -1.6453089714050293, + 2.606400728225708, + -1.8939754962921143, + 0.5196341276168823, + 0.0055860355496406555, + -0.2335057258605957, + 0.9807199239730835, + -1.2137882709503174, + -0.2699125409126282, + 1.7379733324050903, + -1.4401814937591553, + 0.435971736907959, + -0.04829222336411476, + -0.24543480575084686, + 1.3292583227157593, + -2.0375823974609375, + 1.2458536624908447, + -0.08251484483480453, + -0.14181238412857056, + 0.10612589120864868, + -0.21671657264232635, + 0.1129523366689682, + 0.3666985034942627, + -0.7546612024307251, + 0.42979565262794495, + -0.0976259633898735, + -0.0008812264422886074, + 0.02994859404861927, + -0.07027778774499893, + 0.01393035613000393, + 0.07363647222518921, + -0.10249849408864975, + 0.06602989137172699, + -0.012129798531532288, + 0.10730132460594177, + -0.04546127840876579, + -0.16065146028995514, + 0.14788293838500977, + -0.05488971993327141, + 0.03601694852113724, + 0.07513345777988434, + -0.23953600227832794, + 0.48062530159950256, + -0.42057543992996216, + -0.02402813360095024, + 0.17920851707458496, + -0.10703158378601074, + -0.028666120022535324, + 0.2815375030040741, + -0.860264241695404, + 1.4422725439071655, + -1.2058128118515015, + 0.5272247791290283, + -0.06504356116056442, + -0.20021803677082062, + 0.44968947768211365, + -0.3856053650379181, + -0.1589551419019699, + 0.7579770684242249, + -0.8349987268447876, + 0.3225692808628082, + 0.08153475821018219, + -0.43163740634918213, + 0.8742384910583496, + -0.9722443222999573, + 0.579015851020813, + -0.06688100844621658, + -0.12384293973445892, + 0.08289378881454468, + -0.10082041472196579, + -0.11204896867275238, + 0.3934254050254822, + -0.4511864185333252, + 0.32745760679244995, + -0.06534548103809357, + -0.028830429539084435, + 0.021844232454895973, + 0.01775779016315937, + -0.004250001162290573, + 0.013087524101138115, + -0.001250433037057519, + -0.040545206516981125, + -0.014049320481717587, + -0.024194253608584404, + -0.023865194991230965, + -0.0038033330347388983, + 0.00920871365815401, + -0.006582418456673622, + 0.0032474950421601534, + -0.0369916632771492, + -0.16640843451023102, + -0.28968843817710876, + -0.3531132638454437, + -0.26307201385498047, + -0.13392697274684906, + -0.03747623786330223, + 0.08083077520132065, + 0.2026241272687912, + 0.25018608570098877, + 0.2529378831386566, + 0.2307336926460266, + 0.13928599655628204, + 0.08631229400634766, + 0.13893137872219086, + 0.4867081344127655, + 0.7170669436454773, + 0.8331555724143982, + 0.6734364032745361, + 0.3549460768699646, + 0.16798041760921478, + -0.14487245678901672, + -0.47733625769615173, + -0.7670150995254517, + -0.875726580619812, + -0.6291986703872681, + -0.2910463213920593, + -0.09991979598999023, + -0.009158087894320488, + 0.018850643187761307, + 0.02646111696958542, + -0.009077857248485088, + 0.029430989176034927, + -0.03707962855696678, + -0.05111744999885559, + -0.02076525054872036, + 0.011828843504190445, + 0.017857171595096588, + 0.02548048458993435, + -0.009077494964003563, + 0.0022066361270844936, + -0.02064262516796589, + -0.008582246489822865, + -0.022748643532395363, + -0.03038850985467434, + 0.0006585497176274657, + -0.0016039719339460135, + -0.01612498238682747, + 0.013966801576316357, + -0.05851661041378975, + -0.21422894299030304, + -0.33863192796707153, + -0.3720807433128357, + -0.3030800521373749, + -0.1737397164106369, + -0.05903157964348793, + 0.15018144249916077, + 0.27454254031181335, + 0.31182464957237244, + 0.30118387937545776, + 0.24605700373649597, + 0.14123573899269104, + 0.14992672204971313, + 0.20660799741744995, + 0.5046274662017822, + 0.7706091403961182, + 0.8978630900382996, + 0.7368614673614502, + 0.3929724097251892, + 0.23079657554626465, + -0.21169082820415497, + -0.5920398235321045, + -0.893406867980957, + -0.9499238729476929, + -0.730407178401947, + -0.3615736961364746, + -0.15422092378139496, + -0.024615347385406494, + 0.005115498788654804, + 0.024657316505908966, + 0.028517475351691246, + 0.027910854667425156, + -0.009482389315962791, + -0.042242538183927536, + -0.017875321209430695, + 0.00430292496457696, + 0.015949612483382225, + 0.003636278910562396, + -0.018156034871935844, + -0.0009349065367132425, + -0.0010362856555730104, + -0.013051170855760574, + -0.009141271002590656, + -8.714485738892108e-05, + 0.02399279735982418, + 0.01753612607717514, + -0.013710699044167995, + -0.014245252124965191, + -0.0028008236549794674, + -0.08206935226917267, + -0.1098734438419342, + -0.10250325500965118, + -0.08874496072530746, + -0.031079040840268135, + 0.004536658991128206, + 0.03923843801021576, + 0.08478657901287079, + 0.07715648412704468, + 0.018803801387548447, + 0.013921198435127735, + 0.015864359214901924, + 0.04947463795542717, + 0.039856068789958954, + 0.1712094396352768, + 0.362756609916687, + 0.4192918539047241, + 0.2668488621711731, + 0.11430513113737106, + 0.06648365408182144, + -0.058979276567697525, + -0.24177154898643494, + -0.3709423542022705, + -0.3979431986808777, + -0.29706764221191406, + -0.11569518595933914, + -0.01848490908741951, + -0.015523962676525116, + 0.05081642046570778, + 0.09057094901800156, + 0.08520761132240295, + 0.04497350752353668, + -0.019453801214694977, + -0.06109466031193733, + 0.011463015340268612, + -0.008522219955921173, + -0.005283404141664505, + -0.017313135787844658, + -0.0015744483098387718, + -0.011845857836306095, + -0.016727561131119728, + -0.006708915811032057, + 0.0008860539528541267, + -0.010050912387669086, + -0.028460539877414703, + -0.0165643822401762, + -0.016545938327908516, + -0.00567589420825243, + -0.0032017906196415424, + -0.0130555285140872, + -0.026848897337913513, + -0.02615198865532875, + 0.002669057110324502, + -0.027966763824224472, + -0.03851256147027016, + -0.014509409666061401, + -0.029059220105409622, + -0.007284109480679035, + 0.04045313969254494, + 0.10005538910627365, + 0.014574537053704262, + -0.044292762875556946, + -0.01750861294567585, + -0.02231375314295292, + -0.004432118032127619, + 0.10051869601011276, + 0.1443023532629013, + 0.0508832149207592, + -0.04350621998310089, + -0.0025447055231779814, + -0.014583000913262367, + -0.02153291553258896, + 0.018860718235373497, + 0.03618147224187851, + 0.007304056081920862, + -0.029104959219694138, + 0.00576505484059453, + -0.016025763005018234, + -0.025094063952565193, + -0.05296780914068222, + -0.037012189626693726, + -0.04414081946015358, + -0.053135257214307785, + -0.028890708461403847, + -0.010220452211797237, + -0.027575822547078133, + -0.01087758969515562, + -0.027209162712097168, + -0.030827227979898453, + -0.007646164856851101, + -0.016133273020386696, + 0.000639698002487421, + -0.0034172122832387686, + 0.03914793208241463, + 0.030786357820034027, + 0.005965455900877714, + 0.020923329517245293, + -0.03435938432812691, + -0.0026781477499753237, + 0.04278327897191048, + 0.20045910775661469, + 0.21770593523979187, + 0.09422573447227478, + 0.03198440372943878, + -0.021056609228253365, + 0.028007682412862778, + 0.19196027517318726, + 0.4791645109653473, + 0.5333831906318665, + 0.3014310598373413, + 0.103666290640831, + -0.03651479259133339, + 0.027079502120614052, + 0.19239209592342377, + 0.5168290138244629, + 0.5564895868301392, + 0.2977963089942932, + 0.07770062237977982, + -0.042239490896463394, + -0.017265107482671738, + 0.08760321140289307, + 0.2775075435638428, + 0.312491774559021, + 0.12284757196903229, + 0.019664151594042778, + -0.026643047109246254, + 0.0009152573184110224, + 0.016156431287527084, + 0.09042830765247345, + 0.08991760015487671, + 0.013326293788850307, + 0.02613811008632183, + 0.021025240421295166, + 0.0198842640966177, + 0.03375901281833649, + 0.028616728261113167, + 0.026605166494846344, + 0.04126269370317459, + 0.029309948906302452, + 0.01408455427736044, + -0.003831037785857916, + 0.01922326348721981, + -0.018229445442557335, + -0.013015883974730968, + 0.017597628757357597, + -0.007964612916111946, + 0.045263469219207764, + 0.0184696726500988, + -0.001163159729912877, + -0.1809321641921997, + -0.22486254572868347, + -0.08606110513210297, + 0.001087217591702938, + 0.037091098725795746, + -0.013625397346913815, + -0.178089901804924, + -0.5483279824256897, + -0.612791895866394, + -0.32531827688217163, + -0.06506585329771042, + 0.05076128616929054, + -0.007585812360048294, + -0.20981833338737488, + -0.6155760884284973, + -0.7119701504707336, + -0.354442298412323, + -0.04236743599176407, + 0.045713260769844055, + 0.03192479908466339, + -0.07216271013021469, + -0.310979425907135, + -0.3656359910964966, + -0.13522450625896454, + 0.008291869424283504, + 0.03362602740526199, + -0.0009240762447007, + 0.01604474149644375, + -0.055634208023548126, + -0.06180194392800331, + 0.0222025066614151, + 0.027704820036888123, + -0.034385330975055695, + -0.07050742954015732, + -0.06287489086389542, + 0.03521641716361046, + -0.00020920530369039625, + 0.05458284169435501, + 0.058752644807100296, + -0.08097169548273087, + -0.01668735221028328, + 0.18557283282279968, + 0.26208117604255676, + 0.1253771185874939, + 0.07758381962776184, + -0.022084739059209824, + 0.016727397218346596, + 0.23247942328453064, + 0.35444316267967224, + 0.21802566945552826, + -0.04409221559762955, + -0.08573070168495178, + -0.0994141548871994, + 0.07754423469305038, + 0.14311672747135162, + 0.04036660119891167, + -0.29222917556762695, + -0.38828015327453613, + -0.26185816526412964, + -0.12845511734485626, + 0.04763585329055786, + -0.017382778227329254, + -0.16010743379592896, + -0.2395028918981552, + -0.2049665004014969, + -0.041346337646245956, + 0.091490738093853, + -0.005191737785935402, + -0.07687077671289444, + -0.08105621486902237, + -0.05329642817378044, + -0.03404862806200981, + 0.11478845030069351, + 0.13328343629837036, + -0.037197597324848175, + -0.01787363924086094, + -0.016605347394943237, + 0.007853846065700054, + 0.029950136318802834, + 0.10808859020471573, + 0.02873288467526436, + -0.1766187697649002, + -0.17560969293117523, + -0.03922238200902939, + 0.14447443187236786, + 0.1534212827682495, + 0.11272227019071579, + 0.008810695260763168, + -0.1485181748867035, + 0.07839693129062653, + 0.43013128638267517, + 0.4898712635040283, + 0.26522761583328247, + 0.10202436149120331, + -0.07163076847791672, + 0.09933187812566757, + 0.47377726435661316, + 0.6340300440788269, + 0.36741772294044495, + -0.04812543839216232, + -0.17370514571666718, + -0.17513291537761688, + 0.22105705738067627, + 0.3226463794708252, + 0.09850790351629257, + -0.4044247269630432, + -0.6237908601760864, + -0.4679968059062958, + -0.1954391747713089, + 0.09878316521644592, + -0.004430827684700489, + -0.31550562381744385, + -0.5235733985900879, + -0.4510284662246704, + -0.13843706250190735, + 0.10064390301704407, + -0.006748788990080357, + -0.12714813649654388, + -0.2107744812965393, + -0.18755048513412476, + -0.05646044388413429, + 0.12781813740730286, + 0.18928050994873047, + -0.04337320104241371, + -0.04973407834768295, + -0.04690375551581383, + 0.0245530866086483, + 0.10698680579662323, + 0.1646823137998581, + 0.081840381026268, + -0.01471243891865015, + -0.03138890117406845, + -0.04195617139339447, + 0.012708203867077827, + 0.033312954008579254, + 0.02409377694129944, + -0.0036440726835280657, + -0.06239784508943558, + 0.0037516560405492783, + 0.11261500418186188, + 0.13069754838943481, + 0.05901307612657547, + 0.048614490777254105, + -0.027712708339095116, + 0.027247682213783264, + 0.19195327162742615, + 0.2688453793525696, + 0.1509387195110321, + 0.020540937781333923, + -0.004100556951016188, + -0.012650247663259506, + 0.039176344871520996, + 0.09037251025438309, + -0.004689970053732395, + -0.23859903216362, + -0.2364242821931839, + -0.15189304947853088, + -0.0761493444442749, + -0.0028172829188406467, + -0.04328106716275215, + -0.16187387704849243, + -0.21743592619895935, + -0.1282283067703247, + -0.024501819163560867, + 0.04029383510351181, + -0.027387680485844612, + -0.05414740741252899, + -0.08344019204378128, + -0.06591048091650009, + 0.012637111358344555, + 0.06905930489301682, + 0.08426016569137573, + -0.0030199100729078054, + 0.034059297293424606, + 0.01111840270459652, + 0.013492933474481106, + 0.0674189031124115, + 0.08242739737033844, + 0.006129032466560602, + -0.07763395458459854, + -0.03002289868891239, + -0.055725954473018646, + 0.008795201778411865, + 0.02994825504720211, + -0.06114519387483597, + -0.0560108907520771, + -0.008179228752851486, + -0.07149285078048706, + -0.02700420655310154, + -0.01306728646159172, + 0.06276566535234451, + 0.007125973701477051, + -0.03540417551994324, + -0.039717916399240494, + 0.009147526696324348, + -0.06517947465181351, + 0.0720859095454216, + -0.05035398155450821, + 0.06659520417451859, + -0.01841895841062069, + 0.004233633633702993, + -0.020911216735839844, + -0.004646372981369495, + 1.6690073013305664, + 0.4517613649368286, + -0.07667035609483719, + 0.005556757096201181, + -0.02638973295688629, + 0.044588603079319, + -0.020916732028126717, + 0.2571280598640442, + -0.009559552185237408, + -0.043380800634622574, + 0.03196016326546669, + -0.03783237189054489, + -0.03076902963221073, + 0.03180111199617386, + 0.06352709978818893, + 0.020281998440623283, + -0.00741154421120882, + -0.0009214285528287292, + -0.0476187989115715, + -0.07208544760942459, + -0.05323023349046707, + -0.011103631928563118, + 0.02877136506140232, + -0.05324484035372734, + -0.10076326876878738, + 0.026193000376224518, + 0.03536469116806984, + 0.045722659677267075, + -0.03756006807088852, + 0.022998394444584846, + 0.0019359687576070428, + 0.01654801517724991, + 0.047304198145866394, + -0.08431598544120789, + -0.0645647644996643, + -0.17326746881008148, + -0.10692577064037323, + -0.08416426181793213, + -0.04107839986681938, + -0.0012680464424192905, + -0.02600814774632454, + -0.014215772971510887, + 0.2114446610212326, + -0.040954578667879105, + -0.05050172284245491, + 0.004194092936813831, + -0.0025900816544890404, + -0.1359374076128006, + 0.03946976363658905, + 2.3023669719696045, + 0.7484877109527588, + -0.1994970589876175, + -0.06490366160869598, + 0.007983183488249779, + -0.017937449738383293, + -0.12516839802265167, + 0.3313288688659668, + 0.11946671456098557, + -0.16942338645458221, + -0.007721045054495335, + 0.02824605070054531, + -0.05310647189617157, + -0.1122083067893982, + -0.17094524204730988, + -0.08465421944856644, + -0.09679102897644043, + -0.03848385065793991, + 0.040121182799339294, + -0.06661732494831085, + 0.0005764663219451904, + -0.05729356408119202, + -0.04778655245900154, + -0.034835152328014374, + -0.07634143531322479, + -0.05054831504821777, + 0.00597620103508234, + 0.04499154910445213, + -0.03308190405368805, + -0.04915233701467514, + -0.05842791870236397, + 0.003590918146073818, + 0.055837079882621765, + -0.02547842636704445, + -0.018847621977329254, + -0.2073899656534195, + -0.14987564086914062, + -0.03971748799085617, + 0.05886378139257431, + 0.020922083407640457, + -0.039155181497335434, + -0.028855402022600174, + 0.08688661456108093, + -0.1402827501296997, + -0.05810496211051941, + 0.037841811776161194, + -0.04082907736301422, + -0.1191127747297287, + -0.10852136462926865, + 1.6274418830871582, + 0.3678200840950012, + -0.2865799367427826, + -0.05291350558400154, + 0.023858532309532166, + -0.046683818101882935, + -0.2307816743850708, + -0.001670230645686388, + -0.17716962099075317, + -0.16724731028079987, + 0.040194038301706314, + -0.023075448349118233, + -0.01538322027772665, + -0.07914327085018158, + -0.19621343910694122, + -0.11628971993923187, + -0.05851752683520317, + 0.06313594430685043, + 0.017808571457862854, + 0.02447943389415741, + 0.048611078411340714, + -0.009247995913028717, + 0.00789090245962143, + 0.06673033535480499, + 0.0661577433347702, + 0.019111329689621925, + 0.038164373487234116, + 0.029342610388994217, + -0.03547409921884537, + -0.11017149686813354, + -0.11077891290187836, + 0.001108204829506576, + -0.0330691784620285, + -0.05039837956428528, + 0.017638904973864555, + 0.277705579996109, + 0.5606598258018494, + 0.5469182133674622, + 0.13591277599334717, + 0.012421006336808205, + 0.046348799020051956, + -0.02721901424229145, + -0.5645118355751038, + -1.072814702987671, + -0.9852984547615051, + -0.3608386516571045, + -0.010197073221206665, + -0.09785731136798859, + -0.02597353421151638, + 0.4627133309841156, + 1.1483618021011353, + 0.9505703449249268, + 0.17471027374267578, + -0.016467586159706116, + 0.026623696088790894, + 0.04765752702951431, + -0.4000166058540344, + -0.8956774473190308, + -0.6268588304519653, + -0.09439487755298615, + 0.02861764468252659, + -0.004155704285949469, + 0.08989865332841873, + 0.27384331822395325, + 0.6518518328666687, + 0.4184596836566925, + 0.13106893002986908, + 0.0050344159826636314, + 0.007061495911329985, + -0.016157688573002815, + -0.1364346295595169, + -0.27324289083480835, + -0.14245718717575073, + -0.04623992741107941, + -0.015541884116828442, + 0.030779436230659485, + 0.03756715729832649, + 0.01957445964217186, + -0.04964561015367508, + -0.0211405660957098, + 0.044496409595012665, + -0.026335055008530617, + -0.11620140820741653, + -0.11803250014781952, + 0.18242181837558746, + 0.5057784914970398, + 0.5045838952064514, + 0.03748183697462082, + 0.05692485347390175, + 0.1608155369758606, + 0.02245517633855343, + -0.7651812434196472, + -1.5504053831100464, + -1.3563542366027832, + -0.4314505457878113, + -0.028384560719132423, + -0.12238024920225143, + 0.106974296271801, + 1.11427903175354, + 2.173083543777466, + 1.747692346572876, + 0.5455064177513123, + 0.03363418206572533, + 0.11388687789440155, + -0.05905687436461449, + -0.8059568405151367, + -1.6196117401123047, + -1.1898213624954224, + -0.2654758095741272, + -0.004251840524375439, + -0.0916782096028328, + -0.024067873135209084, + 0.22692462801933289, + 0.6695711612701416, + 0.3673460781574249, + -0.017016466706991196, + -0.029604146257042885, + 0.020365707576274872, + 0.03215239942073822, + 0.0070981839671730995, + -0.14026938378810883, + -0.02425236999988556, + 0.059152450412511826, + -0.006319367326796055, + 0.003989882301539183, + 0.048541076481342316, + 0.003988460637629032, + -0.03105335496366024, + -0.08329232037067413, + 0.03226872906088829, + 0.02119620516896248, + -0.0953872874379158, + -0.15174035727977753, + 0.07963212579488754, + 0.29094186425209045, + 0.2690921127796173, + -0.020104877650737762, + 0.024988379329442978, + 0.15326620638370514, + 0.1256464123725891, + -0.40941280126571655, + -0.946648120880127, + -0.8358487486839294, + -0.14284957945346832, + -0.07980851829051971, + -0.1435413807630539, + 0.038134895265102386, + 0.8021518588066101, + 1.552701473236084, + 1.2496209144592285, + 0.38152581453323364, + 0.07136060297489166, + 0.14329172670841217, + -0.06546801328659058, + -0.5923707485198975, + -1.253793478012085, + -0.9458200335502625, + -0.156633198261261, + -0.04217473417520523, + -0.11199303716421127, + -0.07520301640033722, + 0.15331010520458221, + 0.4794600307941437, + 0.2449675053358078, + -0.10396319627761841, + 0.0034801275469362736, + 0.04475663974881172, + 0.024035215377807617, + 0.056806568056344986, + -0.07363307476043701, + -0.001563104335218668, + 0.05157755687832832, + 0.043718185275793076, + 0.02102719619870186, + 0.11859089881181717, + 0.08675580471754074, + -0.13180124759674072, + -0.15522590279579163, + 0.03273458778858185, + -0.0019622649997472763, + 0.1011638194322586, + -0.10800585150718689, + -0.6884365677833557, + -0.5495791435241699, + 0.0780424103140831, + 0.33674973249435425, + -0.21274283528327942, + -0.4183696210384369, + -0.8053947687149048, + 0.03347628563642502, + 1.3938312530517578, + 0.9454176425933838, + -0.012210174463689327, + 0.04924672842025757, + 0.16284359991550446, + 1.1340152025222778, + 2.0020322799682617, + 0.2796843647956848, + -0.968036413192749, + -0.5768532752990723, + 0.17757350206375122, + 0.37485063076019287, + 0.11534234136343002, + -1.2916942834854126, + -1.692176103591919, + -0.30523377656936646, + 0.14307916164398193, + 0.03928302228450775, + -0.19196964800357819, + -0.4533900022506714, + -0.3294944167137146, + 0.5480389595031738, + 0.4497548043727875, + 0.2170887440443039, + -0.05817069113254547, + -0.06957870721817017, + 0.03169052675366402, + 0.23751793801784515, + 0.0823391005396843, + -0.04811413958668709, + -0.051265716552734375, + -0.0395645909011364, + -0.03849785774946213, + 0.04607917368412018, + 0.09946659207344055, + -0.029992828145623207, + -0.05369366332888603, + -0.005230880342423916, + 0.012808755040168762, + 0.1821947544813156, + 0.05478882044553757, + -0.47736144065856934, + -0.44480830430984497, + -0.036321353167295456, + 0.13646431267261505, + -0.04045571759343147, + -0.21837295591831207, + -0.6888197660446167, + -0.08431777358055115, + 0.96018385887146, + 0.6788493990898132, + 0.011028020642697811, + 0.05917810648679733, + 0.02488739602267742, + 0.6898419857025146, + 1.4259209632873535, + 0.13193827867507935, + -0.8078985810279846, + -0.31056249141693115, + 0.018122224137187004, + 0.137860506772995, + 0.051947757601737976, + -0.9757952094078064, + -1.1060559749603271, + 0.06675099581480026, + 0.2091575562953949, + -0.029623042792081833, + -0.0705878809094429, + -0.18514159321784973, + -0.07947035878896713, + 0.5719470381736755, + 0.2286168485879898, + -0.03433626517653465, + 0.0036030709743499756, + 0.006251791957765818, + 0.04144154116511345, + 0.08598234504461288, + -0.050599172711372375, + -0.10440917313098907, + -0.02927244082093239, + -0.04102599248290062, + -0.07101748138666153, + -0.03579306975007057, + 0.03586365282535553, + 0.06752362847328186, + 0.048901572823524475, + -0.020898710936307907, + -0.009411930106580257, + 0.10169848799705505, + 0.1812015175819397, + -0.014482695609331131, + -0.12548771500587463, + -0.060731250792741776, + -0.034499138593673706, + 0.0829617902636528, + 0.04616715386509895, + -0.20867496728897095, + -0.1990129053592682, + 0.1773940473794937, + 0.13156233727931976, + -0.03437860682606697, + 0.04012921825051308, + -0.11132699251174927, + -0.023460939526557922, + 0.2713286876678467, + -0.06662362813949585, + -0.2709292471408844, + -0.0030232456047087908, + -0.10379529744386673, + -0.07136038690805435, + 0.03757762163877487, + -0.20515622198581696, + -0.1231834888458252, + 0.26915228366851807, + 0.0998353362083435, + -0.031466737389564514, + 0.04657471179962158, + 0.07664929330348969, + 0.10308870673179626, + 0.23429608345031738, + -0.06942534446716309, + -0.09051290899515152, + 0.03243685141205788, + 0.04053235426545143, + -0.021392958238720894, + -0.05330868810415268, + -0.11525140702724457, + -0.03889385238289833, + 0.01636480540037155, + -0.009352890774607658, + 0.13151532411575317, + -0.14738643169403076, + -0.18289834260940552, + 0.15955400466918945, + -0.001023759599775076, + 0.028809679672122, + 0.012261062860488892, + 0.29654747247695923, + -0.285063236951828, + -0.40187928080558777, + 0.3713407516479492, + 0.009383893571794033, + -0.023022817447781563, + -0.003799814498052001, + 0.48470190167427063, + -0.43402406573295593, + -0.5858806371688843, + 0.5751441717147827, + 0.05045031011104584, + -0.05559438094496727, + -0.02045449987053871, + 0.5281224250793457, + -0.5058223605155945, + -0.5950849056243896, + 0.6492323279380798, + 0.013408469036221504, + -0.05940670147538185, + -0.0044364179484546185, + 0.3112560212612152, + -0.34908774495124817, + -0.42427319288253784, + 0.43349501490592957, + 0.03724945709109306, + -0.05263671651482582, + -0.010485195554792881, + 0.1261255145072937, + -0.1349790245294571, + -0.2524855136871338, + 0.24608080089092255, + 0.036001257598400116, + -0.028843939304351807, + 0.0056989979930222034, + 0.04458172619342804, + -0.06122935935854912, + -0.166972354054451, + 0.14557687938213348, + 0.018050044775009155, + 0.032598987221717834, + -0.0055792503990232944, + 0.24355076253414154, + -0.21433626115322113, + -0.29646870493888855, + 0.1958809792995453, + 0.015435033477842808, + 0.05235098674893379, + 0.010786890983581543, + 0.47903597354888916, + -0.4127257168292999, + -0.6203306317329407, + 0.47024452686309814, + 0.0823090448975563, + -0.04538045823574066, + -0.004072466865181923, + 0.7509317994117737, + -0.6508772969245911, + -0.8481631278991699, + 0.7875698208808899, + 0.0966777428984642, + -0.10461349785327911, + 0.0063789174892008305, + 0.7535857558250427, + -0.8082649111747742, + -0.8165622353553772, + 0.9064085483551025, + 0.04986630380153656, + -0.10200339555740356, + 0.0314355194568634, + 0.46324053406715393, + -0.5523763298988342, + -0.5632953643798828, + 0.6378755569458008, + 0.07833302766084671, + -0.07979781180620193, + 0.031164664775133133, + 0.1967470794916153, + -0.21681970357894897, + -0.29283079504966736, + 0.3367702066898346, + 0.034929461777210236, + -0.047199901193380356, + -0.0033645557705312967, + 0.05454660952091217, + -0.11264829337596893, + -0.190998375415802, + 0.17961400747299194, + 0.0009085010970011353, + -0.0001827089727157727, + 0.04841821268200874, + 0.019923821091651917, + -0.07004066556692123, + -0.10590090602636337, + 0.054114967584609985, + 0.04302384704351425, + 0.00462615629658103, + 0.022948985919356346, + 0.1673787385225296, + -0.1319379210472107, + -0.2711219787597656, + 0.2387620061635971, + 0.05667697265744209, + -0.018639734014868736, + -0.07672597467899323, + 0.3503187298774719, + -0.2981504797935486, + -0.38647517561912537, + 0.4072522521018982, + 0.010913677513599396, + -0.05246961489319801, + -0.04058554396033287, + 0.39216771721839905, + -0.3605193495750427, + -0.34857264161109924, + 0.46899959444999695, + -0.03358001261949539, + -0.05188553035259247, + -0.023204902186989784, + 0.17140533030033112, + -0.2120431810617447, + -0.2144550085067749, + 0.2837989032268524, + -0.0191226527094841, + -0.020922169089317322, + 0.004324179142713547, + 0.038136694580316544, + -0.042803723365068436, + -0.11487454175949097, + 0.11820490658283234, + 0.003412557765841484, + 0.0035020115319639444, + 0.03646541014313698, + -0.010104459710419178, + -0.010897459462285042, + -0.09292570501565933, + 0.06823977828025818, + 0.02677192911505699, + 0.020071662962436676, + 0.005776307079941034, + 0.02613351307809353, + 0.017107944935560226, + -0.0002623539185151458, + -0.039298396557569504, + -0.0314190648496151, + -0.019773684442043304, + -0.01924789510667324, + 0.04253160580992699, + 0.09694722294807434, + 0.1925637573003769, + 0.1901547759771347, + 0.09470294415950775, + -0.00296174269169569, + -0.03602522239089012, + 0.03572473302483559, + 0.08787581324577332, + 0.1773553043603897, + 0.20970025658607483, + 0.14899243414402008, + 0.05427362397313118, + -0.032429151237010956, + 0.023915717378258705, + 0.06557436287403107, + 0.13488733768463135, + 0.17550915479660034, + 0.17485061287879944, + 0.10260436683893204, + -0.005381361581385136, + -0.05573735386133194, + -0.09410752356052399, + -0.07940010726451874, + -0.03424998000264168, + 0.007975265383720398, + 0.028827181085944176, + 0.023788832128047943, + -0.02962818741798401, + -0.13474339246749878, + -0.22529757022857666, + -0.20413516461849213, + -0.14711618423461914, + -0.05960607901215553, + 0.04579121991991997, + 0.005325576290488243, + -0.11592217534780502, + -0.2260522097349167, + -0.2467145025730133, + -0.22054187953472137, + -0.13919179141521454, + 0.0016459478065371513, + 0.0515579916536808, + 0.060555730015039444, + 0.040788713842630386, + -0.017907800152897835, + -0.026459651067852974, + -0.02488812990486622, + 0.015644825994968414, + 0.10543125867843628, + 0.19312354922294617, + 0.28380078077316284, + 0.28878358006477356, + 0.16968156397342682, + 0.04848042502999306, + -0.00986899808049202, + 0.06337545067071915, + 0.16356752812862396, + 0.2444516271352768, + 0.29273414611816406, + 0.2314801961183548, + 0.12695762515068054, + -0.022283215075731277, + 0.018402203917503357, + 0.07152476161718369, + 0.14247483015060425, + 0.18759845197200775, + 0.20828258991241455, + 0.14114585518836975, + -0.047197990119457245, + -0.13794781267642975, + -0.17509934306144714, + -0.1696663200855255, + -0.1206701323390007, + -0.036128126084804535, + 0.007180679589509964, + 0.006984225939959288, + -0.09600912779569626, + -0.22975720465183258, + -0.33287662267684937, + -0.2942708134651184, + -0.20305578410625458, + -0.08411446958780289, + 0.042896877974271774, + -0.020053744316101074, + -0.16365791857242584, + -0.3145587742328644, + -0.3321540057659149, + -0.2667454183101654, + -0.1542910486459732, + -0.006954069249331951, + 0.020191870629787445, + 0.014010002836585045, + 0.0016916356980800629, + -0.04649524390697479, + -0.014931428246200085, + -0.017954425886273384, + -0.020003901794552803, + 0.03831968829035759, + 0.08447518199682236, + 0.14068123698234558, + 0.13400419056415558, + 0.08205568045377731, + -0.0004489773709792644, + -0.019211264327168465, + 0.023363608866930008, + 0.08738930523395538, + 0.12299696356058121, + 0.13070489466190338, + 0.09040816128253937, + 0.03286544978618622, + -0.006979941390454769, + -0.0010930931894108653, + 0.04313739389181137, + 0.10121051222085953, + 0.11390950530767441, + 0.11383924633264542, + 0.06694260239601135, + -0.00425445893779397, + -0.0666416585445404, + -0.09225274622440338, + -0.0977785512804985, + -0.07118111103773117, + -0.026749763637781143, + -0.019425569102168083, + 0.03321055322885513, + -0.0033978468272835016, + -0.08309262245893478, + -0.15557922422885895, + -0.14969374239444733, + -0.07188998907804489, + -0.018716221675276756, + 0.022834330797195435, + 0.004232254344969988, + -0.04141783341765404, + -0.125192329287529, + -0.14545302093029022, + -0.12225300818681717, + -0.05844716727733612, + 0.010607236064970493, + 0.024218380451202393, + -0.002702374942600727, + -0.030814893543720245, + 0.03507756441831589, + -0.0506589449942112, + 0.03415676951408386, + 0.0011444400297477841, + 0.0026324463542550802, + 0.028514407575130463, + -0.01849454641342163, + -0.030959082767367363, + -0.05565863475203514, + 0.05771413818001747, + 0.003916156478226185, + -0.004474544432014227, + 0.04403551295399666, + 0.1733711212873459, + -0.37650829553604126, + 0.22322984039783478, + 0.0032540319953113794, + -0.01139416079968214, + -0.039046600461006165, + 0.0021948080975562334, + 0.5777754783630371, + -1.1944804191589355, + 0.769478976726532, + -0.1349843591451645, + 0.0004430754925124347, + -0.0061850035563111305, + -0.08340868353843689, + 0.8327823877334595, + -1.649588942527771, + 1.126111388206482, + -0.2918313145637512, + 0.003614947199821472, + 0.0016799914883449674, + -0.03255167230963707, + 0.6123784184455872, + -1.1993682384490967, + 0.8305437564849854, + -0.13622376322746277, + 0.00905851274728775, + -0.006772476714104414, + 0.07578610628843307, + 0.05859832838177681, + -0.4543764293193817, + 0.26330503821372986, + 0.0259060300886631, + -0.0007997890934348106, + 0.01269856933504343, + 0.006897627376019955, + -0.02491801232099533, + -0.03139931708574295, + 0.0028456314466893673, + 0.0008253560517914593, + -0.01086023822426796, + -0.004186873324215412, + 0.06299160420894623, + -0.039931319653987885, + -0.09315146505832672, + 0.05495935305953026, + 0.027547571808099747, + -0.010900916531682014, + -0.025233760476112366, + 0.060600072145462036, + 0.21010243892669678, + -0.5445898771286011, + 0.35070353746414185, + -0.033771682530641556, + -0.0269146841019392, + -0.025363197550177574, + -0.021729450672864914, + 0.70921790599823, + -1.4368270635604858, + 0.9582043290138245, + -0.1708265244960785, + 0.010022420436143875, + -0.032301150262355804, + -0.08667651563882828, + 1.0338889360427856, + -1.913576364517212, + 1.262008547782898, + -0.23795078694820404, + -0.032233912497758865, + -0.01397701445966959, + -0.05402921140193939, + 0.7621430158615112, + -1.387437343597412, + 0.8621506094932556, + -0.14765247702598572, + -0.004747485741972923, + 0.0017516895895823836, + 0.08154146373271942, + 0.16601374745368958, + -0.5324177742004395, + 0.27442997694015503, + 0.03274058923125267, + -0.008812552317976952, + 0.005774920806288719, + 0.04165825620293617, + -0.011749272234737873, + -0.01953396573662758, + -0.009672109968960285, + 0.01170953270047903, + 0.003071938641369343, + -0.018979815766215324, + 0.062123894691467285, + -0.004921444226056337, + -0.03380037844181061, + 0.01310884952545166, + 0.007953890599310398, + -0.0012086924398317933, + -0.03317898139357567, + -0.0015596294542774558, + 0.08166785538196564, + -0.2291223704814911, + 0.11783571541309357, + -0.016078786924481392, + 0.018957575783133507, + 0.025793947279453278, + -0.09036394208669662, + 0.3833881616592407, + -0.5794023871421814, + 0.4610825777053833, + -0.14165280759334564, + -0.007412370759993792, + 0.05252876877784729, + -0.21435455977916718, + 0.6177686452865601, + -0.8516795635223389, + 0.667263925075531, + -0.22572898864746094, + -0.004465761594474316, + 0.02589319832623005, + -0.1893543303012848, + 0.43213585019111633, + -0.6462821364402771, + 0.434274822473526, + -0.15750259160995483, + -0.01198036689311266, + -2.4281514924950898e-05, + 0.039562296122312546, + 0.11126027256250381, + -0.23193514347076416, + 0.1412443071603775, + -0.011839920654892921, + 0.007880321703851223, + 0.02950354479253292, + 0.011689653620123863, + -0.07272310554981232, + -0.03319466486573219, + -0.003948990721255541, + 0.03549842908978462, + -0.02165558747947216, + -0.09912239760160446, + -0.08742356300354004, + 0.30591821670532227, + 0.23934677243232727, + 0.02658180706202984, + -0.022127188742160797, + -0.02769642136991024, + 0.16399237513542175, + 0.5140998959541321, + 0.007951628416776657, + -0.5589093565940857, + -0.24106110632419586, + -0.02753414213657379, + 0.06947467476129532, + 0.048558495938777924, + -0.5370690822601318, + -0.761831521987915, + 0.16272802650928497, + 0.29426246881484985, + 0.07943751662969589, + -0.022394873201847076, + -0.217612162232399, + -0.03093647211790085, + 0.5945476293563843, + 0.2873935103416443, + -0.16481661796569824, + -0.02931203693151474, + -0.029083512723445892, + 0.06754925847053528, + 0.20200076699256897, + -0.07271742075681686, + -0.1976277083158493, + -0.04189611226320267, + 0.06403793394565582, + -0.00022445111244451255, + -0.01032529678195715, + -0.03415631130337715, + 0.009091783314943314, + 0.04317992925643921, + 0.07196266949176788, + -0.025028688833117485, + -0.02722775563597679, + -0.017168480902910233, + -0.027666645124554634, + -0.06734028458595276, + 0.10843724757432938, + 0.08066407591104507, + -0.027849983423948288, + -0.0045820740051567554, + -0.03388727456331253, + 0.16772156953811646, + 0.651636004447937, + 0.34874194860458374, + -0.1454945057630539, + -0.18056720495224, + 0.11703842133283615, + 0.43017855286598206, + 0.7624525427818298, + -0.3420296907424927, + -1.272199273109436, + -0.5284644365310669, + -0.005667245015501976, + 0.08240436762571335, + -0.13299596309661865, + -1.3164156675338745, + -1.659982442855835, + 0.19898656010627747, + 0.6253566741943359, + 0.25137946009635925, + -0.18244975805282593, + -0.5360167622566223, + -0.06195700913667679, + 1.2547520399093628, + 1.0296341180801392, + 0.10651036351919174, + -0.023540280759334564, + -0.07594245672225952, + 0.1492130160331726, + 0.5033117532730103, + 0.09394379705190659, + -0.22459803521633148, + -0.22473134100437164, + -0.04738321527838707, + 0.04127531498670578, + 0.0682951882481575, + -0.02095615118741989, + -0.1233135387301445, + -0.10028401762247086, + -0.008111395873129368, + -0.000617706507910043, + 0.018859047442674637, + 0.028446361422538757, + -0.06159031391143799, + -0.1292838156223297, + 0.051308393478393555, + 0.11001072078943253, + -0.02056661807000637, + -0.012175443582236767, + -0.1313694268465042, + 0.0067574759013950825, + 0.4612729251384735, + 0.323080450296402, + -0.09392253309488297, + -0.1256203055381775, + 0.03537299111485481, + 0.2556088864803314, + 0.6467183232307434, + -0.16340143978595734, + -0.8799455165863037, + -0.3312987685203552, + 0.01464154850691557, + 0.07046713680028915, + 0.053634822368621826, + -0.8514915108680725, + -1.176972508430481, + 0.2056443840265274, + 0.4998764395713806, + 0.1268644779920578, + -0.10905193537473679, + -0.3750888705253601, + -0.06701061874628067, + 0.9052186608314514, + 0.6792045831680298, + -0.00323892361484468, + -0.0007412935374304652, + -0.03608793020248413, + 0.1009129211306572, + 0.36775916814804077, + 0.035214491188526154, + -0.2273784875869751, + -0.15815992653369904, + -0.004773923195898533, + 0.06374036520719528, + 0.04737555980682373, + -0.0563247986137867, + -0.09587392956018448, + -0.043853096663951874, + 0.032572731375694275, + -0.0036250585690140724, + 0.07889056205749512, + -0.03589344769716263, + -0.019771328195929527, + 0.04937156289815903, + 0.039052557200193405, + -0.013377528637647629, + -0.0841481015086174, + -0.03358105197548866, + -0.2128981053829193, + -0.14468812942504883, + 0.14675867557525635, + 0.2550889551639557, + 0.22369499504566193, + -0.0032973098568618298, + 0.006679064594209194, + -0.11752036958932877, + 0.025247232988476753, + 0.23064176738262177, + 0.25043538212776184, + 0.3474777638912201, + 0.2151806503534317, + 0.051294319331645966, + 0.16301114857196808, + 0.25422143936157227, + -0.1796918362379074, + -0.6128425598144531, + -0.42049655318260193, + 0.07740531116724014, + -0.007960617542266846, + 0.2504507601261139, + 0.2932300865650177, + -0.5157915949821472, + -1.2904177904129028, + -1.0362532138824463, + -0.22443994879722595, + 0.007411653641611338, + 0.16024430096149445, + 0.33939966559410095, + -0.2748318016529083, + -0.8487470149993896, + -0.5955387949943542, + 0.033155132085084915, + -0.09185351431369781, + -0.05639262869954109, + 0.17084303498268127, + 0.11292264610528946, + -0.046329669654369354, + 0.11495561897754669, + 0.31740760803222656, + -0.13903948664665222, + 0.05507560819387436, + 0.10180198401212692, + -0.1369788944721222, + -0.10618618875741959, + -0.001083499751985073, + 0.16340164840221405, + 0.07591762393712997, + 0.3417445123195648, + 0.27897438406944275, + -0.32192930579185486, + -0.5731648206710815, + -0.46150147914886475, + -0.03230089321732521, + 0.04096771031618118, + 0.22242987155914307, + 0.027000218629837036, + -0.4113498628139496, + -0.433158278465271, + -0.5252256393432617, + -0.3510502874851227, + -0.133863165974617, + -0.38554033637046814, + -0.45547229051589966, + 0.2475612610578537, + 1.154951572418213, + 0.8282179236412048, + -0.13197137415409088, + -0.03350961208343506, + -0.5282800197601318, + -0.5297923684120178, + 0.9037952423095703, + 2.516275405883789, + 2.086421489715576, + 0.3573826849460602, + -0.010694397613406181, + -0.31418153643608093, + -0.5325371026992798, + 0.48083701729774475, + 1.7732245922088623, + 1.2747145891189575, + -0.06401863694190979, + 0.14296381175518036, + 0.07267159968614578, + -0.28001847863197327, + -0.29204103350639343, + 0.12853951752185822, + -0.1998838633298874, + -0.6375644207000732, + 0.06310836225748062, + -0.020014479756355286, + -0.08150970935821533, + 0.08175478130578995, + 0.07667485624551773, + 0.0025236753281205893, + -0.08504530042409897, + -0.035742271691560745, + -0.1332666128873825, + -0.15150736272335052, + 0.18459312617778778, + 0.3363596200942993, + 0.2501969635486603, + 0.029292423278093338, + -0.060296736657619476, + -0.1142202764749527, + -0.05918247997760773, + 0.18826954066753387, + 0.2183520495891571, + 0.21247169375419617, + 0.14935970306396484, + 0.09923429787158966, + 0.21808095276355743, + 0.21930061280727386, + -0.060535889118909836, + -0.5729222297668457, + -0.4199080169200897, + 0.058897778391838074, + 0.050647757947444916, + 0.2784770131111145, + 0.2754706144332886, + -0.40136128664016724, + -1.3269731998443604, + -1.124815583229065, + -0.11878778040409088, + -0.005137663800269365, + 0.17839783430099487, + 0.2115524858236313, + -0.24165289103984833, + -0.9655010104179382, + -0.7425088286399841, + 0.0304054357111454, + -0.07012742757797241, + -0.015557953156530857, + 0.1128007024526596, + 0.18957749009132385, + -0.07996463775634766, + 0.09505810588598251, + 0.34419506788253784, + -0.3072076439857483, + 0.03868290036916733, + 0.11494885385036469, + 0.03748936951160431, + 0.0797261893749237, + -0.003397951368242502, + -0.07380004972219467, + -0.11507676541805267, + -0.10298885405063629, + 0.10698320716619492, + 0.06602972000837326, + 0.08226803690195084, + 0.0037747276946902275, + -0.162277951836586, + 0.01671667955815792, + 0.09137773513793945, + 0.18799471855163574, + 0.04144813120365143, + 0.1285877376794815, + 0.1820434182882309, + 0.04940629005432129, + 0.0991915687918663, + 0.10219171643257141, + -0.013141660951077938, + -0.051191627979278564, + 0.05468929558992386, + 0.087598517537117, + 0.15897324681282043, + 0.11863455921411514, + -0.00814050156623125, + -0.07701541483402252, + -0.14013728499412537, + -0.044140227138996124, + -0.05328791216015816, + 0.06760499626398087, + 0.12053386867046356, + 0.09780212491750717, + -0.053725965321063995, + -0.07915244251489639, + -0.0032519602682441473, + 0.019637396559119225, + 0.07848430424928665, + 0.019138827919960022, + 0.1460287868976593, + 0.1281038075685501, + 0.024417784065008163, + 0.059176862239837646, + 0.0658111497759819, + -0.016405148431658745, + -0.18877744674682617, + 0.16666102409362793, + 0.1610611230134964, + 0.08374520391225815, + 0.11570518463850021, + 0.11903064697980881, + 0.1294964700937271, + 0.06379758566617966, + 0.08417274057865143, + 0.12754113972187042, + 0.025328608229756355, + 0.05170705169439316, + 0.0835295170545578, + 0.07477264851331711, + 0.11244285851716995, + 0.11559426784515381, + 0.045258160680532455, + -0.14825093746185303, + -0.08153342455625534, + 0.06288623809814453, + 0.11952362209558487, + 0.11784297972917557, + 0.011141132563352585, + -0.21666541695594788, + -0.29976174235343933, + -0.2279169261455536, + -0.11828474700450897, + 0.12436322867870331, + 0.10465826094150543, + -0.09751085937023163, + -0.292611300945282, + -0.37374064326286316, + -0.31437963247299194, + -0.25637903809547424, + 0.06173908710479736, + 0.14131486415863037, + 0.008434675633907318, + -0.23816508054733276, + -0.30330890417099, + -0.22094152867794037, + -0.11608295142650604, + 0.13235151767730713, + 0.15353602170944214, + 0.15839524567127228, + 0.012247815728187561, + -0.08126968890428543, + -0.003756331978365779, + 0.10660683363676071, + 0.21976575255393982, + -0.04188326746225357, + 0.15462253987789154, + 0.06303395330905914, + 0.006879634689539671, + 0.008284888230264187, + 0.07084798067808151, + 0.1211942657828331, + 0.10190404951572418, + 0.02935362420976162, + -0.05645999684929848, + -0.16800500452518463, + -0.1850246787071228, + -0.09476880729198456, + -0.025327544659376144, + 0.054355036467313766, + -0.035813912749290466, + -0.18694879114627838, + -0.34871891140937805, + -0.3151862621307373, + -0.1943007856607437, + -0.09755205363035202, + 0.014881589449942112, + -0.14875493943691254, + -0.37112873792648315, + -0.37739917635917664, + -0.3241480886936188, + -0.2915399968624115, + -0.11268249899148941, + -0.019726404920220375, + -0.2510305941104889, + -0.38005372881889343, + -0.3622463345527649, + -0.2932804226875305, + -0.28574010729789734, + -0.1505027860403061, + -0.004947682376950979, + -0.18587322533130646, + -0.34759166836738586, + -0.28965193033218384, + -0.21052972972393036, + -0.18780536949634552, + -0.07400713860988617, + 0.11154936999082565, + -0.03556853160262108, + -0.1896934062242508, + -0.18135806918144226, + -0.10117948800325394, + -0.0393117293715477, + 0.06517928093671799, + -0.016659021377563477, + -0.011290309950709343, + -0.007930322550237179, + 0.008189777843654156, + 0.03678786754608154, + 0.021890517324209213, + 0.0034292477648705244, + 0.02200375869870186, + 0.0014921070542186499, + -0.0800287202000618, + -0.17657361924648285, + -0.18702608346939087, + -0.12880444526672363, + -0.022084584459662437, + 0.026420501992106438, + -0.023968446999788284, + -0.07948111742734909, + -0.16741475462913513, + -0.18733707070350647, + -0.16539834439754486, + -0.07347387820482254, + -0.009723886847496033, + -0.02016977220773697, + -0.061092622578144073, + -0.13145211338996887, + -0.15919029712677002, + -0.15043555200099945, + -0.10107766091823578, + 0.0016151965828612447, + 0.0627974420785904, + 0.08695066720247269, + 0.11727584898471832, + 0.11745581030845642, + 0.11329426616430283, + 0.0533670075237751, + -0.016355818137526512, + 0.008450252935290337, + 0.06448577344417572, + 0.1538505256175995, + 0.21232697367668152, + 0.14713847637176514, + 0.039088234305381775, + -0.015588105656206608, + 0.026483291760087013, + 0.060862988233566284, + 0.18265819549560547, + 0.23042462766170502, + 0.168768972158432, + 0.034099943935871124, + -0.018249109387397766, + -0.0321880541741848, + -0.03254542127251625, + -0.03061222843825817, + -0.0026304698549211025, + 0.017764942720532417, + 0.010707704350352287, + 0.009254949167370796, + -0.04533161595463753, + -0.1483704000711441, + -0.2637183666229248, + -0.2678598165512085, + -0.1737881749868393, + -0.049990858882665634, + 0.013515918515622616, + -0.054345693439245224, + -0.1467861533164978, + -0.24911582469940186, + -0.2831358015537262, + -0.22300836443901062, + -0.13739243149757385, + -0.017879672348499298, + -0.040345460176467896, + -0.09990613907575607, + -0.16936856508255005, + -0.2266550064086914, + -0.2020808756351471, + -0.1509508341550827, + 0.014163740910589695, + 0.07591170817613602, + 0.09185601025819778, + 0.10455341637134552, + 0.09514842182397842, + 0.09877350926399231, + 0.053898438811302185, + 0.005704578943550587, + 0.0591997392475605, + 0.13600079715251923, + 0.21777905523777008, + 0.2574957311153412, + 0.20117221772670746, + 0.11415109038352966, + -0.001181072206236422, + 0.09470006823539734, + 0.18978413939476013, + 0.3073742389678955, + 0.36875811219215393, + 0.3069853186607361, + 0.1708926260471344, + -0.0325310118496418, + -0.02656698040664196, + 0.016060845926404, + 0.02459372952580452, + 0.04165660962462425, + 0.033969976007938385, + 0.012855498120188713, + 0.030497560277581215, + 0.004896117839962244, + -0.030887477099895477, + -0.13454437255859375, + -0.1294785887002945, + -0.06398608535528183, + 0.016156472265720367, + 0.03577340394258499, + -0.0033482143189758062, + -0.07112833857536316, + -0.16465041041374207, + -0.1621057391166687, + -0.09478478878736496, + -0.03555302321910858, + -0.001592929707840085, + -0.01719600521028042, + -0.06598587334156036, + -0.1411861628293991, + -0.1496778130531311, + -0.11535074561834335, + -0.0905962884426117, + -0.013807609677314758, + 0.029542237520217896, + 0.039138730615377426, + 0.03988270089030266, + 0.02665030211210251, + 0.049553126096725464, + -0.0015685928519815207, + -0.018007200211286545, + 0.009533192962408066, + 0.06910547614097595, + 0.1034330427646637, + 0.15017645061016083, + 0.10221225768327713, + 0.020978443324565887, + -0.023747621104121208, + 0.02295384369790554, + 0.09313814342021942, + 0.1771395057439804, + 0.21169933676719666, + 0.17989481985569, + 0.05862005427479744, + -0.004540165886282921, + 0.021994179114699364, + -0.003493826137855649, + -0.000224211675231345, + 0.031808022409677505, + -0.05090906098484993, + 0.001970196608453989, + 0.01633802428841591, + 0.0049764602445065975, + 0.0006027702474966645, + -0.005952450912445784, + -0.009886081330478191, + -0.08520589768886566, + 0.030780712142586708, + 0.00037104589864611626, + 0.011886775493621826, + -0.023506291210651398, + 0.08029806613922119, + -0.005086984951049089, + -0.07738454639911652, + 0.06721897423267365, + -0.02397127076983452, + 0.006669329944998026, + -0.016343094408512115, + 0.06056324020028114, + 0.15656796097755432, + -0.49836501479148865, + 0.2475810945034027, + -0.009270203299820423, + -0.006855266634374857, + 0.0034896093420684338, + -0.027938276529312134, + 0.5722692012786865, + -1.1357109546661377, + 0.5644665956497192, + 0.015787361189723015, + -0.015141892246901989, + -0.0032788251992315054, + -0.04797150194644928, + 0.6196744441986084, + -1.1540743112564087, + 0.6065864562988281, + 0.0019708566833287477, + 0.006332532037049532, + 0.014192940667271614, + 0.03773411735892296, + 0.27323007583618164, + -0.594700813293457, + 0.2488076239824295, + -0.008853388018906116, + 0.005692378617823124, + 0.000576167949475348, + -0.027197014540433884, + 0.022015029564499855, + -0.02571249194443226, + 0.004507753532379866, + -0.002439734758809209, + -0.01994609646499157, + 0.03601142391562462, + 0.008136607706546783, + 0.01658148691058159, + -0.06548810750246048, + 0.022721221670508385, + -0.0038820707704871893, + -0.0007800398161634803, + 0.001392301986925304, + 0.09576108306646347, + -0.014628835022449493, + -0.14505760371685028, + 0.07135403156280518, + -0.00839388556778431, + -0.004555124789476395, + -0.04466082155704498, + 0.1456393599510193, + 0.3475525975227356, + -0.7879117131233215, + 0.36262738704681396, + 0.008226356469094753, + 0.0055343699641525745, + -0.061139706522226334, + 0.08975803852081299, + 0.9340736269950867, + -1.7307822704315186, + 0.796896755695343, + -0.024700213223695755, + -0.013090251013636589, + -0.05148586630821228, + 0.050525497645139694, + 0.927090048789978, + -1.7473385334014893, + 0.7727715373039246, + -0.005721901543438435, + 0.010676853358745575, + -0.012798544019460678, + 0.11131046712398529, + 0.4181194007396698, + -0.8475598096847534, + 0.33206430077552795, + 0.018843427300453186, + 0.0006885005859658122, + 0.027498219162225723, + 0.00207257061265409, + 0.0032615051604807377, + -0.021950624883174896, + -0.008452882058918476, + -0.007631891872733831, + -0.028561849147081375, + 0.04865337535738945, + -0.0023105579894036055, + -0.026170270517468452, + -0.011794357560575008, + 0.004327487666159868, + 0.01756221242249012, + 0.0011611212976276875, + -0.008793564513325691, + 0.0741758644580841, + -0.057649385184049606, + -0.006000686902552843, + -0.022717488929629326, + -0.0047143916599452496, + 0.005709030199795961, + -0.05611564591526985, + 0.05792170390486717, + 0.1873699128627777, + -0.3856293857097626, + 0.1371920108795166, + 0.018953431397676468, + 0.015250314958393574, + -0.0016827551880851388, + -0.08515634387731552, + 0.6517581939697266, + -0.9557326436042786, + 0.46986615657806396, + -0.014306572265923023, + -0.01625121757388115, + -0.016088897362351418, + -0.13429272174835205, + 0.6437729001045227, + -1.0167845487594604, + 0.5061463117599487, + 0.00879831612110138, + -0.008598369546234608, + 0.02747279778122902, + 0.007245234213769436, + 0.2527446150779724, + -0.47163763642311096, + 0.15560215711593628, + 0.005050336476415396, + -0.024848125874996185, + -0.0006449198699556291, + -0.008673148229718208, + -0.06940636038780212, + -0.016248086467385292, + 0.1250494420528412, + 0.026387182995676994, + 0.009615709073841572, + -0.0025482974015176296, + -0.04534498229622841, + -0.2626228630542755, + -0.2753732204437256, + 0.052055053412914276, + 0.010792221873998642, + 0.007360508665442467, + 0.10271529853343964, + 0.1113760769367218, + -0.31120774149894714, + -0.49849262833595276, + -0.2206398844718933, + 0.04994913563132286, + 0.054614756256341934, + 0.27786919474601746, + 0.56647789478302, + 0.20970205962657928, + -0.22717078030109406, + -0.17321231961250305, + -0.07836200296878815, + -0.09607961028814316, + 0.10685958713293076, + 0.40848156809806824, + 0.34087467193603516, + -0.005242985673248768, + -0.0682876780629158, + -0.0694413110613823, + -0.1886596381664276, + -0.04473332315683365, + 0.18096435070037842, + 0.1961163580417633, + 0.0014336564345285296, + 0.014584851451218128, + 0.0462430939078331, + -0.1556192934513092, + -0.12809665501117706, + 0.0213937908411026, + 0.10984069108963013, + -0.023050926625728607, + -0.013447473756968975, + 0.007857509888708591, + -0.027979737147688866, + -0.04768490046262741, + -0.09350565075874329, + -0.1659490317106247, + 0.007927919737994671, + 0.26641780138015747, + 0.03398526459932327, + 0.02118881419301033, + -0.006898822728544474, + -0.15209096670150757, + -0.4939330220222473, + -0.42655149102211, + 0.08215854316949844, + 0.02115131914615631, + 0.08892140537500381, + 0.2164168655872345, + 0.12431265413761139, + -0.47813764214515686, + -0.6588870882987976, + -0.3097454905509949, + 0.0837375745177269, + 0.1548176258802414, + 0.49661529064178467, + 0.7337944507598877, + 0.1966201215982437, + -0.29367199540138245, + -0.2547970116138458, + -0.11655519157648087, + -0.11720486730337143, + 0.21941716969013214, + 0.5902130603790283, + 0.42572125792503357, + 0.020460324361920357, + -0.12768393754959106, + -0.12030418962240219, + -0.2582310736179352, + -0.0355166494846344, + 0.2766987085342407, + 0.28080257773399353, + 0.08665957301855087, + 0.027141664177179337, + 0.02690703421831131, + -0.25276950001716614, + -0.23180679976940155, + 0.015180152840912342, + 0.11523276567459106, + 0.041165824979543686, + 0.017444534227252007, + 0.0009439520072191954, + -0.025763530284166336, + -0.022880665957927704, + -0.024819007143378258, + -0.04901815578341484, + 0.027672944590449333, + 0.11211585998535156, + 0.024664992466568947, + -0.010093869641423225, + 0.009466213174164295, + -0.043605536222457886, + -0.17007218301296234, + -0.1366996467113495, + 0.08740171790122986, + -0.014591479673981667, + -0.0031720874831080437, + 0.0835830345749855, + 0.028662094846367836, + -0.21436777710914612, + -0.24753160774707794, + -0.06092096120119095, + 0.03788171336054802, + 0.04295210912823677, + 0.19064708054065704, + 0.3095722496509552, + 0.08003447204828262, + -0.09509303420782089, + -0.05495578795671463, + -0.052218906581401825, + -0.07204427570104599, + 0.07710819691419601, + 0.18033725023269653, + 0.0834946483373642, + -0.049662720412015915, + -0.06561554968357086, + -0.013351643458008766, + -0.11217659711837769, + 0.031957074999809265, + 0.12180440872907639, + 0.06891122460365295, + -0.013705568388104439, + 0.0011150656500831246, + 0.03281388059258461, + -0.11285661906003952, + -0.06422404199838638, + 0.04218210279941559, + 0.014165353029966354, + -0.006244795396924019, + 0.01745765097439289, + 0.08924975246191025, + 0.01710040494799614, + -0.14013372361660004, + -0.21913501620292664, + 0.03613810986280441, + 0.14273521304130554, + 0.05801931768655777, + 0.021427493542432785, + 0.23185034096240997, + 0.2427377849817276, + -0.4384608566761017, + -0.7205182909965515, + -0.18313364684581757, + 0.033575087785720825, + -0.0809125304222107, + 0.04173902049660683, + 0.7251381874084473, + 1.1058244705200195, + -0.015065462328493595, + -0.6434917449951172, + -0.3080260753631592, + -0.090518057346344, + -0.3659006655216217, + -0.4520319700241089, + 0.5924424529075623, + 1.4148176908493042, + 0.5285682082176208, + -0.027211233973503113, + -0.07359065115451813, + -0.08583711832761765, + -0.5631492137908936, + -1.0246236324310303, + -0.1835726648569107, + 0.3307121694087982, + 0.22562064230442047, + 0.05237145721912384, + 0.13263091444969177, + 0.13899636268615723, + -0.1626550555229187, + -0.3918432295322418, + -0.03585565462708473, + 0.06904798001050949, + 0.029870154336094856, + 0.04289601743221283, + 0.05758490040898323, + 0.10055387020111084, + -0.011962685734033585, + -0.13269846141338348, + 0.0012237781193107367, + 0.05511128902435303, + 0.03764793649315834, + -0.07580426335334778, + -0.1750984787940979, + 0.0189101230353117, + 0.08156414330005646, + 0.01691802591085434, + 0.004023027140647173, + 0.18009696900844574, + 0.22744491696357727, + -0.38747039437294006, + -0.6413040161132812, + -0.19208981096744537, + 0.01971367374062538, + -0.036756888031959534, + 0.004946697968989611, + 0.7331712245941162, + 1.1178003549575806, + 0.03220612183213234, + -0.5881579518318176, + -0.24453559517860413, + -0.11856977641582489, + -0.43593257665634155, + -0.5339378118515015, + 0.49467018246650696, + 1.3376370668411255, + 0.5238692164421082, + 0.04584280773997307, + 0.004761924035847187, + -0.032823480665683746, + -0.5419207811355591, + -1.0093209743499756, + -0.19847697019577026, + 0.20687319338321686, + 0.12301573902368546, + 0.07981085777282715, + 0.14125365018844604, + 0.19885297119617462, + -0.1678825318813324, + -0.4042292535305023, + 0.004483209457248449, + 0.03009556047618389, + 0.010802071541547775, + 0.005967534612864256, + 0.0892769992351532, + 0.07342032343149185, + -0.0588892325758934, + -0.09044717997312546, + 0.06307072192430496, + -0.012583961710333824, + -0.006880680099129677, + 0.0030021765269339085, + 0.01633061282336712, + 0.06990820169448853, + 0.0070900083519518375, + -0.03546716272830963, + -0.022131899371743202, + -0.02906683459877968, + 0.010664403438568115, + -0.18731924891471863, + -0.158770352602005, + 0.08571326732635498, + 0.039154618978500366, + 0.032578419893980026, + -0.005781106185168028, + 0.17460086941719055, + 0.2787456810474396, + -0.13416190445423126, + -0.23966801166534424, + -0.004878139588981867, + 0.02796499989926815, + -0.06610933691263199, + -0.19162042438983917, + 0.11163146048784256, + 0.371842622756958, + 0.06444671750068665, + 0.016595548018813133, + 0.01164282951503992, + 0.08330011367797852, + -0.03192862868309021, + -0.2867860198020935, + -0.07080501317977905, + -0.016348646953701973, + -0.06306261569261551, + -0.016291450709104538, + 0.010558445006608963, + 0.13014638423919678, + 0.06202690303325653, + -0.03361419215798378, + 0.0691375732421875, + 0.003561250865459442, + -0.013095442205667496, + -0.050333790481090546, + -0.019117066636681557, + 0.0012089330703020096, + -0.004555183462798595, + -0.022682132199406624, + 0.04747068136930466, + -0.06425238400697708, + -0.0010437731398269534, + -0.0071629988960921764, + -0.04302623122930527, + -0.04830477759242058, + -0.04069536179304123, + -0.06627446413040161, + -0.011470981873571873, + 0.03961857780814171, + 0.026594260707497597, + -0.020662540569901466, + -0.05999285355210304, + -0.053548794239759445, + -0.025959201157093048, + -0.015834785997867584, + 0.013910192996263504, + -0.015868371352553368, + -0.056620921939611435, + -0.06785159558057785, + -0.061030179262161255, + -0.03560228645801544, + -0.04177624359726906, + -0.024657463654875755, + -0.04889696091413498, + 0.004557035397738218, + 0.15414470434188843, + 0.21642963588237762, + 0.035425592213869095, + -0.04339970648288727, + -0.05034525692462921, + -0.08522290736436844, + 0.10652441531419754, + 0.6791198253631592, + 0.7785530686378479, + 0.19941796362400055, + -0.05430706962943077, + -0.02583213709294796, + -0.055139996111392975, + 0.17940561473369598, + 0.6757862567901611, + 0.8240399360656738, + 0.25826773047447205, + -0.062254682183265686, + -0.026456547901034355, + -0.027271386235952377, + -0.0026193747762590647, + 0.11893659085035324, + 0.1915995329618454, + 0.013776157051324844, + 0.08452087640762329, + 0.009950258769094944, + 0.01774573139846325, + 0.06609759479761124, + 0.06512798368930817, + 0.07601971179246902, + 0.09192144125699997, + 0.007696932647377253, + -0.056120894849300385, + -0.03937293961644173, + 0.043086692690849304, + 0.055803027004003525, + 0.08208976686000824, + 0.03658852353692055, + 0.025779196992516518, + -0.0340605266392231, + 0.03186321631073952, + 0.09720855951309204, + 0.10651290416717529, + 0.09562067687511444, + 0.08120692521333694, + 0.06832587718963623, + 0.03940538689494133, + 0.09561086446046829, + -0.03726261481642723, + -0.3520663380622864, + -0.4187469184398651, + -0.11643502116203308, + 0.06203937157988548, + 0.056670401245355606, + 0.11540547758340836, + -0.2742924690246582, + -1.1301417350769043, + -1.2482489347457886, + -0.4411431849002838, + 0.08538330346345901, + 0.036888301372528076, + 0.08759869635105133, + -0.32129940390586853, + -1.1163593530654907, + -1.26430082321167, + -0.48638999462127686, + 0.1056363582611084, + 0.042436979711055756, + 0.07075526565313339, + -0.08341801166534424, + -0.30567145347595215, + -0.39268070459365845, + -0.10187282413244247, + -0.02507772110402584, + -0.0044433241710066795, + -0.009278317913413048, + -0.02964872494339943, + -0.018799586221575737, + -0.03760084509849548, + -0.030454028397798538, + -0.004638439975678921, + 0.026587119325995445, + 0.0095819728448987, + -0.007110759150236845, + -0.006491640582680702, + -0.028083719313144684, + -0.009543413296341896, + -0.005706887226551771, + 0.013012710027396679, + -0.010281933471560478, + -0.0544208325445652, + -0.023230208083987236, + -0.05344587564468384, + -0.04052828997373581, + -0.028035156428813934, + -0.011922319419682026, + -0.045427750796079636, + 0.020700184628367424, + 0.2117788940668106, + 0.21090814471244812, + 0.07214333862066269, + -0.019348343834280968, + -0.014455118216574192, + -0.03561105206608772, + 0.17339389026165009, + 0.49509289860725403, + 0.5219546556472778, + 0.26121678948402405, + -0.029803339391946793, + -0.013761913403868675, + -0.04028521850705147, + 0.17008572816848755, + 0.45583003759384155, + 0.4757367670536041, + 0.22357690334320068, + -0.050064269453287125, + -0.021086007356643677, + -0.039873600006103516, + 0.06433176249265671, + 0.20187893509864807, + 0.2078690379858017, + 0.07802058011293411, + 0.022050827741622925, + -0.05272649601101875, + -0.024311071261763573, + -0.12387345731258392, + -0.20065246522426605, + 0.0262442696839571, + 0.20101603865623474, + 0.056791841983795166, + -0.008266052231192589, + 0.025132112205028534, + -0.23289933800697327, + -0.5296569466590881, + -0.282010018825531, + 0.025113720446825027, + 0.13172000646591187, + 0.16999290883541107, + 0.31588253378868103, + 0.05583454668521881, + -0.5321000814437866, + -0.5585035085678101, + -0.23885560035705566, + 0.0461968369781971, + 0.13807418942451477, + 0.6536149382591248, + 0.6385176777839661, + -0.15636183321475983, + -0.5484278798103333, + -0.5470613241195679, + -0.06269911676645279, + -0.06726553291082382, + 0.5561463236808777, + 1.0985187292099, + 0.6801460385322571, + 0.12841203808784485, + -0.21693651378154755, + -0.19168342649936676, + -0.43073776364326477, + -0.15226863324642181, + 0.41150590777397156, + 0.47421786189079285, + 0.25146934390068054, + -0.017203813418745995, + -0.09694849699735641, + -0.4082376956939697, + -0.3549531400203705, + -0.023591510951519012, + 0.12086013704538345, + 0.08050766587257385, + -0.044960521161556244, + -0.0031193571630865335, + 0.014398006722331047, + -0.005931032355874777, + 0.01548685971647501, + 0.05407734215259552, + -0.006386967841535807, + 0.021660227328538895, + 0.01656133122742176, + 0.002835798542946577, + 0.0008500503608956933, + 0.021802745759487152, + 0.13470955193042755, + 0.06802596151828766, + 0.0033256933093070984, + 0.03037848509848118, + 0.054654810577631, + -0.034221138805150986, + 0.015171438455581665, + 0.23395732045173645, + 0.24771827459335327, + 0.16352902352809906, + -0.07505007833242416, + -0.0814652070403099, + -0.21493901312351227, + -0.3109704852104187, + 0.013416547328233719, + 0.12807825207710266, + 0.12044191360473633, + -0.007915153168141842, + 0.0100772799924016, + -0.15165796875953674, + -0.4013277292251587, + -0.24811144173145294, + -0.06641282886266708, + 0.022568246349692345, + 0.061083581298589706, + 0.09920243173837662, + 0.0695505365729332, + -0.12213064730167389, + -0.12606006860733032, + -0.04593949392437935, + -0.040190644562244415, + 0.03899035230278969, + 0.12688779830932617, + 0.114081971347332, + -0.013348283246159554, + 0.03325115144252777, + 0.007111718878149986, + 0.048056699335575104, + -0.003726312192156911, + 0.05401211231946945, + 0.05355936661362648, + 0.21303032338619232, + 0.2944865822792053, + -0.13604623079299927, + -0.3770989775657654, + -0.0808275118470192, + -0.006103217601776123, + -0.02005188539624214, + 0.37605899572372437, + 0.7776278853416443, + 0.32064270973205566, + -0.23708422482013702, + -0.23380732536315918, + -0.22103570401668549, + -0.45596328377723694, + 0.07213663309812546, + 0.9384943246841431, + 0.8762810230255127, + 0.3557227551937103, + -0.09239326417446136, + -0.25462013483047485, + -0.9858288168907166, + -0.9860153198242188, + 0.2600172162055969, + 0.7731484770774841, + 0.7665594816207886, + 0.14806008338928223, + 0.13109923899173737, + -0.6917864680290222, + -1.580305814743042, + -0.9557210803031921, + -0.16357193887233734, + 0.3189502954483032, + 0.28703632950782776, + 0.5599567890167236, + 0.2459551841020584, + -0.5451022982597351, + -0.6926754713058472, + -0.4368602931499481, + 0.027606861665844917, + 0.025857241824269295, + 0.5376880764961243, + 0.535673975944519, + 0.09012678265571594, + -0.14688564836978912, + -0.1812361180782318, + 0.050619762390851974, + 0.021388273686170578, + -0.05923623591661453, + -0.006538081914186478, + 0.05171535536646843, + -0.051560595631599426, + -0.007643367163836956, + 0.027748188003897667, + 0.0024676925968378782, + -0.008760283701121807, + 0.13039670884609222, + 0.18568934500217438, + 0.06342563778162003, + 0.030788781121373177, + -0.004423442296683788, + -0.041261281818151474, + 0.013299684040248394, + 0.22491391003131866, + 0.27831292152404785, + 0.0883866474032402, + 0.048967570066452026, + 0.0012756097130477428, + -0.03215779736638069, + 0.02710782177746296, + 0.20178261399269104, + 0.22446107864379883, + 0.06052157282829285, + 0.019020315259695053, + 0.02715166099369526, + -0.03146626800298691, + -0.017960363999009132, + 0.11820292472839355, + 0.16114193201065063, + 0.05221821367740631, + -0.02201441302895546, + -0.026308327913284302, + 0.008580431342124939, + -0.02444308064877987, + 0.061380185186862946, + 0.11184953153133392, + 0.006053542252629995, + -0.03248603641986847, + -0.037558719515800476, + 0.01881473697721958, + -0.02349863201379776, + 0.02150980569422245, + 0.09881952404975891, + 0.03962325677275658, + -0.0031782283913344145, + -0.0030868228059262037, + -0.007606725674122572, + -0.06136326491832733, + 0.022755015641450882, + 0.09683670848608017, + 0.0016674631042405963, + 0.01306125894188881, + 0.011335537768900394, + -0.01769089885056019, + 0.005807302892208099, + 0.19103741645812988, + 0.2631426155567169, + 0.10424992442131042, + 0.025223100557923317, + -0.024689532816410065, + -0.03370697423815727, + 0.0512213259935379, + 0.2983294129371643, + 0.37597405910491943, + 0.18788966536521912, + 0.056492965668439865, + -0.006051253993064165, + -0.027141474187374115, + 0.06733105331659317, + 0.29171472787857056, + 0.32160115242004395, + 0.14176633954048157, + 0.008538221009075642, + -0.013039524666965008, + -0.04279422387480736, + 0.03345612436532974, + 0.19111940264701843, + 0.25728005170822144, + 0.09830093383789062, + -0.03371569141745567, + -0.05277566984295845, + -0.0011038694065064192, + -0.013657800853252411, + 0.10037966072559357, + 0.1724642813205719, + 0.04436478391289711, + -0.02240786701440811, + -0.02181128039956093, + 0.019526727497577667, + -0.050060197710990906, + 0.017275504767894745, + 0.07785085588693619, + -0.001727179973386228, + -0.0014453287003561854, + 0.019352080300450325, + -0.003202121239155531, + -0.04241566359996796, + 0.005586653482168913, + 0.06037082523107529, + 0.014115821570158005, + -0.00568200321868062, + 0.018071964383125305, + -0.0007147599244490266, + 0.011219227686524391, + 0.10582104325294495, + 0.15557849407196045, + 0.06189450994133949, + 0.014160261489450932, + 0.00814653467386961, + -0.028064200654625893, + 0.026086319237947464, + 0.1474728286266327, + 0.18273885548114777, + 0.06638553738594055, + 0.019263381138443947, + 0.028977060690522194, + -0.02551555074751377, + 0.01937149092555046, + 0.12000202387571335, + 0.1285850703716278, + 0.047506313771009445, + -0.011383740231394768, + 0.02826755866408348, + -0.009583448991179466, + -0.02093282900750637, + 0.07994058728218079, + 0.0926218256354332, + 0.0318426676094532, + -0.024409465491771698, + 0.020994359627366066, + 0.03295197710394859, + -0.034276511520147324, + 0.037398867309093475, + 0.0794353187084198, + 0.022805212065577507, + 0.0015407208120450377, + 0.013169347308576107, + 0.038584139198064804, + -0.002118688775226474, + 0.03358406573534012, + 0.09085306525230408, + 0.04255761206150055, + 0.010275964625179768, + 0.025351760908961296, + 0.04205995053052902, + 0.1319226324558258, + 0.049708493053913116, + -0.03743802383542061, + -0.04293569549918175, + -0.07646205276250839, + -0.04986324533820152, + 0.15992362797260284, + 0.011027384549379349, + -0.32150742411613464, + -0.3761928677558899, + -0.1654653549194336, + -0.08728181570768356, + 0.044714685529470444, + -0.007500737439841032, + -0.41376256942749023, + -0.6625701189041138, + -0.21809393167495728, + 0.10641554743051529, + 0.09274336695671082, + 0.10189083218574524, + -0.1175118163228035, + -0.2905261516571045, + -0.06248515099287033, + 0.4791955053806305, + 0.49865299463272095, + 0.23415400087833405, + 0.12729482352733612, + -0.05814196541905403, + -0.003843356389552355, + 0.16410382091999054, + 0.40895968675613403, + 0.22034852206707, + 0.021014101803302765, + -0.05658271536231041, + -0.012199933640658855, + 0.034277670085430145, + 0.09565535932779312, + 0.18921032547950745, + 0.010441004298627377, + -0.07427560538053513, + -0.09049694985151291, + -0.00554919708520174, + 0.021386168897151947, + 0.0297325998544693, + 0.06431404501199722, + -0.07367311418056488, + -0.08734254539012909, + -0.059512097388505936, + 0.11382041126489639, + 0.19622667133808136, + 0.02534862980246544, + -0.09704668819904327, + -0.10857658833265305, + -0.10241919010877609, + -0.037928055971860886, + 0.17917697131633759, + -0.0396210141479969, + -0.472421795129776, + -0.5453466176986694, + -0.23921693861484528, + -0.06353127211332321, + 0.033679377287626266, + -0.011634309776127338, + -0.523267924785614, + -0.8400278091430664, + -0.3026646375656128, + 0.17986975610256195, + 0.20296970009803772, + 0.14190459251403809, + -0.12953802943229675, + -0.3968985378742218, + -0.13779792189598083, + 0.548722505569458, + 0.7039015293121338, + 0.4025704264640808, + 0.19535738229751587, + -0.08568660169839859, + -0.0589536651968956, + 0.1868993639945984, + 0.5782724618911743, + 0.43018248677253723, + 0.08876730501651764, + -0.10219226032495499, + -0.04660544916987419, + 0.018129168078303337, + 0.14359626173973083, + 0.3174169361591339, + 0.07668197154998779, + -0.13716676831245422, + -0.2058524489402771, + -0.023707473650574684, + 0.03213014453649521, + 0.06718969345092773, + 0.0917893499135971, + -0.10766899585723877, + -0.206499844789505, + -0.12713390588760376, + -0.03174767270684242, + 0.046395305544137955, + 0.018318502232432365, + -0.002416136907413602, + -0.027143845334649086, + -0.0036621293984353542, + -0.019220896065235138, + 0.05427055433392525, + 0.05058867856860161, + -0.05274957790970802, + -0.11321325600147247, + -0.07062514126300812, + -0.01720590703189373, + -0.00901520811021328, + 0.01746262051165104, + -0.08946436643600464, + -0.2304752618074417, + -0.1021895483136177, + 0.013501768000423908, + 0.029721295461058617, + -0.010094762779772282, + 0.009764805436134338, + -0.06424269080162048, + -0.03032868541777134, + 0.13044297695159912, + 0.12166891992092133, + 0.07157951593399048, + 0.029467372223734856, + -0.03827595338225365, + -0.031337328255176544, + -0.026486340910196304, + 0.05953369289636612, + 0.029497025534510612, + 0.022669093683362007, + -0.01055963709950447, + -0.025020133703947067, + 0.002589448355138302, + 0.017152298241853714, + 0.062067389488220215, + 0.008266719058156013, + 0.00563611788675189, + -0.0044869836419820786, + 0.003065212396904826, + 0.014371387660503387, + 0.013636622577905655, + 0.021183570846915245, + -0.012462744489312172, + -0.02493542619049549, + 0.009652925655245781, + -0.09309647232294083, + -0.09614148736000061, + 0.020278261974453926, + 0.262399286031723, + 0.0025974283926188946, + -0.09532646089792252, + -0.0391894206404686, + -0.003332971129566431, + -0.25919869542121887, + -0.2104814499616623, + 0.5975717306137085, + 0.20378711819648743, + -0.20521192252635956, + 0.005045099183917046, + 0.16707547008991241, + -0.08322134613990784, + -1.1734565496444702, + 0.4060916006565094, + 0.9109339714050293, + -0.22450445592403412, + -0.14085394144058228, + 0.19534644484519958, + 0.6220589280128479, + -1.0614460706710815, + -1.2444484233856201, + 1.1965712308883667, + 0.5032565593719482, + -0.26604175567626953, + -0.13583213090896606, + 0.6453277468681335, + 0.4994892477989197, + -1.7917202711105347, + -0.15182015299797058, + 0.7891079783439636, + 0.10711944103240967, + -0.11587982624769211, + 0.08287231624126434, + 0.7848142981529236, + -0.1764022707939148, + -1.0492321252822876, + 0.15281184017658234, + 0.3100045919418335, + -0.0461110882461071, + -0.06824400275945663, + 0.25544390082359314, + 0.3444065451622009, + -0.3189513683319092, + -0.3503313362598419, + 0.05462741479277611, + -0.041028521955013275, + 0.00624969182536006, + -0.0014677124563604593, + 0.10383514314889908, + -0.03467189520597458, + -0.03946290910243988, + 0.012734192423522472, + -0.003676857566460967, + -0.1616411954164505, + -0.034441810101270676, + 0.34758275747299194, + -0.0017601394793018699, + -0.17407774925231934, + 0.05167992413043976, + 0.12394318729639053, + -0.018228475004434586, + -0.71342533826828, + 0.39672648906707764, + 0.4870489537715912, + -0.27272745966911316, + -0.02687050960958004, + 0.09090551733970642, + 0.46698617935180664, + -0.6089348196983337, + -0.7488552331924438, + 0.8327828645706177, + 0.19947239756584167, + -0.17806877195835114, + -0.09197663515806198, + 0.3198661506175995, + 0.42619431018829346, + -1.1321229934692383, + -0.05452701821923256, + 0.4155597984790802, + -0.001295815804041922, + -0.06596186012029648, + -0.05821318179368973, + 0.4515152871608734, + 0.06321248412132263, + -0.6065720319747925, + 0.10882120579481125, + 0.13767170906066895, + 0.01809641905128956, + -0.070295050740242, + 0.04035783186554909, + 0.22459834814071655, + -0.048405971378088, + -0.14622822403907776, + -0.01119917817413807, + 0.00666345190256834, + 0.04815478250384331, + -0.017866114154458046, + -0.04813665896654129, + -0.02366034686565399, + 0.03589487820863724, + -0.0066519430838525295, + 0.0004148671869188547, + -0.014153627678751945, + 0.04403751716017723, + 0.04098428785800934, + -0.10525348782539368, + -0.0078808031976223, + 0.0444580540060997, + -0.027595041319727898, + 0.010916849598288536, + -0.1390431821346283, + 0.20334453880786896, + -0.006475532427430153, + -0.16053295135498047, + 0.06964287906885147, + -0.025649840012192726, + 0.12622858583927155, + -0.09694403409957886, + -0.09791161119937897, + 0.2617567479610443, + -0.06268735229969025, + -0.03128494322299957, + -0.017743078991770744, + -0.02372320368885994, + 0.2195650041103363, + -0.2456466406583786, + 0.031090563163161278, + 0.010196326300501823, + -0.04323133826255798, + 0.02746250294148922, + -0.079569011926651, + 0.06894756853580475, + 0.11414647102355957, + -0.12175147980451584, + 0.025397513061761856, + 0.006027852185070515, + 0.013360690325498581, + -0.024561991915106773, + -0.10966529697179794, + 0.04913714900612831, + 0.09801583737134933, + 0.00013951699656900018, + -0.03194398432970047, + 0.002382949460297823, + -0.003335593966767192, + 0.023621119558811188, + 0.024585755541920662, + -0.016027197241783142, + -0.02846739999949932, + -0.012949706055223942, + -0.020852699875831604, + -0.016913240775465965, + 0.016088848933577538, + 0.141468346118927, + 0.07285624742507935, + -0.008997173048555851, + -0.033306676894426346, + -0.03418722003698349, + -0.15127411484718323, + -0.047440435737371445, + 0.2687169015407562, + 0.17237843573093414, + 0.03505166247487068, + -0.06994523108005524, + -0.031143782660365105, + -0.3024960458278656, + -0.1552041918039322, + 0.33517369627952576, + 0.28441429138183594, + 0.06471730768680573, + -0.0613982267677784, + -0.02271229960024357, + -0.29379361867904663, + -0.3259792923927307, + 0.16062304377555847, + 0.29220375418663025, + 0.10862076282501221, + -0.005909152328968048, + 0.049116987735033035, + -0.20140305161476135, + -0.3278747797012329, + -0.02566053718328476, + 0.14338354766368866, + 0.006411381531506777, + -0.007274044211953878, + 0.08232597261667252, + -0.04198717698454857, + -0.17330540716648102, + -0.01131037063896656, + 0.08018575608730316, + -0.02374250255525112, + -0.002276432001963258, + 0.00019528658594936132, + -0.024716932326555252, + 0.026509074494242668, + 0.08361849933862686, + 0.012956380844116211, + -0.06030649319291115, + -0.020338360220193863, + -0.03577016666531563, + -0.06858085840940475, + 0.008245388977229595, + 0.25225168466567993, + 0.16135559976100922, + -0.03690743073821068, + -0.09188401699066162, + -0.10410526394844055, + -0.25971388816833496, + -0.07926154136657715, + 0.3933144509792328, + 0.33186599612236023, + 0.059405017644166946, + -0.11824909597635269, + -0.10528354346752167, + -0.4808295667171478, + -0.25224801898002625, + 0.4267246127128601, + 0.4853539764881134, + 0.16933484375476837, + -0.073345847427845, + -0.02648857608437538, + -0.4723232388496399, + -0.4904792010784149, + 0.1938265562057495, + 0.44070878624916077, + 0.22439399361610413, + 0.03877745941281319, + 0.08536087721586227, + -0.31432414054870605, + -0.5158097743988037, + -0.09537900239229202, + 0.20227058231830597, + 0.07895126938819885, + 0.059195615351200104, + 0.14728911221027374, + -0.059377528727054596, + -0.2884902060031891, + -0.12288203090429306, + 0.05220698565244675, + -0.045279599726200104, + 0.019795719534158707, + -0.009819806553423405, + -0.013713877648115158, + 0.0012175077572464943, + 0.03281072899699211, + 0.0017424041870981455, + -0.028847966343164444, + -0.0032059827353805304, + -0.020358575507998466, + 0.0009416870889253914, + -0.007760196924209595, + 0.07921157032251358, + 0.03826644644141197, + -0.02976907789707184, + -0.03300238028168678, + -0.017963968217372894, + -0.055836472660303116, + -0.03299689665436745, + 0.15166012942790985, + 0.06786434352397919, + 0.008589516393840313, + -0.05790036544203758, + -0.0029997669626027346, + -0.14070068299770355, + -0.08799122273921967, + 0.19680362939834595, + 0.14703704416751862, + 0.03569985553622246, + -0.02847554162144661, + 0.03601403906941414, + -0.1339161992073059, + -0.20527805387973785, + 0.1060374304652214, + 0.16269326210021973, + 0.0575268417596817, + 0.0029672966338694096, + 0.018848277628421783, + -0.1029881089925766, + -0.19446833431720734, + -0.055140964686870575, + 0.09632515162229538, + 0.01196608692407608, + 0.01994382217526436, + 0.0030014747753739357, + 0.0029817752074450254, + -0.09395840018987656, + -0.038611751049757004, + 0.03793984279036522, + -0.006295992527157068, + 0.01736803539097309, + -0.0961727425456047, + 0.1318971812725067, + 0.00169672432821244, + 0.02773740515112877, + -0.03737606480717659, + -0.02413480542600155, + -0.07371329516172409, + 0.04465596005320549, + 0.34972262382507324, + 0.269726425409317, + 0.14907677471637726, + -0.15323053300380707, + -0.24987848103046417, + -0.32931339740753174, + 0.05209995433688164, + 0.5192161798477173, + 0.5108750462532043, + 0.2627664804458618, + -0.26889729499816895, + -0.49891141057014465, + -0.5081418752670288, + 0.13535383343696594, + 0.7318623661994934, + 0.7116816639900208, + 0.2973657250404358, + -0.38982102274894714, + -0.7131763100624084, + -0.5916072130203247, + 0.1200462281703949, + 0.7752112746238708, + 0.6947993636131287, + 0.21100594103336334, + -0.5576100945472717, + -0.7797606587409973, + -0.6058254837989807, + 0.08617032319307327, + 0.6432424187660217, + 0.522933304309845, + 0.16018754243850708, + -0.5134027004241943, + -0.6838728189468384, + -0.5088241100311279, + 0.10101393610239029, + 0.4321025311946869, + 0.3330003023147583, + 0.10116448998451233, + -0.2786642014980316, + -0.4134466052055359, + -0.3247438967227936, + 0.009768294170498848, + 0.008712833747267723, + -0.029476309195160866, + 0.007709377445280552, + 0.025279967114329338, + 0.01615188643336296, + 0.01585867628455162, + -0.0031516810413450003, + -0.06462288647890091, + -0.055517926812171936, + -0.013180199079215527, + -0.014849795028567314, + 0.05535515025258064, + 0.04162544384598732, + 0.0022392054088413715, + -0.09408581256866455, + -0.07889631390571594, + -0.032870080322027206, + 0.0382377915084362, + 0.07495865970849991, + 0.08439645916223526, + 0.008036677725613117, + -0.1167779192328453, + -0.10782196372747421, + -0.06854722648859024, + 0.06310252100229263, + 0.09643208235502243, + 0.08629462122917175, + -0.016969647258520126, + -0.10456187278032303, + -0.10410942137241364, + -0.017384463921189308, + 0.03931420296430588, + 0.11296819150447845, + 0.08688211441040039, + -0.018024103716015816, + -0.0985492691397667, + -0.10534191876649857, + 0.016594627872109413, + 0.024613894522190094, + 0.09626104682683945, + 0.056779902428388596, + -0.01314453687518835, + -0.1173979789018631, + -0.07576211541891098, + -0.00741730397567153, + 0.04463285952806473, + 0.06365535408258438, + 0.029472019523382187, + 0.06097950413823128, + -0.0884813666343689, + -0.020469073206186295, + -0.004499382339417934, + 0.006147715728729963, + 0.0061135985888540745, + 0.046618249267339706, + -0.024977274239063263, + -0.2809607684612274, + -0.20776452124118805, + -0.10792756825685501, + 0.10520339012145996, + 0.2195160835981369, + 0.27846819162368774, + -0.0425783209502697, + -0.4539273977279663, + -0.4210258722305298, + -0.24160517752170563, + 0.2377386838197708, + 0.4254952371120453, + 0.40258923172950745, + -0.08894401043653488, + -0.6261403560638428, + -0.6177268624305725, + -0.2941279113292694, + 0.36115866899490356, + 0.6176164746284485, + 0.5170959234237671, + -0.12760992348194122, + -0.6392932534217834, + -0.6288641095161438, + -0.20397846400737762, + 0.4859760105609894, + 0.7283636927604675, + 0.5233575105667114, + -0.08038943260908127, + -0.513219952583313, + -0.4611802101135254, + -0.08622774481773376, + 0.41959214210510254, + 0.6145293116569519, + 0.4252074360847473, + -0.08993257582187653, + -0.3586794435977936, + -0.23889268934726715, + -0.07402873039245605, + 0.2362663745880127, + 0.33187127113342285, + 0.24442552030086517, + -0.10037989169359207, + -0.1200498715043068, + -0.06188809871673584, + 0.009648810140788555, + 0.07703708112239838, + -0.07734857499599457, + -0.16337357461452484, + -0.13160429894924164, + -0.037760209292173386, + 0.10750655829906464, + 0.21975228190422058, + 0.21332265436649323, + 0.1482381671667099, + -0.012174196541309357, + -0.03128019720315933, + 0.06983920931816101, + 0.2055918425321579, + 0.16611628234386444, + 0.20955723524093628, + 0.21407610177993774, + 0.13214662671089172, + 0.01558306161314249, + 0.20919384062290192, + 0.21453723311424255, + 0.10980720072984695, + 0.10323476791381836, + 0.1754676252603531, + 0.16320686042308807, + 0.076839879155159, + 0.2669583261013031, + 0.29500535130500793, + 0.18005967140197754, + 0.14900699257850647, + 0.2337430715560913, + 0.2607984244823456, + -0.08909865468740463, + 0.12383633106946945, + 0.27329200506210327, + 0.2634970247745514, + 0.2298160344362259, + 0.22673286497592926, + 0.1753624528646469, + -0.14258335530757904, + -0.033422429114580154, + 0.09338828176259995, + 0.21975602209568024, + 0.2488732784986496, + 0.21165378391742706, + 0.08514796197414398, + 0.0776415765285492, + -0.028732767328619957, + -0.0827818363904953, + -0.14784079790115356, + -0.06101813539862633, + -0.10570015013217926, + -0.07298385351896286, + -0.03352680057287216, + -0.08094660192728043, + -0.08546923100948334, + -0.025722583755850792, + -0.04828448221087456, + -0.15816760063171387, + -0.22295169532299042, + -0.04976325109601021, + -0.12255501747131348, + -0.04869991913437843, + 0.09818085283041, + 0.2285904735326767, + 0.015187943354249, + -0.19952231645584106, + -0.1415022611618042, + -0.09511925280094147, + 0.10828559100627899, + 0.35640013217926025, + 0.5399265289306641, + 0.3026861250400543, + -0.10532847791910172, + -0.0455780103802681, + -0.09365752339363098, + 0.2482689470052719, + 0.5483031272888184, + 0.6572608947753906, + 0.4098849594593048, + -0.0039499495178461075, + -0.11641024053096771, + -0.22666053473949432, + -0.03133581206202507, + 0.2815704643726349, + 0.3229265809059143, + 0.009749597869813442, + -0.19616934657096863, + -0.05046992748975754, + -0.15597671270370483, + -0.22775587439537048, + -0.14872166514396667, + -0.12174414098262787, + -0.23433859646320343, + -0.238412007689476, + 0.09725375473499298, + 0.08522887527942657, + 0.006490080617368221, + -0.024619178846478462, + 0.07278231531381607, + 0.13406167924404144, + 0.22993306815624237, + 0.10250072181224823, + 0.09119024127721786, + -0.07687287777662277, + -0.1012108325958252, + -0.09500063210725784, + -0.10082961618900299, + 0.09466016292572021, + 0.11299365013837814, + -0.033278487622737885, + -0.20269805192947388, + -0.21449527144432068, + -0.08820098638534546, + -0.18970704078674316, + -0.050536416471004486, + -0.03471578657627106, + -0.13205547630786896, + -0.18150201439857483, + -0.03963223099708557, + 0.13029472529888153, + -0.11594776809215546, + -0.173879474401474, + 0.017406627535820007, + -0.11885572224855423, + -0.06966021656990051, + 0.1687183529138565, + 0.2677668035030365, + -0.020446041598916054, + -0.11710261553525925, + 0.044354867190122604, + -0.10054060816764832, + -0.1287878155708313, + -0.03600803390145302, + -0.03198331966996193, + -0.22372953593730927, + -0.11045534163713455, + 0.22963544726371765, + 0.16736479103565216, + -0.023956498131155968, + -0.0882943719625473, + -0.11904646456241608, + -0.10481738299131393, + 0.083598293364048, + 0.058089643716812134, + -0.04821285232901573, + 0.16764044761657715, + -0.13788309693336487, + -0.1412951946258545, + 0.059633608907461166, + 0.012824267148971558, + -0.03141501545906067, + -0.017422236502170563, + 0.3908282518386841, + -0.31520241498947144, + -0.27876099944114685, + 0.17109407484531403, + 0.011913848109543324, + -0.04440265893936157, + 0.05610174685716629, + 0.5290316343307495, + -0.4506116211414337, + -0.2946499288082123, + 0.2802693545818329, + 0.04180249199271202, + -0.05673402547836304, + 0.0445592887699604, + 0.4933576285839081, + -0.4903600513935089, + -0.3259376883506775, + 0.26069584488868713, + 0.047843094915151596, + -0.053804315626621246, + 0.029928382486104965, + 0.3588394224643707, + -0.39090782403945923, + -0.18598265945911407, + 0.1703576147556305, + 0.010407418012619019, + 0.019840527325868607, + -0.017079327255487442, + 0.21012797951698303, + -0.1586841642856598, + -0.12738685309886932, + 0.12431345880031586, + 0.028149213641881943, + 0.05083676427602768, + -0.07053223252296448, + 0.12090320140123367, + -0.13737183809280396, + -0.09807822853326797, + 0.07203921675682068, + -0.01965559460222721, + 0.036479320377111435, + -0.02657422423362732, + 0.2924504280090332, + -0.19397024810314178, + -0.20908842980861664, + 0.07435549795627594, + 0.011985386721789837, + -0.051603686064481735, + 0.039122600108385086, + 0.5911946892738342, + -0.45937344431877136, + -0.43863579630851746, + 0.23180224001407623, + 0.05592876672744751, + -0.10227655619382858, + 0.1371937245130539, + 0.7193072438240051, + -0.6789532899856567, + -0.5275344252586365, + 0.4098500609397888, + 0.09136661887168884, + -0.08802130073308945, + 0.12226735055446625, + 0.6819202303886414, + -0.7316576838493347, + -0.5229181051254272, + 0.37578293681144714, + 0.09086397290229797, + -0.05128701403737068, + 0.09287497401237488, + 0.5103837251663208, + -0.6150248646736145, + -0.3208717107772827, + 0.29780012369155884, + 0.071808360517025, + 0.04605705663561821, + 0.028153980150818825, + 0.30872926115989685, + -0.32211968302726746, + -0.1925150454044342, + 0.18948692083358765, + 0.07391810417175293, + 0.08546463400125504, + -0.07042243331670761, + 0.14390304684638977, + -0.22509464621543884, + -0.12615789473056793, + 0.09681600332260132, + 0.0030679223127663136, + 0.06206878274679184, + -0.0493885837495327, + 0.11675205081701279, + -0.09476804733276367, + -0.0708041712641716, + 0.027848264202475548, + 0.018535451963543892, + 0.01112216804176569, + -0.023546719923615456, + 0.2808285057544708, + -0.2312571257352829, + -0.16320407390594482, + 0.15229304134845734, + -0.007220278959721327, + -0.026767488569021225, + -0.008487970568239689, + 0.39064091444015503, + -0.3746477961540222, + -0.22930599749088287, + 0.23297259211540222, + -0.020648201927542686, + -0.03918099403381348, + -0.03193120285868645, + 0.37857353687286377, + -0.38306936621665955, + -0.25103962421417236, + 0.2414209097623825, + 0.007709929719567299, + -0.041483473032712936, + -0.001570625347085297, + 0.315625935792923, + -0.276553213596344, + -0.13154125213623047, + 0.17517149448394775, + 0.03219839558005333, + 0.002647437620908022, + -0.012777225114405155, + 0.17064248025417328, + -0.13943275809288025, + -0.10204917937517166, + 0.09418098628520966, + 0.026260169222950935, + 0.05167905613780022, + -0.024634944275021553, + 0.0931941494345665, + -0.11875593662261963, + -0.0752263143658638, + 0.0569780170917511, + 0.00024334408226422966, + -0.001991289434954524, + -0.012094452045857906, + -0.0012201170902699232, + 0.01342268567532301, + 0.006425719242542982, + 0.01147665549069643, + -0.002208880614489317, + -0.019385183230042458, + -0.024868011474609375, + 0.00465290667489171, + 0.009205960668623447, + 0.0016242304118350148, + 0.0059639886021614075, + -0.03436571732163429, + 0.01672518253326416, + 0.008815832436084747, + 0.06389293074607849, + 0.06249547377228737, + 0.06542838364839554, + 0.043118152767419815, + 0.04117512330412865, + 0.014435848221182823, + 0.0065850247628986835, + 0.03811212629079819, + -0.006077034864574671, + -0.004025861620903015, + 0.006247953977435827, + 0.014478449709713459, + 0.0009701942908577621, + -0.002422194229438901, + 0.009390920400619507, + -0.052253514528274536, + -0.05192738026380539, + -0.010346310213208199, + -0.001328076352365315, + -0.002972622634842992, + 0.0015572139527648687, + 0.022503724321722984, + -0.002475353656336665, + 0.001927886507473886, + 0.02994818612933159, + 0.02062363363802433, + -0.0010653833160176873, + -0.005995174869894981, + 0.024450020864605904, + 0.013005194254219532, + 0.0496530681848526, + 0.029475165531039238, + 0.004157512914389372, + -0.0007043799851089716, + 0.01860312558710575, + 0.03839566186070442, + 0.00014980587002355605, + 0.018569663166999817, + 0.05668198689818382, + 0.04645680636167526, + 0.01642409712076187, + 0.03577466681599617, + 0.03575601801276207, + -0.03680748492479324, + -0.01865880750119686, + 0.041660092771053314, + 0.033268485218286514, + 0.03338993713259697, + 0.04665865749120712, + -0.03322917968034744, + -0.2860279381275177, + -0.28877392411231995, + -0.09617949277162552, + 0.014234350994229317, + 0.038012001663446426, + -0.016850680112838745, + -0.27252569794654846, + -0.6714493632316589, + -0.686245322227478, + -0.3376169502735138, + -0.0812990590929985, + 0.003058002796024084, + -0.026376569643616676, + -0.29216718673706055, + -0.6779875159263611, + -0.6917123198509216, + -0.3184400796890259, + -0.058261968195438385, + 0.06338769942522049, + 0.03199980780482292, + -0.09837217628955841, + -0.3355932831764221, + -0.30900436639785767, + -0.04878076910972595, + 0.061543505638837814, + 0.04651529714465141, + 0.0263908002525568, + 0.0030237447936087847, + -0.10458099842071533, + -0.07959774881601334, + 0.05430716276168823, + 0.056767694652080536, + 0.00796051137149334, + -0.016737859696149826, + -0.042338743805885315, + -0.0198048185557127, + -0.03085070475935936, + -0.058721307665109634, + -0.036032311618328094, + -0.0035414688754826784, + -8.359456842299551e-05, + -0.02213932015001774, + 0.02032857947051525, + 0.021788733080029488, + -0.03522418439388275, + -0.025317413732409477, + -0.042937491089105606, + -0.05680134892463684, + -0.012510996311903, + 0.226289302110672, + 0.24401520192623138, + 0.022300971671938896, + -0.030825607478618622, + -0.05485948920249939, + 0.007590078748762608, + 0.2208130657672882, + 0.6964298486709595, + 0.7457719445228577, + 0.3470557630062103, + 0.06941442936658859, + -0.03543366119265556, + 0.035853609442710876, + 0.2872598171234131, + 0.7504303455352783, + 0.7509996294975281, + 0.34327855706214905, + 0.024429334327578545, + -0.05711393058300018, + -0.034500252455472946, + 0.057939525693655014, + 0.33292675018310547, + 0.3141649067401886, + 0.033748809248209, + -0.062175147235393524, + -0.041224412620067596, + -0.01891348883509636, + -0.014519350603222847, + 0.08635713160037994, + 0.03148616850376129, + -0.08749162405729294, + -0.05658482387661934, + 0.00018510188965592533, + 0.002624311950057745, + -0.003570129396393895, + 0.0067627751268446445, + -0.01349653396755457, + -0.003961967770010233, + 0.0034001911990344524, + -0.00385954394005239, + 0.018012456595897675, + -0.018755480647087097, + -0.03163064643740654, + -0.0035233700182288885, + 0.011690095998346806, + -0.014693490229547024, + 0.017746854573488235, + 0.05693097040057182, + 0.1272590607404709, + 0.23477119207382202, + 0.19823509454727173, + 0.05071045830845833, + -0.007188393268734217, + -0.05571149289608002, + -0.06468938291072845, + -0.017831332981586456, + -0.07572834193706512, + -0.19599483907222748, + -0.15608063340187073, + -0.039450764656066895, + -0.035583946853876114, + -0.1605951488018036, + -0.5041624307632446, + -0.6836286783218384, + -0.3773191571235657, + -0.08623629808425903, + -0.04881078004837036, + 0.029403403401374817, + 0.15516817569732666, + 0.4108496308326721, + 0.6393839716911316, + 0.4688946008682251, + 0.2135964334011078, + 0.0623941570520401, + 0.02426956780254841, + -8.065254223765805e-05, + -0.00816427543759346, + -0.09353788942098618, + -0.06872912496328354, + -0.029405562207102776, + 0.012364620342850685, + 0.0060868943110108376, + 0.017015695571899414, + -0.0076495204120874405, + -0.006090708542615175, + -0.016521835699677467, + 0.009218892082571983, + 0.030833140015602112, + -0.0002345978282392025, + 0.03332215175032616, + 0.0030349211301654577, + 0.009600857272744179, + 0.05706647038459778, + 0.06095677986741066, + -0.016137542203068733, + 0.03195658698678017, + 0.13535599410533905, + 0.28229761123657227, + 0.4573267698287964, + 0.39102476835250854, + 0.17547546327114105, + 0.005337159149348736, + -0.07699840515851974, + -0.12667469680309296, + -0.16613735258579254, + -0.2908898890018463, + -0.44942277669906616, + -0.34229782223701477, + -0.16225378215312958, + -0.1100199744105339, + -0.4044281840324402, + -0.9058251976966858, + -1.1549302339553833, + -0.7502554059028625, + -0.2716369032859802, + -0.13495275378227234, + 0.08614412695169449, + 0.3164423108100891, + 0.7155097723007202, + 1.0356683731079102, + 0.7939887642860413, + 0.39567017555236816, + 0.16957539319992065, + 0.02675812318921089, + 0.048314403742551804, + 0.053107086569070816, + -0.009243623353540897, + -0.011442561633884907, + 0.004911235999315977, + 0.012210517190396786, + 0.006660772021859884, + -0.004562888294458389, + -0.009606098756194115, + -0.01610635593533516, + -0.03475078567862511, + 0.007796770427376032, + 0.02015513926744461, + 0.020311446860432625, + 0.009043446741998196, + -0.01929326355457306, + -0.04183953255414963, + -0.003052672604098916, + 0.020744286477565765, + 0.01371331699192524, + 0.004048139322549105, + 0.0692848190665245, + 0.16867054998874664, + 0.2799474000930786, + 0.28119951486587524, + 0.13579942286014557, + -0.0015732255997136235, + -0.05406518653035164, + -0.05831173434853554, + -0.034435681998729706, + -0.11925295740365982, + -0.2570647895336151, + -0.19120880961418152, + -0.09981344640254974, + -0.011702792719006538, + -0.22477947175502777, + -0.5395713448524475, + -0.7111374139785767, + -0.4207299053668976, + -0.11811137199401855, + -0.035199034959077835, + 0.024358956143260002, + 0.16262274980545044, + 0.46769100427627563, + 0.677872896194458, + 0.4637402892112732, + 0.15558630228042603, + 0.04467496648430824, + 0.03221412003040314, + 0.02430277317762375, + -0.006398700177669525, + -0.07235423475503922, + -0.03669704124331474, + -0.000992153538390994, + 0.02220241352915764, + -0.03329842537641525, + 0.05199713259935379, + -0.14053553342819214, + 0.1906905472278595, + -0.13544943928718567, + 0.08535720407962799, + -0.009813228622078896, + 0.03578176349401474, + -0.05863757058978081, + 0.33848440647125244, + -0.49837300181388855, + 0.15308170020580292, + 0.14865124225616455, + -0.12349266558885574, + -0.025796135887503624, + 0.17790427803993225, + -0.7813658714294434, + 0.853188693523407, + 0.2489670068025589, + -0.7378701567649841, + 0.2207188457250595, + 0.05207442864775658, + -0.4280349314212799, + 1.1408430337905884, + -0.24505679309368134, + -1.5490919351577759, + 1.4560288190841675, + -0.31143030524253845, + -0.03536878153681755, + 0.5640448331832886, + -0.6874421834945679, + -1.210310697555542, + 2.6637399196624756, + -1.6589887142181396, + 0.2221546173095703, + 0.10179737955331802, + -0.4354941248893738, + 0.034149203449487686, + 1.480568528175354, + -2.072199821472168, + 0.9205833673477173, + 0.021510563790798187, + -0.07755836099386215, + 0.17983688414096832, + 0.040537625551223755, + -0.5325585603713989, + 0.550999641418457, + -0.11060550063848495, + -0.09052976220846176, + -0.048361390829086304, + 0.03450514376163483, + -0.11854307353496552, + 0.23462797701358795, + -0.17563995718955994, + 0.0653814822435379, + -0.009748813696205616, + 0.07013920694589615, + -0.08628369867801666, + 0.3019683063030243, + -0.630340576171875, + 0.274477481842041, + 0.15417183935642242, + -0.036220982670784, + -0.07344137132167816, + 0.2339126616716385, + -1.0395091772079468, + 1.2002928256988525, + 0.085142120718956, + -0.7080597281455994, + 0.23101751506328583, + 0.016307154670357704, + -0.45877355337142944, + 1.617128849029541, + -0.6593433618545532, + -1.8957709074020386, + 1.746606469154358, + -0.37062564492225647, + 0.01213759370148182, + 0.5851964354515076, + -1.0307577848434448, + -1.4803766012191772, + 3.812014102935791, + -2.0028398036956787, + 0.12008816003799438, + 0.01813559979200363, + -0.5065457820892334, + 0.17598780989646912, + 2.0418734550476074, + -2.680522918701172, + 0.7466094493865967, + 0.16271913051605225, + -0.04379571974277496, + 0.21930621564388275, + 0.041255541145801544, + -0.6644601821899414, + 0.481300413608551, + 0.05410065874457359, + -0.09025495499372482, + 0.01954805478453636, + 0.01899997517466545, + -0.1337241530418396, + 0.19821906089782715, + -0.06395180523395538, + -0.03586877882480621, + 0.01973363384604454, + 0.013873124495148659, + -0.09288538247346878, + 0.4300728440284729, + -0.4235192537307739, + 0.03646458685398102, + 0.10077393800020218, + -0.07569073140621185, + -0.08176662772893906, + 0.3834531605243683, + -0.747482419013977, + 0.4493187367916107, + 0.2960513234138489, + -0.5245057344436646, + 0.27831950783729553, + 0.0731748417019844, + -0.45574328303337097, + 0.6987965703010559, + 0.019539732486009598, + -1.1160184144973755, + 1.0756875276565552, + -0.3804619312286377, + -0.040626902133226395, + 0.2780243456363678, + -0.32946258783340454, + -0.8122196793556213, + 1.9535348415374756, + -1.300661563873291, + 0.3443142771720886, + 0.04858396574854851, + -0.17409801483154297, + -0.07783844321966171, + 1.0875797271728516, + -1.5148566961288452, + 0.8014272451400757, + -0.19643208384513855, + -0.033590562641620636, + 0.11178025603294373, + 0.08284300565719604, + -0.5165408849716187, + 0.5841389894485474, + -0.24739950895309448, + 0.027926180511713028, + -0.028708497062325478, + 0.0037401756271719933, + -0.0047450135461986065, + 0.008427698165178299, + 0.009801353327929974, + -0.0029346586670726538, + -0.010193527676165104, + 0.014876358211040497, + 0.009861295111477375, + -0.005554665345698595, + -0.06270359456539154, + -0.0316256619989872, + 0.006706684362143278, + 0.04316525161266327, + 0.008637072518467903, + -0.03666357323527336, + -0.0719730481505394, + -0.1525861918926239, + -0.14396126568317413, + -0.05387119948863983, + 0.01955549605190754, + 0.007112634833902121, + -0.05175568535923958, + -0.16772602498531342, + -0.20807777345180511, + -0.18768996000289917, + -0.17093753814697266, + -0.03334345668554306, + 0.0011808606795966625, + -0.01579100452363491, + -0.12589050829410553, + -0.17219413816928864, + -0.19648219645023346, + -0.21980451047420502, + -0.04920821264386177, + 0.0012217299081385136, + 0.023885242640972137, + -0.056074876338243484, + -0.13907776772975922, + -0.19139252603054047, + -0.13652737438678741, + -0.0027339402586221695, + 0.004720518831163645, + -0.00037206560955382884, + 0.017924504354596138, + -0.02118082158267498, + -0.06553903222084045, + -0.0435921773314476, + 0.02721239998936653, + 0.020702000707387924, + 0.024033410474658012, + 0.005382229574024677, + -0.01273527555167675, + -0.01742861233651638, + 0.007402990944683552, + 0.010333286598324776, + 0.02598601020872593, + 0.012456837110221386, + -0.03471057116985321, + -0.10051856189966202, + -0.08084382116794586, + -0.023420603945851326, + 0.031205907464027405, + 0.00424322672188282, + -0.03734385594725609, + -0.1152661070227623, + -0.2012551724910736, + -0.1995576024055481, + -0.07972321659326553, + -0.011126434430480003, + -0.0185835100710392, + -0.06944561004638672, + -0.21481844782829285, + -0.26795628666877747, + -0.24916253983974457, + -0.17833945155143738, + -0.06658200174570084, + -0.00305415247566998, + -0.054028186947107315, + -0.19072681665420532, + -0.256619930267334, + -0.26868295669555664, + -0.21621295809745789, + -0.06564134359359741, + 0.0031192339956760406, + 0.013205861672759056, + -0.08044812828302383, + -0.18137820065021515, + -0.23007699847221375, + -0.13054916262626648, + -0.01135951280593872, + 0.013734308071434498, + 0.010981118306517601, + -0.02249351143836975, + -0.05804377421736717, + -0.10652261227369308, + -0.04163172468543053, + 0.017101088538765907, + -0.028687385842204094, + -0.0019976652693003416, + 0.009987232275307178, + 0.010130539536476135, + 0.0015575449215248227, + -0.000983694102615118, + -0.012845008634030819, + 0.01329281460493803, + 0.0029350779950618744, + -0.003755913581699133, + -0.036475058645009995, + -0.0245466697961092, + -0.0020879909861832857, + 0.025867130607366562, + -0.0065954397432506084, + 0.008656582795083523, + -0.04037104919552803, + -0.11718368530273438, + -0.13506115972995758, + -0.024255141615867615, + 0.014097613282501698, + -0.0009370348998345435, + -0.010953565128147602, + -0.12869219481945038, + -0.18789908289909363, + -0.19098156690597534, + -0.12795749306678772, + -0.002666366985067725, + -0.004907527007162571, + -0.014610078185796738, + -0.11913872510194778, + -0.19921070337295532, + -0.21869640052318573, + -0.1849898099899292, + -0.03470952808856964, + 0.0064156935550272465, + 0.03401843458414078, + -0.04000416398048401, + -0.12354391813278198, + -0.16908879578113556, + -0.10385500639677048, + 0.002833302365615964, + -0.036176733672618866, + -0.001048827893100679, + 0.010002595372498035, + -0.020798830315470695, + -0.0488261841237545, + -0.002972641494125128, + 0.016395021229982376, + -0.045770127326250076, + -0.12710650265216827, + -0.1637774109840393, + -0.1411965787410736, + 0.20447289943695068, + 0.509396493434906, + 0.07264503091573715, + 0.12041529268026352, + -0.015143441036343575, + -0.2673257887363434, + -0.3589763641357422, + 0.11289574205875397, + 0.8517020344734192, + 0.7068799138069153, + 0.067301444709301, + -0.02102830447256565, + -0.5235708355903625, + -1.2064802646636963, + -0.856619656085968, + 0.26774707436561584, + 0.6825867295265198, + 0.13516077399253845, + 0.3054035007953644, + -0.0727991834282875, + -1.4912222623825073, + -1.906838297843933, + -0.8574200868606567, + -0.15282419323921204, + 0.39327505230903625, + 0.9758505821228027, + 1.2323224544525146, + 0.18179064989089966, + -0.947610080242157, + -0.6657719016075134, + -0.19935055077075958, + -0.09150458872318268, + 0.34379544854164124, + 1.2025749683380127, + 0.9517407417297363, + -0.12023784220218658, + -0.3146151900291443, + -0.1049022302031517, + -0.34867578744888306, + -0.32945582270622253, + 0.28920575976371765, + 0.7844374179840088, + 0.35520124435424805, + 0.007452746387571096, + 0.018862545490264893, + -0.0021927610505372286, + 0.0321974977850914, + 0.05439181253314018, + -0.030729038640856743, + -0.03517322614789009, + -0.037830010056495667, + -0.056672073900699615, + -0.017769837751984596, + 0.06385952979326248, + 0.08161566406488419, + 0.07809178531169891, + 0.06333671510219574, + -0.036322008818387985, + -0.06432312726974487, + -0.03629852458834648, + 0.010879911482334137, + 0.088901087641716, + 0.0021402277052402496, + 0.09618857502937317, + 0.02661084569990635, + -0.03414442762732506, + -0.08736730366945267, + -0.048222169280052185, + 0.03507986292243004, + -0.053828027099370956, + 0.006044292356818914, + 0.04232194274663925, + 0.001624415279366076, + -0.028371643275022507, + -0.08724038302898407, + -0.005835397634655237, + 0.01057528518140316, + 0.04210871085524559, + 0.06106603890657425, + 0.04250370338559151, + 0.0028668276499956846, + -0.07583706080913544, + -0.06849333643913269, + -0.08538331836462021, + -0.021475542336702347, + 0.044341571629047394, + 0.03604369983077049, + 0.05146002024412155, + 0.00280605535954237, + -0.004615028854459524, + -0.07857430726289749, + -0.03716180846095085, + 0.010876243002712727, + -0.03418488800525665, + 0.007391764782369137, + 0.05969953536987305, + 0.08769611269235611, + 0.066011443734169, + -0.10404568910598755, + -0.27194535732269287, + -0.05224551260471344, + -0.03618992492556572, + -0.023098375648260117, + 0.13832588493824005, + 0.21510572731494904, + -0.07285867631435394, + -0.489085853099823, + -0.33285844326019287, + -0.04830349236726761, + 0.014211038127541542, + 0.2612524926662445, + 0.6911754608154297, + 0.5294638276100159, + -0.2706173360347748, + -0.39350029826164246, + -0.05156399682164192, + -0.16490484774112701, + 0.1161464974284172, + 0.8029336929321289, + 1.1809980869293213, + 0.5025736689567566, + 0.07084998488426208, + -0.1901131123304367, + -0.4918227195739746, + -0.603122889995575, + -0.09460704773664474, + 0.5786081552505493, + 0.35392242670059204, + 0.1328991800546646, + -0.008106965571641922, + -0.2159435749053955, + -0.6369062662124634, + -0.5241336822509766, + 0.06276796758174896, + 0.1139409989118576, + 0.05483332276344299, + 0.1703934520483017, + 0.14603517949581146, + -0.16187912225723267, + -0.4139055907726288, + -0.14918148517608643, + -0.06163417547941208, + 0.005302567034959793, + 0.015524876303970814, + -0.11895350366830826, + -0.19724233448505402, + 0.03412429615855217, + 0.10862118750810623, + 0.08550503104925156, + -0.008599682711064816, + -0.03031114675104618, + -0.33224624395370483, + -0.27994298934936523, + 0.196475550532341, + 0.31109708547592163, + 0.17151644825935364, + -0.04994147643446922, + -0.167176753282547, + -0.5247878432273865, + -0.21136601269245148, + 0.54701828956604, + 0.6110883951187134, + 0.04194486886262894, + -0.27640673518180847, + -0.0795169249176979, + -0.360530287027359, + 0.3472684621810913, + 1.5428175926208496, + 1.0249378681182861, + -0.2724844515323639, + -0.3013695478439331, + 0.020736562088131905, + -0.019495302811264992, + 0.7758124470710754, + 1.5381159782409668, + 0.028625331819057465, + -1.289720892906189, + -0.5894255638122559, + 0.0526396706700325, + 0.11443997919559479, + 0.5935031771659851, + 0.47169724106788635, + -1.2507063150405884, + -1.351940631866455, + -0.03894977271556854, + 0.05095001682639122, + 0.01581231690943241, + 0.11137383431196213, + -0.22327138483524323, + -0.9629225730895996, + -0.2607772946357727, + 0.5907121300697327, + 0.006906076334416866, + 0.002633580705150962, + 0.01940075121819973, + 0.0143396882340312, + 0.020781584084033966, + -0.07249777764081955, + -0.016355905681848526, + 0.016553230583667755, + -0.027528395876288414, + 0.0244428887963295, + 0.024910561740398407, + 0.027229825034737587, + -0.04104151204228401, + 0.007100561633706093, + 0.0157785601913929, + -0.06626633554697037, + 0.006520191207528114, + 0.021171070635318756, + 0.036674920469522476, + -0.06950324773788452, + -0.03003627620637417, + 2.178798422391992e-05, + -0.07278106361627579, + 0.014382920227944851, + 0.0982266515493393, + 0.1454961597919464, + -0.10096189379692078, + 0.022237209603190422, + -0.00040665315464138985, + -0.013766243122518063, + 0.06440296769142151, + 0.21751047670841217, + 0.02519127167761326, + -0.23383572697639465, + 0.0038903038948774338, + -0.042271602898836136, + -0.012596859596669674, + 0.023778460919857025, + 0.07685687392950058, + -0.21480663120746613, + -0.19205358624458313, + 0.04876565560698509, + -0.016765035688877106, + -0.02620583213865757, + 0.01641852967441082, + 0.02201787941157818, + -0.07457322627305984, + -0.003633625339716673, + 0.07550841569900513, + 0.024774253368377686, + 0.04710151255130768, + 0.09110233932733536, + -0.017366377636790276, + -0.04366954043507576, + -0.039786458015441895, + 0.005311290733516216, + 0.037867460399866104, + 0.05367766693234444, + 0.07434491813182831, + -0.07251215726137161, + -0.04231821000576019, + -0.023427855223417282, + 0.036294277757406235, + 0.07782749086618423, + 0.11835407465696335, + 0.08753973245620728, + -0.20742319524288177, + -0.13341759145259857, + -0.008225077763199806, + 0.07292432337999344, + 0.006392402108758688, + 0.021914338693022728, + -0.09218581020832062, + -0.44192466139793396, + -0.1744878888130188, + 0.014938815496861935, + 0.10678526759147644, + -0.012087192386388779, + -0.024533385410904884, + -0.1804407387971878, + -0.3253834545612335, + 0.040678758174180984, + 0.2011708915233612, + 0.17262929677963257, + -0.0045212251134216785, + -0.033313386142253876, + -0.10575363039970398, + -0.07636692374944687, + 0.20343273878097534, + 0.28330928087234497, + 0.043149981647729874, + -0.01109551265835762, + -0.0027725452091544867, + 0.003926735837012529, + 0.029440222308039665, + 0.23945140838623047, + 0.09122566133737564, + -0.15140119194984436, + 0.08737201988697052, + 0.07120998948812485, + 0.05722665786743164, + -0.04388495534658432, + 0.02116825245320797, + 0.023315919563174248, + 0.10898162424564362, + 0.11808467656373978, + 0.03412344306707382, + 0.002771642990410328, + -0.1959579437971115, + -0.05181330814957619, + -0.0044630044139921665, + 0.12481725960969925, + 0.09140311926603317, + 0.03444851189851761, + -0.10931172221899033, + -0.3204459846019745, + -0.21193139255046844, + -0.11101037263870239, + 0.04186606407165527, + -0.07420916110277176, + -0.2004990428686142, + -0.26937955617904663, + -0.12928874790668488, + 0.20819628238677979, + -0.17379426956176758, + -0.2181481271982193, + 0.005387924611568451, + -0.24132733047008514, + -0.23942433297634125, + 0.41489261388778687, + 1.0702778100967407, + 0.024913936853408813, + -0.28405970335006714, + 0.083008773624897, + -0.11059781163930893, + -0.17623695731163025, + -0.17386195063591003, + 0.010644182562828064, + -0.32716259360313416, + -0.2135595828294754, + 0.1223129853606224, + 0.07060510665178299, + -0.048680394887924194, + -0.3332099914550781, + -0.25886017084121704, + -0.18619979918003082, + -0.00733158877119422, + 0.03393476828932762, + -0.010564662516117096, + -0.01817108877003193, + -0.05650597810745239, + -0.01891104131937027, + -0.0554141066968441, + -0.004592927638441324, + -0.0013615720672532916, + -0.05552899092435837, + -0.0560498908162117, + -0.1080632209777832, + -0.013965745456516743, + -0.03290533646941185, + -0.02599845454096794, + -0.02877708151936531, + -0.05670137703418732, + -0.07158109545707703, + -0.08808472007513046, + -0.03919175639748573, + -0.08478893339633942, + -0.08045543730258942, + -0.10066724568605423, + -0.048338882625103, + -0.06750114262104034, + 0.08164039999246597, + 0.3343777060508728, + 0.004952755756676197, + -0.14891156554222107, + 0.032855477184057236, + -0.03277512267231941, + 0.0474768728017807, + 0.6316664814949036, + 1.2214386463165283, + 0.2548498213291168, + -0.13185030221939087, + -0.018188906833529472, + -0.07653989642858505, + -0.01643386110663414, + 0.06630122661590576, + 0.23864209651947021, + -0.013703612610697746, + -0.09347789734601974, + -0.0900193303823471, + -0.04930814355611801, + -0.02791711315512657, + -0.15441712737083435, + -0.01623091846704483, + -0.0447690524160862, + -0.06071227043867111, + -0.04737209901213646, + -0.059769801795482635, + -0.04375007003545761, + -0.00650476710870862, + 0.021540174260735512, + -0.05590728670358658, + -0.13030850887298584, + -0.022067781537771225, + -0.05066747963428497, + 0.00609770929440856, + 0.108611099421978, + 0.1621929407119751, + 0.05232185125350952, + -0.049729123711586, + -0.11906369775533676, + -0.030973592773079872, + 0.057787079364061356, + 0.1610448956489563, + 0.18756121397018433, + 0.07277501374483109, + -0.05777435004711151, + -0.05227195844054222, + 0.14434091746807098, + 0.1889694482088089, + 0.26951169967651367, + 0.4710105359554291, + 0.2164669781923294, + 0.05052375793457031, + -0.0038236663676798344, + 0.20267778635025024, + 0.31214746832847595, + 0.7506387829780579, + 1.2302387952804565, + 0.4363090693950653, + 0.16759593784809113, + -0.049752235412597656, + 0.044786907732486725, + 0.14537742733955383, + 0.2227499932050705, + 0.37362414598464966, + 0.16590620577335358, + 0.0864599421620369, + -0.14058542251586914, + -0.04404178634285927, + -0.0325944609940052, + -0.019113417714834213, + 0.17414243519306183, + 0.11160623282194138, + -0.034911543130874634, + 0.1523953527212143, + 0.04554234445095062, + -0.054958827793598175, + -0.11794494092464447, + -0.19570015370845795, + -0.21358126401901245, + -0.1885669231414795, + -0.08286706358194351, + -0.29818814992904663, + -0.52330082654953, + -0.6190353631973267, + -0.682529091835022, + -0.6171367764472961, + -0.4793100655078888, + -0.11180876195430756, + -0.3490432798862457, + -0.5531057715415955, + -0.6426181793212891, + -0.6420838832855225, + -0.4970071613788605, + -0.27038174867630005, + -0.09740017354488373, + -0.1929621547460556, + -0.30848363041877747, + -0.27204805612564087, + -0.2515120208263397, + -0.07497832179069519, + 0.03551386669278145, + -0.05060403421521187, + 0.08276989310979843, + 0.14321963489055634, + 0.3583574593067169, + 0.40667927265167236, + 0.39398193359375, + 0.27561235427856445, + 0.005085935816168785, + 0.2793635427951813, + 0.48155927658081055, + 0.7088037729263306, + 0.7394692897796631, + 0.6158861517906189, + 0.3986552655696869, + 0.025508087128400803, + 0.38533228635787964, + 0.5305332541465759, + 0.6659612059593201, + 0.6396889090538025, + 0.5396444797515869, + 0.39010515809059143, + -0.03072960674762726, + 0.014305810444056988, + 0.029885446652770042, + 0.038084372878074646, + 0.012448564171791077, + 0.034353457391262054, + 0.048626724630594254, + 0.048866890370845795, + 0.07561437785625458, + 0.09152165800333023, + 0.08432324975728989, + 0.09332144260406494, + 0.07517607510089874, + 0.049146559089422226, + 0.03146318346261978, + 0.06335246562957764, + 0.06438779830932617, + 0.06851581484079361, + 0.09263566881418228, + 0.06460423022508621, + 0.011992924846708775, + 0.03396693989634514, + 0.04433950409293175, + 0.04642309248447418, + 0.0022602551616728306, + -0.0361824594438076, + -0.0005105047021061182, + 0.030808264389634132, + 0.0022333709057420492, + -0.017826544120907784, + -0.03796307370066643, + -0.012887164019048214, + -0.028499294072389603, + -0.03367336839437485, + -0.03668365254998207, + -0.02807682938873768, + -0.07444571703672409, + -0.081318199634552, + -0.09610070288181305, + -0.05368436127901077, + -0.09006591141223907, + -0.10038736462593079, + -0.04115951433777809, + -0.056811004877090454, + -0.09935522079467773, + -0.11107856035232544, + -0.07852742075920105, + -0.0942930206656456, + -0.07625897973775864, + -0.12966541945934296, + -0.038938648998737335, + 0.04580259323120117, + 0.10179819911718369, + 0.17127273976802826, + 0.17857632040977478, + 0.13426578044891357, + 0.04687841981649399, + 0.2424812912940979, + 0.42633309960365295, + 0.5291624069213867, + 0.6012980937957764, + 0.5449428558349609, + 0.3945220708847046, + 0.07037744671106339, + 0.26918724179267883, + 0.44614800810813904, + 0.5331310629844666, + 0.568580687046051, + 0.43367546796798706, + 0.25516101717948914, + 0.08428427577018738, + 0.177769735455513, + 0.24885930120944977, + 0.2178547978401184, + 0.13834305107593536, + 0.07452446967363358, + 0.005187708884477615, + 0.050621017813682556, + -0.08428733795881271, + -0.15576106309890747, + -0.25531095266342163, + -0.34646397829055786, + -0.3276817202568054, + -0.24377694725990295, + 0.02817704901099205, + -0.2531633675098419, + -0.3907041549682617, + -0.5944734811782837, + -0.6062930822372437, + -0.5171639919281006, + -0.3501560389995575, + -0.019397703930735588, + -0.2758809030056, + -0.4118667244911194, + -0.5375933051109314, + -0.5525977611541748, + -0.44681206345558167, + -0.2748269736766815, + -0.04229651764035225, + -0.005005967803299427, + -0.011332424357533455, + 0.011387092061340809, + -0.015463154762983322, + -0.012038768269121647, + 0.011360889300704002, + 0.03551746904850006, + 0.05123865604400635, + 0.020377267152071, + 0.1065637394785881, + 0.18875306844711304, + 0.18516196310520172, + 0.12519532442092896, + -0.042940977960824966, + -0.03246130794286728, + -0.016645772382616997, + 0.07807288318872452, + -0.7815885543823242, + -0.5930942296981812, + 0.03312799707055092, + -0.04537777230143547, + -0.022234303876757622, + 0.009241255931556225, + 0.16947965323925018, + -0.0700032040476799, + -0.06346366554498672, + 0.09555318206548691, + 0.02858082763850689, + 0.009246457368135452, + 0.03902693837881088, + 0.007071994710713625, + 0.10085106641054153, + 0.0881502702832222, + 0.011019160971045494, + 0.006030070595443249, + -0.012882355600595474, + -0.01701420359313488, + 0.022596944123506546, + -0.05345382168889046, + 0.02355102449655533, + -0.0091088330373168, + 0.00015542628534603864, + -0.0004997836658731103, + -0.006951311603188515, + 0.01267238613218069, + -0.0033983420580625534, + -0.0030770134180784225, + 0.02975126914680004, + 0.010702245868742466, + -0.016947058960795403, + 0.007774800062179565, + 0.09566964209079742, + 0.07426714897155762, + 0.1621979922056198, + 0.12728945910930634, + 0.06112523376941681, + 0.06061968579888344, + 0.07934501022100449, + 0.11534841358661652, + 0.10001469403505325, + 0.15475066006183624, + 0.1828109323978424, + 0.02134544588625431, + -0.015320047736167908, + 0.012000483460724354, + -0.014393450692296028, + -1.5520576238632202, + -1.2115217447280884, + 0.017239907756447792, + -0.007013735361397266, + 0.0019166347337886691, + 0.025112343952059746, + 0.1803419440984726, + -0.30807924270629883, + -0.33957329392433167, + 0.10846519470214844, + 0.06151076406240463, + 0.054799750447273254, + 0.06235412135720253, + 0.09605015069246292, + 0.16495031118392944, + 0.12624189257621765, + 0.12234552949666977, + 0.006969878450036049, + 0.0033541936427354813, + 0.008165130391716957, + 0.035377491265535355, + -0.03170061111450195, + 0.019396571442484856, + -0.011411413550376892, + 0.019043665379285812, + 0.00957057997584343, + 0.0055394587107002735, + 0.05569477006793022, + 0.0076510305516421795, + 0.018707536160945892, + 0.06073765829205513, + 0.006503407843410969, + -0.0058801183477044106, + -0.03229741007089615, + 0.0386439748108387, + 0.03167358413338661, + 0.027749545872211456, + -0.04634377732872963, + -0.00019781991431955248, + 0.024982664734125137, + 0.009453915059566498, + 0.1091528981924057, + 0.21055325865745544, + 0.23810525238513947, + 0.13829846680164337, + -0.019112061709165573, + -0.0014926757430657744, + 0.01856786385178566, + 0.10649964213371277, + -0.8599057793617249, + -0.6383436322212219, + 0.10839059948921204, + -0.038730181753635406, + -0.030203847214579582, + -0.033147793263196945, + 0.18132103979587555, + -0.1427767276763916, + -0.11132896691560745, + 0.10957232862710953, + -0.00349965482018888, + 0.03486581891775131, + 0.016247740015387535, + 0.060106489807367325, + 0.1439678966999054, + 0.07201634347438812, + 0.07603273540735245, + -0.0072280303575098515, + 0.01600506529211998, + -0.012912745587527752, + 0.015192546881735325, + -0.034853674471378326, + 0.026164958253502846, + 0.001483929343521595, + 0.0508253313601017, + -0.010546445846557617, + -0.024398569017648697, + -0.0043407524935901165, + 0.0030393539927899837, + -0.009643012657761574, + -0.008882591500878334, + 0.01182172168046236, + 0.003359999740496278, + -0.01145304087549448, + -7.34154018573463e-05, + 0.007416137028485537, + -0.012022661976516247, + 0.013550116680562496, + -0.005982181057333946, + -0.019205773249268532, + -0.0811527743935585, + -0.06323252618312836, + -0.026379290968179703, + -0.04671972244977951, + -0.006205265875905752, + 0.05242094770073891, + 0.05065605416893959, + 0.01961991749703884, + 0.021542323753237724, + 0.04147094115614891, + 0.04451332613825798, + 0.05155060812830925, + 0.15659169852733612, + 0.4448348879814148, + 0.7207449078559875, + 0.8680058717727661, + 0.7269517779350281, + 0.36259666085243225, + 0.10394725203514099, + -0.20449180901050568, + -0.42664405703544617, + -0.7290332317352295, + -0.9376083016395569, + -0.735107958316803, + -0.3541502356529236, + -0.23789332807064056, + -0.10901623964309692, + -0.26809337735176086, + -0.38465574383735657, + -0.44440212845802307, + -0.4070444703102112, + -0.22405119240283966, + -0.14190013706684113, + 0.07151509076356888, + 0.21848519146442413, + 0.41893038153648376, + 0.4783499836921692, + 0.4281534254550934, + 0.28631147742271423, + 0.057699400931596756, + 0.0029010034631937742, + -0.02580493874847889, + -0.02152368798851967, + -0.025850815698504448, + 0.004789783153682947, + 0.021941278129816055, + 0.00574735039845109, + -0.004016151186078787, + -0.014377521350979805, + -0.0828985944390297, + -0.06380187720060349, + -0.048879947513341904, + -0.04580164700746536, + -0.030843649059534073, + 0.024663949385285378, + 0.03409295156598091, + 0.060452476143836975, + 0.037006158381700516, + 0.058853648602962494, + 0.07275765389204025, + 0.02882941998541355, + 0.14549848437309265, + 0.4268765151500702, + 0.7150183320045471, + 0.8942612409591675, + 0.7532845139503479, + 0.3846176564693451, + 0.15604183077812195, + -0.19108416140079498, + -0.42633384466171265, + -0.7508237361907959, + -0.9448286890983582, + -0.719300389289856, + -0.3583783805370331, + -0.2060524821281433, + -0.10382426530122757, + -0.2624296545982361, + -0.4049411416053772, + -0.4338999092578888, + -0.41390693187713623, + -0.22797809541225433, + -0.14593803882598877, + 0.08197329193353653, + 0.2430788278579712, + 0.3906225562095642, + 0.47147202491760254, + 0.42429792881011963, + 0.29326340556144714, + 0.06683206558227539, + 0.004355552606284618, + -0.007973028346896172, + 0.0035172239877283573, + -0.0018502225866541266, + -0.015291260555386543, + 0.0025160792283713818, + 0.0015979957534000278, + 0.011951611377298832, + -0.0004334237310104072, + -0.00172338483389467, + 0.017284434288740158, + -0.00445173867046833, + -0.004828867502510548, + 0.004030159674584866, + 0.03321678191423416, + -0.016998661682009697, + -0.029765218496322632, + -0.07912255078554153, + -0.0494595468044281, + 0.012136446312069893, + 0.029541414231061935, + -0.01129366084933281, + 0.09502168744802475, + 0.21533286571502686, + 0.3453419804573059, + 0.22987395524978638, + 0.04720258712768555, + 0.0032486498821526766, + -0.0042808204889297485, + -0.10162857174873352, + -0.21601493656635284, + -0.3040534257888794, + -0.19600912928581238, + -0.0568307563662529, + -0.0062937624752521515, + -0.021828925237059593, + -0.03831009939312935, + -0.08992031216621399, + -0.08103442937135696, + -0.07600760459899902, + -0.02319694682955742, + -0.008472982794046402, + -0.004151565954089165, + 0.05002164468169212, + 0.0985124409198761, + 0.11273156106472015, + 0.10279814153909683, + 0.032678257673978806, + -0.023295480757951736, + -0.022312145680189133, + 0.032877422869205475, + 0.08301658928394318, + -0.049675002694129944, + -0.05956050381064415, + 0.006878976244479418, + 0.011597251519560814, + -0.03617611899971962, + -0.005020621232688427, + 0.0066283573396503925, + 0.061849869787693024, + 0.0668889507651329, + -0.1120104044675827, + 0.0215831957757473, + -0.008177083916962147, + 0.019240612164139748, + -0.03794482350349426, + -0.21581093966960907, + 0.3248063623905182, + 0.0525924488902092, + -0.13873063027858734, + -0.030904211103916168, + -0.004122832324355841, + 0.2784009277820587, + -0.42068102955818176, + -0.15351417660713196, + 0.4266241192817688, + -0.10780557245016098, + 0.03840374946594238, + -0.15116721391677856, + 0.2292502224445343, + 0.23400554060935974, + -0.5023872256278992, + 0.14868289232254028, + 0.09809935092926025, + 0.03480924293398857, + -0.046804867684841156, + -0.14212554693222046, + 0.3073779344558716, + -0.029529480263590813, + -0.13998086750507355, + -0.02750661037862301, + 0.010526027530431747, + 0.032874979078769684, + -0.07645174115896225, + -0.02746269293129444, + 0.10902399569749832, + -0.00446560001000762, + -0.01339190173894167, + 0.003540819976478815, + -0.04410126060247421, + -0.10884726047515869, + 0.016081949695944786, + 0.15211890637874603, + 0.04027504846453667, + -0.05552368983626366, + 0.04718002676963806, + 0.014503135345876217, + -0.2764658033847809, + -0.16068166494369507, + 0.3356778621673584, + 0.06485499441623688, + -0.07164154946804047, + 0.084479421377182, + 0.2702949047088623, + -0.1339409202337265, + -0.9642015695571899, + 0.47433769702911377, + 0.4715694189071655, + -0.17669782042503357, + -0.04434441775083542, + 0.2641690671443939, + 0.7357130646705627, + -1.2222046852111816, + -0.8205837607383728, + 0.9091072678565979, + 0.14896778762340546, + -0.09332367032766342, + -0.16173647344112396, + 0.8782246708869934, + 0.3819980323314667, + -1.619883418083191, + 0.059255462139844894, + 0.42745286226272583, + -0.03186821565032005, + -0.16420172154903412, + 0.12124066799879074, + 0.8650834560394287, + -0.3728218674659729, + -0.5816569328308105, + 0.10949260741472244, + -0.010671291500329971, + -0.07903271913528442, + -0.09700250625610352, + 0.3192030191421509, + 0.2756008505821228, + -0.2616698145866394, + -0.11051242798566818, + 0.016789941117167473, + -0.0484573096036911, + -0.12333080172538757, + 0.0158428642898798, + 0.11172449588775635, + 0.014953864738345146, + -0.011746960692107677, + 0.05310823395848274, + 0.030244171619415283, + -0.23969320952892303, + -0.1039247065782547, + 0.285805881023407, + -0.04652552306652069, + -0.05380000174045563, + 0.05430186912417412, + 0.25547218322753906, + -0.06164371967315674, + -0.7386756539344788, + 0.4393811821937561, + 0.2623714804649353, + -0.1849273294210434, + -0.049713607877492905, + 0.1656467467546463, + 0.6638666391372681, + -0.899787187576294, + -0.5747878551483154, + 0.7465870976448059, + -0.025567445904016495, + -0.051771312952041626, + -0.19754628837108612, + 0.6828271746635437, + 0.4451557695865631, + -1.2559787034988403, + 0.07448688894510269, + 0.27905938029289246, + 0.003908769693225622, + -0.18454433977603912, + -0.011183545924723148, + 0.7449039816856384, + -0.228777676820755, + -0.47592073678970337, + 0.13784541189670563, + 0.019371675327420235, + -0.06424596160650253, + -0.1660400629043579, + 0.2080633044242859, + 0.2942465841770172, + -0.20263032615184784, + -0.0709841251373291, + -0.0021153483539819717, + -0.028180474415421486, + -0.021557176485657692, + 0.012511649169027805, + 0.06533018499612808, + 0.006560645066201687, + -0.01908997632563114, + -0.020228691399097443, + 0.10450740903615952, + 0.04476405307650566, + -0.20389842987060547, + -0.36356496810913086, + -0.18690945208072662, + 0.06581642478704453, + 0.005246834829449654, + -0.14777734875679016, + 0.04554577171802521, + 0.7314760088920593, + 1.1759854555130005, + 0.7747871279716492, + 0.08771117031574249, + 0.04425497353076935, + 0.14875195920467377, + -0.05036012455821037, + -1.0561891794204712, + -1.7835016250610352, + -1.313464879989624, + -0.4041728973388672, + -0.08825081586837769, + -0.18483860790729523, + -0.09619659930467606, + 0.6506555676460266, + 1.2331949472427368, + 1.057729721069336, + 0.3030258119106293, + 0.053314659744501114, + 0.10696353763341904, + 0.19720971584320068, + -0.19457301497459412, + -0.3546113669872284, + -0.3773464560508728, + 0.007737448439002037, + 0.007112926337867975, + -0.026632368564605713, + -0.07708505541086197, + 0.016982559114694595, + 0.03331448882818222, + 0.03235285356640816, + -0.04479134455323219, + 0.0062864539213478565, + -0.04983896017074585, + -0.014209658838808537, + 0.025105496868491173, + 0.07187403738498688, + -0.019782420247793198, + -0.0387532040476799, + 0.01098113413900137, + 0.10765481740236282, + -0.005502769257873297, + -0.29967597126960754, + -0.5370010733604431, + -0.25729984045028687, + 0.0341138020157814, + -0.01927473582327366, + -0.11736954003572464, + 0.09457080066204071, + 0.8881804943084717, + 1.5049697160720825, + 1.0347492694854736, + 0.22410355508327484, + -0.004720119293779135, + 0.1449226438999176, + -0.11916695535182953, + -1.2009364366531372, + -2.080855369567871, + -1.5549882650375366, + -0.5231477618217468, + -0.005029830615967512, + -0.11258674412965775, + 0.03710457682609558, + 0.9192798137664795, + 1.525830626487732, + 1.3018689155578613, + 0.44408130645751953, + 0.006972550880163908, + 0.07937697321176529, + 0.060622286051511765, + -0.4068094491958618, + -0.5964561104774475, + -0.6058750152587891, + -0.1743212193250656, + -0.0038881103973835707, + -0.04932431876659393, + -0.04989266395568848, + 0.07228495925664902, + 0.10359980911016464, + 0.11054171621799469, + 0.017031395807862282, + -0.012849675491452217, + -0.02224516123533249, + -0.019851619377732277, + 0.04567919671535492, + 0.12134519219398499, + 0.018673665821552277, + -0.03933878242969513, + 0.03506385162472725, + 0.07499910145998001, + -0.004981306381523609, + -0.269795298576355, + -0.4478399455547333, + -0.3141564130783081, + 0.014856644906103611, + -0.01102763693779707, + -0.11778493225574493, + -0.00048367868294008076, + 0.46917271614074707, + 0.8380635976791382, + 0.5829758048057556, + 0.14924737811088562, + 0.00504975114017725, + 0.1242799386382103, + 0.027800291776657104, + -0.5343790054321289, + -0.9185061454772949, + -0.6974499225616455, + -0.1733488291501999, + 0.028415951877832413, + -0.07513032108545303, + 0.010947657749056816, + 0.5501428246498108, + 0.8556726574897766, + 0.6854383945465088, + 0.21023745834827423, + -0.04757346957921982, + 0.028925150632858276, + -0.05005616322159767, + -0.4106282889842987, + -0.5990055203437805, + -0.5274976491928101, + -0.18928098678588867, + 0.007199999876320362, + 0.004744168370962143, + -0.006203897297382355, + 0.16117095947265625, + 0.20310591161251068, + 0.17358633875846863, + 0.057794276624917984, + 0.0018837900133803487, + -0.021730661392211914, + 0.03705505281686783, + 0.048999205231666565, + 0.017187459394335747, + -0.04760497808456421, + -0.06534644961357117, + 0.027641354128718376, + -0.02722003310918808, + -0.09557735174894333, + 0.2721945643424988, + 0.06861108541488647, + -0.17862513661384583, + 0.029542427510023117, + -0.028343068435788155, + -0.24357359111309052, + 0.2928915321826935, + 0.6317090392112732, + -0.5675624012947083, + -0.31298428773880005, + 0.119928739964962, + -0.04503166303038597, + 0.1997436285018921, + 0.9068917632102966, + -0.6105388402938843, + -1.176649808883667, + 0.391012579202652, + 0.21436090767383575, + 0.06404570490121841, + 0.4306352436542511, + -0.18372972309589386, + -1.6093186140060425, + 0.5129231810569763, + 0.8333584666252136, + -0.11607109010219574, + 0.024050598964095116, + -0.027272621169686317, + -0.8072280883789062, + 0.15613007545471191, + 1.0115277767181396, + -0.1886059194803238, + -0.1662863790988922, + -0.07484262436628342, + -0.11359186470508575, + -0.05765556916594505, + 0.48085057735443115, + 0.031143836677074432, + -0.20803743600845337, + 0.005643316078931093, + -0.011422591283917427, + -0.02063453011214733, + 0.010139239020645618, + 0.026931140571832657, + 0.02650240994989872, + 0.014503400772809982, + -0.030498046427965164, + 0.01038119662553072, + -0.041832923889160156, + -0.11747029423713684, + 0.24838468432426453, + 0.08126607537269592, + -0.17684465646743774, + 0.009867151267826557, + -0.04349489137530327, + -0.22892898321151733, + 0.3097872734069824, + 0.6229272484779358, + -0.5710748434066772, + -0.2540203332901001, + 0.15970031917095184, + -0.05765099450945854, + 0.24631772935390472, + 0.9121918678283691, + -0.6539115309715271, + -1.1680796146392822, + 0.43742635846138, + 0.1981748640537262, + 0.060766786336898804, + 0.48115089535713196, + -0.2704729437828064, + -1.668082594871521, + 0.6258481740951538, + 0.8217618465423584, + -0.17844447493553162, + 0.07583325356245041, + -0.031355466693639755, + -0.884739100933075, + 0.21298757195472717, + 1.0279508829116821, + -0.2118954360485077, + -0.16616611182689667, + -0.025157395750284195, + -0.11329160630702972, + -0.08147483319044113, + 0.46636614203453064, + 0.023730026558041573, + -0.21343427896499634, + -0.015201984904706478, + -0.00498165050521493, + 0.022955382242798805, + 0.020228328183293343, + -0.029405873268842697, + -0.032065436244010925, + 0.047389160841703415, + -0.01793060638010502, + 0.01669210195541382, + 0.05227159336209297, + -0.11703876405954361, + 0.006789325270801783, + 0.03741219639778137, + -0.04651298373937607, + -0.012846981175243855, + 0.024231625720858574, + -0.13399703800678253, + -0.024073680862784386, + 0.2970501184463501, + -0.1497301310300827, + -0.04287628084421158, + 0.08405227214097977, + -0.06020639091730118, + -0.01648692972958088, + 0.4150170087814331, + -0.17000712454319, + -0.43461430072784424, + 0.27202337980270386, + 0.006708468310534954, + -0.04474359005689621, + 0.15199843049049377, + -0.03348325565457344, + -0.6591396331787109, + 0.4057810306549072, + 0.25226324796676636, + -0.16070741415023804, + 0.03464199975132942, + 0.023064177483320236, + -0.35642316937446594, + 0.22774185240268707, + 0.37138837575912476, + -0.24171461164951324, + -0.023513946682214737, + 0.028774995356798172, + -0.02702418342232704, + -0.012504744343459606, + 0.17893734574317932, + -0.1554262489080429, + -0.09501983970403671, + 0.06177212670445442, + -0.013536165468394756, + 0.012441401369869709, + 0.006566522642970085, + -0.018207622691988945, + 0.003373368876054883, + -0.034891802817583084, + 0.002223123563453555, + 0.006169564090669155, + 0.022658145055174828, + -0.005327044054865837, + -0.023764559999108315, + -0.004386506043374538, + -0.02777106687426567, + 0.01950058527290821, + 0.004401096608489752, + 0.02882237359881401, + 0.01790205016732216, + -0.007827110588550568, + -0.005222277250140905, + -0.05361752584576607, + 0.008359426632523537, + -0.026494475081562996, + -0.015572195872664452, + -0.04412947595119476, + -0.006163781508803368, + 0.180303692817688, + 0.17117105424404144, + -0.014117442071437836, + 0.014543564058840275, + 0.03875281661748886, + 0.002004631096497178, + 0.11982911080121994, + 0.609316349029541, + 0.5792325735092163, + 0.10267578810453415, + -0.02287464588880539, + -0.011516223661601543, + -0.02587946131825447, + 0.019127164036035538, + 0.2742871046066284, + 0.23896890878677368, + -0.013414637185633183, + 0.012439075857400894, + 0.01148916780948639, + 0.0024075021501630545, + -0.028374193236231804, + -0.02938784286379814, + -0.061723873019218445, + -0.03288640081882477, + 0.010918691754341125, + 0.01171314436942339, + 0.00894222967326641, + -0.0050367508083581924, + 0.00322812981903553, + -0.01958087645471096, + 0.000401448953198269, + 0.00655051926150918, + 0.008647873997688293, + -0.015351405367255211, + -0.022286182269454002, + -0.0018973759142681956, + -0.032965533435344696, + 0.009401706047356129, + 0.01680464670062065, + 0.01722409576177597, + 0.017367251217365265, + -0.0012145076179876924, + 0.015895379707217216, + -0.013976357877254486, + 0.01587546430528164, + -0.019388504326343536, + -0.004597584251314402, + -0.026080038398504257, + 0.020517753437161446, + 0.20680218935012817, + 0.20302064716815948, + 0.03813354671001434, + 0.027738921344280243, + 0.02183712273836136, + 0.023807305842638016, + 0.14632326364517212, + 0.5991678237915039, + 0.608651340007782, + 0.15929070115089417, + -0.02112223394215107, + -0.020013611763715744, + -0.03723381832242012, + 0.032139480113983154, + 0.27032363414764404, + 0.24862462282180786, + 0.02374681644141674, + 0.007894856855273247, + 0.00042308925185352564, + -0.004832752980291843, + -0.024313796311616898, + -0.0018940505106002092, + -0.02681432105600834, + 0.002362651750445366, + 0.013330202549695969, + 0.012553646229207516, + 0.002630018163472414, + 0.002979951212182641, + 0.0015847217291593552, + -0.03376828506588936, + -0.010844729840755463, + -0.002748559694737196, + 0.012938202358782291, + -0.011872833594679832, + -0.0025761008728295565, + 0.003677211469039321, + -0.04305516183376312, + 0.001133457524701953, + 0.0020396243780851364, + 0.01797032356262207, + 0.016580887138843536, + 0.04445189982652664, + 0.013270077295601368, + -0.04839251935482025, + 0.011546633206307888, + -0.015829432755708694, + 0.019473392516374588, + -0.011464826762676239, + 0.018693143501877785, + 0.18201367557048798, + 0.16157257556915283, + 0.02082117274403572, + 0.015915032476186752, + 0.010720869526267052, + -0.0020238866563886404, + 0.09329187124967575, + 0.46998023986816406, + 0.5186727046966553, + 0.09814783185720444, + -0.016547314822673798, + 0.00325066689401865, + -0.028936590999364853, + 0.01002424769103527, + 0.21822214126586914, + 0.22012007236480713, + 0.008229314349591732, + 0.015599996782839298, + 0.014740276150405407, + 0.0019725109450519085, + 0.003613655688241124, + -0.03043546713888645, + -0.06308998167514801, + 0.014664110727608204, + 0.06775129586458206, + -0.12990300357341766, + -0.03638269379734993, + -0.03883139044046402, + 0.05194637551903725, + 0.03896122798323631, + -0.05132362246513367, + -0.07234688848257065, + -0.36106064915657043, + -0.2839237451553345, + -0.11496391147375107, + 0.3026673197746277, + 0.3528609871864319, + 0.21559017896652222, + -0.11970120668411255, + -0.5473688244819641, + -0.5362005233764648, + -0.21015112102031708, + 0.4089161455631256, + 0.6033567786216736, + 0.38614287972450256, + -0.12437233328819275, + -0.6394402384757996, + -0.6945835947990417, + -0.3482857942581177, + 0.5189254283905029, + 0.8457668423652649, + 0.6248002648353577, + -0.12700730562210083, + -0.6978924870491028, + -0.7764106392860413, + -0.4171960651874542, + 0.44747814536094666, + 0.8406224846839905, + 0.6821274161338806, + -0.07793218642473221, + -0.5459966659545898, + -0.6139025092124939, + -0.35998886823654175, + 0.27800890803337097, + 0.6048891544342041, + 0.591307520866394, + -0.04850815609097481, + -0.3863481283187866, + -0.3542836606502533, + -0.2491992861032486, + 0.1616278886795044, + 0.3402666747570038, + 0.4610227644443512, + -0.010262396186590195, + 0.0408165417611599, + 0.006382474210113287, + -0.011430315673351288, + -0.027895113453269005, + -0.009767768904566765, + 0.005882019177079201, + 0.05225436016917229, + 0.0415218211710453, + 0.08244743943214417, + 0.026765575632452965, + -0.05404946208000183, + -0.06101839989423752, + -0.028233220800757408, + 0.03128793090581894, + 0.07133004069328308, + 0.0718698799610138, + 0.042146697640419006, + -0.08380170166492462, + -0.09263177216053009, + -0.07569421827793121, + 0.032425008714199066, + 0.12351400405168533, + 0.09103626012802124, + -0.004768018145114183, + -0.05960838869214058, + -0.11922567337751389, + -0.10132396221160889, + 0.044341862201690674, + 0.100867860019207, + 0.09607693552970886, + -0.00129030947573483, + -0.05481477826833725, + -0.1278291642665863, + -0.12058380991220474, + 0.016678951680660248, + 0.09958931058645248, + 0.08456224203109741, + 0.061599165201187134, + -0.049776893109083176, + -0.11354166269302368, + -0.09844806790351868, + 0.004753128159791231, + 0.07868346571922302, + 0.06464104354381561, + 0.020981626585125923, + -0.010770543478429317, + -0.08838209509849548, + -0.07265795767307281, + -0.058313023298978806, + 0.10897739976644516, + 0.026735201478004456, + 0.03972309082746506, + -0.019998662173748016, + -0.048948734998703, + 0.03377270698547363, + 0.053406376391649246, + 0.27304399013519287, + 0.20850272476673126, + 0.07890326529741287, + -0.22241365909576416, + -0.2816997468471527, + -0.1745096743106842, + 0.08957889676094055, + 0.4962941110134125, + 0.4586986303329468, + 0.20177948474884033, + -0.3625744581222534, + -0.47758376598358154, + -0.32412785291671753, + 0.0669194757938385, + 0.5394997596740723, + 0.601328432559967, + 0.24388420581817627, + -0.4319041073322296, + -0.6893490552902222, + -0.5106037259101868, + 0.10174300521612167, + 0.5457565784454346, + 0.6549625992774963, + 0.38772058486938477, + -0.3778320252895355, + -0.6820934414863586, + -0.551069438457489, + 0.049600999802351, + 0.45137161016464233, + 0.5143972039222717, + 0.3713279068470001, + -0.26546329259872437, + -0.5121409893035889, + -0.47691628336906433, + 0.03843758627772331, + 0.30808231234550476, + 0.3185756504535675, + 0.22629432380199432, + -0.14860986173152924, + -0.2915389835834503, + -0.3552006185054779, + -0.003137432038784027, + -0.01327254343777895, + -0.027139298617839813, + 0.04800891876220703, + 0.05380738899111748, + -0.01380784809589386, + 0.0022881641052663326, + -0.012132279574871063, + 0.06182793900370598, + 0.03762871399521828, + 0.0966145321726799, + 0.08963571488857269, + 0.06551238149404526, + 0.031640589237213135, + -0.010532311163842678, + 0.07195396721363068, + 0.11343465745449066, + 0.11621421575546265, + 0.047318290919065475, + 0.1111951395869255, + 0.044054243713617325, + 0.016777141019701958, + 0.03392713516950607, + 0.06047024950385094, + -0.7924502491950989, + -0.7310910224914551, + 0.031088173389434814, + 0.0906061977148056, + 0.022829236462712288, + 0.04470035433769226, + 0.025999872013926506, + -0.8246837258338928, + -0.723675549030304, + 0.15835590660572052, + 0.07358791679143906, + -0.015819497406482697, + -0.014207872562110424, + 0.08506257086992264, + 0.08868777751922607, + 0.0976945012807846, + 0.11740022897720337, + 0.016287995502352715, + -0.024363648146390915, + 0.04249691963195801, + 0.02909177541732788, + 0.12011238187551498, + 0.10729824751615524, + 0.05927390977740288, + 0.04731644690036774, + 0.008210064843297005, + 0.03859357163310051, + -0.005175672471523285, + 0.01984376832842827, + -0.0011626111809164286, + -0.0010909241391345859, + 0.02311880886554718, + 0.007646523881703615, + 0.04582137614488602, + -0.0027255103923380375, + 0.027656713500618935, + 0.02781369723379612, + 0.015750093385577202, + 0.040563344955444336, + -0.007784596644341946, + 0.006534814368933439, + 0.002403199439868331, + -0.020037032663822174, + -0.011717663146555424, + 0.07826739549636841, + 0.018203573301434517, + 0.021228624507784843, + 0.014112413860857487, + -0.02866269089281559, + -0.9502679109573364, + -0.825043797492981, + 0.05938851460814476, + 0.06553053110837936, + 0.015418429858982563, + 0.0616452619433403, + -0.0094453701749444, + -0.9471839666366577, + -0.7922234535217285, + 0.13069523870944977, + 0.04939320683479309, + 0.007429714780300856, + 0.022599652409553528, + 0.0820123627781868, + 0.06440276652574539, + 0.09897352755069733, + 0.0856291800737381, + 0.006608777679502964, + -0.0005533680086955428, + 0.021656949073076248, + 0.014818831346929073, + 0.03757459297776222, + -0.001428246614523232, + 0.03473127633333206, + 0.03607869893312454, + 0.017313262447714806, + 0.0025767614133656025, + -0.033292777836322784, + 0.027883101254701614, + -0.007534499745815992, + -0.04302362725138664, + -0.01795666106045246, + -0.007667913101613522, + 0.012547189369797707, + -0.021762438118457794, + 0.03789107874035835, + 0.06384614109992981, + 0.0014223429607227445, + -0.01393786258995533, + -0.041693057864904404, + -0.01813604310154915, + 0.065328449010849, + 0.15736474096775055, + 0.1531635969877243, + 0.09920474886894226, + -0.04044449329376221, + 0.010558396577835083, + 0.05559245124459267, + 0.10931257158517838, + -0.5784384608268738, + -0.5109886527061462, + 0.17690584063529968, + 0.07484250515699387, + 0.010378374718129635, + 0.0890144556760788, + 0.13172735273838043, + -0.6058865785598755, + -0.49908995628356934, + 0.1835336685180664, + 0.005293308291584253, + -0.03870566934347153, + -0.025229454040527344, + 0.12571711838245392, + 0.14792272448539734, + 0.14905226230621338, + 0.0700206533074379, + -0.035034529864788055, + 0.013128797523677349, + 0.015581230632960796, + 0.005400130525231361, + 0.07070232182741165, + 0.03829728811979294, + -0.013876918703317642, + -0.019958000630140305, + -0.020086020231246948, + -0.019999003037810326, + -0.015111410059034824, + 0.11963249742984772, + -0.08270428329706192, + -0.0025947154499590397, + -0.010668564587831497, + 0.016670405864715576, + -0.03206938877701759, + -0.053453829139471054, + 0.1236601173877716, + -0.020077411085367203, + 0.00779569149017334, + -0.0318986251950264, + 0.03579804673790932, + -0.060723867267370224, + -0.009301809594035149, + 0.09249342232942581, + -0.13378725945949554, + 0.17496798932552338, + -0.0935625433921814, + 0.06569044291973114, + -0.18187756836414337, + 0.06397300213575363, + 0.3793930113315582, + -0.5664302706718445, + 0.23658618330955505, + -0.03206830099225044, + 0.03155658766627312, + 0.039305318146944046, + -0.6008145213127136, + 1.0417630672454834, + -0.5062726140022278, + -0.04698493704199791, + 0.0979752242565155, + -0.037326715886592865, + 0.26255178451538086, + -0.590207576751709, + 0.4195419251918793, + 0.12212422490119934, + -0.26122942566871643, + 0.06442253291606903, + -0.07682429254055023, + 0.12608948349952698, + -0.13872937858104706, + -0.030260663479566574, + 0.2047160565853119, + -0.13068141043186188, + 0.016608506441116333, + -0.021629147231578827, + 0.04659907519817352, + 0.024417348206043243, + 0.06751634925603867, + -0.1705978959798813, + 0.0655774399638176, + -0.0041802311316132545, + -0.02263445220887661, + -0.014069054275751114, + 0.06242800131440163, + 0.08984102308750153, + -0.19382472336292267, + 0.09380361437797546, + -0.0032764992211014032, + -0.03950225189328194, + -0.08896161615848541, + 0.28387022018432617, + 0.1668996810913086, + -0.5457127094268799, + 0.21796099841594696, + 0.012032964266836643, + 0.030721815302968025, + -0.4431600570678711, + 0.3104412257671356, + 1.0070439577102661, + -1.1077969074249268, + 0.08187273889780045, + 0.1387241780757904, + 0.09014563262462616, + -0.25378379225730896, + -0.9253583550453186, + 1.9745515584945679, + -0.6605072617530823, + -0.4394792318344116, + 0.11501576751470566, + 0.03007262572646141, + 0.2538164258003235, + -1.1462018489837646, + 0.7988958954811096, + 0.46934643387794495, + -0.4244523048400879, + -0.0001816617150325328, + -0.04351970925927162, + 0.20500127971172333, + -0.40710335969924927, + -0.15871365368366241, + 0.4640160799026489, + -0.06024328991770744, + -0.016036653891205788, + -0.012419192120432854, + 0.05552554875612259, + 0.050986770540475845, + -0.0171927809715271, + -0.12105240672826767, + 0.03947274759411812, + 0.009537882171571255, + -0.026668362319469452, + 0.017273351550102234, + 0.10812800377607346, + -0.015008139424026012, + -0.14154496788978577, + 0.08008233457803726, + -0.01306608971208334, + -0.05574854835867882, + -0.06091056764125824, + 0.2888447940349579, + 0.05022002384066582, + -0.4581625759601593, + 0.21146118640899658, + -0.01495362538844347, + 0.02946372702717781, + -0.38554418087005615, + 0.30167311429977417, + 0.7605867981910706, + -0.898481547832489, + 0.11953620612621307, + 0.12686115503311157, + 0.09949854761362076, + -0.14409342408180237, + -0.7404491901397705, + 1.5449001789093018, + -0.5307857394218445, + -0.3347839415073395, + 0.09940771013498306, + 0.009087899699807167, + 0.3081797957420349, + -0.9053899049758911, + 0.5102643370628357, + 0.4646914303302765, + -0.36200836300849915, + -0.043260715901851654, + -0.05309509113430977, + 0.22480911016464233, + -0.2674587666988373, + -0.25316888093948364, + 0.435017466545105, + -0.017485838383436203, + -0.049459364265203476, + 0.012460661120712757, + -0.02262282371520996, + -0.04392899200320244, + 0.013330060057342052, + 0.05963548645377159, + -0.020561739802360535, + -0.013496879488229752, + -0.02310933545231819, + -0.06549905985593796, + 0.12132573872804642, + 0.22165189683437347, + -0.07683887332677841, + -0.12427931278944016, + 0.05543455854058266, + 0.009089780040085316, + 0.19844494760036469, + 0.07650767266750336, + -0.48934996128082275, + -0.35080164670944214, + 0.13422781229019165, + 0.022217294201254845, + -0.006589306052774191, + -0.18357548117637634, + -0.6055922508239746, + 0.09492127597332001, + 0.7073907256126404, + 0.1777055710554123, + -0.05434347689151764, + 0.04566245526075363, + -0.023967979475855827, + 0.4856843054294586, + 0.8131930828094482, + -0.2068077027797699, + -0.3863125145435333, + 0.02887917123734951, + -0.05048410966992378, + 0.051201049238443375, + 0.057671088725328445, + -0.6412642002105713, + -0.39739903807640076, + 0.11036981642246246, + 0.06687764078378677, + -0.018151026219129562, + 0.0022760110441595316, + -0.09328305721282959, + 0.1352599710226059, + 0.19680921733379364, + 0.032235175371170044, + -0.06123670935630798, + -0.013810456730425358, + -0.01821190118789673, + -0.029903864488005638, + 0.027588335797190666, + 0.0762094110250473, + -0.046041399240493774, + 0.017117975279688835, + -0.018925148993730545, + 0.00423092395067215, + 0.2065701186656952, + 0.157025545835495, + -0.26491472125053406, + -0.24569831788539886, + 0.0873267725110054, + 0.004694689530879259, + 0.1838335543870926, + -0.18973900377750397, + -0.9744532108306885, + -0.41959065198898315, + 0.409589946269989, + 0.22223009169101715, + -0.0989728644490242, + -0.40883490443229675, + -0.8418471813201904, + 0.40256521105766296, + 1.4742398262023926, + 0.4913789629936218, + -0.14741277694702148, + -0.0028576564509421587, + 0.0861843004822731, + 1.0056577920913696, + 1.479182481765747, + -0.21940617263317108, + -0.8383130431175232, + -0.30560192465782166, + 0.12028121203184128, + 0.24013034999370575, + 0.11750353127717972, + -1.1071972846984863, + -0.9066778421401978, + -0.055051110684871674, + 0.15361995995044708, + 0.0032418384216725826, + -0.08823435008525848, + -0.3188804090023041, + -0.02160414680838585, + 0.2972750663757324, + 0.17006494104862213, + 0.03401973098516464, + 0.017106015235185623, + 0.010733614675700665, + 0.004688877146691084, + 0.02985573373734951, + 0.046415988355875015, + -0.05177726596593857, + -0.04624386876821518, + 0.026672907173633575, + 0.03479000926017761, + 0.22761401534080505, + 0.12049756944179535, + -0.23494181036949158, + -0.2207801640033722, + 0.06036320701241493, + 0.02112250216305256, + 0.16173022985458374, + -0.14196650683879852, + -0.8236543536186218, + -0.3530665934085846, + 0.3715725541114807, + 0.25781863927841187, + -0.09806561470031738, + -0.341796338558197, + -0.7201419472694397, + 0.2111824005842209, + 1.1648427248001099, + 0.3866075575351715, + -0.1955428272485733, + -0.13164694607257843, + -0.06048528477549553, + 0.7989920973777771, + 1.143347144126892, + -0.19509637355804443, + -0.6719933152198792, + -0.26912447810173035, + 0.16733723878860474, + 0.32526257634162903, + 0.1910397708415985, + -0.8516904711723328, + -0.6005953550338745, + 0.10627525299787521, + 0.16700856387615204, + 0.032433755695819855, + -0.11345972120761871, + -0.270126610994339, + -0.012052524834871292, + 0.25489771366119385, + 0.14647918939590454, + -0.014324051328003407, + -0.011148945428431034, + -0.0011708218371495605, + -0.018903911113739014, + -0.010648071765899658, + -0.017981043085455894, + 0.014055400155484676, + -0.020784996449947357, + -0.030126383528113365, + 0.1150858998298645, + -0.1112036183476448, + -0.023664508014917374, + 0.1651369333267212, + -0.055412910878658295, + -0.007318025920540094, + -0.07404221594333649, + 0.3068569302558899, + -0.6175673007965088, + 0.35226404666900635, + 0.1940349042415619, + -0.22921296954154968, + 0.06411048769950867, + 0.001689439988695085, + 0.23336739838123322, + -0.9470900893211365, + 1.2042961120605469, + -0.44587329030036926, + -0.15847182273864746, + 0.07572423666715622, + 0.11138042062520981, + -0.2075018584728241, + -0.2651064693927765, + 0.8896074295043945, + -0.7130936980247498, + 0.10370831191539764, + 0.07730382680892944, + 0.02368813008069992, + -0.20520009100437164, + 0.13611918687820435, + 0.31062978506088257, + -0.471883624792099, + 0.21489326655864716, + -0.0216743852943182, + -0.04020361602306366, + -0.022920167073607445, + 0.16054102778434753, + -0.002624030224978924, + -0.14670424163341522, + 0.12018264085054398, + -0.043656397610902786, + -0.005084550939500332, + 0.03873870149254799, + -0.07967288792133331, + -0.007439201697707176, + 0.027688704431056976, + 0.08916077762842178, + -0.0036629599053412676, + -0.01389122661203146, + 0.1402083784341812, + -0.2923351228237152, + -0.01932896114885807, + 0.224355086684227, + -0.013193303719162941, + -0.03984276205301285, + -0.04474477842450142, + 0.3302844762802124, + -0.9746807217597961, + 0.5603556036949158, + 0.3556183874607086, + -0.2713812589645386, + 0.01890619471669197, + 0.06983876973390579, + 0.09052442759275436, + -1.3613605499267578, + 1.8220031261444092, + -0.40902698040008545, + -0.31302449107170105, + 0.03893759846687317, + 0.11448371410369873, + -0.4220678210258484, + -0.3677598237991333, + 1.539440631866455, + -0.8297391533851624, + -0.08504960685968399, + 0.0629446730017662, + -0.016804160550236702, + -0.31778836250305176, + 0.2363198846578598, + 0.6452136635780334, + -0.700931191444397, + 0.09927428513765335, + 0.0019635935313999653, + -0.05397690460085869, + -0.014552262611687183, + 0.2352754771709442, + 0.09991656988859177, + -0.28891685605049133, + 0.07818552106618881, + -0.021534763276576996, + -0.009461677633225918, + -0.01069199200719595, + -0.008059840649366379, + -0.0129952197894454, + 0.038492631167173386, + 0.018906958401203156, + -0.025432486087083817, + -0.03420932963490486, + 0.09104404598474503, + -0.10342919826507568, + -0.035048507153987885, + 0.1415904313325882, + -0.052986644208431244, + -0.021596742793917656, + -0.049690280109643936, + 0.3079117238521576, + -0.5487046837806702, + 0.27024003863334656, + 0.15158434212207794, + -0.16488635540008545, + 0.027642132714390755, + 0.004561549983918667, + 0.21555493772029877, + -0.9188903570175171, + 1.0972669124603271, + -0.3528037667274475, + -0.07574182748794556, + 0.021962830796837807, + 0.08826783299446106, + -0.18681983649730682, + -0.2789378762245178, + 0.864517331123352, + -0.5642455816268921, + 0.07469761371612549, + 0.03803368657827377, + 0.014268620871007442, + -0.17712704837322235, + 0.1349189728498459, + 0.3181247115135193, + -0.45067182183265686, + 0.1391848623752594, + 0.009777083061635494, + -0.028080958873033524, + -0.03586730733513832, + 0.14503192901611328, + -0.014655024744570255, + -0.1472700834274292, + 0.07361634075641632, + -0.0029754601418972015, + -0.006887470372021198, + -0.019166842103004456, + 0.0034907464869320393, + -0.015169994905591011, + 0.053831856697797775, + -0.028789488598704338, + -0.02033298648893833, + 0.0018537036376073956, + 0.07567961513996124, + -0.07041627168655396, + -0.047083087265491486, + 0.17573483288288116, + -0.04860217124223709, + 0.013171656988561153, + 0.020158233121037483, + -0.006270059384405613, + -0.28434091806411743, + 0.2760852873325348, + 0.32198208570480347, + -0.43535903096199036, + 0.03188510239124298, + 0.019360313192009926, + -0.20063988864421844, + 0.04450676590204239, + 0.9678076505661011, + -0.683987021446228, + -0.3979112207889557, + 0.2618143558502197, + -0.049711134284734726, + -0.06456997990608215, + 0.6518288850784302, + -0.1357039213180542, + -1.1304017305374146, + 0.4881652295589447, + 0.19583553075790405, + -0.03677722439169884, + 0.21429045498371124, + 0.09559855610132217, + -0.7311355471611023, + 0.10988117009401321, + 0.4949330687522888, + -0.17359353601932526, + 0.03822369873523712, + 0.011371256783604622, + -0.1900172382593155, + -0.04778448864817619, + 0.2897090017795563, + -0.02235160581767559, + -0.05582524091005325, + 0.007624597754329443, + -0.027456223964691162, + -0.029680097475647926, + -0.023810429498553276, + 0.15409281849861145, + 0.013284318149089813, + -0.0788225457072258, + -0.025637971237301826, + 0.01406402699649334, + -0.13676859438419342, + 0.027384959161281586, + 0.30458444356918335, + -0.11150643229484558, + -0.06806201487779617, + 0.009601237252354622, + -0.0866582989692688, + -0.2328706979751587, + 0.5188567638397217, + 0.3787381649017334, + -0.655829906463623, + 0.0072118742391467094, + -0.0031494891736656427, + -0.2424815446138382, + 0.28893929719924927, + 1.2396824359893799, + -1.0406886339187622, + -0.6376030445098877, + 0.4103420078754425, + -0.05929668992757797, + 0.03918358311057091, + 0.9274081587791443, + -0.28890565037727356, + -1.6682262420654297, + 0.66976398229599, + 0.35488471388816833, + 0.027932289987802505, + 0.3169145882129669, + 0.09107685089111328, + -1.2099432945251465, + 0.11623579263687134, + 0.7632684707641602, + -0.16506360471248627, + 0.037474747747182846, + -0.005203985143452883, + -0.35939401388168335, + -0.17138688266277313, + 0.525232195854187, + 0.10247340798377991, + -0.14317406713962555, + 0.007572649512439966, + -0.006046198774129152, + 0.06188087910413742, + -0.050851333886384964, + 0.032844241708517075, + 0.0544477179646492, + -0.07947597652673721, + -0.03073730878531933, + 0.04025515541434288, + -0.010001083835959435, + -0.11831062287092209, + 0.17422229051589966, + -0.05468267202377319, + -0.04996664077043533, + 0.023996006697416306, + 0.02888253889977932, + -0.18709556758403778, + 0.13987921178340912, + 0.32867854833602905, + -0.31714990735054016, + 0.019951285794377327, + 0.027247004210948944, + -0.19416090846061707, + -0.006519266404211521, + 0.7540720105171204, + -0.5474190711975098, + -0.27137213945388794, + 0.20772530138492584, + -0.042619917541742325, + -0.09566087275743484, + 0.548494815826416, + -0.1599852293729782, + -0.9178788661956787, + 0.5456539988517761, + 0.07497559487819672, + 0.003984459210187197, + 0.18640351295471191, + 0.12121234089136124, + -0.7249511480331421, + 0.2559764087200165, + 0.4684237241744995, + -0.19216996431350708, + 0.018075481057167053, + 0.02684594877064228, + -0.221074178814888, + -0.09164194762706757, + 0.3596596121788025, + -0.08310746401548386, + -0.10815230011940002, + -0.015406409278512001, + -0.011985878460109234, + 0.028467312455177307, + -0.0879230722784996, + 0.0347294844686985, + 0.05081191286444664, + 0.00362736196257174, + 0.010529003106057644, + -0.002672453410923481, + 0.025318201631307602, + -0.06232529878616333, + 0.008822780102491379, + 0.06744717806577682, + 0.003999210894107819, + -0.0022885131184011698, + -0.046704765409231186, + 0.13673964142799377, + -0.2590992748737335, + -0.022161437198519707, + 0.258914053440094, + -0.10650330036878586, + 0.023435762152075768, + 0.06992689520120621, + 0.03760937228798866, + -0.5444027185440063, + 0.4131152629852295, + 0.25325170159339905, + -0.2482522875070572, + 0.010479461401700974, + 0.045747850090265274, + -0.1541248857975006, + -0.35291528701782227, + 0.9078133702278137, + -0.34428781270980835, + -0.14787709712982178, + -0.024105649441480637, + -0.007651817053556442, + -0.14991067349910736, + 0.17544956505298615, + 0.3692120611667633, + -0.46861159801483154, + 0.10201738774776459, + 0.003734431229531765, + -0.010433703660964966, + 0.022045455873012543, + 0.0944862961769104, + 0.01679016835987568, + -0.16537833213806152, + 0.07900089025497437, + -0.004211293533444405, + -0.01076442189514637, + 0.09729930013418198, + -0.1490965485572815, + -0.02511671558022499, + 0.0766475573182106, + 0.010980346240103245, + -0.010220799595117569, + -0.0004861881607212126, + 0.09204736351966858, + -0.179045170545578, + -0.025164175778627396, + 0.15608654916286469, + 0.004787537269294262, + -0.0005253870622254908, + 0.034556396305561066, + 0.1509256660938263, + -0.5432079434394836, + -0.03155849874019623, + 0.513609766960144, + -0.14458952844142914, + 0.015178131870925426, + 0.09172039479017258, + -0.12612608075141907, + -0.926306962966919, + 0.8281942009925842, + 0.5954549908638, + -0.492740273475647, + 0.007195526268333197, + -0.018258413299918175, + -0.4074647128582001, + -0.43008187413215637, + 1.7370752096176147, + -0.350849986076355, + -0.5158001780509949, + -0.017458094283938408, + -0.08306471258401871, + -0.2334563285112381, + 0.445117712020874, + 0.7808031439781189, + -0.7913723587989807, + -0.11814796179533005, + -0.00913319457322359, + 0.0223994143307209, + 0.1012248545885086, + 0.25349485874176025, + 0.028286214917898178, + -0.4809858798980713, + 0.05953341722488403, + 0.015634188428521156, + 0.005101620219647884, + 0.10901974141597748, + -0.11964976042509079, + -0.09117673337459564, + 0.0734483003616333, + 0.01821213960647583, + 5.350751234800555e-05, + -0.020279232412576675, + 0.1097220927476883, + -0.1354990452528, + -0.08653146773576736, + 0.11775246262550354, + -0.012575668282806873, + 0.0310806967318058, + 0.010271146893501282, + 0.20337054133415222, + -0.3854014277458191, + -0.09943562000989914, + 0.3921409249305725, + -0.08432158827781677, + 0.010676748119294643, + 0.040244489908218384, + -0.0015478944405913353, + -0.7022866010665894, + 0.49858638644218445, + 0.42338883876800537, + -0.2982582449913025, + -0.005396307446062565, + -0.008777705952525139, + -0.2325415015220642, + -0.4083922803401947, + 1.186205506324768, + -0.26399391889572144, + -0.2621048092842102, + -0.015712907537817955, + -0.04675402492284775, + -0.1797540783882141, + 0.2992522716522217, + 0.4747498333454132, + -0.5266988277435303, + 0.04581758379936218, + -0.04037958011031151, + 0.0071074217557907104, + 0.047499995678663254, + 0.16617828607559204, + -0.03973710536956787, + -0.2953551113605499, + 0.10628587752580643, + -0.00904526561498642, + 0.010427894070744514, + 0.08035022020339966, + 0.03841109946370125, + -0.06335253268480301, + -0.06992083787918091, + 0.015409895218908787, + -0.026900725439190865, + -0.04523912072181702, + 0.08087682723999023, + 0.12542113661766052, + 0.018750213086605072, + -0.23430712521076202, + 0.11755944788455963, + -0.019747508689761162, + -0.03171322122216225, + -0.12132623791694641, + 0.2640603184700012, + 0.38445138931274414, + -0.5724408030509949, + 0.15661633014678955, + 0.01949799247086048, + -0.021771302446722984, + -0.18984957039356232, + -0.23499636352062225, + 1.2112919092178345, + -0.7037869095802307, + -0.14260035753250122, + 0.01848726160824299, + 0.06443414837121964, + -0.11740390956401825, + -0.8794785141944885, + 1.4160369634628296, + 0.016899125650525093, + -0.5444768071174622, + 0.017313210293650627, + 0.0508052259683609, + 0.11102095246315002, + -0.790285587310791, + 0.3501206636428833, + 0.7238660454750061, + -0.49468666315078735, + -0.019021952524781227, + -0.01212992612272501, + 0.15032203495502472, + -0.3573611080646515, + -0.1293754130601883, + 0.45295456051826477, + -0.08407819271087646, + -0.008717959746718407, + 0.022566653788089752, + -0.012640242464840412, + 0.03181227669119835, + 0.0638526976108551, + -0.058120664209127426, + -0.042917650192976, + 0.02129550836980343, + -0.018790805712342262, + -0.00655191857367754, + 0.05951414257287979, + 0.12890471518039703, + -0.1886381357908249, + 0.059096939861774445, + -0.016928592696785927, + 0.02327263168990612, + -0.17282842099666595, + 0.13812857866287231, + 0.38889989256858826, + -0.5282873511314392, + 0.07564643770456314, + -0.006128210574388504, + -0.00876594614237547, + -0.18427829444408417, + -0.26697441935539246, + 1.2529815435409546, + -0.6549165844917297, + -0.2111111879348755, + 0.011410325765609741, + 0.07089994102716446, + -0.12627695500850677, + -0.8245998024940491, + 1.4581915140151978, + -0.01822204887866974, + -0.5626582503318787, + -0.01661459542810917, + 0.03759436681866646, + 0.10841676592826843, + -0.7652962803840637, + 0.4360819458961487, + 0.7012669444084167, + -0.47011038661003113, + 0.01529701892286539, + -0.0033166150096803904, + 0.12170535326004028, + -0.3871544301509857, + -0.05247795954346657, + 0.4504147171974182, + -0.11442532390356064, + -0.00882577896118164, + 0.005190832540392876, + -0.05153197422623634, + 0.0055236960761249065, + 0.09320031106472015, + -0.03762076050043106, + -0.021778371185064316, + 0.00750907463952899, + 0.014965789392590523, + -0.015135630965232849, + -0.037086039781570435, + 0.08020154386758804, + -0.04429963231086731, + 0.0038218852132558823, + -0.01712334342300892, + 0.053772956132888794, + -0.05226677283644676, + -0.024439912289381027, + 0.12774989008903503, + -0.18722355365753174, + 0.0683830976486206, + -0.010828870348632336, + -0.012880662456154823, + 0.02679484151303768, + -0.13696907460689545, + 0.46868517994880676, + -0.322968989610672, + 0.052930932492017746, + 0.009463602676987648, + -0.046861011534929276, + 0.07714711129665375, + -0.35792097449302673, + 0.5517901182174683, + -0.13382655382156372, + -0.12921281158924103, + 0.018562642857432365, + -0.03842621296644211, + 0.10284601897001266, + -0.28243398666381836, + 0.13314206898212433, + 0.20769073069095612, + -0.1551610678434372, + 0.018036767840385437, + -0.03553476929664612, + 0.036686040461063385, + -0.09568552672863007, + 0.008917863480746746, + 0.11340243369340897, + -0.04745811969041824, + 0.005833764094859362, + -0.04174824804067612, + 0.022730106487870216, + 0.0013601485406979918, + -0.07473982870578766, + -0.004801879171282053, + 0.05632775276899338, + -0.04081303998827934, + 0.11509573459625244, + 0.004507652949541807, + -0.24791881442070007, + 0.43171870708465576, + -0.1362573653459549, + -0.10758046060800552, + 0.02746163308620453, + -0.2954745888710022, + 0.30186471343040466, + 0.3135572075843811, + -1.2296111583709717, + 0.8754236102104187, + -0.11699853837490082, + 0.022482017055153847, + 0.24945153295993805, + -0.7858022451400757, + 0.5181443095207214, + 1.4243930578231812, + -1.876152515411377, + 0.4689188003540039, + 0.04258054122328758, + -0.030832920223474503, + 0.9340220093727112, + -1.512351632118225, + -0.3731614947319031, + 2.021338701248169, + -0.7801089286804199, + -0.09288544207811356, + -0.12423597276210785, + -0.36861127614974976, + 1.1679530143737793, + -0.4960964024066925, + -1.0398281812667847, + 0.686152458190918, + 0.02052121050655842, + 0.07246638089418411, + -0.01763315312564373, + -0.37442535161972046, + 0.33217450976371765, + 0.22260302305221558, + -0.2657756209373474, + 0.00016369696822948754, + 0.008136127144098282, + -0.03592197597026825, + 0.022231513634324074, + 0.041430093348026276, + -0.06439317017793655, + 0.03496818616986275, + -0.05143435671925545, + 0.09930871427059174, + 0.017110232263803482, + -0.3834381699562073, + 0.44344815611839294, + -0.00280396337620914, + -0.11487428843975067, + 0.050503507256507874, + -0.22837062180042267, + 0.47540077567100525, + 0.5802375674247742, + -1.7325034141540527, + 0.8587368130683899, + 0.10429240018129349, + -0.02456486038863659, + 0.1340152472257614, + -1.2299835681915283, + 0.7986555099487305, + 2.2204456329345703, + -2.4498374462127686, + 0.33742472529411316, + 0.1001473218202591, + 0.08700849115848541, + 0.9933257102966309, + -2.5278031826019287, + -0.5935835242271423, + 2.710871934890747, + -0.87749183177948, + -0.06125229224562645, + -0.19061818718910217, + -0.04017600044608116, + 1.7519460916519165, + -0.7798219919204712, + -1.28012216091156, + 0.7500321269035339, + 0.02245335467159748, + 0.08263842761516571, + -0.1563340127468109, + -0.3502165377140045, + 0.5060794949531555, + 0.11768018454313278, + -0.2394258826971054, + 0.0027446788735687733, + -0.0012661140644922853, + 0.010839025489985943, + 0.04500429332256317, + -0.04333498701453209, + -0.027386408299207687, + 0.04357098788022995, + -0.04407481476664543, + 0.08443310111761093, + -0.08108946681022644, + -0.20346391201019287, + 0.3825778365135193, + -0.16498182713985443, + -0.04287993535399437, + 0.05340999737381935, + -0.14011172950267792, + 0.29446643590927124, + 0.2738667130470276, + -1.1299961805343628, + 0.7827413082122803, + -0.07552053779363632, + -0.03602323681116104, + 0.16167275607585907, + -0.6924317479133606, + 0.4478289783000946, + 1.2428895235061646, + -1.4833877086639404, + 0.4690392315387726, + -0.00820756796747446, + -0.09873292595148087, + 0.692342221736908, + -1.0981175899505615, + -0.3906446695327759, + 1.438644528388977, + -0.719068169593811, + 0.026173872873187065, + -0.09383898228406906, + -0.3282022774219513, + 1.0363390445709229, + -0.23960772156715393, + -0.7638148069381714, + 0.5488630533218384, + -0.015319733880460262, + 0.11911362409591675, + 0.017409542575478554, + -0.4231888949871063, + 0.23724795877933502, + 0.1191876158118248, + -0.15694500505924225, + -0.03534351661801338, + 0.06342366337776184, + 0.17738288640975952, + 0.012300643138587475, + -0.06408121436834335, + -0.06030220910906792, + 0.0018237337935715914, + 0.07659764587879181, + 0.1820947527885437, + 0.24410061538219452, + -0.06998514384031296, + -0.1491813361644745, + -0.06184092164039612, + 0.04607890918850899, + 0.15362663567066193, + 0.18308304250240326, + 0.08175522834062576, + -0.305602103471756, + -0.2915116548538208, + -0.08144206553697586, + 0.07138665020465851, + -0.03521484509110451, + -0.0914112851023674, + -0.2766699492931366, + -0.6285344362258911, + -0.38168880343437195, + -0.0033710987772792578, + 0.14477019011974335, + -0.03885374590754509, + -0.11367184668779373, + -0.1979650855064392, + -0.3575190007686615, + 0.016150522977113724, + 0.28292712569236755, + 0.2836199402809143, + -0.016672370955348015, + -0.034946177154779434, + -0.014770845882594585, + -0.0004113636096008122, + 0.29938748478889465, + 0.3562523126602173, + 0.13313128054141998, + -0.029499055817723274, + 0.007187174167484045, + 0.0636785551905632, + 0.047712039202451706, + 0.20670579373836517, + 0.10999035090208054, + -0.1150810718536377, + 0.00879934523254633, + -0.009125287644565105, + -0.013732590712606907, + 0.04738131910562515, + 0.0549951009452343, + -0.014094026759266853, + -0.01195482350885868, + -0.017125386744737625, + -0.071754589676857, + -0.023961570113897324, + 0.013098018243908882, + 0.05972208455204964, + -0.032899752259254456, + -0.024354496970772743, + -0.013116234913468361, + -0.05865325778722763, + -0.006360829807817936, + 0.12809234857559204, + 0.14038555324077606, + -0.022946689277887344, + -0.039698828011751175, + 0.05144746974110603, + -0.025034509599208832, + 0.08764739334583282, + 0.24594412744045258, + 0.19307002425193787, + -0.04085381329059601, + -0.020323628559708595, + 0.022060081362724304, + 0.01799374632537365, + 0.09039195626974106, + 0.1681770235300064, + 0.0016234283102676272, + -0.23777234554290771, + -0.11634974926710129, + -0.014439117163419724, + -0.034799374639987946, + 0.0457066111266613, + 0.049919649958610535, + -0.1926913857460022, + -0.2680967450141907, + 0.0018220803467556834, + -0.012749310582876205, + -0.04389086738228798, + 0.0060565415769815445, + -0.012036234140396118, + -0.12737582623958588, + -0.05777670815587044, + 0.09932202100753784, + 0.09969642758369446, + -0.1296343356370926, + -0.2964152693748474, + -0.05487265810370445, + 0.12073978036642075, + 0.06634647399187088, + 0.004042446613311768, + -0.1586746722459793, + -0.6267098784446716, + -0.5184157490730286, + -0.032286129891872406, + 0.28023189306259155, + 0.12663227319717407, + -0.08828771114349365, + -0.2600027620792389, + -0.5287090539932251, + -0.0994620993733406, + 0.7820600271224976, + 0.9638882279396057, + 0.2193463146686554, + -0.13466303050518036, + 0.042050741612911224, + -0.02292742393910885, + 0.7523098587989807, + 1.7435946464538574, + 1.111282229423523, + -0.2104763388633728, + -0.35129284858703613, + 0.08224371820688248, + 0.11167984455823898, + 0.6513852477073669, + 0.9696454405784607, + -0.1501394510269165, + -1.1777327060699463, + -0.7738466262817383, + 0.01114045549184084, + 0.004884988535195589, + 0.2849186658859253, + 0.14232710003852844, + -1.0306764841079712, + -1.2078118324279785, + -0.14658716320991516, + 0.036605384200811386, + 0.0001495486794738099, + 0.12111346423625946, + -0.24653346836566925, + -0.7028710246086121, + -0.18977169692516327, + 0.5171932578086853, + -0.02514370158314705, + 0.0885375589132309, + -0.1023016944527626, + 0.023200739175081253, + 0.11839435249567032, + -0.09749021381139755, + 0.008283962495625019, + 0.0106261121109128, + -0.031724803149700165, + -0.1594654619693756, + 0.433218389749527, + -0.33944255113601685, + 0.14406877756118774, + -0.0339396670460701, + 0.09370072185993195, + -0.35916459560394287, + 0.7577320337295532, + -0.5531823635101318, + -0.016844574362039566, + 0.2994873523712158, + -0.21487002074718475, + -0.16125759482383728, + 0.35567227005958557, + 0.09099612385034561, + -1.3889282941818237, + 1.9466298818588257, + -1.2556309700012207, + 0.4389301836490631, + -0.010665428824722767, + 0.4707520306110382, + -1.4310415983200073, + 2.0986156463623047, + -1.5515614748001099, + 0.3905705511569977, + 0.01881679706275463, + 0.057307951152324677, + -0.29734691977500916, + 0.369127094745636, + -0.05115725100040436, + -0.44008156657218933, + 0.48642784357070923, + -0.13904061913490295, + -0.004375698510557413, + -0.06351548433303833, + 0.256020188331604, + -0.34121274948120117, + 0.22490821778774261, + 0.004067304544150829, + -0.059063635766506195, + -0.010710661299526691, + 0.03514768183231354, + -0.08577805012464523, + 0.05103181675076485, + 0.04276616871356964, + -0.10832246392965317, + 0.03325289487838745, + 0.06318283081054688, + -0.11063538491725922, + -0.062119144946336746, + 0.40978243947029114, + -0.5597845315933228, + 0.34106317162513733, + -0.030269838869571686, + 0.057014383375644684, + -0.44329890608787537, + 1.0965592861175537, + -1.0767146348953247, + 0.13287265598773956, + 0.517289400100708, + -0.310720294713974, + -0.15501761436462402, + 0.5854693055152893, + -0.12469431757926941, + -1.7694847583770752, + 2.6433238983154297, + -1.596714735031128, + 0.3888415992259979, + -0.02415616251528263, + 0.42178481817245483, + -1.8008503913879395, + 2.8845136165618896, + -1.7628657817840576, + 0.1951047033071518, + 0.11415407806634903, + 0.07305648922920227, + -0.34212157130241394, + 0.46562451124191284, + 0.03175807744264603, + -0.7942091226577759, + 0.6133171319961548, + -0.14596694707870483, + 0.010496735572814941, + -0.03459644690155983, + 0.2948842942714691, + -0.47654271125793457, + 0.2612597346305847, + 0.016025209799408913, + -0.05287598818540573, + -0.01606004498898983, + 0.022197037935256958, + 0.028397703543305397, + -0.0390767939388752, + 0.0037972000427544117, + -0.07010228931903839, + 0.10934390872716904, + 0.017220165580511093, + 0.02215729095041752, + -0.14772991836071014, + 0.2353552132844925, + -0.3846408724784851, + 0.23990634083747864, + -0.02300707995891571, + 0.12085225433111191, + -0.3576957881450653, + 0.6410096883773804, + -0.532350480556488, + -0.002389132045209408, + 0.41821879148483276, + -0.24739143252372742, + -0.10216745734214783, + 0.16793736815452576, + 0.16367803514003754, + -1.1304419040679932, + 1.676539421081543, + -1.064436435699463, + 0.26995453238487244, + -0.07634275406599045, + 0.3324422240257263, + -1.11312997341156, + 1.8095507621765137, + -1.2477567195892334, + 0.3605581820011139, + -0.06627745926380157, + 0.008511146530508995, + -0.19528241455554962, + 0.4320055842399597, + -0.22881783545017242, + -0.18463851511478424, + 0.3064245581626892, + -0.14437103271484375, + 0.02049900032579899, + 0.018321938812732697, + 0.14011529088020325, + -0.26683253049850464, + 0.2172057181596756, + -0.12119362503290176, + 0.025965997949242592, + -0.03424325957894325, + 0.0433838777244091, + 0.1072857677936554, + 0.1997794657945633, + 0.0648089200258255, + -0.06444115936756134, + -0.13146057724952698, + 0.02106364443898201, + -0.22582228481769562, + -0.007233713287860155, + 0.18876874446868896, + -0.5612399578094482, + 0.2632557451725006, + 0.44088244438171387, + 0.11389002948999405, + -0.2791701555252075, + -0.18004432320594788, + 0.8571203947067261, + -1.9517340660095215, + -1.4906251430511475, + 0.3436146676540375, + 0.31222787499427795, + -0.20083315670490265, + -0.217665895819664, + 3.801243782043457, + 1.2014728784561157, + -0.9149202704429626, + 0.6968244910240173, + 0.12756747007369995, + -0.06783506274223328, + -2.086660385131836, + 0.5455523133277893, + 0.49095916748046875, + -0.5991013050079346, + 0.7938552498817444, + -0.1335069239139557, + 0.4730406701564789, + -1.00951087474823, + -0.537578821182251, + -0.49764835834503174, + -1.2683815956115723, + -0.045739322900772095, + -0.16049732267856598, + 0.30239275097846985, + 0.035600025206804276, + 0.6344828605651855, + 0.8256548643112183, + -0.12940075993537903, + 0.09257010370492935, + -0.11000311374664307, + 0.003206665627658367, + -0.008585316129028797, + -0.14573170244693756, + 0.172541081905365, + 0.2107972949743271, + -0.05270108953118324, + -0.08480435609817505, + 0.1914149820804596, + 0.21630872786045074, + -0.23309426009655, + -0.29484814405441284, + -0.1899339109659195, + 0.02601807750761509, + -0.05416746065020561, + 0.20924429595470428, + 0.15566189587116241, + -0.1556546688079834, + -0.23387494683265686, + -0.5112816691398621, + 0.24130745232105255, + -0.049835484474897385, + -0.2685615122318268, + -0.024764614179730415, + 0.5458847880363464, + 0.9501044750213623, + 0.1328524947166443, + 0.21218529343605042, + 0.2524968683719635, + -0.5205130577087402, + -0.3361912667751312, + 1.1678112745285034, + -0.004513490945100784, + -0.9149109125137329, + 0.2125048041343689, + 0.22423015534877777, + -0.08384363353252411, + -0.2866036593914032, + -0.20210212469100952, + -1.2377471923828125, + -0.7704879641532898, + 0.365038126707077, + -0.08308980613946915, + -0.08326874673366547, + 0.456358402967453, + 0.35142943263053894, + 0.19268833100795746, + 0.3706081509590149, + -0.04951317980885506, + 0.10151109844446182, + 0.005193099845200777, + -0.1124582439661026, + -0.08353164792060852, + -0.18709596991539001, + -0.18975794315338135, + 0.17628741264343262, + 0.05536900460720062, + 0.008301885798573494, + -0.1890449970960617, + 0.056875281035900116, + 0.7981322407722473, + -0.05872391164302826, + -0.4860122501850128, + -0.08073797076940536, + 0.13145819306373596, + -0.03608228266239166, + -0.6600452661514282, + 2.243560314178467, + 1.9288626909255981, + -0.5698518753051758, + -0.2486664056777954, + 0.42693793773651123, + 0.2667267322540283, + -4.395429611206055, + -2.15342378616333, + 0.819127082824707, + -0.9362612962722778, + -0.3760467767715454, + 0.5671858787536621, + 2.468177080154419, + -1.6694080829620361, + -0.49952322244644165, + 1.502772569656372, + -1.0188850164413452, + -0.10419629514217377, + -0.36795151233673096, + 1.2645196914672852, + 0.7223924994468689, + 1.751431941986084, + 2.018704891204834, + -0.3197852671146393, + 0.22054125368595123, + -0.19326329231262207, + -0.5307535529136658, + -0.9362435936927795, + -1.0772119760513306, + -0.19870880246162415, + -0.0650869607925415, + -0.0796947032213211, + 0.15733301639556885, + 0.08798394352197647, + 0.0010860684560611844, + 0.05327683687210083, + 0.1107875183224678, + 0.13224183022975922, + 0.08979664742946625, + 0.004348093178123236, + -0.07060158997774124, + -0.19925491511821747, + -0.15811985731124878, + -0.08220887929201126, + -0.022623460739850998, + 0.08509720861911774, + 0.00792989507317543, + -0.14345014095306396, + -0.2720486521720886, + -0.18885627388954163, + -0.11063539236783981, + -0.0355350486934185, + 0.048891279846429825, + -0.12828074395656586, + -0.2712610363960266, + -0.20134924352169037, + -0.1863398402929306, + -0.19976121187210083, + -0.09535074234008789, + 0.009852319024503231, + -0.2776590585708618, + -0.3087778687477112, + -0.21431012451648712, + -0.19772370159626007, + -0.23412325978279114, + -0.11640459299087524, + 0.09514907747507095, + -0.17561811208724976, + -0.29451555013656616, + -0.2381855845451355, + -0.18296842277050018, + -0.18682444095611572, + -0.023345205932855606, + 0.1438502073287964, + 0.02504260651767254, + -0.1554802507162094, + -0.1477985382080078, + -0.07874225080013275, + -0.002977968193590641, + 0.1048416793346405, + -0.1779504120349884, + 0.13204343616962433, + 0.14215172827243805, + 0.049610622227191925, + 0.0888131782412529, + 0.07250366359949112, + 0.0696505531668663, + 0.009899160824716091, + 0.032067786902189255, + 0.08401404321193695, + -0.03567894548177719, + -0.004740188363939524, + -0.0021664693485945463, + -0.011156522668898106, + 0.0821070745587349, + 0.10295391082763672, + -0.0017653254326432943, + -0.16915833950042725, + -0.062223054468631744, + 0.004783258773386478, + 0.038355808705091476, + 0.10124270617961884, + -0.003437258303165436, + -0.18881437182426453, + -0.15905225276947021, + -0.12576808035373688, + -0.11059725284576416, + 0.021587060764431953, + 0.07237453758716583, + -0.1706620156764984, + -0.27434206008911133, + -0.23003827035427094, + -0.20530915260314941, + -0.20856624841690063, + -0.021966496482491493, + 0.13395215570926666, + -0.03810539469122887, + -0.2409798800945282, + -0.2515420913696289, + -0.1872486174106598, + -0.15951117873191833, + 0.04223426431417465, + 0.09909931570291519, + 0.12328703701496124, + -0.057749148458242416, + -0.1300545036792755, + -0.046062104403972626, + 0.019744107499718666, + 0.09484386444091797, + -0.2709728479385376, + 0.03540695831179619, + 0.1206774190068245, + 0.057636432349681854, + 0.10385740548372269, + 0.032486993819475174, + -0.020434774458408356, + -0.10122086852788925, + -0.0023329253308475018, + 0.16941140592098236, + 0.098082534968853, + 0.1250472217798233, + 0.06134447827935219, + -0.025240115821361542, + 0.004181401338428259, + 0.14425808191299438, + 0.17515034973621368, + 0.04739757999777794, + 0.1618604063987732, + 0.1751406490802765, + 0.09162088483572006, + 0.09512057155370712, + 0.13736343383789062, + 0.028775952756404877, + 0.042535409331321716, + 0.08839954435825348, + 0.09229374676942825, + 0.1658262014389038, + 0.09852072596549988, + 0.002680110279470682, + -0.05479496717453003, + -0.03634755313396454, + -0.002902726177126169, + -0.023990361019968987, + 0.1277875006198883, + 0.12727677822113037, + 0.1002269834280014, + -0.040967896580696106, + -0.07101184874773026, + -0.007902896963059902, + 0.019561029970645905, + 0.145268052816391, + 0.017638152465224266, + 0.19240263104438782, + 0.12857146561145782, + 0.05043037235736847, + 0.11596394330263138, + 0.12513381242752075, + 0.12088746577501297, + 0.04333524778485298, + 0.05500142276287079, + 0.05169082432985306, + -0.09941842406988144, + -0.005959822330623865, + -0.032586321234703064, + -0.03065132349729538, + -0.04826900362968445, + 0.14192889630794525, + 0.2543988823890686, + 0.09563885629177094, + -0.28965362906455994, + -0.1341734230518341, + 0.033991701900959015, + -0.22402706742286682, + -0.3190857768058777, + 0.011840387247502804, + 0.9620282053947449, + 1.0609054565429688, + -0.13429726660251617, + -0.20191268622875214, + 0.05324135720729828, + -0.16234318912029266, + -0.9101927280426025, + -1.7916113138198853, + 0.3981992304325104, + 1.3173034191131592, + 0.53525310754776, + 0.18472574651241302, + 0.3719426691532135, + 0.7792536020278931, + -0.027768991887569427, + -2.245561122894287, + -1.2211185693740845, + 0.22817185521125793, + -0.0023349972907453775, + -0.12598364055156708, + 0.06836964190006256, + 0.9917387366294861, + 1.1885775327682495, + -0.2851368486881256, + -0.7428704500198364, + -0.04798422381281853, + -0.00811613816767931, + -0.19619861245155334, + -0.28184008598327637, + 0.0828644260764122, + 0.44643187522888184, + 0.1461745798587799, + -0.005575121380388737, + -0.06604957580566406, + 0.011459077708423138, + 0.03927984461188316, + 0.0634538009762764, + -0.005732079967856407, + -0.01014732290059328, + 0.07607843726873398, + 0.06948187947273254, + -0.010600326582789421, + -0.056259915232658386, + -0.24602480232715607, + -0.01649448834359646, + 0.11143466085195541, + -0.0027401424013078213, + -0.012853104621171951, + 0.08452893793582916, + 0.639316201210022, + 0.5167437195777893, + -0.2775256335735321, + -0.22241903841495514, + -0.07067711651325226, + -0.06368192285299301, + -0.4687917232513428, + -1.1776493787765503, + 0.36015447974205017, + 0.9171182513237, + 0.1905054748058319, + -0.010661551728844643, + 0.10800722986459732, + 0.5352235436439514, + 0.18558207154273987, + -1.5184046030044556, + -0.8130561709403992, + 0.15417319536209106, + 0.0713079422712326, + -0.07369451224803925, + -0.09037846326828003, + 0.6168488264083862, + 0.9663773775100708, + -0.007113471627235413, + -0.33585548400878906, + -0.02738586813211441, + 0.061310965567827225, + -0.0955657884478569, + -0.23896107077598572, + -0.1107473075389862, + 0.1830059289932251, + 0.10748914629220963, + -0.040772341191768646, + -0.05803938955068588, + -0.0004895658930763602, + 0.07664632797241211, + 0.039049405604600906, + -0.002806248841807246, + -0.02642429992556572, + 0.05169009417295456, + -0.036710865795612335, + -0.1002974808216095, + -0.12001149356365204, + -0.08043934404850006, + 0.11466419696807861, + 0.12322796136140823, + 0.07564827799797058, + 0.10148002207279205, + 0.04720174893736839, + 0.14046646654605865, + -0.0819464847445488, + -0.30803975462913513, + -0.0838734582066536, + -0.0801682323217392, + 0.05861072987318039, + 0.04970559477806091, + -0.20592759549617767, + 0.2673366665840149, + 0.2431953400373459, + -0.10027645528316498, + -0.07884806394577026, + -0.09939537942409515, + 0.1181628480553627, + 0.25269386172294617, + -0.3439132571220398, + -0.11160463094711304, + 0.08640077710151672, + 0.07200870662927628, + -0.03449570760130882, + -0.17610406875610352, + -0.021308166906237602, + 0.30556705594062805, + 0.05186203494668007, + -0.004691269714385271, + -0.005278654862195253, + 0.06289899349212646, + 0.052224051207304, + -0.05927770212292671, + -0.1586783081293106, + -0.022610770538449287, + 0.03463536128401756, + 0.004338411148637533, + 0.01452699676156044, + -0.008622901514172554, + 0.010536444373428822, + -0.038111478090286255, + 0.013373414985835552, + 0.007125865668058395, + -0.003420598339289427, + 0.03533756732940674, + 0.0320388600230217, + 0.045789655297994614, + -0.08139114826917648, + -0.03447948023676872, + -0.01453007198870182, + -0.004573625046759844, + 0.10279268026351929, + 0.10881853848695755, + 0.07537791877985, + -0.10887791216373444, + -0.0980544164776802, + -0.06889445334672928, + 0.006558350287377834, + 0.197514146566391, + 0.17890937626361847, + 0.07630149275064468, + -0.16081148386001587, + -0.16685302555561066, + -0.11421715468168259, + -0.013679573312401772, + 0.22477784752845764, + 0.20761631429195404, + 0.07321957498788834, + -0.17697854340076447, + -0.17810045182704926, + -0.1579347848892212, + -0.02679254300892353, + 0.1408146619796753, + 0.15144851803779602, + 0.08801613748073578, + -0.13237154483795166, + -0.13181765377521515, + -0.1279487907886505, + -0.01779216341674328, + 0.08145096898078918, + 0.05625852569937706, + 0.07724357396364212, + -0.04653938114643097, + -0.07479449361562729, + -0.06189379468560219, + -0.04310920089483261, + 0.02028634026646614, + -0.006228619255125523, + 0.03549303859472275, + -0.043929651379585266, + 0.007818001322448254, + 0.00874761026352644, + -0.017027731984853745, + 0.11014463752508163, + 0.0841977447271347, + 0.05960552394390106, + -0.12814101576805115, + -0.0544624924659729, + -0.045333195477724075, + 0.02336869016289711, + 0.22365787625312805, + 0.18523427844047546, + 0.09366372227668762, + -0.20144090056419373, + -0.16367222368717194, + -0.13003699481487274, + 0.0590205080807209, + 0.3301562964916229, + 0.26524844765663147, + 0.09425198286771774, + -0.26156124472618103, + -0.28513699769973755, + -0.21749621629714966, + 0.04356053099036217, + 0.35879984498023987, + 0.29898661375045776, + 0.0977487862110138, + -0.28175386786460876, + -0.2964495122432709, + -0.249031201004982, + 0.028877725824713707, + 0.26395633816719055, + 0.23059280216693878, + 0.09593978524208069, + -0.22489066421985626, + -0.2248908430337906, + -0.19214706122875214, + 0.007535146549344063, + 0.15299226343631744, + 0.09148521721363068, + 0.06946425884962082, + -0.1445557326078415, + -0.11587042361497879, + -0.0978587418794632, + -0.00984917301684618, + -0.012626220472157001, + -0.02837960794568062, + 0.02399199828505516, + -0.005340439733117819, + 0.023224178701639175, + 0.011642432771623135, + 0.003958537708967924, + 0.042965203523635864, + 0.01099414099007845, + 0.024063799530267715, + -0.0702008455991745, + 0.007805663626641035, + 0.0050195748917758465, + 0.017281856387853622, + 0.10123670846223831, + 0.06401767581701279, + 0.02626805007457733, + -0.1073761060833931, + -0.03802435100078583, + -0.014407800510525703, + -0.0006281707319431007, + 0.15516239404678345, + 0.12629136443138123, + 0.033691491931676865, + -0.17609107494354248, + -0.15251316130161285, + -0.07914211601018906, + -0.015578335151076317, + 0.18422608077526093, + 0.1740245372056961, + 0.06139932945370674, + -0.17213505506515503, + -0.1602732092142105, + -0.08922445774078369, + -0.012822975404560566, + 0.13543544709682465, + 0.12543149292469025, + 0.07651004195213318, + -0.13805902004241943, + -0.09661149233579636, + -0.052669934928417206, + -0.03268992528319359, + 0.0391642227768898, + 0.01116940937936306, + 0.04585625231266022, + -0.06474924832582474, + -0.023607701063156128, + -0.007017284631729126, + -0.026150476187467575, + 0.05729387328028679, + -0.10095079243183136, + 0.16617903113365173, + -0.13664309680461884, + 0.026482274755835533, + 0.008411461487412453, + -0.03410203382372856, + 0.022963764145970345, + 0.008903563022613525, + 0.11244194954633713, + -0.20863348245620728, + 0.11064451932907104, + -0.024916114285588264, + 0.009591493755578995, + -0.26092270016670227, + 0.5717483758926392, + -0.38539814949035645, + 0.035056713968515396, + 0.08623965084552765, + -0.016184961423277855, + 0.11129201203584671, + -0.6138678789138794, + 1.3646206855773926, + -1.4969615936279297, + 0.8465064764022827, + -0.2794847786426544, + 0.05826558917760849, + 0.07709132134914398, + -0.5444677472114563, + 1.3013663291931152, + -1.5686073303222656, + 0.9930508732795715, + -0.39188963174819946, + 0.08085884898900986, + -0.05875617265701294, + 0.03498996049165726, + 0.23967482149600983, + -0.3468690514564514, + 0.19146253168582916, + 0.019604403525590897, + -0.027150027453899384, + -0.024670494720339775, + 0.09944183379411697, + -0.11718503385782242, + 0.09772855788469315, + -0.11857263743877411, + 0.09660946577787399, + -0.03638811036944389, + -0.0295167975127697, + 0.1032838523387909, + -0.12557579576969147, + 0.11812210828065872, + -0.08446288853883743, + 0.027706580236554146, + 0.010997293516993523, + -0.06348618865013123, + 0.09578556567430496, + -0.0165568757802248, + -0.014778072014451027, + -0.07772849500179291, + 0.11245536059141159, + -0.043248821049928665, + 0.013345679268240929, + -0.22149333357810974, + 0.6456363797187805, + -0.7280437350273132, + 0.3046833574771881, + 0.06304280459880829, + -0.07310052216053009, + 0.08824795484542847, + -0.65179842710495, + 1.6453673839569092, + -2.046448230743408, + 1.3267604112625122, + -0.42399832606315613, + 0.0010522910160943866, + 0.07953720539808273, + -0.5960973501205444, + 1.5601089000701904, + -2.084894895553589, + 1.4612183570861816, + -0.5491638779640198, + 0.13709494471549988, + -0.09170618653297424, + 0.07287970930337906, + 0.24422486126422882, + -0.4581631124019623, + 0.29479551315307617, + -0.07515113800764084, + -0.012292998842895031, + -0.04451148584485054, + 0.14961428940296173, + -0.15577177703380585, + 0.06323063373565674, + -0.07806269824504852, + 0.07061618566513062, + -0.026793144643306732, + -0.051938362419605255, + 0.13946141302585602, + -0.14129231870174408, + 0.11092118173837662, + -0.08889970183372498, + 0.034787945449352264, + -0.008983314968645573, + -0.04930088296532631, + 0.09856640547513962, + -0.09350966662168503, + 0.07015673816204071, + -0.06468848884105682, + 0.08028972148895264, + -0.02378295361995697, + 0.004251216538250446, + -0.11239825189113617, + 0.2660067081451416, + -0.367576539516449, + 0.2212517410516739, + -0.035011082887649536, + -0.037866897881031036, + 0.11835235357284546, + -0.4868132174015045, + 0.9402765035629272, + -1.0933791399002075, + 0.9518744349479675, + -0.5096855759620667, + 0.12277142703533173, + 0.12916085124015808, + -0.4648635983467102, + 0.8895858526229858, + -1.0776352882385254, + 1.023865818977356, + -0.5914785861968994, + 0.1682877242565155, + -0.05646277964115143, + 0.04132156819105148, + -0.01790236309170723, + -0.059831030666828156, + 0.10092897713184357, + -0.1268356889486313, + 0.013669619336724281, + -0.02746082842350006, + 0.11544085294008255, + -0.2124193012714386, + 0.2733248472213745, + -0.1360178142786026, + 0.025302443653345108, + 0.01249375008046627, + -0.015119954012334347, + 0.017966970801353455, + 0.00269943755120039, + 0.014392177574336529, + 0.007648292928934097, + 0.011665135622024536, + -0.006192799191921949, + 0.004215092398226261, + 0.017718149349093437, + 0.046436555683612823, + 0.044417623430490494, + 0.01518242433667183, + -0.0020157198887318373, + -0.01828707568347454, + -0.029163505882024765, + -0.03131464868783951, + -0.004393945913761854, + 0.048599082976579666, + 0.015757638961076736, + -0.015650734305381775, + -0.002684049541130662, + -0.0697445422410965, + -0.25050923228263855, + -0.4758685231208801, + -0.5382962822914124, + -0.38907238841056824, + -0.12599025666713715, + -0.00266047241166234, + 0.0758173018693924, + 0.26593172550201416, + 0.4203726053237915, + 0.4958920478820801, + 0.3697706162929535, + 0.12434400618076324, + 0.026325728744268417, + 0.022295912727713585, + 0.08135133236646652, + 0.2627769708633423, + 0.26325660943984985, + 0.12326934933662415, + 0.058665141463279724, + 0.04346219077706337, + -0.0013142779935151339, + -0.10037153959274292, + -0.27075886726379395, + -0.28071707487106323, + -0.17300420999526978, + -0.06914675980806351, + 0.004067219793796539, + -0.020674005150794983, + 0.02103183977305889, + 0.0033879741095006466, + 0.013523808680474758, + -0.007318845018744469, + -0.009975744411349297, + -0.02981705591082573, + 0.023193644359707832, + 0.09624253213405609, + 0.1077117845416069, + 0.11186518520116806, + 0.07592211663722992, + 0.04614634811878204, + 0.015908582136034966, + -0.05212458223104477, + -0.1262977123260498, + -0.10974782705307007, + -0.07645918428897858, + -0.06987964361906052, + -0.08783216774463654, + -0.046172842383384705, + -0.22593465447425842, + -0.5281140804290771, + -0.8424770832061768, + -0.9608982801437378, + -0.7363743185997009, + -0.3312055170536041, + -0.10426472127437592, + 0.24067367613315582, + 0.5504152178764343, + 0.81276935338974, + 0.9592635035514832, + 0.7479950785636902, + 0.32608768343925476, + 0.14525265991687775, + 0.15008939802646637, + 0.32246851921081543, + 0.5287250876426697, + 0.5817036032676697, + 0.37340155243873596, + 0.20366452634334564, + 0.1546182781457901, + -0.11224830150604248, + -0.29856279492378235, + -0.5281672477722168, + -0.5890122056007385, + -0.4024880528450012, + -0.23706914484500885, + -0.0641399398446083, + -0.0025121152866631746, + 0.0051757702603936195, + -0.014290476217865944, + 0.0043721878901124, + -0.004783981014043093, + 0.021787043660879135, + -0.004969750996679068, + -0.022116241976618767, + 0.05208030343055725, + 0.07022145390510559, + 0.03730607405304909, + 0.03242917358875275, + 0.04344351217150688, + -0.01189794484525919, + -0.0418211966753006, + -0.059125497937202454, + -0.014576594345271587, + 0.01294493954628706, + -0.011262460611760616, + -0.059920165687799454, + -0.04733816161751747, + -0.12665517628192902, + -0.29677024483680725, + -0.5247481465339661, + -0.6474934816360474, + -0.4751538038253784, + -0.1937171369791031, + -0.05117221921682358, + 0.14646948873996735, + 0.32891425490379333, + 0.5415402054786682, + 0.6071264147758484, + 0.4653589427471161, + 0.18045872449874878, + 0.09937354922294617, + 0.1264665126800537, + 0.18507222831249237, + 0.31783968210220337, + 0.3545042872428894, + 0.22468777000904083, + 0.09973976761102676, + 0.1227618008852005, + -0.07824759930372238, + -0.20465101301670074, + -0.36476215720176697, + -0.38243186473846436, + -0.2540777623653412, + -0.13525226712226868, + -0.03621843457221985, + -0.012233156710863113, + -0.01481863297522068, + -0.04313792288303375, + 0.002874002791941166, + -0.028444716706871986, + -0.04687628522515297, + -0.026806645095348358, + -0.0228339321911335, + -0.015892738476395607, + -0.015550780110061169, + 0.07011140882968903, + 0.0017389585264027119, + -0.05721491947770119, + -0.017484690994024277, + -0.03954736143350601, + -0.006339249666780233, + 0.08166316151618958, + 0.37439921498298645, + 0.2830294966697693, + 0.00668215099722147, + -0.038873329758644104, + -0.012295035645365715, + 0.04932165890932083, + 0.31826695799827576, + 0.8449289202690125, + 0.7123299241065979, + 0.2574000954627991, + 0.04747961834073067, + -0.04416817054152489, + -0.005029442720115185, + 0.2027042657136917, + 0.6639980673789978, + 0.6243636012077332, + 0.21359916031360626, + 0.027929672971367836, + -0.05395142361521721, + -0.04981911554932594, + -0.006375179626047611, + 0.23660773038864136, + 0.2155737280845642, + 0.020577391609549522, + -0.032118700444698334, + -0.02332071214914322, + -0.009217707440257072, + -0.038096409291028976, + 0.05811609327793121, + 0.03776064142584801, + -0.03570764884352684, + -0.042420413345098495, + 0.017812976613640785, + 0.019242385402321815, + 0.030057156458497047, + 0.003040613606572151, + 0.02378096617758274, + 0.04043402150273323, + 0.0243258997797966, + 0.014026327058672905, + 0.005650558043271303, + -0.002831381279975176, + -0.0645776093006134, + -0.03761167451739311, + 0.043774381279945374, + 0.010685136541724205, + 0.031011218205094337, + -0.0025828774087131023, + -0.11959855258464813, + -0.3524792194366455, + -0.30037227272987366, + -0.053334690630435944, + 0.009859252721071243, + 0.0010005333460867405, + -0.04819931834936142, + -0.3154168128967285, + -0.7240553498268127, + -0.6380828022956848, + -0.25695785880088806, + -0.06639125943183899, + 0.03295261785387993, + -0.012727363035082817, + -0.24232468008995056, + -0.6055921912193298, + -0.5679556727409363, + -0.20067356526851654, + -0.03628019988536835, + 0.04774145409464836, + 0.029560575261712074, + -0.038632482290267944, + -0.24032950401306152, + -0.2095729559659958, + -0.006905315909534693, + 0.02563827484846115, + 0.03053808957338333, + 0.0012747920118272305, + 0.004095789045095444, + -0.07932732999324799, + -0.046672020107507706, + 0.02153847925364971, + 0.019504766911268234, + -0.006118285935372114, + 0.0026654782705008984, + 0.013819373212754726, + -0.01078135147690773, + 0.0070082321763038635, + 0.00906399916857481, + 0.010149766691029072, + 0.000516490894369781, + 0.00034157291520386934, + 0.02412085421383381, + 0.006926041562110186, + 0.023299943655729294, + 0.01129852794110775, + -0.0018704778049141169, + 0.016042279079556465, + 0.023886069655418396, + 0.04207555204629898, + -0.0021778997033834457, + 0.041684601455926895, + 0.05059140920639038, + 0.03518521040678024, + -0.0032736151479184628, + -0.0007146652205847204, + 0.015503454953432083, + -0.11896659433841705, + -0.07006713002920151, + 0.007565992418676615, + 0.012584990821778774, + 0.00843358226120472, + 0.017024952918291092, + 0.0359124094247818, + -0.05997823178768158, + -0.04116949439048767, + -0.016472430899739265, + 0.002696823561564088, + 0.00829327292740345, + 0.016238784417510033, + 0.0455794483423233, + 0.0019872160628437996, + -0.005927432328462601, + -0.003552153240889311, + 0.020063765347003937, + 0.00010026743984781206, + 0.01045019831508398, + 0.034689340740442276, + 0.014206668362021446, + 0.015128945000469685, + 0.00972809735685587, + 0.019944868981838226, + 0.020581791177392006, + 0.02938947267830372, + 0.03923909366130829, + 0.03601628914475441, + 0.030168617144227028, + 0.05403255671262741, + 0.03985666483640671, + 0.020015308633446693, + 0.0285494402050972, + 0.013555807992815971, + -0.04409409686923027, + -0.07503483444452286, + 0.01716756261885166, + 0.02053452841937542, + 0.057520389556884766, + 0.02973104454576969, + -0.04563397541642189, + -0.2676408588886261, + -0.30933722853660583, + -0.11671236902475357, + 0.0020135289523750544, + 0.022801443934440613, + -0.03161352127790451, + -0.2704106271266937, + -0.5803710222244263, + -0.5762420296669006, + -0.30449461936950684, + -0.0780220776796341, + 0.017343536019325256, + -0.05319945886731148, + -0.2906038463115692, + -0.598426342010498, + -0.5925986766815186, + -0.31852787733078003, + -0.09950074553489685, + 0.05888299271464348, + 0.01939479075372219, + -0.1060815081000328, + -0.3505017161369324, + -0.3200446665287018, + -0.10609738528728485, + 0.03659524768590927, + 0.056114207953214645, + 0.03447861596941948, + 0.014380007050931454, + -0.09436371922492981, + -0.07562272250652313, + 0.04223132133483887, + 0.06327345967292786, + -0.03735652193427086, + -0.052881840616464615, + -0.058017320930957794, + -0.02474917098879814, + -0.02431381866335869, + -0.0629878118634224, + -0.05212349444627762, + -0.03820814937353134, + -0.0034579068887978792, + -0.004930540919303894, + 0.07968354970216751, + 0.07278168946504593, + 0.015167324803769588, + -0.013638288713991642, + -0.05875609815120697, + -0.008851750753819942, + 0.10708516091108322, + 0.33075177669525146, + 0.3502756953239441, + 0.14791442453861237, + 0.03131852671504021, + -0.028764141723513603, + 0.07454497367143631, + 0.3000347316265106, + 0.6147283315658569, + 0.6289594173431396, + 0.3398674726486206, + 0.13494613766670227, + -0.03705109655857086, + 0.0633230209350586, + 0.3147434592247009, + 0.595033586025238, + 0.594217836856842, + 0.33864542841911316, + 0.11264053732156754, + -0.059276629239320755, + 0.005206871312111616, + 0.14524762332439423, + 0.37473905086517334, + 0.34477534890174866, + 0.12632343173027039, + 0.011062734760344028, + -0.06149457022547722, + -0.028670497238636017, + 0.011082210578024387, + 0.13112866878509521, + 0.1106843650341034, + -0.0025933771394193172, + -0.03781202808022499, + 0.030325254425406456, + 0.017758814617991447, + 0.01635698974132538, + -0.008786264806985855, + -0.0005018062074668705, + 0.005934061016887426, + 0.020206287503242493, + 0.019497420638799667, + -0.01290479488670826, + -0.010817185044288635, + -0.032760608941316605, + -0.026973316445946693, + -0.0021766452118754387, + -0.012848617509007454, + -0.0002560729335527867, + -0.02383977733552456, + -0.05322824791073799, + -0.05382781848311424, + -0.04459262639284134, + -0.04581240937113762, + -0.03465775027871132, + 0.0026904877740889788, + -0.026097090914845467, + -0.05170493200421333, + -0.04981262609362602, + -0.05221042037010193, + -0.05268307775259018, + -0.04735802114009857, + 0.019142162054777145, + -0.019374292343854904, + -0.03312355652451515, + -0.04133244976401329, + -0.033129844814538956, + -0.01844680868089199, + -0.024726904928684235, + 0.0012146441731601954, + -0.025521529838442802, + -0.03120318427681923, + -0.04863203689455986, + -0.021450525149703026, + -0.04190714284777641, + -0.02833862416446209, + 0.017827404662966728, + -0.010181388817727566, + -0.020994380116462708, + -0.04290826618671417, + -0.031555648893117905, + -0.030525390058755875, + -0.024981478229165077, + -0.017512500286102295, + 0.019927235320210457, + 0.00433371402323246, + -0.009276121854782104, + -0.03990143537521362, + -0.021251117810606956, + 0.017825132235884666, + -0.02313065528869629, + 0.012881814502179623, + 0.0009175563463941216, + -0.0656605213880539, + -0.007037178613245487, + 0.023603176698088646, + 0.04873553663492203, + 0.013912673108279705, + 9.78652315097861e-05, + -0.03166677802801132, + -0.11772678792476654, + -0.034320034086704254, + 0.04952533170580864, + 0.10113520920276642, + 0.030472615733742714, + -0.05131377652287483, + -0.1371452510356903, + -0.2326214611530304, + -0.0629519522190094, + 0.12444627285003662, + 0.15845368802547455, + 0.014535457827150822, + -0.06888624280691147, + -0.18798232078552246, + -0.24720685184001923, + -0.04858007654547691, + 0.26889580488204956, + 0.2433905005455017, + -0.01772989332675934, + -0.06027546152472496, + -0.12164203822612762, + -0.20018024742603302, + 0.0035393801517784595, + 0.27190765738487244, + 0.1929154396057129, + -0.012923460453748703, + -0.013931642286479473, + -0.043986693024635315, + -0.0655391663312912, + 0.04751605913043022, + 0.13482201099395752, + 0.06690078228712082, + -0.01862635649740696, + 0.02938506379723549, + 0.01789080537855625, + -0.006509440019726753, + -0.029202938079833984, + -0.023693149909377098, + 0.01042762491852045, + -0.0035929735749959946, + 0.024952176958322525, + -0.013459124602377415, + -0.10798560827970505, + -0.020217353478074074, + 0.017876077443361282, + 0.07628928124904633, + 0.04444783553481102, + 0.012667268514633179, + -0.09012818336486816, + -0.22452381253242493, + -0.07556752860546112, + 0.07942477613687515, + 0.17035256326198578, + 0.0396822914481163, + -0.08236342668533325, + -0.23916372656822205, + -0.3645225763320923, + -0.10748416185379028, + 0.1996970921754837, + 0.3076043725013733, + -0.0033923503942787647, + -0.13259321451187134, + -0.28894615173339844, + -0.3605952262878418, + -0.07969008386135101, + 0.3583948314189911, + 0.4267900586128235, + -0.02228585258126259, + -0.11386624723672867, + -0.21445821225643158, + -0.26956692337989807, + 0.026791207492351532, + 0.37918713688850403, + 0.37130093574523926, + -0.05172214284539223, + -0.05132569745182991, + -0.07469630241394043, + -0.11400169134140015, + 0.07863093167543411, + 0.24061299860477448, + 0.19393151998519897, + -0.03217098489403725, + 0.013085477985441685, + 0.032348379492759705, + 0.03207695484161377, + 0.010604938492178917, + -0.026534704491496086, + -0.018284842371940613, + -0.01768680103123188, + -0.001516501884907484, + 0.013829287141561508, + -0.034318119287490845, + 0.015753330662846565, + -0.0018936718115583062, + 0.014737343415617943, + 0.03306088596582413, + 0.020835628733038902, + -0.03396771103143692, + -0.10758449137210846, + -0.03052518330514431, + 0.020080547779798508, + 0.06180800125002861, + 0.03735671192407608, + -0.037925880402326584, + -0.09720461815595627, + -0.21495617926120758, + -0.06842153519392014, + 0.08532039076089859, + 0.13350333273410797, + 0.03649023920297623, + -0.03904158994555473, + -0.1483580619096756, + -0.2068314403295517, + -0.05687328055500984, + 0.21108660101890564, + 0.21018920838832855, + 0.009318819269537926, + -0.037683792412281036, + -0.09845960140228271, + -0.1535443514585495, + 0.004504916723817587, + 0.20256847143173218, + 0.1799001693725586, + -0.03175490349531174, + -0.020391397178173065, + -0.007309200707823038, + -0.06765769422054291, + 0.013149870559573174, + 0.08469820767641068, + 0.04147877171635628, + -0.0027241194620728493, + 0.008016721345484257, + 0.001382349175401032, + 0.0001219741752720438, + -0.059255484491586685, + -0.03761141747236252, + 0.0381690077483654, + -0.01603613793849945, + 0.0017731477273628116, + -0.016544193029403687, + 0.09518970549106598, + 0.1735895872116089, + 0.005558829288929701, + -0.13464735448360443, + -0.0703420490026474, + 0.001990854274481535, + -0.03426021337509155, + -0.4390500485897064, + -0.11292288452386856, + 0.20430812239646912, + 0.14832687377929688, + 0.06074441969394684, + -0.03749264031648636, + 0.408058226108551, + 0.43119552731513977, + -0.3804298937320709, + -0.3694773018360138, + -0.03696960583329201, + 0.04022200033068657, + -0.0812998041510582, + -0.4322642385959625, + 0.19638888537883759, + 0.7809834480285645, + 0.11584538966417313, + -0.04975399747490883, + -0.015579828992486, + 0.1362757831811905, + 0.027220597490668297, + -0.4703449606895447, + -0.3726261258125305, + 0.11754196882247925, + -0.01204066164791584, + -0.00118898821529001, + -0.05152498185634613, + 0.08767394721508026, + 0.14183296263217926, + 0.01692730002105236, + -0.04587334021925926, + 0.011115594767034054, + 0.021572716534137726, + -0.021584773436188698, + -0.012763801962137222, + 0.05708793178200722, + 0.021982798352837563, + -0.02731800265610218, + 0.03000856563448906, + 0.006653181277215481, + -0.02485630102455616, + -0.20296195149421692, + -0.10483214259147644, + 0.20483383536338806, + 0.1350196748971939, + -0.08543248474597931, + 0.02644401416182518, + 0.26855263113975525, + 0.1071053072810173, + -0.8168368935585022, + -0.6617473363876343, + 0.02877889946103096, + 0.21807144582271576, + -0.02164696715772152, + -0.03712613880634308, + 0.9743875861167908, + 1.1631361246109009, + -0.45643851161003113, + -0.8180081844329834, + -0.28109386563301086, + -0.09115415811538696, + -0.4352502226829529, + -0.7433719038963318, + 0.5383746027946472, + 1.7271664142608643, + 0.509749174118042, + -0.0689467042684555, + 0.010011479258537292, + 0.11752951890230179, + -0.28825971484184265, + -1.113126277923584, + -0.6029489636421204, + 0.357056587934494, + 0.19766344130039215, + 0.023361098021268845, + 0.04305602237582207, + 0.24867205321788788, + 0.16359609365463257, + -0.2485191822052002, + -0.2251967489719391, + 0.030422789976000786, + 0.0049157580360770226, + -0.05497031658887863, + -0.030760835856199265, + 0.034536562860012054, + 0.019565051421523094, + -0.00933124776929617, + 0.01611645519733429, + 0.07988770306110382, + -0.021982649341225624, + -0.21876110136508942, + -0.10555483400821686, + 0.1893070936203003, + 0.14684906601905823, + -0.031080693006515503, + 0.09768003225326538, + 0.3261844515800476, + 0.1466774046421051, + -0.6738073825836182, + -0.5424039363861084, + 0.04689512774348259, + 0.22039148211479187, + -0.07084018737077713, + -0.07436021417379379, + 0.8260523080825806, + 1.0253428220748901, + -0.38162854313850403, + -0.727206289768219, + -0.2605172097682953, + -0.0996573269367218, + -0.3653049170970917, + -0.6791687607765198, + 0.43514078855514526, + 1.4186147451400757, + 0.38797008991241455, + -0.12675431370735168, + 0.02766786515712738, + 0.14237603545188904, + -0.2306709885597229, + -0.9204807877540588, + -0.5071616172790527, + 0.32662850618362427, + 0.20703284442424774, + -0.020968681201338768, + 0.014105334877967834, + 0.24642448127269745, + 0.20103473961353302, + -0.15519124269485474, + -0.22072142362594604, + 0.049920063465833664, + -0.05465548485517502, + 0.018651481717824936, + 0.030082669109106064, + 0.05234164372086525, + 0.10243640840053558, + 0.03569166734814644, + 0.038984544575214386, + 0.05248976871371269, + 0.24501988291740417, + 0.4674161374568939, + 0.7142530083656311, + 0.7423628568649292, + 0.6262048482894897, + 0.4019012451171875, + -0.010997634381055832, + 0.17266513407230377, + 0.4467124342918396, + 0.7795005440711975, + 0.8282667994499207, + 0.6824804544448853, + 0.3955397605895996, + 0.009771074168384075, + 0.10707246512174606, + 0.23039454221725464, + 0.33151063323020935, + 0.36120596528053284, + 0.3240644633769989, + 0.17939962446689606, + -0.01115038525313139, + -0.11081521213054657, + -0.2146066278219223, + -0.3572347164154053, + -0.44021451473236084, + -0.38320258259773254, + -0.24643990397453308, + 0.031578775495290756, + -0.21325217187404633, + -0.4312629997730255, + -0.7276368141174316, + -0.8273008465766907, + -0.718246340751648, + -0.4161607027053833, + -0.06636986136436462, + -0.28078269958496094, + -0.476252943277359, + -0.734549880027771, + -0.7796792984008789, + -0.6637035608291626, + -0.41896238923072815, + 0.021693198010325432, + 0.006199972704052925, + -0.016619624570012093, + -0.010678192600607872, + 0.012267512269318104, + 0.004102918319404125, + -0.004080160986632109, + -0.0029241242446005344, + -0.027252744883298874, + -0.0772257149219513, + -0.09107967466115952, + -0.11302012205123901, + -0.08569496124982834, + -0.07242150604724884, + -0.016465697437524796, + -0.04874062165617943, + -0.09103028476238251, + -0.09025602042675018, + -0.07523388415575027, + -0.06320428103208542, + -0.048220545053482056, + -0.028701437637209892, + -0.008647853508591652, + -0.022354092448949814, + -0.06076030433177948, + -0.030872423201799393, + -0.045786645263433456, + -0.04190178960561752, + 0.03718986362218857, + 0.021405767649412155, + 0.007675759959965944, + 0.02794131636619568, + 0.030316906049847603, + 0.007403802592307329, + 0.04861852154135704, + 0.023217258974909782, + 0.04545973241329193, + 0.07504793256521225, + 0.06824314594268799, + 0.07417462021112442, + 0.0769289955496788, + 0.0766506940126419, + -0.0028638055082410574, + 0.05911175534129143, + 0.055706772953271866, + 0.10735032707452774, + 0.10494870692491531, + 0.11092723160982132, + 0.09338293969631195, + 0.04235343262553215, + -0.022347571328282356, + -0.026347652077674866, + -0.06954608112573624, + -0.06944439560174942, + -0.05570404976606369, + -0.042987462133169174, + -0.056951191276311874, + -0.2151203453540802, + -0.3603246510028839, + -0.5899456143379211, + -0.6453464031219482, + -0.5338351726531982, + -0.31790611147880554, + 0.049492284655570984, + -0.12898015975952148, + -0.40155911445617676, + -0.6737278699874878, + -0.7170611619949341, + -0.5817899703979492, + -0.32979026436805725, + -0.005899591837078333, + -0.07673019915819168, + -0.190496027469635, + -0.34019437432289124, + -0.3314637243747711, + -0.2796767055988312, + -0.1381818801164627, + -0.008025999180972576, + 0.08429048955440521, + 0.2105528861284256, + 0.3415210545063019, + 0.4151126444339752, + 0.34003961086273193, + 0.21059827506542206, + -0.03514896333217621, + 0.1792585551738739, + 0.3903186321258545, + 0.6413942575454712, + 0.7557680010795593, + 0.6069726943969727, + 0.3415443003177643, + 0.03447553142905235, + 0.21517080068588257, + 0.4215562045574188, + 0.6151171922683716, + 0.6550290584564209, + 0.5680058002471924, + 0.33561068773269653, + -0.12205997854471207, + -0.0038300298620015383, + 0.3281119763851166, + -0.2328944057226181, + -0.03834507241845131, + 0.05432930961251259, + -0.014430212788283825, + 0.006271198857575655, + 0.32864242792129517, + 0.47277259826660156, + -0.5593215227127075, + -0.14971251785755157, + 0.13066314160823822, + -0.09738356620073318, + 0.2966129779815674, + 0.5606555342674255, + -0.3184640407562256, + -2.022890090942383, + -0.361995667219162, + 0.5496177673339844, + 0.02796279452741146, + -0.21818380057811737, + -0.5373459458351135, + -1.9538941383361816, + -1.9984712600708008, + 1.6747761964797974, + 1.5063239336013794, + -0.24534250795841217, + -0.040306344628334045, + -0.16963164508342743, + -0.40690454840660095, + 1.3548375368118286, + 3.922116279602051, + 0.8723023533821106, + -0.8986141681671143, + 0.06912416964769363, + 0.2192920595407486, + 0.352949321269989, + 1.2243634462356567, + 1.1395865678787231, + -1.5146961212158203, + -1.1557590961456299, + -0.05440744385123253, + -0.04629289731383324, + -0.002693743444979191, + -0.21906790137290955, + -0.5464610457420349, + -1.1933224201202393, + 0.01913866586983204, + 0.09363497048616409, + -0.06080613285303116, + -0.049100056290626526, + 0.04482033848762512, + -0.04087500274181366, + -0.009318803437054157, + 0.009458474814891815, + -0.09565524011850357, + -0.2264278084039688, + -0.0698866918683052, + 0.13825084269046783, + 0.014815542846918106, + -0.05801662430167198, + 0.012776852585375309, + -0.0753035843372345, + -0.07555855065584183, + 0.484436959028244, + 0.6397283673286438, + 0.12687323987483978, + -0.01779526099562645, + 0.05689511448144913, + 0.06747376173734665, + 0.26353734731674194, + 0.5908273458480835, + 0.4315526783466339, + -0.5426794290542603, + -0.44501280784606934, + -0.019558124244213104, + -0.03320806100964546, + -0.025809556245803833, + 0.17376014590263367, + -0.5201969742774963, + -1.2842578887939453, + -0.3674038052558899, + 0.0882175862789154, + -0.030023137107491493, + -0.1173325777053833, + 0.02555503323674202, + -0.39882710576057434, + -0.37364596128463745, + 0.3550366163253784, + 0.3903135359287262, + 0.04022252932190895, + 0.016731394454836845, + 0.11207644641399384, + -0.020967213436961174, + -0.028497911989688873, + 0.37590932846069336, + 0.14920172095298767, + 0.029958104714751244, + 0.039632707834243774, + -0.24969367682933807, + 0.16809938848018646, + 0.07703239470720291, + -0.03522319719195366, + -0.007072617299854755, + 0.07751759141683578, + -0.06782346963882446, + -0.4010501801967621, + 0.41269779205322266, + 0.1311105638742447, + -0.07331988960504532, + 0.08240311592817307, + -0.20034979283809662, + -0.4718745946884155, + -0.178948312997818, + 1.3285318613052368, + 0.20384186506271362, + -0.48546233773231506, + -0.09941625595092773, + 0.13249020278453827, + 0.29977336525917053, + 1.2681238651275635, + 1.5725642442703247, + -1.0834472179412842, + -1.0335719585418701, + 0.25975045561790466, + 0.06584863364696503, + 0.1609305590391159, + 0.25940945744514465, + -0.8426372408866882, + -2.590407609939575, + -0.4723183214664459, + 0.7581043243408203, + -0.03634117543697357, + -0.10199672728776932, + -0.3744191527366638, + -0.7823801636695862, + -0.7062401175498962, + 1.116550087928772, + 0.7735803127288818, + 0.012776976451277733, + 0.034575968980789185, + -0.10188565403223038, + 0.2212170958518982, + 0.5182898044586182, + 0.8056022524833679, + -0.1897655427455902, + -0.005556725896894932, + -0.003909373190253973, + -0.02175678312778473, + -0.04085654392838478, + -0.03573022410273552, + -0.0038509985897690058, + 0.02454996667802334, + 0.039437733590602875, + 0.02077251859009266, + 0.02166259102523327, + 0.17245841026306152, + 0.09513862431049347, + -0.10491111874580383, + -0.08084940910339355, + -0.026179829612374306, + 0.0215831957757473, + -0.16602416336536407, + -0.2803819179534912, + 0.23894084990024567, + 0.3269801735877991, + 0.04504352807998657, + 0.0009768904419615865, + 0.01959501951932907, + 0.24426960945129395, + -0.1451571136713028, + -0.5944203734397888, + -0.17875447869300842, + 0.028336334973573685, + 0.004323791246861219, + -0.045389141887426376, + 0.0343034490942955, + 0.46665430068969727, + 0.3707427978515625, + -0.114569291472435, + 0.04335101321339607, + -0.018011711537837982, + -0.021181274205446243, + -0.19074901938438416, + -0.20113815367221832, + 0.048786211758852005, + 0.08533122390508652, + -0.06084573268890381, + 0.01217757910490036, + 0.030666939914226532, + 0.05272842198610306, + 0.010849648155272007, + -0.05913804844021797, + -0.04202868044376373, + -0.0015147016383707523, + -0.03421122953295708, + 0.015080726705491543, + 0.12191007286310196, + 0.10450142621994019, + -0.04972418025135994, + -0.07557133585214615, + -0.02221665158867836, + -0.0861242413520813, + -0.14919178187847137, + -0.04388582333922386, + 0.4605262875556946, + 0.5697804093360901, + 0.1583399623632431, + -0.045628566294908524, + -0.05220475420355797, + -0.13630147278308868, + -0.7103163599967957, + -1.0178179740905762, + 0.1927143931388855, + 0.7479860186576843, + 0.47013771533966064, + 0.16943301260471344, + 0.2398149073123932, + 0.4710526168346405, + -0.5974176526069641, + -1.8564051389694214, + -0.7726883292198181, + 0.05584309995174408, + 0.08902852982282639, + 0.0931839719414711, + 0.46213099360466003, + 1.2080260515213013, + 0.6001025438308716, + -0.590207576751709, + -0.4145379662513733, + -0.04529324173927307, + -0.08303339034318924, + -0.2470429688692093, + -0.03481363505125046, + 0.4808541238307953, + 0.4001348614692688, + -0.1292688548564911, + -0.03635162487626076, + -0.006270444020628929, + -0.0314505510032177, + -0.13043232262134552, + -0.10837803781032562, + 0.10718243569135666, + 0.07523836195468903, + -0.00597786670550704, + 0.06580565124750137, + 0.11166563630104065, + 0.021869506686925888, + -0.10510984063148499, + -0.07651247084140778, + 0.01229890063405037, + -0.08976037800312042, + -0.14929910004138947, + -0.018859578296542168, + 0.4408939778804779, + 0.4029107689857483, + -0.05015433207154274, + -0.13887189328670502, + -0.04514491930603981, + -0.07346425950527191, + -0.5277182459831238, + -0.7335640788078308, + 0.24182197451591492, + 0.626846432685852, + 0.23399080336093903, + 0.09675730019807816, + 0.15529058873653412, + 0.42680656909942627, + -0.4012089967727661, + -1.3605350255966187, + -0.4793834686279297, + 0.10987094044685364, + 0.07592830061912537, + 0.003319029463455081, + 0.24004696309566498, + 0.9590277671813965, + 0.4946591258049011, + -0.4889579117298126, + -0.34744441509246826, + -0.020535729825496674, + -0.026767954230308533, + -0.2090117186307907, + -0.11841326951980591, + 0.37452432513237, + 0.39960840344429016, + -0.07025045901536942, + -0.022984744980931282, + 0.022319970652461052, + -0.0027356306090950966, + -0.13681942224502563, + -0.09797768294811249, + 0.09914079308509827, + 0.10856777429580688, + ] + value = numpy.array(list_value, dtype=numpy.float32).reshape((64, 3, 7, 7)) + tensor = numpy_helper.from_array(value, name="onnx::Conv_501") + + initializers.append(tensor) + + list_value = [ + 3.085598945617676, + 2.2436060905456543, + 4.244357585906982, + 1.4069645404815674, + -4.00622034072876, + 2.595770835876465, + 2.7202603816986084, + 2.4405417442321777, + 1.1759933233261108, + 2.021026372909546, + 2.6628992557525635, + 6.445226192474365, + -7.029932498931885, + 1.1305793523788452, + 2.537140369415283, + 5.456772327423096, + 4.780154705047607, + 10.039976119995117, + 2.912492275238037, + 15.781542778015137, + 2.5154318809509277, + 2.628824472427368, + 2.2992050647735596, + 2.0950584411621094, + -7.93365478515625, + 2.067786931991577, + 4.094852447509766, + 1.673399806022644, + 3.1814424991607666, + 22.49496078491211, + 2.232640027999878, + 2.6427979469299316, + -9.418174743652344, + 1.790976643562317, + 2.3774726390838623, + 2.5836219787597656, + 2.5608203411102295, + 2.287343978881836, + 2.6439085006713867, + 16.859027862548828, + 1.8699607849121094, + -3.6987526416778564, + 2.6861538887023926, + 2.8997464179992676, + 2.689293384552002, + 2.6654043197631836, + 2.3799915313720703, + 2.5603086948394775, + 3.146122694015503, + 2.715951681137085, + 2.889486789703369, + 2.966134548187256, + -4.960191249847412, + 2.6123547554016113, + 1.3074164390563965, + 2.2033026218414307, + 2.2114620208740234, + 4.132844924926758, + 4.893764495849609, + 2.6469600200653076, + 2.654136896133423, + 1.9311997890472412, + 2.881012439727783, + 2.6991193294525146, + ] + value = numpy.array(list_value, dtype=numpy.float32) + tensor = numpy_helper.from_array(value, name="onnx::Conv_502") + + initializers.append(tensor) + + list_value = [ + 0.057212892919778824, + 0.06299274414777756, + -0.018499961122870445, + -0.06501776725053787, + -0.015820641070604324, + 0.024293724447488785, + 0.05624663084745407, + -0.025112055242061615, + 0.043546054512262344, + 0.08439744263887405, + 0.005678815301507711, + 0.0034800865687429905, + 0.030301403254270554, + -0.011669250205159187, + -0.005434689112007618, + -0.1591511219739914, + 0.02324092946946621, + -0.018942436203360558, + 0.025366367772221565, + -0.07414374500513077, + 0.03468436002731323, + -0.003742520697414875, + -0.06651683896780014, + 0.005561002530157566, + 0.04527103528380394, + -0.13710148632526398, + 0.0025444801431149244, + 0.03583350405097008, + 0.015219246037304401, + -0.053635064512491226, + 0.004856681916862726, + -0.07223699986934662, + 0.016770021989941597, + 0.0012010147329419851, + 0.014582094736397266, + -0.005172556731849909, + 0.02009868621826172, + -0.0064261858351528645, + -0.029086023569107056, + 0.001915874076075852, + 0.0008194410474970937, + 0.01620865799486637, + 0.03067426010966301, + -0.0018463254673406482, + 0.05358384922146797, + -0.003966080490499735, + -0.05991416424512863, + -0.06455761194229126, + 0.01634763367474079, + -0.013959774747490883, + 0.03615918383002281, + 0.004434086848050356, + 0.02086004987359047, + -0.004025993403047323, + -0.8869641423225403, + 0.05558132007718086, + 0.024729542434215546, + -0.005809253081679344, + -0.025079259648919106, + 0.04757235199213028, + 0.0023902510292828083, + 0.01522061601281166, + 0.011692625470459461, + 0.023033330217003822, + -0.012664714828133583, + -0.29325294494628906, + -0.006855700630694628, + -0.243958979845047, + 0.0024398649111390114, + -0.060877203941345215, + -0.21996521949768066, + -0.008708474226295948, + -0.06639625877141953, + -0.03170674294233322, + -0.09708897024393082, + 0.013403226621448994, + 0.024766888469457626, + 0.2594103217124939, + -0.02221749909222126, + 0.0662861093878746, + -0.15123076736927032, + -0.010314224287867546, + -0.0029192541260272264, + 0.05985910817980766, + 0.021665453910827637, + 0.003247617743909359, + -0.006802591495215893, + 0.00772367138415575, + 0.0399332195520401, + 0.005198766943067312, + 0.006013805978000164, + -0.04212838411331177, + -0.03166411817073822, + 0.13363900780677795, + 0.006383878644555807, + -0.05536859482526779, + 0.02053261175751686, + 0.015062958002090454, + 0.03352641686797142, + -0.2944328486919403, + 0.019855381920933723, + -0.15567174553871155, + -0.06759943068027496, + 0.07467031478881836, + 0.01674237661063671, + 0.004549413453787565, + -0.0032498433720320463, + -0.1837870180606842, + -0.04725493863224983, + -0.111307792365551, + 0.022237055003643036, + 0.004200428258627653, + 0.00970534235239029, + -0.045657914131879807, + -0.024577995762228966, + 0.0035376595333218575, + 0.008936531841754913, + -0.03904002904891968, + 0.05013228952884674, + -0.011168933473527431, + -0.008444730192422867, + 0.0035155978985130787, + -0.023502476513385773, + 0.005275514908134937, + -0.09448224306106567, + -0.009177467785775661, + -0.010720008052885532, + 0.004110944457352161, + -0.0060218218713998795, + 0.058124978095293045, + -0.0016586220590397716, + 0.15812785923480988, + -0.049118027091026306, + -0.007983109913766384, + -0.04265601187944412, + -0.01627231575548649, + 0.33705562353134155, + 0.01555223111063242, + 0.035853929817676544, + 0.0005046340520493686, + 0.054810188710689545, + -0.08808254450559616, + -0.0013819067971780896, + -0.14938786625862122, + -0.019771935418248177, + 0.004152575507760048, + 0.021979758515954018, + 0.1985529363155365, + -0.07694264501333237, + 0.013187955133616924, + -0.016572976484894753, + -0.03094586730003357, + -0.03673199936747551, + -0.03916170820593834, + -0.003836784977465868, + -0.012262578122317791, + 0.005559554789215326, + 0.1488093137741089, + -0.01842501200735569, + -0.004847189411520958, + -0.02391587756574154, + 0.015824301168322563, + 0.012022596783936024, + 0.06724318116903305, + -0.032682593911886215, + 0.00450896704569459, + -0.0024625889491289854, + 0.00933725107461214, + -0.04473242908716202, + 0.06270455569028854, + -0.02062271721661091, + -0.01071448065340519, + -0.017757099121809006, + 0.01575278490781784, + -0.06489317119121552, + -0.01519051194190979, + 0.0028058059979230165, + 0.00917835533618927, + -0.01291860081255436, + -0.009537308476865292, + 0.041757628321647644, + 0.03203853219747543, + -0.10918509215116501, + -0.007152496371418238, + -0.06777876615524292, + 0.03223242610692978, + 0.01780836284160614, + -0.09791012853384018, + -0.009385241195559502, + 0.013184775598347187, + 0.0031673219054937363, + -0.010640445165336132, + 0.024713385850191116, + -0.026738369837403297, + -0.004191657993942499, + -0.13764967024326324, + -0.003720735665410757, + 0.01737186871469021, + 0.015459887683391571, + 0.033229030668735504, + 0.008042111992835999, + -0.007184108253568411, + 0.008226306177675724, + 0.0031303109135478735, + 0.0406314842402935, + -0.8669105768203735, + 0.02079751342535019, + -0.17030003666877747, + -0.03849703446030617, + 0.034153200685977936, + -0.007219486869871616, + 0.11227627843618393, + -0.2681085467338562, + 0.015872526913881302, + 0.10855260491371155, + -0.008631505072116852, + 0.02556358277797699, + 0.06043418496847153, + -0.012900532223284245, + -0.08834894001483917, + 0.028099440038204193, + -0.05156330019235611, + 0.032628703862428665, + 0.044928934425115585, + 0.006176372990012169, + 0.007333829998970032, + -0.037409231066703796, + -0.046724822372198105, + -0.011172871105372906, + 0.04603327810764313, + 0.03288746625185013, + -0.20848578214645386, + 0.0028185085393488407, + -0.032673876732587814, + 0.061944279819726944, + 0.016787173226475716, + 0.02703898213803768, + -0.0060023171827197075, + 0.06870592385530472, + 0.03154531493782997, + 0.02784041129052639, + 0.007780189625918865, + 0.02033168077468872, + 0.0019289497286081314, + 0.02545374445617199, + 0.04262726008892059, + 0.01301807351410389, + -0.023882156237959862, + 0.027872221544384956, + -0.013518108054995537, + -0.0031075032893568277, + 0.03753834590315819, + 0.0369209349155426, + -0.014378191903233528, + 0.004397932440042496, + -0.030286893248558044, + -0.007679021451622248, + -0.045032769441604614, + 0.032050322741270065, + -0.03373495861887932, + -0.04363032802939415, + 0.034301597625017166, + -0.07021668553352356, + 0.03942524269223213, + -0.11061309278011322, + 0.049139462411403656, + 0.04161922261118889, + -0.01507576834410429, + -0.012748259119689465, + 0.06599434465169907, + 0.007602245546877384, + -0.03973209857940674, + -0.06923151016235352, + 0.026153067126870155, + -0.04221056029200554, + -0.4828230142593384, + 0.03360651433467865, + 0.01847662217915058, + -0.08594681322574615, + 0.04071836546063423, + -0.0035729086957871914, + 0.0049045816995203495, + -0.036198534071445465, + 0.03046257793903351, + 0.013275806792080402, + 0.09266786277294159, + -0.03625647351145744, + -0.059672992676496506, + 0.050213005393743515, + -0.018153885379433632, + -0.0858495831489563, + 0.01621098257601261, + -0.03029749169945717, + 0.02193332649767399, + 0.0422661192715168, + 0.6109512448310852, + -0.01068826112896204, + -0.02184930257499218, + -0.03213764354586601, + -0.03148162364959717, + -0.055331334471702576, + 0.006972005590796471, + -0.00815682765096426, + 0.014874683693051338, + -0.012943249195814133, + -0.03318992629647255, + -0.0010484680533409119, + 0.005414161365479231, + -0.013610370457172394, + 0.008836873807013035, + -0.05890084058046341, + -0.022663919255137444, + -0.018899116665124893, + -0.01037894282490015, + 0.005064660683274269, + 0.08522599190473557, + 0.0075323861092329025, + 0.013720778748393059, + 0.032096460461616516, + -0.008450351655483246, + 0.020377663895487785, + 0.04537765309214592, + 0.014030816033482552, + 0.024340089410543442, + 0.0231801588088274, + -0.10347768664360046, + 0.041163086891174316, + -0.060614243149757385, + -0.09241361171007156, + 0.05831432715058327, + -0.16008608043193817, + -0.04505622759461403, + 0.04866329953074455, + -0.0656094029545784, + 0.09627313911914825, + 0.1153625100851059, + 0.008151216432452202, + 0.03813345730304718, + 0.05990723893046379, + 0.24788673222064972, + 0.06294118613004684, + 0.11761849373579025, + -0.0722033903002739, + -0.013892017304897308, + -0.016778236255049706, + 0.038522012531757355, + -0.015539593063294888, + 0.01263216882944107, + 0.0003969807003159076, + -0.0224238783121109, + -0.005919966846704483, + 0.031987495720386505, + -0.014712700620293617, + 0.03508169203996658, + 0.07568854838609695, + -0.011961974203586578, + 0.027983952313661575, + -0.03512958809733391, + -0.010324078612029552, + -0.2895449995994568, + 0.007338976487517357, + -0.042290836572647095, + -0.1640917807817459, + -0.034807007759809494, + -0.1268443465232849, + 0.18418198823928833, + -0.3867812156677246, + -0.14214494824409485, + 0.001021744217723608, + 0.11288078874349594, + 0.006741920951753855, + -0.006421610247343779, + 0.021150892600417137, + 0.02486848644912243, + 0.002660338068380952, + 0.03732302784919739, + 0.10844919830560684, + -0.032568808645009995, + 0.009477612562477589, + 0.053578171879053116, + -0.07421902567148209, + 0.05660263076424599, + 0.03038308583199978, + 0.049440011382102966, + 0.0395139642059803, + 0.0217339675873518, + 0.028231965377926826, + 0.1661153882741928, + -0.02168717049062252, + 0.055143170058727264, + -0.14159196615219116, + 0.05894732475280762, + 0.006888065952807665, + -0.06988262385129929, + 0.017527412623167038, + -0.007171930745244026, + -0.00448343763127923, + 0.02932717651128769, + -0.00652179354801774, + -0.002897858154028654, + 0.020487705245614052, + -0.027063967660069466, + -0.02539752423763275, + -0.1066114604473114, + -0.10011029988527298, + -0.03331710025668144, + -0.003807300003245473, + -0.010441976599395275, + -0.005605363752692938, + 0.09679440408945084, + 0.020033519715070724, + -0.010188378393650055, + -0.030630890280008316, + -0.00955540407449007, + 0.02825581096112728, + -0.4307324290275574, + 0.012557203881442547, + 0.043258048593997955, + 0.09386534243822098, + -0.009555542841553688, + 0.05304868891835213, + 0.014706632122397423, + -0.012911850586533546, + 0.0981304720044136, + -0.010722141712903976, + -0.027317194268107414, + 0.0893903523683548, + -0.19983792304992676, + -0.15778200328350067, + -0.1012115329504013, + -0.3758164644241333, + -0.05782865360379219, + -0.01230492815375328, + -0.37126046419143677, + -0.01596723683178425, + 0.0020407456904649734, + -0.017498979344964027, + 0.005369496997445822, + -0.023121315985918045, + 0.022279681637883186, + -0.006232256535440683, + 0.05115891620516777, + 0.006679570768028498, + 0.0026316209696233273, + 0.04291496425867081, + 0.04381528124213219, + -0.05994122102856636, + 0.007081915624439716, + -0.04571640491485596, + 0.07592425495386124, + -0.00836833007633686, + 0.008123279549181461, + -0.008003163151443005, + -0.003938044421374798, + 0.005643180105835199, + 0.016194086521863937, + -0.004063089843839407, + 0.012334472499787807, + 0.017072021961212158, + 0.005761854816228151, + 0.004702428821474314, + 0.005736868362873793, + 0.0017962371930480003, + 0.059996701776981354, + 0.19533602893352509, + 0.02649352326989174, + -0.06493135541677475, + -0.05955052375793457, + 0.015692468732595444, + -0.10623155534267426, + 0.07290898263454437, + 0.036108434200286865, + -0.01248949021100998, + 0.16444285213947296, + -0.005899128969758749, + 0.07875277101993561, + 0.0014204353792592883, + 0.03381470963358879, + -0.09680792689323425, + 0.002102318685501814, + 0.026962973177433014, + 0.031665392220020294, + -0.18168538808822632, + 0.11163855344057083, + -0.5409999489784241, + 0.07833191007375717, + -0.005324948113411665, + 0.0267564058303833, + 0.02250477857887745, + 0.03249068558216095, + -0.18441715836524963, + -0.006447427906095982, + 0.037927329540252686, + 0.0005173985846340656, + -0.02617005631327629, + 0.05929232016205788, + -0.028510913252830505, + 0.05447050556540489, + 0.012390155345201492, + 0.00046797769027762115, + -0.008598590269684792, + -0.17247197031974792, + -0.02855759859085083, + 0.033968932926654816, + -0.09011702984571457, + 0.05276056379079819, + 0.03299655020236969, + -0.005699596833437681, + -0.1954648792743683, + 0.011109501123428345, + -0.0013570536393672228, + -0.6543989181518555, + 0.009102803654968739, + 0.0407538004219532, + 0.04312055557966232, + 0.027609223499894142, + -0.035538043826818466, + 0.027167823165655136, + -0.024043193086981773, + 0.0047575319185853004, + -0.006788836792111397, + 0.025714389979839325, + 0.007848678156733513, + -0.07680192589759827, + 0.009700766764581203, + -0.0097329281270504, + 0.00586724653840065, + 0.022815868258476257, + -0.023448282852768898, + -0.05608998239040375, + 0.10786863416433334, + -0.02803603559732437, + 0.012898198328912258, + -0.009270391426980495, + -0.021972229704260826, + 0.26533082127571106, + -0.01021308358758688, + -0.01972626894712448, + 0.062940314412117, + 0.022569671273231506, + 0.027042347937822342, + -0.05669092759490013, + -0.01200617104768753, + -0.006279367487877607, + -0.009608528576791286, + -0.013600943610072136, + -0.02187415212392807, + 0.0351138636469841, + 0.006282923277467489, + -0.011123511008918285, + -0.009205769747495651, + 0.001010146806947887, + -0.4796978235244751, + -0.0030205894727259874, + -0.011987377889454365, + -0.027548225596547127, + 0.009372347965836525, + -0.005388603545725346, + -0.006444129627197981, + -0.02501147985458374, + 0.027465635910630226, + 0.027784524485468864, + 0.006878893356770277, + -0.027763860300183296, + -0.0047700353898108006, + -0.018965192139148712, + 0.027898501604795456, + 0.022454144433140755, + 0.02973407506942749, + 0.03505602851510048, + 0.04003170132637024, + -0.004336829297244549, + -0.01998550072312355, + -0.06097743660211563, + -0.07844759523868561, + 0.0013787010684609413, + 0.0066132270731031895, + -0.03124997951090336, + 0.0313432514667511, + 0.047656893730163574, + 0.06175797060132027, + -0.02077358029782772, + -0.004535601008683443, + -0.10219905525445938, + -0.07125344127416611, + -0.06927482783794403, + -0.04813461750745773, + -0.02618095651268959, + -0.01255929097533226, + -0.009180150926113129, + -0.005838831886649132, + 0.09108023345470428, + -0.032710760831832886, + 0.03091445378959179, + -0.01955563761293888, + 0.0959300771355629, + -0.09353741258382797, + -0.0761636272072792, + -0.023445438593626022, + -0.012328366748988628, + 0.05850536748766899, + -0.052494827657938004, + 0.0025638933293521404, + -0.017152179032564163, + -0.004435579292476177, + 0.12312240898609161, + -0.007241012528538704, + 0.09605048596858978, + 0.03355967625975609, + -0.015987426042556763, + -0.03470349311828613, + -0.02499505691230297, + -0.015004142187535763, + -0.018609771504998207, + -0.06654462963342667, + 0.013861652463674545, + -0.005973289255052805, + -0.04734775796532631, + 0.08755116909742355, + 0.03012942522764206, + 0.07887610793113708, + -0.01827712170779705, + 0.10793066769838333, + 0.10793614387512207, + -0.01075535174459219, + 0.03439560532569885, + 0.011567444540560246, + 0.0016386889619752765, + -0.031207261607050896, + -0.01707504875957966, + 0.20471863448619843, + 0.0025428179651498795, + 0.004082779865711927, + -0.012389302253723145, + 0.0400562584400177, + -0.21075034141540527, + 0.012872264720499516, + -0.01639414019882679, + 0.016652485355734825, + 0.0016037120949476957, + -0.006540367379784584, + -0.0068405005149543285, + -0.2484254390001297, + 0.0008089764742180705, + -0.022340824827551842, + -0.005441636312752962, + 0.002882100408896804, + 0.008654038421809673, + 0.07159754633903503, + -0.02537086047232151, + 0.011997461318969727, + -0.49913132190704346, + -0.02300887741148472, + 0.044442202895879745, + 0.001787978457286954, + 0.010291379876434803, + 0.009601960889995098, + -0.5312613248825073, + -0.014247804880142212, + 0.06685849279165268, + 0.035772595554590225, + 0.03432310372591019, + 0.03151272237300873, + -0.10318460315465927, + -0.030476456508040428, + -0.004469831008464098, + -0.16645164787769318, + -0.021104637533426285, + 0.013934006914496422, + -0.011767406016588211, + 0.008054615929722786, + 0.06089277192950249, + 0.0003409573109820485, + -0.0053401123732328415, + 0.05970478057861328, + -0.004363172687590122, + 0.014423285610973835, + -0.002795026171952486, + -0.019875092431902885, + -0.07540513575077057, + -0.09043378382921219, + 0.00750827556475997, + -0.045314721763134, + -0.00724808732047677, + 0.005193864461034536, + -0.020468784496188164, + -0.01098695583641529, + -0.0003122477210126817, + -0.007263806648552418, + -0.03325646370649338, + 0.021689830347895622, + -0.13272541761398315, + 0.02332465350627899, + -0.019292252138257027, + 0.05533658340573311, + -0.018616480752825737, + -0.015228793025016785, + -0.28432801365852356, + -0.29721561074256897, + 0.04648810625076294, + -0.014750649221241474, + -0.15370936691761017, + -0.1497083604335785, + 0.013243601657450199, + 0.042343802750110626, + -0.017519792541861534, + -0.0161418616771698, + 0.00807454064488411, + -0.023562468588352203, + -0.0315413773059845, + 0.03386805206537247, + 0.2854529917240143, + 0.0191020630300045, + -0.49126777052879333, + 0.052687134593725204, + -0.023298051208257675, + -0.009119837544858456, + 0.05149759724736214, + -0.8527837991714478, + 0.08062390983104706, + 0.057379938662052155, + -0.020724931731820107, + -0.006624895613640547, + 0.05322050303220749, + 0.017887847498059273, + 0.04229281470179558, + 0.04171830415725708, + 0.029683062806725502, + -0.00028416322311386466, + 0.1112222746014595, + -0.0448714978992939, + -0.005255761090666056, + 0.017773712053894997, + -0.0016064767260104418, + -0.013840594328939915, + -0.00398495327681303, + -4.32919041486457e-05, + 0.040796443819999695, + 0.018185198307037354, + -0.018671950325369835, + 0.0028256692457944155, + -0.020582057535648346, + 0.05567716807126999, + -0.056062404066324234, + 0.01614757999777794, + -0.0029299987945705652, + 0.048686008900403976, + 0.04299888014793396, + 0.12249592691659927, + 0.01469603180885315, + -0.1254546344280243, + -0.18532024323940277, + -0.003263876074925065, + 0.014804725535213947, + 0.004450956825166941, + -0.013681051321327686, + -0.0030781759414821863, + -0.03433656692504883, + -0.0035507124848663807, + 0.1600082814693451, + -0.028547707945108414, + -0.00989136379212141, + -0.012126478366553783, + -0.12963305413722992, + 0.008547360077500343, + 0.017959514632821083, + -0.012571084313094616, + 0.0008666724897921085, + -0.010519342496991158, + -0.009684977121651173, + -0.04285729303956032, + 0.015031769871711731, + -0.030043724924325943, + 0.018907636404037476, + 0.08019450306892395, + -0.04836742579936981, + 0.01025464478880167, + -0.004908542148768902, + -0.10327022522687912, + -0.10163667798042297, + -0.03403499722480774, + -0.019678063690662384, + -0.043049123138189316, + 0.0384567566215992, + -0.05596519634127617, + -0.09381429851055145, + -0.18688108026981354, + -0.09762943536043167, + -0.03164997324347496, + -0.006416287273168564, + 0.07003920525312424, + -0.016646990552544594, + -0.025972194969654083, + -0.028768088668584824, + -0.06332779675722122, + 0.045144014060497284, + -0.03735211119055748, + -0.010442189872264862, + 0.10948455333709717, + 0.14629514515399933, + -0.023416690528392792, + -0.01347778458148241, + 0.020830679684877396, + 0.0003131759003736079, + 0.007049075793474913, + 0.06547018885612488, + 0.03152740001678467, + 0.08380027115345001, + 0.03185325488448143, + -0.015359007753431797, + 0.08864206075668335, + 0.032676901668310165, + -0.002908645663410425, + 0.053111132234334946, + 0.0026159954722970724, + -0.05177146941423416, + -0.033048152923583984, + -0.0020293137058615685, + -0.07363513857126236, + -0.17662747204303741, + 0.004798125941306353, + 0.07139395922422409, + 0.019802849739789963, + 0.009199771098792553, + -0.009043877013027668, + -0.07681646943092346, + -0.06748555600643158, + 0.05094710737466812, + 0.0014789587585255504, + -0.0166088305413723, + -0.27988284826278687, + 0.03634800389409065, + 0.05322619527578354, + -0.15566207468509674, + -0.019964642822742462, + -0.010204506106674671, + -0.011832086369395256, + -0.0680927112698555, + -0.05793820694088936, + 0.0020100779365748167, + -0.24647225439548492, + 0.04904041066765785, + -0.05589786171913147, + -0.030167482793331146, + 0.023974033072590828, + -0.22719347476959229, + 0.019620347768068314, + -0.18078163266181946, + -0.11321499198675156, + -0.023790234699845314, + -0.1266157031059265, + 0.01117659267038107, + 0.13824795186519623, + -0.024211348965764046, + -0.0548308864235878, + 0.04849318787455559, + -0.0016174454940482974, + -0.01826266385614872, + 0.006709347013384104, + -0.350631982088089, + 0.03139018639922142, + 0.021502504125237465, + -0.12596893310546875, + 0.04311670735478401, + -0.005905786994844675, + -0.0807335153222084, + -0.07214773446321487, + -0.2054852843284607, + -0.04526854678988457, + -0.09145382046699524, + 0.002603817731142044, + -0.01951524056494236, + -0.0028278473764657974, + -0.03270411863923073, + -0.0003385065938346088, + -0.019816655665636063, + -0.003430107608437538, + 0.010664679110050201, + 0.030127109959721565, + 0.02611778862774372, + 0.030213139951229095, + 0.04682943969964981, + 0.010338326916098595, + -0.02618880569934845, + 0.014982170425355434, + -0.06979402899742126, + 0.06403722614049911, + 0.025545112788677216, + -0.11981001496315002, + 0.004320457112044096, + 0.008849565871059895, + 0.07450827211141586, + -0.04322020336985588, + -0.07648278027772903, + 0.009221173822879791, + -0.12771189212799072, + 0.027474528178572655, + -0.1637975573539734, + -0.022587651386857033, + 0.0713210329413414, + -0.09652210026979446, + -0.04942077025771141, + -0.08977267891168594, + -0.004629603121429682, + -0.09891843795776367, + 0.0004028059483971447, + 0.12999524176120758, + 0.009417874738574028, + -0.012465995736420155, + 0.09959464520215988, + 0.012048770673573017, + 0.00529639283195138, + -0.1231047734618187, + -0.010156300850212574, + -0.0067022680304944515, + 0.09231371432542801, + 0.1372271031141281, + 0.01140755694359541, + -0.014376018196344376, + 0.009014246053993702, + -0.0558021254837513, + 0.009297777898609638, + -0.023461824283003807, + 0.12312523275613785, + 0.0013492326252162457, + -0.10130659490823746, + 0.07867099344730377, + -0.04363301396369934, + -0.05203291028738022, + 0.010715829208493233, + 0.2679101228713989, + 0.047242000699043274, + 0.009700302965939045, + -0.004188477993011475, + 0.04595324397087097, + -0.10256988555192947, + 0.013266253285109997, + 0.13415516912937164, + -0.06461263447999954, + -0.04262775555253029, + 0.014638054184615612, + -0.020396970212459564, + 0.016008291393518448, + 0.012964261695742607, + 0.030219901353120804, + -0.03906702250242233, + -0.009459082037210464, + -0.006880247965455055, + 0.009383107535541058, + 0.0591101311147213, + -0.049882922321558, + -0.014105924405157566, + -0.04896679148077965, + 0.021726086735725403, + -0.013863577507436275, + -0.05801064148545265, + -0.031143831089138985, + 0.0010298469569534063, + -0.03104572743177414, + 0.1193046048283577, + 0.00880056619644165, + -0.01678626798093319, + 0.0014990485506132245, + -0.001967367948964238, + -0.0053575835190713406, + -0.006879259832203388, + -0.008937212638556957, + 0.014141763560473919, + 0.00687083275988698, + -0.0012949275551363826, + 0.017160816118121147, + -0.035110652446746826, + -0.00976842176169157, + 0.026605995371937752, + 0.004003277514129877, + 0.010927689261734486, + 0.002173327375203371, + -0.05133439600467682, + -0.04658171907067299, + 0.03023359179496765, + -0.015038624405860901, + 0.016580749303102493, + 0.02393144741654396, + 0.004817661829292774, + -0.008468102663755417, + 0.017239807173609734, + 0.019924553111195564, + 0.02557404898107052, + 0.01985766738653183, + -0.01881517469882965, + -0.14637643098831177, + -0.005403783638030291, + -0.013156545348465443, + -0.3882855176925659, + 0.01537711638957262, + 0.005061861593276262, + 0.018044542521238327, + 0.00010373388067819178, + -0.01769324019551277, + -0.020439250394701958, + 0.01761222817003727, + 0.017716309055685997, + -0.01828574948012829, + 0.0059916484169662, + 0.006117791403084993, + -0.0025541253853589296, + 0.01598154753446579, + 0.0015296537894755602, + 0.006711189169436693, + -0.005831963382661343, + 0.024547481909394264, + 0.011665170080959797, + 0.013990279287099838, + -0.009193074889481068, + -0.0014407691778615117, + 0.0025373499374836683, + -0.001535113900899887, + 0.022016262635588646, + 0.002165747107937932, + -0.00010288839985150844, + -0.01185672264546156, + 0.3959958255290985, + -0.06701132655143738, + 0.024550342932343483, + -0.007259713020175695, + 0.00011224728223169222, + 0.08959072828292847, + 0.006745494436472654, + -0.007461291737854481, + -0.0010788652580231428, + -0.003997487016022205, + 0.0023250498343259096, + 0.005845727398991585, + 0.002441686810925603, + 0.0010628585005179048, + 0.004687050357460976, + 0.03825820982456207, + 0.0027951127849519253, + 0.004356732591986656, + 0.0036379920784384012, + -0.00048690394032746553, + -0.31681910157203674, + 0.01621195860207081, + 0.009373913519084454, + -0.005099120549857616, + 0.004866141825914383, + 0.008112045004963875, + -0.009933174587786198, + -0.006929770577698946, + 0.005561198107898235, + -0.2225065976381302, + -0.00019208311277907342, + -0.003284667618572712, + 0.010527989827096462, + -0.010160842910408974, + -0.008410060778260231, + 0.004605174530297518, + 0.01542133092880249, + 0.013958578929305077, + 0.0021779180970042944, + 0.002810562262311578, + 0.001369283301755786, + -0.0003347232413943857, + 0.013902815990149975, + -0.0022218015510588884, + 0.00024955783737823367, + -0.0019350153161212802, + 0.0025213193148374557, + -0.0054915109649300575, + -0.00011564489977899939, + -0.0037644850090146065, + -0.002863431815057993, + -0.0025196163915097713, + 0.02352992817759514, + 0.00354134407825768, + -0.010700036771595478, + -0.03428381308913231, + 0.008170859888195992, + 0.005420713219791651, + -0.0013479178305715322, + 0.0015741022070869803, + -0.18286381661891937, + 0.03189067915081978, + 0.0014371845172718167, + -4.885893940809183e-05, + -0.004666821099817753, + -0.026595929637551308, + -0.0064376350492239, + 0.01583540253341198, + -0.085715651512146, + -0.00916224904358387, + -0.3605174124240875, + 0.019973354414105415, + 0.05533794313669205, + 0.053907446563243866, + 0.030877795070409775, + -0.919844925403595, + 8.968543988885358e-05, + -0.02068270742893219, + 0.012602192349731922, + 0.03245612978935242, + 0.06622699648141861, + 0.00882122665643692, + -0.03616628423333168, + -0.02428283728659153, + 0.003318701172247529, + -0.0007259293342940509, + -0.026197656989097595, + -0.059503961354494095, + 0.029495801776647568, + -0.006955073680728674, + -0.01926456019282341, + 0.009927013888955116, + 0.059641581028699875, + 0.0016886347439140081, + -0.029346982017159462, + 0.01948450319468975, + -0.04397860914468765, + 0.025248751044273376, + 0.04597266763448715, + 0.009454794228076935, + -0.018872544169425964, + -0.039650529623031616, + 0.026324709877371788, + -0.01808176562190056, + 0.028935831040143967, + 0.009501701220870018, + -0.05183069407939911, + -0.005787428934127092, + -0.021436212584376335, + 0.029735956341028214, + 0.0350160151720047, + 0.033825185149908066, + 0.03185566887259483, + 0.018431033939123154, + 0.02450188808143139, + 0.03271135315299034, + -0.0027792940381914377, + -0.0004625302099157125, + 0.01268392987549305, + 0.045023106038570404, + 0.05562014505267143, + 0.029052015393972397, + -0.002513203304260969, + -0.08349838852882385, + 7.017837560852058e-06, + -0.0014392733573913574, + 0.016982918605208397, + 0.016358936205506325, + -0.024013325572013855, + -0.004375616554170847, + -0.03734249249100685, + 0.04336351156234741, + 0.07323610782623291, + -0.0243068914860487, + 0.009403819218277931, + 0.02663031965494156, + 0.01930687017738819, + 0.02175578847527504, + 0.01639295555651188, + 0.024892140179872513, + 0.031219134107232094, + 0.02986173704266548, + -0.002100786194205284, + 0.05054357647895813, + 0.04015854373574257, + 0.0048207067884504795, + -0.03244275599718094, + 0.027246609330177307, + 0.00409608893096447, + -0.0054193479008972645, + 0.07014931738376617, + 0.009954879060387611, + 0.022472694516181946, + -0.47738370299339294, + -0.019097158685326576, + 0.028984038159251213, + -0.042564358562231064, + -0.006040808744728565, + 0.04094231128692627, + -0.007740774191915989, + -0.07854597270488739, + 0.003920051269233227, + -0.050799619406461716, + 0.023691626265645027, + 0.019952887669205666, + 0.00716764759272337, + -0.0046928380616009235, + 0.00041822553612291813, + 0.006359069608151913, + 0.017860781401395798, + -0.22999149560928345, + -0.02180831879377365, + -0.024055887013673782, + -0.0226126741617918, + -0.01795077696442604, + 0.015591473318636417, + -0.004053472075611353, + 0.016760380938649178, + 0.03378744795918465, + -0.0027090508956462145, + 0.00999806821346283, + 0.019252799451351166, + 0.0027550198137760162, + 0.03454355522990227, + -0.0295003242790699, + -0.007663591764867306, + 0.061172280460596085, + 0.049142658710479736, + -0.00858291145414114, + -0.0035321018658578396, + -0.7689260244369507, + 0.0004916944890283048, + 0.02915046364068985, + 0.017000442370772362, + -0.003298018593341112, + -0.0405484102666378, + 0.021160880103707314, + 0.0013289587805047631, + -0.07510386407375336, + 0.03890690207481384, + 0.03729970380663872, + -0.04906352981925011, + -0.10020274668931961, + 0.01506283599883318, + -0.053726132959127426, + 0.016631007194519043, + 0.03425036743283272, + 0.03358260169625282, + -0.023937245830893517, + -0.13656578958034515, + -0.13947314023971558, + 0.012915699742734432, + 0.02431132085621357, + -0.03089652583003044, + 0.1382707953453064, + 0.056695129722356796, + -0.09263960272073746, + 0.10406216233968735, + 0.02619105577468872, + -0.01678614132106304, + -0.16045455634593964, + 8.974489173851907e-05, + -0.03521093726158142, + -0.028908027336001396, + 0.21234789490699768, + -0.02046572044491768, + -0.09703273326158524, + 0.05248226970434189, + 0.011973158456385136, + 0.004557646345347166, + -0.018632734194397926, + -0.1649131029844284, + -0.00682018743827939, + -0.12712189555168152, + 0.10513507574796677, + 0.020745709538459778, + 0.02996259182691574, + -0.15409024059772491, + -0.08719073981046677, + -0.14634187519550323, + -0.16255779564380646, + -0.15963757038116455, + -0.1324772834777832, + -0.022830091416835785, + -0.06426219642162323, + -0.025459224358201027, + 0.00281702633947134, + 0.03255268186330795, + -0.05778049677610397, + -0.30381152033805847, + -0.06582051515579224, + -0.033722274005413055, + 0.014956191182136536, + 0.004153797868639231, + 0.2391217201948166, + -0.0311420951038599, + 0.001518488978035748, + 0.019769812002778053, + -0.056324463337659836, + -0.006009253207594156, + -0.21367721259593964, + -0.0481688529253006, + 0.22422266006469727, + 0.0402204655110836, + 0.1432792693376541, + 0.14159953594207764, + -0.0025862890761345625, + -0.028965365141630173, + 0.011978867463767529, + 0.161293163895607, + 0.028642605990171432, + -0.008417634293437004, + -0.10145614296197891, + 0.08381767570972443, + 0.05199432373046875, + 0.18680602312088013, + -0.023287687450647354, + 0.03601476550102234, + 0.03738229721784592, + 0.19291405379772186, + 0.03553088754415512, + 0.05483124405145645, + 0.09577616304159164, + -0.004635817836970091, + 0.052481625229120255, + -0.042084019631147385, + -0.2629147469997406, + -0.006157668773084879, + -0.0401761569082737, + 0.02154349908232689, + -0.056558139622211456, + -0.003753019031137228, + 0.01922912523150444, + 0.1291409730911255, + -0.21358416974544525, + 0.004696246236562729, + 0.13787509500980377, + -0.07022479176521301, + -0.06828727573156357, + 0.09193858504295349, + -0.06863763928413391, + -0.05677935853600502, + -0.030970478430390358, + -0.10181070864200592, + -0.1247706487774849, + 0.014181962236762047, + -0.09259836375713348, + -0.03174220770597458, + -0.014812505804002285, + -0.024658311158418655, + -0.04815720021724701, + -0.01683010160923004, + 0.015726473182439804, + 0.002938281511887908, + -0.1586887538433075, + -0.29276973009109497, + -0.029981529340147972, + -0.046828676015138626, + -0.04909103736281395, + 0.06043976545333862, + 0.03698069602251053, + -0.04807118698954582, + 0.0943484902381897, + 0.01930702105164528, + 0.06498143821954727, + 0.0381690077483654, + -0.19611406326293945, + 0.006944946013391018, + 0.06454038619995117, + -0.19779883325099945, + 0.04966692253947258, + 0.046355295926332474, + 0.0590626522898674, + -0.24392037093639374, + -0.0018132536206394434, + 0.010944955050945282, + -0.014556891284883022, + 0.051466893404722214, + -0.0059846509248018265, + -0.06719732284545898, + 0.030604040250182152, + 0.051190104335546494, + -0.053196243941783905, + -0.06912374496459961, + -0.06263922154903412, + 0.05626852437853813, + 0.013047950342297554, + -0.005828890949487686, + 0.056055404245853424, + 0.007044378202408552, + 0.030499491840600967, + -0.035373322665691376, + 0.030934391543269157, + 0.04358363524079323, + 0.001537138712592423, + 0.005963161122053862, + -0.005889860913157463, + 0.053225863724946976, + 0.052091702818870544, + -0.02871675044298172, + 0.05662619322538376, + -0.4585985839366913, + 0.06490323692560196, + 0.02542230300605297, + 0.017592567950487137, + 0.05066920816898346, + -0.20954127609729767, + -0.06689731031656265, + -0.3632309138774872, + -0.03407476842403412, + 0.04976007342338562, + 0.03856723755598068, + 0.009329214692115784, + -0.10107281804084778, + 0.007077769376337528, + -0.005482642911374569, + 0.04388934373855591, + 0.03984231874346733, + 0.005358297843486071, + 0.05032944679260254, + 0.007170544005930424, + 0.017318176105618477, + -0.03577208146452904, + -0.02195456624031067, + 0.014414021745324135, + -0.008203372359275818, + 0.04585091397166252, + -0.012298643589019775, + 0.03959968313574791, + -0.06015963852405548, + -0.1360240876674652, + -0.07704123109579086, + -0.0842466950416565, + -0.11261942237615585, + 0.0433686338365078, + -0.1059969812631607, + 0.014813154004514217, + 0.04216694459319115, + 0.10441470146179199, + 0.04579426348209381, + 0.026033954694867134, + 0.08725529909133911, + -0.14662955701351166, + -0.0726592168211937, + 0.1293957382440567, + 0.013497715815901756, + -0.01318936888128519, + -0.05188713222742081, + 0.08793413639068604, + 0.1094818189740181, + 0.07991892844438553, + 0.03549068048596382, + -0.04469897970557213, + -0.10442564636468887, + 0.13456915318965912, + 0.01154977548867464, + -0.05959299951791763, + 0.01768219843506813, + 0.0179652888327837, + -0.010112428106367588, + 0.020603090524673462, + -0.7144030928611755, + 0.20126283168792725, + 0.058172807097435, + -0.10543914139270782, + 0.07461538910865784, + -0.1744592934846878, + 0.055722273886203766, + -0.046595826745033264, + 0.06237049773335457, + 0.05800141766667366, + 0.04118870943784714, + 0.002582935383543372, + 0.010623090900480747, + -0.0439014658331871, + 0.044685740023851395, + -0.017063472419977188, + -0.0173367727547884, + -0.04761765897274017, + 0.06136244907975197, + 0.08495236933231354, + 0.24923592805862427, + -0.061080869287252426, + 0.15922360122203827, + -0.09322690963745117, + -0.09617402404546738, + 0.0029533954802900553, + 0.12630371749401093, + 0.0011397749185562134, + 0.0005059551913291216, + -0.060922350734472275, + -0.16446451842784882, + 0.057099178433418274, + 0.03073902614414692, + -0.031064951792359352, + 0.012277435511350632, + 0.020447896793484688, + 0.06010727211833, + 0.07065457105636597, + 0.026963504031300545, + 0.010798406787216663, + -0.02631279267370701, + 0.02046871930360794, + -0.004800989292562008, + -0.03282550349831581, + 0.053904879838228226, + -0.03294985368847847, + -0.4204113185405731, + 0.028552187606692314, + 0.023685462772846222, + 0.0017703581834211946, + 0.02868991158902645, + -0.3585520088672638, + -0.011516556143760681, + -0.00248165475204587, + 0.011379038915038109, + 0.0459531806409359, + 0.015357235446572304, + 0.05573337897658348, + 0.06516549736261368, + 0.02981666848063469, + 0.05498211458325386, + 0.028714550659060478, + -0.005899528972804546, + 0.008476868271827698, + 0.11328839510679245, + 0.020578190684318542, + -0.15382742881774902, + 0.015724696218967438, + -0.08402770012617111, + 0.060314107686281204, + 0.032343748956918716, + 0.014438764192163944, + -0.13614842295646667, + -0.0017508765449747443, + 0.09998518973588943, + -0.06364594399929047, + 0.049632295966148376, + -0.11922458559274673, + -0.08834195137023926, + 0.019541991874575615, + 0.06320779770612717, + 0.017419861629605293, + -0.0028468866366893053, + -0.14753428101539612, + 0.02623703144490719, + -0.011462770402431488, + 0.06676206737756729, + -0.014891563914716244, + -0.002118025440722704, + 0.02519390918314457, + -0.29581141471862793, + 0.0264339130371809, + 0.04027356952428818, + 0.00412194337695837, + 0.03778498247265816, + -0.012331741861999035, + 0.15336745977401733, + -0.034510836005210876, + 0.0319819413125515, + 0.01916184462606907, + 0.04952343553304672, + -0.026733938604593277, + -0.014996573328971863, + 0.0010714810341596603, + 0.01959756202995777, + -0.0392388179898262, + -0.0052064210176467896, + -0.05015777423977852, + -0.0002977418771479279, + -0.04029487073421478, + -0.012846150435507298, + -0.09198840707540512, + 0.0118671590462327, + -0.06176264211535454, + 0.006427878048270941, + 0.04043034091591835, + -0.017270859330892563, + -0.012422707863152027, + 0.01713552325963974, + -0.026697810739278793, + 0.2446632832288742, + -0.020500628277659416, + -0.0012782106641680002, + -0.13429665565490723, + 0.07528743892908096, + -0.002225265372544527, + 0.06695574522018433, + 0.0017388156848028302, + -0.0629071593284607, + -0.05081196129322052, + 0.042025983333587646, + 0.029097404330968857, + 0.07048555463552475, + -0.11881273239850998, + 0.012633765116333961, + -0.06181430071592331, + 0.038810230791568756, + 0.05186169967055321, + 0.03248963877558708, + 0.07868267595767975, + 0.024977494031190872, + 0.023991582915186882, + 0.0023529180325567722, + 0.07197123020887375, + 0.02653665468096733, + 0.058702051639556885, + 0.015001803636550903, + 0.043739400804042816, + -0.07251746207475662, + 0.045659150928258896, + -0.02111324854195118, + 0.26666632294654846, + 0.1975221484899521, + -0.031074335798621178, + 0.029075143858790398, + 0.013020229525864124, + 0.015244663693010807, + 0.01387549377977848, + -0.025354426354169846, + 0.06151636317372322, + -0.034430794417858124, + 0.00752665288746357, + 0.1678706705570221, + -0.016560610383749008, + 0.0421285480260849, + -0.02527586743235588, + -0.02166694961488247, + -0.034658536314964294, + 0.036866605281829834, + -0.036233626306056976, + 0.02042747661471367, + 0.028099242597818375, + 0.020503878593444824, + 0.022789381444454193, + 0.08666791766881943, + -0.06426636874675751, + -0.043599683791399, + 0.1136128157377243, + 0.020200412720441818, + -0.003839759388938546, + -0.06010120362043381, + -0.02218424715101719, + 0.09008956700563431, + 0.008711264468729496, + -0.04874516651034355, + -0.011533043347299099, + -0.036206502467393875, + -0.006006627343595028, + -0.0350450798869133, + 0.005623341538012028, + 0.09562186151742935, + -0.03952183946967125, + -0.013931595720350742, + -0.020029470324516296, + 0.0022144403774291277, + -0.020198611542582512, + 0.012238736264407635, + 0.054415784776210785, + -0.024457741528749466, + -0.01174110360443592, + 0.031656913459300995, + 0.060322560369968414, + 0.01573050767183304, + 0.03361794352531433, + 0.022875478491187096, + 0.036340806633234024, + -0.02932620421051979, + 0.0224352665245533, + -0.013475337065756321, + -0.030774995684623718, + 0.013921404257416725, + -0.01229875348508358, + -0.07986237108707428, + -0.007543445099145174, + 0.05208213999867439, + -0.04440496116876602, + -0.029659371823072433, + -0.029070377349853516, + 0.07376870512962341, + -0.07208643853664398, + -0.05429431423544884, + -0.007887271232903004, + 0.011400371789932251, + 0.014227204024791718, + 0.01763899251818657, + -0.0426466204226017, + 0.0024213625583797693, + 0.02564665488898754, + 0.0020850151777267456, + 0.027386819943785667, + 0.12722602486610413, + -0.060991525650024414, + -0.009061425924301147, + 0.014208497479557991, + -0.006956137716770172, + 0.09096626192331314, + 0.0037735258229076862, + -0.8347064852714539, + -0.2857951521873474, + 0.0011818337952718139, + 0.0341162234544754, + -0.04230167716741562, + 0.05230262130498886, + 0.08486262708902359, + -0.34235459566116333, + -0.02393503487110138, + 0.02718495950102806, + 0.050966840237379074, + 0.024611525237560272, + -0.004936584271490574, + -0.036420952528715134, + -0.009803534485399723, + 0.05421328917145729, + 0.008357672952115536, + 0.020987343043088913, + -0.007292840629816055, + 0.018060531467199326, + 0.06739793717861176, + 0.06161382421851158, + 0.000842935056425631, + -0.007857701741158962, + 0.023870037868618965, + -0.009690430946648121, + -0.04231289029121399, + -0.22531479597091675, + 0.034284885972738266, + 0.07360551506280899, + 0.0421777106821537, + 0.000788167177233845, + -0.3953339457511902, + -0.042627450078725815, + -0.02774403616786003, + 0.02647743932902813, + -0.01561375055462122, + 0.04745408892631531, + 0.021774733439087868, + 0.006606150884181261, + 0.03879173845052719, + 0.06500626355409622, + 0.044954728335142136, + 0.01523532159626484, + 0.04741065576672554, + -0.13645507395267487, + 0.0038059696089476347, + -0.012993253767490387, + -0.004529603291302919, + 0.03268986567854881, + -0.025349941104650497, + -0.02268051542341709, + -0.0001516443444415927, + -0.010289257392287254, + -0.0010476588504388928, + -0.0690254345536232, + 0.04298266023397446, + -0.05470968782901764, + 0.04369102790951729, + -0.007372597698122263, + 0.027607066556811333, + 0.0009343988494947553, + -0.09573916345834732, + 0.04389296472072601, + -0.01522558368742466, + -0.03138086944818497, + 0.04511113464832306, + -0.0342172235250473, + -0.00033129166695289314, + -0.037289440631866455, + 0.055575959384441376, + 0.01849759928882122, + 0.03041103295981884, + -0.01965116336941719, + 0.07604960352182388, + -0.0399625338613987, + -0.008190250024199486, + -0.015386211685836315, + -0.04315667226910591, + 0.0023679479490965605, + 0.018971435725688934, + -0.005599244497716427, + -0.029607947915792465, + 0.07574024051427841, + -0.013816094025969505, + 0.04464992880821228, + 0.00032806122908368707, + 0.06071484833955765, + 0.04261377081274986, + 0.012208743952214718, + 0.0801805928349495, + 0.02875029854476452, + -0.0662921741604805, + 0.015754999592900276, + 0.05831082537770271, + 0.03810921683907509, + 0.05483977496623993, + -0.019509335979819298, + 0.0032034649048000574, + 0.011807492934167385, + -0.01916244812309742, + 0.022101666778326035, + -0.0366031751036644, + 0.10915965586900711, + 0.030322788283228874, + -0.028386037796735764, + -0.05443429946899414, + -0.02489445172250271, + 0.0892239362001419, + -0.05427740886807442, + -0.034238025546073914, + -0.04136161506175995, + -0.041148390620946884, + 0.06879492849111557, + -0.37424594163894653, + 0.028803903609514236, + 0.05349116027355194, + 0.0359492301940918, + -0.3629145622253418, + -0.17875684797763824, + -0.012246759608387947, + 0.2744927704334259, + -0.010421697050333023, + -0.19415415823459625, + 0.005668101832270622, + 0.018326066434383392, + 0.28319111466407776, + -0.008164885453879833, + -0.07401272654533386, + -0.04154321923851967, + 0.030028337612748146, + -0.008959534578025341, + -0.03160349279642105, + -0.0191870778799057, + 0.044875819236040115, + 0.052173007279634476, + 0.012135458178818226, + 0.008775291964411736, + 0.005302258301526308, + 0.009224606677889824, + -0.07574712485074997, + 0.06096252053976059, + 0.02645082212984562, + 0.05135556682944298, + 0.021985528990626335, + 0.0076704383827745914, + 0.02961125783622265, + -0.07608609646558762, + -0.17564956843852997, + 0.03679918497800827, + -0.2696506083011627, + 0.0627906322479248, + 0.031165480613708496, + 0.01799822598695755, + 0.02351829782128334, + 0.015595306642353535, + -0.25137314200401306, + -0.011266927234828472, + 0.04895596578717232, + 0.01718883402645588, + 0.0009224268142133951, + 0.021923478692770004, + 0.044791676104068756, + 0.079147569835186, + 0.02014082670211792, + -0.0003547854721546173, + -0.02535748854279518, + -0.029639363288879395, + -0.01965961419045925, + -0.37630724906921387, + 0.01674639992415905, + 0.01316642016172409, + -0.025120021775364876, + -0.12474260479211807, + 0.059980470687150955, + 0.036066047847270966, + -0.15973420441150665, + -0.010871605016291142, + 0.014708316884934902, + -0.2174367904663086, + 0.012985467910766602, + -0.03782057762145996, + -0.003427069401368499, + -0.011010636575520039, + 0.02433733455836773, + 0.08641276508569717, + -0.004630533047020435, + 0.019430357962846756, + -0.02088969387114048, + -0.06182911619544029, + 0.02577812969684601, + 0.015741532668471336, + 0.04723552614450455, + -0.003783567575737834, + 0.11646346747875214, + 0.01827184483408928, + -0.0999741181731224, + -0.0031216999050229788, + -0.002268272452056408, + -0.019456079229712486, + -0.003156653605401516, + 0.0067732855677604675, + 0.027299508452415466, + 0.06979037076234818, + 0.013329057022929192, + -0.016705401241779327, + 0.33774301409721375, + 0.007617524825036526, + 0.044453222304582596, + 0.0016282782889902592, + 0.0010982973035424948, + 0.04183036834001541, + 0.016857653856277466, + 0.006673034280538559, + -0.0187662523239851, + 0.0037163379602134228, + -0.04568779841065407, + -0.007807960733771324, + 0.016653010621666908, + 0.0033014933578670025, + 0.015063234604895115, + 0.012843966484069824, + -0.012042546644806862, + 0.016909126192331314, + 0.022089935839176178, + -0.002550398698076606, + 0.04166745766997337, + -0.0014742743223905563, + -0.010846617631614208, + -0.12333541363477707, + 0.0018612967105582356, + 0.04913188889622688, + -0.029431112110614777, + 0.01824735291302204, + 0.10425490140914917, + -0.08880072832107544, + 0.03029320202767849, + 0.018876856192946434, + 0.016104502603411674, + 0.00882721971720457, + 0.0029782119672745466, + 0.007922517135739326, + -0.02030068263411522, + -0.029835309833288193, + 0.006661414168775082, + -0.04313879832625389, + -0.001850730157457292, + -0.0035070034209638834, + -0.0070700813084840775, + 0.009637435898184776, + -0.016844747588038445, + -0.026075454428792, + 0.0030682040378451347, + 0.004208600614219904, + -0.005515689495950937, + -0.018976539373397827, + -0.019196776673197746, + -0.008948019705712795, + 0.016215825453400612, + 0.00296461652033031, + 0.14222395420074463, + -0.029066482558846474, + -0.011013337410986423, + -0.01267730537801981, + -0.004976287949830294, + -0.016607511788606644, + -0.0005681798211298883, + -0.012520174495875835, + -0.0015903630992397666, + -0.0013642794219776988, + -0.21956196427345276, + -0.0011431180173531175, + -0.0008808697457425296, + -0.022889399901032448, + 0.024718068540096283, + -0.054929111152887344, + -0.015585094690322876, + -0.018188318237662315, + -0.0008287815726362169, + -0.01957552134990692, + 0.10818513482809067, + -0.0034382494632154703, + -0.02667389065027237, + -0.01304248720407486, + -0.0034645304549485445, + -0.008519704453647137, + -0.015123830176889896, + -0.008219013921916485, + -0.009952309541404247, + -2.3375787350232713e-05, + -0.012512428686022758, + -0.001955948770046234, + -0.0029842876829206944, + -0.004291659686714411, + 0.006655955221503973, + 0.007771315053105354, + 0.014132227748632431, + -0.007390063256025314, + -0.024650415405631065, + -0.022503213956952095, + 0.0032607221510261297, + -0.008497492410242558, + 0.00860870536416769, + 0.002819088753312826, + -0.01841069757938385, + -0.010009711608290672, + -0.2912862300872803, + 0.017160022631287575, + 0.11349690705537796, + -0.027656083926558495, + -0.04482223838567734, + -0.019336597993969917, + 0.07413014769554138, + 0.014554106630384922, + 0.020965611562132835, + -0.028231356292963028, + -0.0582813061773777, + 0.05617539584636688, + -0.05042734369635582, + 0.025630727410316467, + -0.0956532284617424, + -0.14554104208946228, + -0.020851148292422295, + 0.006990485824644566, + 0.08457829803228378, + -0.11314752697944641, + 0.004020951222628355, + -0.03477870300412178, + 0.005594289395958185, + 0.011181964538991451, + 0.010988114401698112, + 0.019416088238358498, + 0.026451971381902695, + -0.00452260859310627, + 0.0004952011513523757, + 0.012377702631056309, + -0.0063480171374976635, + 0.0256175734102726, + -0.020753338932991028, + 0.03223377838730812, + -0.1147943064570427, + -0.009170151315629482, + 0.015267477370798588, + -0.0009072314132936299, + -0.1621374636888504, + 0.022807778790593147, + 0.007394107989966869, + 0.01378557924181223, + -0.10719677805900574, + -0.000919080339372158, + -0.006567052565515041, + -0.007409179583191872, + -0.007469762582331896, + -0.004784661345183849, + -0.03967805579304695, + 0.015857066959142685, + -0.02015744335949421, + 0.056037548929452896, + 0.03962035849690437, + 0.08429893851280212, + 0.022117067128419876, + -0.2675061821937561, + 0.016738418489694595, + 0.0037785861641168594, + 0.004771686624735594, + -0.134505033493042, + -0.010618447326123714, + -0.004784524440765381, + 0.014044507406651974, + -0.03105556219816208, + 0.05049083009362221, + 0.012162688188254833, + 0.005920265335589647, + 0.008554516360163689, + 0.0025892227422446012, + 0.023483717814087868, + -0.20711173117160797, + 0.03360452130436897, + -0.24758699536323547, + -0.05136318504810333, + -0.015016172081232071, + 0.06466241925954819, + 0.023470288142561913, + 0.023495715111494064, + 0.004300899337977171, + 0.02461574412882328, + 0.025745516642928123, + -0.026187308132648468, + 0.08441776037216187, + -0.06955462694168091, + -0.11116205900907516, + -0.2169608771800995, + -0.004244703333824873, + -0.024184226989746094, + -0.10068271309137344, + -0.021129190921783447, + -0.021129680797457695, + -0.0054467362351715565, + 0.17416934669017792, + 0.015367642976343632, + -0.01237915363162756, + 0.024573752656579018, + 0.004588739015161991, + 0.05616860091686249, + -0.0018992060795426369, + -0.12394066900014877, + -0.03691404312849045, + -0.15878455340862274, + 0.10572423785924911, + 0.014409378170967102, + -0.008566108532249928, + -0.20319701731204987, + -0.018277373164892197, + -0.21615462005138397, + -0.11269525438547134, + -0.2767113745212555, + -0.25617966055870056, + -0.0036413148045539856, + -0.008058675564825535, + -0.051732294261455536, + -0.013052727095782757, + 0.05229722708463669, + -0.03535814583301544, + 0.3111231327056885, + -0.044130608439445496, + -0.02232682704925537, + -0.0040402463637292385, + 0.013798556290566921, + -0.07689940929412842, + -0.028940049931406975, + -0.00565366679802537, + -0.028972560539841652, + -0.007728889584541321, + 0.013665011152625084, + -0.014678380452096462, + -0.06747694313526154, + -0.06480871140956879, + -0.00028885426581837237, + -0.01525174267590046, + 0.027096102014183998, + -0.05200905352830887, + 0.0066903820261359215, + 0.0023834225721657276, + -0.002379713812842965, + -0.0208051148802042, + 0.335977703332901, + 0.03895771875977516, + -0.04814215749502182, + -0.037339694797992706, + -0.004409746266901493, + 0.07042848318815231, + -0.08318590372800827, + -0.04138712212443352, + 0.06309781968593597, + 0.007484383415430784, + 0.09696535021066666, + 0.024134323000907898, + -0.009859816171228886, + -0.06243982911109924, + 0.04630015045404434, + -0.06593744456768036, + 0.009306293912231922, + 0.5033899545669556, + 0.007804783061146736, + 0.024170484393835068, + -0.036085959523916245, + 0.016438491642475128, + 0.01678072288632393, + -0.006299734115600586, + -0.027441656216979027, + -0.014344800263643265, + 0.022293711081147194, + 0.011197407729923725, + -0.0026971842162311077, + 0.2685070335865021, + 0.01403988990932703, + -0.005100077483803034, + -0.026031343266367912, + -0.005419034510850906, + -0.014735087752342224, + -0.0283498577773571, + 0.002656748052686453, + -0.07137783616781235, + 0.02235356532037258, + -0.02970476634800434, + 0.20672672986984253, + 0.017398398369550705, + 0.02438206970691681, + 0.025746773928403854, + -0.03279582038521767, + 0.043908532708883286, + -0.003417646512389183, + 0.020200302824378014, + 0.007243862375617027, + -0.004560714587569237, + -0.01142876222729683, + -0.028091270476579666, + -0.2949703335762024, + 0.0729827880859375, + 0.004566277377307415, + 0.16689160466194153, + 0.034872010350227356, + -0.09590360522270203, + -0.13309867680072784, + 0.06429398059844971, + 0.04174232855439186, + -0.022723963484168053, + -0.04695400968194008, + 0.013115685433149338, + 0.013574879616498947, + 0.04794493317604065, + -0.015077140182256699, + 0.09493618458509445, + 0.008845972828567028, + 0.020302923396229744, + 0.02037016488611698, + 0.009083293378353119, + 0.0747746080160141, + -0.008078188635408878, + 0.024796344339847565, + -0.015212535858154297, + -0.005867444910109043, + 0.08309170603752136, + 0.03676094114780426, + 0.07232356816530228, + -0.3577176630496979, + 0.0013658110983669758, + -0.0009247250854969025, + 0.02284996211528778, + 0.012630275450646877, + 0.013745593838393688, + 0.003447894938290119, + 0.03563565015792847, + -0.031025355681777, + -0.07258180528879166, + -0.13482442498207092, + -0.029425248503684998, + -0.014927731826901436, + 0.045984312891960144, + -0.0176406130194664, + -0.22678181529045105, + -0.025248311460018158, + -0.11617762595415115, + -0.056157518178224564, + 0.009453062899410725, + -0.34616726636886597, + 0.05691010504961014, + -0.32302799820899963, + -0.026544231921434402, + -0.007374088745564222, + -0.07682909071445465, + -0.021214107051491737, + -0.07102422416210175, + 0.02693488635122776, + 0.014817211776971817, + 0.015572831965982914, + 0.04313618317246437, + -0.1277216374874115, + 0.02174532599747181, + -0.0226149819791317, + -0.00010956164624076337, + 0.023728065192699432, + 0.008212783373892307, + 0.010561724193394184, + -0.011036543175578117, + -0.022485855966806412, + 0.008243439719080925, + -0.03383245691657066, + -0.5630682110786438, + 0.0015974265988916159, + -0.28416821360588074, + 0.04123701527714729, + -0.0042976438999176025, + 0.03786511346697807, + 0.01862393692135811, + -0.04082413762807846, + -0.05792848393321037, + 0.0068894242867827415, + 0.0024085959885269403, + 0.001471342402510345, + 0.030681759119033813, + -0.026314062997698784, + 0.0555737242102623, + 0.03169534355401993, + 0.0031395808327943087, + 0.018701769411563873, + -0.5604594945907593, + 0.01526441890746355, + -0.00621993700042367, + 0.0009401043644174933, + 0.01587403193116188, + 0.030135583132505417, + -0.007350685074925423, + 0.006527469493448734, + 0.016000108793377876, + -0.042957425117492676, + 0.018247080966830254, + 0.0025622656103223562, + -0.03169511258602142, + 0.09235119074583054, + -0.013365034945309162, + 0.01607452519237995, + 0.017734844237565994, + 0.05609896034002304, + 0.04819876700639725, + -0.0871855691075325, + 0.05157865956425667, + 0.009171447716653347, + 0.022200705483555794, + -0.005507844965904951, + -0.024452703073620796, + 0.010224574245512486, + -0.006914906669408083, + 0.004650818649679422, + 0.02167516015470028, + 0.10456826537847519, + -0.07652094960212708, + -6.050072988728061e-05, + 0.012855490669608116, + 0.022669879719614983, + 0.022655120119452477, + 0.033012885600328445, + 0.025709744542837143, + 0.00481270719319582, + 0.005920717027038336, + -0.08545156568288803, + -0.004363589454442263, + -0.01531639602035284, + 0.030760569497942924, + 0.02796284481883049, + -0.03690989315509796, + 0.044959694147109985, + -0.14276015758514404, + -0.0002254673163406551, + -0.15694372355937958, + 0.012381293810904026, + -0.021977441385388374, + 0.005496624857187271, + -0.035593707114458084, + -0.0950438603758812, + 0.03825876861810684, + 0.05915532633662224, + -0.023323312401771545, + 0.017213119193911552, + -0.03807183355093002, + 0.02619507722556591, + 0.02741156332194805, + 0.005847832188010216, + 0.0020307491067796946, + 0.025714349001646042, + -0.04780200496315956, + 0.010206928476691246, + -0.01345440000295639, + 0.029133174568414688, + -0.0014764482621103525, + 0.004046705551445484, + -0.007725241594016552, + 0.013041527941823006, + 0.0018969239899888635, + 0.002417983952909708, + -0.010975837707519531, + 0.0015862436266615987, + 0.00597577728331089, + 0.002882696921005845, + 0.02855525352060795, + -0.005954153370112181, + 0.04090835899114609, + -0.39500924944877625, + 0.03586621209979057, + -0.5250031352043152, + -0.05697731301188469, + -0.09568691998720169, + -0.07179264724254608, + 0.04683076590299606, + 0.009320023469626904, + -0.11629963666200638, + -0.0016945215174928308, + 0.01624997705221176, + -0.0063682254403829575, + 0.15033549070358276, + -0.5171176791191101, + -0.01525783073157072, + 0.016417231410741806, + -0.00303818890824914, + 0.2500321865081787, + 0.022074062377214432, + 0.01191191840916872, + 0.012274803593754768, + 0.016534989699721336, + -0.028437916189432144, + 0.04241323843598366, + -0.01824999786913395, + -0.34815871715545654, + 0.04734490439295769, + -0.06419701874256134, + -0.022288290783762932, + -0.0004865761147812009, + 0.05369419604539871, + -0.058212973177433014, + -0.2196469008922577, + 0.010950890369713306, + 0.029042819514870644, + -0.07349151372909546, + -0.0422789566218853, + 0.062069639563560486, + 0.05589267984032631, + 0.014877256006002426, + 0.04236084595322609, + 0.03975239768624306, + 0.16930873692035675, + 0.03981085494160652, + 0.11499395221471786, + 0.0271450225263834, + 0.013969083316624165, + -0.0002660648606251925, + 0.010936664417386055, + -0.18389767408370972, + -0.10237602889537811, + 0.03041323646903038, + -0.013864071108400822, + -0.015729930251836777, + 0.037400804460048676, + -0.009598327800631523, + -0.09533312171697617, + -0.014712700620293617, + 0.08537333458662033, + -0.007200485561043024, + -0.31139102578163147, + -0.06366845220327377, + 0.02039063163101673, + -0.023356139659881592, + -0.0029549277387559414, + -0.12494662404060364, + 0.011755092069506645, + -0.26468148827552795, + -0.11541861295700073, + 0.010529865510761738, + -0.05965733155608177, + -0.05945499241352081, + -0.08796169608831406, + -0.014683439396321774, + 0.008732054382562637, + 0.010073489509522915, + 0.09553763270378113, + 0.034884922206401825, + 0.018675342202186584, + -0.009549405425786972, + -0.0007051719003356993, + -0.16936513781547546, + -0.0030460187699645758, + -0.022060535848140717, + -0.06689190864562988, + 0.013926704414188862, + 0.012043816037476063, + -0.0587068572640419, + -0.03814113140106201, + 0.06235629320144653, + 0.013228330761194229, + 0.04154474660754204, + -0.08039120584726334, + 0.028436705470085144, + -0.042226389050483704, + -0.019135186448693275, + 0.03747033327817917, + -0.14261123538017273, + 0.02827540971338749, + 0.0455685593187809, + -0.031124960631132126, + -0.007588588632643223, + 0.0034326373133808374, + -0.07682976871728897, + 0.24654042720794678, + -0.014518304727971554, + -0.07052458822727203, + -0.08241941034793854, + -0.04116151109337807, + -0.048463717103004456, + -0.038745298981666565, + 0.036902472376823425, + 0.0442035011947155, + 0.05572585016489029, + -0.014312628656625748, + 0.010794793255627155, + -0.3440641760826111, + -0.5161325335502625, + 0.0005156552069820464, + -0.010257269255816936, + -0.02412656880915165, + -0.023385023698210716, + 0.05533458665013313, + -0.012186119332909584, + -0.029286568984389305, + 0.04116401448845863, + -0.044610101729631424, + -0.019175484776496887, + 0.06835268437862396, + 0.06366674602031708, + 0.0373748242855072, + 0.03804386034607887, + 0.05369521677494049, + -0.04451881721615791, + 0.0018838117830455303, + 0.34775662422180176, + 0.010958605445921421, + -0.047990139573812485, + 0.04386777803301811, + -0.10427688807249069, + 0.04417382925748825, + 4.402965714689344e-05, + 0.01935163326561451, + -0.06753949075937271, + 0.02735923044383526, + 0.01465953141450882, + 0.06198301538825035, + -0.015980403870344162, + -0.2108263075351715, + 0.008177559822797775, + 0.006046924740076065, + 0.002665479900315404, + 0.20868580043315887, + -0.013740362599492073, + 0.008203004486858845, + -0.005066391546279192, + 0.026405498385429382, + 0.01383009273558855, + 0.012581533752381802, + 0.009014940820634365, + 0.022820021957159042, + -0.008534795604646206, + 0.2603924572467804, + 0.02297227643430233, + -0.000749691273085773, + 0.044753506779670715, + 0.018596511334180832, + 0.006852792575955391, + -0.008686172775924206, + -0.10452616959810257, + 0.017021872103214264, + 0.003722329391166568, + -0.025453045964241028, + -0.011473417282104492, + -0.017907623201608658, + 0.01400628499686718, + -0.1670989990234375, + 0.004298652987927198, + -0.0022204748820513487, + 0.16521315276622772, + -0.008831127546727657, + 0.026490870863199234, + 0.006190746556967497, + -0.0177209060639143, + 0.08967147767543793, + 0.0033069502096623182, + -0.005021366756409407, + 0.0004906906979158521, + 0.0169216375797987, + -0.06124846637248993, + -0.005200678016990423, + 0.08404737710952759, + -0.010559299029409885, + -0.006309974938631058, + 0.023113396018743515, + -0.010227260179817677, + 0.001256447983905673, + 0.019783375784754753, + -0.006308461539447308, + -0.04529590904712677, + -0.00908862054347992, + -0.043217338621616364, + -0.32200074195861816, + 0.02592635713517666, + 0.030795685946941376, + -0.001814531977288425, + 0.0092842485755682, + 0.07088880985975266, + -0.0867588147521019, + 0.024099843576550484, + -0.0034031609538942575, + 0.007234686985611916, + -0.02505563199520111, + 0.0030480287969112396, + -0.019158190116286278, + 0.26473408937454224, + -0.011918547563254833, + -0.023240016773343086, + -0.06084466353058815, + -0.021916134282946587, + -0.010251260362565517, + -0.0009625791572034359, + 0.082605741918087, + -0.013018425554037094, + 0.007627277635037899, + -0.0010813736589625478, + 0.007952406071126461, + 0.06551267951726913, + -0.026020025834441185, + 0.050048135221004486, + -0.010610008612275124, + -0.02429312653839588, + -0.025263017043471336, + -0.04611891135573387, + 0.04451768472790718, + -0.08045025914907455, + -0.048037610948085785, + 0.008019295521080494, + 0.0160224549472332, + 0.002078550634905696, + -0.0202508345246315, + -0.5446130633354187, + 0.012585492804646492, + -0.0331973135471344, + 0.08371605724096298, + -0.00590998912230134, + -0.013058983720839024, + 0.027742384001612663, + 0.1042199358344078, + -0.3072803318500519, + 0.06284149736166, + -0.28551968932151794, + 0.026768438518047333, + 0.022245990112423897, + 0.018242113292217255, + -0.035077981650829315, + 0.03546127676963806, + 0.10165776312351227, + -0.025475669652223587, + -0.014933750964701176, + 0.040547240525484085, + -0.033055808395147324, + 0.011755919083952904, + -0.014459444209933281, + -0.03455093130469322, + 0.020743343979120255, + 0.02720930427312851, + -0.287664532661438, + 0.008260028436779976, + -0.009877690114080906, + 0.16657423973083496, + -0.010943812318146229, + -0.012381386943161488, + 0.030678801238536835, + 0.1559792459011078, + 0.038967035710811615, + -0.023399239405989647, + 0.015019542537629604, + -0.014201333746314049, + -0.014202176593244076, + -0.006699408870190382, + -0.13175444304943085, + 0.004643211141228676, + 0.012747463770210743, + -0.04086190089583397, + 0.06581410765647888, + -0.12192045897245407, + -0.03126347437500954, + 0.011175516061484814, + -0.00914736744016409, + -0.02883930690586567, + -0.11305265873670578, + -0.04405384883284569, + -0.009120048955082893, + -0.008926079608500004, + -0.03169447183609009, + 0.05464877560734749, + 0.25674498081207275, + 0.08497058600187302, + -0.023222925141453743, + 0.35592252016067505, + -0.006929511670023203, + 0.025255810469388962, + -0.05150032415986061, + 0.039239466190338135, + -0.07082924991846085, + -0.017321549355983734, + 0.17293211817741394, + -0.02155853807926178, + -0.014333213679492474, + 0.0031305316369980574, + -0.013490653596818447, + -0.1376512199640274, + -0.021713266149163246, + -0.029826253652572632, + -0.0011473714839667082, + -0.012434332631528378, + -0.04860873892903328, + 0.013857590034604073, + 0.0703854188323021, + 0.034528713673353195, + -0.014423011802136898, + 0.0882454589009285, + -0.091700978577137, + 0.038885727524757385, + 0.012043441645801067, + -0.03183690831065178, + -0.014495689421892166, + -0.019726552069187164, + -0.010094117373228073, + -0.004218627233058214, + -0.04413086175918579, + -0.1344134360551834, + -0.0004976870259270072, + -0.0008357573533430696, + 0.04518067091703415, + 0.046797975897789, + 0.24766182899475098, + 0.01065139751881361, + -0.0034267394803464413, + -0.016103556379675865, + -0.05139121413230896, + 0.012563390657305717, + -0.03310413286089897, + -0.030157553032040596, + 0.046670909970998764, + 0.012565785087645054, + -0.040275491774082184, + 0.023816417902708054, + -0.38536572456359863, + 0.04508889466524124, + 0.13637560606002808, + -0.010654824785888195, + 0.0459851399064064, + -0.0046302699483931065, + -0.020852191373705864, + 0.10662271827459335, + 0.06486576050519943, + 0.05727925896644592, + 0.09816201776266098, + 0.04878557100892067, + -0.16256237030029297, + 0.014547038823366165, + 0.018567964434623718, + -0.07284612208604813, + 0.017150163650512695, + 0.0246741883456707, + -0.38470372557640076, + -0.07465949654579163, + 0.03010236658155918, + -0.004397575277835131, + -0.06618984788656235, + -0.02908281609416008, + 0.060166433453559875, + -0.0020949048921465874, + 0.007689109072089195, + -0.0047390698455274105, + -0.014199030585587025, + -0.01794746331870556, + -0.02528063952922821, + 0.002218312583863735, + 0.10169881582260132, + 0.010602130554616451, + -0.06605861335992813, + -0.0008762837387621403, + -0.035027723759412766, + -0.011684391647577286, + 0.02247578091919422, + 0.17245104908943176, + 0.22525252401828766, + -0.010771296918392181, + 0.05595310404896736, + 0.06338834017515182, + -0.0038216698449105024, + -0.0032836494501680136, + 0.005779017228633165, + -0.18020786345005035, + -0.05066698044538498, + -0.0035458216443657875, + -0.10578767210245132, + -0.041712939739227295, + 0.2104150652885437, + -0.03753345459699631, + 0.013989892788231373, + 0.01988149993121624, + 0.05108603090047836, + 0.04496738687157631, + -0.3034508526325226, + 0.0226743221282959, + -0.0431472510099411, + -0.025635428726673126, + -0.18961989879608154, + -0.17218825221061707, + 0.03576141223311424, + 0.060613714158535004, + -0.011970550753176212, + -0.21435107290744781, + 0.01422552578151226, + 0.02974064089357853, + -0.061079952865839005, + 0.031064646318554878, + 0.009629320353269577, + -0.13762925565242767, + 0.01928475871682167, + 0.007310172542929649, + 0.06103459745645523, + -0.16216528415679932, + 0.03330384939908981, + 0.09578404575586319, + -0.0037327276077121496, + 0.029233848676085472, + -0.0015759399393573403, + 0.005511409603059292, + -0.4195749759674072, + 0.024169376119971275, + 0.13220365345478058, + 0.007961929775774479, + 0.008045470342040062, + 0.01919495314359665, + -0.023188553750514984, + 0.07084394991397858, + -0.24922333657741547, + 0.02011212892830372, + -0.18514998257160187, + 0.03114209696650505, + 0.09826567023992538, + 0.00592303741723299, + -0.010020115412771702, + 0.027117054909467697, + -0.214133620262146, + -0.01214816514402628, + 0.06564164906740189, + 0.02513044886291027, + 0.02132420241832733, + -0.02127540111541748, + -0.041606876999139786, + 0.04196378216147423, + -0.02060609683394432, + 0.01730814389884472, + -0.17418994009494781, + 0.03462710976600647, + -0.017470642924308777, + -0.3992193639278412, + 0.02652592957019806, + 0.025042008608579636, + 0.026447610929608345, + -0.19199316203594208, + 3.27593952533789e-05, + 0.002988220192492008, + -0.21171888709068298, + 0.03300239518284798, + 0.015727035701274872, + -0.008947308175265789, + 0.03924538940191269, + -0.08990193158388138, + 0.023726975545287132, + 0.03463870286941528, + -0.05018220469355583, + 0.13170146942138672, + 0.054000236093997955, + 0.01158218178898096, + 0.062349993735551834, + -0.014724616892635822, + 0.039657603949308395, + 0.04436490684747696, + 0.014076294377446175, + 0.07666806876659393, + 0.09630247205495834, + -0.04152659326791763, + -0.1860806941986084, + -0.07671733945608139, + 0.031573690474033356, + -0.44617798924446106, + -0.004897239152342081, + -0.03991628438234329, + 0.01880800537765026, + -0.04769768565893173, + 0.02198435738682747, + 0.01341161783784628, + -0.12239313870668411, + 0.019765935838222504, + 0.005221452098339796, + -0.025201082229614258, + 0.005132562946528196, + 0.08668412268161774, + 0.0035341952461749315, + 0.008583099581301212, + 0.032979920506477356, + 0.03324040770530701, + 0.04411708936095238, + -0.008390798233449459, + 0.040486790239810944, + -0.059673551470041275, + 0.02003314346075058, + -0.0990666076540947, + 0.03971675783395767, + 0.012021057307720184, + 0.0017271327087655663, + 0.01818535290658474, + 0.0025106174871325493, + 0.043714240193367004, + 0.019146842882037163, + -0.0041794623248279095, + 0.033447377383708954, + 0.06863203644752502, + -0.004350902978330851, + 0.0113364327698946, + -0.05825724080204964, + -0.04649435728788376, + -0.10618306696414948, + 0.02653644233942032, + 0.012514552101492882, + 0.019399365410208702, + -0.0022177041973918676, + 0.017741208896040916, + 0.04115311801433563, + 0.05122101679444313, + 0.055051617324352264, + 0.01687677949666977, + -0.03698579967021942, + 0.10053858160972595, + -0.007528421934694052, + 0.003968802746385336, + 0.02458524890244007, + -0.02144794538617134, + 0.026791265234351158, + -0.016701897606253624, + 0.014119372703135014, + -0.03460531681776047, + -0.02320348098874092, + 0.056146953254938126, + 0.028700685128569603, + -0.14820916950702667, + -0.016996873542666435, + 0.025667931884527206, + 0.08408629894256592, + 0.00034475952270440757, + 0.007573155220597982, + 0.06784884631633759, + 0.025982951745390892, + -0.08363039791584015, + -0.015748541802167892, + -0.0029514851048588753, + -0.01523523684591055, + 0.10500328987836838, + 0.3070858418941498, + -0.024624783545732498, + 0.0058471946977078915, + -0.039751242846250534, + 0.0012745993444696069, + -0.0796508714556694, + 0.024727927520871162, + 0.056764136999845505, + -0.013338261283934116, + -0.04794292524456978, + -0.02609768509864807, + -0.010784422047436237, + -0.048712026327848434, + 0.020345501601696014, + 0.0021618579048663378, + -0.0021724768448621035, + 0.03056410700082779, + -0.01633712649345398, + -0.47168225049972534, + -0.014639903791248798, + -0.012550815008580685, + 0.03358187526464462, + 0.07889427989721298, + -0.03615899011492729, + -0.002809660043567419, + -0.006953644100576639, + 0.02024337276816368, + -0.0738825723528862, + -0.006984011270105839, + -0.04472561925649643, + -0.027498915791511536, + 0.07207506150007248, + -0.09166522324085236, + -0.008861960843205452, + 0.05264359340071678, + 0.01889069564640522, + -0.1380404680967331, + -0.010141258127987385, + 0.015403619967401028, + -0.16416165232658386, + -0.03529815003275871, + 0.042106859385967255, + 0.11173021793365479, + -0.3143587112426758, + 0.011045016348361969, + 0.0012351945042610168, + 0.03840603306889534, + 0.0685538575053215, + -0.000746160454582423, + -0.028142500668764114, + 0.027154160663485527, + 0.005731801502406597, + 0.04433267563581467, + -0.8158469796180725, + 0.02226361259818077, + -0.07650655508041382, + 0.026958195492625237, + -0.005810025613754988, + -0.020102059468626976, + -0.0019310436910018325, + 0.07697021961212158, + -0.057701658457517624, + 0.05954534560441971, + 0.0027106746565550566, + -0.06311310827732086, + 0.011713752523064613, + -0.0034454476553946733, + -0.0006881420267745852, + 0.08937360346317291, + -0.0008253820124082267, + -0.031066063791513443, + -0.14708301424980164, + -0.04438449814915657, + 0.004772413522005081, + 0.05992274731397629, + 0.07473544776439667, + -0.1784757375717163, + -0.19057415425777435, + -0.014637955464422703, + -0.24898527562618256, + 0.13606221973896027, + -0.018039124086499214, + -0.047193415462970734, + -0.06526428461074829, + 0.04075757786631584, + 0.049901530146598816, + -0.008585861884057522, + 0.01616351678967476, + -3.091737016802654e-05, + 0.024283329024910927, + 0.008861682377755642, + -0.0005823548417538404, + 0.0997646301984787, + 0.051001910120248795, + 0.009473294951021671, + -0.0032046104315668344, + 0.018362928181886673, + 0.008627718314528465, + -0.4148157835006714, + -0.016077928245067596, + 0.0745391696691513, + 0.00724065862596035, + 0.08948155492544174, + 0.11626332253217697, + -0.052439428865909576, + 0.005599102005362511, + 0.002622961765155196, + 0.07586965709924698, + 0.03274847939610481, + -0.02099076844751835, + -0.04666733741760254, + -0.0013019372709095478, + 0.04945925995707512, + 0.11393380910158157, + 0.006346395239233971, + 0.04721064493060112, + 0.010331138968467712, + 0.08918803185224533, + 0.04288423806428909, + -0.09234773367643356, + 0.020141584798693657, + -3.256054696976207e-05, + -0.02799108810722828, + 0.018966441974043846, + -0.4136410355567932, + -0.07217283546924591, + 0.01840362884104252, + -0.055327851325273514, + 0.003275467548519373, + -0.017174070701003075, + -0.032178670167922974, + 0.09021560847759247, + -0.524413526058197, + 0.01994725503027439, + 0.10380692034959793, + -0.01043684035539627, + -0.00011200909648323432, + 0.01331041194498539, + 0.020127851516008377, + -0.025159789249300957, + 0.05252581834793091, + 0.04759140685200691, + 0.0032084162812680006, + -0.03579062595963478, + 0.054719552397727966, + -0.04674411937594414, + 0.028389262035489082, + 0.001127603929489851, + -0.0006243048119358718, + -0.00550495833158493, + -0.022523507475852966, + -0.024282312020659447, + 0.009519628249108791, + -0.39908328652381897, + -0.009265545755624771, + -0.00037090369733050466, + 0.06425131112337112, + -0.05998316407203674, + -0.015221518464386463, + -0.004825026262551546, + 0.11847284436225891, + -0.011302731931209564, + -0.006884834263473749, + -0.04678218811750412, + -0.012078279629349709, + 0.021638741716742516, + -0.016819776967167854, + -0.009127719327807426, + -0.002491263672709465, + 0.0016752213705331087, + -0.016600262373685837, + 0.011772023513913155, + -0.013447183184325695, + -0.020662957802414894, + -0.011593316681683064, + 0.008270744234323502, + -0.0026990456972271204, + -0.004406482446938753, + -0.023110052570700645, + -0.00208942755125463, + -0.1711198389530182, + 0.012432538904249668, + -0.0045453268103301525, + 0.024807902052998543, + -0.0035043740645051003, + -0.004001997876912355, + -0.013488625176250935, + -0.02020987868309021, + -0.01216109935194254, + -0.004432092886418104, + 0.09323672950267792, + -0.015641510486602783, + -0.019307948648929596, + 0.01117538008838892, + -0.01422040443867445, + 0.01705607771873474, + -0.0029596879612654448, + -0.0021530911326408386, + -0.006551788654178381, + 0.00429268553853035, + -0.1620807945728302, + -0.014128226786851883, + -0.005428737495094538, + -0.006771362852305174, + 0.005730633158236742, + 0.0007243106956593692, + 0.0024031582288444042, + -0.00199915561825037, + 0.006133859045803547, + -0.013380909338593483, + 0.00733462069183588, + -0.001863821060396731, + -0.0020169683266431093, + -0.014070986770093441, + -0.006501683499664068, + -0.029421553015708923, + 0.0009377509704791009, + -0.01718256250023842, + -0.05819401144981384, + -0.018859732896089554, + 0.0010356366401538253, + 0.006394123658537865, + -0.021985618397593498, + -0.01204769592732191, + -0.002014884725213051, + -0.019398409873247147, + -0.013122898526489735, + -0.017277296632528305, + -0.002270353492349386, + -0.05294327810406685, + -0.020317314192652702, + -0.018196573480963707, + -0.010375416837632656, + -0.019704729318618774, + -0.016109557822346687, + -0.0167380403727293, + -0.0285252146422863, + -0.02665277197957039, + -0.03554505482316017, + -0.00741522666066885, + -0.013580105267465115, + -0.026335405185818672, + -0.011694515123963356, + -0.004639182705432177, + -0.03996071219444275, + -0.022463932633399963, + -0.007204636000096798, + -0.021065134555101395, + -0.014410646632313728, + 0.0035447971895337105, + -0.0013098351191729307, + -0.024171002209186554, + 0.00047751085367053747, + -0.01870289072394371, + -0.06016797944903374, + -0.025703946128487587, + -0.009730588644742966, + -0.021792838349938393, + -0.024519823491573334, + -0.01843440905213356, + -0.0016325484029948711, + -0.008116388693451881, + -0.017774557694792747, + -0.04375867918133736, + -0.03893980756402016, + -0.018188582733273506, + -0.007122726645320654, + -0.028115490451455116, + -0.01821342669427395, + -0.01011319737881422, + -0.02616124413907528, + -0.013797983527183533, + -0.03202736750245094, + -0.030110370367765427, + -0.01883666031062603, + -0.01185502391308546, + -0.006012012716382742, + -0.017311619594693184, + -0.022577986121177673, + -0.02101938985288143, + 0.0025952248834073544, + -0.005058783106505871, + -0.004162575118243694, + -0.01559755764901638, + -0.017923563718795776, + -0.04231095686554909, + -0.017630560323596, + -0.011938830837607384, + -0.01587115228176117, + 0.004972478374838829, + -0.016601158306002617, + 0.15419845283031464, + 0.0009241115767508745, + 0.051028184592723846, + 0.008128340356051922, + -0.019917558878660202, + -0.0010339801665395498, + 0.022349294275045395, + -0.0072520882822573185, + 0.0017750378465279937, + -0.10526080429553986, + 0.03420695662498474, + 0.019183926284313202, + -0.0006544998032040894, + -0.0032203509472310543, + -0.01216941885650158, + -0.03561796247959137, + 0.024905826896429062, + -0.026948239654302597, + -0.01913355104625225, + -0.014459407888352871, + 0.006972283590584993, + -0.033184293657541275, + 0.04884861409664154, + -0.002296984428539872, + -0.19194477796554565, + 0.00392142403870821, + 0.009490449912846088, + -0.02687196619808674, + -0.06327224522829056, + -0.03684951737523079, + -0.0002613202668726444, + -0.012086644768714905, + 0.03630973398685455, + 0.007296048104763031, + 0.011186012998223305, + 0.0074085514061152935, + -0.020394617691636086, + -0.010585476644337177, + -0.030289918184280396, + 0.0773506686091423, + 0.008841303177177906, + 0.019423579797148705, + 0.001184571417979896, + 0.005553434602916241, + 0.015373414382338524, + -0.0027953842654824257, + 0.013204757124185562, + 0.029097743332386017, + 0.012627501040697098, + 0.02102004364132881, + -0.09469914436340332, + -0.023324014618992805, + 0.029243655502796173, + 0.002979277865961194, + -0.004492263309657574, + 0.20549021661281586, + -0.3244459927082062, + 0.025892559438943863, + 0.009620796889066696, + -0.05520407855510712, + -0.02271144650876522, + 0.008378816768527031, + -0.0671214610338211, + -0.016056722030043602, + -0.02355658821761608, + 0.0005429868469946086, + -0.007960098795592785, + 0.02513299137353897, + -0.13005328178405762, + -0.0025323680602014065, + -0.02197088487446308, + -0.02404806576669216, + 0.08261960744857788, + 0.17078880965709686, + 0.02880753017961979, + -0.03642067685723305, + 0.021994341164827347, + -0.012368184514343739, + -0.10681373625993729, + 0.16371481120586395, + 0.17881983518600464, + -0.10202010720968246, + -0.08641688525676727, + -0.1259487271308899, + 0.06907707452774048, + 0.023792706429958344, + -0.02534419298171997, + 0.016984017565846443, + -0.06743635982275009, + 0.08445960283279419, + -0.08037827908992767, + -0.11935994029045105, + -0.31716489791870117, + -0.01860150322318077, + 0.060669515281915665, + -0.06137414649128914, + 0.09878886491060257, + 0.01794014871120453, + 0.12382296472787857, + -0.016424886882305145, + 0.09045679122209549, + -0.02998783066868782, + -0.00972777884453535, + -0.024124544113874435, + 0.09879253059625626, + 0.05500243604183197, + -0.06635259836912155, + 0.11268552392721176, + 0.011751363053917885, + -0.04690232127904892, + -0.025168607011437416, + 0.088335320353508, + -0.1140628531575203, + 0.04129032790660858, + -0.04258979484438896, + -0.0903872698545456, + 0.008473021909594536, + -0.026690304279327393, + -0.051559556275606155, + -0.05481572076678276, + -0.05251916125416756, + -0.0018165932269766927, + 0.09836867451667786, + 0.0054859439842402935, + 0.06432581692934036, + 0.10621821135282516, + -0.019325286149978638, + -0.028727786615490913, + 0.014013150706887245, + -0.008022608235478401, + -0.006281842477619648, + -0.0297000203281641, + 0.01525485422462225, + -0.4346403479576111, + 0.07787995040416718, + -0.25380268692970276, + 0.05261845141649246, + 0.010875157080590725, + 0.0014149334747344255, + 0.05021188035607338, + -0.24382442235946655, + 0.0807114690542221, + 0.022907381877303123, + 0.006440790370106697, + -0.017028095200657845, + 0.001552293193526566, + 0.05961666256189346, + -0.14113056659698486, + 0.03398876264691353, + -0.005411976482719183, + -0.014025667682290077, + -0.5433799624443054, + 0.019015472382307053, + 0.04091138765215874, + 0.05059061944484711, + 0.0274446289986372, + -0.010288042947649956, + -0.001335533568635583, + -0.013533512130379677, + 0.018798377364873886, + -0.04099345579743385, + 0.0031264263670891523, + -0.21071769297122955, + -0.014384736306965351, + -0.1045387014746666, + -0.014340974390506744, + 0.001986369490623474, + -0.04118456318974495, + -0.10952988266944885, + 0.049147430807352066, + -0.08382093161344528, + -0.1741400957107544, + -0.0885215476155281, + -0.10934099555015564, + 0.05553343519568443, + 0.02434251271188259, + 0.006634524557739496, + -0.0017163373995572329, + 0.0185443926602602, + 0.06250902265310287, + -0.17145656049251556, + -0.07543934881687164, + 0.026583310216665268, + 0.01634727604687214, + 0.003603539662435651, + -0.2817271649837494, + 0.03882112354040146, + 0.011341865174472332, + 0.00826666783541441, + 0.050427842885255814, + -0.22358834743499756, + 0.06419781595468521, + 0.03245265409350395, + -0.04503164440393448, + -0.023194484412670135, + -0.027968740090727806, + 0.08563586324453354, + 0.07954753190279007, + -0.08513130992650986, + 0.02850884199142456, + 0.008976672776043415, + 0.07886530458927155, + 0.0022273347713053226, + -0.09540755301713943, + 0.032016951590776443, + -0.05196075513958931, + 0.10555616766214371, + 0.07629868388175964, + 0.039732079952955246, + -0.0029798501636832952, + 0.014692343771457672, + 0.09200941026210785, + -0.04299614951014519, + -0.023488566279411316, + -0.01851060427725315, + 0.09257487207651138, + 0.055612049996852875, + 0.06423109769821167, + -0.28587806224823, + -0.09950444847345352, + 0.10397437959909439, + 0.025166453793644905, + -0.03235514089465141, + -0.033381711691617966, + 0.1513858139514923, + 0.06468874961137772, + 0.01928441785275936, + 0.0032701045274734497, + -0.0579083226621151, + -0.022929169237613678, + 0.012971373274922371, + -0.018524186685681343, + -0.06484643369913101, + 0.012233717367053032, + 0.06590451300144196, + -0.04558677598834038, + 0.05253027006983757, + 0.048656731843948364, + -0.2288871705532074, + 0.037114787846803665, + -0.20519588887691498, + 0.0058607361279428005, + -0.002009372925385833, + -0.006671734619885683, + -0.07107856124639511, + -0.07407436519861221, + 0.03941629081964493, + 0.0447598397731781, + 0.03509354963898659, + -0.061107732355594635, + -0.09305761009454727, + -0.012180411256849766, + 0.04902016744017601, + 0.07974442094564438, + -0.016854410991072655, + 0.005089411046355963, + -0.08127597719430923, + 0.03258403390645981, + 0.039813362061977386, + -0.01668727956712246, + 0.027226485311985016, + -0.029213925823569298, + -0.008598217740654945, + 0.00931101106107235, + 0.026936721056699753, + -0.03083401545882225, + -0.05799110606312752, + -0.008277476765215397, + -0.014854338951408863, + -0.20012643933296204, + 0.012290815822780132, + 0.007194168865680695, + 0.06858328729867935, + -0.3296163082122803, + -0.11424986273050308, + 0.009912200272083282, + -0.06211454048752785, + 0.0007546336855739355, + 0.03507614880800247, + 0.10649498552083969, + -0.03036407195031643, + 0.0646015852689743, + -0.01595110446214676, + -0.16919563710689545, + 0.0013865949586033821, + -0.08339446783065796, + 0.06962471455335617, + 0.016058098524808884, + -0.04729780554771423, + 0.010602935217320919, + 0.01470863912254572, + 0.06903938204050064, + 0.014901719056069851, + -0.15120048820972443, + 0.016727851703763008, + 0.05003673583269119, + 0.04370126873254776, + 0.029703885316848755, + 0.021875420585274696, + 0.026293285191059113, + -0.01048936415463686, + 0.00040810942300595343, + -0.015616541728377342, + -0.062451593577861786, + 0.010016348212957382, + -0.06790193170309067, + -0.02077331207692623, + 0.007985175587236881, + -0.04435744881629944, + 0.06920231133699417, + 0.018344474956393242, + 0.028591370210051537, + 0.021957838907837868, + 0.0017570338677614927, + 0.036665257066488266, + 0.015438515692949295, + -0.0006347382441163063, + 0.04621066153049469, + -0.001942177303135395, + 0.010664877481758595, + -0.016754357144236565, + 0.006541184149682522, + -0.027716301381587982, + -0.0058586387895047665, + -0.005346015095710754, + 0.020482052117586136, + 0.06882552057504654, + 0.0026622572913765907, + 0.016321638599038124, + 0.017728103324770927, + -0.13356441259384155, + 0.030281176790595055, + 1.0354949154134374e-05, + 0.050639618188142776, + 0.0013030078262090683, + -0.11136802285909653, + -0.006832807790488005, + -0.09628921747207642, + 0.046699415892362595, + 0.002175685251131654, + 0.008100612089037895, + 0.012449901551008224, + -0.01713990420103073, + -0.000769267207942903, + 0.022544430568814278, + -0.0018787183798849583, + -0.014189678244292736, + 0.37042510509490967, + -0.030317893251776695, + 0.012663356028497219, + -0.04071582853794098, + 0.01653047651052475, + 0.06578584760427475, + 0.005606585182249546, + 0.0029362838249653578, + -0.02035594917833805, + 0.016131827607750893, + -0.06512665003538132, + 0.020292088389396667, + 0.12818951904773712, + -0.00017647731874603778, + 0.0004811069811694324, + 0.013025660999119282, + -0.006004344671964645, + 0.011330580338835716, + 0.0021733916364610195, + -0.0026290342211723328, + 0.008579215034842491, + -0.017107143998146057, + 0.0032798980828374624, + 0.21415431797504425, + -0.011049880646169186, + 0.04915957152843475, + -0.01152863260358572, + 0.01988764852285385, + -0.30189022421836853, + 0.1491061896085739, + 0.022540517151355743, + 0.02323656715452671, + -0.0028044115751981735, + -0.02501249685883522, + 0.0016759912250563502, + 0.023405946791172028, + 0.0865691602230072, + 0.0056661744602024555, + 0.2334042638540268, + -0.05771901085972786, + 0.03428330272436142, + -0.05191519856452942, + 0.025708407163619995, + -0.11474912613630295, + 0.05345827341079712, + 0.050046734511852264, + -0.03785427287220955, + 0.02726786397397518, + 0.008640051819384098, + -0.05810163915157318, + 0.19147679209709167, + 0.12065602838993073, + -0.08667072653770447, + -0.12831886112689972, + 0.027053257450461388, + -0.1771622896194458, + -0.2615586817264557, + 0.112942636013031, + 0.002398239215835929, + 0.00907410029321909, + 0.059947770088911057, + 0.040937639772892, + 0.003431124845519662, + 0.012721046805381775, + -0.10228776186704636, + 0.04169567674398422, + -0.04826785624027252, + -0.021415220573544502, + 0.027615519240498543, + 0.16087181866168976, + 0.03552674129605293, + -0.36409878730773926, + 0.0015418739058077335, + 0.03940089792013168, + -0.12929502129554749, + 0.017082052305340767, + -0.07193783670663834, + 0.10395099222660065, + -0.2240910828113556, + -0.003303584409877658, + -0.0074868109077215195, + -0.13708709180355072, + 0.2098008245229721, + 0.013808795250952244, + -0.03606148064136505, + 0.001965852687135339, + 0.04186573252081871, + 0.02105732634663582, + -0.11873909085988998, + -0.08529136329889297, + 0.0060731275007128716, + 0.04803553968667984, + 0.07665349543094635, + 0.026997262611985207, + 0.05191565304994583, + 0.09013131260871887, + 0.013081093318760395, + 0.04667182266712189, + -0.19899451732635498, + 0.004642056301236153, + 0.0025570227298885584, + -0.2640555500984192, + 0.008254006505012512, + 0.05971720814704895, + -0.002980671590194106, + 0.0011313167633488774, + -0.004445134196430445, + 0.01951296627521515, + -0.006634386721998453, + -0.008033698424696922, + 0.012400158680975437, + -0.15906694531440735, + 0.007047838997095823, + 0.0003521084145177156, + -0.00517050176858902, + -0.0003226286207791418, + -0.01226231548935175, + -0.06750697642564774, + -0.03061128593981266, + -0.0027100055012851954, + 0.004726986400783062, + 0.010185977444052696, + 0.021205933764576912, + -0.05105980113148689, + -0.006725164130330086, + 0.26042309403419495, + 0.003935054875910282, + 0.009450466372072697, + -0.009512278251349926, + 0.036205559968948364, + 0.0066987741738557816, + 0.05687355250120163, + -0.0070350514724850655, + 0.021287698298692703, + 0.004246287513524294, + -0.004053668584674597, + 0.0030501342844218016, + -0.003596516093239188, + 0.00571554945781827, + 0.039099883288145065, + 0.06648323684930801, + 0.011140268296003342, + 0.002779693342745304, + 0.0004113377653993666, + 0.0019621821120381355, + 0.002047213725745678, + -9.034215327119455e-05, + 0.006674906238913536, + -0.024464793503284454, + 4.372629337012768e-05, + 0.04560312256217003, + 0.029951298609375954, + 0.0053787860088050365, + 0.010052027180790901, + 0.0018156497972086072, + 0.001613074098713696, + -0.3710610568523407, + 0.18385423719882965, + 0.0197732076048851, + -2.409513217571657e-05, + 0.043657880276441574, + 0.029824273660779, + -0.0015272254822775722, + -0.0009817760437726974, + 0.030571524053812027, + 0.05133187025785446, + 0.021092001348733902, + -0.022430723533034325, + -0.011050102300941944, + -0.01653454266488552, + 0.00856624636799097, + 0.007617316208779812, + 0.023697074502706528, + -0.00541776092723012, + -0.06940567493438721, + -0.024501511827111244, + 0.0029131292831152678, + 0.005110545549541712, + 0.02394089475274086, + 0.009317552670836449, + -0.05198051407933235, + -0.14872707426548004, + -0.03553030639886856, + 0.05354774370789528, + 0.053996339440345764, + 0.016679847612977028, + -0.4505158066749573, + 0.006403166800737381, + -0.014287465251982212, + 0.010499212890863419, + 0.00510875741019845, + 0.0230255089700222, + -0.04791099205613136, + -0.08405473828315735, + -0.00807158276438713, + -0.016310568898916245, + -0.018034789711236954, + -0.03381670266389847, + 0.038599055260419846, + 0.01189411524683237, + 0.0038598189130425453, + 0.0077203805558383465, + -0.0006835742969997227, + 0.3038807809352875, + 0.00930703990161419, + -0.017654214054346085, + -0.029550395905971527, + 0.0014829621650278568, + -0.010562432929873466, + -0.011867706663906574, + -0.008104459382593632, + 0.008003979921340942, + -0.028282882645726204, + 0.00898829661309719, + -0.04963170364499092, + 0.014971665106713772, + 0.028662119060754776, + 0.055792808532714844, + 0.018142173066735268, + 0.029526766389608383, + 0.04726170003414154, + 0.020290115848183632, + -0.01347910612821579, + -0.027794860303401947, + -0.033374592661857605, + 0.05699307844042778, + -0.005888971965759993, + 0.009723466821014881, + 0.011825029738247395, + 0.0005665962235070765, + -0.22433574497699738, + 0.04777664318680763, + 0.054696254432201385, + 0.06447272002696991, + 0.006656138692051172, + -0.2656468152999878, + -0.006602808367460966, + -0.04309352487325668, + 0.024392882362008095, + -0.046948980540037155, + 0.17317010462284088, + -0.014694501645863056, + 0.09150391072034836, + 0.05414793640375137, + -0.0034523033536970615, + -0.029682809486985207, + -0.11646991223096848, + 0.036394182592630386, + -0.008510537445545197, + -0.09555189311504364, + 0.012331446632742882, + 0.022554755210876465, + 0.037040166556835175, + 0.011939534917473793, + -0.035405583679676056, + -0.008284371346235275, + 0.008629710413515568, + -0.0017152110813185573, + -0.01656493730843067, + 0.02205522358417511, + -0.008015291765332222, + -0.02198217809200287, + -0.08165504783391953, + 0.018647879362106323, + 0.010489191859960556, + 0.0009643095545470715, + 0.08301698416471481, + 0.00795030314475298, + -0.08973152935504913, + 0.05324552580714226, + 0.0187348835170269, + 0.00770497927442193, + 0.016434336081147194, + 0.0031714467331767082, + 0.031489044427871704, + -0.01682765781879425, + -0.0006042059976607561, + 0.006229344755411148, + 0.0031935630831867456, + -0.03694210946559906, + -0.027148112654685974, + 0.03319454565644264, + 0.013541879132390022, + 0.04362545907497406, + 0.010766182094812393, + 0.01287879142910242, + 0.02723391354084015, + 0.01831277459859848, + -0.0028144901152700186, + 0.0317537821829319, + -0.05053209140896797, + 0.03341667726635933, + 0.009338690899312496, + 0.030376508831977844, + 0.028512636199593544, + 0.002190604107454419, + 0.031132254749536514, + 0.04174429178237915, + 0.025147251784801483, + 0.02602408640086651, + 0.022863827645778656, + 0.024160150438547134, + 0.04043813422322273, + 0.011693909764289856, + 0.008020071312785149, + 0.010814648121595383, + 0.014862221665680408, + 0.043966785073280334, + 0.04133215174078941, + 0.03920775279402733, + 0.02128027193248272, + -0.0024078795686364174, + 0.03185494989156723, + 0.030951442196965218, + 0.008766901679337025, + -0.0013500713976100087, + 0.012680909596383572, + 0.01911563239991665, + 0.02226334996521473, + 0.03873631730675697, + 0.005242412444204092, + 0.02335301972925663, + 0.00577192846685648, + 0.0019918885082006454, + 0.019501060247421265, + 0.048295676708221436, + 0.027288099750876427, + 0.03500128164887428, + 0.032504353672266006, + 0.03619033470749855, + 0.022762063890695572, + 0.014124974608421326, + 0.04055529460310936, + 0.040181197226047516, + 0.04843837395310402, + 0.019578352570533752, + 0.04370861127972603, + 0.024640914052724838, + 0.027013463899493217, + 0.04700532928109169, + 0.018523193895816803, + 0.03569294884800911, + 0.031140455976128578, + 0.010298499837517738, + 0.03979840502142906, + 0.015059049241244793, + 0.020604899153113365, + 0.010335667058825493, + 0.02557498589158058, + 0.015946611762046814, + 0.018900645896792412, + 0.05494159087538719, + 0.015756357461214066, + 0.0452926866710186, + 0.04820817708969116, + -0.0183499027043581, + 0.04002442955970764, + -0.08226092159748077, + -0.034417178481817245, + 0.059122342616319656, + 0.028960591182112694, + -0.020427608862519264, + -0.043222296983003616, + 0.023134637624025345, + -0.014232538640499115, + -0.06970997899770737, + -0.0035826240200549364, + -0.015384080819785595, + -0.0695020854473114, + 0.03645527362823486, + 0.013986784033477306, + -0.027729706838726997, + -0.05711805075407028, + -0.0763891413807869, + -0.16338491439819336, + -0.02358265034854412, + -0.004730133805423975, + 0.022057903930544853, + -0.011578230187296867, + 0.040772147476673126, + -0.059327173978090286, + -0.03819728270173073, + -0.050089117139577866, + -0.005152902565896511, + -0.3071111738681793, + -0.010683669708669186, + 0.030922774225473404, + 0.08924981951713562, + 0.005679265595972538, + 0.06334424018859863, + 0.016136568039655685, + -0.02575727365911007, + -0.012562219053506851, + 0.007206748705357313, + -0.1373208612203598, + -0.010450832545757294, + -0.05991309881210327, + -0.006700845435261726, + -0.006468744482845068, + -0.02040017955005169, + -0.010068708099424839, + 0.008442427963018417, + 0.012259873561561108, + -0.002103718463331461, + -0.019605906680226326, + -0.010690353810787201, + 0.0005222380859777331, + -0.015031278133392334, + -0.012983204796910286, + -0.03552224859595299, + -0.007792052812874317, + -0.035602111369371414, + -0.03479204699397087, + -0.02480080910027027, + -0.05733964219689369, + 4.38804054283537e-05, + -0.021825626492500305, + -0.03287259489297867, + -0.05437042564153671, + -0.007981077767908573, + 0.023045696318149567, + 0.05785335600376129, + 0.03685669228434563, + 0.04314129799604416, + -0.005843586288392544, + -0.024806369096040726, + -0.02562016434967518, + 0.0015172295970842242, + -0.01568800024688244, + -0.005925294477492571, + 0.010173594579100609, + 0.06834683567285538, + 0.024159085005521774, + -0.009547322988510132, + 0.014080812223255634, + 0.013578452169895172, + 0.035671167075634, + 0.01240566186606884, + -0.021352441981434822, + 0.05245270952582359, + -0.008943279273808002, + -0.010131126269698143, + 0.02976749651134014, + 0.0600045844912529, + 0.0014893191400915384, + 0.03796907886862755, + 0.01258794590830803, + -0.025344882160425186, + 0.14140591025352478, + 0.028354406356811523, + 0.0035325682256370783, + 0.05017172172665596, + 0.01994139887392521, + 0.03679897263646126, + -0.009579945355653763, + -0.012607194483280182, + -0.00034231581958010793, + 0.00046832446241751313, + 0.057916246354579926, + 0.02351403795182705, + 0.06157909706234932, + 0.00789523497223854, + -0.018361341208219528, + 0.0018971840618178248, + -0.007180131506174803, + -0.0010631990153342485, + -0.03140748664736748, + -0.028505641967058182, + 0.010669395327568054, + -0.036474280059337616, + 0.01703447848558426, + 0.04667484760284424, + -0.007303370162844658, + 0.01768752932548523, + 0.012412219308316708, + 0.013702306896448135, + 0.07651616632938385, + 0.05469715967774391, + 0.013292597606778145, + -0.006288900971412659, + 0.0215559434145689, + 0.010094149969518185, + -0.024216346442699432, + -0.15225785970687866, + 0.05467289313673973, + 0.019871067255735397, + 0.04662928730249405, + 0.05072600021958351, + -0.011824453249573708, + -0.028083933517336845, + 0.013322187587618828, + -0.044827401638031006, + 0.05955006927251816, + -0.006152187939733267, + 0.013426700606942177, + -0.014220507815480232, + 0.022510837763547897, + 0.019426455721259117, + -0.05546477064490318, + -0.49202534556388855, + 0.026985207572579384, + -0.08852843940258026, + 0.07166163623332977, + 0.05509938299655914, + -0.42284780740737915, + -0.05131356418132782, + 0.0196990966796875, + -0.008681846782565117, + 0.02739996463060379, + 0.0010900507913902402, + 0.04289104416966438, + -0.06694932281970978, + 0.05930810049176216, + -0.02174360118806362, + 0.03464379161596298, + 0.018284866586327553, + 0.018807150423526764, + 0.019874336197972298, + -0.03665176033973694, + -0.2980017066001892, + 0.050937239080667496, + -0.013874954544007778, + -0.0229057464748621, + 0.016420641914010048, + 0.024160616099834442, + -0.10750921070575714, + -0.010134756565093994, + 0.026874780654907227, + 0.007151094265282154, + 0.06304068863391876, + -0.11811652034521103, + -0.12590888142585754, + 0.031846947968006134, + -0.06898463517427444, + 0.03395693376660347, + -0.00010166154243052006, + -0.19019480049610138, + 0.06616076827049255, + -0.035927142947912216, + 0.08526375889778137, + 0.0015017242403700948, + -0.009137739427387714, + 0.04529058188199997, + -0.23621641099452972, + 0.02148340456187725, + -0.02741178683936596, + -0.20779411494731903, + ] + value = numpy.array(list_value, dtype=numpy.float32).reshape((64, 64, 1, 1)) + tensor = numpy_helper.from_array(value, name="onnx::Conv_504") + + initializers.append(tensor) + + list_value = [ + 5.195802688598633, + 0.940099835395813, + -7.016428470611572, + 5.185446739196777, + -4.134859085083008, + 2.0121846199035645, + 5.215719223022461, + 3.371406078338623, + 3.7616095542907715, + -3.6593239307403564, + 15.99945068359375, + 3.306276321411133, + 5.790191173553467, + 6.33050537109375, + 3.4512906074523926, + 2.5531861782073975, + 4.278702259063721, + 4.350361347198486, + 8.025779724121094, + -2.8830037117004395, + 2.915111541748047, + 3.592482805252075, + 5.810481071472168, + 3.4743332862854004, + 3.5245680809020996, + 1.8243598937988281, + 8.069726943969727, + 1.401036024093628, + 5.110081672668457, + -12.873579978942871, + 10.977816581726074, + 5.909627437591553, + -0.4007779359817505, + -20.147268295288086, + 6.649413585662842, + 3.325921058654785, + 5.84471321105957, + 4.47447395324707, + 3.754193067550659, + -5.167671203613281, + 3.2778055667877197, + -9.067073822021484, + 2.6243438720703125, + 1.7002031803131104, + 5.476454734802246, + 2.510835647583008, + 3.856968402862549, + 2.3172807693481445, + 12.462139129638672, + 7.355924129486084, + 4.140628814697266, + 4.807559967041016, + 5.7524309158325195, + 4.128836154937744, + 11.4532470703125, + -12.482564926147461, + 5.590144157409668, + 0.9172697067260742, + 4.356811046600342, + 0.9934853315353394, + -4.3548994064331055, + 15.853201866149902, + -5.241130828857422, + 5.9644365310668945, + ] + value = numpy.array(list_value, dtype=numpy.float32) + tensor = numpy_helper.from_array(value, name="onnx::Conv_505") + + initializers.append(tensor) + + # inputs + + inputs.append(make_tensor_value_info("input", 1, ["batch_size", 3, 32, 32])) + + # outputs + + outputs.append(make_tensor_value_info("/layer1/layer1.0/relu/Relu_output_0", 1, ["batch_size", 64, 8, 8])) + + # nodes + + node = make_node( + "Conv", + ["input", "onnx::Conv_501", "onnx::Conv_502"], + ["/conv1/Conv_output_0"], + name="/conv1/Conv", + dilations=[1, 1], + group=1, + kernel_shape=[7, 7], + pads=[3, 3, 3, 3], + strides=[2, 2], + domain="", + ) + nodes.append(node) + + node = make_node("Relu", ["/conv1/Conv_output_0"], ["/relu/Relu_output_0"], name="/relu/Relu", domain="") + nodes.append(node) + + node = make_node( + "MaxPool", + ["/relu/Relu_output_0"], + ["/maxpool/MaxPool_output_0"], + name="/maxpool/MaxPool", + ceil_mode=0, + kernel_shape=[3, 3], + pads=[1, 1, 1, 1], + strides=[2, 2], + domain="", + ) + nodes.append(node) + + node = make_node( + "Conv", + ["/maxpool/MaxPool_output_0", "onnx::Conv_504", "onnx::Conv_505"], + ["/layer1/layer1.0/conv1/Conv_output_0"], + name="/layer1/layer1.0/conv1/Conv", + dilations=[1, 1], + group=1, + kernel_shape=[1, 1], + pads=[0, 0, 0, 0], + strides=[1, 1], + domain="", + ) + nodes.append(node) + + node = make_node( + "Relu", + ["/layer1/layer1.0/conv1/Conv_output_0"], + ["/layer1/layer1.0/relu/Relu_output_0"], + name="/layer1/layer1.0/relu/Relu", + domain="", + ) + nodes.append(node) + + # opsets + opset_imports = [make_opsetid(domain, 1 if version is None else version) for domain, version in opsets.items()] + + # graph + graph = make_graph(nodes, "torch_jit", inputs, outputs, initializers) + # '7' + + onnx_model = make_model(graph, opset_imports=opset_imports, functions=functions) + onnx_model.ir_version = 7 + onnx_model.producer_name = "pytorch" + onnx_model.producer_version = "" + onnx_model.domain = "" + onnx_model.model_version = 0 + onnx_model.doc_string = "" + set_model_props(onnx_model, {}) + + return onnx_model diff --git a/onnxruntime/test/python/quantization/test_quantize_static_resnet.py b/onnxruntime/test/python/quantization/test_quantize_static_resnet.py new file mode 100644 index 0000000000000..1efa283af6881 --- /dev/null +++ b/onnxruntime/test/python/quantization/test_quantize_static_resnet.py @@ -0,0 +1,138 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import os +import random +import tempfile +import unittest + +import numpy as np +import onnx +from numpy.testing import assert_allclose +from onnx.numpy_helper import to_array +from resnet_code import create_model + +from onnxruntime import InferenceSession +from onnxruntime import __version__ as ort_version +from onnxruntime.quantization import QuantFormat, QuantType, quantize_static +from onnxruntime.quantization.calibrate import CalibrationDataReader, CalibrationMethod + + +class FakeResnetCalibrationDataReader(CalibrationDataReader): + def __init__(self, batch_size: int = 16): + super().__init__() + self.dataset = [ + (np.random.rand(1, 3, 32, 32).astype(np.float32), random.randint(0, 9)) for _ in range(batch_size) + ] + self.iterator = iter(self.dataset) + + def get_next(self) -> dict: + try: + return {"input": next(self.iterator)[0]} + except Exception: + return None + + +class TestStaticQuantizationResNet(unittest.TestCase): + def test_quantize_static_resnet(self): + kwargs = { + "activation_type": QuantType.QUInt8, + "weight_type": QuantType.QInt8, + "calibrate_method": CalibrationMethod.Percentile, + "extra_options": { + "ActivationSymmetric": False, + "EnableSubgraph": False, + "ForceQuantizeNoInputCheck": False, + "MatMulConstBOnly": False, + "WeightSymmetric": True, + "extra.Sigmoid.nnapi": False, + }, + "nodes_to_exclude": None, + "nodes_to_quantize": None, + "op_types_to_quantize": None, + "per_channel": True, + "quant_format": QuantFormat.QDQ, + "reduce_range": False, + } + + proto = create_model() + + with tempfile.TemporaryDirectory() as temp: + model = os.path.join(temp, "resnet_first_nodes.onnx") + with open(model, "wb") as f: + f.write(proto.SerializeToString()) + + for per_channel in [True, False]: + kwargs["per_channel"] = per_channel + dataloader = FakeResnetCalibrationDataReader(16) + with self.subTest(per_channel=per_channel): + qdq_file = os.path.join( + temp, f"preprocessed-small-qdq-{1 if per_channel else 0}-ort-{ort_version}.onnx" + ) + quantize_static( + model_input=model, + model_output=qdq_file, + calibration_data_reader=dataloader, + use_external_data_format=False, + **kwargs, + ) + + # With onnxruntime==1.15.1, the initializer 'onnx::Conv_504_zero_point' is: + # * uint8(128) if per_channel is False + # * int8([0, 0, ....]) if per_channel is True + # With onnxruntime>1.16.0 + # * uint8(128) if per_channel is False + # * uint8([128, 128, ..., 127, ...]) if per_channel is True + # QLinearConv : zero point of per-channel filter must be same. + # That's why the quantization forces a symmetric quantization into INT8. + # zero_point is guaranted to be zero whatever the channel is. + + with open(qdq_file, "rb") as f: + onx = onnx.load(f) + for init in onx.graph.initializer: + arr = to_array(init) + if ( + arr.dtype == np.int8 + and "zero_point" not in init.name + and not init.name.endswith("quantized") + ): + raise AssertionError( + f"Initializer {init.name!r} has type {arr.dtype} and " + f"shape {arr.shape} but should be {np.uint8}." + ) + + sess = InferenceSession(qdq_file, providers=["CPUExecutionProvider"]) + shape = (1, 3, 32, 32) + size = np.prod(shape) + dummy = (np.arange(size) / float(size)).astype(np.float32).reshape(shape) + got = sess.run(None, {"input": dummy}) + self.assertEqual(got[0].shape, (1, 64, 8, 8)) + self.assertEqual(got[0].dtype, np.float32) + if per_channel: + expected = np.array( + [ + [[1.0862497091293335, 0.9609132409095764], [1.0862497091293335, 0.9191343784332275]], + [[0.7520190477371216, 1.0026921033859253], [1.0444709062576294, 1.0862497091293335]], + [[0.0, 0.0], [0.0, 0.0]], + [[0.0, 0.0], [0.9609132409095764, 0.7937979102134705]], + ], + dtype=np.float32, + ) + assert_allclose(expected, got[0][0, :4, :2, :2], atol=0.2) + else: + expected = np.array( + [ + [[1.428238868713379, 1.2602107524871826], [1.3442248106002808, 1.2182037830352783]], + [[0.8821475505828857, 1.0921826362609863], [1.1341897249221802, 1.1761966943740845]], + [[0.0, 0.0], [0.0, 0.0]], + [[0.0, 0.0], [1.2182037830352783, 1.050175666809082]], + ], + dtype=np.float32, + ) + assert_allclose(expected, got[0][0, :4, :2, :2], atol=0.2) + + +if __name__ == "__main__": + unittest.main(verbosity=2) From fcfc2391b818af5e28ded7f669a2f466f9276cb1 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Mon, 25 Sep 2023 12:20:56 -0700 Subject: [PATCH 23/58] [JSEP] allow JsCustomAllocator to deal with zero sized input (#17660) ### Description allow JsCustomAllocator to deal with zero sized input. This is a standalone fix for zero-sized tensor handling for JsCustomAllocator. There are other components in JSEP not supporting zero-sized tensors need to be fixed. --- onnxruntime/core/providers/js/allocator.cc | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/js/allocator.cc b/onnxruntime/core/providers/js/allocator.cc index c1d0aa9abbf6b..574c507222a5c 100644 --- a/onnxruntime/core/providers/js/allocator.cc +++ b/onnxruntime/core/providers/js/allocator.cc @@ -10,6 +10,10 @@ namespace onnxruntime { namespace js { void* JsCustomAllocator::Alloc(size_t size) { + if (size == 0) { + return nullptr; + } + void* p = EM_ASM_PTR({ return Module.jsepAlloc($0); }, size); stats_.num_allocs++; stats_.bytes_in_use += size; @@ -17,8 +21,10 @@ void* JsCustomAllocator::Alloc(size_t size) { } void JsCustomAllocator::Free(void* p) { - size_t size = (size_t)(void*)EM_ASM_PTR({ return Module.jsepFree($0); }, p); - stats_.bytes_in_use -= size; + if (p != nullptr) { + size_t size = (size_t)(void*)EM_ASM_PTR({ return Module.jsepFree($0); }, p); + stats_.bytes_in_use -= size; + } } void JsCustomAllocator::GetStats(AllocatorStats* stats) { From f50fa46fe09819734550a7e8725db999210e2b97 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Mon, 25 Sep 2023 12:21:20 -0700 Subject: [PATCH 24/58] [JSEP] allow DataTransfer to deal with zero sized input (#17661) ### Description allow DataTransfer to deal with zero sized input. This is a standalone fix for zero-sized tensor handling for JSEP DataTransfer. There are other components in JSEP not supporting zero-sized tensors need to be fixed. --- .../core/providers/js/data_transfer.cc | 34 ++++++++++--------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/onnxruntime/core/providers/js/data_transfer.cc b/onnxruntime/core/providers/js/data_transfer.cc index c62362d90867f..ebea041b80128 100644 --- a/onnxruntime/core/providers/js/data_transfer.cc +++ b/onnxruntime/core/providers/js/data_transfer.cc @@ -20,23 +20,25 @@ bool DataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_dev common::Status DataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const { size_t bytes = src.SizeInBytes(); - const void* src_data = src.DataRaw(); - void* dst_data = dst.MutableDataRaw(); - - auto& src_device = src.Location().device; - auto& dst_device = dst.Location().device; - - if (dst_device.Type() == OrtDevice::GPU) { - if (src_device.Type() == OrtDevice::GPU) { - // copy from GPU to GPU - EM_ASM({ Module.jsepCopy($0, $1, $2, true); }, src_data, dst_data, bytes); - } else { - // copy from CPU to GPU - EM_ASM({ Module.jsepCopy($0, $1, $2); }, src_data, dst_data, bytes); + if (bytes > 0) { + const void* src_data = src.DataRaw(); + void* dst_data = dst.MutableDataRaw(); + + auto& src_device = src.Location().device; + auto& dst_device = dst.Location().device; + + if (dst_device.Type() == OrtDevice::GPU) { + if (src_device.Type() == OrtDevice::GPU) { + // copy from GPU to GPU + EM_ASM({ Module.jsepCopy($0, $1, $2, true); }, src_data, dst_data, bytes); + } else { + // copy from CPU to GPU + EM_ASM({ Module.jsepCopy($0, $1, $2); }, src_data, dst_data, bytes); + } + } else /* if (src_device.Type() == OrtDevice::GPU) */ { + // copy from GPU to CPU + jsepDownload(src_data, dst_data, bytes); } - } else /* if (src_device.Type() == OrtDevice::GPU) */ { - // copy from GPU to CPU - jsepDownload(src_data, dst_data, bytes); } return Status::OK(); From b2b140860817d974b5f9dd0b8943c49a725a5edd Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Mon, 25 Sep 2023 12:24:46 -0700 Subject: [PATCH 25/58] [js/web] update browser launch cmd flags (#17658) ### Description update Chromium browser launch command line flags Canary already using dxc so no need to specify '--enable-dawn-features=use_dxc' for canary. --- js/web/script/test-runner-cli.ts | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/js/web/script/test-runner-cli.ts b/js/web/script/test-runner-cli.ts index a75321d45f1ef..f3764e63fcf45 100644 --- a/js/web/script/test-runner-cli.ts +++ b/js/web/script/test-runner-cli.ts @@ -493,19 +493,13 @@ async function main() { karmaArgs.push('--force-localhost'); } if (webgpu) { - if (browser.includes('Canary')) { - chromiumFlags.push('--enable-dawn-features=allow_unsafe_apis,use_dxc'); - } else { - chromiumFlags.push('--enable-dawn-features=use_dxc'); - chromiumFlags.push('--disable-dawn-features=disallow_unsafe_apis'); - } + // flag 'allow_unsafe_apis' is required to enable experimental features like fp16 and profiling inside pass. + // flag 'use_dxc' is required to enable DXC compiler. + chromiumFlags.push('--enable-dawn-features=allow_unsafe_apis,use_dxc'); } if (webnn) { chromiumFlags.push('--enable-experimental-web-platform-features'); } - if (config.options.globalEnvFlags?.webgpu?.profilingMode === 'default') { - chromiumFlags.push('--disable-dawn-features=disallow_unsafe_apis'); - } karmaArgs.push(`--bundle-mode=${args.bundleMode}`); karmaArgs.push(...chromiumFlags.map(flag => `--chromium-flags=${flag}`)); if (browser.startsWith('Edge')) { From a942bbf4897688fd1bd88bb1db73d2f8d648bcbc Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Mon, 25 Sep 2023 14:12:11 -0700 Subject: [PATCH 26/58] Update nodejs to 18.x (#17657) 1. Upgrade nodejs from 16.x to 18.x for Windows pipelines 2. Avoid using Azure DevOps "NodeTool" on Linux. The tool installs nodejs from internet or local disk cache. But we already moved all Linux tests to docker. So we do not need the installer anymore. 3. Remove some other unused code. --- .../azure-pipelines/linux-ci-pipeline.yml | 4 -- .../linux-cpu-aten-pipeline.yml | 4 -- .../linux-multi-gpu-tensorrt-ci-pipeline.yml | 2 - .../linux-openvino-ci-pipeline.yml | 2 - .../orttraining-linux-ci-pipeline.yml | 47 +------------------ .../orttraining-linux-gpu-ci-pipeline.yml | 11 +---- .../templates/jobs/win-ci-vs-2022-job.yml | 2 +- .../azure-pipelines/templates/linux-ci.yml | 40 +--------------- .../templates/linux-wasm-ci.yml | 2 +- .../templates/mac-cpu-packing-jobs.yml | 2 +- .../templates/react-native-ci.yml | 2 +- .../templates/web-browserstack-ci.yml | 2 +- .../azure-pipelines/templates/web-ci.yml | 2 +- .../azure-pipelines/templates/win-ci.yml | 2 +- .../azure-pipelines/templates/win-wasm-ci.yml | 2 +- .../azure-pipelines/templates/win-web-ci.yml | 2 +- .../templates/win-web-multi-browsers.yml | 2 +- .../azure-pipelines/win-ci-fuzz-testing.yml | 2 +- 18 files changed, 15 insertions(+), 117 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml index 33fc9d94bac09..395c190ce9e11 100644 --- a/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml @@ -56,10 +56,6 @@ stages: clean: true submodules: none - - task: NodeTool@0 - inputs: - versionSpec: '16.x' - - task: UsePythonVersion@0 inputs: versionSpec: '3.8' diff --git a/tools/ci_build/github/azure-pipelines/linux-cpu-aten-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-cpu-aten-pipeline.yml index 2c5a69e216d14..146186e9eeaf5 100644 --- a/tools/ci_build/github/azure-pipelines/linux-cpu-aten-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-cpu-aten-pipeline.yml @@ -53,10 +53,6 @@ jobs: clean: true submodules: recursive - - task: NodeTool@0 - inputs: - versionSpec: '16.x' - - template: templates/get-docker-image-steps.yml parameters: Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_aten_cpu diff --git a/tools/ci_build/github/azure-pipelines/linux-multi-gpu-tensorrt-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-multi-gpu-tensorrt-ci-pipeline.yml index 0a7dc0e456a95..e4441853240e5 100644 --- a/tools/ci_build/github/azure-pipelines/linux-multi-gpu-tensorrt-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-multi-gpu-tensorrt-ci-pipeline.yml @@ -36,5 +36,3 @@ jobs: JobName: 'Linux_CI_Multi_GPU_TensorRT_Dev' # The latest TensorRT container only supports ubuntu20.04 and python 3.8 RunDockerBuildArgs: '-o ubuntu20.04 -d tensorrt -x "--enable_multi_device_test"' - DoNugetPack: 'false' - ArtifactName: 'drop-linux' diff --git a/tools/ci_build/github/azure-pipelines/linux-openvino-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-openvino-ci-pipeline.yml index 93ee17b4cc7e6..c92fc93abba37 100644 --- a/tools/ci_build/github/azure-pipelines/linux-openvino-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-openvino-ci-pipeline.yml @@ -33,6 +33,4 @@ jobs: AgentPool : 'Linux-CPU-2019' JobName: 'Linux_CI_Dev' RunDockerBuildArgs: '-o ubuntu20.04 -d openvino -v 2023.0.0 -x "--use_openvino CPU_FP32 --build_wheel"' - DoNugetPack: 'false' - ArtifactName: 'drop-linux' TimeoutInMinutes: 120 diff --git a/tools/ci_build/github/azure-pipelines/orttraining-linux-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/orttraining-linux-ci-pipeline.yml index 007630edb25be..018672e0b2dea 100644 --- a/tools/ci_build/github/azure-pipelines/orttraining-linux-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/orttraining-linux-ci-pipeline.yml @@ -54,10 +54,6 @@ jobs: clean: true submodules: none - - task: NodeTool@0 - inputs: - versionSpec: '16.x' - - task: UsePythonVersion@0 inputs: versionSpec: '3.8' @@ -88,6 +84,7 @@ jobs: mkdir -p $(Pipeline.Workspace)/ccache docker run --rm \ --volume /data/onnx:/data/onnx:ro \ + --volume /data/models:/build/models:ro \ --volume $(Build.SourcesDirectory):/onnxruntime_src \ --volume $(Build.BinariesDirectory):/build \ --volume $HOME/.onnx:/home/onnxruntimedev/.onnx \ @@ -109,51 +106,11 @@ jobs: --build_wheel \ --enable_onnx_tests \ --enable_training \ - --use_cache \ - --update --build; \ + --use_cache; \ ccache -sv; \ ccache -z" workingDirectory: $(Build.SourcesDirectory) - - task: CmdLine@2 - displayName: 'Install python deps' - inputs: - script: | - set -e -x - python3 -m pip uninstall -y ort-nightly-gpu ort-nightly onnxruntime onnxruntime-gpu onnxruntime-training onnxruntime-directml ort-nightly-directml onnx -qq - cp $(Build.SourcesDirectory)/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt $(Build.BinariesDirectory)/requirements.txt - # Test ORT with the latest ONNX release. - sed -i "s/git+http:\/\/github\.com\/onnx\/onnx.*/onnx/" $(Build.BinariesDirectory)/requirements.txt - python3 -m pip install -r $(Build.BinariesDirectory)/requirements.txt - mkdir $(Build.BinariesDirectory)/requirements_torch_cpu/ - cp $(Build.SourcesDirectory)/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch_cpu/requirements.txt $(Build.BinariesDirectory)/requirements_torch_cpu/requirements.txt - python3 -m pip install -r $(Build.BinariesDirectory)/requirements_torch_cpu/requirements.txt - - - task: CmdLine@2 - displayName: 'Install Release python package' - inputs: - script: | - rm -rf $(Build.BinariesDirectory)/Release/onnxruntime $(Build.BinariesDirectory)/Release/pybind11 - python3 -m pip install $(Build.BinariesDirectory)/Release/dist/*.whl - - - task: PythonScript@0 - displayName: 'Run Release unit tests' - inputs: - scriptPath: $(Build.SourcesDirectory)/tools/ci_build/build.py - workingDirectory: $(Build.BinariesDirectory)/Release - arguments: >- - --build_dir $(Build.BinariesDirectory) - --cmake_generator Ninja - --config Release - --test - --skip_submodule_sync - --build_shared_lib - --parallel - --build_wheel - --enable_onnx_tests - --enable_training - --ctest_path "" - - task: PublishTestResults@2 displayName: 'Publish unit test results' inputs: diff --git a/tools/ci_build/github/azure-pipelines/orttraining-linux-gpu-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/orttraining-linux-gpu-ci-pipeline.yml index adf5695bd76eb..2d2719fef8f3d 100644 --- a/tools/ci_build/github/azure-pipelines/orttraining-linux-gpu-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/orttraining-linux-gpu-ci-pipeline.yml @@ -32,7 +32,6 @@ jobs: parameters: AgentPool : 'Onnxruntime-Linux-GPU-NC6sv3' JobName: 'Onnxruntime_Linux_GPU_Training' - SubmoduleCheckoutMode: 'recursive' RunDockerBuildArgs: > -o ubuntu20.04 -d gpu -t onnxruntime_orttraining_ortmodule_tests_image @@ -40,24 +39,16 @@ jobs: -e -x " --enable_training - --config $(buildConfig) + --config Release --use_cuda --cuda_version=11.8 --cuda_home=/usr/local/cuda-11.8 --cudnn_home=/usr/local/cuda-11.8 --build_wheel --enable_nvtx_profile --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=70 " - DoNugetPack: 'false' RunInjectedPipeline: 'true' InjectedPipeline: 'orttraining-linux-gpu-test-ci-pipeline.yml' DockerImageTag: 'onnxruntime_orttraining_ortmodule_tests_image' - BuildConfig: $(buildConfig) - ArtifactName: 'drop-linux' TimeoutInMinutes: 140 # Enable unreleased onnx opsets in CI builds # This facilitates testing the implementation for the new opsets AllowReleasedOpsetOnly: '0' - Strategy: - maxParallel: 2 - matrix: - Release: - buildConfig: Release diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml index 46f2ae7b97acc..3b1fde6cb6e4f 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml @@ -97,7 +97,7 @@ jobs: - task: NodeTool@0 inputs: - versionSpec: '16.x' + versionSpec: '18.x' force32bit: ${{ parameters.isX86 }} # Our build machine doesn't have java x86 diff --git a/tools/ci_build/github/azure-pipelines/templates/linux-ci.yml b/tools/ci_build/github/azure-pipelines/templates/linux-ci.yml index 05b2dee77e689..7b9788d90b17d 100644 --- a/tools/ci_build/github/azure-pipelines/templates/linux-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/linux-ci.yml @@ -1,23 +1,14 @@ parameters: AgentPool : 'onnxruntime-Ubuntu2004-AMD-CPU' StageName : 'Linux_CI_Dev' - SubmoduleCheckoutMode: '' RunDockerBuildArgs: '-o ubuntu20.04 -d cpu -x "--build_wheel"' - DoNodejsPack: 'false' - DoNugetPack: 'false' NuPackScript: '' RunInjectedPipeline: 'false' InjectedPipeline: '' DockerImageTag: '' - BuildConfig: '' - ArtifactName: 'drop-linux' TimeoutInMinutes: 120 # Controls whether unreleased onnx opsets are allowed. Default is set to 1 AllowReleasedOpsetOnly: '1' - # to inject strategy, you need to pass in the whole yaml structure - - # https://docs.microsoft.com/en-us/azure/devops/pipelines/yaml-schema?view=azure-devops&tabs=schema#strategies - # see example in orttraining-linux-gpu-ci-pipeline.yml - Strategy: '' jobs: - job: ${{ parameters.StageName }} @@ -28,16 +19,8 @@ jobs: ALLOW_RELEASED_ONNX_OPSET_ONLY: ${{ parameters.AllowReleasedOpsetOnly }} skipComponentGovernanceDetection: true pool: ${{ parameters.AgentPool }} - ${{ if ne(parameters.Strategy, '') }}: - strategy: - ${{ parameters.Strategy }} steps: - checkout: self - ${{ if ne(parameters.SubmoduleCheckoutMode, '') }}: - submodules: ${{ parameters.SubmoduleCheckoutMode }} - - task: NodeTool@0 - inputs: - versionSpec: '16.x' - template: run-docker-build-steps.yml parameters: RunDockerBuildArgs: '${{ parameters.RunDockerBuildArgs }}' @@ -48,31 +31,10 @@ jobs: searchFolder: '$(Build.BinariesDirectory)' testRunTitle: 'Unit Test Run' condition: succeededOrFailed() - - ${{ if eq(parameters['DoNugetPack'], 'true') }}: - - script: | - ${{ parameters.NuPackScript }} - displayName: 'Create Artifacts' - - task: PublishPipelineArtifact@0 - displayName: 'Publish Pipeline Artifact' - inputs: - artifactName: ${{ parameters.ArtifactName }} - targetPath: '$(Build.ArtifactStagingDirectory)' - - ${{ if eq(parameters['DoNodejsPack'], 'true') }}: - - script: | - npm pack - cp $(Build.SourcesDirectory)/js/node/onnxruntime-*.tgz $(Build.ArtifactStagingDirectory) - cp -R $(Build.SourcesDirectory)/js/node/prebuilds $(Build.ArtifactStagingDirectory)/prebuilds - workingDirectory: '$(Build.SourcesDirectory)/js/node' - displayName: 'Create NPM Package' - - task: PublishPipelineArtifact@0 - displayName: 'Publish Pipeline Artifact: ${{ parameters.ArtifactName }}' - inputs: - artifactName: ${{ parameters.ArtifactName }} - targetPath: '$(Build.ArtifactStagingDirectory)' - ${{ if eq(parameters['RunInjectedPipeline'], 'true') }}: - template: | ${{ parameters.InjectedPipeline }} parameters: DockerImageTag: ${{ parameters.DockerImageTag }} - BuildConfig: ${{ parameters.BuildConfig }} + BuildConfig: Release - template: clean-agent-build-directory-step.yml diff --git a/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml b/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml index 0e584b550f562..96a0ebd753d8e 100644 --- a/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml @@ -83,7 +83,7 @@ jobs: architecture: $(buildArch) - task: NodeTool@0 inputs: - versionSpec: '16.x' + versionSpec: '18.x' - template: download-deps.yml - task: PythonScript@0 diff --git a/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packing-jobs.yml b/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packing-jobs.yml index adfcd98e37230..f5e5435cfaca0 100644 --- a/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packing-jobs.yml +++ b/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packing-jobs.yml @@ -50,7 +50,7 @@ jobs: versionSpec: 3.11 - task: NodeTool@0 inputs: - versionSpec: '16.x' + versionSpec: '18.x' - template: set-version-number-variables-step.yml diff --git a/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml b/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml index 8c54e71448992..e63939ae0114c 100644 --- a/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml @@ -80,7 +80,7 @@ stages: - task: NodeTool@0 inputs: - versionSpec: '16.x' + versionSpec: '18.x' - script: brew install coreutils ninja npm yarn diff --git a/tools/ci_build/github/azure-pipelines/templates/web-browserstack-ci.yml b/tools/ci_build/github/azure-pipelines/templates/web-browserstack-ci.yml index 4494fd36b336e..96e6ff89cd4f1 100644 --- a/tools/ci_build/github/azure-pipelines/templates/web-browserstack-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/web-browserstack-ci.yml @@ -29,7 +29,7 @@ jobs: - task: NodeTool@0 inputs: - versionSpec: '16.x' + versionSpec: '18.x' - task: DownloadPipelineArtifact@2 inputs: patterns: 'Release_*/**/*' diff --git a/tools/ci_build/github/azure-pipelines/templates/web-ci.yml b/tools/ci_build/github/azure-pipelines/templates/web-ci.yml index 0b7bd3f645442..3f6c6af753a98 100644 --- a/tools/ci_build/github/azure-pipelines/templates/web-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/web-ci.yml @@ -74,7 +74,7 @@ stages: displayName: 'Checkout submodule onnx' - task: NodeTool@0 inputs: - versionSpec: '16.x' + versionSpec: '18.x' - template: linux-web-init-and-check.yml - task: Bash@3 displayName: 'Extract commit SHA and save to __commit.txt' diff --git a/tools/ci_build/github/azure-pipelines/templates/win-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-ci.yml index 80d285f3fd3fb..8d28b4ce580b4 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-ci.yml @@ -101,7 +101,7 @@ stages: - task: NodeTool@0 condition: and(succeeded(), eq('${{ parameters.buildNodejs}}', true)) inputs: - versionSpec: '16.x' + versionSpec: '18.x' - template: jobs/set-winenv.yml parameters: diff --git a/tools/ci_build/github/azure-pipelines/templates/win-wasm-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-wasm-ci.yml index 9d36e2dbe4944..406683af80222 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-wasm-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-wasm-ci.yml @@ -74,7 +74,7 @@ jobs: architecture: $(buildArch) - task: NodeTool@0 inputs: - versionSpec: '16.x' + versionSpec: '18.x' - template: download-deps.yml - task: PythonScript@0 diff --git a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml index bad7448715936..d737376eb99b5 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml @@ -72,7 +72,7 @@ jobs: displayName: 'Testing: force EOL to lf on windows for /js/**' - task: NodeTool@0 inputs: - versionSpec: '16.x' + versionSpec: '18.x' - task: DownloadPipelineArtifact@2 inputs: patterns: '${{ parameters.BuildConfig }}_*/**/*' diff --git a/tools/ci_build/github/azure-pipelines/templates/win-web-multi-browsers.yml b/tools/ci_build/github/azure-pipelines/templates/win-web-multi-browsers.yml index 723567389579d..f7876f15029c1 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-web-multi-browsers.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-web-multi-browsers.yml @@ -33,7 +33,7 @@ jobs: displayName: 'Checkout submodule onnx' - task: NodeTool@0 inputs: - versionSpec: '16.x' + versionSpec: '18.x' - task: DownloadPipelineArtifact@2 inputs: patterns: 'Release_*/**/*' diff --git a/tools/ci_build/github/azure-pipelines/win-ci-fuzz-testing.yml b/tools/ci_build/github/azure-pipelines/win-ci-fuzz-testing.yml index f3a5728d6519b..98f1bf7ea1a16 100644 --- a/tools/ci_build/github/azure-pipelines/win-ci-fuzz-testing.yml +++ b/tools/ci_build/github/azure-pipelines/win-ci-fuzz-testing.yml @@ -32,7 +32,7 @@ jobs: - task: NodeTool@0 inputs: - versionSpec: '16.x' + versionSpec: '18.x' - task: NuGetToolInstaller@0 displayName: Use Nuget 5.7.0 From 95e8dfaea51f6e6dda89b66d95c672927f578f21 Mon Sep 17 00:00:00 2001 From: aimilefth <60664743+aimilefth@users.noreply.github.com> Date: Tue, 26 Sep 2023 01:56:03 +0300 Subject: [PATCH 27/58] Update quant_utils.py/write_calibration_table (#17314) --- onnxruntime/python/tools/quantization/quant_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxruntime/python/tools/quantization/quant_utils.py b/onnxruntime/python/tools/quantization/quant_utils.py index 74e54c3f1fa37..739e399042bf3 100644 --- a/onnxruntime/python/tools/quantization/quant_utils.py +++ b/onnxruntime/python/tools/quantization/quant_utils.py @@ -505,7 +505,7 @@ def apply_plot(hist, hist_edges): plt.show() -def write_calibration_table(calibration_cache): +def write_calibration_table(calibration_cache, dir="."): """ Helper function to write calibration table to files. """ @@ -519,7 +519,7 @@ def write_calibration_table(calibration_cache): logging.info(f"calibration cache: {calibration_cache}") - with open("calibration.json", "w") as file: + with open(os.path.join(dir, "calibration.json"), "w") as file: file.write(json.dumps(calibration_cache)) # use `json.loads` to do the reverse # Serialize data using FlatBuffers @@ -551,7 +551,7 @@ def write_calibration_table(calibration_cache): builder.Finish(cal_table) buf = builder.Output() - with open("calibration.flatbuffers", "wb") as file: + with open(os.path.join(dir, "calibration.flatbuffers"), "wb") as file: file.write(buf) # Deserialize data (for validation) @@ -564,7 +564,7 @@ def write_calibration_table(calibration_cache): logging.info(key_value.Value()) # write plain text - with open("calibration.cache", "w") as file: + with open(os.path.join(dir, "calibration.cache"), "w") as file: for key in sorted(calibration_cache.keys()): value = calibration_cache[key] s = key + " " + str(max(abs(value[0]), abs(value[1]))) From ccb73fd827d29f69b1b5dfcbd4c27188b2364f0d Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Mon, 25 Sep 2023 20:03:24 -0700 Subject: [PATCH 28/58] [On-Device Training] Expose Parameters through the Training API (#17364) --- .../Training/CheckpointState.shared.cs | 133 +++++++---- .../Training/NativeTrainingMethods.shared.cs | 34 +++ .../Training/TrainingSession.shared.cs | 55 ++--- .../TrainingTest.cs | 128 ++++++++-- .../python/orttraining_pybind_state.cc | 80 ++++++- .../python/training/api/checkpoint_state.py | 220 ++++++++++++++++-- .../orttraining_test_python_bindings.py | 71 +++++- .../training_api/core/training_capi_tests.cc | 102 ++++++++ .../training_api/checkpoint_property.h | 10 +- .../include/onnxruntime_training_c_api.h | 61 ++++- .../include/onnxruntime_training_cxx_api.h | 36 ++- .../include/onnxruntime_training_cxx_inline.h | 12 + .../orttraining/training_api/module.cc | 59 +++++ orttraining/orttraining/training_api/module.h | 5 +- .../onnxruntime_training_c_api.cc | 79 ++++++- .../training_api/ort_training_apis.h | 10 + 16 files changed, 936 insertions(+), 159 deletions(-) diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs index 659c6303702ac..6889112acb385 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs @@ -40,20 +40,16 @@ internal enum PropertyType : long String = 2 } - private void AddPropertyImpl(string propertyName, PropertyType propertyType, T propertyValue) + private void AddPropertyImpl(string propertyName, PropertyType propertyType, T propertyValue) where T : unmanaged { var propertyNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyName); - T[] value = new T[1]; - value[0] = propertyValue; - Memory memory = value; - using (var memHandle = memory.Pin()) + T[] value = { propertyValue }; + unsafe { - IntPtr memPtr; - unsafe + fixed (T* memPtr = value) { - memPtr = (IntPtr)memHandle.Pointer; + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, propertyType, (IntPtr)memPtr)); } - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, propertyType, memPtr)); } } @@ -103,13 +99,13 @@ public static void SaveCheckpoint(CheckpointState state, string checkpointPath, } /// - /// Adds the given int property to the checkpoint state. + /// Adds or updates the given int property to/in the checkpoint state. /// - /// Runtime properties that are ints such as epoch, training step, and others can be added to the checkpoint - /// state by the user if they desire by calling this function with the appropriate property name and - /// value. The given property name must be unique to be able to successfully add the property. + /// Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint + /// state by the user by calling this function with the corresponding property name and value. + /// The given property name must be unique to be able to successfully add the property. /// - /// Unique name of the property being added. + /// Name of the property being added or updated. /// Property value associated with the given name. public void AddProperty(string propertyName, long propertyValue) { @@ -117,13 +113,13 @@ public void AddProperty(string propertyName, long propertyValue) } /// - /// Adds the given float property to the checkpoint state. + /// Adds or updates the given float property to/in the checkpoint state. /// - /// Runtime properties that are floats such as loss, best score, and others can be added to the checkpoint - /// state by the user if they desire by calling this function with the appropriate property name and - /// value. The given property name must be unique to be able to successfully add the property. + /// Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint + /// state by the user by calling this function with the corresponding property name and value. + /// The given property name must be unique to be able to successfully add the property. /// - /// Unique name of the property being added. + /// Name of the property being added or updated. /// Property value associated with the given name. public void AddProperty(string propertyName, float propertyValue) { @@ -131,28 +127,25 @@ public void AddProperty(string propertyName, float propertyValue) } /// - /// Adds the given string property to the checkpoint state. + /// Adds or updates the given string property to/in the checkpoint state. /// - /// Runtime properties that are strings such as parameter names, custom strings, and others can be added - /// to the checkpoint state by the user if they desire by calling this function with the appropriate property - /// name and value. The given property name must be unique to be able to successfully add the property. + /// Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint + /// state by the user by calling this function with the corresponding property name and value. + /// The given property name must be unique to be able to successfully add the property. /// - /// Unique name of the property being added. + /// Name of the property being added or updated. /// Property value associated with the given name. public void AddProperty(string propertyName, string propertyValue) { var propertyNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyName); var propertyValueUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyValue); - IntPtr unmanagedPointer = Marshal.AllocHGlobal(propertyValueUtf8.Length); - try - { - Marshal.Copy(propertyValueUtf8, 0, unmanagedPointer, propertyValueUtf8.Length); - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, PropertyType.String, unmanagedPointer)); - } - finally + unsafe { - Marshal.FreeHGlobal(unmanagedPointer); + fixed (byte* p = propertyValueUtf8) + { + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, PropertyType.String, (IntPtr)p)); + } } } @@ -162,34 +155,86 @@ public void AddProperty(string propertyName, string propertyValue) /// Gets the property value from an existing entry in the checkpoint state. The property must /// exist in the checkpoint state to be able to retrieve it successfully. /// - /// Unique name of the property being retrieved. + /// Name of the property being retrieved. /// Property value associated with the given property name. public object GetProperty(string propertyName) { var propertyNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyName); var allocator = OrtAllocator.DefaultInstance; IntPtr propertyValue = IntPtr.Zero; + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetProperty(handle, propertyNameUtf8, allocator.Pointer, out PropertyType propertyType, out propertyValue)); - if (propertyType == PropertyType.Int) + try { - var longPropertyValue = Marshal.ReadInt64(propertyValue); - allocator.FreeMemory(propertyValue); - return longPropertyValue; + if (propertyType == PropertyType.Int) + { + Int64 value; + unsafe + { + value = *(Int64*)propertyValue; + } + return value; + } + else if (propertyType == PropertyType.Float) + { + float value; + unsafe + { + value = *(float*)propertyValue; + } + return value; + } + else if (propertyType == PropertyType.String) + { + return NativeOnnxValueHelper.StringFromNativeUtf8(propertyValue); + } + + throw new ArgumentException("Expected the property type to be one of long, float or string. Unknown type retrieved " + propertyValue.ToString()); } - else if (propertyType == PropertyType.Float) + finally { - float[] value = new float[1]; - Marshal.Copy(propertyValue, value, 0, 1); allocator.FreeMemory(propertyValue); - return value[0]; } - else if (propertyType == PropertyType.String) + } + + /// + /// Updates the data associated with the model parameter in the checkpoint state for the given parameter name. + /// + /// This function updates a model parameter in the checkpoint state with the given parameter data. + /// The training session must be already created with the checkpoint state that contains the parameter + /// being updated. The given parameter is copied over to the registered device for the training session. + /// The parameter must exist in the checkpoint state to be able to update it successfully. + /// + /// Name of the parameter being updated. + /// The parameter data that should replace the existing parameter data. + public void UpdateParameter(string parameterName, OrtValue parameter) + { + if (parameter.OnnxType != OnnxValueType.ONNX_TYPE_TENSOR) { - return NativeOnnxValueHelper.StringFromNativeUtf8(propertyValue, allocator); + throw new ArgumentException("Incorrect buffer received. Expected a tensor parameter."); } - throw new ArgumentException("Expected the property type to be one of long, float or string. Unknown type retrieved " + propertyValue.ToString()); + var parameterNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(parameterName); + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtUpdateParameter(handle, parameterNameUtf8, parameter.Handle)); + } + + /// + /// Gets the data associated with the model parameter from the checkpoint state for the given parameter name. + /// + /// This function retrieves the model parameter data from the checkpoint state for the given parameter name. + /// The parameter is copied over to the provided OrtValue. The training session must be already created + /// with the checkpoint state that contains the parameter being retrieved. + /// The parameter must exist in the checkpoint state to be able to retrieve it successfully. + /// + /// Name of the parameter being updated. + /// The parameter data that is retrieved from the checkpoint state. + public OrtValue GetParameter(string parameterName) + { + var parameterNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(parameterName); + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParameter(handle, parameterNameUtf8, OrtAllocator.DefaultInstance.Pointer, out IntPtr parameterHandle)); + + return new OrtValue(parameterHandle); } #region SafeHandle diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs index 1868ff509bfc3..68a399f8b9671 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs @@ -42,6 +42,9 @@ public struct OrtTrainingApi public IntPtr AddProperty; public IntPtr GetProperty; public IntPtr LoadCheckpointFromBuffer; + public IntPtr GetParameterTypeAndShape; + public IntPtr UpdateParameter; + public IntPtr GetParameter; } internal static class NativeTrainingMethods @@ -97,6 +100,9 @@ static NativeTrainingMethods() OrtGetEvalModelInputName = (DOrtGetEvalModelInputName)Marshal.GetDelegateForFunctionPointer(trainingApi_.TrainingSessionGetEvalModelInputName, typeof(DOrtGetEvalModelInputName)); OrtAddProperty = (DOrtAddProperty)Marshal.GetDelegateForFunctionPointer(trainingApi_.AddProperty, typeof(DOrtAddProperty)); OrtGetProperty = (DOrtGetProperty)Marshal.GetDelegateForFunctionPointer(trainingApi_.GetProperty, typeof(DOrtGetProperty)); + OrtGetParameterTypeAndShape = (DOrtGetParameterTypeAndShape)Marshal.GetDelegateForFunctionPointer(trainingApi_.GetParameterTypeAndShape, typeof(DOrtGetParameterTypeAndShape)); + OrtUpdateParameter = (DOrtUpdateParameter)Marshal.GetDelegateForFunctionPointer(trainingApi_.UpdateParameter, typeof(DOrtUpdateParameter)); + OrtGetParameter = (DOrtGetParameter)Marshal.GetDelegateForFunctionPointer(trainingApi_.GetParameter, typeof(DOrtGetParameter)); } } @@ -359,6 +365,34 @@ out UIntPtr inputCount public static DOrtGetProperty OrtGetProperty; + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /*(OrtStatus*)*/ DOrtGetParameterTypeAndShape( + IntPtr /*(OrtCheckpointState*)*/ checkpointState, + byte[] /*(const char*)*/ parameterName, + out IntPtr /*(OrtTensorTypeAndShapeInfo**)*/ parameterTypeAndShape + ); + + public static DOrtGetParameterTypeAndShape OrtGetParameterTypeAndShape; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /*(OrtStatus*)*/ DOrtUpdateParameter( + IntPtr /*(OrtCheckpointState*)*/ checkpointState, + byte[] /*(const char*)*/ parameterName, + IntPtr /*(OrtValue*)*/ parameter + ); + + public static DOrtUpdateParameter OrtUpdateParameter; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /*(OrtStatus*)*/ DOrtGetParameter( + IntPtr /*(OrtCheckpointState*)*/ checkpointState, + byte[] /*(const char*)*/ parameterName, + IntPtr /*(OrtAllocator*)*/ allocator, + out IntPtr /*(OrtValue**)*/ parameter + ); + + public static DOrtGetParameter OrtGetParameter; + #endregion TrainingSession API public static bool TrainingEnabled() diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs index 33993c2be135b..877677dcad57b 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs @@ -358,13 +358,14 @@ public void EvalStep( IReadOnlyCollection inputValues, IReadOnlyCollection outputValues) { - if (!_evalOutputCount.Equals(outputValues.Count)) + if (_evalOutputCount != (ulong)outputValues.Count()) { - throw new ArgumentException($"Length of {nameof(outputValues)} ({outputValues.Count}) must match that of train model ({_trainOutputCount})."); + throw new ArgumentException($"Length of {nameof(outputValues)} ({outputValues.Count}) must match that of eval model ({_evalOutputCount})."); } - IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues, true); + const bool isInput = true; + IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues, isInput); - IntPtr[] outputValuesArray = GetOrtValuesHandles(outputValues, false); /* pointers to Pre-allocated OrtValue instances */ + IntPtr[] outputValuesArray = GetOrtValuesHandles(outputValues, !isInput); /* pointers to Pre-allocated OrtValue instances */ NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtEvalStep(_nativeHandle, options.Handle, (UIntPtr)inputValues.Count, inputValuesArray, (UIntPtr)outputValues.Count, outputValuesArray)); } @@ -509,18 +510,17 @@ public void ExportModelForInferencing(string inferenceModelPath, IReadOnlyCollec /// Returns a contiguous buffer that holds a copy of all training state parameters /// /// Whether to only copy trainable parameters or to copy all parameters. - public FixedBufferOnnxValue ToBuffer(bool onlyTrainable) + public OrtValue ToBuffer(bool onlyTrainable) { UIntPtr bufferSize = UIntPtr.Zero; NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out bufferSize, onlyTrainable)); float[] bufferMemory = new float[bufferSize.ToUInt64()]; - var memInfo = OrtMemoryInfo.DefaultInstance; // CPU - var shape = new long[] { (long)bufferSize.ToUInt64() }; - var buffer = FixedBufferOnnxValue.CreateFromMemory(memInfo, bufferMemory, Tensors.TensorElementType.Float, shape, (long)bufferSize.ToUInt64() * sizeof(float)); + var shape = new long[] { (long)bufferSize }; + var buffer = OrtValue.CreateAllocatedTensorValue(OrtAllocator.DefaultInstance, Tensors.TensorElementType.Float, shape); - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyParametersToBuffer(_nativeHandle, buffer.Value.Handle, onlyTrainable)); + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyParametersToBuffer(_nativeHandle, buffer.Handle, onlyTrainable)); return buffer; } @@ -528,45 +528,30 @@ public FixedBufferOnnxValue ToBuffer(bool onlyTrainable) /// /// Loads the training session model parameters from a contiguous buffer /// - /// Contiguous buffer to load the parameters from. - public void FromBuffer(FixedBufferOnnxValue buffer) + /// Contiguous buffer to load the parameters from. + /// Whether to only load trainable parameters or to load all parameters. + public void FromBuffer(OrtValue ortValue, bool onlyTrainable) { - if (buffer.OnnxValueType != OnnxValueType.ONNX_TYPE_TENSOR) + if (ortValue.OnnxType != OnnxValueType.ONNX_TYPE_TENSOR) { throw new ArgumentException("Incorrect buffer received. Expected a tensor buffer."); } - IntPtr typeAndShapeInfo = IntPtr.Zero; - NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorTypeAndShape(buffer.Value.Handle, out typeAndShapeInfo)); - UIntPtr numDimensions = UIntPtr.Zero; - NativeApiStatus.VerifySuccess(NativeMethods.OrtGetDimensionsCount(typeAndShapeInfo, out numDimensions)); - if (numDimensions.ToUInt64() != 1) + var tensorInfo = ortValue.GetTensorTypeAndShape(); + if (tensorInfo.ElementDataType != Tensors.TensorElementType.Float) { - string errorMessage = "Incorrect buffer shape received. Expected a contiguous tensor buffer. Expected number of dimensions: 1, Actual: " + numDimensions.ToString(); - throw new ArgumentException(errorMessage); - } - - // Here buffer size represents the number of elements in the buffer - NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorShapeElementCount(typeAndShapeInfo, out UIntPtr bufferSize)); - - // OrtGetParametersSize returns the total number of elements in the model's parameters. - UIntPtr numElementsTrainingOnly = UIntPtr.Zero; - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out numElementsTrainingOnly, true)); - if ((ulong)bufferSize == (ulong)numElementsTrainingOnly) - { - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyBufferToParameters(_nativeHandle, buffer.Value.Handle, true)); - return; + throw new ArgumentException("Incorrect buffer received. Expected a tensor buffer of type float."); } UIntPtr numElements = UIntPtr.Zero; - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out numElements, false)); - if ((ulong)bufferSize != (ulong)numElements) + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out numElements, onlyTrainable)); + if ((ulong)tensorInfo.ElementCount != (ulong)numElements) { - string errorMessage = "Incorrect buffer size received. Expected size to be one of " + numElementsTrainingOnly.ToString() + " (training only) or " + numElements.ToString() + " (all parameters). Actual size: " + bufferSize.ToString(); + string errorMessage = "Incorrect buffer size received. Expected size to be " + numElements.ToString() + ". Actual size: " + tensorInfo.ElementCount.ToString(); throw new ArgumentException(errorMessage); } - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyBufferToParameters(_nativeHandle, buffer.Value.Handle, false)); + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyBufferToParameters(_nativeHandle, ortValue.Handle, onlyTrainable)); } /// diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs index ea2b6d7dbc118..68b1d5bcc6147 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs @@ -484,20 +484,23 @@ public void TestEvalModelOutputNames() public void TestToBuffer() { string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt"); - using (var cleanUp = new DisposableListTest()) + string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); + string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx"); + string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx"); + + using (var state = CheckpointState.LoadCheckpoint(checkpointPath)) + using (var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath)) { - var state = CheckpointState.LoadCheckpoint(checkpointPath); - cleanUp.Add(state); Assert.NotNull(state); - string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); - string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx"); - string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx"); - var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath); - cleanUp.Add(trainingSession); - - var buffer = trainingSession.ToBuffer(true); - cleanUp.Add(buffer); + using (var buffer = trainingSession.ToBuffer(true)) + { + Assert.NotNull(buffer); + var typeShape = buffer.GetTensorTypeAndShape(); + Assert.Equal(1, typeShape.DimensionsCount); + var fetchedShape = typeShape.Shape; + Assert.Equal(397510, fetchedShape[0]); + } } } @@ -505,22 +508,25 @@ public void TestToBuffer() public void TestFromBuffer() { string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt"); - using (var cleanUp = new DisposableListTest()) + string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); + string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx"); + string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx"); + + using (var state = CheckpointState.LoadCheckpoint(checkpointPath)) + using (var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath)) { - var state = CheckpointState.LoadCheckpoint(checkpointPath); - cleanUp.Add(state); Assert.NotNull(state); - string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); - string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx"); - string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx"); - - var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath); - cleanUp.Add(trainingSession); - var buffer = trainingSession.ToBuffer(true); - cleanUp.Add(buffer); + using (var buffer = trainingSession.ToBuffer(true)) + { + Assert.NotNull(buffer); + var typeShape = buffer.GetTensorTypeAndShape(); + Assert.Equal(1, typeShape.DimensionsCount); + var fetchedShape = typeShape.Shape; + Assert.Equal(397510, fetchedShape[0]); - trainingSession.FromBuffer(buffer); + trainingSession.FromBuffer(buffer, true); + } } } @@ -530,6 +536,82 @@ public void TestSetSeed() TrainingUtils.SetSeed(8888); } + [Fact(DisplayName = "TestGetParameter")] + public void TestGetParameter() + { + string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt"); + string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); + string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx"); + string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx"); + + using (var state = CheckpointState.LoadCheckpoint(checkpointPath)) + using (var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath)) + using (var parameter = state.GetParameter("fc1.weight")) + { + Assert.NotNull(state); + Assert.NotNull(parameter); + + var typeShape = parameter.GetTensorTypeAndShape(); + Assert.Equal(2, typeShape.DimensionsCount); + var fetchedShape = typeShape.Shape; + Assert.Equal(500, fetchedShape[0]); + Assert.Equal(784, fetchedShape[1]); + } + } + + [Fact(DisplayName = "TestUpdateParameter")] + public void TestUpdateParameter() + { + string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt"); + string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); + string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx"); + string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx"); + + using (var state = CheckpointState.LoadCheckpoint(checkpointPath)) + using (var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath)) + { + Assert.NotNull(state); + + using (var parameter = state.GetParameter("fc1.weight")) + { + Assert.NotNull(parameter); + var typeShape = parameter.GetTensorTypeAndShape(); + + Assert.Equal(2, typeShape.DimensionsCount); + var fetchedShape = typeShape.Shape; + Assert.Equal(500, fetchedShape[0]); + Assert.Equal(784, fetchedShape[1]); + + float maxVal = 20; + Random randNum = new Random(); + float[] updated_parameter_buffer = Enumerable + .Repeat(0, 500 * 784) + .Select(i => maxVal * (float)randNum.NextDouble()) + .ToArray(); + + using (var updated_parameter = OrtValue.CreateTensorValueFromMemory(updated_parameter_buffer, fetchedShape)) + { + state.UpdateParameter("fc1.weight", updated_parameter); + using (var current_parameter = state.GetParameter("fc1.weight")) + { + var current_parameter_tensor = current_parameter.GetTensorDataAsSpan().ToArray(); + Assert.Equal(updated_parameter_buffer, current_parameter_tensor); + Assert.NotEqual(parameter.GetTensorDataAsSpan().ToArray(), current_parameter_tensor); + } + + state.UpdateParameter("fc1.weight", parameter); + + using (var current_parameter = state.GetParameter("fc1.weight")) + { + var current_parameter_tensor = current_parameter.GetTensorDataAsSpan().ToArray(); + Assert.Equal(parameter.GetTensorDataAsSpan().ToArray(), current_parameter_tensor); + Assert.NotEqual(updated_parameter_buffer, current_parameter_tensor); + } + } + } + } + } + internal class FloatComparer : IEqualityComparer { private float atol = 1e-3f; diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc index 3f3aa396e6ca0..35d9755ba0ba7 100644 --- a/orttraining/orttraining/python/orttraining_pybind_state.cc +++ b/orttraining/orttraining/python/orttraining_pybind_state.cc @@ -1065,17 +1065,60 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn checkpoint_state(m, "CheckpointState", R"pbdoc(CheckpointState.)pbdoc"); checkpoint_state .def(py::init()) - .def("add_property", [](onnxruntime::training::api::CheckpointState* state, - const std::string& property_name, - const std::variant& property_value) { - state->property_bag.AddProperty(property_name, property_value); - }) - .def("get_property", [](onnxruntime::training::api::CheckpointState* state, const std::string& property_name) { - return state->property_bag.GetProperty(property_name); - }) - .def("has_property", [](onnxruntime::training::api::CheckpointState* state, const std::string& property_name) { - return state->property_bag.HasProperty(property_name); - }); + .def("add_property", + [](onnxruntime::training::api::CheckpointState* state, + const std::string& property_name, + const std::variant& property_value) { + state->property_bag.AddProperty(property_name, property_value); + }) + .def("get_property", + [](onnxruntime::training::api::CheckpointState* state, const std::string& property_name) { + return state->property_bag.GetProperty(property_name); + }) + .def("has_property", + [](onnxruntime::training::api::CheckpointState* state, const std::string& property_name) { + return state->property_bag.HasProperty(property_name); + }) + .def("copy_parameter_from", + [](onnxruntime::training::api::CheckpointState* state, + const std::string& parameter_name, OrtValue& value) -> void { + auto it = state->module_checkpoint_state.named_parameters.find(parameter_name); + if (it == state->module_checkpoint_state.named_parameters.end()) { + ORT_THROW("Parameter with name ", parameter_name, " does not exist."); + } + ORT_THROW_IF_ERROR(it->second->CopyFrom( + state->module_checkpoint_state.train_session_data_transfer_mgr, value)); + }) + .def("get_parameter", + [](onnxruntime::training::api::CheckpointState* state, const std::string& parameter_name) { + auto it = state->module_checkpoint_state.named_parameters.find(parameter_name); + if (it == state->module_checkpoint_state.named_parameters.end()) { + ORT_THROW("Parameter with name ", parameter_name, " does not exist."); + } + return it->second; + }) + .def("has_parameter", + [](onnxruntime::training::api::CheckpointState* state, const std::string& parameter_name) { + return state->module_checkpoint_state.named_parameters.count(parameter_name); + }) + .def("parameter_names", + [](onnxruntime::training::api::CheckpointState* state) { + std::vector names; + for ([[maybe_unused]] auto& [name, value] : state->module_checkpoint_state.named_parameters) { + names.push_back(name); + } + std::sort(names.begin(), names.end()); + return names; + }) + .def("property_names", + [](onnxruntime::training::api::CheckpointState* state) { + std::vector names; + for ([[maybe_unused]] auto& [name, value] : state->property_bag) { + names.push_back(name); + } + std::sort(names.begin(), names.end()); + return names; + }); py::class_ training_optimizer(m, "Optimizer", R"pbdoc(Training Optimizer.)pbdoc"); @@ -1111,6 +1154,21 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn ORT_THROW_IF_ERROR(scheduler->Step()); }); + py::class_> + parameter(m, "Parameter"); + parameter + .def_property_readonly("name", &onnxruntime::training::api::Parameter::Name) + .def_property_readonly("data", &onnxruntime::training::api::Parameter::Data) + .def_property_readonly("grad", &onnxruntime::training::api::Parameter::Gradient) + .def_property_readonly("requires_grad", &onnxruntime::training::api::Parameter::RequiresGrad) + .def("copy_from", + [](onnxruntime::training::api::Parameter* parameter, + onnxruntime::training::api::CheckpointState* state, + OrtValue& value) -> void { + ORT_THROW_IF_ERROR(parameter->CopyFrom(state->module_checkpoint_state.train_session_data_transfer_mgr, value)); + }); + m.def( "save_checkpoint", [](const std::vector& trainable_tensor_protos_pybytes, diff --git a/orttraining/orttraining/python/training/api/checkpoint_state.py b/orttraining/orttraining/python/training/api/checkpoint_state.py index 285264bbed744..ba95cd04fce7e 100644 --- a/orttraining/orttraining/python/training/api/checkpoint_state.py +++ b/orttraining/orttraining/python/training/api/checkpoint_state.py @@ -5,70 +5,171 @@ import os +import numpy as np + from onnxruntime.capi import _pybind_state as C +from onnxruntime.capi.onnxruntime_inference_collection import OrtValue -class CheckpointState: - """Class that holds the state of the training session +class Parameter: + """Class that represents a model parameter - This class holds all the state information of the training session such as the model parameters, - its gradients, the optimizer state and user defined properties. + This class represents a model parameter and provides access to its data, + gradient and other properties. This class is not expected to be instantiated directly. + Instead, it is returned by the `CheckpointState` object. + + Args: + parameter: The C.Parameter object that holds the underlying parameter data. + state: The C.CheckpointState object that holds the underlying session state. + """ + + def __init__(self, parameter: C.Parameter, state: C.CheckpointState): + self._parameter = parameter + self._state = state - User defined properties can be indexed by name from the `CheckpointState` object. + @property + def name(self) -> str: + """The name of the parameter""" + return self._parameter.name - To create the `CheckpointState`, use the `CheckpointState.load_checkpoint` method. + @property + def data(self) -> np.ndarray: + """The data of the parameter""" + return self._parameter.data.numpy() + + @data.setter + def data(self, value: np.ndarray) -> None: + """Sets the data of the parameter""" + self._parameter.copy_from(self._state, OrtValue.ortvalue_from_numpy(value)._ortvalue) + + @property + def grad(self) -> np.ndarray: + """The gradient of the parameter""" + return self._parameter.grad.numpy() if self._parameter.grad.has_value() else None + + @property + def requires_grad(self) -> bool: + """Whether or not the parameter requires its gradient to be computed""" + return self._parameter.requires_grad + + def __repr__(self) -> str: + """Returns a string representation of the parameter""" + return f"Parameter(name={self.name}, requires_grad={self.requires_grad})" + + +class Parameters: + """Class that holds all the model parameters + + This class holds all the model parameters and provides access to them. + This class is not expected to be instantiated directly. Instead, it is returned by the + `CheckpointState`'s parameters attribute. + This class behaves like a dictionary and provides access to the parameters by name. Args: - state: The C.Checkpoint state object that holds the underlying session state. + state: The C.CheckpointState object that holds the underlying session state. """ def __init__(self, state: C.CheckpointState): - if not isinstance(state, C.CheckpointState): - raise TypeError(f"Invalid argument for CheckpointState received {type(state)}") self._state = state - @classmethod - def load_checkpoint(cls, checkpoint_uri: str | os.PathLike) -> CheckpointState: - """Loads the checkpoint state from the checkpoint file + def __getitem__(self, name: str) -> Parameter: + """Gets the parameter associated with the given name + + Searches for the name in the parameters of the checkpoint state. Args: - checkpoint_uri: The path to the checkpoint file. + name: The name of the parameter Returns: - CheckpointState: The checkpoint state object. + The value of the parameter + + Raises: + KeyError: If the parameter is not found """ - return cls(C.load_checkpoint(os.fspath(checkpoint_uri))) - @classmethod - def save_checkpoint( - cls, state: CheckpointState, checkpoint_uri: str | os.PathLike, include_optimizer_state: bool = False - ) -> None: - """Saves the checkpoint state to the checkpoint file + if name not in self: + raise KeyError(f"Parameter {name} not found.") + + return Parameter(self._state.get_parameter(name), self._state) + + def __setitem__(self, name: str, value: np.ndarray) -> None: + """Sets the parameter value for the given name + + Searches for the name in the parameters of the checkpoint state. + If the name is found in parameters, the value is updated. Args: - state: The checkpoint state object. - checkpoint_uri: The path to the checkpoint file. - include_optimizer_state: If True, the optimizer state is also saved to the checkpoint file. + name: The name of the parameter + value: The value of the parameter as a numpy array + + Raises: + KeyError: If the parameter is not found """ - C.save_checkpoint(state._state, os.fspath(checkpoint_uri), include_optimizer_state) + if name not in self: + raise KeyError(f"Parameter {name} not found.") + + self._state.copy_parameter_from(name, OrtValue.ortvalue_from_numpy(value)._ortvalue) + + def __contains__(self, name: str) -> bool: + """Checks if the parameter exists in the state + + Args: + name: The name of the parameter + + Returns: + True if the name is a parameter False otherwise + """ + + return self._state.has_parameter(name) + + def __iter__(self): + """Returns an iterator over the properties""" + for parameter_name in self._state.parameter_names(): + yield parameter_name, Parameter(self._state.get_parameter(parameter_name), self._state) + + def __repr__(self) -> str: + """Returns a string representation of the parameters""" + return self._state.parameter_names() + + def __len__(self) -> int: + """Returns the number of parameters""" + return len(self._state.parameter_names()) + + +class Properties: + def __init__(self, state: C.CheckpointState): + self._state = state def __getitem__(self, name: str) -> int | float | str: """Gets the property associated with the given name + Searches for the name in the properties of the checkpoint state. + Args: name: The name of the property Returns: The value of the property + + Raises: + KeyError: If the property is not found """ + + if name not in self: + raise KeyError(f"Property {name} not found.") + return self._state.get_property(name) def __setitem__(self, name: str, value: int | float | str) -> None: """Sets the property value for the given name + Searches for the name in the properties of the checkpoint state. + The value is added or updated in the properties. + Args: name: The name of the property value: The value of the property + Properties only support int, float and str values. """ self._state.add_property(name, value) @@ -79,6 +180,75 @@ def __contains__(self, name: str) -> bool: name: The name of the property Returns: - True if the property exists, False otherwise + True if the name is a property, False otherwise """ + return self._state.has_property(name) + + def __iter__(self): + """Returns an iterator over the properties""" + for property_name in self._state.property_names(): + yield property_name, self._state.get_property(property_name) + + def __repr__(self) -> str: + """Returns a string representation of the properties""" + return self._state.property_names() + + def __len__(self) -> int: + """Returns the number of properties""" + return len(self._state.property_names()) + + +class CheckpointState: + """Class that holds the state of the training session + + This class holds all the state information of the training session such as the model parameters, + its gradients, the optimizer state and user defined properties. + + To create the `CheckpointState`, use the `CheckpointState.load_checkpoint` method. + + Args: + state: The C.Checkpoint state object that holds the underlying session state. + """ + + def __init__(self, state: C.CheckpointState): + if not isinstance(state, C.CheckpointState): + raise TypeError(f"Invalid argument for CheckpointState received {type(state)}") + self._state = state + self._parameters = Parameters(self._state) + self._properties = Properties(self._state) + + @classmethod + def load_checkpoint(cls, checkpoint_uri: str | os.PathLike) -> CheckpointState: + """Loads the checkpoint state from the checkpoint file + + Args: + checkpoint_uri: The path to the checkpoint file. + + Returns: + CheckpointState: The checkpoint state object. + """ + return cls(C.load_checkpoint(os.fspath(checkpoint_uri))) + + @classmethod + def save_checkpoint( + cls, state: CheckpointState, checkpoint_uri: str | os.PathLike, include_optimizer_state: bool = False + ) -> None: + """Saves the checkpoint state to the checkpoint file + + Args: + state: The checkpoint state object. + checkpoint_uri: The path to the checkpoint file. + include_optimizer_state: If True, the optimizer state is also saved to the checkpoint file. + """ + C.save_checkpoint(state._state, os.fspath(checkpoint_uri), include_optimizer_state) + + @property + def parameters(self) -> Parameters: + """Returns the model parameters from the checkpoint state""" + return self._parameters + + @property + def properties(self) -> Properties: + """Returns the properties from the checkpoint state""" + return self._properties diff --git a/orttraining/orttraining/test/python/orttraining_test_python_bindings.py b/orttraining/orttraining/test/python/orttraining_test_python_bindings.py index 56338ddbaffef..d5c37b3e36ee7 100644 --- a/orttraining/orttraining/test/python/orttraining_test_python_bindings.py +++ b/orttraining/orttraining/test/python/orttraining_test_python_bindings.py @@ -360,14 +360,18 @@ def test_add_get_property(property_value): if isinstance(property_value, float): property_value = float(np.float32(property_value)) - state["property"] = property_value - assert "property" in state - assert state["property"] == property_value + assert len(state.properties) == 0 + + state.properties["property"] = property_value + assert "property" in state.properties + assert state.properties["property"] == property_value + assert len(state.properties) == 1 CheckpointState.save_checkpoint(state, checkpoint_file_path) new_state = CheckpointState.load_checkpoint(checkpoint_file_path) - assert "property" in new_state - assert new_state["property"] == property_value + assert "property" in new_state.properties + assert new_state.properties["property"] == property_value + assert len(new_state.properties) == 1 def test_get_input_output_names(): @@ -563,3 +567,60 @@ def test_eval_step_with_ort_values(): fetches = model(inputs, labels) assert isinstance(fetches, OrtValue) assert fetches + + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_get_and_set_parameter_values(device): + with tempfile.TemporaryDirectory() as temp_dir: + ( + checkpoint_file_path, + training_model_file_path, + eval_model_file_path, + _, + pt_model, + ) = _create_training_artifacts( + temp_dir, requires_grad=["fc2.weight", "fc2.bias"], frozen_params=["fc1.weight", "fc1.bias"] + ) + + state = CheckpointState.load_checkpoint(checkpoint_file_path) + + model = Module(training_model_file_path, state, eval_model_file_path, device=device) + + state_dict = pt_model.state_dict() + assert len(state_dict) == len(state.parameters) + for parameter_name, _ in state.parameters: + assert parameter_name in state_dict + + for name, pt_param in pt_model.named_parameters(): + ort_param = state.parameters[name] + assert ort_param.name == name + assert np.allclose(pt_param.detach().cpu().numpy(), ort_param.data) + if name in ["fc1.weight", "fc1.bias"]: + assert ort_param.requires_grad is False + assert ort_param.grad is None + else: + assert ort_param.requires_grad is True + assert np.allclose(ort_param.grad, np.zeros_like(ort_param.data, dtype=np.float32)) + + original_param = state.parameters["fc1.weight"].data + state.parameters["fc1.weight"].data = np.ones_like(state.parameters["fc1.weight"].data, dtype=np.float32) + updated_param = state.parameters["fc1.weight"].data + assert np.allclose(updated_param, np.ones_like(updated_param, dtype=np.float32)) + + model.train() + inputs = torch.randn(64, 784).numpy() + labels = torch.randint(high=10, size=(64,), dtype=torch.int64).numpy() + loss = model(inputs, labels) + assert loss is not None + for name, _ in pt_model.named_parameters(): + ort_param = state.parameters[name] + assert ort_param.name == name + if name in ["fc1.weight", "fc1.bias"]: + assert ort_param.requires_grad is False + assert ort_param.grad is None + else: + assert ort_param.requires_grad is True + assert ort_param.grad.any() + + state.parameters["fc1.weight"] = original_param + assert np.allclose(state.parameters["fc1.weight"].data, original_param) diff --git a/orttraining/orttraining/test/training_api/core/training_capi_tests.cc b/orttraining/orttraining/test/training_api/core/training_capi_tests.cc index d734be8e3474b..e46952d87c2bf 100644 --- a/orttraining/orttraining/test/training_api/core/training_capi_tests.cc +++ b/orttraining/orttraining/test/training_api/core/training_capi_tests.cc @@ -318,4 +318,106 @@ TEST(TrainingCApiTest, LoadModelsFromBufferThrows) { testing::HasSubstr("Training Session Creation failed. Train model data cannot be NULL.")); } } + +TEST(TrainingCApiTest, GetParameter) { + auto model_uri = MODEL_FOLDER "training_model.onnx"; + + Ort::Env env; + Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt"); + Ort::TrainingSession training_session = Ort::TrainingSession(env, Ort::SessionOptions(), checkpoint_state, model_uri); + + Ort::Value parameter = checkpoint_state.GetParameter("fc1.weight"); + auto tensor_info = parameter.GetTensorTypeAndShapeInfo(); + auto shape = tensor_info.GetShape(); + ASSERT_EQ(shape.size(), 2U); + ASSERT_EQ(shape.front(), static_cast(500)); + ASSERT_EQ(shape.back(), static_cast(784)); +} + +TEST(TrainingCApiTest, UpdateParameter) { + auto model_uri = MODEL_FOLDER "training_model.onnx"; + + Ort::Env env; + Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt"); + Ort::TrainingSession training_session = Ort::TrainingSession(env, Ort::SessionOptions(), checkpoint_state, model_uri); + + Ort::Value parameter = checkpoint_state.GetParameter("fc1.weight"); + auto tensor_info = parameter.GetTensorTypeAndShapeInfo(); + auto shape = tensor_info.GetShape(); + ASSERT_EQ(shape.size(), 2U); + ASSERT_EQ(shape.front(), static_cast(500)); + ASSERT_EQ(shape.back(), static_cast(784)); + + OrtValue* updated_param_value = std::make_unique().release(); + GenerateRandomInput(std::array{500, 784}, *updated_param_value); + Ort::Value updated_parameter{updated_param_value}; + checkpoint_state.UpdateParameter("fc1.weight", updated_parameter); + + Ort::Value current_parameter = checkpoint_state.GetParameter("fc1.weight"); + gsl::span actual = gsl::span(current_parameter.GetTensorMutableData(), + current_parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + gsl::span expected = gsl::span(updated_parameter.GetTensorMutableData(), + updated_parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + gsl::span not_expected = gsl::span(parameter.GetTensorMutableData(), + parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + ASSERT_EQ(actual, expected); + ASSERT_NE(actual, not_expected); + + checkpoint_state.UpdateParameter("fc1.weight", parameter); + current_parameter = checkpoint_state.GetParameter("fc1.weight"); + actual = gsl::span(current_parameter.GetTensorMutableData(), + current_parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + expected = gsl::span(parameter.GetTensorMutableData(), + parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + not_expected = gsl::span(updated_parameter.GetTensorMutableData(), + updated_parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + ASSERT_EQ(actual, expected); + ASSERT_NE(actual, not_expected); +} + +#ifdef USE_CUDA +TEST(TrainingCApiTest, UpdateParameterDifferentDevices) { + auto model_uri = MODEL_FOLDER "training_model.onnx"; + + Ort::Env env; + Ort::SessionOptions session_options; + Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0)); + Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt"); + Ort::TrainingSession training_session = Ort::TrainingSession(env, session_options, checkpoint_state, model_uri); + + Ort::Value parameter = checkpoint_state.GetParameter("fc1.weight"); + auto tensor_info = parameter.GetTensorTypeAndShapeInfo(); + auto shape = tensor_info.GetShape(); + ASSERT_EQ(shape.size(), 2U); + ASSERT_EQ(shape.front(), static_cast(500)); + ASSERT_EQ(shape.back(), static_cast(784)); + + OrtValue* updated_param_value = std::make_unique().release(); + GenerateRandomInput(std::array{500, 784}, *updated_param_value); + Ort::Value updated_parameter{updated_param_value}; + checkpoint_state.UpdateParameter("fc1.weight", updated_parameter); + + Ort::Value current_parameter = checkpoint_state.GetParameter("fc1.weight"); + gsl::span actual = gsl::span(current_parameter.GetTensorMutableData(), + current_parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + gsl::span expected = gsl::span(updated_parameter.GetTensorMutableData(), + updated_parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + gsl::span not_expected = gsl::span(parameter.GetTensorMutableData(), + parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + ASSERT_EQ(actual, expected); + ASSERT_NE(actual, not_expected); + + checkpoint_state.UpdateParameter("fc1.weight", parameter); + current_parameter = checkpoint_state.GetParameter("fc1.weight"); + actual = gsl::span(current_parameter.GetTensorMutableData(), + current_parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + expected = gsl::span(parameter.GetTensorMutableData(), + parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + not_expected = gsl::span(updated_parameter.GetTensorMutableData(), + updated_parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + ASSERT_EQ(actual, expected); + ASSERT_NE(actual, not_expected); +} +#endif + } // namespace onnxruntime::training::test diff --git a/orttraining/orttraining/training_api/checkpoint_property.h b/orttraining/orttraining/training_api/checkpoint_property.h index d7b1e295df53e..3c38c99b3152f 100644 --- a/orttraining/orttraining/training_api/checkpoint_property.h +++ b/orttraining/orttraining/training_api/checkpoint_property.h @@ -22,10 +22,12 @@ struct PropertyBag { PropertyBag() = default; void AddProperty(const std::string& name, const PropertyDataType& val) { - ORT_ENFORCE(named_properties_.find(name) == named_properties_.end(), - "Duplicated property named ", name); - - named_properties_.insert({name, val}); + auto it = named_properties_.find(name); + if (it == named_properties_.end()) { + named_properties_.insert({name, val}); + } else { + it->second = val; + } } template diff --git a/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h b/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h index 0af737074964d..0e8544a7639ba 100644 --- a/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h +++ b/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h @@ -608,14 +608,14 @@ struct OrtTrainingApi { /// \name Accessing The Training Session State /// @{ - /** \brief Adds the given property to the checkpoint state. + /** \brief Adds or updates the given property to/in the checkpoint state. * * Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint - * state by the user if they desire by calling this function with the appropriate property name and - * value. The given property name must be unique to be able to successfully add the property. + * state by the user by calling this function with the corresponding property name and value. + * The given property name must be unique to be able to successfully add the property. * * \param[in] checkpoint_state The checkpoint state which should hold the property. - * \param[in] property_name Unique name of the property being added. + * \param[in] property_name Name of the property being added or updated. * \param[in] property_type Type of the property associated with the given name. * \param[in] property_value Property value associated with the given name. * @@ -632,7 +632,7 @@ struct OrtTrainingApi { * exist in the checkpoint state to be able to retrieve it successfully. * * \param[in] checkpoint_state The checkpoint state that is currently holding the property. - * \param[in] property_name Unique name of the property being retrieved. + * \param[in] property_name Name of the property being retrieved. * \param[in] allocator Allocator used to allocate the memory for the property_value. * \param[out] property_type Type of the property associated with the given name. * \param[out] property_value Property value associated with the given name. @@ -669,6 +669,57 @@ struct OrtTrainingApi { ORT_API2_STATUS(LoadCheckpointFromBuffer, _In_ const void* checkpoint_buffer, _In_ const size_t num_bytes, _Outptr_ OrtCheckpointState** checkpoint_state); + /** \brief Retrieves the type and shape information of the parameter associated with the given parameter name. + * + * This function retrieves the type and shape of the parameter associated with the given parameter name. + * The parameter must exist in the checkpoint state to be able to retrieve its type and shape information successfully. + * + * \param[in] checkpoint_state The checkpoint state. + * \param[in] parameter_name Name of the parameter being retrieved. + * \param[out] parameter_type_and_shape The type and shape of the parameter being retrieved. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + */ + ORT_API2_STATUS(GetParameterTypeAndShape, _In_ const OrtCheckpointState* checkpoint_state, + _In_ const char* parameter_name, _Outptr_ OrtTensorTypeAndShapeInfo** parameter_type_and_shape); + + /** \brief Updates the data associated with the model parameter in the checkpoint state for the given parameter name. + * + * This function updates a model parameter in the checkpoint state with the given parameter data. + * The training session must be already created with the checkpoint state that contains the parameter + * being updated. The given parameter is copied over to the registered device for the training session. + * The parameter must exist in the checkpoint state to be able to update it successfully. + * + * \param[in] checkpoint_state The checkpoint state. + * \param[in] parameter_name Name of the parameter being updated. + * \param[in] parameter The parameter data that should replace the existing parameter data. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + */ + ORT_API2_STATUS(UpdateParameter, _Inout_ OrtCheckpointState* checkpoint_state, + _In_ const char* parameter_name, _In_ OrtValue* parameter); + + /** \brief Gets the data associated with the model parameter from the checkpoint state for the given parameter name. + * + * This function retrieves the model parameter data from the checkpoint state for the given parameter name. + * The parameter is copied over and returned as an OrtValue. The training session must be already created + * with the checkpoint state that contains the parameter being retrieved. + * The parameter must exist in the checkpoint state to be able to retrieve it successfully. + * + * \param[in] checkpoint_state The checkpoint state. + * \param[in] parameter_name Name of the parameter being retrieved. + * \param[in] allocator Allocator used to allocate the memory for the parameter. + * \param[out] parameter The parameter data that is retrieved from the checkpoint state. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + */ + ORT_API2_STATUS(GetParameter, _In_ const OrtCheckpointState* checkpoint_state, + _In_ const char* parameter_name, _Inout_ OrtAllocator* allocator, + _Outptr_ OrtValue** parameter); + /// @} }; diff --git a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h index 0edef20ba6da8..218bef524200c 100644 --- a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h +++ b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h @@ -112,13 +112,13 @@ class CheckpointState : public detail::Base { const std::basic_string& path_to_checkpoint, const bool include_optimizer_state = false); - /** \brief Adds the given property to the checkpoint state. + /** \brief Adds or updates the given property to/in the checkpoint state. * * Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint - * state by the user if they desire by calling this function with the appropriate property name and - * value. The given property name must be unique to be able to successfully add the property. + * state by the user by calling this function with the corresponding property name and value. + * The given property name must be unique to be able to successfully add the property. * - * \param[in] property_name Unique name of the property being added. + * \param[in] property_name Name of the property being added or updated. * \param[in] property_value Property value associated with the given name. * */ @@ -129,12 +129,38 @@ class CheckpointState : public detail::Base { * Gets the property value from an existing entry in the checkpoint state. The property must * exist in the checkpoint state to be able to retrieve it successfully. * - * \param[in] property_name Unique name of the property being retrieved. + * \param[in] property_name Name of the property being retrieved. * \return Property value associated with the given property name. * */ Property GetProperty(const std::string& property_name); + /** \brief Updates the data associated with the model parameter in the checkpoint state for the given parameter name. + * + * This function updates a model parameter in the checkpoint state with the given parameter data. + * The training session must be already created with the checkpoint state that contains the parameter + * being updated. The given parameter is copied over to the registered device for the training session. + * The parameter must exist in the checkpoint state to be able to update it successfully. + * + * \param[in] parameter_name Name of the parameter being updated. + * \param[in] parameter The parameter data that should replace the existing parameter data. + * + */ + void UpdateParameter(const std::string& parameter_name, const Value& parameter); + + /** \brief Gets the data associated with the model parameter from the checkpoint state for the given parameter name. + * + * This function retrieves the model parameter data from the checkpoint state for the given parameter name. + * The parameter is copied over to the provided OrtValue. The training session must be already created + * with the checkpoint state that contains the parameter being retrieved. + * The parameter must exist in the checkpoint state to be able to retrieve it successfully. + * + * \param[in] parameter_name Name of the parameter being retrieved. + * \return The parameter data that is retrieved from the checkpoint state. + * + */ + Value GetParameter(const std::string& parameter_name); + /// @} }; diff --git a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h index c0048458ddf4d..7d1326a10f8f8 100644 --- a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h +++ b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h @@ -279,4 +279,16 @@ inline Property CheckpointState::GetProperty(const std::string& property_name) { return property; } +inline void CheckpointState::UpdateParameter(const std::string& parameter_name, const Value& parameter) { + ThrowOnError(GetTrainingApi().UpdateParameter(p_, parameter_name.c_str(), parameter)); +} + +inline Value CheckpointState::GetParameter(const std::string& parameter_name) { + AllocatorWithDefaultOptions allocator; + OrtValue* parameter; + ThrowOnError(GetTrainingApi().GetParameter(p_, parameter_name.c_str(), allocator, ¶meter)); + + return Value{parameter}; +} + } // namespace Ort diff --git a/orttraining/orttraining/training_api/module.cc b/orttraining/orttraining/training_api/module.cc index d1775e358163c..cf49a01517d6b 100644 --- a/orttraining/orttraining/training_api/module.cc +++ b/orttraining/orttraining/training_api/module.cc @@ -119,6 +119,61 @@ Status TransformModelInputsForInference(Graph& inference_graph, #endif } // namespace +Status Parameter::CopyTo(const DataTransferManager* data_transfer_manager, OrtValue& data) const { + ORT_ENFORCE(data.IsAllocated(), "Given parameter data is not allocated. Cannot copy the checkpoint parameter to it."); + ORT_ENFORCE(data.IsTensor(), "Parameter data should be of tensor type."); + ORT_ENFORCE(data.Get().Shape() == data_.Get().Shape(), + "Parameter data shape mismatch. Expected: ", data_.Get().Shape().ToString(), + ", Got: ", data.Get().Shape().ToString()); +#ifdef ENABLE_STRIDED_TENSORS + auto data_strides = data.Get().Strides(); + auto param_strides = data_.Get().Strides(); + ORT_ENFORCE(data_strides.size() == param_strides.size(), + "Parameter data stride mismatch. Expected strides of size: ", param_strides.size(), + ", Got: ", data_strides.size()); + ORT_ENFORCE(std::equal(data_strides.begin(), data_strides.end(), param_strides.begin()), + "Parameter data stride value mismatch."); +#endif + ORT_ENFORCE(data.Get().DataType() == data_.Get().DataType(), + "Parameter data type mismatch. Expected: ", data_.Get().DataType(), + ", Got: ", data.Get().DataType()); + ORT_ENFORCE(data_transfer_manager != nullptr, + "Data transfer manager must be provided to copy data to the parameter. " + "Please create the TrainingSession before trying to update the parameter."); + + ORT_THROW_IF_ERROR(data_transfer_manager->CopyTensor(data_.Get(), *data.GetMutable())); + + return Status::OK(); +} + +Status Parameter::CopyFrom(const DataTransferManager* data_transfer_manager, const OrtValue& data) { + ORT_ENFORCE(data_.IsAllocated(), + "The checkpoint parameter is not allocated. Cannot copy the given parameter data to it."); + ORT_ENFORCE(data.IsTensor(), "Parameter data should be of tensor type."); + ORT_ENFORCE(data.Get().Shape() == data_.Get().Shape(), + "Parameter data shape mismatch. Expected: ", data_.Get().Shape().ToString(), + ", Got: ", data.Get().Shape().ToString()); +#ifdef ENABLE_STRIDED_TENSORS + auto data_strides = data.Get().Strides(); + auto param_strides = data_.Get().Strides(); + ORT_ENFORCE(data_strides.size() == param_strides.size(), + "Parameter data stride mismatch. Expected strides of size: ", param_strides.size(), + ", Got: ", data_strides.size()); + ORT_ENFORCE(std::equal(data_strides.begin(), data_strides.end(), param_strides.begin()), + "Parameter data stride value mismatch."); +#endif + ORT_ENFORCE(data.Get().DataType() == data_.Get().DataType(), + "Parameter data type mismatch. Expected: ", data_.Get().DataType(), + ", Got: ", data.Get().DataType()); + ORT_ENFORCE(data_transfer_manager != nullptr, + "Data transfer manager must be provided to copy data to the parameter. " + "Please create the TrainingSession before trying to update the parameter."); + + ORT_THROW_IF_ERROR(data_transfer_manager->CopyTensor(data.Get(), *data_.GetMutable())); + + return Status::OK(); +} + Status Parameter::SetGrad(const std::string& gradient_name, const OrtValue& param_grad) { // assert param is allocated ORT_ENFORCE(data_.IsAllocated(), "Parameter data should be allocated before allocating gradient."); @@ -334,6 +389,10 @@ Module::Module(const ModelIdentifiers& model_identifiers, } } +Module::~Module() { + state_->module_checkpoint_state.train_session_data_transfer_mgr = nullptr; +} + size_t Module::GetTrainingModelOutputCount() const noexcept { return train_output_names_.size(); } diff --git a/orttraining/orttraining/training_api/module.h b/orttraining/orttraining/training_api/module.h index adb633343263e..f323e6be72d49 100644 --- a/orttraining/orttraining/training_api/module.h +++ b/orttraining/orttraining/training_api/module.h @@ -21,6 +21,8 @@ struct Parameter { // Return the mutable data. OrtValue& Data() { return data_; } + Status CopyTo(const DataTransferManager* data_transfer_manager, OrtValue& data) const; + Status CopyFrom(const DataTransferManager* data_transfer_manager, const OrtValue& data); const std::string& Name() const { return name_; } // Returns whether this parameter is trainable or not. @@ -34,7 +36,6 @@ struct Parameter { // Reset and release the gradient buffer of this Parameter greedily. Status ResetGrad(); - protected: Status SetGrad(const std::string& gradient_name, const OrtValue& param_grad); private: @@ -83,6 +84,8 @@ struct Module { const std::vector>& providers, gsl::span op_domains = gsl::span()); + ~Module(); + // Return the trainable/nontrainable parameters std::vector> Parameters() const; diff --git a/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc b/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc index 6693bba348648..38a9aad9640ea 100644 --- a/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc +++ b/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc @@ -333,6 +333,10 @@ ORT_API_STATUS_IMPL(OrtTrainingApis::LoadCheckpointFromBuffer, _In_ const void* _In_ const size_t num_bytes, _Outptr_ OrtCheckpointState** checkpoint_state) { API_IMPL_BEGIN + if (checkpoint_buffer == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Expected a valid checkpoint buffer. Actual: nullptr."); + } + *checkpoint_state = nullptr; auto chkpt_state = std::make_unique(); const auto* checkpoint_bytes = reinterpret_cast(checkpoint_buffer); @@ -559,6 +563,76 @@ ORT_API_STATUS_IMPL(OrtTrainingApis::GetProperty, _In_ const OrtCheckpointState* API_IMPL_END } +ORT_API_STATUS_IMPL(OrtTrainingApis::GetParameterTypeAndShape, _In_ const OrtCheckpointState* checkpoint_state, + _In_ const char* parameter_name, _Outptr_ OrtTensorTypeAndShapeInfo** parameter_type_and_shape) { + API_IMPL_BEGIN + + auto chkpt_state = reinterpret_cast(checkpoint_state); + auto it = chkpt_state->module_checkpoint_state.named_parameters.find(parameter_name); + if (it == chkpt_state->module_checkpoint_state.named_parameters.end()) { + std::string err_msg = "Parameter name " + std::string(parameter_name) + " not found in checkpoint state."; + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, err_msg.c_str()); + } + + return OrtApis::GetTensorTypeAndShape(&it->second->Data(), parameter_type_and_shape); + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtTrainingApis::UpdateParameter, _Inout_ OrtCheckpointState* checkpoint_state, + _In_ const char* parameter_name, _In_ OrtValue* parameter) { + API_IMPL_BEGIN + if (parameter == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Expected a valid parameter. Actual: nullptr."); + } + + auto chkpt_state = reinterpret_cast(checkpoint_state); + auto it = chkpt_state->module_checkpoint_state.named_parameters.find(parameter_name); + if (it == chkpt_state->module_checkpoint_state.named_parameters.end()) { + std::string err_msg = "Parameter name " + std::string(parameter_name) + " not found in checkpoint state."; + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, err_msg.c_str()); + } + ORT_API_RETURN_IF_STATUS_NOT_OK(it->second->CopyFrom( + chkpt_state->module_checkpoint_state.train_session_data_transfer_mgr, *parameter)); + + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtTrainingApis::GetParameter, _In_ const OrtCheckpointState* checkpoint_state, + _In_ const char* parameter_name, _Inout_ OrtAllocator* allocator, + _Outptr_ OrtValue** parameter) { + API_IMPL_BEGIN + + if (parameter == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Expected a valid parameter. Actual: nullptr."); + } + + auto chkpt_state = reinterpret_cast(checkpoint_state); + auto it = chkpt_state->module_checkpoint_state.named_parameters.find(parameter_name); + if (it == chkpt_state->module_checkpoint_state.named_parameters.end()) { + std::string err_msg = "Parameter name " + std::string(parameter_name) + " not found in checkpoint state."; + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, err_msg.c_str()); + } + + if (!it->second->Data().IsTensor()) { + return OrtApis::CreateStatus(ORT_FAIL, "Expected a tensor type for the parameter. Found a non-tensor type."); + } + const auto& parameter_tensor = it->second->Data().Get(); + ORT_API_RETURN_IF_ERROR(OrtApis::CreateTensorAsOrtValue( + allocator, parameter_tensor.Shape().GetDims().data(), parameter_tensor.Shape().NumDimensions(), + ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, parameter)); + + auto status = it->second->CopyTo( + chkpt_state->module_checkpoint_state.train_session_data_transfer_mgr, **parameter); + if (!status.IsOK()) { + OrtApis::ReleaseValue(*parameter); + return onnxruntime::ToOrtStatus(status); + } + + return nullptr; + API_IMPL_END +} + static constexpr OrtTrainingApi ort_training_api = { // NOTE: The C# bindings depend on the API order within this struct. Since Training APIs are not officially // released, it is OK to change the order here, however a corresponding matching change should also be done in the @@ -592,7 +666,10 @@ static constexpr OrtTrainingApi ort_training_api = { &OrtTrainingApis::TrainingSessionGetEvalModelInputName, &OrtTrainingApis::AddProperty, &OrtTrainingApis::GetProperty, - &OrtTrainingApis::LoadCheckpointFromBuffer}; + &OrtTrainingApis::LoadCheckpointFromBuffer, + &OrtTrainingApis::GetParameterTypeAndShape, + &OrtTrainingApis::UpdateParameter, + &OrtTrainingApis::GetParameter}; ORT_API(const OrtTrainingApi*, OrtTrainingApis::GetTrainingApi, uint32_t) { // No constraints on the API version yet. diff --git a/orttraining/orttraining/training_api/ort_training_apis.h b/orttraining/orttraining/training_api/ort_training_apis.h index c87108957c975..2a8c1e30361c6 100644 --- a/orttraining/orttraining/training_api/ort_training_apis.h +++ b/orttraining/orttraining/training_api/ort_training_apis.h @@ -94,4 +94,14 @@ ORT_API_STATUS_IMPL(GetProperty, _In_ const OrtCheckpointState* checkpoint_state ORT_API_STATUS_IMPL(LoadCheckpointFromBuffer, _In_ const void* checkpoint_buffer, _In_ const size_t num_bytes, _Outptr_ OrtCheckpointState** checkpoint_state); +ORT_API_STATUS_IMPL(GetParameterTypeAndShape, _In_ const OrtCheckpointState* checkpoint_state, + _In_ const char* parameter_name, _Outptr_ OrtTensorTypeAndShapeInfo** parameter_type_and_shape); + +ORT_API_STATUS_IMPL(UpdateParameter, _Inout_ OrtCheckpointState* checkpoint_state, + _In_ const char* parameter_name, _In_ OrtValue* parameter); + +ORT_API_STATUS_IMPL(GetParameter, _In_ const OrtCheckpointState* checkpoint_state, + _In_ const char* parameter_name, _Inout_ OrtAllocator* allocator, + _Outptr_ OrtValue** parameter); + } // namespace OrtTrainingApis From aed43f429a961e2fee543b801da50899a11b43cb Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Tue, 26 Sep 2023 04:49:13 -0400 Subject: [PATCH 29/58] [java] Enable output pinning in OrtSession and OrtTrainingSession (#16835) --- .../main/java/ai/onnxruntime/OrtSession.java | 202 +++++++++++++++--- .../ai/onnxruntime/OrtTrainingSession.java | 172 ++++++++++++--- .../main/native/ai_onnxruntime_OrtSession.c | 38 +++- .../ai_onnxruntime_OrtTrainingSession.c | 109 ++++++---- .../java/ai/onnxruntime/InferenceTest.java | 152 ++++++++++++- .../java/ai/onnxruntime/ModelGenerators.java | 96 +++++++++ .../test/java/ai/onnxruntime/TestHelpers.java | 6 + .../java/ai/onnxruntime/TrainingTest.java | 12 +- .../resources/java-three-output-matmul.onnx | Bin 0 -> 530 bytes 9 files changed, 668 insertions(+), 119 deletions(-) create mode 100644 java/src/test/resources/java-three-output-matmul.onnx diff --git a/java/src/main/java/ai/onnxruntime/OrtSession.java b/java/src/main/java/ai/onnxruntime/OrtSession.java index 435f86daa5fe2..fbea13d155507 100644 --- a/java/src/main/java/ai/onnxruntime/OrtSession.java +++ b/java/src/main/java/ai/onnxruntime/OrtSession.java @@ -239,7 +239,7 @@ public Result run(Map inputs, RunOptions runOp */ public Result run(Map inputs, Set requestedOutputs) throws OrtException { - return run(inputs, requestedOutputs, null); + return run(inputs, requestedOutputs, Collections.emptyMap(), null); } /** @@ -259,17 +259,90 @@ public Result run( Set requestedOutputs, RunOptions runOptions) throws OrtException { + return run(inputs, requestedOutputs, Collections.emptyMap(), runOptions); + } + + /** + * Scores an input feed dict, returning the map of pinned outputs. + * + *

The outputs are sorted based on the supplied map traversal order. + * + *

Note: pinned outputs are not owned by the {@link Result} object, and are not closed + * when the result object is closed. + * + * @param inputs The inputs to score. + * @param pinnedOutputs The requested outputs which the user has allocated. + * @return The inferred outputs. + * @throws OrtException If there was an error in native code, the input or output names are + * invalid, or if there are zero or too many inputs or outputs. + */ + public Result run( + Map inputs, Map pinnedOutputs) + throws OrtException { + return run(inputs, Collections.emptySet(), pinnedOutputs, null); + } + + /** + * Scores an input feed dict, returning the map of requested and pinned outputs. + * + *

The outputs are sorted based on the supplied set traversal order with pinned outputs first, + * then requested outputs. An {@link IllegalArgumentException} is thrown if the same output name + * appears in both the requested outputs and the pinned outputs. + * + *

Note: pinned outputs are not owned by the {@link Result} object, and are not closed + * when the result object is closed. + * + * @param inputs The inputs to score. + * @param requestedOutputs The requested outputs which ORT will allocate. + * @param pinnedOutputs The requested outputs which the user has allocated. + * @return The inferred outputs. + * @throws OrtException If there was an error in native code, the input or output names are + * invalid, or if there are zero or too many inputs or outputs. + */ + public Result run( + Map inputs, + Set requestedOutputs, + Map pinnedOutputs) + throws OrtException { + return run(inputs, requestedOutputs, pinnedOutputs, null); + } + + /** + * Scores an input feed dict, returning the map of requested and pinned outputs. + * + *

The outputs are sorted based on the supplied set traversal order with pinned outputs first, + * then requested outputs. An {@link IllegalArgumentException} is thrown if the same output name + * appears in both the requested outputs and the pinned outputs. + * + *

Note: pinned outputs are not owned by the {@link Result} object, and are not closed + * when the result object is closed. + * + * @param inputs The inputs to score. + * @param requestedOutputs The requested outputs which ORT will allocate. + * @param pinnedOutputs The requested outputs which the user has allocated. + * @param runOptions The RunOptions to control this run. + * @return The inferred outputs. + * @throws OrtException If there was an error in native code, the input or output names are + * invalid, or if there are zero or too many inputs or outputs. + */ + public Result run( + Map inputs, + Set requestedOutputs, + Map pinnedOutputs, + RunOptions runOptions) + throws OrtException { if (!closed) { if ((inputs.isEmpty() && (numInputs != 0)) || (inputs.size() > numInputs)) { throw new OrtException( "Unexpected number of inputs, expected [1," + numInputs + ") found " + inputs.size()); } - if (requestedOutputs.isEmpty() || (requestedOutputs.size() > numOutputs)) { + int totalOutputs = requestedOutputs.size() + pinnedOutputs.size(); + if ((totalOutputs == 0) || (totalOutputs > numOutputs)) { throw new OrtException( - "Unexpected number of requestedOutputs, expected [1," + "Unexpected number of requestedOutputs & pinnedOutputs, expected [1," + numOutputs + ") found " - + requestedOutputs.size()); + + totalOutputs); } String[] inputNamesArray = new String[inputs.size()]; long[] inputHandles = new long[inputs.size()]; @@ -284,20 +357,41 @@ public Result run( "Unknown input name " + t.getKey() + ", expected one of " + inputNames.toString()); } } - String[] outputNamesArray = new String[requestedOutputs.size()]; + String[] outputNamesArray = new String[requestedOutputs.size() + pinnedOutputs.size()]; + OnnxValue[] outputValues = new OnnxValue[outputNamesArray.length]; + long[] outputHandles = new long[outputNamesArray.length]; i = 0; + for (Map.Entry e : pinnedOutputs.entrySet()) { + if (outputNames.contains(e.getKey())) { + outputNamesArray[i] = e.getKey(); + outputValues[i] = e.getValue(); + outputHandles[i] = getHandle(e.getValue()); + i++; + } else { + throw new OrtException( + "Unknown output name " + e.getKey() + ", expected one of " + outputNames.toString()); + } + } for (String s : requestedOutputs) { if (outputNames.contains(s)) { - outputNamesArray[i] = s; - i++; + if (!pinnedOutputs.containsKey(s)) { + outputNamesArray[i] = s; + // outputValues and outputHandles can be null/0 for these outputs as ORT will allocate + // them. + i++; + } else { + throw new OrtException( + "Output '" + + s + + "' was found in both the requested outputs and the pinned outputs"); + } } else { throw new OrtException( "Unknown output name " + s + ", expected one of " + outputNames.toString()); } } long runOptionsHandle = runOptions == null ? 0 : runOptions.getNativeHandle(); - - OnnxValue[] outputValues = + boolean[] ownedByResult = run( OnnxRuntime.ortApiHandle, nativeHandle, @@ -307,13 +401,40 @@ public Result run( inputNamesArray.length, outputNamesArray, outputNamesArray.length, + outputValues, + outputHandles, runOptionsHandle); - return new Result(outputNamesArray, outputValues); + return new Result(outputNamesArray, outputValues, ownedByResult); } else { throw new IllegalStateException("Trying to score a closed OrtSession."); } } + /** + * Pulls out the native handle by casting it to the appropriate type. + * + * @param v The OnnxValue. + * @return The native handle. + */ + static long getHandle(OnnxValue v) { + /* + * Note this method exists as interface methods are all public, but we do not want users to be + * able to access the native pointer via a public API so can't add a method to OnnxValue which + * exposes it. + */ + if (v instanceof OnnxTensorLike) { + return ((OnnxTensorLike) v).nativeHandle; + } else if (v instanceof OnnxSequence) { + return ((OnnxSequence) v).nativeHandle; + } else if (v instanceof OnnxMap) { + return ((OnnxMap) v).nativeHandle; + } else { + throw new IllegalArgumentException( + "Unexpected OnnxValue subclass, should be {OnnxTensorLike, OnnxSequence, OnnxMap}, found " + + v.getClass()); + } + } + /** * Gets the metadata for the currently loaded model. * @@ -409,8 +530,9 @@ private native NodeInfo[] getOutputInfo(long apiHandle, long nativeHandle, long throws OrtException; /** - * The native run call. runOptionsHandle can be zero (i.e. the null pointer), but all other - * handles must be valid pointers. + * The native run call. runOptionsHandle can be zero (i.e. the null pointer), outputValues can + * contain null entries, and outputHandles can contain zero values (i.e. the null pointer), but + * all other handles must be valid pointers. * * @param apiHandle The pointer to the api. * @param nativeHandle The pointer to the session. @@ -419,12 +541,14 @@ private native NodeInfo[] getOutputInfo(long apiHandle, long nativeHandle, long * @param inputs The input tensors. * @param numInputs The number of inputs. * @param outputNamesArray The requested output names. + * @param outputValues The OnnxValue output array. + * @param outputHandles The OrtValue output pointer array. * @param numOutputs The number of requested outputs. * @param runOptionsHandle The (possibly null) pointer to the run options. - * @return The OnnxValues produced by this run. + * @return A boolean array representing if the OnnxValues were allocated by this run call. * @throws OrtException If the native call failed in some way. */ - private native OnnxValue[] run( + private native boolean[] run( long apiHandle, long nativeHandle, long allocatorHandle, @@ -433,6 +557,8 @@ private native OnnxValue[] run( long numInputs, String[] outputNamesArray, long numOutputs, + OnnxValue[] outputValues, + long[] outputHandles, long runOptionsHandle) throws OrtException; @@ -1417,9 +1543,13 @@ private native void addRunConfigEntry( /** * An {@link AutoCloseable} wrapper around a {@link Map} containing {@link OnnxValue}s. * - *

When this is closed it closes all the {@link OnnxValue}s inside it. If you maintain a - * reference to a value after this object has been closed it will throw an {@link + *

When this is closed it closes all the {@link OnnxValue}s owned by the result object. If you + * maintain a reference to a value after this object has been closed it will throw an {@link * IllegalStateException} upon access. + * + *

{@link OnnxValue}s which are supplied as pinned outputs to a {@code run} call are not closed + * by the {@link Result#close()} method. Ownership of each output can be checked with {@link + * Result#isResultOwner(int)}. */ public static class Result implements AutoCloseable, Iterable> { @@ -1429,6 +1559,8 @@ public static class Result implements AutoCloseable, Iterable list; + private final boolean[] ownedByResult; + private boolean closed; /** @@ -1437,21 +1569,23 @@ public static class Result implements AutoCloseable, IterableThrows {@link IllegalStateException} if the container has been closed, and {@link + * ArrayIndexOutOfBoundsException} if the index is invalid. + * + * @param index The index to lookup. + * @return Is that value owned by this result object? + */ + public boolean isResultOwner(int index) { + if (!closed) { + return ownedByResult[index]; + } else { + throw new IllegalStateException("Result is closed"); + } + } + /** * Returns the number of outputs in this Result. * diff --git a/java/src/main/java/ai/onnxruntime/OrtTrainingSession.java b/java/src/main/java/ai/onnxruntime/OrtTrainingSession.java index 8c03c5b80433c..49ddf29c22335 100644 --- a/java/src/main/java/ai/onnxruntime/OrtTrainingSession.java +++ b/java/src/main/java/ai/onnxruntime/OrtTrainingSession.java @@ -418,7 +418,7 @@ private static native void setSeed(long apiHandle, long trainingHandle, long see */ public OrtSession.Result trainStep(Map inputs) throws OrtException { - return trainStep(inputs, trainOutputNames, null); + return trainStep(inputs, trainOutputNames, Collections.emptyMap(), null); } /** @@ -432,7 +432,7 @@ public OrtSession.Result trainStep(Map inputs) public OrtSession.Result trainStep( Map inputs, OrtSession.RunOptions runOptions) throws OrtException { - return trainStep(inputs, trainOutputNames, runOptions); + return trainStep(inputs, trainOutputNames, Collections.emptyMap(), runOptions); } /** @@ -446,14 +446,41 @@ public OrtSession.Result trainStep( public OrtSession.Result trainStep( Map inputs, Set requestedOutputs) throws OrtException { - return trainStep(inputs, requestedOutputs, null); + return trainStep(inputs, requestedOutputs, Collections.emptyMap(), null); } /** * Performs a single step of training, accumulating the gradients. * + *

The outputs are sorted based on the supplied map traversal order. + * + *

Note: pinned outputs are not owned by the {@link OrtSession.Result} object, and are + * not closed when the result object is closed. + * * @param inputs The inputs (must include both the features and the target). - * @param requestedOutputs The requested outputs. + * @param pinnedOutputs The requested outputs which the user has allocated. + * @return Requested outputs produced by the training step. + * @throws OrtException If the native call failed. + */ + public OrtSession.Result trainStep( + Map inputs, Map pinnedOutputs) + throws OrtException { + return trainStep(inputs, Collections.emptySet(), pinnedOutputs, null); + } + + /** + * Performs a single step of training, accumulating the gradients. + * + *

The outputs are sorted based on the supplied set traversal order with pinned outputs first, + * then requested outputs. An {@link IllegalArgumentException} is thrown if the same output name + * appears in both the requested outputs and the pinned outputs. + * + *

Note: pinned outputs are not owned by the {@link OrtSession.Result} object, and are + * not closed when the result object is closed. + * + * @param inputs The inputs (must include both the features and the target). + * @param requestedOutputs The requested outputs which ORT will allocate. + * @param pinnedOutputs The requested outputs which the user has allocated. * @param runOptions Run options for controlling this specific call. * @return Requested outputs produced by the training step. * @throws OrtException If the native call failed. @@ -461,6 +488,7 @@ public OrtSession.Result trainStep( public OrtSession.Result trainStep( Map inputs, Set requestedOutputs, + Map pinnedOutputs, OrtSession.RunOptions runOptions) throws OrtException { checkClosed(); @@ -472,12 +500,14 @@ public OrtSession.Result trainStep( + ") found " + inputs.size()); } - if (requestedOutputs.isEmpty() || (requestedOutputs.size() > trainOutputNames.size())) { + int numTrainOutputs = trainOutputNames.size(); + int totalOutputs = requestedOutputs.size() + pinnedOutputs.size(); + if ((totalOutputs == 0) || (totalOutputs > numTrainOutputs)) { throw new OrtException( - "Unexpected number of requestedOutputs, expected [1," - + trainOutputNames.size() + "Unexpected number of requestedOutputs & pinnedOutputs, expected [1," + + numTrainOutputs + ") found " - + requestedOutputs.size()); + + totalOutputs); } String[] inputNamesArray = new String[inputs.size()]; long[] inputHandles = new long[inputs.size()]; @@ -492,12 +522,35 @@ public OrtSession.Result trainStep( "Unknown input name " + t.getKey() + ", expected one of " + trainInputNames); } } - String[] outputNamesArray = new String[requestedOutputs.size()]; + String[] outputNamesArray = new String[requestedOutputs.size() + pinnedOutputs.size()]; + OnnxValue[] outputValues = new OnnxValue[outputNamesArray.length]; + long[] outputHandles = new long[outputNamesArray.length]; i = 0; + for (Map.Entry e : pinnedOutputs.entrySet()) { + if (trainOutputNames.contains(e.getKey())) { + outputNamesArray[i] = e.getKey(); + outputValues[i] = e.getValue(); + outputHandles[i] = OrtSession.getHandle(e.getValue()); + i++; + } else { + throw new OrtException( + "Unknown output name " + + e.getKey() + + ", expected one of " + + trainOutputNames.toString()); + } + } for (String s : requestedOutputs) { if (trainOutputNames.contains(s)) { - outputNamesArray[i] = s; - i++; + if (!pinnedOutputs.containsKey(s)) { + outputNamesArray[i] = s; + // outputValues and outputHandles can be null/0 for these outputs as ORT will allocate + // them. + i++; + } else { + throw new OrtException( + "Output '" + s + "' was found in both the requested outputs and the pinned outputs"); + } } else { throw new OrtException( "Unknown output name " + s + ", expected one of " + trainOutputNames.toString()); @@ -505,7 +558,7 @@ public OrtSession.Result trainStep( } long runOptionsHandle = runOptions == null ? 0 : runOptions.getNativeHandle(); - OnnxValue[] outputValues = + boolean[] ownedByResult = trainStep( OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, @@ -516,8 +569,10 @@ public OrtSession.Result trainStep( inputNamesArray.length, outputNamesArray, outputNamesArray.length, + outputValues, + outputHandles, runOptionsHandle); - return new OrtSession.Result(outputNamesArray, outputValues); + return new OrtSession.Result(outputNamesArray, outputValues, ownedByResult); } /* @@ -540,7 +595,7 @@ public OrtSession.Result trainStep( * run_options, size_t inputs_len, _In_reads_(inputs_len) const OrtValue* const* inputs, size_t * outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs); */ - private native OnnxValue[] trainStep( + private native boolean[] trainStep( long apiHandle, long trainingApiHandle, long nativeHandle, @@ -550,6 +605,8 @@ private native OnnxValue[] trainStep( long numInputs, String[] outputNamesArray, long numOutputs, + OnnxValue[] outputValues, + long[] outputHandles, long runOptionsHandle); /** @@ -561,7 +618,7 @@ private native OnnxValue[] trainStep( */ public OrtSession.Result evalStep(Map inputs) throws OrtException { - return evalStep(inputs, evalOutputNames, null); + return evalStep(inputs, evalOutputNames, Collections.emptyMap(), null); } /** @@ -575,7 +632,7 @@ public OrtSession.Result evalStep(Map inputs) public OrtSession.Result evalStep( Map inputs, OrtSession.RunOptions runOptions) throws OrtException { - return evalStep(inputs, evalOutputNames, runOptions); + return evalStep(inputs, evalOutputNames, Collections.emptyMap(), runOptions); } /** @@ -589,14 +646,41 @@ public OrtSession.Result evalStep( public OrtSession.Result evalStep( Map inputs, Set requestedOutputs) throws OrtException { - return evalStep(inputs, requestedOutputs, null); + return evalStep(inputs, requestedOutputs, Collections.emptyMap(), null); } /** * Performs a single evaluation step using the supplied inputs. * - * @param inputs The model inputs. - * @param requestedOutputs The requested output names. + *

The outputs are sorted based on the supplied map traversal order. + * + *

Note: pinned outputs are not owned by the {@link OrtSession.Result} object, and are + * not closed when the result object is closed. + * + * @param inputs The inputs to score. + * @param pinnedOutputs The requested outputs which the user has allocated. + * @return The requested outputs. + * @throws OrtException If the native call failed. + */ + public OrtSession.Result evalStep( + Map inputs, Map pinnedOutputs) + throws OrtException { + return evalStep(inputs, Collections.emptySet(), pinnedOutputs, null); + } + + /** + * Performs a single evaluation step using the supplied inputs. + * + *

The outputs are sorted based on the supplied set traversal order with pinned outputs first, + * then requested outputs. An {@link IllegalArgumentException} is thrown if the same output name + * appears in both the requested outputs and the pinned outputs. + * + *

Note: pinned outputs are not owned by the {@link OrtSession.Result} object, and are + * not closed when the result object is closed. + * + * @param inputs The inputs to score. + * @param requestedOutputs The requested outputs which ORT will allocate. + * @param pinnedOutputs The requested outputs which the user has allocated. * @param runOptions Run options for controlling this specific call. * @return The requested outputs. * @throws OrtException If the native call failed. @@ -604,6 +688,7 @@ public OrtSession.Result evalStep( public OrtSession.Result evalStep( Map inputs, Set requestedOutputs, + Map pinnedOutputs, OrtSession.RunOptions runOptions) throws OrtException { checkClosed(); @@ -615,12 +700,14 @@ public OrtSession.Result evalStep( + ") found " + inputs.size()); } - if (requestedOutputs.isEmpty() || (requestedOutputs.size() > evalOutputNames.size())) { + int numEvalOutputs = evalOutputNames.size(); + int totalOutputs = requestedOutputs.size() + pinnedOutputs.size(); + if ((totalOutputs == 0) || (totalOutputs > numEvalOutputs)) { throw new OrtException( - "Unexpected number of requestedOutputs, expected [1," - + evalOutputNames.size() + "Unexpected number of requestedOutputs & pinnedOutputs, expected [1," + + numEvalOutputs + ") found " - + requestedOutputs.size()); + + totalOutputs); } String[] inputNamesArray = new String[inputs.size()]; long[] inputHandles = new long[inputs.size()]; @@ -635,12 +722,35 @@ public OrtSession.Result evalStep( "Unknown input name " + t.getKey() + ", expected one of " + evalInputNames.toString()); } } - String[] outputNamesArray = new String[requestedOutputs.size()]; + String[] outputNamesArray = new String[requestedOutputs.size() + pinnedOutputs.size()]; + OnnxValue[] outputValues = new OnnxValue[outputNamesArray.length]; + long[] outputHandles = new long[outputNamesArray.length]; i = 0; + for (Map.Entry e : pinnedOutputs.entrySet()) { + if (evalOutputNames.contains(e.getKey())) { + outputNamesArray[i] = e.getKey(); + outputValues[i] = e.getValue(); + outputHandles[i] = OrtSession.getHandle(e.getValue()); + i++; + } else { + throw new OrtException( + "Unknown output name " + + e.getKey() + + ", expected one of " + + evalOutputNames.toString()); + } + } for (String s : requestedOutputs) { if (evalOutputNames.contains(s)) { - outputNamesArray[i] = s; - i++; + if (!pinnedOutputs.containsKey(s)) { + outputNamesArray[i] = s; + // outputValues and outputHandles can be null/0 for these outputs as ORT will allocate + // them. + i++; + } else { + throw new OrtException( + "Output '" + s + "' was found in both the requested outputs and the pinned outputs"); + } } else { throw new OrtException( "Unknown output name " + s + ", expected one of " + evalOutputNames.toString()); @@ -648,7 +758,7 @@ public OrtSession.Result evalStep( } long runOptionsHandle = runOptions == null ? 0 : runOptions.getNativeHandle(); - OnnxValue[] outputValues = + boolean[] ownedByResult = evalStep( OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, @@ -659,8 +769,10 @@ public OrtSession.Result evalStep( inputNamesArray.length, outputNamesArray, outputNamesArray.length, + outputValues, + outputHandles, runOptionsHandle); - return new OrtSession.Result(outputNamesArray, outputValues); + return new OrtSession.Result(outputNamesArray, outputValues, ownedByResult); } /* @@ -682,7 +794,7 @@ public OrtSession.Result evalStep( * run_options, size_t inputs_len, _In_reads_(inputs_len) const OrtValue* const* inputs, size_t * outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs); */ - private native OnnxValue[] evalStep( + private native boolean[] evalStep( long apiHandle, long trainingApiHandle, long nativeHandle, @@ -692,6 +804,8 @@ private native OnnxValue[] evalStep( long numInputs, String[] outputNamesArray, long numOutputs, + OnnxValue[] outputValues, + long[] outputHandles, long runOptionsHandle) throws OrtException; diff --git a/java/src/main/native/ai_onnxruntime_OrtSession.c b/java/src/main/native/ai_onnxruntime_OrtSession.c index 6f4e34648cf81..f4d5ab080cd31 100644 --- a/java/src/main/native/ai_onnxruntime_OrtSession.c +++ b/java/src/main/native/ai_onnxruntime_OrtSession.c @@ -316,14 +316,19 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtSession_getOutputInfo(JNIE /* * Class: ai_onnxruntime_OrtSession * Method: run - * Signature: (JJJ[Ljava/lang/String;[JJ[Ljava/lang/String;JJ)[Lai/onnxruntime/OnnxValue; - * private native OnnxValue[] run(long apiHandle, long nativeHandle, long allocatorHandle, String[] inputNamesArray, long[] inputs, long numInputs, String[] outputNamesArray, long numOutputs) + * Signature: (JJJ[Ljava/lang/String;[JJ[Ljava/lang/String;J[Lai/onnxruntime/OnnxValue;[JJ)[Z + * private native boolean[] run(long apiHandle, long nativeHandle, long allocatorHandle, + * String[] inputNamesArray, long[] inputs, long numInputs, + * String[] outputNamesArray, long numOutputs, + * OnnxValue[] outputValues, long[] outputHandles, + * long runOptionsHandle) throws OrtException; */ -JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtSession_run(JNIEnv* jniEnv, jobject jobj, jlong apiHandle, +JNIEXPORT jbooleanArray JNICALL Java_ai_onnxruntime_OrtSession_run(JNIEnv* jniEnv, jobject jobj, jlong apiHandle, jlong sessionHandle, jlong allocatorHandle, jobjectArray inputNamesArr, jlongArray tensorArr, jlong numInputs, jobjectArray outputNamesArr, - jlong numOutputs, jlong runOptionsHandle) { + jlong numOutputs, jobjectArray outputValuesArr, + jlongArray outputHandlesArr, jlong runOptionsHandle) { (void)jobj; // Required JNI parameter not needed by functions which don't need to access their host object. const OrtApi* api = (const OrtApi*)apiHandle; @@ -331,7 +336,7 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtSession_run(JNIEnv* jniEnv OrtSession* session = (OrtSession*)sessionHandle; OrtRunOptions* runOptions = (OrtRunOptions*)runOptionsHandle; - jobjectArray outputArray = NULL; + jbooleanArray outputArray = NULL; // Create the buffers for the Java input & output strings, and the input pointers const char** inputNames = allocarray(numInputs, sizeof(char*)); @@ -376,13 +381,19 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtSession_run(JNIEnv* jniEnv // Release the java array copy of pointers to the tensors. (*jniEnv)->ReleaseLongArrayElements(jniEnv, tensorArr, inputValueLongs, JNI_ABORT); + // Extract a C array of longs which are pointers to the output tensors. + jlong* outputHandleLongs = (*jniEnv)->GetLongArrayElements(jniEnv, outputHandlesArr, NULL); + // Extract the names of the output values. for (int i = 0; i < numOutputs; i++) { javaOutputStrings[i] = (*jniEnv)->GetObjectArrayElement(jniEnv, outputNamesArr, i); outputNames[i] = (*jniEnv)->GetStringUTFChars(jniEnv, javaOutputStrings[i], NULL); - outputValues[i] = NULL; + outputValues[i] = (OrtValue*)outputHandleLongs[i]; } + // Release the java array copy of pointers to the outputs. + (*jniEnv)->ReleaseLongArrayElements(jniEnv, outputHandlesArr, outputHandleLongs, JNI_ABORT); + // Actually score the inputs. // ORT_API_STATUS(OrtRun, _Inout_ OrtSession* sess, _In_ OrtRunOptions* run_options, // _In_ const char* const* input_names, _In_ const OrtValue* const* input, size_t input_len, @@ -394,21 +405,26 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtSession_run(JNIEnv* jniEnv goto cleanup_output_values; } - // Construct the output array of ONNXValues - jclass onnxValueClass = (*jniEnv)->FindClass(jniEnv, ORTJNI_OnnxValueClassName); - outputArray = (*jniEnv)->NewObjectArray(jniEnv, safecast_int64_to_jsize(numOutputs), onnxValueClass, NULL); + // Create the output boolean array denoting if ORT owns the memory for each output. + // Java boolean arrays are initialized to false. + outputArray = (*jniEnv)->NewBooleanArray(jniEnv, safecast_int64_to_jsize(numOutputs)); + jboolean* boolArr = (*jniEnv)->GetBooleanArrayElements(jniEnv, outputArray, NULL); // Convert the output tensors into ONNXValues for (int i = 0; i < numOutputs; i++) { - if (outputValues[i] != NULL) { + if (outputValues[i] != NULL && (*jniEnv)->GetObjectArrayElement(jniEnv, outputValuesArr, i) == NULL) { jobject onnxValue = convertOrtValueToONNXValue(jniEnv, api, allocator, outputValues[i]); if (onnxValue == NULL) { break; // go to cleanup, exception thrown } - (*jniEnv)->SetObjectArrayElement(jniEnv, outputArray, i, onnxValue); + boolArr[i] = 1; + (*jniEnv)->SetObjectArrayElement(jniEnv, outputValuesArr, i, onnxValue); } } + // Write the output array back to Java. + (*jniEnv)->ReleaseBooleanArrayElements(jniEnv, outputArray, boolArr, 0); + // Note these gotos are in a specific order so they mirror the allocation pattern above. // They must be changed if the allocation code is rearranged. cleanup_output_values: diff --git a/java/src/main/native/ai_onnxruntime_OrtTrainingSession.c b/java/src/main/native/ai_onnxruntime_OrtTrainingSession.c index b3b530a8b15aa..9f7b8d3a3dcfc 100644 --- a/java/src/main/native/ai_onnxruntime_OrtTrainingSession.c +++ b/java/src/main/native/ai_onnxruntime_OrtTrainingSession.c @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022 Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2022, 2023, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ #include @@ -330,12 +330,12 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtTrainingSession_lazyResetGrad /* * Class: ai_onnxruntime_OrtTrainingSession * Method: trainStep - * Signature: (JJJJ[Ljava/lang/String;[JJ[Ljava/lang/String;JJ)[Lai/onnxruntime/OnnxValue; + * Signature: (JJJJ[Ljava/lang/String;[JJ[Ljava/lang/String;J[Lai/onnxruntime/OnnxValue;[JJ)[Z */ -JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtTrainingSession_trainStep +JNIEXPORT jbooleanArray JNICALL Java_ai_onnxruntime_OrtTrainingSession_trainStep (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong trainApiHandle, jlong nativeHandle, jlong allocatorHandle, jobjectArray inputNamesArr, jlongArray inputHandles, jlong numInputs, - jobjectArray outputNamesArr, jlong numOutputs, jlong runOptionsHandle) { + jobjectArray outputNamesArr, jlong numOutputs, jobjectArray outputValuesArr, jlongArray outputHandlesArr, jlong runOptionsHandle) { (void)jobj; // Required JNI parameter not needed by functions which don't need to access their host object. const OrtApi* api = (const OrtApi*)apiHandle; const OrtTrainingApi* trainApi = (const OrtTrainingApi*)trainApiHandle; @@ -343,31 +343,31 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtTrainingSession_trainStep OrtTrainingSession* trainSession = (OrtTrainingSession*)nativeHandle; OrtRunOptions* runOptions = (OrtRunOptions*)runOptionsHandle; - jobjectArray outputArray = NULL; + jbooleanArray outputArray = NULL; // Create the buffers for the Java input & output strings, and the input pointers - const char** inputNames = malloc(sizeof(char*) * numInputs); + const char** inputNames = allocarray(numInputs, sizeof(char*)); if (inputNames == NULL) { // Nothing to cleanup, return and throw exception return outputArray; } - const char** outputNames = malloc(sizeof(char*) * numOutputs); + const char** outputNames = allocarray(numOutputs, sizeof(char*)); if (outputNames == NULL) { goto cleanup_input_names; } - jobject* javaInputStrings = malloc(sizeof(jobject) * numInputs); + jobject* javaInputStrings = allocarray(numInputs, sizeof(jobject)); if (javaInputStrings == NULL) { goto cleanup_output_names; } - jobject* javaOutputStrings = malloc(sizeof(jobject) * numOutputs); + jobject* javaOutputStrings = allocarray(numOutputs, sizeof(jobject)); if (javaOutputStrings == NULL) { goto cleanup_java_input_strings; } - const OrtValue** inputValuePtrs = malloc(sizeof(OrtValue*) * numInputs); + const OrtValue** inputValuePtrs = allocarray(numInputs, sizeof(OrtValue*)); if (inputValuePtrs == NULL) { goto cleanup_java_output_strings; } - OrtValue** outputValues = malloc(sizeof(OrtValue*) * numOutputs); + OrtValue** outputValues = allocarray(numOutputs, sizeof(OrtValue*)); if (outputValues == NULL) { goto cleanup_input_values; } @@ -388,13 +388,19 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtTrainingSession_trainStep // Release the java array copy of pointers to the tensors. (*jniEnv)->ReleaseLongArrayElements(jniEnv, inputHandles, inputValueLongs, JNI_ABORT); + // Extract a C array of longs which are pointers to the output tensors. + jlong* outputHandleLongs = (*jniEnv)->GetLongArrayElements(jniEnv, outputHandlesArr, NULL); + // Extract the names of the output values. for (int i = 0; i < numOutputs; i++) { javaOutputStrings[i] = (*jniEnv)->GetObjectArrayElement(jniEnv, outputNamesArr, i); outputNames[i] = (*jniEnv)->GetStringUTFChars(jniEnv, javaOutputStrings[i], NULL); - outputValues[i] = NULL; + outputValues[i] = (OrtValue*)outputHandleLongs[i]; } + // Release the java array copy of pointers to the outputs. + (*jniEnv)->ReleaseLongArrayElements(jniEnv, outputHandlesArr, outputHandleLongs, JNI_ABORT); + // Actually score the inputs. //ORT_API2_STATUS(TrainStep, _Inout_ OrtTrainingSession* sess, _In_opt_ const OrtRunOptions* run_options, // size_t inputs_len, _In_reads_(inputs_len) const OrtValue* const* inputs, @@ -406,24 +412,29 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtTrainingSession_trainStep goto cleanup_output_values; } - // Construct the output array of ONNXValues - jclass onnxValueClass = (*jniEnv)->FindClass(jniEnv, "ai/onnxruntime/OnnxValue"); - outputArray = (*jniEnv)->NewObjectArray(jniEnv, safecast_int64_to_jsize(numOutputs), onnxValueClass, NULL); + // Create the output boolean array denoting if ORT owns the memory for each output. + // Java boolean arrays are initialized to false. + outputArray = (*jniEnv)->NewBooleanArray(jniEnv, safecast_int64_to_jsize(numOutputs)); + jboolean* boolArr = (*jniEnv)->GetBooleanArrayElements(jniEnv, outputArray, NULL); // Convert the output tensors into ONNXValues for (int i = 0; i < numOutputs; i++) { - if (outputValues[i] != NULL) { + if (outputValues[i] != NULL && (*jniEnv)->GetObjectArrayElement(jniEnv, outputValuesArr, i) == NULL) { jobject onnxValue = convertOrtValueToONNXValue(jniEnv, api, allocator, outputValues[i]); if (onnxValue == NULL) { break; // go to cleanup, exception thrown } - (*jniEnv)->SetObjectArrayElement(jniEnv, outputArray, i, onnxValue); + boolArr[i] = 1; + (*jniEnv)->SetObjectArrayElement(jniEnv, outputValuesArr, i, onnxValue); } } + // Write the output array back to Java. + (*jniEnv)->ReleaseBooleanArrayElements(jniEnv, outputArray, boolArr, 0); + // Note these gotos are in a specific order so they mirror the allocation pattern above. // They must be changed if the allocation code is rearranged. - cleanup_output_values: +cleanup_output_values: free(outputValues); // Release the Java output strings @@ -437,15 +448,15 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtTrainingSession_trainStep } // Release the buffers - cleanup_input_values: +cleanup_input_values: free((void*)inputValuePtrs); - cleanup_java_output_strings: +cleanup_java_output_strings: free(javaOutputStrings); - cleanup_java_input_strings: +cleanup_java_input_strings: free(javaInputStrings); - cleanup_output_names: +cleanup_output_names: free((void*)outputNames); - cleanup_input_names: +cleanup_input_names: free((void*)inputNames); return outputArray; @@ -454,12 +465,12 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtTrainingSession_trainStep /* * Class: ai_onnxruntime_OrtTrainingSession * Method: evalStep - * Signature: (JJJJ[Ljava/lang/String;[JJ[Ljava/lang/String;JJ)[Lai/onnxruntime/OnnxValue; + * Signature: (JJJJ[Ljava/lang/String;[JJ[Ljava/lang/String;J[Lai/onnxruntime/OnnxValue;[JJ)[Z */ -JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtTrainingSession_evalStep +JNIEXPORT jbooleanArray JNICALL Java_ai_onnxruntime_OrtTrainingSession_evalStep (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong trainApiHandle, jlong nativeHandle, jlong allocatorHandle, jobjectArray inputNamesArr, jlongArray inputHandles, jlong numInputs, - jobjectArray outputNamesArr, jlong numOutputs, jlong runOptionsHandle) { + jobjectArray outputNamesArr, jlong numOutputs, jobjectArray outputValuesArr, jlongArray outputHandlesArr, jlong runOptionsHandle) { (void)jobj; // Required JNI parameter not needed by functions which don't need to access their host object. const OrtApi* api = (const OrtApi*)apiHandle; const OrtTrainingApi* trainApi = (const OrtTrainingApi*)trainApiHandle; @@ -467,31 +478,31 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtTrainingSession_evalStep OrtTrainingSession* trainSession = (OrtTrainingSession*)nativeHandle; OrtRunOptions* runOptions = (OrtRunOptions*)runOptionsHandle; - jobjectArray outputArray = NULL; + jbooleanArray outputArray = NULL; // Create the buffers for the Java input & output strings, and the input pointers - const char** inputNames = malloc(sizeof(char*) * numInputs); + const char** inputNames = allocarray(numInputs, sizeof(char*)); if (inputNames == NULL) { // Nothing to cleanup, return and throw exception return outputArray; } - const char** outputNames = malloc(sizeof(char*) * numOutputs); + const char** outputNames = allocarray(numOutputs, sizeof(char*)); if (outputNames == NULL) { goto cleanup_input_names; } - jobject* javaInputStrings = malloc(sizeof(jobject) * numInputs); + jobject* javaInputStrings = allocarray(numInputs, sizeof(jobject)); if (javaInputStrings == NULL) { goto cleanup_output_names; } - jobject* javaOutputStrings = malloc(sizeof(jobject) * numOutputs); + jobject* javaOutputStrings = allocarray(numOutputs, sizeof(jobject)); if (javaOutputStrings == NULL) { goto cleanup_java_input_strings; } - const OrtValue** inputValuePtrs = malloc(sizeof(OrtValue*) * numInputs); + const OrtValue** inputValuePtrs = allocarray(numInputs, sizeof(OrtValue*)); if (inputValuePtrs == NULL) { goto cleanup_java_output_strings; } - OrtValue** outputValues = malloc(sizeof(OrtValue*) * numOutputs); + OrtValue** outputValues = allocarray(numOutputs, sizeof(OrtValue*)); if (outputValues == NULL) { goto cleanup_input_values; } @@ -512,11 +523,14 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtTrainingSession_evalStep // Release the java array copy of pointers to the tensors. (*jniEnv)->ReleaseLongArrayElements(jniEnv, inputHandles, inputValueLongs, JNI_ABORT); + // Extract a C array of longs which are pointers to the output tensors. + jlong* outputHandleLongs = (*jniEnv)->GetLongArrayElements(jniEnv, outputHandlesArr, NULL); + // Extract the names of the output values. for (int i = 0; i < numOutputs; i++) { javaOutputStrings[i] = (*jniEnv)->GetObjectArrayElement(jniEnv, outputNamesArr, i); outputNames[i] = (*jniEnv)->GetStringUTFChars(jniEnv, javaOutputStrings[i], NULL); - outputValues[i] = NULL; + outputValues[i] = (OrtValue*)outputHandleLongs[i]; } // Actually score the inputs. @@ -530,24 +544,29 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtTrainingSession_evalStep goto cleanup_output_values; } - // Construct the output array of ONNXValues - jclass onnxValueClass = (*jniEnv)->FindClass(jniEnv, "ai/onnxruntime/OnnxValue"); - outputArray = (*jniEnv)->NewObjectArray(jniEnv, safecast_int64_to_jsize(numOutputs), onnxValueClass, NULL); + // Create the output boolean array denoting if ORT owns the memory for each output. + // Java boolean arrays are initialized to false. + outputArray = (*jniEnv)->NewBooleanArray(jniEnv, safecast_int64_to_jsize(numOutputs)); + jboolean* boolArr = (*jniEnv)->GetBooleanArrayElements(jniEnv, outputArray, NULL); // Convert the output tensors into ONNXValues for (int i = 0; i < numOutputs; i++) { - if (outputValues[i] != NULL) { + if (outputValues[i] != NULL && (*jniEnv)->GetObjectArrayElement(jniEnv, outputValuesArr, i) == NULL) { jobject onnxValue = convertOrtValueToONNXValue(jniEnv, api, allocator, outputValues[i]); if (onnxValue == NULL) { break; // go to cleanup, exception thrown } - (*jniEnv)->SetObjectArrayElement(jniEnv, outputArray, i, onnxValue); + boolArr[i] = 1; + (*jniEnv)->SetObjectArrayElement(jniEnv, outputValuesArr, i, onnxValue); } } + // Write the output array back to Java. + (*jniEnv)->ReleaseBooleanArrayElements(jniEnv, outputArray, boolArr, 0); + // Note these gotos are in a specific order so they mirror the allocation pattern above. // They must be changed if the allocation code is rearranged. - cleanup_output_values: +cleanup_output_values: free(outputValues); // Release the Java output strings @@ -561,15 +580,15 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtTrainingSession_evalStep } // Release the buffers - cleanup_input_values: +cleanup_input_values: free((void*)inputValuePtrs); - cleanup_java_output_strings: +cleanup_java_output_strings: free(javaOutputStrings); - cleanup_java_input_strings: +cleanup_java_input_strings: free(javaInputStrings); - cleanup_output_names: +cleanup_output_names: free((void*)outputNames); - cleanup_input_names: +cleanup_input_names: free((void*)inputNames); return outputArray; diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java index 08d2a5698d579..e975117fb75bd 100644 --- a/java/src/test/java/ai/onnxruntime/InferenceTest.java +++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java @@ -6,11 +6,14 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNotSame; import static org.junit.jupiter.api.Assertions.assertSame; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; +import ai.onnxruntime.OrtException.OrtErrorCode; import ai.onnxruntime.OrtSession.Result; import ai.onnxruntime.OrtSession.SessionOptions; import ai.onnxruntime.OrtSession.SessionOptions.ExecutionMode; @@ -31,6 +34,8 @@ import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.Optional; @@ -71,7 +76,7 @@ public void environmentTest() { @Test public void testVersion() { String version = env.getVersion(); - Assertions.assertFalse(version.isEmpty()); + assertFalse(version.isEmpty()); } @Test @@ -749,6 +754,151 @@ public void testOverridingInitializer() throws OrtException { } } + @Test + public void testPinnedOutputs() throws OrtException { + String modelPath = TestHelpers.getResourcePath("/java-three-output-matmul.onnx").toString(); + FloatBuffer outputABuf = + ByteBuffer.allocateDirect(4 * 4).order(ByteOrder.nativeOrder()).asFloatBuffer(); + FloatBuffer outputBBuf = + ByteBuffer.allocateDirect(4 * 4).order(ByteOrder.nativeOrder()).asFloatBuffer(); + FloatBuffer outputCBuf = + ByteBuffer.allocateDirect(4 * 4).order(ByteOrder.nativeOrder()).asFloatBuffer(); + FloatBuffer tooSmallBuf = + ByteBuffer.allocateDirect(4 * 2).order(ByteOrder.nativeOrder()).asFloatBuffer(); + FloatBuffer tooBigBuf = + ByteBuffer.allocateDirect(4 * 6).order(ByteOrder.nativeOrder()).asFloatBuffer(); + FloatBuffer wrongShapeBuf = + ByteBuffer.allocateDirect(4 * 4).order(ByteOrder.nativeOrder()).asFloatBuffer(); + LongBuffer wrongTypeBuf = + ByteBuffer.allocateDirect(8 * 4).order(ByteOrder.nativeOrder()).asLongBuffer(); + + try (SessionOptions options = new SessionOptions()) { + try (OrtSession session = env.createSession(modelPath, options); + OnnxTensor t = OnnxTensor.createTensor(env, new float[][] {{1, 2, 3, 4}}); + OnnxTensor outputA = OnnxTensor.createTensor(env, outputABuf, new long[] {1, 4}); + OnnxTensor outputB = OnnxTensor.createTensor(env, outputBBuf, new long[] {1, 4}); + OnnxTensor outputC = OnnxTensor.createTensor(env, outputCBuf, new long[] {1, 4}); + OnnxTensor tooSmall = OnnxTensor.createTensor(env, tooSmallBuf, new long[] {1, 2}); + OnnxTensor tooBig = OnnxTensor.createTensor(env, tooBigBuf, new long[] {1, 6}); + OnnxTensor wrongShape = OnnxTensor.createTensor(env, wrongShapeBuf, new long[] {2, 2}); + OnnxTensor wrongType = OnnxTensor.createTensor(env, wrongTypeBuf, new long[] {1, 4})) { + Map inputMap = Collections.singletonMap("input", t); + Set requestedOutputs = new LinkedHashSet<>(); + Map pinnedOutputs = new LinkedHashMap<>(); + + // Test that all outputs can be pinned + pinnedOutputs.put("output-0", outputA); + pinnedOutputs.put("output-1", outputB); + pinnedOutputs.put("output-2", outputC); + try (OrtSession.Result r = session.run(inputMap, requestedOutputs, pinnedOutputs)) { + assertEquals(3, r.size()); + assertSame(outputA, r.get(0)); + assertSame(outputB, r.get(1)); + assertSame(outputC, r.get(2)); + assertFalse(r.isResultOwner(0)); + assertFalse(r.isResultOwner(1)); + assertFalse(r.isResultOwner(2)); + // More tests + } + TestHelpers.zeroBuffer(outputABuf); + TestHelpers.zeroBuffer(outputBBuf); + TestHelpers.zeroBuffer(outputCBuf); + requestedOutputs.clear(); + pinnedOutputs.clear(); + + // Test a single pinned output + pinnedOutputs.put("output-1", outputB); + try (OrtSession.Result r = session.run(inputMap, requestedOutputs, pinnedOutputs)) { + assertEquals(1, r.size()); + assertSame(outputB, r.get(0)); + assertSame(outputB, r.get("output-1").get()); + assertFalse(r.isResultOwner(0)); + // More tests + } + TestHelpers.zeroBuffer(outputABuf); + TestHelpers.zeroBuffer(outputBBuf); + TestHelpers.zeroBuffer(outputCBuf); + requestedOutputs.clear(); + pinnedOutputs.clear(); + + // Test a mixture of pinned and generated outputs + requestedOutputs.add("output-0"); + requestedOutputs.add("output-2"); + pinnedOutputs.put("output-1", outputB); + try (OrtSession.Result r = session.run(inputMap, requestedOutputs, pinnedOutputs)) { + assertEquals(3, r.size()); + // pinned outputs are first + assertSame(outputB, r.get(0)); + assertSame(outputB, r.get("output-1").get()); + // requested outputs are different + assertNotSame(outputA, r.get("output-0").get()); + assertNotSame(outputC, r.get("output-2").get()); + // check ownership. + assertFalse(r.isResultOwner(0)); + assertTrue(r.isResultOwner(1)); + assertTrue(r.isResultOwner(2)); + // More tests + } + TestHelpers.zeroBuffer(outputABuf); + TestHelpers.zeroBuffer(outputBBuf); + TestHelpers.zeroBuffer(outputCBuf); + requestedOutputs.clear(); + pinnedOutputs.clear(); + + // Test that overlapping names causes an error + requestedOutputs.add("output-1"); + pinnedOutputs.put("output-1", outputB); + try (OrtSession.Result r = session.run(inputMap, requestedOutputs, pinnedOutputs)) { + fail("Should have thrown OrtException"); + } catch (OrtException e) { + assertEquals(OrtErrorCode.ORT_JAVA_UNKNOWN, e.getCode()); + } + requestedOutputs.clear(); + pinnedOutputs.clear(); + + // Test that a tensor of the wrong type causes an error + pinnedOutputs.put("output-0", wrongType); + try (OrtSession.Result r = session.run(inputMap, requestedOutputs, pinnedOutputs)) { + fail("Should have thrown OrtException"); + } catch (OrtException e) { + assertEquals(OrtErrorCode.ORT_INVALID_ARGUMENT, e.getCode()); + } + requestedOutputs.clear(); + pinnedOutputs.clear(); + + // Test that a tensor of the wrong shape (but right capacity) causes an error. + pinnedOutputs.put("output-1", wrongShape); + try (OrtSession.Result r = session.run(inputMap, requestedOutputs, pinnedOutputs)) { + fail("Should have thrown OrtException"); + } catch (OrtException e) { + assertEquals(OrtErrorCode.ORT_INVALID_ARGUMENT, e.getCode()); + } + requestedOutputs.clear(); + pinnedOutputs.clear(); + + // Test that a tensor which is too small causes an error + pinnedOutputs.put("output-1", tooSmall); + try (OrtSession.Result r = session.run(inputMap, requestedOutputs, pinnedOutputs)) { + fail("Should have thrown OrtException"); + } catch (OrtException e) { + assertEquals(OrtErrorCode.ORT_INVALID_ARGUMENT, e.getCode()); + } + requestedOutputs.clear(); + pinnedOutputs.clear(); + + // Test that a tensor which is too large causes an error + pinnedOutputs.put("output-1", tooBig); + try (OrtSession.Result r = session.run(inputMap, requestedOutputs, pinnedOutputs)) { + fail("Should have thrown OrtException"); + } catch (OrtException e) { + assertEquals(OrtErrorCode.ORT_INVALID_ARGUMENT, e.getCode()); + } + requestedOutputs.clear(); + pinnedOutputs.clear(); + } + } + } + private static File getTestModelsDir() throws IOException { // get build directory, append downloaded models location String cwd = System.getProperty("user.dir"); diff --git a/java/src/test/java/ai/onnxruntime/ModelGenerators.java b/java/src/test/java/ai/onnxruntime/ModelGenerators.java index 90fda4c5cf610..7bf7cef43208a 100644 --- a/java/src/test/java/ai/onnxruntime/ModelGenerators.java +++ b/java/src/test/java/ai/onnxruntime/ModelGenerators.java @@ -182,6 +182,102 @@ public void generateMatMul() throws IOException { } } + public void generateThreeOutputMatmul() throws IOException { + OnnxMl.GraphProto.Builder graph = OnnxMl.GraphProto.newBuilder(); + graph.setName("ort-test-three-matmul"); + + // Add placeholders + OnnxMl.ValueInfoProto.Builder input = OnnxMl.ValueInfoProto.newBuilder(); + input.setName("input"); + OnnxMl.TypeProto inputType = + buildTensorTypeNode( + new long[] {-1, 4}, + new String[] {"batch_size", null}, + OnnxMl.TensorProto.DataType.FLOAT); + input.setType(inputType); + graph.addInput(input); + OnnxMl.ValueInfoProto.Builder outputA = OnnxMl.ValueInfoProto.newBuilder(); + outputA.setName("output-0"); + OnnxMl.TypeProto outputType = + buildTensorTypeNode( + new long[] {-1, 4}, + new String[] {"batch_size", null}, + OnnxMl.TensorProto.DataType.FLOAT); + outputA.setType(outputType); + graph.addOutput(outputA); + OnnxMl.ValueInfoProto.Builder outputB = OnnxMl.ValueInfoProto.newBuilder(); + outputB.setName("output-1"); + outputB.setType(outputType); + graph.addOutput(outputB); + OnnxMl.ValueInfoProto.Builder outputC = OnnxMl.ValueInfoProto.newBuilder(); + outputC.setName("output-2"); + outputC.setType(outputType); + graph.addOutput(outputC); + + // Add initializers + OnnxMl.TensorProto.Builder tensor = OnnxMl.TensorProto.newBuilder(); + tensor.addDims(4); + tensor.addDims(4); + Float[] floats = + new Float[] {1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f, 10f, 11f, 12f, 13f, 14f, 15f, 16f}; + tensor.addAllFloatData(Arrays.asList(floats)); + tensor.setDataType(OnnxMl.TensorProto.DataType.FLOAT.getNumber()); + tensor.setName("tensor"); + graph.addInitializer(tensor); + OnnxMl.TensorProto.Builder addInit = OnnxMl.TensorProto.newBuilder(); + addInit.addDims(4); + Float[] addFloats = new Float[] {1f, 2f, 3f, 4f}; + addInit.addAllFloatData(Arrays.asList(addFloats)); + addInit.setDataType(OnnxMl.TensorProto.DataType.FLOAT.getNumber()); + addInit.setName("add-init"); + graph.addInitializer(addInit); + + // Add operations + OnnxMl.NodeProto.Builder matmul = OnnxMl.NodeProto.newBuilder(); + matmul.setName("matmul-0"); + matmul.setOpType("MatMul"); + matmul.addInput("input"); + matmul.addInput("tensor"); + matmul.addOutput("matmul-output"); + graph.addNode(matmul); + + OnnxMl.NodeProto.Builder id = OnnxMl.NodeProto.newBuilder(); + id.setName("id-1"); + id.setOpType("Identity"); + id.addInput("matmul-output"); + id.addOutput("output-0"); + graph.addNode(id); + + OnnxMl.NodeProto.Builder add = OnnxMl.NodeProto.newBuilder(); + add.setName("add-2"); + add.setOpType("Add"); + add.addInput("matmul-output"); + add.addInput("add-init"); + add.addOutput("output-1"); + graph.addNode(add); + + OnnxMl.NodeProto.Builder log = OnnxMl.NodeProto.newBuilder(); + log.setName("log-3"); + log.setOpType("Log"); + log.addInput("matmul-output"); + log.addOutput("output-2"); + graph.addNode(log); + + // Build model + OnnxMl.ModelProto.Builder model = OnnxMl.ModelProto.newBuilder(); + model.setGraph(graph); + model.setDocString("ORT three output matmul test"); + model.setModelVersion(0); + model.setIrVersion(8); + model.setDomain("ai.onnxruntime.test"); + model.addOpsetImport(OnnxMl.OperatorSetIdProto.newBuilder().setVersion(18).build()); + try (OutputStream os = + Files.newOutputStream( + Paths.get("src", "test", "resources", "java-three-output-matmul.onnx"))) { + model.build().writeTo(os); + } + } + private static void genCast( String name, OnnxMl.TensorProto.DataType inputDataType, diff --git a/java/src/test/java/ai/onnxruntime/TestHelpers.java b/java/src/test/java/ai/onnxruntime/TestHelpers.java index 7d41918b1c6c7..55d8169434d48 100644 --- a/java/src/test/java/ai/onnxruntime/TestHelpers.java +++ b/java/src/test/java/ai/onnxruntime/TestHelpers.java @@ -262,6 +262,12 @@ public static Path getResourcePath(String path) { return new File(TestHelpers.class.getResource(path).getFile()).toPath(); } + public static void zeroBuffer(FloatBuffer buf) { + for (int i = 0; i < buf.capacity(); i++) { + buf.put(i, 0.0f); + } + } + public static float[] loadTensorFromFile(Path filename) { return loadTensorFromFile(filename, true); } diff --git a/java/src/test/java/ai/onnxruntime/TrainingTest.java b/java/src/test/java/ai/onnxruntime/TrainingTest.java index a02f5a88b2ac5..eaa7da1fc6a16 100644 --- a/java/src/test/java/ai/onnxruntime/TrainingTest.java +++ b/java/src/test/java/ai/onnxruntime/TrainingTest.java @@ -16,7 +16,6 @@ import java.util.Map; import java.util.Set; import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfSystemProperty; @@ -69,8 +68,6 @@ public void testCreateTrainingSessionWithEval() throws OrtException { } } - // this test is not enabled as ORT Java doesn't support supplying an output buffer - @Disabled @Test public void testTrainingSessionTrainStep() throws OrtException { String checkpointPath = TestHelpers.getResourcePath("/checkpoint.ckpt").toString(); @@ -99,14 +96,11 @@ public void testTrainingSessionTrainStep() throws OrtException { ByteBuffer.allocateDirect(4 * expectedOutput.length) .order(ByteOrder.nativeOrder()) .asFloatBuffer(); - OnnxTensor outputTensor = - OnnxTensor.createTensor(env, output, new long[expectedOutput.length]); + OnnxTensor outputTensor = OnnxTensor.createTensor(env, output, new long[0]); outputMap.put("onnx::loss::21273", outputTensor); - /* Disabled as we haven't implemented this yet - try (trainingSession.trainStep(pinnedInputs, outputMap)) { - Assertions.assertArrayEquals(expectedOutput, (float[]) outputTensor.getValue(), 1e-3f); + try (OrtSession.Result r = trainingSession.trainStep(pinnedInputs, outputMap)) { + Assertions.assertEquals(expectedOutput[0], (float) outputTensor.getValue(), 1e-3f); } - */ } finally { OnnxValue.close(outputMap); OnnxValue.close(pinnedInputs); diff --git a/java/src/test/resources/java-three-output-matmul.onnx b/java/src/test/resources/java-three-output-matmul.onnx new file mode 100644 index 0000000000000000000000000000000000000000..fed0bbca460cfed900548ed1bfcc5018d8a927b3 GIT binary patch literal 530 zcma)&%}T>S5P-K$T9!c}tSAyZq(KiBwx;ps#iTdEVi8YXG&M`Rke`xHKt1(Qd=#I; zvybA$gn$L1hmYNz{braM&fSAZkMb;gEy@gasz#{R=%3u(KRCE7lydSCS0y@WglU;L z)$i4p0Uq>pMset)%GP-y_G>}by3L!X=k})&PRj(&;jbcitxC@}bu7m&zljyKfNyZI zr2>!QSn5n;n>4n2Rm^vdFplADE1}hVyO-n(dFdLr`9d7#1 Date: Tue, 26 Sep 2023 09:28:17 -0700 Subject: [PATCH 30/58] [TensorRT EP] Back out the PerThreadContext (#17690) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Current TRT EP's PerthreadContext allows more than one IExecutionContext instance to be created by one engine instance. But, it's possible to hit an error that caused by TRT API context.setBindingDimensions() in our TRT EP code [here](https://nam06.safelinks.protection.outlook.com/?url=https%3A%2F%2Fgithub.com%2Fmicrosoft%2Fonnxruntime%2Fblob%2Fmain%2Fonnxruntime%2Fcore%2Fproviders%2Ftensorrt%2Ftensorrt_execution_provider.cc%23L2775&data=05%7C01%7CChi.Lo%40microsoft.com%7Cd8b23c3a4c0b4dcce9b408dbbd9309de%7C72f988bf86f141af91ab2d7cd011db47%7C1%7C0%7C638312211465211140%7CUnknown%7CTWFpbGZsb3d8eyJWIjoiMC4wLjAwMDAiLCJQIjoiV2luMzIiLCJBTiI6Ik1haWwiLCJXVCI6Mn0%3D%7C3000%7C%7C%7C&sdata=5EZoAoXgWFSuz%2BIRMH%2FXZaO%2BfKNP%2FZDZYEZg3W%2Ff30w%3D&reserved=0) under the case of the input shape changes ( meaning engine being rebuilt) with multithreading. From the [doc](https://nam06.safelinks.protection.outlook.com/?url=https%3A%2F%2Fdocs.nvidia.com%2Fdeeplearning%2Ftensorrt%2Fapi%2Fc_api%2Fclassnvinfer1_1_1_i_execution_context.html%23ada050e88320bcc40987b0acadc2ef962&data=05%7C01%7CChi.Lo%40microsoft.com%7Cd8b23c3a4c0b4dcce9b408dbbd9309de%7C72f988bf86f141af91ab2d7cd011db47%7C1%7C0%7C638312211465211140%7CUnknown%7CTWFpbGZsb3d8eyJWIjoiMC4wLjAwMDAiLCJQIjoiV2luMzIiLCJBTiI6Ik1haWwiLCJXVCI6Mn0%3D%7C3000%7C%7C%7C&sdata=%2BmVZU5iLD97B3YBPdHZP7jOQ2dGoleI3R0mSMVgopG4%3D&reserved=0) and the [discussion](https://nam06.safelinks.protection.outlook.com/?url=https%3A%2F%2Fgithub.com%2FNVIDIA%2FTensorRT%2Fissues%2F846&data=05%7C01%7CChi.Lo%40microsoft.com%7Cd8b23c3a4c0b4dcce9b408dbbd9309de%7C72f988bf86f141af91ab2d7cd011db47%7C1%7C0%7C638312211465211140%7CUnknown%7CTWFpbGZsb3d8eyJWIjoiMC4wLjAwMDAiLCJQIjoiV2luMzIiLCJBTiI6Ik1haWwiLCJXVCI6Mn0%3D%7C3000%7C%7C%7C&sdata=c8v%2FK2UkQ%2FNbf8w1sHNDGsB2kxw4sSmkyQ2QuCs8Fs8%3D&reserved=0), it seems we should have different OptimizationProfile for different IExecutionContext which our current TRT EP doesn’t support regardless of using PerThreadContext implementation. Back out the PerThreadContext until we completely solve this issue. --- .../tensorrt/tensorrt_execution_provider.cc | 120 ++++++------------ .../tensorrt/tensorrt_execution_provider.h | 17 +++ 2 files changed, 55 insertions(+), 82 deletions(-) diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 96893f63b4540..55204abc80187 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -1143,46 +1143,35 @@ bool TensorrtExecutionProvider::IsGraphCaptureEnabled() const { return cuda_graph_enable_; } -bool TensorrtExecutionProvider::IsGraphCaptured() const { - return GetPerThreadContext().IsGraphCaptured(); -} - -Status TensorrtExecutionProvider::ReplayGraph() { - return GetPerThreadContext().ReplayGraph(); -} - -void TensorrtExecutionProvider::PerThreadContext::SetGraphStream(cudaStream_t stream) { - cuda_graph_.SetStream(stream); -} - -bool TensorrtExecutionProvider::PerThreadContext::IsGraphCaptureAllowed() const { +bool TensorrtExecutionProvider::IsGraphCaptureAllowed() const { return regular_run_count_before_graph_capture_ >= min_num_runs_before_cuda_graph_capture_; } -void TensorrtExecutionProvider::PerThreadContext::CaptureBegin() { +void TensorrtExecutionProvider::CaptureBegin() { cuda_graph_.Reset(); cuda_graph_.CaptureBegin(); } -void TensorrtExecutionProvider::PerThreadContext::CaptureEnd() { +void TensorrtExecutionProvider::CaptureEnd() { cuda_graph_.CaptureEnd(); is_graph_captured_ = true; } -bool TensorrtExecutionProvider::PerThreadContext::IsGraphCaptured() const { +bool TensorrtExecutionProvider::IsGraphCaptured() const { return is_graph_captured_; } -Status TensorrtExecutionProvider::PerThreadContext::ReplayGraph() { +Status TensorrtExecutionProvider::ReplayGraph() { ORT_ENFORCE(IsGraphCaptured()); // Please note that CUDAGraph::Replay() is not thread safe. - // The cuda graph object is maintained by a per thread basis, + // ORT TRT calls ReplayGraph() in compute_func() where synchromization is enforced due to lock_guard(), // therefore calling CUDAGraph::Replay() here is guaranteed to be thread safe. return cuda_graph_.Replay(); } -void TensorrtExecutionProvider::PerThreadContext::IncrementRegularRunCountBeforeGraphCapture() { - // The cuda graph object is maintained by a per thread basis, +void TensorrtExecutionProvider::IncrementRegularRunCountBeforeGraphCapture() { + // Please note that this function is not thread safe. + // ORT TRT calls this function in compute_func() where synchronization is enforced due to lock_guard(), // therefore following increment is guaranteed to be thread safe. ++regular_run_count_before_graph_capture_; } @@ -1213,18 +1202,6 @@ Status TensorrtExecutionProvider::OnRunEnd(bool sync_stream) { if (sync_stream && external_stream_) { CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream_)); } - - // The reason of !IsGraphCaptureEnabled(): - // If cuda graph is enabled, the per thread context will not be released - // because the per thread cuda graph needs to be maintained and replayed for - // the next run. - // The reason of PerThreadContextCache()->find(this) != PerThreadContextCache()->end(): - // In extreme cases (e.g., 1-op graph and that op fallbacks to CPU), - // PerThreadContext won't be created and there is nothing to release. - if (!IsGraphCaptureEnabled() && - PerThreadContextCache()->find(this) != PerThreadContextCache()->end()) { - ReleasePerThreadContext(); - } return Status::OK(); } @@ -2384,6 +2361,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorallocate_func, context->release_func, context->allocator_handle, context->node_name, - &parsers_[context->node_name], &engines_[context->node_name], &builders_[context->node_name], + &parsers_[context->node_name], &engines_[context->node_name], &contexts_[context->node_name], &builders_[context->node_name], &networks_[context->node_name], input_info_[context->node_name], output_info_[context->node_name], input_shape_ranges_[context->node_name], &tensorrt_mu_, fp16_enable_, int8_enable_, int8_calibration_cache_available_, dla_enable_, dla_core_, &max_workspace_size_, trt_node_name_with_precision, engine_cache_enable_, cache_path_, @@ -2445,6 +2415,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorinput_shape_ranges; auto trt_builder = trt_state->builder->get(); auto trt_engine = trt_state->engine->get(); + auto trt_context = trt_state->context->get(); auto trt_profiles = trt_state->profiles; auto max_context_mem_size_ptr = trt_state->max_context_mem_size_ptr; int num_inputs = static_cast(input_indexes.size()); @@ -2502,7 +2473,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorengine->reset(); *(trt_state->engine) = std::unique_ptr( trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr)); - if (*(trt_state->engine) == nullptr) { + if (!(*(trt_state->engine))) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine."); } LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path; @@ -2527,7 +2498,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorengine->reset(); *(trt_state->engine) = std::unique_ptr(trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr)); - if (*(trt_state->engine) == nullptr) { + if (!(*(trt_state->engine))) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP could not deserialize engine from encrypted cache: " + encrypted_engine_cache_path); } @@ -2556,10 +2527,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorcontext->reset(); trt_state->engine->reset(); auto trt_config = std::unique_ptr(trt_builder->createBuilderConfig()); trt_config->setMaxWorkspaceSize(*(trt_state->max_workspace_size_ptr)); @@ -2660,7 +2628,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectortrt_node_name_with_precision << " took: " << std::chrono::duration_cast(engine_build_stop - engine_build_start).count() << "ms" << std::endl; } } - if (*(trt_state->engine) == nullptr) { + if (!(*(trt_state->engine))) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine."); } trt_engine = trt_state->engine->get(); @@ -2706,32 +2674,20 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector new_context; + if (context_update) { if (trt_state->context_memory_sharing_enable) { - new_context.reset(trt_state->engine->get()->createExecutionContextWithoutDeviceMemory()); + *(trt_state->context) = std::unique_ptr( + trt_state->engine->get()->createExecutionContextWithoutDeviceMemory()); } else { - new_context.reset(trt_state->engine->get()->createExecutionContext()); + *(trt_state->context) = std::unique_ptr( + trt_state->engine->get()->createExecutionContext()); } - auto context_status = GetPerThreadContext().UpdateTensorRTContext(fused_node_name, std::move(new_context)); - if (!context_status) { + if (!(*(trt_state->context))) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to create context."); } - GetPerThreadContext().UpdateProfileShapes(fused_node_name, shape_ranges); + trt_context = trt_state->context->get(); } - // Get the reference to the IExecutionContext object that is maintained on a per thread basis. - nvinfer1::IExecutionContext& trt_context = GetPerThreadContext().GetTensorRTContext(fused_node_name); - // Get input and output binding names int total_bindings = trt_engine->getNbBindings(); std::vector buffers(total_bindings); @@ -2767,12 +2723,12 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorisShapeBinding(binding_index)) { - trt_context.setInputShapeBinding(binding_index, &tensor_shape_values[input_name][0]); + trt_context->setInputShapeBinding(binding_index, &tensor_shape_values[input_name][0]); } else { for (int j = 0, end = nb_dims; j < end; ++j) { dimensions.d[j] = static_cast(tensor_shapes[j]); } - const bool status = trt_context.setBindingDimensions(binding_index, dimensions); + const bool status = trt_context->setBindingDimensions(binding_index, dimensions); if (!status) { ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP cannot set the dynamic dimensions of a binding")); @@ -2911,7 +2867,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorsecond; } - nvinfer1::Dims dimensions = trt_context.getBindingDimensions(static_cast(binding_index)); + nvinfer1::Dims dimensions = trt_context->getBindingDimensions(static_cast(binding_index)); int nb_dims = dimensions.nbDims; std::vector output_shapes(nb_dims); for (int j = 0, end = nb_dims; j < end; ++j) { @@ -3045,20 +3001,20 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector *max_context_mem_size_ptr) { *max_context_mem_size_ptr = mem_size; } - trt_context.setDeviceMemory(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, *max_context_mem_size_ptr).get()); + trt_context->setDeviceMemory(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, *max_context_mem_size_ptr).get()); } // Start CUDA graph capture. // Note: The reason we don't put graph capture in OnRunStart() like CUDA EP does is because // current ORT TRT doesn't get cuda stream until compute time and graph capture requires cuda stream. - if (cuda_graph_enable_ && GetPerThreadContext().IsGraphCaptureAllowed() && !GetPerThreadContext().IsGraphCaptured()) { + if (cuda_graph_enable_ && IsGraphCaptureAllowed() && !IsGraphCaptured()) { LOGS_DEFAULT(INFO) << "Capturing the cuda graph for this model"; - GetPerThreadContext().SetGraphStream(stream); - GetPerThreadContext().CaptureBegin(); + cuda_graph_.SetStream(stream); + CaptureBegin(); } // Run TRT inference - if (!trt_context.enqueueV2(&buffers[0], stream, nullptr)) { + if (!trt_context->enqueueV2(&buffers[0], stream, nullptr)) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "TensorRT EP execution context enqueue failed."); } @@ -3089,14 +3045,14 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector* parser = nullptr; std::unique_ptr* engine = nullptr; + std::unique_ptr* context = nullptr; std::unique_ptr* builder = nullptr; std::unique_ptr* network = nullptr; std::vector> input_info; @@ -246,6 +247,7 @@ class TensorrtExecutionProvider : public IExecutionProvider { // For those non thread safe operations, TRT EP uses (1) lock_guard or (2) PerThreadContext to make sure synchronization. std::unordered_map> parsers_; std::unordered_map> engines_; + std::unordered_map> contexts_; std::unordered_map> builders_; std::unordered_map> networks_; std::unordered_map>> input_info_; @@ -256,6 +258,21 @@ class TensorrtExecutionProvider : public IExecutionProvider { std::unordered_map input_shape_ranges_; // The profile shape ranges that the engine is built with std::unordered_map> profiles_; + // for external stream, we need to create its cudnn/cublass handle before cuda EP enable cuda graph capture + cudnnHandle_t external_cudnn_handle_ = nullptr; + cublasHandle_t external_cublas_handle_ = nullptr; + + CUDAGraph cuda_graph_; + bool is_graph_captured_ = false; + int regular_run_count_before_graph_capture_ = 0; + // There is chance (currently only happens in CUDA EP) that the second regular run allocates GPU memory for causes like: + // (1) memory pattern is enabled. (2) arena allocation for stream. + // Since no GPU memory allocation is allowed during graph capturing, we need at least two regular runs + // to allocate enough memory in Arena before graph capturing. + const int min_num_runs_before_cuda_graph_capture_ = 1; // required min regular runs before graph capture for the necessary memory allocations. + + // [Note] We don't use PerThreadContext for now since it has issue with multithreading + // // TRT or CUDA objects that must be maintained on a per thread basis will be put under this PerThreadContext data structure. // For example, TensorRT execution context and CUDA graph are the ones to be put here. class PerThreadContext final { From 1c245e6775201fec7a1d269a71deb0f3af6f7bea Mon Sep 17 00:00:00 2001 From: RandySheriffH <48490400+RandySheriffH@users.noreply.github.com> Date: Tue, 26 Sep 2023 09:46:30 -0700 Subject: [PATCH 31/58] Stop throwing exception on python binding when multiple EP available (#17659) Stop throwing the exception when the provider list is empty but there are multiple available EPs. Other language bindings throw no exception at all, this change will align them up. --------- Co-authored-by: Randy Shuai --- .../onnxruntime_inference_collection.py | 12 ++------ .../test/python/onnxruntime_test_python.py | 30 +++++-------------- 2 files changed, 9 insertions(+), 33 deletions(-) diff --git a/onnxruntime/python/onnxruntime_inference_collection.py b/onnxruntime/python/onnxruntime_inference_collection.py index 4124822adef1f..bcc6f15129231 100644 --- a/onnxruntime/python/onnxruntime_inference_collection.py +++ b/onnxruntime/python/onnxruntime_inference_collection.py @@ -438,7 +438,7 @@ def _create_inference_session(self, providers, provider_options, disabled_optimi # Tensorrt can fall back to CUDA if it's explicitly assigned. All others fall back to CPU. if "TensorrtExecutionProvider" in available_providers: - if any( + if providers and any( provider == "CUDAExecutionProvider" or (isinstance(provider, tuple) and provider[0] == "CUDAExecutionProvider") for provider in providers @@ -448,7 +448,7 @@ def _create_inference_session(self, providers, provider_options, disabled_optimi self._fallback_providers = ["CPUExecutionProvider"] # MIGraphX can fall back to ROCM if it's explicitly assigned. All others fall back to CPU. elif "MIGraphXExecutionProvider" in available_providers: - if any( + if providers and any( provider == "ROCMExecutionProvider" or (isinstance(provider, tuple) and provider[0] == "ROCMExecutionProvider") for provider in providers @@ -463,14 +463,6 @@ def _create_inference_session(self, providers, provider_options, disabled_optimi providers, provider_options = check_and_normalize_provider_args( providers, provider_options, available_providers ) - if not providers and len(available_providers) > 1: - self.disable_fallback() - raise ValueError( - f"This ORT build has {available_providers} enabled. " - "Since ORT 1.9, you are required to explicitly set " - "the providers parameter when instantiating InferenceSession. For example, " - f"onnxruntime.InferenceSession(..., providers={available_providers}, ...)" - ) session_options = self._sess_options if self._sess_options else C.get_default_session_options() if self._model_path: diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py index 59f7781bb4f8a..1d954fe4370ad 100644 --- a/onnxruntime/test/python/onnxruntime_test_python.py +++ b/onnxruntime/test/python/onnxruntime_test_python.py @@ -80,11 +80,7 @@ def test_model_serialization(self): so.log_severity_level = 1 so.logid = "TestModelSerialization" so.optimized_model_filepath = "./PythonApiTestOptimizedModel.onnx" - onnxrt.InferenceSession( - get_name("mul_1.onnx"), - sess_options=so, - providers=["CPUExecutionProvider"], - ) + onnxrt.InferenceSession(get_name("mul_1.onnx"), sess_options=so) self.assertTrue(os.path.isfile(so.optimized_model_filepath)) os.remove(so.optimized_model_filepath) except Fail as onnxruntime_error: @@ -107,11 +103,7 @@ def test_model_serialization_with_external_initializers(self): "session.optimized_model_external_initializers_file_name", external_initializers_file ) so.add_session_config_entry("session.optimized_model_external_initializers_min_size_in_bytes", "100") - onnxrt.InferenceSession( - get_name("mnist.onnx"), - sess_options=so, - providers=["CPUExecutionProvider"], - ) + onnxrt.InferenceSession(get_name("mnist.onnx"), sess_options=so) self.assertTrue(os.path.isfile(so.optimized_model_filepath)) self.assertTrue(os.path.isfile(external_initializers_file)) os.remove(so.optimized_model_filepath) @@ -137,7 +129,7 @@ def test_model_serialization_with_external_initializers_to_directory(self): "session.optimized_model_external_initializers_file_name", external_initializers_file ) so.add_session_config_entry("session.optimized_model_external_initializers_min_size_in_bytes", "100") - onnxrt.InferenceSession(get_name("mnist.onnx"), sess_options=so, providers=["CPUExecutionProvider"]) + onnxrt.InferenceSession(get_name("mnist.onnx"), sess_options=so) self.assertTrue(os.path.isfile(so.optimized_model_filepath)) self.assertTrue(os.path.isfile(os.path.join(directory, external_initializers_file))) os.remove(so.optimized_model_filepath) @@ -163,9 +155,7 @@ def test_model_serialization_with_original_external_initializers_to_directory(se "session.optimized_model_external_initializers_file_name", external_initializers_file ) so.add_session_config_entry("session.optimized_model_external_initializers_min_size_in_bytes", "100") - onnxrt.InferenceSession( - get_name("model_with_orig_ext_data.onnx"), sess_options=so, providers=["CPUExecutionProvider"] - ) + onnxrt.InferenceSession(get_name("model_with_orig_ext_data.onnx"), sess_options=so) self.assertTrue(os.path.isfile(so.optimized_model_filepath)) self.assertTrue(os.path.isfile(os.path.join(directory, external_initializers_file))) os.remove(so.optimized_model_filepath) @@ -198,9 +188,7 @@ def test_model_serialization_with_original_external_initializers_to_current_dire # still refers to the original external data file. We shall fix this issue so that the # optimized model only refers to one external data file. so.add_session_config_entry("session.optimized_model_external_initializers_min_size_in_bytes", "10") - session1 = onnxrt.InferenceSession( - get_name("model_with_orig_ext_data.onnx"), sess_options=so, providers=["CPUExecutionProvider"] - ) + session1 = onnxrt.InferenceSession(get_name("model_with_orig_ext_data.onnx"), sess_options=so) del session1 self.assertTrue(os.path.isfile(optimized_model_filepath)) self.assertTrue(os.path.isfile(external_initializers_file)) @@ -216,9 +204,7 @@ def test_model_serialization_with_original_external_initializers_to_current_dire # verify that we can load the optimized model with external data in current directory and save # optimized model with external data to current directory. - session2 = onnxrt.InferenceSession( - optimized_model_filepath, sess_options=so2, providers=["CPUExecutionProvider"] - ) + session2 = onnxrt.InferenceSession(optimized_model_filepath, sess_options=so2) del session2 self.assertTrue(os.path.isfile(optimized_model_filepath_2)) self.assertTrue(os.path.isfile(external_initializers_file_2)) @@ -227,9 +213,7 @@ def test_model_serialization_with_original_external_initializers_to_current_dire os.remove(optimized_model_filepath) os.remove(external_initializers_file) - session3 = onnxrt.InferenceSession( - optimized_model_filepath_2, sess_options=onnxrt.SessionOptions(), providers=["CPUExecutionProvider"] - ) + session3 = onnxrt.InferenceSession(optimized_model_filepath_2, sess_options=onnxrt.SessionOptions()) del session3 os.remove(optimized_model_filepath_2) From f43acf2d33ca4c2f87b0927929123ebfaed82b1a Mon Sep 17 00:00:00 2001 From: Kaz Nishimura Date: Wed, 27 Sep 2023 01:51:13 +0900 Subject: [PATCH 32/58] Close the JSON object in settings.json (#17583) ### Description This patch adds a closing curly bracket at the end of `settings.json`. ### Motivation and Context `settings.json` is just not closed. It was accidentally removed at 4e6ea730d633756e9e04df8968304d11a575dde4 --- .vscode/settings.json | 1 + 1 file changed, 1 insertion(+) diff --git a/.vscode/settings.json b/.vscode/settings.json index b7a1292efb2c6..fd28e2d7b335c 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -40,3 +40,4 @@ "-build/include_subdir", "-runtime/references" ] +} From b8e348145cbcfaa72172aa78b8c21f1544d3854c Mon Sep 17 00:00:00 2001 From: Vadym Stupakov Date: Tue, 26 Sep 2023 19:57:01 +0300 Subject: [PATCH 33/58] fixed #16873 (#16932) --- tools/perf_view/ort_perf_view.html | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tools/perf_view/ort_perf_view.html b/tools/perf_view/ort_perf_view.html index e00e38702d342..509fe5593f6a1 100644 --- a/tools/perf_view/ort_perf_view.html +++ b/tools/perf_view/ort_perf_view.html @@ -5,7 +5,7 @@ - +