Skip to content

Commit

Permalink
#4269: Supports specifying up to 2 hardware command queues on a singl…
Browse files Browse the repository at this point in the history
…e device
  • Loading branch information
DrJessop committed Jan 4, 2024
1 parent 1417237 commit 0891c9b
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,34 +34,6 @@ class CommandQueueFixture : public ::testing::Test {
}
};

class MultiCommandQueueFixture : public ::testing::Test {
protected:
void SetUp() override {
auto slow_dispatch = getenv("TT_METAL_SLOW_DISPATCH_MODE");
if (slow_dispatch) {
TT_THROW("This suite can only be run with fast dispatch or TT_METAL_SLOW_DISPATCH_MODE unset");
GTEST_SKIP();
}
arch_ = tt::get_arch_from_string(tt::test_utils::get_env_arch_name());

num_devices_ = tt::tt_metal::GetNumAvailableDevices();

for (unsigned int id = 0; id < num_devices_; id++) {
auto* device = tt::tt_metal::CreateDevice(id);
devices_.push_back(device);
}
}

void TearDown() override {
for (unsigned int id = 0; id < devices_.size(); id++) {
tt::tt_metal::CloseDevice(devices_.at(id));
}
}

std::vector<tt::tt_metal::Device*> devices_;
tt::ARCH arch_;
size_t num_devices_;
};

class CommandQueueMultiDeviceFixture : public ::testing::Test {
protected:
Expand All @@ -73,6 +45,8 @@ class CommandQueueMultiDeviceFixture : public ::testing::Test {
}
arch_ = tt::get_arch_from_string(tt::test_utils::get_env_arch_name());

num_devices_ = tt::tt_metal::GetNumAvailableDevices();

for (unsigned int id = 0; id < num_devices_; id++) {
auto* device = tt::tt_metal::CreateDevice(id);
devices_.push_back(device);
Expand Down
2 changes: 2 additions & 0 deletions tt_metal/common/core_descriptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ inline const core_descriptor_t &get_core_descriptor_config(chip_id_t device_id,
}
producer_cores.push_back(coord);
}
TT_ASSERT(producer_cores.size(), "Producer cores size must be positive");

std::vector<RelativeCoreCoord> consumer_cores;
for (const auto& core_node : desc_yaml["consumer_cores"]) {
Expand All @@ -150,6 +151,7 @@ inline const core_descriptor_t &get_core_descriptor_config(chip_id_t device_id,
}
consumer_cores.push_back(coord);
}
TT_ASSERT(consumer_cores.size(), "Consumer cores size must be positive");

config_by_num_cqs[num_hw_cqs] = core_descriptor_t{
.compute_grid_size = compute_grid_size,
Expand Down
23 changes: 21 additions & 2 deletions tt_metal/core_descriptors/grayskull_120_arch.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,24 @@ E75:
storage_cores: # Relative to grid of tensix cores
[[11, 1], [11, 2], [11, 3], [11, 5], [11, 6], [11, 7]]

dispatch_cores: # Relative to grid of tensix cores
[[11, 0], [11, 4]]
producer_cores:
[[11, 0]]

consumer_cores:
[[11, 4]]
2:
l1_bank_size:
1048576

compute_with_storage_grid_range: # Logical only start and end [x, y]
start: [0, 0]
end: [10, 7]

storage_cores: # Relative to grid of tensix cores
[[11, 2], [11, 3], [11, 6], [11, 7]]

producer_cores:
[[11, 0], [11, 1]]

consumer_cores:
[[11, 4], [11, 5]]
2 changes: 1 addition & 1 deletion tt_metal/detail/tt_metal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ namespace tt::tt_metal{
static std::mutex cq_creation_mutex;
{
std::lock_guard<std::mutex> lock(cq_creation_mutex);
command_queues[device->id()] = std::make_unique<CommandQueue>(device, id);
command_queues[device->id()] = std::make_unique<CommandQueue>(device, 0);
}
return *(command_queues[id]);
}
Expand Down
2 changes: 2 additions & 0 deletions tt_metal/impl/dispatch/command_queue_interface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ class SystemMemoryManager {
tt::Cluster::instance().get_fast_pcie_static_tlb_write_callable(device_id)),
worker_from_logical_callable(worker_from_logical) {

TT_ASSERT(cq_cores.size(), "cq_cores size must be positive");

uint8_t num_hw_cqs = cq_cores.size();
this->issue_byte_addrs.resize(num_hw_cqs);
this->completion_byte_addrs.resize(num_hw_cqs);
Expand Down

0 comments on commit 0891c9b

Please sign in to comment.