Skip to content

Commit

Permalink
#0: temp
Browse files Browse the repository at this point in the history
  • Loading branch information
aliuTT committed May 21, 2024
1 parent 678bd04 commit 0c168ec
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,58 +15,91 @@
using namespace tt;
using namespace tt::test_utils;

TEST_F(FDBasicFixture, DevicePoolUninitialized) {
TEST_F(FDBasicFixture, DevicePoolOpenClose) {
std::vector<chip_id_t> device_ids{0};
int num_hw_cqs = 1;
int l1_small_size = 1024;
ASSERT_ANY_THROW(std::vector<Device *> devices = tt::DevicePool::instance().get_all_devices());
tt::DevicePool::initialize(device_ids, num_hw_cqs, l1_small_size);
std::vector<Device *> devices = tt::DevicePool::instance().get_all_active_devices();
for (const auto& dev: devices) {
ASSERT_TRUE((int)(dev->get_l1_small_size()) == l1_small_size);
ASSERT_TRUE((int)(dev->num_hw_cqs()) == num_hw_cqs);
ASSERT_TRUE(dev->is_initialized());
}

// Close then get devices again
for (const auto& dev: devices) {
dev->close();
}
devices = tt::DevicePool::instance().get_all_active_devices();
for (const auto& dev: devices) {
ASSERT_TRUE((int)(dev->get_l1_small_size()) == l1_small_size);
ASSERT_TRUE((int)(dev->num_hw_cqs()) == num_hw_cqs);
ASSERT_TRUE(dev->is_initialized());
}
for (const auto& dev: devices) {
dev->close();
}
}

TEST_F(FDBasicFixture, DevicePoolOpenClose) {
TEST_F(FDBasicFixture, DevicePoolReconfigDevices) {
std::vector<chip_id_t> device_ids{0};
int num_hw_cqs = 1;
int l1_small_size = 1024;
tt::DevicePool::initialize(device_ids, num_hw_cqs, l1_small_size);
std::vector<Device *> devices = tt::DevicePool::instance().get_all_devices();
std::vector<Device *> devices = tt::DevicePool::instance().get_all_active_devices();
for (const auto& dev: devices) {
ASSERT_TRUE((int)(dev->get_l1_small_size()) == l1_small_size);
ASSERT_TRUE((int)(dev->num_hw_cqs()) == num_hw_cqs);
ASSERT_TRUE(dev->is_initialized());
}

// Close then get devices again
// Close then get devices with different configs
for (const auto& dev: devices) {
std::cout << "closing device " << dev->id() << std::endl;
dev->close();
}
devices = tt::DevicePool::instance().get_all_devices();
l1_small_size = 2048;
tt::DevicePool::initialize(device_ids, num_hw_cqs, l1_small_size);
devices = tt::DevicePool::instance().get_all_active_devices();
for (const auto& dev: devices) {
ASSERT_TRUE((int)(dev->get_l1_small_size()) == l1_small_size);
ASSERT_TRUE((int)(dev->num_hw_cqs()) == num_hw_cqs);
ASSERT_TRUE(dev->is_initialized());
}
for (const auto& dev: devices) {
dev->close();
}
}

TEST_F(FDBasicFixture, DevicePoolAddDevices) {
if (tt::tt_metal::GetNumAvailableDevices() != 8) {
GTEST_SKIP();
}
std::vector<chip_id_t> device_ids{0};
int num_hw_cqs = 1;
int l1_small_size = 1024;
tt::DevicePool::initialize(device_ids, num_hw_cqs, l1_small_size);
std::vector<Device *> devices = tt::DevicePool::instance().get_all_devices();
std::vector<Device *> devices = tt::DevicePool::instance().get_all_active_devices();
for (const auto& dev: devices) {
ASSERT_TRUE((int)(dev->get_l1_small_size()) == l1_small_size);
ASSERT_TRUE((int)(dev->num_hw_cqs()) == num_hw_cqs);
ASSERT_TRUE(dev->is_initialized());
}

// Close then get devices again
// Close then get more devices
for (const auto& dev: devices) {
dev->close();
}
devices.clear();
devices = tt::DevicePool::instance().get_all_devices();
device_ids = {0, 1, 2, 3};
tt::DevicePool::initialize(device_ids, num_hw_cqs, l1_small_size);
devices = tt::DevicePool::instance().get_all_active_devices();
ASSERT_TRUE(devices.size() >= 4);
for (const auto& dev: devices) {
ASSERT_TRUE((int)(dev->get_l1_small_size()) == l1_small_size);
ASSERT_TRUE((int)(dev->num_hw_cqs()) == num_hw_cqs);
ASSERT_TRUE(dev->is_initialized());
}
for (const auto& dev: devices) {
dev->close();
}
}
11 changes: 6 additions & 5 deletions tt_metal/impl/device/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,9 @@ void ::detail::ProgramDeleter::operator()(Program *p) {

Device::Device(
chip_id_t device_id, const uint8_t num_hw_cqs, size_t l1_small_size, const std::vector<uint32_t> &l1_bank_remap, bool minimal) :
id_(device_id), num_hw_cqs_(num_hw_cqs), work_executor(device_id) {
id_(device_id), work_executor(device_id) {
ZoneScoped;
TT_ASSERT(num_hw_cqs > 0 and num_hw_cqs < 3, "num_hw_cqs can be between 1 and 2");
this->build_key_ = tt::Cluster::instance().get_harvesting_mask(device_id);
this->initialize(l1_small_size, l1_bank_remap, minimal);
this->initialize(num_hw_cqs, l1_small_size, l1_bank_remap, minimal);
}

void Device::initialize_cluster() {
Expand Down Expand Up @@ -1385,9 +1383,12 @@ void Device::initialize_synchronous_sw_cmd_queue() {
}
}

bool Device::initialize(size_t l1_small_size, const std::vector<uint32_t> &l1_bank_remap, bool minimal) {
bool Device::initialize(const uint8_t num_hw_cqs, size_t l1_small_size, const std::vector<uint32_t> &l1_bank_remap, bool minimal) {
ZoneScoped;
log_info(tt::LogMetal, "Initializing device {}. Program cache is {}enabled", this->id_, this->program_cache.is_enabled() ? "": "NOT ");
TT_ASSERT(num_hw_cqs > 0 and num_hw_cqs < 3, "num_hw_cqs can be between 1 and 2");
this->build_key_ = tt::Cluster::instance().get_harvesting_mask(this->id());
this->num_hw_cqs_ = num_hw_cqs;
this->initialize_cluster();
this->initialize_allocator(l1_small_size, l1_bank_remap);
this->initialize_build();
Expand Down
2 changes: 1 addition & 1 deletion tt_metal/impl/device/device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ class Device {

// Checks that the given arch is on the given pci_slot and that it's responding
// Puts device into reset
bool initialize(size_t l1_small_size, const std::vector<uint32_t> &l1_bank_remap = {}, bool minimal = false);
bool initialize(const uint8_t num_hw_cqs, size_t l1_small_size, const std::vector<uint32_t> &l1_bank_remap = {}, bool minimal = false);
void initialize_cluster();
void initialize_allocator(size_t l1_small_size, const std::vector<uint32_t> &l1_bank_remap = {});
void initialize_build();
Expand Down
26 changes: 14 additions & 12 deletions tt_metal/impl/device/device_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,14 @@ void DevicePool::activate_device(chip_id_t id) {
this->initialize_device(dev);
this->devices[id] = std::unique_ptr<Device>(dev);

} else if (this->devices[id]->state() == ActiveState::ACTIVE) {
TT_THROW("Cannot re-initialize device {}, must first call close()", id);
} else {
const auto& dev = this->devices[id];
if (not dev->is_initialized()) {
dev->initialize(num_hw_cqs, this->l1_small_size, this->l1_bank_remap);
this->initialize_device(dev.get());
} else {
TT_THROW("Cannot re-initialize device {}, must first call close()", id);
}
}

}
Expand All @@ -65,6 +71,7 @@ void DevicePool::add_devices_to_pool(std::vector<chip_id_t> device_ids, const ui
}
}
}
tt::Cluster::instance().set_internal_routing_info_for_ethernet_cores(true);
}

DevicePool::DevicePool(std::vector<chip_id_t> device_ids, const uint8_t num_hw_cqs, size_t l1_small_size, const std::vector<uint32_t> &l1_bank_remap) {
Expand All @@ -79,18 +86,13 @@ Device* DevicePool::get_active_device(chip_id_t device_id) const {
return this->devices[device_id].get();
}

std::vector<Device*> DevicePool::get_all_devices() const {
std::vector<Device*> DevicePool::get_all_active_devices() const {
std::vector<Device*> user_devices;
for (const auto& dev : this->devices) {
if (dev != nullptr) {
if (not dev->is_initialized()) {
dev->initialize(this->l1_small_size, this->l1_bank_remap);
this->initialize_device(dev.get());
}
user_devices.emplace_back(dev.get());
}
for (int id=0; id < this->devices.size(); id++) {
if(this->is_device_active(id)) {
user_devices.emplace_back(this->devices[id].get());
}
}
tt::Cluster::instance().set_internal_routing_info_for_ethernet_cores(true);
return user_devices;
}

Expand Down
2 changes: 1 addition & 1 deletion tt_metal/impl/device/device_pool.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class DevicePool {
}

Device* get_active_device(chip_id_t device_id) const;
std::vector<Device*> get_all_devices() const;
std::vector<Device*> get_all_active_devices() const;
bool close_device(chip_id_t device_id) const;

private:
Expand Down
2 changes: 1 addition & 1 deletion tt_metal/tt_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ std::map<chip_id_t, Device *> CreateDevices(
ZoneScoped;
std::cout << " CreateDevices " << std::endl;
tt::DevicePool::initialize(device_ids, num_hw_cqs, l1_small_size);
std::vector<Device *> devices = tt::DevicePool::instance().get_all_devices();
std::vector<Device *> devices = tt::DevicePool::instance().get_all_active_devices();
std::map<chip_id_t, Device *> ret_devices;
for (Device * dev: devices) {
ret_devices.insert({dev->id(), dev});
Expand Down

0 comments on commit 0c168ec

Please sign in to comment.