diff --git a/tests/tt_metal/tt_metal/unit_tests_common/common/dprint_fixture.hpp b/tests/tt_metal/tt_metal/unit_tests_common/common/dprint_fixture.hpp index c161df12857..83bceb91043 100644 --- a/tests/tt_metal/tt_metal/unit_tests_common/common/dprint_fixture.hpp +++ b/tests/tt_metal/tt_metal/unit_tests_common/common/dprint_fixture.hpp @@ -14,14 +14,27 @@ class DPrintFixture: public ::testing::Test { public: inline static const string dprint_file_name = "gtest_dprint_log.txt"; + + // A function to run a program, according to which dispatch mode is set. + void RunProgram(Device* device, Program& program) { + if (this->slow_dispatch_) { + // Slow dispatch uses LaunchProgram + tt::tt_metal::detail::LaunchProgram(device, program); + } else { + // Fast Dispatch uses the command queue + CommandQueue& cq = tt::tt_metal::detail::GetCommandQueue(device); + EnqueueProgram(cq, program, false); + Finish(cq); + } + + // Wait for the print server to catch up if needed. + tt::DprintServerAwait(); + } protected: tt::ARCH arch_; - Device* device_; + vector devices_; bool slow_dispatch_; - - // A flag to mark if the test is skipped or not. Since we skip before - // device setup, we need to skip device teardown if the test is skipped. - bool test_skipped = false; + bool has_remote_devices_; void SetUp() override { // Skip for slow dispatch for now @@ -34,26 +47,36 @@ class DPrintFixture: public ::testing::Test { slow_dispatch_ = false; } // The core range (physical) needs to be set >= the set of all cores - // used by all tests using this fixture. TODO: update with a way to - // just set all physical cores to have printing enabled. + // used by all tests using this fixture, so set dprint enabled for + // all cores and all devices tt::llrt::OptionsG.set_dprint_all_cores(true); + tt::llrt::OptionsG.set_dprint_all_chips(true); // Send output to a file so the test can check after program is run. tt::llrt::OptionsG.set_dprint_file_name(dprint_file_name); - // Parent call, sets up the device + // Set up all available devices this->arch_ = tt::get_arch_from_string(tt::test_utils::get_env_arch_name()); + auto num_devices = tt::tt_metal::GetNumAvailableDevices(); + auto num_pci_devices = tt::tt_metal::GetNumPCIeDevices(); + for (unsigned int id = 0; id < num_devices; id++) { + if (SkipTest(id)) + continue; + auto* device = tt::tt_metal::CreateDevice(id); + devices_.push_back(device); + } - const int device_id = 0; - this->device_ = tt::tt_metal::CreateDevice(device_id); - + // An extra flag for if we have remote devices, as some tests are disabled for fast + // dispatch + remote devices. + this->has_remote_devices_ = num_devices > num_pci_devices; } void TearDown() override { - if (!test_skipped) { - tt::tt_metal::CloseDevice(this->device_); - // Remove the DPrint output file after the test is finished. - std::remove(dprint_file_name.c_str()); + // Close all opened devices + for (unsigned int id = 0; id < devices_.size(); id++) { + tt::tt_metal::CloseDevice(devices_.at(id)); } + // Remove the DPrint output file after the test is finished. + std::remove(dprint_file_name.c_str()); // Reset DPrint settings tt::llrt::OptionsG.set_dprint_cores({}); @@ -61,19 +84,29 @@ class DPrintFixture: public ::testing::Test { tt::llrt::OptionsG.set_dprint_file_name(""); } - // A function to run a program, according to which dispatch mode is set. - void RunProgram(Program& program) { - if (this->slow_dispatch_) { - // Slow dispatch uses LaunchProgram - tt::tt_metal::detail::LaunchProgram(this->device_, program); - } else { - // Fast Dispatch uses the command queue - CommandQueue& cq = tt::tt_metal::detail::GetCommandQueue(this->device_); - EnqueueProgram(cq, program, false); - Finish(cq); + bool SkipTest(unsigned int device_id) { + // Skip condition is fast dispatch for remote devices + if (this->has_remote_devices_ && !this->slow_dispatch_ && device_id != 0) { + log_info( + tt::LogTest, + "Skipping test on device {} due to fast dispatch unsupported on remote devices.", + device_id + ); + return true; } + return false; + } - // Wait for the print server to catch up if needed. - tt::DprintServerAwait(); + void RunTestOnDevice( + const std::function& run_function, + Device* device + ) { + if (SkipTest(device->id())) + return; + log_info(tt::LogTest, "Running test on device {}.", device->id()); + run_function(this, device); + log_info(tt::LogTest, "Finished running test on device {}.", device->id()); + tt::DPrintServerClearLogFile(); + tt::DPrintServerClearSignals(); } }; diff --git a/tests/tt_metal/tt_metal/unit_tests_common/dprint/test_mute_print_server.cpp b/tests/tt_metal/tt_metal/unit_tests_common/dprint/test_mute_print_server.cpp index 1470b2a207b..4b76841b498 100644 --- a/tests/tt_metal/tt_metal/unit_tests_common/dprint/test_mute_print_server.cpp +++ b/tests/tt_metal/tt_metal/unit_tests_common/dprint/test_mute_print_server.cpp @@ -18,10 +18,7 @@ const std::string golden_output = R"(Printing int from arg: 0 Printing int from arg: 2)"; -TEST_F(DPrintFixture, TestPrintMuting) { - // Device already set up by gtest fixture. - Device *device = this->device_; - +static void RunTest(DPrintFixture* fixture, Device* device) { // Set up program Program program = Program(); @@ -42,7 +39,7 @@ TEST_F(DPrintFixture, TestPrintMuting) { core, {test_number} ); - RunProgram(program); + fixture->RunProgram(device, program); }; // Run the program, prints should be enabled. @@ -64,3 +61,9 @@ TEST_F(DPrintFixture, TestPrintMuting) { ) ); } + +TEST_F(DPrintFixture, TestPrintMuting) { + for (Device* device : this->devices_) { + this->RunTestOnDevice(RunTest, device); + } +} diff --git a/tests/tt_metal/tt_metal/unit_tests_common/dprint/test_print_all_harts.cpp b/tests/tt_metal/tt_metal/unit_tests_common/dprint/test_print_all_harts.cpp index 0557921071e..9e241ce7ce4 100644 --- a/tests/tt_metal/tt_metal/unit_tests_common/dprint/test_print_all_harts.cpp +++ b/tests/tt_metal/tt_metal/unit_tests_common/dprint/test_print_all_harts.cpp @@ -144,10 +144,7 @@ TILE: ( ptr=122880))"; -TEST_F(DPrintFixture, TestPrintFromAllHarts) { - // Device already set up by gtest fixture. - Device *device = this->device_; - +static void RunTest(DPrintFixture* fixture, Device* device) { // Set up program and command queue constexpr CoreCoord core = {0, 0}; // Print on first core only Program program = Program(); @@ -183,7 +180,7 @@ TEST_F(DPrintFixture, TestPrintFromAllHarts) { ); // Run the program - RunProgram(program); + fixture->RunProgram(device, program); // Check that the expected print messages are in the log file EXPECT_TRUE( @@ -193,3 +190,9 @@ TEST_F(DPrintFixture, TestPrintFromAllHarts) { ) ); } + +TEST_F(DPrintFixture, TestPrintFromAllHarts) { + for (Device* device : this->devices_) { + this->RunTestOnDevice(RunTest, device); + } +} diff --git a/tests/tt_metal/tt_metal/unit_tests_common/dprint/test_print_hanging.cpp b/tests/tt_metal/tt_metal/unit_tests_common/dprint/test_print_hanging.cpp index 3e9c36bc096..0ed2278908e 100644 --- a/tests/tt_metal/tt_metal/unit_tests_common/dprint/test_print_hanging.cpp +++ b/tests/tt_metal/tt_metal/unit_tests_common/dprint/test_print_hanging.cpp @@ -19,16 +19,7 @@ const std::string golden_output = R"(DPRINT server timed out on core (1,1) riscv 4, waiting on a RAISE signal: 1 )"; -TEST_F(DPrintFixture, TestPrintHanging) { - // Skip this test for slow dipatch for now. Due to how llrt currently sits below device, it's - // tricky to check print server status from the finish loop for slow dispatch. Once issue #4363 - // is resolved, we should add a check for print server handing in slow dispatch as well. - if (this->slow_dispatch_) - GTEST_SKIP(); - - // Device already set up by gtest fixture. - Device *device = this->device_; - +static void RunTest(DPrintFixture* fixture, Device* device) { // Set up program Program program = Program(); @@ -43,7 +34,7 @@ TEST_F(DPrintFixture, TestPrintHanging) { // Run the program, we expect it to throw on waiting for CQ to finish try { - RunProgram(program); + fixture->RunProgram(device, program); } catch (std::runtime_error& e) { const string expected = "Command Queue could not finish: device hang due to unanswered DPRINT WAIT."; const string error = string(e.what()); @@ -59,3 +50,15 @@ try { ) ); } + +TEST_F(DPrintFixture, TestPrintHanging) { + // Skip this test for slow dipatch for now. Due to how llrt currently sits below device, it's + // tricky to check print server status from the finish loop for slow dispatch. Once issue #4363 + // is resolved, we should add a check for print server handing in slow dispatch as well. + if (this->slow_dispatch_) + GTEST_SKIP(); + + for (Device* device : this->devices_) { + this->RunTestOnDevice(RunTest, device); + } +} diff --git a/tests/tt_metal/tt_metal/unit_tests_common/dprint/test_raise_wait.cpp b/tests/tt_metal/tt_metal/unit_tests_common/dprint/test_raise_wait.cpp index 65652e2fbb9..9be90cb2de4 100644 --- a/tests/tt_metal/tt_metal/unit_tests_common/dprint/test_raise_wait.cpp +++ b/tests/tt_metal/tt_metal/unit_tests_common/dprint/test_raise_wait.cpp @@ -215,10 +215,8 @@ TestConstCharStrNC{4,4} ---------- TestStrBR{4,4} +++++++++++++++)"; -TEST_F(DPrintFixture, TestPrintRaiseWait) { - // Device already set up by gtest fixture. - Device *device = this->device_; +static void RunTest(DPrintFixture* fixture, Device* device) { // Set up program and command queue Program program = Program(); @@ -280,7 +278,7 @@ TEST_F(DPrintFixture, TestPrintRaiseWait) { // Run the program - RunProgram(program); + fixture->RunProgram(device, program); // Check the print log against golden output. EXPECT_TRUE( @@ -290,3 +288,9 @@ TEST_F(DPrintFixture, TestPrintRaiseWait) { ) ); } + +TEST_F(DPrintFixture, TestPrintRaiseWait) { + for (Device* device : this->devices_) { + this->RunTestOnDevice(RunTest, device); + } +} diff --git a/tt_metal/impl/debug/dprint_server.cpp b/tt_metal/impl/debug/dprint_server.cpp index a45c0337017..10a320ec174 100644 --- a/tt_metal/impl/debug/dprint_server.cpp +++ b/tt_metal/impl/debug/dprint_server.cpp @@ -75,9 +75,15 @@ struct DebugPrintServerContext { // This device must have been attached previously. void DetachDevice(Device* device); + // Clears the log file of a currently-running print server. + void ClearLogFile(); + + // Clears any raised signals (so they can be used again in a later run). + void ClearSignals(); + int GetNumAttachedDevices() { return device_to_core_range_.size(); } - bool print_hang_detected() { return server_killed_due_to_hang_; } + bool PrintHangDetected() { return server_killed_due_to_hang_; } private: @@ -252,7 +258,7 @@ void DebugPrintServerContext::WaitForPrintsFinished() { // No need to await if the server was killed already due to a hang. if (server_killed_due_to_hang_) break; - std::this_thread::sleep_for(std::chrono::milliseconds(5)); + std::this_thread::sleep_for(std::chrono::milliseconds(20)); } while (hart_waiting_on_signal_.size() > 0 || new_data_last_iter_); } // WaitForPrintsFinished @@ -331,6 +337,22 @@ void DebugPrintServerContext::DetachDevice(Device* device) { log_info(tt::LogMetal, "DPRINT Server dettached device {}", device->id()); } // DetachDevice +void DebugPrintServerContext::ClearLogFile() { + if (outfile_) { + // Just close the file and re-open it (without append) to clear it. + outfile_->close(); + delete outfile_; + + string file_name = tt::llrt::OptionsG.get_dprint_file_name(); + outfile_ = new std::ofstream(file_name); + stream_ = outfile_ ? outfile_ : &cout; + } +} // ClearLogFile + +void DebugPrintServerContext::ClearSignals() { + raised_signals_.clear(); +} // ClearSignals + bool DebugPrintServerContext::PeekOneHartNonBlocking( int chip_id, const CoreCoord& core, @@ -615,8 +637,9 @@ void DprintServerAttach(Device* device) { // Skip if RTOptions doesn't enable DPRINT for this device vector chip_ids = tt::llrt::OptionsG.get_dprint_chip_ids(); - if (std::find(chip_ids.begin(), chip_ids.end(), device->id()) == chip_ids.end()) - return; + if (!tt::llrt::OptionsG.get_dprint_all_chips()) + if (std::find(chip_ids.begin(), chip_ids.end(), device->id()) == chip_ids.end()) + return; // If no server ir running, create one if (!DprintServerIsRunning()) @@ -666,7 +689,16 @@ void DprintServerAwait() { } bool DPrintServerHangDetected() { - return (DebugPrintServerContext::inst != nullptr) - && (DebugPrintServerContext::inst->print_hang_detected()); + return DprintServerIsRunning() && DebugPrintServerContext::inst->PrintHangDetected(); +} + +void DPrintServerClearLogFile() { + if (DprintServerIsRunning()) + DebugPrintServerContext::inst->ClearLogFile(); +} + +void DPrintServerClearSignals() { + if (DprintServerIsRunning()) + DebugPrintServerContext::inst->ClearSignals(); } } // namespace tt diff --git a/tt_metal/impl/debug/dprint_server.hpp b/tt_metal/impl/debug/dprint_server.hpp index d4619bfb761..d853d78c830 100644 --- a/tt_metal/impl/debug/dprint_server.hpp +++ b/tt_metal/impl/debug/dprint_server.hpp @@ -86,4 +86,14 @@ return true and the print server will be terminated. */ bool DPrintServerHangDetected(); +/** +@brief Clears the print server log file. +*/ +void DPrintServerClearLogFile(); + +/** +@brief Clears any RAISE signals in the print server, so they can be used again in a later run. +*/ +void DPrintServerClearSignals(); + } // namespace tt diff --git a/tt_metal/llrt/rtoptions.hpp b/tt_metal/llrt/rtoptions.hpp index 80c6c9639a9..cdce708fa16 100644 --- a/tt_metal/llrt/rtoptions.hpp +++ b/tt_metal/llrt/rtoptions.hpp @@ -30,6 +30,7 @@ class RunTimeOptions { std::vector dprint_cores; bool dprint_all_cores; std::vector dprint_chip_ids; + bool dprint_all_chips; uint32_t dprint_riscv_mask; std::string dprint_file_name; @@ -76,6 +77,11 @@ class RunTimeOptions { inline void set_dprint_chip_ids(std::vector chip_ids) { dprint_chip_ids = chip_ids; } + // An alternative to setting cores by range, a flag to enable all. + inline void set_dprint_all_chips(bool all_chips) { + dprint_all_chips = all_chips; + } + inline bool get_dprint_all_chips() { return dprint_all_chips; } inline uint32_t get_dprint_riscv_mask() { return dprint_riscv_mask; } inline void set_dprint_riscv_mask(uint32_t riscv_mask) { dprint_riscv_mask = riscv_mask;