Skip to content

Commit

Permalink
#4269: Supports specifying up to 2 hardware command queues on a singl…
Browse files Browse the repository at this point in the history
…e device
  • Loading branch information
DrJessop committed Dec 28, 2023
1 parent a8b2b0b commit 00fa0bd
Show file tree
Hide file tree
Showing 35 changed files with 682 additions and 420 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@
#include "tt_metal/tools/profiler/op_profiler.hpp"

using namespace tt;

//
void measure_latency(string kernel_name) {
const int device_id = 0;
tt_metal::Device *device = tt_metal::CreateDevice(device_id);

auto dispatch_cores = device->dispatch_cores().begin();
CoreCoord producer_logical_core = *dispatch_cores++;
CoreCoord consumer_logical_core = *dispatch_cores;
// auto dispatch_cores = device->dispatch_cores().begin();
CoreCoord producer_logical_core = {0, 0};//*dispatch_cores++;
CoreCoord consumer_logical_core = {0, 0}; //*dispatch_cores;

auto first_worker_physical_core = device->worker_core_from_logical_core({0, 0});

Expand Down
6 changes: 3 additions & 3 deletions tests/tt_metal/tt_metal/unit_tests/basic/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,14 +154,14 @@ TEST_F(BasicFixture, SingleDeviceInitializeAndTeardown) {
auto arch = tt::get_arch_from_string(get_env_arch_name());
tt::tt_metal::Device* device;
const unsigned int device_id = 0;
device = tt::tt_metal::CreateDevice(device_id);
device = tt::tt_metal::CreateDevice(device_id, 1);
ASSERT_TRUE(tt::tt_metal::CloseDevice(device));
}
TEST_F(BasicFixture, SingleDeviceHarvestingPrints) {
auto arch = tt::get_arch_from_string(get_env_arch_name());
tt::tt_metal::Device* device;
const unsigned int device_id = 0;
device = tt::tt_metal::CreateDevice(device_id);
device = tt::tt_metal::CreateDevice(device_id, 1);
CoreCoord unharvested_logical_grid_size = {.x = 12, .y = 10};
if (arch == tt::ARCH::WORMHOLE_B0) {
unharvested_logical_grid_size = {.x = 8, .y = 10};
Expand Down Expand Up @@ -196,7 +196,7 @@ TEST_F(BasicFixture, SingleDeviceLoadBlankKernels) {
auto arch = tt::get_arch_from_string(get_env_arch_name());
tt::tt_metal::Device* device;
const unsigned int device_id = 0;
device = tt::tt_metal::CreateDevice(device_id);
device = tt::tt_metal::CreateDevice(device_id, 1);
unit_tests::basic::device::load_all_blank_kernels(device);
ASSERT_TRUE(tt::tt_metal::CloseDevice(device));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class DPrintFixture: public ::testing::Test {
for (unsigned int id = 0; id < num_devices; id++) {
if (SkipTest(id))
continue;
auto* device = tt::tt_metal::CreateDevice(id);
auto* device = tt::tt_metal::CreateDevice(id, 1);
devices_.push_back(device);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,20 @@ using namespace tt::tt_metal;

namespace host_tests {

TEST_F(MultiCommandQueueFixture, TestAccessCommandQueue) {
namespace single_device_tests {

TEST_F(MultiCommandQueueSingleDeviceFixture, TestFinishOnTwoCqs) {
CommandQueue cq0(this->device_, 0);
CommandQueue cq1(this->device_, 1);

Finish(cq0);
Finish(cq1);
}

}

namespace multi_device_tests {
TEST_F(CommandQueueMultiDeviceFixture, TestAccessCommandQueue) {
for (unsigned int device_id = 0; device_id < num_devices_; device_id++) {
EXPECT_NO_THROW(detail::GetCommandQueue(devices_[device_id]));
}
Expand All @@ -36,7 +49,7 @@ TEST(FastDispatchHostSuite, TestCannotAccessCommandQueueForClosedDevice) {
EXPECT_ANY_THROW(detail::GetCommandQueue(device));
}

TEST_F(MultiCommandQueueFixture, TestDirectedLoopbackToUniqueHugepage) {
TEST_F(CommandQueueMultiDeviceFixture, TestDirectedLoopbackToUniqueHugepage) {
std::unordered_map<chip_id_t, std::vector<uint32_t>> golden_data;

const uint32_t byte_size = 2048 * 16;
Expand All @@ -62,5 +75,9 @@ TEST_F(MultiCommandQueueFixture, TestDirectedLoopbackToUniqueHugepage) {
EXPECT_EQ(readback_data, golden_data.at(device_id));
}
}
}




} // namespace host_tests
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,39 @@ bool test_EnqueueWriteBuffer_and_EnqueueReadBuffer(Device* device, CommandQueue&
return pass;
}

bool test_EnqueueWriteBuffer_and_EnqueueReadBuffer_multi_queue(Device* device, vector<std::reference_wrapper<CommandQueue>>& cqs, const TestBufferConfig& config) {
bool pass = true;
for (const bool use_void_star_api: {true, false}) {

size_t buf_size = config.num_pages * config.page_size;
vector<unique_ptr<Buffer>> buffers;
vector<vector<uint32_t>> srcs;
for (uint i = 0; i < cqs.size(); i++) {
buffers.push_back(make_unique<Buffer>(device, buf_size, config.page_size, config.buftype));
srcs.push_back(generate_arange_vector(buffers[i]->size()));
if (use_void_star_api) {
EnqueueWriteBuffer(cqs[i], *buffers[i], srcs[i].data(), false);
} else {
EnqueueWriteBuffer(cqs[i], *buffers[i], srcs[i], false);
}
}

for (uint i = 0; i < cqs.size(); i++) {
vector<uint32_t> result;
if (use_void_star_api) {
result.resize(buf_size / sizeof(uint32_t));
EnqueueReadBuffer(cqs[i], *buffers[i], result.data(), true);
} else {
EnqueueReadBuffer(cqs[i], *buffers[i], result, true);
}
bool local_pass = (srcs[i] == result);
pass &= local_pass;
}
}

return pass;
}

bool stress_test_EnqueueWriteBuffer_and_EnqueueReadBuffer(
Device* device, CommandQueue& cq, const BufferStressTestConfig& config) {
srand(config.seed);
Expand Down Expand Up @@ -197,7 +230,6 @@ bool stress_test_EnqueueWriteBuffer_and_EnqueueReadBuffer_sharded(

bool test_EnqueueWrap_on_EnqueueReadBuffer(Device* device, CommandQueue& cq, const TestBufferConfig& config) {
auto [buffer, src] = EnqueueWriteBuffer_prior_to_wrap(device, cq, config);

vector<uint32_t> dst;
EnqueueReadBuffer(cq, buffer, dst, true);

Expand Down Expand Up @@ -260,8 +292,13 @@ namespace dram_tests {

TEST_F(CommandQueueFixture, WriteOneTileToDramBank0) {
TestBufferConfig config = {.num_pages = 1, .page_size = 2048, .buftype = BufferType::DRAM};

EXPECT_TRUE(local_test_functions::test_EnqueueWriteBuffer_and_EnqueueReadBuffer(this->device_, tt::tt_metal::detail::GetCommandQueue(device_), config));
CommandQueue a(this->device_, 0);
CommandQueue b(this->device_, 1);
// Finish(a);
// Finish(b);
vector<reference_wrapper<CommandQueue>> cqs = {a, b};
EXPECT_TRUE(local_test_functions::test_EnqueueWriteBuffer_and_EnqueueReadBuffer_multi_queue(this->device_, cqs, config));
// EXPECT_TRUE(local_test_functions::test_EnqueueWriteBuffer_and_EnqueueReadBuffer(this->device_, tt::tt_metal::detail::GetCommandQueue(device_), config));
}

TEST_F(CommandQueueFixture, WriteOneTileToAllDramBanks) {
Expand Down Expand Up @@ -329,8 +366,8 @@ TEST_F(CommandQueueFixture, TestWrapHostHugepageOnEnqueueReadBuffer) {
uint32_t num_pages = buffer_size / page_size;

TestBufferConfig buf_config = {.num_pages = num_pages, .page_size = page_size, .buftype = BufferType::DRAM};

EXPECT_TRUE(local_test_functions::test_EnqueueWrap_on_EnqueueReadBuffer(this->device_, tt::tt_metal::detail::GetCommandQueue(device_), buf_config));
CommandQueue a(this->device_, 0);
EXPECT_TRUE(local_test_functions::test_EnqueueWrap_on_EnqueueReadBuffer(this->device_, a, buf_config));
}

TEST_F(CommandQueueFixture, TestIssueMultipleReadWriteCommandsForOneBuffer) {
Expand Down Expand Up @@ -420,18 +457,26 @@ namespace l1_tests {

TEST_F(CommandQueueFixture, WriteOneTileToL1Bank0) {
TestBufferConfig config = {.num_pages = 1, .page_size = 2048, .buftype = BufferType::L1};

EXPECT_TRUE(local_test_functions::test_EnqueueWriteBuffer_and_EnqueueReadBuffer(this->device_, tt::tt_metal::detail::GetCommandQueue(device_), config));
CommandQueue a(this->device_, 0);
CommandQueue b(this->device_, 1);
// Finish(a);
// Finish(b);
vector<reference_wrapper<CommandQueue>> cqs = {a, b};
EXPECT_TRUE(local_test_functions::test_EnqueueWriteBuffer_and_EnqueueReadBuffer_multi_queue(this->device_, cqs, config));
}

TEST_F(CommandQueueFixture, WriteOneTileToAllL1Banks) {
auto compute_with_storage_grid = this->device_->compute_with_storage_grid_size();
TestBufferConfig config = {
.num_pages = uint32_t(compute_with_storage_grid.x * compute_with_storage_grid.y),
.num_pages = 124,//uint32_t(compute_with_storage_grid.x * compute_with_storage_grid.y),
.page_size = 2048,
.buftype = BufferType::L1};

EXPECT_TRUE(local_test_functions::test_EnqueueWriteBuffer_and_EnqueueReadBuffer(this->device_, tt::tt_metal::detail::GetCommandQueue(device_), config));
CommandQueue a(this->device_, 0);
CommandQueue b(this->device_, 1);
vector<reference_wrapper<CommandQueue>> cqs = {a, b};
EXPECT_TRUE(local_test_functions::test_EnqueueWriteBuffer_and_EnqueueReadBuffer_multi_queue(this->device_, cqs, config));
// EXPECT_TRUE(local_test_functions::test_EnqueueWriteBuffer_and_EnqueueReadBuffer(this->device_, tt::tt_metal::detail::GetCommandQueue(device_), config));
}

TEST_F(CommandQueueFixture, WriteOneTileToAllL1BanksTwiceRoundRobin) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class CommandQueueFixture : public ::testing::Test {
this->arch_ = tt::get_arch_from_string(tt::test_utils::get_env_arch_name());

const int device_id = 0;
this->device_ = tt::tt_metal::CreateDevice(device_id);
this->device_ = tt::tt_metal::CreateDevice(device_id, 1);

this->pcie_id = 0;
}
Expand Down Expand Up @@ -63,7 +63,7 @@ class MultiCommandQueueFixture : public ::testing::Test {
size_t num_devices_;
};

class MultiCommandQueueFixture : public ::testing::Test {
class CommandQueueMultiDeviceFixture : public ::testing::Test {
protected:
void SetUp() override {
auto slow_dispatch = getenv("TT_METAL_SLOW_DISPATCH_MODE");
Expand All @@ -73,8 +73,6 @@ class MultiCommandQueueFixture : public ::testing::Test {
}
arch_ = tt::get_arch_from_string(tt::test_utils::get_env_arch_name());

num_devices_ = tt::tt_metal::GetNumAvailableDevices();

for (unsigned int id = 0; id < num_devices_; id++) {
auto* device = tt::tt_metal::CreateDevice(id);
devices_.push_back(device);
Expand All @@ -91,3 +89,24 @@ class MultiCommandQueueFixture : public ::testing::Test {
tt::ARCH arch_;
size_t num_devices_;
};

class MultiCommandQueueSingleDeviceFixture : public ::testing::Test {
protected:
void SetUp() override {
auto slow_dispatch = getenv("TT_METAL_SLOW_DISPATCH_MODE");
if (slow_dispatch) {
TT_THROW("This suite can only be run with fast dispatch or TT_METAL_SLOW_DISPATCH_MODE unset");
GTEST_SKIP();
}
arch_ = tt::get_arch_from_string(tt::test_utils::get_env_arch_name());
uint32_t num_cq_hw_resources = 2;
device_ = tt::tt_metal::CreateDevice(0, num_cq_hw_resources, {});
}

void TearDown() override {
tt::tt_metal::CloseDevice(device_);
}

tt::tt_metal::Device* device_;
tt::ARCH arch_;
};
2 changes: 1 addition & 1 deletion tt_eager/tt_lib/csrc/tt_lib_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ void DeviceModule(py::module &m_device) {
.def(
py::init<>(
[](int device_id) {
return Device(device_id);
return Device(device_id, 1);
}
), "Create device."
)
Expand Down
3 changes: 3 additions & 0 deletions tt_metal/common/base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ using std::string;
using std::size_t;
using std::map;

inline uint32_t align(uint32_t addr, uint32_t alignment) { return ((addr - 1) | (alignment - 1)) + 1; }


namespace tt
{

Expand Down
17 changes: 14 additions & 3 deletions tt_metal/common/metal_soc_descriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -308,15 +308,26 @@ void metal_SocDescriptor::load_dispatch_and_banking_config(uint32_t harvesting_m

// dispatch_cores are a subset of worker cores
// they have already been parsed as CoreType::WORKER and saved into `cores` map when parsing `functional_workers`
for (const auto& core_node : config["dispatch_cores"]) {
for (const auto& core_node : config["producer_cores"]) {
RelativeCoreCoord coord = {};
if (core_node.IsSequence()) {
// Logical coord
coord = RelativeCoreCoord({.x = core_node[0].as<int>(), .y = core_node[1].as<int>()});
} else {
TT_THROW("Only logical relative coords supported for dispatch_cores cores");
TT_THROW("Only logical relative coords supported for producer_cores cores");
}
this->dispatch_cores.push_back(coord);
this->producer_cores.push_back(coord);
}

for (const auto& core_node : config["consumer_cores"]) {
RelativeCoreCoord coord = {};
if (core_node.IsSequence()) {
// Logical coord
coord = RelativeCoreCoord({.x = core_node[0].as<int>(), .y = core_node[1].as<int>()});
} else {
TT_THROW("Only logical relative coords supported for consumer_cores cores");
}
this->consumer_cores.push_back(coord);
}
}

Expand Down
3 changes: 2 additions & 1 deletion tt_metal/common/metal_soc_descriptor.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ struct metal_SocDescriptor : public tt_SocDescriptor {
CoreCoord compute_with_storage_grid_size;
std::vector<RelativeCoreCoord> compute_with_storage_cores; // saved as CoreType::WORKER
std::vector<RelativeCoreCoord> storage_cores; // saved as CoreType::WORKER
std::vector<RelativeCoreCoord> dispatch_cores; // saved as CoreType::WORKER
std::vector<RelativeCoreCoord> producer_cores;
std::vector<RelativeCoreCoord> consumer_cores;
std::vector<CoreCoord> logical_ethernet_cores;
int l1_bank_size;
uint32_t dram_core_size;
Expand Down
11 changes: 8 additions & 3 deletions tt_metal/common/test_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@
// Needed for TargetDevice enum
#include "common/base.hpp"

inline std::string get_soc_description_file(const tt::ARCH &arch, tt::TargetDevice target_device, string output_dir = "") {

inline std::string get_soc_description_file(const tt::ARCH &arch, tt::TargetDevice target_device, uint32_t num_cqs, string output_dir = "") {
// Ability to skip this runtime opt, since trimmed SOC desc limits which DRAM channels are available.
bool use_full_soc_desc = getenv("TT_METAL_VERSIM_FORCE_FULL_SOC_DESC");
string tt_metal_home;
Expand All @@ -47,7 +46,13 @@ inline std::string get_soc_description_file(const tt::ARCH &arch, tt::TargetDevi
switch (arch) {
case tt::ARCH::Invalid: throw std::runtime_error("Invalid arch not supported"); // will be overwritten in tt_global_state constructor
case tt::ARCH::JAWBRIDGE: throw std::runtime_error("JAWBRIDGE arch not supported");
case tt::ARCH::GRAYSKULL: return tt_metal_home + "tt_metal/soc_descriptors/grayskull_120_arch.yaml";
case tt::ARCH::GRAYSKULL: {
if (num_cqs == 1) {
return tt_metal_home + "tt_metal/soc_descriptors/grayskull_120_arch_one_cq.yaml";
} else if (num_cqs == 2) {
return tt_metal_home + "tt_metal/soc_descriptors/grayskull_120_arch_two_cqs.yaml";
}
}
case tt::ARCH::WORMHOLE: throw std::runtime_error("WORMHOLE arch not supported");
case tt::ARCH::WORMHOLE_B0: return tt_metal_home + "tt_metal/soc_descriptors/wormhole_b0_80_arch.yaml";
default: throw std::runtime_error("Unsupported device arch");
Expand Down
Loading

0 comments on commit 00fa0bd

Please sign in to comment.