diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java index e11537492d3a7..15d89b536b39a 100644 --- a/java/src/test/java/ai/onnxruntime/InferenceTest.java +++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java @@ -737,6 +737,7 @@ public void testCoreML() throws OrtException { runProvider(OrtProvider.CORE_ML); } + @Disabled("DirectML Java API hasn't been supported yet") @Test @EnabledIfSystemProperty(named = "USE_DML", matches = "1") public void testDirectML() throws OrtException { diff --git a/onnxruntime/test/common/cuda_op_test_utils.h b/onnxruntime/test/common/cuda_op_test_utils.h index 6f3e460628566..d3e069237217e 100644 --- a/onnxruntime/test/common/cuda_op_test_utils.h +++ b/onnxruntime/test/common/cuda_op_test_utils.h @@ -5,6 +5,11 @@ #include "test/util/include/default_providers.h" +#define SKIP_CUDA_TEST_WITH_DML \ + if (DefaultCudaExecutionProvider() == nullptr) { \ + GTEST_SKIP() << "CUDA Tests are not supported while DML is enabled"; \ + } + namespace onnxruntime { namespace test { @@ -13,6 +18,10 @@ namespace test { int GetCudaArchitecture(); inline bool HasCudaEnvironment(int min_cuda_architecture) { + if (DefaultCudaExecutionProvider() == nullptr) { + return false; + } + if (DefaultCudaExecutionProvider().get() == nullptr) { return false; } diff --git a/onnxruntime/test/contrib_ops/beam_search_test.cc b/onnxruntime/test/contrib_ops/beam_search_test.cc index 5f94d30112f0e..f6fc9ea7662cb 100644 --- a/onnxruntime/test/contrib_ops/beam_search_test.cc +++ b/onnxruntime/test/contrib_ops/beam_search_test.cc @@ -73,6 +73,9 @@ TEST(BeamSearchTest, GptBeamSearchFp32) { const char* const output_names[] = {"sequences"}; Ort::SessionOptions session_options; +#if defined(USE_CUDA) && defined(USE_DML) + SKIP_CUDA_TEST_WITH_DML; +#endif #ifdef USE_CUDA OrtCUDAProviderOptionsV2 cuda_options; cuda_options.use_tf32 = false; @@ -166,6 +169,9 @@ TEST(BeamSearchTest, GptBeamSearchFp16) { bool enable_rocm = (nullptr != DefaultRocmExecutionProvider().get()); if (enable_cuda || enable_rocm) { Ort::SessionOptions session_options; +#if defined(USE_CUDA) && defined(USE_DML) + SKIP_CUDA_TEST_WITH_DML; +#endif #ifdef USE_CUDA OrtCUDAProviderOptionsV2 cuda_options; cuda_options.use_tf32 = false; diff --git a/onnxruntime/test/contrib_ops/bias_dropout_op_test.cc b/onnxruntime/test/contrib_ops/bias_dropout_op_test.cc index 027d4b3fff1b0..297629b015796 100644 --- a/onnxruntime/test/contrib_ops/bias_dropout_op_test.cc +++ b/onnxruntime/test/contrib_ops/bias_dropout_op_test.cc @@ -181,6 +181,9 @@ void RunBiasDropoutTest(const bool use_mask, const std::vector& input_s t.SetCustomOutputVerifier(output_verifier); std::vector> t_eps; #ifdef USE_CUDA + if (DefaultCudaExecutionProvider() == nullptr) { + return; + } t_eps.emplace_back(DefaultCudaExecutionProvider()); #elif USE_ROCM t_eps.emplace_back(DefaultRocmExecutionProvider()); diff --git a/onnxruntime/test/contrib_ops/bitmask_dropout_op_test.cc b/onnxruntime/test/contrib_ops/bitmask_dropout_op_test.cc index 7ca4e1004066c..26b0e3a4dd7a9 100644 --- a/onnxruntime/test/contrib_ops/bitmask_dropout_op_test.cc +++ b/onnxruntime/test/contrib_ops/bitmask_dropout_op_test.cc @@ -61,7 +61,9 @@ void RunTestForInference(const std::vector& input_dims, bool has_ratio std::vector> test_eps; #ifdef USE_CUDA - test_eps.emplace_back(DefaultCudaExecutionProvider()); + if (DefaultCudaExecutionProvider() != nullptr) { + test_eps.emplace_back(DefaultCudaExecutionProvider()); + } #elif USE_ROCM test_eps.emplace_back(DefaultRocmExecutionProvider()); #endif @@ -122,6 +124,9 @@ void RunTestForTraining(const std::vector& input_dims) { std::vector> dropout_eps; #ifdef USE_CUDA + if (DefaultCudaExecutionProvider() == nullptr) { + return; + } dropout_eps.emplace_back(DefaultCudaExecutionProvider()); #elif USE_ROCM dropout_eps.emplace_back(DefaultRocmExecutionProvider()); diff --git a/onnxruntime/test/contrib_ops/layer_norm_test.cc b/onnxruntime/test/contrib_ops/layer_norm_test.cc index 46082e1b0cd31..b414a98c4e756 100644 --- a/onnxruntime/test/contrib_ops/layer_norm_test.cc +++ b/onnxruntime/test/contrib_ops/layer_norm_test.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "test/providers/compare_provider_test_utils.h" +#include "test/util/include/default_providers.h" namespace onnxruntime { namespace test { @@ -79,14 +80,20 @@ static void TestLayerNorm(const std::vector& x_dims, #endif #ifdef USE_CUDA - test.CompareWithCPU(kCudaExecutionProvider); + if (DefaultCudaExecutionProvider() != nullptr) { + test.CompareWithCPU(kCudaExecutionProvider); + } #elif USE_ROCM test.CompareWithCPU(kRocmExecutionProvider); -#elif USE_DML - test.CompareWithCPU(kDmlExecutionProvider); #elif USE_WEBGPU test.CompareWithCPU(kWebGpuExecutionProvider); #endif + +#ifdef USE_DML + if (DefaultDmlExecutionProvider() != nullptr) { + test.CompareWithCPU(kDmlExecutionProvider); + } +#endif } TEST(CudaKernelTest, LayerNorm_NullInput) { diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index eb6e316202da9..7d99ffab9a88f 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -489,13 +489,17 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, int64_t accura std::vector> execution_providers; if (use_float16) { #ifdef USE_CUDA - execution_providers.push_back(DefaultCudaExecutionProvider()); + if (DefaultCudaExecutionProvider() != nullptr) { + execution_providers.push_back(DefaultCudaExecutionProvider()); + } #endif #ifdef USE_ROCM execution_providers.push_back(DefaultRocmExecutionProvider()); #endif #ifdef USE_DML - execution_providers.push_back(DefaultDmlExecutionProvider()); + if (DefaultDmlExecutionProvider() != nullptr) { + execution_providers.push_back(DefaultDmlExecutionProvider()); + } #endif #ifdef USE_WEBGPU execution_providers.push_back(DefaultWebGpuExecutionProvider()); @@ -513,8 +517,11 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, int64_t accura } // namespace TEST(MatMulNBits, Float16Cuda) { -#if defined(USE_CUDA) || defined(USE_ROCM) - auto has_gidx_options = {true, false}; +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML) + std::vector has_gidx_options = {true, false}; + if (DefaultDmlExecutionProvider() != nullptr) { + has_gidx_options.assign(1, false); + } #else auto has_gidx_options = {false}; #endif @@ -525,7 +532,9 @@ TEST(MatMulNBits, Float16Cuda) { for (auto block_size : {16, 32, 64, 128}) { for (auto has_gidx : has_gidx_options) { #ifdef USE_DML - RunTest(M, N, K, block_size, 0, false, true, has_gidx, true, 0.04f); + if (DefaultDmlExecutionProvider() != nullptr) { + RunTest(M, N, K, block_size, 0, false, true, has_gidx, true, 0.04f); + } #else RunTest(M, N, K, block_size, 0, false, true, has_gidx); RunTest(M, N, K, block_size, 0, true, true, has_gidx, false); @@ -538,12 +547,16 @@ TEST(MatMulNBits, Float16Cuda) { } TEST(MatMulNBits, Float16Large) { -#ifdef USE_DML +#if defined(USE_CUDA) || defined(USE_DML) // For some reason, the A10 machine that runs these tests during CI has a much bigger error than all retail // machines we tested on. All consumer-grade machines from Nvidia/AMD/Intel seem to pass these tests with an // absolute error of 0.08, but the A10 has errors going as high as 0.22. Ultimately, given the large number // of elements in this test, ULPs should probably be used instead of absolute/relative tolerances. - float abs_error = 0.3f; + float abs_error = 0.05f; + if (DefaultDmlExecutionProvider() != nullptr) { + // it means the ep is dml in runtime, the abs_error is changed to 0.3f + abs_error = 0.3f; + } #elif USE_WEBGPU // See Intel A770 to pass these tests with an absolute error of 0.08. float abs_error = 0.08f; @@ -559,7 +572,6 @@ TEST(MatMulNBits, Float16Large) { } } } - #endif // defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML) } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/matmul_integer_to_float_test.cc b/onnxruntime/test/contrib_ops/matmul_integer_to_float_test.cc index 8d7629b5fda1c..d88c3131a4ca5 100644 --- a/onnxruntime/test/contrib_ops/matmul_integer_to_float_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_integer_to_float_test.cc @@ -227,7 +227,7 @@ TEST(MatMulIntegerToFloat, HasZeroPoint_HasBias_test_U8S8) { } // DML EP supports Float16 output type and Signed A Matrix and Unsigned B Matric for Float32 output -#if defined(USE_DML) +#if defined(USE_DML) && !defined(USE_CUDA) TEST(MatMulIntegerToFloat, HasZeroPoint_NoBias_test_S8U8) { RunMatMulIntegerToFloatTest(); diff --git a/onnxruntime/test/contrib_ops/tensor_op_test.cc b/onnxruntime/test/contrib_ops/tensor_op_test.cc index bc2ff5f4f724d..d5e2ddebfe67f 100644 --- a/onnxruntime/test/contrib_ops/tensor_op_test.cc +++ b/onnxruntime/test/contrib_ops/tensor_op_test.cc @@ -121,7 +121,15 @@ void MeanVarianceNormalizationAcrossChannels(bool across_channels, bool normaliz test.AddAttribute("normalize_variance", normalize_variance ? one : zero); test.AddInput("input", {N, C, H, W}, X); test.AddOutput("output", {N, C, H, W}, result); +#if defined(USE_CUDA) && defined(USE_DML) + if (DefaultCudaExecutionProvider() == nullptr) { + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider, kCudaExecutionProvider, kTensorrtExecutionProvider}); + } else if (DefaultDmlExecutionProvider() == nullptr) { + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider, kDmlExecutionProvider, kTensorrtExecutionProvider}); + } +#else test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider, kTensorrtExecutionProvider}); // OpenVINO doesn't support MVN operator below opset 9. TensorRT doesn't support opset 8 of MVN operator. +#endif } void MeanVarianceNormalizationPerChannel(bool across_channels, bool normalize_variance) { @@ -188,7 +196,15 @@ void MeanVarianceNormalizationPerChannel(bool across_channels, bool normalize_va test.AddAttribute("normalize_variance", normalize_variance ? one : zero); test.AddInput("input", {N, C, H, W}, X); test.AddOutput("output", {N, C, H, W}, result); +#if defined(USE_CUDA) && defined(USE_DML) + if (DefaultCudaExecutionProvider() == nullptr) { + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider, kCudaExecutionProvider, kTensorrtExecutionProvider}); + } else if (DefaultDmlExecutionProvider() == nullptr) { + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider, kDmlExecutionProvider, kTensorrtExecutionProvider}); + } +#else test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider, kTensorrtExecutionProvider}); // OpenVINO doesn't support MVN operator below opset 9. TensorRT doesn't support opset 8 of MVN operator. +#endif } TEST(MVNContribOpTest, MeanVarianceNormalizationCPUTest_Version1_TO_8) { @@ -230,7 +246,9 @@ TEST(UnfoldTensorOpTest, LastDim) { std::vector> execution_providers; #ifdef USE_CUDA - execution_providers.push_back(DefaultCudaExecutionProvider()); + if (DefaultCudaExecutionProvider() != nullptr) { + execution_providers.push_back(DefaultCudaExecutionProvider()); + } #endif execution_providers.push_back(DefaultCpuExecutionProvider()); tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); diff --git a/onnxruntime/test/framework/allocation_planner_test.cc b/onnxruntime/test/framework/allocation_planner_test.cc index 0105e90b5a24a..a7f8a6424aa50 100644 --- a/onnxruntime/test/framework/allocation_planner_test.cc +++ b/onnxruntime/test/framework/allocation_planner_test.cc @@ -28,6 +28,7 @@ using json = nlohmann::json; #ifdef USE_CUDA #include "core/providers/cuda/cuda_execution_provider.h" #include "core/providers/cuda/cuda_provider_factory.h" +#include "test/common/cuda_op_test_utils.h" #endif // USE_CUDA #include "core/session/onnxruntime_session_options_config_keys.h" using namespace ONNX_NAMESPACE; @@ -894,6 +895,9 @@ TEST_F(PlannerTest, LocationPlanningForPassThroughExplicitAndImplicitSubgraphInp SessionOptions so; InferenceSession sess{so, GetEnvironment()}; + if (DefaultCudaExecutionProvider() == nullptr) { + return; + } auto status = sess.RegisterExecutionProvider(DefaultCudaExecutionProvider()); ASSERT_TRUE(status.IsOK()); @@ -1036,6 +1040,9 @@ TEST_F(PlannerTest, LocationPlanningForInitializersOnlyUsedInANestedSubgraph) { SessionOptions so; InferenceSession sess{so, GetEnvironment()}; + if (DefaultCudaExecutionProvider() == nullptr) { + return; + } auto status = sess.RegisterExecutionProvider(DefaultCudaExecutionProvider()); ASSERT_TRUE(status.IsOK()); @@ -1143,6 +1150,9 @@ TEST_F(PlannerTest, LocationPlanningForInitializersUsedOnDifferentDevicesInMainG SessionOptions so; InferenceSession sess{so, GetEnvironment()}; + if (DefaultCudaExecutionProvider() == nullptr) { + return; + } auto status = sess.RegisterExecutionProvider(DefaultCudaExecutionProvider()); ASSERT_TRUE(status.IsOK()); @@ -1235,6 +1245,9 @@ TEST_F(PlannerTest, LocationPlanningForImplicitInputsWithoutExplicitConsumersInM SessionOptions so; InferenceSession sess{so, GetEnvironment()}; + if (DefaultCudaExecutionProvider() == nullptr) { + return; + } auto status = sess.RegisterExecutionProvider(DefaultCudaExecutionProvider()); ASSERT_TRUE(status.IsOK()); @@ -1267,6 +1280,10 @@ TEST_F(PlannerTest, LocationPlanningForImplicitInputsWithoutExplicitConsumersInM // Test MultiStream scenario for the graph: // node1(CPU ep)->node2(CPU ep)->node3(CUDA ep)->node4(CPU ep) TEST_F(PlannerTest, MultiStream) { +#if defined(USE_CUDA) && defined(USE_DML) + SKIP_CUDA_TEST_WITH_DML; +#endif + ONNX_NAMESPACE::TensorProto tensor; tensor.add_dims(1); tensor.add_float_data(1.0f); @@ -1285,6 +1302,7 @@ TEST_F(PlannerTest, MultiStream) { onnxruntime::ProviderInfo_CUDA& ep = onnxruntime::GetProviderInfo_CUDA(); auto epFactory = ep.CreateExecutionProviderFactory(epi); std::unique_ptr execution_provider = epFactory->CreateProvider(); + ORT_THROW_IF_ERROR(GetExecutionProviders().Add("CUDAExecutionProvider", std::move(execution_provider))); CreatePlan({}, false); @@ -1312,6 +1330,9 @@ TEST_F(PlannerTest, MultiStream) { // node3 // All 3 nodes are CUDA EP, node1 is in stream0, node2 is in stream1, node3 is in stream2 TEST_F(PlannerTest, MultiStream1StreamWaitFor2Streams) { +#if defined(USE_CUDA) && defined(USE_DML) + SKIP_CUDA_TEST_WITH_DML; +#endif std::unique_ptr<::onnxruntime::KernelDef> cudaKernel = KernelDefBuilder().SetName("Transpose").Provider(kCudaExecutionProvider).SinceVersion(1, 10).Build(); std::unique_ptr<::onnxruntime::KernelDef> cudaKernelAdd = KernelDefBuilder().SetName("Add").Provider(kCudaExecutionProvider).SinceVersion(1, 10).Build(); std::string Graph_input("Graph_input"), Arg1("Arg1"), Arg2("Arg2"), Arg3("Arg3"), node1("node1"), node2("node2"), node3("node3"); @@ -1353,6 +1374,9 @@ TEST_F(PlannerTest, MultiStream1StreamWaitFor2Streams) { // stream 1: node2 (CPU EP) // node1's output, which is consumed by both node2 and node3, is in CPU. TEST_F(PlannerTest, MultiStreamCudaEPNodeCPUOutput) { +#if defined(USE_CUDA) && defined(USE_DML) + SKIP_CUDA_TEST_WITH_DML; +#endif MemcpyToHostInCuda_TransposeInCudaAndCpu("./testdata/multi_stream_models/memcpyToHost_same_stream_with_transpose.json"); EXPECT_EQ(GetState().GetExecutionPlan()->execution_plan.size(), 2) << "2 logic streams"; EXPECT_EQ(GetState().GetExecutionPlan()->execution_plan[0]->steps_.size(), 5) << "stream 0 has 5 steps"; @@ -1374,6 +1398,11 @@ TEST_F(PlannerTest, MultiStreamCudaEPNodeCPUOutput) { // TODO(leca): there is a bug in the corresponding graph that node2 will be visited twice when traversing node1's output nodes // (see: for (auto it = node->OutputNodesBegin(); it != node->OutputNodesEnd(); ++it) in BuildExecutionPlan()). We can just break the loop and don't need the extra variables once it is fixed TEST_F(PlannerTest, MultiStreamMultiOutput) { +#if defined(USE_CUDA) && defined(USE_DML) + if (DefaultCudaExecutionProvider() == nullptr) { + return; + } +#endif std::unique_ptr<::onnxruntime::KernelDef> cudaKernel = KernelDefBuilder().SetName("RNN").Provider(kCudaExecutionProvider).SinceVersion(7).Build(); std::string Graph_input1("Graph_input1"), Graph_input2("Graph_input2"), Graph_input3("Graph_input3"), Arg1("Arg1"), Arg2("Arg2"), Arg3("Arg3"), node1("node1"), node2("node2"); std::vector input1{Arg(Graph_input1), Arg(Graph_input2), Arg(Graph_input3)}, output1{Arg(Arg1), Arg(Arg2)}, input2{Arg(Arg1), Arg(Arg2)}, output2{Arg(Arg3)}; @@ -1411,6 +1440,9 @@ TEST_F(PlannerTest, MultiStreamMultiOutput) { // TODO(leca): the ideal case is there is only 1 wait step before launching node3, // as there is a specific order between node1 and node2 if they are in the same stream, thus node3 will only need to wait the latter one TEST_F(PlannerTest, MultiStream2NodesSameStreamConsumedBy1NodeInDifferentStream) { +#if defined(USE_CUDA) && defined(USE_DML) + SKIP_CUDA_TEST_WITH_DML; +#endif std::unique_ptr<::onnxruntime::KernelDef> cudaKernel = KernelDefBuilder().SetName("Transpose").Provider(kCudaExecutionProvider).SinceVersion(1, 10).Build(); std::string Graph_input1("Graph_input1"), Graph_input2("Graph_input2"), Graph_input3("Graph_input3"), Arg1("Arg1"), Arg2("Arg2"), Arg3("Arg3"), node1("node1"), node2("node2"), node3("node3"); std::vector input1{Arg(Graph_input1)}, input2{Arg(Graph_input2)}, output1{Arg(Arg1)}, output2{Arg(Arg2)}, input3{Arg(Arg1), Arg(Arg2)}, output3{Arg(Arg3)}; @@ -1448,6 +1480,9 @@ TEST_F(PlannerTest, MultiStream2NodesSameStreamConsumedBy1NodeInDifferentStream) #if !defined(__wasm__) && defined(ORT_ENABLE_STREAM) TEST_F(PlannerTest, ParaPlanCreation) { +#if defined(USE_CUDA) && defined(USE_DML) + SKIP_CUDA_TEST_WITH_DML; +#endif TypeProto graph_in_type; graph_in_type.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT); auto* graph_in_shape = graph_in_type.mutable_tensor_type()->mutable_shape(); @@ -1889,6 +1924,10 @@ TEST_F(PlannerTest, ParaPlanCreation) { } TEST_F(PlannerTest, TestMultiStreamConfig) { +#if defined(USE_CUDA) && defined(USE_DML) + SKIP_CUDA_TEST_WITH_DML; +#endif + const char* type = "DeviceBasedPartitioner"; constexpr size_t type_len = 22; @@ -1962,6 +2001,10 @@ TEST_F(PlannerTest, TestMultiStreamSaveConfig) { // Load with partition config where a node is missing, session load expected to fail. TEST_F(PlannerTest, TestMultiStreamMissingNodeConfig) { +#if defined(USE_CUDA) && defined(USE_DML) + SKIP_CUDA_TEST_WITH_DML; +#endif + const char* config_file_path = "./testdata/multi_stream_models/conv_add_relu_single_stream_missing_node.json"; SessionOptions sess_opt; sess_opt.graph_optimization_level = TransformerLevel::Default; @@ -1982,6 +2025,9 @@ TEST_F(PlannerTest, TestMultiStreamMissingNodeConfig) { // Load with partition config where streams and devices has mismatch TEST_F(PlannerTest, TestMultiStreamMismatchDevice) { +#if defined(USE_CUDA) && defined(USE_DML) + SKIP_CUDA_TEST_WITH_DML; +#endif const char* config_file_path = "./testdata/multi_stream_models/conv_add_relu_single_stream_mismatch_device.json"; SessionOptions sess_opt; sess_opt.graph_optimization_level = TransformerLevel::Default; @@ -2007,6 +2053,9 @@ TEST_F(PlannerTest, TestCpuIf) { sess_opt.graph_optimization_level = TransformerLevel::Default; InferenceSession sess(sess_opt, GetEnvironment(), ORT_TSTR("./testdata/multi_stream_models/cpu_if.onnx")); + if (DefaultCudaExecutionProvider() == nullptr) { + return; + } ASSERT_STATUS_OK(sess.RegisterExecutionProvider(DefaultCudaExecutionProvider())); ASSERT_STATUS_OK(sess.Load()); ASSERT_STATUS_OK(sess.Initialize()); @@ -2067,10 +2116,17 @@ TEST_F(PlannerTest, TestCpuIf) { // onnx.save(model, 'issue_19480.onnx') // TEST(AllocationPlannerTest, ReusedInputCrossDifferentStreams) { +#if defined(USE_CUDA) && defined(USE_DML) + SKIP_CUDA_TEST_WITH_DML; +#endif + SessionOptions sess_opt; sess_opt.graph_optimization_level = TransformerLevel::Default; InferenceSession sess(sess_opt, GetEnvironment(), ORT_TSTR("./testdata/multi_stream_models/issue_19480.onnx")); + if (DefaultCudaExecutionProvider() == nullptr) { + return; + } auto status = sess.RegisterExecutionProvider(DefaultCudaExecutionProvider()); status = sess.Load(); status = sess.Initialize(); diff --git a/onnxruntime/test/framework/cuda/fence_cuda_test.cc b/onnxruntime/test/framework/cuda/fence_cuda_test.cc index e28327941dda4..3e5ef30e7ebef 100644 --- a/onnxruntime/test/framework/cuda/fence_cuda_test.cc +++ b/onnxruntime/test/framework/cuda/fence_cuda_test.cc @@ -115,6 +115,9 @@ TEST(CUDAFenceTests, DISABLED_PartOnCPU) { SessionOptions so; FenceCudaTestInferenceSession session(so, GetEnvironment()); ASSERT_STATUS_OK(LoadInferenceSessionFromModel(session, *model)); + if (DefaultCudaExecutionProvider() == nullptr) { + return; + } ASSERT_STATUS_OK(session.RegisterExecutionProvider(DefaultCudaExecutionProvider())); ASSERT_TRUE(session.Initialize().IsOK()); ASSERT_TRUE(1 == CountCopyNodes(graph)); @@ -164,6 +167,9 @@ TEST(CUDAFenceTests, TileWithInitializer) { SessionOptions so; FenceCudaTestInferenceSession session(so, GetEnvironment()); ASSERT_STATUS_OK(LoadInferenceSessionFromModel(session, *model)); + if (DefaultCudaExecutionProvider() == nullptr) { + return; + } ASSERT_STATUS_OK(session.RegisterExecutionProvider(DefaultCudaExecutionProvider())); ASSERT_STATUS_OK(session.Initialize()); @@ -224,6 +230,9 @@ TEST(CUDAFenceTests, TileWithComputedInput) { SessionOptions so; FenceCudaTestInferenceSession session(so, GetEnvironment()); ASSERT_STATUS_OK(LoadInferenceSessionFromModel(session, *model)); + if (DefaultCudaExecutionProvider() == nullptr) { + return; + } ASSERT_STATUS_OK(session.RegisterExecutionProvider(DefaultCudaExecutionProvider())); ASSERT_TRUE(session.Initialize().IsOK()); diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index e57864cecb850..9c7e6e9761728 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -34,6 +34,7 @@ #ifdef USE_CUDA #include "core/providers/cuda/cuda_provider_factory.h" #include "core/providers/cuda/gpu_data_transfer.h" +#include "test/common/cuda_op_test_utils.h" #endif #ifdef USE_TENSORRT #include "core/providers/tensorrt/tensorrt_provider_options.h" @@ -689,6 +690,9 @@ TEST(InferenceSessionTests, CheckRunProfilerWithSessionOptions) { InferenceSession session_object(so, GetEnvironment()); #ifdef USE_CUDA + if (DefaultCudaExecutionProvider() == nullptr) { + return; + } ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultCudaExecutionProvider())); #endif #ifdef USE_ROCM @@ -743,6 +747,9 @@ TEST(InferenceSessionTests, CheckRunProfilerWithSessionOptions2) { InferenceSession session_object(so, GetEnvironment()); #ifdef USE_CUDA + if (DefaultCudaExecutionProvider() == nullptr) { + return; + } ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultCudaExecutionProvider())); #endif #ifdef USE_ROCM @@ -1055,6 +1062,9 @@ static void TestBindHelper(const std::string& log_str, if (bind_provider_type == kCudaExecutionProvider || bind_provider_type == kRocmExecutionProvider) { #ifdef USE_CUDA auto provider = DefaultCudaExecutionProvider(); + if (provider == nullptr) { + return; + } gpu_provider = provider.get(); ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(std::move(provider))); #endif @@ -1650,6 +1660,9 @@ TEST(InferenceSessionTests, Test3LayerNestedSubgraph) { #if USE_TENSORRT ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultTensorrtExecutionProvider())); #elif USE_CUDA + if (DefaultCudaExecutionProvider() == nullptr) { + return; + } ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultCudaExecutionProvider())); #elif USE_ROCM ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultRocmExecutionProvider())); @@ -1802,6 +1815,9 @@ TEST(InferenceSessionTests, Test2LayerNestedSubgraph) { #if USE_TENSORRT ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultTensorrtExecutionProvider())); #elif USE_CUDA + if (DefaultCudaExecutionProvider() == nullptr) { + return; + } ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultCudaExecutionProvider())); #elif USE_ROCM ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultRocmExecutionProvider())); @@ -2157,6 +2173,9 @@ TEST(InferenceSessionTests, TestStrictShapeInference) { #ifdef USE_CUDA // disable it, since we are going to enable parallel execution with cuda ep TEST(InferenceSessionTests, DISABLED_TestParallelExecutionWithCudaProvider) { +#if defined(USE_CUDA) && defined(USE_DML) + SKIP_CUDA_TEST_WITH_DML; +#endif string model_uri = "testdata/transform/fusion/fuse-conv-bn-mul-add-unsqueeze.onnx"; SessionOptions so; @@ -2180,6 +2199,10 @@ TEST(InferenceSessionTests, DISABLED_TestParallelExecutionWithCudaProvider) { } TEST(InferenceSessionTests, TestArenaShrinkageAfterRun) { +#if defined(USE_CUDA) && defined(USE_DML) + SKIP_CUDA_TEST_WITH_DML; +#endif + OrtArenaCfg arena_cfg; arena_cfg.arena_extend_strategy = 1; // kSameAsRequested diff --git a/onnxruntime/test/framework/memcpy_transformer_test.cc b/onnxruntime/test/framework/memcpy_transformer_test.cc index 6e86e5b58aead..2313f00e4d123 100644 --- a/onnxruntime/test/framework/memcpy_transformer_test.cc +++ b/onnxruntime/test/framework/memcpy_transformer_test.cc @@ -9,6 +9,9 @@ #include "default_providers.h" #include "gtest/gtest.h" #include "test_utils.h" +#ifdef USE_CUDA +#include "test/common/cuda_op_test_utils.h" +#endif #include "test/test_environment.h" #include "asserts.h" @@ -74,6 +77,9 @@ void ExpectCopy(const onnxruntime::Node& source, const std::string copy_op, #ifdef USE_CUDA TEST(TransformerTest, MemcpyTransformerTest) { +#if defined(USE_CUDA) && defined(USE_DML) + SKIP_CUDA_TEST_WITH_DML; +#endif std::unordered_map domain_to_version; domain_to_version[kOnnxDomain] = 7; auto model = std::make_shared("test", false, ModelMetaData(), PathString(), @@ -106,7 +112,9 @@ TEST(TransformerTest, MemcpyTransformerTest) { KernelRegistryManager kernel_registry_manager; ExecutionProviders execution_providers; +#if defined(USE_CUDA) ASSERT_STATUS_OK(execution_providers.Add(onnxruntime::kCudaExecutionProvider, DefaultCudaExecutionProvider())); +#endif ASSERT_STATUS_OK(execution_providers.Add(onnxruntime::kCpuExecutionProvider, std::make_unique(CPUExecutionProviderInfo()))); KernelRegistryManager test_registry_manager; @@ -129,6 +137,9 @@ TEST(TransformerTest, MemcpyTransformerTest) { } TEST(TransformerTest, MemcpyTransformerTestCudaFirst) { +#if defined(USE_CUDA) && defined(USE_DML) + SKIP_CUDA_TEST_WITH_DML; +#endif std::unordered_map domain_to_version; domain_to_version[kOnnxDomain] = 7; auto model = std::make_shared("test", false, ModelMetaData(), PathString(), @@ -161,7 +172,9 @@ TEST(TransformerTest, MemcpyTransformerTestCudaFirst) { KernelRegistryManager kernel_registry_manager; ExecutionProviders execution_providers; + ASSERT_STATUS_OK(execution_providers.Add(onnxruntime::kCudaExecutionProvider, DefaultCudaExecutionProvider())); + ASSERT_STATUS_OK(execution_providers.Add(onnxruntime::kCpuExecutionProvider, std::make_unique(CPUExecutionProviderInfo()))); KernelRegistryManager test_registry_manager; @@ -281,7 +294,11 @@ TEST(TransformerTest, TestInitializerDuplicationInSubgraph) { KernelRegistryManager kernel_registry_manager; ExecutionProviders execution_providers; +#if defined(USE_CUDA) && defined(USE_DML) + SKIP_CUDA_TEST_WITH_DML; +#endif ASSERT_STATUS_OK(execution_providers.Add(onnxruntime::kCudaExecutionProvider, DefaultCudaExecutionProvider())); + ASSERT_STATUS_OK(execution_providers.Add(onnxruntime::kCpuExecutionProvider, std::make_unique(CPUExecutionProviderInfo()))); KernelRegistryManager test_registry_manager; @@ -323,7 +340,11 @@ TEST(TransformerTest, MemcpyTransformerTestGraphInputConsumedOnMultipleDevices) KernelRegistryManager kernel_registry_manager; ExecutionProviders execution_providers; +#if defined(USE_CUDA) && defined(USE_DML) + SKIP_CUDA_TEST_WITH_DML; +#endif ASSERT_STATUS_OK(execution_providers.Add(onnxruntime::kCudaExecutionProvider, DefaultCudaExecutionProvider())); + ASSERT_STATUS_OK(execution_providers.Add(onnxruntime::kCpuExecutionProvider, std::make_unique(CPUExecutionProviderInfo()))); KernelRegistryManager test_registry_manager; @@ -425,7 +446,11 @@ TEST(TransformerTest, MemcpyTransformerTestImplicitInputConsumedOnMultipleDevice KernelRegistryManager kernel_registry_manager; ExecutionProviders execution_providers; +#if defined(USE_CUDA) && defined(USE_DML) + SKIP_CUDA_TEST_WITH_DML; +#endif ASSERT_STATUS_OK(execution_providers.Add(onnxruntime::kCudaExecutionProvider, DefaultCudaExecutionProvider())); + ASSERT_STATUS_OK(execution_providers.Add(onnxruntime::kCpuExecutionProvider, std::make_unique(CPUExecutionProviderInfo()))); KernelRegistryManager test_registry_manager; diff --git a/onnxruntime/test/framework/sparse_kernels_test.cc b/onnxruntime/test/framework/sparse_kernels_test.cc index 7bd6b47f52b7d..db9592c293fd0 100644 --- a/onnxruntime/test/framework/sparse_kernels_test.cc +++ b/onnxruntime/test/framework/sparse_kernels_test.cc @@ -1457,6 +1457,9 @@ TEST(SparseTensorConversionTests, CsrConversion) { #ifdef USE_CUDA auto cuda_provider = DefaultCudaExecutionProvider(); + if (cuda_provider == nullptr) { + return; + } auto cuda_allocator = cuda_provider->CreatePreferredAllocators()[0]; { auto cuda_transfer = cuda_provider->GetDataTransfer(); @@ -1684,6 +1687,9 @@ TEST(SparseTensorConversionTests, CooConversion) { #ifdef USE_CUDA auto cuda_provider = DefaultCudaExecutionProvider(); + if (cuda_provider == nullptr) { + return; + } auto cuda_allocator = cuda_provider->CreatePreferredAllocators()[0]; { auto cuda_transfer = cuda_provider->GetDataTransfer(); diff --git a/onnxruntime/test/lora/lora_test.cc b/onnxruntime/test/lora/lora_test.cc index fde603858f9a9..8338c7d547a09 100644 --- a/onnxruntime/test/lora/lora_test.cc +++ b/onnxruntime/test/lora/lora_test.cc @@ -201,6 +201,14 @@ TEST(LoraAdapterTest, Load) { #ifdef USE_CUDA TEST(LoraAdapterTest, VerifyCudaDeviceCopy) { + if (DefaultCudaExecutionProvider() == nullptr) { + GTEST_SKIP() << "Skip This Test Due to this EP is null"; + } +#ifdef USE_DML + if (DefaultDmlExecutionProvider() != nullptr) { + GTEST_FAIL() << "It should not run with DML EP"; + } +#endif auto cpu_ep = DefaultCpuExecutionProvider(); auto cpu_allocator = cpu_ep->CreatePreferredAllocators()[0]; auto cuda_allocator = DefaultCudaExecutionProvider()->CreatePreferredAllocators()[0]; @@ -234,6 +242,17 @@ TEST(LoraAdapterTest, VerifyCudaDeviceCopy) { #ifdef USE_DML TEST(LoraAdapterTest, VerifyDmlDeviceCopy) { + // NO_DML_TEST is set, DML test is skipped + if (DefaultDmlExecutionProvider() == nullptr) { + GTEST_SKIP() << "Skip This Test Due to this EP is null"; + } + +#ifdef USE_CUDA + if (DefaultCudaExecutionProvider() != nullptr) { + GTEST_FAIL() << "It should not run with CUDA EP"; + } +#endif + auto cpu_ep = DefaultCpuExecutionProvider(); auto cpu_allocator = cpu_ep->CreatePreferredAllocators()[0]; diff --git a/onnxruntime/test/providers/base_tester.cc b/onnxruntime/test/providers/base_tester.cc index dea39bc99d3e9..9d83c789c5124 100644 --- a/onnxruntime/test/providers/base_tester.cc +++ b/onnxruntime/test/providers/base_tester.cc @@ -529,6 +529,17 @@ void BaseTester::Run(ExpectResult expect_result, const std::string& expected_fai so.use_deterministic_compute = use_determinism_; so.graph_optimization_level = TransformerLevel::Default; // 'Default' == off + // remove nullptr in execution_providers. + // it's a little ugly but we need to do this because DefaultXXXExecutionProvider() can return nullptr in Runtime. + // And there're many places adding DefaultXXXExecutionProvider() to execution_providers directly. + if (execution_providers != nullptr) { + execution_providers->erase(std::remove(execution_providers->begin(), execution_providers->end(), nullptr), execution_providers->end()); + if (execution_providers->size() == 0) { + // In fact, no ep is needed to run + return; + } + } + Run(so, expect_result, expected_failure_string, excluded_provider_types, run_options, execution_providers, options); } diff --git a/onnxruntime/test/providers/compare_provider_test_utils.cc b/onnxruntime/test/providers/compare_provider_test_utils.cc index 386a5656d8a01..9acb37c24ddd0 100644 --- a/onnxruntime/test/providers/compare_provider_test_utils.cc +++ b/onnxruntime/test/providers/compare_provider_test_utils.cc @@ -53,6 +53,11 @@ void CompareOpTester::CompareWithCPU(const std::string& target_provider_type, SetTestFunctionCalled(); std::unique_ptr target_execution_provider = GetExecutionProvider(target_provider_type); +#if defined(USE_CUDA) && defined(USE_DML) + if (target_execution_provider == nullptr) { + return; + } +#endif ASSERT_TRUE(target_execution_provider != nullptr) << "provider_type " << target_provider_type << " is not supported."; diff --git a/onnxruntime/test/providers/cpu/model_tests.cc b/onnxruntime/test/providers/cpu/model_tests.cc index e3c86a137484f..b46c253fb8ed9 100644 --- a/onnxruntime/test/providers/cpu/model_tests.cc +++ b/onnxruntime/test/providers/cpu/model_tests.cc @@ -491,6 +491,18 @@ ::std::vector<::std::basic_string> GetParameterStrings() { // the number of times these are run to reduce the CI time. provider_names.erase(provider_name_cpu); #endif + +#if defined(USE_CUDA) && defined(USE_DML) + const std::string no_cuda_ep_test = Env::Default().GetEnvironmentVar("NO_CUDA_TEST"); + if (no_cuda_ep_test == "1") { + provider_names.erase(provider_name_cuda); + } + const std::string no_dml_ep_test = Env::Default().GetEnvironmentVar("NO_DML_TEST"); + if (no_dml_ep_test == "1") { + provider_names.erase(provider_name_dml); + } +#endif + std::vector> v; // Permanently exclude following tests because ORT support only opset starting from 7, // Please make no more changes to the list diff --git a/onnxruntime/test/providers/cpu/tensor/gather_op_test.cc b/onnxruntime/test/providers/cpu/tensor/gather_op_test.cc index be79a6d29d539..0f23e4c39d7e2 100644 --- a/onnxruntime/test/providers/cpu/tensor/gather_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/gather_op_test.cc @@ -3,6 +3,9 @@ #include "core/session/onnxruntime_session_options_config_keys.h" #include "gtest/gtest.h" +#if USE_CUDA +#include "test/common/cuda_op_test_utils.h" +#endif #include "test/providers/provider_test_utils.h" #include "test/util/include/default_providers.h" @@ -122,6 +125,9 @@ TEST(GatherOpTest, Gather_invalid_index_gpu) { 4.0f, 5.0f, 6.0f, 7.0f, 0.0f, 0.0f, 0.0f, 0.0f}); +#if defined(USE_CUDA) && defined(USE_DML) + SKIP_CUDA_TEST_WITH_DML; +#endif // On GPU, just set the value to 0 instead of report error. exclude all other providers test #if defined(USE_CUDA) diff --git a/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc b/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc index 05cfb5c13d689..7e1a2384d7fc6 100644 --- a/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc @@ -15,11 +15,13 @@ std::vector> GetExecutionProviders(int opset execution_providers.emplace_back(DefaultCpuExecutionProvider()); #ifdef USE_CUDA - if (opset_version < 20) { - execution_providers.emplace_back(DefaultCudaExecutionProvider()); + if (DefaultCudaExecutionProvider() != nullptr) { + if (opset_version < 20) { + execution_providers.emplace_back(DefaultCudaExecutionProvider()); #ifdef ENABLE_CUDA_NHWC_OPS - execution_providers.push_back(DefaultCudaNHWCExecutionProvider()); + execution_providers.push_back(DefaultCudaNHWCExecutionProvider()); #endif + } } #endif diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index 57f748ab8b6bd..62bdedd833025 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -122,6 +122,12 @@ std::unique_ptr DefaultOpenVINOExecutionProvider() { std::unique_ptr DefaultCudaExecutionProvider() { #ifdef USE_CUDA +#ifdef USE_DML + const std::string no_cuda_ep_test = Env::Default().GetEnvironmentVar("NO_CUDA_TEST"); + if (no_cuda_ep_test == "1") { + return nullptr; + } +#endif OrtCUDAProviderOptionsV2 provider_options{}; provider_options.do_copy_in_default_stream = true; provider_options.use_tf32 = false; @@ -328,6 +334,12 @@ std::unique_ptr DefaultCannExecutionProvider() { std::unique_ptr DefaultDmlExecutionProvider() { #ifdef USE_DML +#ifdef USE_CUDA + const std::string no_dml_ep_test = Env::Default().GetEnvironmentVar("NO_DML_TEST"); + if (no_dml_ep_test == "1") { + return nullptr; + } +#endif ConfigOptions config_options{}; if (auto factory = DMLProviderFactoryCreator::CreateFromDeviceOptions(config_options, nullptr, false, false)) { return factory->CreateProvider(); diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget-combine-cuda-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget-combine-cuda-stage.yml index 9c7fbc24ab1b6..0b3eac0110abc 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nuget-combine-cuda-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nuget-combine-cuda-stage.yml @@ -50,6 +50,8 @@ stages: win_trt_home: ${{ parameters.win_trt_home }} win_cuda_home: ${{ parameters.win_cuda_home }} buildJava: ${{ parameters.buildJava }} + SpecificArtifact: ${{ parameters.SpecificArtifact }} + BuildId: ${{ parameters.BuildId }} - template: nuget-cuda-packaging-stage.yml parameters: diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget-win-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget-win-cuda-packaging-stage.yml index 445066f08995a..d6b25c98936f0 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nuget-win-cuda-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nuget-win-cuda-packaging-stage.yml @@ -34,7 +34,7 @@ parameters: displayName: Specific Artifact's BuildId type: string default: '0' - + - name: buildJava type: boolean @@ -50,13 +50,14 @@ stages: msbuildPlatform: x64 packageName: x64-cuda CudaVersion: ${{ parameters.CudaVersion }} - buildparameter: --use_cuda --cuda_home=${{ parameters.win_cuda_home }} --enable_onnx_tests --enable_wcos --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=60;61;70;75;80" + buildparameter: --use_cuda --cuda_home=${{ parameters.win_cuda_home }} --enable_onnx_tests --enable_wcos --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=60;61;70;75;80" --use_dml --build_csharp --parallel runTests: ${{ parameters.RunOnnxRuntimeTests }} buildJava: ${{ parameters.buildJava }} java_artifact_id: onnxruntime_gpu UseIncreasedTimeoutForTests: ${{ parameters.UseIncreasedTimeoutForTests }} SpecificArtifact: ${{ parameters.SpecificArtifact }} BuildId: ${{ parameters.BuildId }} + ComboTests: true # Windows CUDA with TensorRT Packaging - template: ../templates/win-ci.yml parameters: @@ -68,7 +69,7 @@ stages: msbuildPlatform: x64 CudaVersion: ${{ parameters.CudaVersion }} packageName: x64-tensorrt - buildparameter: --use_tensorrt --tensorrt_home=${{ parameters.win_trt_home }} --cuda_home=${{ parameters.win_cuda_home }} --enable_onnx_tests --enable_wcos --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=60;61;70;75;80" + buildparameter: --use_tensorrt --tensorrt_home=${{ parameters.win_trt_home }} --cuda_home=${{ parameters.win_cuda_home }} --enable_onnx_tests --enable_wcos --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=60;61;70;75;80" --parallel runTests: ${{ parameters.RunOnnxRuntimeTests }} buildJava: ${{ parameters.buildJava }} java_artifact_id: onnxruntime_gpu diff --git a/tools/ci_build/github/azure-pipelines/templates/win-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-ci.yml index e7be1d0a4f7ef..115ea010ad00e 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-ci.yml @@ -25,7 +25,7 @@ parameters: - name: runTests type: boolean - default: true + default: false - name: buildJava type: boolean @@ -71,6 +71,10 @@ parameters: - 11.8 - 12.2 +- name: ComboTests + type: boolean + default: false + - name: SpecificArtifact displayName: Use Specific Artifact type: boolean @@ -220,7 +224,7 @@ stages: condition: and(succeeded(), eq('${{ parameters.runTests}}', true)) inputs: scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' - arguments: '--config RelWithDebInfo --use_binskim_compliant_compile_flags --enable_lto --disable_rtti --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --test --cmake_generator "$(VSGenerator)" --enable_onnx_tests $(TelemetryOption) ${{ parameters.buildparameter }}' + arguments: '--config RelWithDebInfo --use_binskim_compliant_compile_flags --enable_lto --disable_rtti --build_dir $(Build.BinariesDirectory) --test --skip_submodule_sync --build_shared_lib --cmake_generator "$(VSGenerator)" --enable_onnx_tests $(TelemetryOption) ${{ parameters.buildparameter }}' workingDirectory: '$(Build.BinariesDirectory)' - ${{ else }}: - powershell: | @@ -332,6 +336,10 @@ stages: displayName: 'Clean Agent Directories' condition: always() + - script: + echo ${{ parameters.SpecificArtifact }} + displayName: 'Print Specific Artifact' + - checkout: self clean: true submodules: none @@ -395,13 +403,34 @@ stages: displayName: 'Append dotnet x86 Directory to PATH' condition: and(succeeded(), eq('${{ parameters.buildArch}}', 'x86')) - - task: PythonScript@0 - displayName: 'test' - condition: and(succeeded(), eq('${{ parameters.runTests}}', true)) - inputs: - scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' - arguments: '--config RelWithDebInfo --use_binskim_compliant_compile_flags --enable_lto --disable_rtti --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --test --enable_onnx_tests $(TelemetryOption) ' - workingDirectory: '$(Build.BinariesDirectory)' + - ${{ if eq(parameters.ComboTests, 'true') }}: + - task: PythonScript@0 + displayName: 'test excludes CUDA' + condition: and(succeeded(), eq('${{ parameters.runTests}}', true)) + inputs: + scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' + arguments: '--config RelWithDebInfo --use_binskim_compliant_compile_flags --enable_lto --disable_rtti --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --test --enable_onnx_tests $(TelemetryOption) ' + workingDirectory: '$(Build.BinariesDirectory)' + env: + NO_CUDA_TEST: '1' + - task: PythonScript@0 + displayName: 'test excludes DML' + condition: and(succeeded(), eq('${{ parameters.runTests}}', true)) + inputs: + scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' + arguments: '--config RelWithDebInfo --use_binskim_compliant_compile_flags --enable_lto --disable_rtti --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --test --enable_onnx_tests $(TelemetryOption) ' + workingDirectory: '$(Build.BinariesDirectory)' + env: + NO_DML_TEST: '1' + - ${{ else }}: + - task: PythonScript@0 + displayName: 'test' + condition: and(succeeded(), eq('${{ parameters.runTests}}', true)) + inputs: + scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' + arguments: '--config RelWithDebInfo --use_binskim_compliant_compile_flags --enable_lto --disable_rtti --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --test --enable_onnx_tests $(TelemetryOption) ' + workingDirectory: '$(Build.BinariesDirectory)' + # Previous stage only assembles the java binaries, testing will be done in this stage with GPU machine - ${{ if eq(parameters.buildJava, 'true') }}: - template: make_java_win_binaries.yml