Skip to content

Commit

Permalink
#13432: fix t3k ethernet tests (#13453)
Browse files Browse the repository at this point in the history
  • Loading branch information
cfjchu authored Oct 3, 2024
1 parent 8ffb584 commit 73e0dd8
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <random>
#include <tuple>

#include "impl/device/mesh_device_view.hpp"
#include "tt_metal/common/logger.hpp"
#include "device/tt_arch_types.h"
#include "impl/device/device.hpp"
Expand All @@ -26,6 +27,7 @@
#include "tt_metal/test_utils/stimulus.hpp"

#include "tt_metal/detail/persistent_kernel_cache.hpp"
#include "tt_metal/impl/device/mesh_device.hpp"

using tt::tt_metal::Device;

Expand All @@ -41,8 +43,7 @@ class T3000TestDevice {
num_devices_ = tt::tt_metal::GetNumAvailableDevices();
if (arch_ == tt::ARCH::WORMHOLE_B0 and tt::tt_metal::GetNumAvailableDevices() == 8 and
tt::tt_metal::GetNumPCIeDevices() == 4) {
devices_ = tt::tt_metal::detail::CreateDevices({0,1,2,3,4,5,6,7});
tt::Cluster::instance().set_internal_routing_info_for_ethernet_cores(true);
mesh_device_ = tt::tt_metal::MeshDevice::create(tt::tt_metal::MeshDeviceConfig(tt::tt_metal::MeshShape{2, 4}));

} else {
TT_THROW("This suite can only be run on T3000 Wormhole devices");
Expand All @@ -57,15 +58,12 @@ class T3000TestDevice {

void TearDown() {
device_open = false;
tt::Cluster::instance().set_internal_routing_info_for_ethernet_cores(false);
for (auto [device_id, device_ptr] : devices_) {
tt::tt_metal::CloseDevice(device_ptr);
}
mesh_device_->close_devices();
}

std::map<chip_id_t, Device *> devices_;
tt::ARCH arch_;
size_t num_devices_;
std::shared_ptr<tt::tt_metal::MeshDevice> mesh_device_;

private:
bool device_open;
Expand Down Expand Up @@ -420,23 +418,51 @@ int main (int argc, char** argv) {
TT_ASSERT(std::all_of(max_concurrent_samples.begin(), max_concurrent_samples.end(), [](std::size_t n) { return n > 0; }));

T3000TestDevice test_fixture;
auto view = test_fixture.mesh_device_->get_view();

// Device setup
std::vector<chip_id_t> device_ids = std::vector<chip_id_t>{0, 1, 2, 3, 4, 5, 6, 7};

auto get_device_list = [](std::map<chip_id_t, Device*> &all_devices, std::size_t n_hops) {
auto get_device_list = [](const std::shared_ptr<MeshDeviceView>& view, std::size_t n_hops) {
switch (n_hops) {
case 2:
return std::vector<Device*>{all_devices[0], all_devices[1]};
return std::vector<Device*>{
view->get_device(0, 0),
view->get_device(0, 1),
};

case 4:
return std::vector<Device*>{all_devices[0], all_devices[1], all_devices[2], all_devices[3]};
return std::vector<Device*>{
view->get_device(1, 1),
view->get_device(0, 1),
view->get_device(0, 2),
view->get_device(1, 2),
};

case 8:
return std::vector<Device*>{all_devices[0], all_devices[4], all_devices[5], all_devices[1], all_devices[2], all_devices[6], all_devices[7], all_devices[3]};
return std::vector<Device*>{
view->get_device(1, 1),
view->get_device(1, 0),
view->get_device(0, 0),
view->get_device(0, 1),
view->get_device(0, 2),
view->get_device(0, 3),
view->get_device(1, 3),
view->get_device(1, 2),
};

case 12: // Does an extra loop through the inner ring
return std::vector<Device*>{all_devices[0], all_devices[4], all_devices[5], all_devices[1], all_devices[2], all_devices[3], all_devices[0], all_devices[1], all_devices[2], all_devices[6], all_devices[7], all_devices[3]};
return std::vector<Device*>{
view->get_device(1, 1),
view->get_device(1, 0),
view->get_device(0, 0),
view->get_device(0, 1),
view->get_device(0, 2),
view->get_device(1, 2),
view->get_device(1, 1),
view->get_device(0, 1),
view->get_device(0, 2),
view->get_device(0, 3),
view->get_device(1, 3),
view->get_device(1, 2),
};

default:
TT_THROW("Unsupported hop_count");
Expand All @@ -448,7 +474,7 @@ int main (int argc, char** argv) {
constexpr std::size_t placeholder_arg_value = 1;
for (auto n_hops : hop_counts) {

auto devices = get_device_list(test_fixture.devices_, n_hops);
auto devices = get_device_list(view, n_hops);
std::vector<hop_eth_sockets> hop_eth_sockets = build_eth_sockets_list(devices);

for (auto max_concurrent_samples : max_concurrent_samples) {
Expand Down
5 changes: 0 additions & 5 deletions tests/ttnn/unit_tests/gtests/ttnn_test_fixtures.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,7 @@ class T3kMultiDeviceFixture : public ::testing::Test {
if (num_devices < 8 or arch != tt::ARCH::WORMHOLE_B0) {
GTEST_SKIP() << "Skipping T3K Multi-Device test suite on non T3K machine.";
}
constexpr auto DEFAULT_NUM_COMMAND_QUEUES = 1;
mesh_device_ = MeshDevice::create(
DEFAULT_L1_SMALL_SIZE,
DEFAULT_TRACE_REGION_SIZE,
DEFAULT_NUM_COMMAND_QUEUES,
DispatchCoreType::WORKER,
MeshDeviceConfig(MeshShape{2, 4}, MeshType::Ring));
}

Expand Down
4 changes: 2 additions & 2 deletions tt_metal/impl/device/mesh_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,11 +216,11 @@ MeshDevice::MeshDevice(const MeshShape& mesh_device_shape, MeshType type, std::w
: mesh_device_shape(mesh_device_shape), type(type), mesh_id(generate_unique_mesh_id()), parent_mesh(parent_mesh) {}

std::shared_ptr<MeshDevice> MeshDevice::create(
const MeshDeviceConfig& config,
size_t l1_small_size,
size_t trace_region_size,
size_t num_command_queues,
DispatchCoreType dispatch_core_type,
const MeshDeviceConfig& config)
DispatchCoreType dispatch_core_type)
{
auto mesh_device = std::make_shared<MeshDevice>(config.mesh_shape, config.mesh_type);
mesh_device->initialize(l1_small_size, trace_region_size, num_command_queues, dispatch_core_type, config);
Expand Down
12 changes: 6 additions & 6 deletions tt_metal/impl/device/mesh_device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ struct MeshDeviceConfig {

MeshDeviceConfig(
const MeshShape &mesh_shape,
MeshType mesh_type = MeshType::RowMajor) :
MeshType mesh_type) :
mesh_shape(mesh_shape),
offset(MeshOffset{0, 0}),
physical_device_ids(std::vector<chip_id_t>()),
Expand Down Expand Up @@ -174,11 +174,11 @@ class MeshDevice : public std::enable_shared_from_this<MeshDevice> {

static std::shared_ptr<MeshDevice> fetch_mesh_device(const std::vector<Device*>& devices);
static std::shared_ptr<MeshDevice> create(
size_t l1_small_size,
size_t trace_region_size,
size_t num_command_queues,
DispatchCoreType dispatch_core_type,
const MeshDeviceConfig &config);
const MeshDeviceConfig &config,
size_t l1_small_size = DEFAULT_L1_SMALL_SIZE,
size_t trace_region_size = DEFAULT_TRACE_REGION_SIZE,
size_t num_command_queues = 1,
DispatchCoreType dispatch_core_type = DispatchCoreType::WORKER);
};

std::ostream &operator<<(std::ostream &os, const MeshDevice &mesh_device);
Expand Down
5 changes: 2 additions & 3 deletions ttnn/cpp/pybind11/multi_device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,12 @@ void py_module(py::module& module) {
const std::pair<size_t, size_t>& offset,
const std::vector<chip_id_t>& physical_device_ids,
MeshType mesh_type) {
auto config = MeshDeviceConfig(mesh_device_shape, offset, physical_device_ids, mesh_type);
return MeshDevice::create(
MeshDeviceConfig(mesh_device_shape, offset, physical_device_ids, mesh_type),
l1_small_size,
trace_region_size,
num_command_queues,
dispatch_core_type,
config);
dispatch_core_type);
}),
py::kw_only(),
py::arg("mesh_shape"),
Expand Down
2 changes: 1 addition & 1 deletion ttnn/cpp/ttnn/multi_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace ttnn::multi_device {

std::shared_ptr<MeshDevice> open_mesh_device(const MeshShape& mesh_shape, size_t l1_small_size, size_t trace_region_size, size_t num_command_queues, DispatchCoreType dispatch_core_type, MeshType mesh_type, const std::pair<size_t, size_t>& offset, const std::vector<int>& physical_device_ids) {
auto config = MeshDeviceConfig(mesh_shape, offset, physical_device_ids, mesh_type);
return MeshDevice::create(l1_small_size, trace_region_size, num_command_queues, dispatch_core_type, config);
return MeshDevice::create(config, l1_small_size, trace_region_size, num_command_queues, dispatch_core_type);
}

void close_mesh_device(const std::shared_ptr<MeshDevice>& mesh_device) {
Expand Down

0 comments on commit 73e0dd8

Please sign in to comment.