Skip to content

Commit

Permalink
#5048: Add CreateDevices and CloseDevices api to detail
Browse files Browse the repository at this point in the history
  • Loading branch information
aliuTT committed Feb 5, 2024
1 parent 001b7d7 commit 214049e
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
// SPDX-License-Identifier: Apache-2.0

#include "gtest/gtest.h"
#include "tt_metal/detail/tt_metal.hpp"
#include "tt_metal/host_api.hpp"
#include "tt_metal/test_utils/env_vars.hpp"
#include "tt_metal/impl/dispatch/command_queue.hpp"
#include "tt_metal/llrt/rtoptions.hpp"
#include "tt_metal/test_utils/env_vars.hpp"

using namespace tt::tt_metal;
class CommandQueueFixture : public ::testing::Test {
Expand Down Expand Up @@ -81,21 +82,20 @@ class CommandQueuePCIDevicesFixture : public ::testing::Test {
GTEST_SKIP();
}

std::vector<chip_id_t> chip_ids;
for (unsigned int id = 0; id < num_devices_; id++) {
auto* device = tt::tt_metal::CreateDevice(id);
devices_.push_back(device);
chip_ids.push_back(id);
}
tt::Cluster::instance().set_internal_routing_info_for_ethernet_cores(true);
}

void TearDown() override {
tt::Cluster::instance().set_internal_routing_info_for_ethernet_cores(false);
for (unsigned int id = 0; id < devices_.size(); id++) {
tt::tt_metal::CloseDevice(devices_.at(id));
reserved_devices_ = tt::tt_metal::detail::CreateDevices(chip_ids);
for (const auto& id : chip_ids) {
devices_.push_back(reserved_devices_.at(id));
}
}

void TearDown() override { tt::tt_metal::detail::CloseDevices(reserved_devices_); }

std::vector<tt::tt_metal::Device*> devices_;
std::map<chip_id_t, tt::tt_metal::Device*> reserved_devices_;
tt::ARCH arch_;
size_t num_devices_;
};
7 changes: 7 additions & 0 deletions tt_metal/detail/tt_metal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@ namespace tt::tt_metal{
return fd;
}

std::map<chip_id_t, Device *> CreateDevices(
std::vector<chip_id_t> device_ids,
const uint8_t num_hw_cqs = 1,
const std::vector<uint32_t> &l1_bank_remap = {});

void CloseDevices(std::map<chip_id_t, Device *> devices);

/**
* Copies data from a host buffer into the specified buffer
*
Expand Down
2 changes: 1 addition & 1 deletion tt_metal/impl/device/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ bool ActiveDevices::activate_device(chip_id_t id) {
} else if (this->active_devices_[id] == ActiveState::ACTIVE) {
TT_THROW("Cannot re-initialize device {}, must first call close()", id);
} else {
already_initialized = true;
already_initialized = (this->active_devices_[id] == ActiveState::INACTIVE) ? true : false;
}
this->active_devices_[id] = ActiveState::ACTIVE;

Expand Down
25 changes: 25 additions & 0 deletions tt_metal/tt_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,31 @@ inline void SetRuntimeArgs(const Program &program, KernelHandle kernel_id, const

namespace detail {

std::map<chip_id_t, Device *> CreateDevices(
std::vector<chip_id_t> device_ids, const uint8_t num_hw_cqs, const std::vector<uint32_t> &l1_bank_remap) {
std::map<chip_id_t, Device *> active_devices; // TODO: pass this to CloseDevices
for (const auto &device_id : device_ids) {
const auto &mmio_device_id = tt::Cluster::instance().get_associated_mmio_device(device_id);
if (active_devices.find(mmio_device_id) == active_devices.end()) {
for (const auto &mmio_controlled_device_id :
tt::Cluster::instance().get_devices_controlled_by_mmio_device(mmio_device_id)) {
active_devices.insert(
{mmio_controlled_device_id, CreateDevice(mmio_controlled_device_id, num_hw_cqs, l1_bank_remap)});
}
}
}
// TODO: need to only enable routing for used mmio chips
tt::Cluster::instance().set_internal_routing_info_for_ethernet_cores(true);
return active_devices;
}

void CloseDevices(std::map<chip_id_t, Device *> devices) {
tt::Cluster::instance().set_internal_routing_info_for_ethernet_cores(false);
for (const auto &[device_id, dev] : devices) {
CloseDevice(dev);
}
}

void print_page(uint32_t dev_page_id, CoreCoord core, uint32_t host_page_id, CoreCoord noc_coordinates, uint32_t l1_address, uint32_t bank_id, std::vector<uint32_t> page){
std::cout << "dev_page_index " << dev_page_id << " on core " << core.str() << std::endl;
std::cout << "host_page_index " << host_page_id << std::endl;
Expand Down

0 comments on commit 214049e

Please sign in to comment.