Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable streams for the DML EP #19481

Merged
merged 2 commits into from
Feb 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions cmake/adjust_global_compile_flags.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,13 @@ if (onnxruntime_MINIMAL_BUILD)
endif()
endif()

# enable stream for all the non-minimal build
if (NOT onnxruntime_MINIMAL_BUILD)
# Enable stream for all the non-minimal build, except for DML. There's currently a bug
# in the allocation planner when reusing buffers and more than one streams are used that
# make it possible (although rarely) to reach a reference count of 0 for a buffer that is
# still being used. Since DML doesn't benefit from multiple streams, disabling it is the
# safest option for now.
# https://github.com/microsoft/onnxruntime/issues/19480
if (NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_USE_DML)
add_compile_definitions(ORT_ENABLE_STREAM)
endif()

Expand Down
21 changes: 17 additions & 4 deletions onnxruntime/test/framework/allocation_planner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -327,10 +327,23 @@ class PlannerTest : public ::testing::Test {

if (invoke_createPlan_explicityly) {
onnxruntime::GraphViewer graph_viewer{graph_};
status = SequentialPlanner::CreatePlan(nullptr, graph_viewer, outer_scope_node_args, execution_providers_,
kernel_create_info_map, {}, {}, state_->GetOrtValueNameIdxMap(), test_context,
MockStreamHandleRegsitry(), /* {{kCpuExecutionProvider, 1}}, {},*/
ORT_TSTR(""), DefaultLoggingManager().DefaultLogger(), plan_);
status = SequentialPlanner::CreatePlan(
nullptr,
graph_viewer,
outer_scope_node_args,
execution_providers_,
kernel_create_info_map,
{},
{},
state_->GetOrtValueNameIdxMap(),
test_context,
#ifdef ORT_ENABLE_STREAM
MockStreamHandleRegsitry(),
#endif
/* {{kCpuExecutionProvider, 1}}, {},*/
ORT_TSTR(""),
DefaultLoggingManager().DefaultLogger(),
plan_);

EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
// AllocationPlanTestUtility::BasicIntegrityCheck(*plan_, name_to_arg_.size());
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/test/framework/bfc_arena_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ struct StreamMock : public Stream {
Status CleanUpOnRunEnd() override { return Status::OK(); }
};

#ifdef ORT_ENABLE_STREAM
TEST(StreamAwareArenaTest, TwoStreamAllocation) {
StreamAwareArena a(std::unique_ptr<IAllocator>(new CPUAllocator()), 1 << 30, false);
CheckStats(&a, 0, 0, 0, 0);
Expand Down Expand Up @@ -413,6 +414,7 @@ TEST(StreamAwareArenaTest, TestSecureTheChunk) {
EXPECT_TRUE(waitFunctionInvoked) << "wait function should be invoked";
a.Free(p2);
}
#endif

TEST(BFCArenaTest, TestExtendStrategy) {
int64_t extend_delta_bytes = 0;
Expand Down
55 changes: 50 additions & 5 deletions onnxruntime/test/framework/execution_frame_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,16 @@ TEST_F(ExecutionFrameTest, TensorAllocationTest) {
ASSERT_STATUS_OK(state.FinalizeSessionState(ORT_TSTR(""), kernel_registry_manager));

vector<OrtValue> outputs;
ExecutionFrame frame({}, {}, {}, outputs, {}, {}, state);
ExecutionFrame frame(
{},
{},
{},
outputs,
{},
#ifdef ORT_ENABLE_STREAM
{},
#endif
state);

int start_index = frame.GetNodeOffset(node->Index());
ASSERT_EQ(start_index, 0);
Expand Down Expand Up @@ -150,7 +159,16 @@ TEST_F(ExecutionFrameTest, OutputShapeValidationTest) {
ASSERT_STATUS_OK(state.FinalizeSessionState(ORT_TSTR(""), kernel_registry_manager));

vector<OrtValue> outputs;
ExecutionFrame frame({}, {}, {}, outputs, {}, {}, state);
ExecutionFrame frame(
{},
{},
{},
outputs,
{},
#ifdef ORT_ENABLE_STREAM
{},
#endif
state);

int start_index = frame.GetNodeOffset(node->Index());
ASSERT_EQ(start_index, 0);
Expand Down Expand Up @@ -216,7 +234,16 @@ TEST_F(ExecutionFrameTest, FeedInDataTest) {
ASSERT_TRUE(mlvalue_name_idx_map.GetIdx("Y", y_idx).IsOK());

vector<OrtValue> outputs;
ExecutionFrame frame(AsSpan({x_idx}), AsSpan({value}), AsSpan({y_idx}), outputs, {}, {}, state);
ExecutionFrame frame(
AsSpan({x_idx}),
AsSpan({value}),
AsSpan({y_idx}),
outputs,
{},
#ifdef ORT_ENABLE_STREAM
{},
#endif
state);

OrtValue* p_ml_value = frame.GetMutableNodeInputOrOutputMLValue(0);
Tensor* p_tensor_arg_0 = p_ml_value ? p_ml_value->GetMutable<Tensor>() : nullptr;
Expand Down Expand Up @@ -299,7 +326,16 @@ TEST_F(ExecutionFrameTest, MemPatternTest) {
std::vector<float>(6, 1.0f), &v3);

std::vector<OrtValue> outputs;
ExecutionFrame frame(AsSpan({x1_idx, x2_idx, x3_idx}), AsSpan({v1, v2, v3}), AsSpan({t3_idx}), outputs, {}, {}, state);
ExecutionFrame frame(
AsSpan({x1_idx, x2_idx, x3_idx}),
AsSpan({v1, v2, v3}),
AsSpan({t3_idx}),
outputs,
{},
#ifdef ORT_ENABLE_STREAM
{},
#endif
state);

OrtValue& mlvalue3 = *frame.GetMutableNodeInputOrOutputMLValue(3);
OrtValue& mlvalue4 = *frame.GetMutableNodeInputOrOutputMLValue(4);
Expand Down Expand Up @@ -388,7 +424,16 @@ TEST_F(ExecutionFrameTest, MemPatternWithExternalOutputsTest) {
CreateMLValue<float>(cpu_allocator, std::vector<int64_t>{2, 2}, std::vector<float>(4, 1.0f), &t_value);

vector<OrtValue> outputs;
ExecutionFrame frame(AsSpan({x_idx}), AsSpan({x_value}), AsSpan({y_idx}), outputs, {}, {}, state);
ExecutionFrame frame(
AsSpan({x_idx}),
AsSpan({x_value}),
AsSpan({y_idx}),
outputs,
{},
#ifdef ORT_ENABLE_STREAM
{},
#endif
state);

ASSERT_FALSE(frame.GetMutableNodeInputOrOutputMLValue(t_idx)->IsTensor());
ASSERT_STATUS_OK(frame.SetOutputMLValue(t_idx, t_value));
Expand Down
Loading