Skip to content

Commit

Permalink
#4214: Add dprint testing for multi-device
Browse files Browse the repository at this point in the history
  • Loading branch information
tt-dma committed Dec 22, 2023
1 parent 0675d8e commit d477a42
Show file tree
Hide file tree
Showing 8 changed files with 152 additions and 58 deletions.
87 changes: 60 additions & 27 deletions tests/tt_metal/tt_metal/unit_tests_common/common/dprint_fixture.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,27 @@
class DPrintFixture: public ::testing::Test {
public:
inline static const string dprint_file_name = "gtest_dprint_log.txt";

// A function to run a program, according to which dispatch mode is set.
void RunProgram(Device* device, Program& program) {
if (this->slow_dispatch_) {
// Slow dispatch uses LaunchProgram
tt::tt_metal::detail::LaunchProgram(device, program);
} else {
// Fast Dispatch uses the command queue
CommandQueue& cq = tt::tt_metal::detail::GetCommandQueue(device);
EnqueueProgram(cq, program, false);
Finish(cq);
}

// Wait for the print server to catch up if needed.
tt::DprintServerAwait();
}
protected:
tt::ARCH arch_;
Device* device_;
vector<Device*> devices_;
bool slow_dispatch_;

// A flag to mark if the test is skipped or not. Since we skip before
// device setup, we need to skip device teardown if the test is skipped.
bool test_skipped = false;
bool has_remote_devices_;

void SetUp() override {
// Skip for slow dispatch for now
Expand All @@ -34,46 +47,66 @@ class DPrintFixture: public ::testing::Test {
slow_dispatch_ = false;
}
// The core range (physical) needs to be set >= the set of all cores
// used by all tests using this fixture. TODO: update with a way to
// just set all physical cores to have printing enabled.
// used by all tests using this fixture, so set dprint enabled for
// all cores and all devices
tt::llrt::OptionsG.set_dprint_all_cores(true);
tt::llrt::OptionsG.set_dprint_all_chips(true);
// Send output to a file so the test can check after program is run.
tt::llrt::OptionsG.set_dprint_file_name(dprint_file_name);

// Parent call, sets up the device
// Set up all available devices
this->arch_ = tt::get_arch_from_string(tt::test_utils::get_env_arch_name());
auto num_devices = tt::tt_metal::GetNumAvailableDevices();
auto num_pci_devices = tt::tt_metal::GetNumPCIeDevices();
for (unsigned int id = 0; id < num_devices; id++) {
if (SkipTest(id))
continue;
auto* device = tt::tt_metal::CreateDevice(id);
devices_.push_back(device);
}

const int device_id = 0;
this->device_ = tt::tt_metal::CreateDevice(device_id);

// An extra flag for if we have remote devices, as some tests are disabled for fast
// dispatch + remote devices.
this->has_remote_devices_ = num_devices > num_pci_devices;
}

void TearDown() override {
if (!test_skipped) {
tt::tt_metal::CloseDevice(this->device_);
// Remove the DPrint output file after the test is finished.
std::remove(dprint_file_name.c_str());
// Close all opened devices
for (unsigned int id = 0; id < devices_.size(); id++) {
tt::tt_metal::CloseDevice(devices_.at(id));
}
// Remove the DPrint output file after the test is finished.
std::remove(dprint_file_name.c_str());

// Reset DPrint settings
tt::llrt::OptionsG.set_dprint_cores({});
tt::llrt::OptionsG.set_dprint_all_cores(false);
tt::llrt::OptionsG.set_dprint_file_name("");
}

// A function to run a program, according to which dispatch mode is set.
void RunProgram(Program& program) {
if (this->slow_dispatch_) {
// Slow dispatch uses LaunchProgram
tt::tt_metal::detail::LaunchProgram(this->device_, program);
} else {
// Fast Dispatch uses the command queue
CommandQueue& cq = tt::tt_metal::detail::GetCommandQueue(this->device_);
EnqueueProgram(cq, program, false);
Finish(cq);
bool SkipTest(unsigned int device_id) {
// Skip condition is fast dispatch for remote devices
if (this->has_remote_devices_ && !this->slow_dispatch_ && device_id != 0) {
log_info(
tt::LogTest,
"Skipping test on device {} due to fast dispatch unsupported on remote devices.",
device_id
);
return true;
}
return false;
}

// Wait for the print server to catch up if needed.
tt::DprintServerAwait();
void RunTestOnDevice(
const std::function<void(DPrintFixture*, Device*)>& run_function,
Device* device
) {
if (SkipTest(device->id()))
return;
log_info(tt::LogTest, "Running test on device {}.", device->id());
run_function(this, device);
log_info(tt::LogTest, "Finished running test on device {}.", device->id());
tt::DPrintServerClearLogFile();
tt::DPrintServerClearSignals();
}
};
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,7 @@ const std::string golden_output =
R"(Printing int from arg: 0
Printing int from arg: 2)";

TEST_F(DPrintFixture, TestPrintMuting) {
// Device already set up by gtest fixture.
Device *device = this->device_;

static void RunTest(DPrintFixture* fixture, Device* device) {
// Set up program
Program program = Program();

Expand All @@ -42,7 +39,7 @@ TEST_F(DPrintFixture, TestPrintMuting) {
core,
{test_number}
);
RunProgram(program);
fixture->RunProgram(device, program);
};

// Run the program, prints should be enabled.
Expand All @@ -64,3 +61,9 @@ TEST_F(DPrintFixture, TestPrintMuting) {
)
);
}

TEST_F(DPrintFixture, TestPrintMuting) {
for (Device* device : this->devices_) {
this->RunTestOnDevice(RunTest, device);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,7 @@ TILE: (
ptr=122880))";

TEST_F(DPrintFixture, TestPrintFromAllHarts) {
// Device already set up by gtest fixture.
Device *device = this->device_;

static void RunTest(DPrintFixture* fixture, Device* device) {
// Set up program and command queue
constexpr CoreCoord core = {0, 0}; // Print on first core only
Program program = Program();
Expand Down Expand Up @@ -183,7 +180,7 @@ TEST_F(DPrintFixture, TestPrintFromAllHarts) {
);

// Run the program
RunProgram(program);
fixture->RunProgram(device, program);

// Check that the expected print messages are in the log file
EXPECT_TRUE(
Expand All @@ -193,3 +190,9 @@ TEST_F(DPrintFixture, TestPrintFromAllHarts) {
)
);
}

TEST_F(DPrintFixture, TestPrintFromAllHarts) {
for (Device* device : this->devices_) {
this->RunTestOnDevice(RunTest, device);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,7 @@ const std::string golden_output =
R"(DPRINT server timed out on core (1,1) riscv 4, waiting on a RAISE signal: 1
)";

TEST_F(DPrintFixture, TestPrintHanging) {
// Skip this test for slow dipatch for now. Due to how llrt currently sits below device, it's
// tricky to check print server status from the finish loop for slow dispatch. Once issue #4363
// is resolved, we should add a check for print server handing in slow dispatch as well.
if (this->slow_dispatch_)
GTEST_SKIP();

// Device already set up by gtest fixture.
Device *device = this->device_;

static void RunTest(DPrintFixture* fixture, Device* device) {
// Set up program
Program program = Program();

Expand All @@ -43,7 +34,7 @@ TEST_F(DPrintFixture, TestPrintHanging) {

// Run the program, we expect it to throw on waiting for CQ to finish
try {
RunProgram(program);
fixture->RunProgram(device, program);
} catch (std::runtime_error& e) {
const string expected = "Command Queue could not finish: device hang due to unanswered DPRINT WAIT.";
const string error = string(e.what());
Expand All @@ -59,3 +50,15 @@ try {
)
);
}

TEST_F(DPrintFixture, TestPrintHanging) {
// Skip this test for slow dipatch for now. Due to how llrt currently sits below device, it's
// tricky to check print server status from the finish loop for slow dispatch. Once issue #4363
// is resolved, we should add a check for print server handing in slow dispatch as well.
if (this->slow_dispatch_)
GTEST_SKIP();

for (Device* device : this->devices_) {
this->RunTestOnDevice(RunTest, device);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -215,10 +215,8 @@ TestConstCharStrNC{4,4}
----------
TestStrBR{4,4}
+++++++++++++++)";
TEST_F(DPrintFixture, TestPrintRaiseWait) {
// Device already set up by gtest fixture.
Device *device = this->device_;

static void RunTest(DPrintFixture* fixture, Device* device) {
// Set up program and command queue
Program program = Program();

Expand Down Expand Up @@ -280,7 +278,7 @@ TEST_F(DPrintFixture, TestPrintRaiseWait) {


// Run the program
RunProgram(program);
fixture->RunProgram(device, program);

// Check the print log against golden output.
EXPECT_TRUE(
Expand All @@ -290,3 +288,9 @@ TEST_F(DPrintFixture, TestPrintRaiseWait) {
)
);
}

TEST_F(DPrintFixture, TestPrintRaiseWait) {
for (Device* device : this->devices_) {
this->RunTestOnDevice(RunTest, device);
}
}
44 changes: 38 additions & 6 deletions tt_metal/impl/debug/dprint_server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,15 @@ struct DebugPrintServerContext {
// This device must have been attached previously.
void DetachDevice(Device* device);

// Clears the log file of a currently-running print server.
void ClearLogFile();

// Clears any raised signals (so they can be used again in a later run).
void ClearSignals();

int GetNumAttachedDevices() { return device_to_core_range_.size(); }

bool print_hang_detected() { return server_killed_due_to_hang_; }
bool PrintHangDetected() { return server_killed_due_to_hang_; }

private:

Expand Down Expand Up @@ -252,7 +258,7 @@ void DebugPrintServerContext::WaitForPrintsFinished() {
// No need to await if the server was killed already due to a hang.
if (server_killed_due_to_hang_)
break;
std::this_thread::sleep_for(std::chrono::milliseconds(5));
std::this_thread::sleep_for(std::chrono::milliseconds(20));
} while (hart_waiting_on_signal_.size() > 0 || new_data_last_iter_);
} // WaitForPrintsFinished

Expand Down Expand Up @@ -331,6 +337,22 @@ void DebugPrintServerContext::DetachDevice(Device* device) {
log_info(tt::LogMetal, "DPRINT Server dettached device {}", device->id());
} // DetachDevice

void DebugPrintServerContext::ClearLogFile() {
if (outfile_) {
// Just close the file and re-open it (without append) to clear it.
outfile_->close();
delete outfile_;

string file_name = tt::llrt::OptionsG.get_dprint_file_name();
outfile_ = new std::ofstream(file_name);
stream_ = outfile_ ? outfile_ : &cout;
}
} // ClearLogFile

void DebugPrintServerContext::ClearSignals() {
raised_signals_.clear();
} // ClearSignals

bool DebugPrintServerContext::PeekOneHartNonBlocking(
int chip_id,
const CoreCoord& core,
Expand Down Expand Up @@ -615,8 +637,9 @@ void DprintServerAttach(Device* device) {

// Skip if RTOptions doesn't enable DPRINT for this device
vector<chip_id_t> chip_ids = tt::llrt::OptionsG.get_dprint_chip_ids();
if (std::find(chip_ids.begin(), chip_ids.end(), device->id()) == chip_ids.end())
return;
if (!tt::llrt::OptionsG.get_dprint_all_chips())
if (std::find(chip_ids.begin(), chip_ids.end(), device->id()) == chip_ids.end())
return;

// If no server ir running, create one
if (!DprintServerIsRunning())
Expand Down Expand Up @@ -666,7 +689,16 @@ void DprintServerAwait() {
}

bool DPrintServerHangDetected() {
return (DebugPrintServerContext::inst != nullptr)
&& (DebugPrintServerContext::inst->print_hang_detected());
return DprintServerIsRunning() && DebugPrintServerContext::inst->PrintHangDetected();
}

void DPrintServerClearLogFile() {
if (DprintServerIsRunning())
DebugPrintServerContext::inst->ClearLogFile();
}

void DPrintServerClearSignals() {
if (DprintServerIsRunning())
DebugPrintServerContext::inst->ClearSignals();
}
} // namespace tt
10 changes: 10 additions & 0 deletions tt_metal/impl/debug/dprint_server.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,14 @@ return true and the print server will be terminated.
*/
bool DPrintServerHangDetected();

/**
@brief Clears the print server log file.
*/
void DPrintServerClearLogFile();

/**
@brief Clears any RAISE signals in the print server, so they can be used again in a later run.
*/
void DPrintServerClearSignals();

} // namespace tt
6 changes: 6 additions & 0 deletions tt_metal/llrt/rtoptions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class RunTimeOptions {
std::vector<CoreCoord> dprint_cores;
bool dprint_all_cores;
std::vector<int> dprint_chip_ids;
bool dprint_all_chips;
uint32_t dprint_riscv_mask;
std::string dprint_file_name;

Expand Down Expand Up @@ -76,6 +77,11 @@ class RunTimeOptions {
inline void set_dprint_chip_ids(std::vector<int> chip_ids) {
dprint_chip_ids = chip_ids;
}
// An alternative to setting cores by range, a flag to enable all.
inline void set_dprint_all_chips(bool all_chips) {
dprint_all_chips = all_chips;
}
inline bool get_dprint_all_chips() { return dprint_all_chips; }
inline uint32_t get_dprint_riscv_mask() { return dprint_riscv_mask; }
inline void set_dprint_riscv_mask(uint32_t riscv_mask) {
dprint_riscv_mask = riscv_mask;
Expand Down

0 comments on commit d477a42

Please sign in to comment.