diff --git a/docs/source/tools/kernel_print.rst b/docs/source/tools/kernel_print.rst index 4011fe0cdda..0b2c5b70316 100644 --- a/docs/source/tools/kernel_print.rst +++ b/docs/source/tools/kernel_print.rst @@ -27,7 +27,7 @@ Note that the core coordinates are currently physical NOC coordinates (not logic To generate kernel debug prints on the device, include the ``debug/dprint.h`` header and use the APIs defined there. And example with the different features available is shown below: -.. code-block:: +.. code-block:: c++ #include "debug/dprint.h" // required in all kernels using DPRINT @@ -64,7 +64,7 @@ to the current CB read or write pointer. This means that for printing a tile rea ``DPRINT`` call has to occur between the ``cb_wait_front`` and ``cb_pop_front`` calls. For printing a tile from the back of the CB, the ``DPRINT`` call has to occur between the ``cb_reserve_back`` and ``cb_push_back`` calls. -.. code-block:: +.. code-block:: sh #include "debug/dprint.h" // required in all kernels using DPRINT diff --git a/tests/tt_metal/tt_metal/test_kernels/misc/print_hang.cpp b/tests/tt_metal/tt_metal/test_kernels/misc/print_hang.cpp new file mode 100644 index 00000000000..e7ae95886ba --- /dev/null +++ b/tests/tt_metal/tt_metal/test_kernels/misc/print_hang.cpp @@ -0,0 +1,15 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "debug/dprint.h" +#include "debug/dprint_test_common.h" + +/* + * Test kernel that wait for a signal that never raises. +*/ + +void kernel_main() { + DPRINT << WAIT{1}; + print_test_data(); +} diff --git a/tests/tt_metal/tt_metal/unit_tests_fast_dispatch/dprint/test_print_hanging.cpp b/tests/tt_metal/tt_metal/unit_tests_fast_dispatch/dprint/test_print_hanging.cpp new file mode 100644 index 00000000000..9881ae63655 --- /dev/null +++ b/tests/tt_metal/tt_metal/unit_tests_fast_dispatch/dprint/test_print_hanging.cpp @@ -0,0 +1,59 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "command_queue_fixture.hpp" +#include "common/bfloat16.hpp" +#include "impl/debug/dprint_server.hpp" +#include "gtest/gtest.h" +#include "test_utils.hpp" +#include "tt_metal/detail/tt_metal.hpp" +#include "tt_metal/host_api.hpp" + +////////////////////////////////////////////////////////////////////////////////////////// +// A test for checking that we can handle an invalid WAIT command. +////////////////////////////////////////////////////////////////////////////////////////// +using namespace tt; +using namespace tt::tt_metal; + +const std::string golden_output = +R"(DPRINT server timed out on core (1,1) riscv 4, waiting on a RAISE signal: 1 +)"; + +TEST_F(CommandQueueWithDPrintFixture, TestPrintHanging) { + // Device already set up by gtest fixture. + Device *device = this->device_; + + // Set up program and command queue + CommandQueue& cq = *tt::tt_metal::detail::GLOBAL_CQ; + Program program = Program(); + + // Run a kernel that just waits on a signal that never comes (BRISC only). + constexpr CoreCoord core = {0, 0}; // Print on first core only + KernelHandle brisc_print_kernel_id = CreateKernel( + program, + "tests/tt_metal/tt_metal/test_kernels/misc/print_hang.cpp", + core, + DataMovementConfig{.processor = DataMovementProcessor::RISCV_0, .noc = NOC::RISCV_0_default} + ); + + // Run the program, we expect it to throw on waiting for CQ to finish + EnqueueProgram(cq, program, false); +try { + Finish(cq); + tt_await_debug_print_server(); +} catch (std::runtime_error& e) { + const string expected = "Command Queue could not finish: device hang due to unanswered DPRINT WAIT."; + const string error = string(e.what()); + log_info(tt::LogTest, "Caught exception (one is expected in this test): {}", error); + EXPECT_TRUE(error.find(expected) != string::npos); +} + + // Check the print log against golden output. + EXPECT_TRUE( + FilesMatchesString( + CommandQueueWithDPrintFixture::dprint_file_name, + golden_output + ) + ); +} diff --git a/tt_metal/impl/debug/dprint_server.cpp b/tt_metal/impl/debug/dprint_server.cpp index f9ce034f656..838b32d1335 100644 --- a/tt_metal/impl/debug/dprint_server.cpp +++ b/tt_metal/impl/debug/dprint_server.cpp @@ -83,7 +83,8 @@ struct DebugPrintServerContext { stop_print_server_ = false; mute_print_server_ = false; - new_data_processed_ = false; + new_data_last_iter_ = false; + server_killed_due_to_hang_ = false; print_server_thread_ = new std::thread( [this, chip_ids, cores, hart_mask] { thread_poll(chip_ids, cores, hart_mask); } ); @@ -116,10 +117,15 @@ struct DebugPrintServerContext { // TODO(dma): once we have access to the device is there a way we can poll the device to // check whether more print data is coming? do { + // No need to await if the server was killed already due to a hang. + if (server_killed_due_to_hang_) + break; std::this_thread::sleep_for(std::chrono::milliseconds(5)); - } while (hart_waiting_on_signal_.size() > 0 || new_data_processed_); + } while (hart_waiting_on_signal_.size() > 0 || new_data_last_iter_); } + bool print_hang_detected() { return server_killed_due_to_hang_; } + private: // Flag for main thread to signal the print server thread to stop. @@ -129,9 +135,12 @@ struct DebugPrintServerContext { std::atomic mute_print_server_; // Flag for signalling whether the print server thread has recently processed data (and is // therefore likely to continue processing data in the next round of polling). - std::atomic new_data_processed_; + std::atomic new_data_last_iter_; std::thread* print_server_thread_; + // A flag to signal to the main thread if the print server detected a print-based hang. + bool server_killed_due_to_hang_; + std::ofstream* outfile_ = nullptr; // non-cout std::ostream* stream_ = nullptr; // either == outfile_ or is &cout @@ -142,7 +151,12 @@ struct DebugPrintServerContext { std::set raised_signals_; void thread_poll(const vector& chip_ids, const vector& cores, uint32_t hart_mask); - bool peek_flush_one_hart_nonblocking(int chip_id, const CoreCoord& core, int hart_index); + bool peek_flush_one_hart_nonblocking( + int chip_id, + const CoreCoord& core, + int hart_index, + bool new_data_this_iter + ); }; static void print_tile_slice(ostream& stream, uint8_t* ptr, int hart_id) { @@ -210,7 +224,12 @@ bool check_init_magic_cleared(int chip_id, const CoreCoord& core, int hart_id) { // Peeks a specified hart for any debug prints present in the buffer and flushes it, printing the // contents out to host-side stream. Returns true if some data was read out, and false if no new // print data was present on the device. -bool DebugPrintServerContext::peek_flush_one_hart_nonblocking(int chip_id, const CoreCoord& core, int hart_id) { +bool DebugPrintServerContext::peek_flush_one_hart_nonblocking( + int chip_id, + const CoreCoord& core, + int hart_id, + bool new_data_this_iter +) { // compute the buffer address for the requested hart uint32_t base_addr = PRINT_BUFFER_NC + hart_id*PRINT_BUFFER_SIZE; @@ -226,7 +245,45 @@ bool DebugPrintServerContext::peek_flush_one_hart_nonblocking(int chip_id, const uint32_t counter = 0; uint32_t sigval = 0; char val = 0; + + // If the print server is muted, dump the output to a null stream instead. ostream& stream = (mute_print_server_)? null_stream : *stream_; + + // Check whether this hart is currently waiting on a WAIT to be fulfilled. + tuple hart_key {core.x, core.y, hart_id}; + if (hart_waiting_on_signal_.count(hart_key) > 0) { + // Check if the signal the hart is wairint for has been raised. + uint32_t wait_signal = hart_waiting_on_signal_[hart_key]; + if (raised_signals_.count(wait_signal) > 0) { + // The signal has been raised, we can continue. + hart_waiting_on_signal_.erase(hart_key); + } else { + // This hart is still waiting. This is fine as long as the print server (and therefore + // the device) is still making progress. Unfortunetaly there's no way to check if the + // print server is full because the next print that would overflow the buffer spins the + // device until the buffer has more space, but checking for any new prints seems to work + // for cases so far. + if (!new_data_this_iter && !new_data_last_iter_) { + // If no progress was made on both sides, then it could be an invalid wait + // condition, which could cause a deadlock. Print a warning and set a flag to close + // the print server in this case. + string core_str = "core (" + to_string(core.x) + "," + to_string(core.y) + + ") riscv " + to_string(hart_id); + string error_str = "DPRINT server timed out on " + + core_str + + ", waiting on a RAISE signal: " + + to_string(wait_signal) + "\n"; + stream << error_str << flush; + log_warning(tt::LogMetal, "Debug Print Server encountered an error: {}", error_str); + server_killed_due_to_hang_ = true; + return false; + } + + // Since it's still waiting, return false here since no data was read. + return false; + } + } + if (rpos < wpos) { // Now read the entire buffer from_dev = tt::llrt::read_hex_vec_from_core(chip_id, core, base_addr, PRINT_BUFFER_SIZE); @@ -404,12 +461,12 @@ void DebugPrintServerContext::thread_poll( if (stop_print_server_) { // If the stop signal was received, exit the print server thread, but wait for any // existing prints to be wrapped up first. - if (hart_waiting_on_signal_.size() == 0 && !new_data_processed_) + if (hart_waiting_on_signal_.size() == 0 && !new_data_last_iter_) break; } // Flag for whether any new print data was found in this round of polling. - bool new_print_data = false; + bool new_data_this_iter = false; for (auto chip: chip_ids) { for (auto core: cores) { for (int hart_index = 0; hart_index < DPRINT_NRISCVS; hart_index++) { @@ -417,31 +474,25 @@ void DebugPrintServerContext::thread_poll( if (!check_init_magic_cleared(chip, core, hart_index)) continue; - // Make sure that this core is not waiting on a raise signal to continue - // printing. - tuple hart_key {core.x, core.y, hart_index}; - if (hart_waiting_on_signal_.count(hart_key) > 0) { - uint32_t wait_signal = hart_waiting_on_signal_[hart_key]; - if (raised_signals_.count(wait_signal) > 0) { - // The signal this hart is waiting for has been raised, it's not - // waiting anymore. - hart_waiting_on_signal_.erase(hart_key); - } else { - // Not raised yet, keep waiting. - continue; - } - } - - new_print_data |= peek_flush_one_hart_nonblocking(chip, core, hart_index); + new_data_this_iter |= peek_flush_one_hart_nonblocking( + chip, + core, + hart_index, + new_data_this_iter + ); + + // If this read detected a print hang, stop processing prints. + if (server_killed_due_to_hang_) + return; } } } } // Signal whether the print server is currently processing data. - new_data_processed_ = new_print_data; + new_data_last_iter_ = new_data_this_iter; // Sleep for a few ms if no data was processed. - if (!new_print_data) + if (!new_data_last_iter_) std::this_thread::sleep_for(std::chrono::milliseconds(1)); } } @@ -489,6 +540,11 @@ void tt_await_debug_print_server() { } } +bool tt_print_hang_detected() { + return (DebugPrintServerContext::inst != nullptr) + && (DebugPrintServerContext::inst->print_hang_detected()); +} + // The print server is not valid without alive Cluster and tt_device void tt_start_debug_print_server( std::functionget_grid_size, @@ -533,7 +589,7 @@ void tt_start_debug_print_server( if (all_worker_cores.count(core) > 0) { print_cores_sanitized.push_back(core); } else { - log_info( + log_warning( tt::LogDevice, "TT_METAL_DPRINT_CORES included worker core ({}, {}), which is not a valid coordinate. This coordinate will be ignored by the dprint server.", core.x, diff --git a/tt_metal/impl/debug/dprint_server.hpp b/tt_metal/impl/debug/dprint_server.hpp index 9d7fd5c6d7d..6bd370c484f 100644 --- a/tt_metal/impl/debug/dprint_server.hpp +++ b/tt_metal/impl/debug/dprint_server.hpp @@ -81,3 +81,13 @@ Note that this function does not actually check whether the device will continue data, it only checks whether the print server is currently processing print data. */ void tt_await_debug_print_server(); + +/** +@brief Check whether a print hang has been detected by the print server. + +The print server tries to determine if a core is stalled due to the combination of (1) a WAIT +print command and (2) no new print data coming through. An invalid WAIT command and the print +buffer filling up afterwards can cause the core to spin forever. In this case this function will +return true and the print server will be terminated. +*/ +bool tt_print_hang_detected(); diff --git a/tt_metal/impl/dispatch/command_queue.cpp b/tt_metal/impl/dispatch/command_queue.cpp index da7d199d963..c7e95a097b9 100644 --- a/tt_metal/impl/dispatch/command_queue.cpp +++ b/tt_metal/impl/dispatch/command_queue.cpp @@ -11,6 +11,7 @@ #include "tt_metal/detail/tt_metal.hpp" #include "tt_metal/host_api.hpp" #include "tt_metal/impl/buffers/semaphore.hpp" +#include "tt_metal/impl/debug/dprint_server.hpp" #include "tt_metal/third_party/umd/device/tt_xy_pair.h" #include "dev_msgs.h" #include // for copy() and assign() @@ -876,6 +877,12 @@ void CommandQueue::finish() { uint32_t finish; do { tt::Cluster::instance().read_sysmem(&finish, 4, HOST_CQ_FINISH_PTR, 0); + + // There's also a case where the device can be hung due to an unanswered DPRINT WAIT and + // a full print buffer. Poll the print server for this case and throw if it happens. + if (tt_print_hang_detected()) { + TT_THROW("Command Queue could not finish: device hang due to unanswered DPRINT WAIT."); + } } while (finish != 1); // Reset this value to 0 before moving on