Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue 4073: Fix for host-side hanging when an invalid DPRINT WAIT command is running on the device. #4103

Merged
merged 3 commits into from
Dec 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/tools/kernel_print.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ Note that the core coordinates are currently physical NOC coordinates (not logic
To generate kernel debug prints on the device, include the ``debug/dprint.h`` header and use the APIs defined there.
And example with the different features available is shown below:

.. code-block::
.. code-block:: c++

#include "debug/dprint.h" // required in all kernels using DPRINT

Expand Down Expand Up @@ -64,7 +64,7 @@ to the current CB read or write pointer. This means that for printing a tile rea
``DPRINT`` call has to occur between the ``cb_wait_front`` and ``cb_pop_front`` calls. For printing a tile from the
back of the CB, the ``DPRINT`` call has to occur between the ``cb_reserve_back`` and ``cb_push_back`` calls.

.. code-block::
.. code-block:: sh

#include "debug/dprint.h" // required in all kernels using DPRINT

Expand Down
15 changes: 15 additions & 0 deletions tests/tt_metal/tt_metal/test_kernels/misc/print_hang.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include "debug/dprint.h"
#include "debug/dprint_test_common.h"

/*
* Test kernel that wait for a signal that never raises.
*/

void kernel_main() {
DPRINT << WAIT{1};
print_test_data();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include "command_queue_fixture.hpp"
#include "common/bfloat16.hpp"
#include "impl/debug/dprint_server.hpp"
#include "gtest/gtest.h"
#include "test_utils.hpp"
#include "tt_metal/detail/tt_metal.hpp"
#include "tt_metal/host_api.hpp"

//////////////////////////////////////////////////////////////////////////////////////////
// A test for checking that we can handle an invalid WAIT command.
//////////////////////////////////////////////////////////////////////////////////////////
using namespace tt;
using namespace tt::tt_metal;

const std::string golden_output =
R"(DPRINT server timed out on core (1,1) riscv 4, waiting on a RAISE signal: 1
)";

TEST_F(CommandQueueWithDPrintFixture, TestPrintHanging) {
// Device already set up by gtest fixture.
Device *device = this->device_;

// Set up program and command queue
CommandQueue& cq = *tt::tt_metal::detail::GLOBAL_CQ;
Program program = Program();

// Run a kernel that just waits on a signal that never comes (BRISC only).
constexpr CoreCoord core = {0, 0}; // Print on first core only
KernelHandle brisc_print_kernel_id = CreateKernel(
program,
"tests/tt_metal/tt_metal/test_kernels/misc/print_hang.cpp",
core,
DataMovementConfig{.processor = DataMovementProcessor::RISCV_0, .noc = NOC::RISCV_0_default}
);

// Run the program, we expect it to throw on waiting for CQ to finish
EnqueueProgram(cq, program, false);
try {
Finish(cq);
tt_await_debug_print_server();
} catch (std::runtime_error& e) {
const string expected = "Command Queue could not finish: device hang due to unanswered DPRINT WAIT.";
const string error = string(e.what());
log_info(tt::LogTest, "Caught exception (one is expected in this test): {}", error);
EXPECT_TRUE(error.find(expected) != string::npos);
}

// Check the print log against golden output.
EXPECT_TRUE(
FilesMatchesString(
CommandQueueWithDPrintFixture::dprint_file_name,
golden_output
)
);
}
108 changes: 82 additions & 26 deletions tt_metal/impl/debug/dprint_server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ struct DebugPrintServerContext {

stop_print_server_ = false;
mute_print_server_ = false;
new_data_processed_ = false;
new_data_last_iter_ = false;
server_killed_due_to_hang_ = false;
print_server_thread_ = new std::thread(
[this, chip_ids, cores, hart_mask] { thread_poll(chip_ids, cores, hart_mask); }
);
Expand Down Expand Up @@ -116,10 +117,15 @@ struct DebugPrintServerContext {
// TODO(dma): once we have access to the device is there a way we can poll the device to
// check whether more print data is coming?
do {
// No need to await if the server was killed already due to a hang.
if (server_killed_due_to_hang_)
break;
std::this_thread::sleep_for(std::chrono::milliseconds(5));
} while (hart_waiting_on_signal_.size() > 0 || new_data_processed_);
} while (hart_waiting_on_signal_.size() > 0 || new_data_last_iter_);
}

bool print_hang_detected() { return server_killed_due_to_hang_; }

private:

// Flag for main thread to signal the print server thread to stop.
Expand All @@ -129,9 +135,12 @@ struct DebugPrintServerContext {
std::atomic<bool> mute_print_server_;
// Flag for signalling whether the print server thread has recently processed data (and is
// therefore likely to continue processing data in the next round of polling).
std::atomic<bool> new_data_processed_;
std::atomic<bool> new_data_last_iter_;
std::thread* print_server_thread_;

// A flag to signal to the main thread if the print server detected a print-based hang.
bool server_killed_due_to_hang_;

std::ofstream* outfile_ = nullptr; // non-cout
std::ostream* stream_ = nullptr; // either == outfile_ or is &cout

Expand All @@ -142,7 +151,12 @@ struct DebugPrintServerContext {
std::set<uint32_t> raised_signals_;

void thread_poll(const vector<int>& chip_ids, const vector<CoreCoord>& cores, uint32_t hart_mask);
bool peek_flush_one_hart_nonblocking(int chip_id, const CoreCoord& core, int hart_index);
bool peek_flush_one_hart_nonblocking(
int chip_id,
const CoreCoord& core,
int hart_index,
bool new_data_this_iter
);
};

static void print_tile_slice(ostream& stream, uint8_t* ptr, int hart_id) {
Expand Down Expand Up @@ -210,7 +224,12 @@ bool check_init_magic_cleared(int chip_id, const CoreCoord& core, int hart_id) {
// Peeks a specified hart for any debug prints present in the buffer and flushes it, printing the
// contents out to host-side stream. Returns true if some data was read out, and false if no new
// print data was present on the device.
bool DebugPrintServerContext::peek_flush_one_hart_nonblocking(int chip_id, const CoreCoord& core, int hart_id) {
bool DebugPrintServerContext::peek_flush_one_hart_nonblocking(
int chip_id,
const CoreCoord& core,
int hart_id,
bool new_data_this_iter
) {
// compute the buffer address for the requested hart
uint32_t base_addr = PRINT_BUFFER_NC + hart_id*PRINT_BUFFER_SIZE;

Expand All @@ -226,7 +245,45 @@ bool DebugPrintServerContext::peek_flush_one_hart_nonblocking(int chip_id, const
uint32_t counter = 0;
uint32_t sigval = 0;
char val = 0;

// If the print server is muted, dump the output to a null stream instead.
ostream& stream = (mute_print_server_)? null_stream : *stream_;

// Check whether this hart is currently waiting on a WAIT to be fulfilled.
tuple<uint32_t, uint32_t, uint32_t> hart_key {core.x, core.y, hart_id};
if (hart_waiting_on_signal_.count(hart_key) > 0) {
// Check if the signal the hart is wairint for has been raised.
uint32_t wait_signal = hart_waiting_on_signal_[hart_key];
if (raised_signals_.count(wait_signal) > 0) {
// The signal has been raised, we can continue.
hart_waiting_on_signal_.erase(hart_key);
} else {
// This hart is still waiting. This is fine as long as the print server (and therefore
// the device) is still making progress. Unfortunetaly there's no way to check if the
// print server is full because the next print that would overflow the buffer spins the
// device until the buffer has more space, but checking for any new prints seems to work
// for cases so far.
if (!new_data_this_iter && !new_data_last_iter_) {
// If no progress was made on both sides, then it could be an invalid wait
// condition, which could cause a deadlock. Print a warning and set a flag to close
// the print server in this case.
string core_str = "core (" + to_string(core.x) + "," + to_string(core.y) +
") riscv " + to_string(hart_id);
string error_str = "DPRINT server timed out on " +
core_str +
", waiting on a RAISE signal: " +
to_string(wait_signal) + "\n";
stream << error_str << flush;
log_warning(tt::LogMetal, "Debug Print Server encountered an error: {}", error_str);
server_killed_due_to_hang_ = true;
return false;
}

// Since it's still waiting, return false here since no data was read.
return false;
}
}

if (rpos < wpos) {
// Now read the entire buffer
from_dev = tt::llrt::read_hex_vec_from_core(chip_id, core, base_addr, PRINT_BUFFER_SIZE);
Expand Down Expand Up @@ -404,44 +461,38 @@ void DebugPrintServerContext::thread_poll(
if (stop_print_server_) {
// If the stop signal was received, exit the print server thread, but wait for any
// existing prints to be wrapped up first.
if (hart_waiting_on_signal_.size() == 0 && !new_data_processed_)
if (hart_waiting_on_signal_.size() == 0 && !new_data_last_iter_)
break;
}

// Flag for whether any new print data was found in this round of polling.
bool new_print_data = false;
bool new_data_this_iter = false;
for (auto chip: chip_ids) {
for (auto core: cores) {
for (int hart_index = 0; hart_index < DPRINT_NRISCVS; hart_index++) {
if (hart_mask & (1<<hart_index)) {
if (!check_init_magic_cleared(chip, core, hart_index))
continue;

// Make sure that this core is not waiting on a raise signal to continue
// printing.
tuple<uint32_t, uint32_t, uint32_t> hart_key {core.x, core.y, hart_index};
if (hart_waiting_on_signal_.count(hart_key) > 0) {
uint32_t wait_signal = hart_waiting_on_signal_[hart_key];
if (raised_signals_.count(wait_signal) > 0) {
// The signal this hart is waiting for has been raised, it's not
// waiting anymore.
hart_waiting_on_signal_.erase(hart_key);
} else {
// Not raised yet, keep waiting.
continue;
}
}

new_print_data |= peek_flush_one_hart_nonblocking(chip, core, hart_index);
new_data_this_iter |= peek_flush_one_hart_nonblocking(
chip,
core,
hart_index,
new_data_this_iter
);

// If this read detected a print hang, stop processing prints.
if (server_killed_due_to_hang_)
return;
}
}
}
}

// Signal whether the print server is currently processing data.
new_data_processed_ = new_print_data;
new_data_last_iter_ = new_data_this_iter;
// Sleep for a few ms if no data was processed.
if (!new_print_data)
if (!new_data_last_iter_)
std::this_thread::sleep_for(std::chrono::milliseconds(1));
}
}
Expand Down Expand Up @@ -489,6 +540,11 @@ void tt_await_debug_print_server() {
}
}

bool tt_print_hang_detected() {
return (DebugPrintServerContext::inst != nullptr)
&& (DebugPrintServerContext::inst->print_hang_detected());
}

// The print server is not valid without alive Cluster and tt_device
void tt_start_debug_print_server(
std::function<CoreCoord ()>get_grid_size,
Expand Down Expand Up @@ -533,7 +589,7 @@ void tt_start_debug_print_server(
if (all_worker_cores.count(core) > 0) {
print_cores_sanitized.push_back(core);
} else {
log_info(
log_warning(
tt::LogDevice,
"TT_METAL_DPRINT_CORES included worker core ({}, {}), which is not a valid coordinate. This coordinate will be ignored by the dprint server.",
core.x,
Expand Down
10 changes: 10 additions & 0 deletions tt_metal/impl/debug/dprint_server.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,13 @@ Note that this function does not actually check whether the device will continue
data, it only checks whether the print server is currently processing print data.
*/
void tt_await_debug_print_server();

/**
@brief Check whether a print hang has been detected by the print server.

The print server tries to determine if a core is stalled due to the combination of (1) a WAIT
print command and (2) no new print data coming through. An invalid WAIT command and the print
buffer filling up afterwards can cause the core to spin forever. In this case this function will
return true and the print server will be terminated.
*/
bool tt_print_hang_detected();
7 changes: 7 additions & 0 deletions tt_metal/impl/dispatch/command_queue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "tt_metal/detail/tt_metal.hpp"
#include "tt_metal/host_api.hpp"
#include "tt_metal/impl/buffers/semaphore.hpp"
#include "tt_metal/impl/debug/dprint_server.hpp"
#include "tt_metal/third_party/umd/device/tt_xy_pair.h"
#include "dev_msgs.h"
#include <algorithm> // for copy() and assign()
Expand Down Expand Up @@ -876,6 +877,12 @@ void CommandQueue::finish() {
uint32_t finish;
do {
tt::Cluster::instance().read_sysmem(&finish, 4, HOST_CQ_FINISH_PTR, 0);

// There's also a case where the device can be hung due to an unanswered DPRINT WAIT and
// a full print buffer. Poll the print server for this case and throw if it happens.
if (tt_print_hang_detected()) {
TT_THROW("Command Queue could not finish: device hang due to unanswered DPRINT WAIT.");
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We cannot do this since it's not like we have one Finish at the end of every run. We can call finish an arbitrary amount of times, and if we break and continue enqueueing more commands, this could lead to undefined behaviour.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pushed up a fix to change this to a throw (throw here instead of in the print server so that the exception comes from Finish()), which should still work for CI (running that now).

} while (finish != 1);

// Reset this value to 0 before moving on
Expand Down