Skip to content

Commit

Permalink
#11403: Support 2x4-submeshes across 8x4 mesh
Browse files Browse the repository at this point in the history
1. Submeshing to support creating submesh on galaxy mesh
2. Key change to start enabling more T3000 Tests onto galaxy:
- Currently ttnn.all_gather(..) in a ring relies on MeshDevice being
initialized in ring-order. Now we decouple this so we don't require that
MeshDevice is initialized with devices in a ring-order. Instead, we now
explicitly request for a ring-order in the operation that requires it.
  • Loading branch information
cfjchu committed Sep 30, 2024
1 parent d8706ff commit c5c86a3
Show file tree
Hide file tree
Showing 14 changed files with 242 additions and 67 deletions.
3 changes: 1 addition & 2 deletions models/demos/t3000/llama2_70b/tt/llama_mlp_optimized.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,11 @@ def load_weights(self):
padded_w3[:, :, :, :H4] = self.state_dict[w3_str].transpose(-2, -1)

# w1: 8k x 4k. width-sharded on 12 banks, 4224 over 12 banks.
device = self.mesh_device.get_device(0)
weight_grid = ttnn.CoreRangeSet(
{
ttnn.CoreRange(
ttnn.CoreCoord(0, 0),
ttnn.CoreCoord(device.dram_grid_size().x - 1, device.dram_grid_size().y - 1),
ttnn.CoreCoord(self.mesh_device.dram_grid_size().x - 1, self.mesh_device.dram_grid_size().y - 1),
)
}
)
Expand Down
6 changes: 5 additions & 1 deletion tests/scripts/tg/run_tg_model_perf_tests.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
#!/bin/bash

run_tg_llm_tests() {
run_t3k_tests_on_tg_tests() {

echo "LOG_METAL: Running T3000 tests on TG"
env pytest -n auto models/demos/t3000/llama2_70b/tests/test_llama_perf_decode.py -m "model_perf_t3000" --timeout=600 ; fail+=$?

# Merge all the generated reports
env python models/perf/merge_perf_results.py; fail+=$?

Expand Down
16 changes: 16 additions & 0 deletions tests/ttnn/multichip_unit_tests/test_multidevice_TG.py
Original file line number Diff line number Diff line change
Expand Up @@ -1573,3 +1573,19 @@ def test_sharded_distributed_layernorm(mesh_device, input_width, input_height, c
is_pass, output_pcc = comp_pcc(torch_output_tensor, tt_output_tensor, pcc=0.999)

assert is_pass, f"PCC value: {output_pcc}"


def test_ttnn_multi_device_all_gather_all_devices(t3k_mesh_device):
"""Example test for running a 2x4-Ring All-Gather on galaxy"""
full_tensor = torch.ones((1, 1, 32, 32 * t3k_mesh_device.get_num_devices()), dtype=torch.bfloat16)
for i in range(t3k_mesh_device.get_num_devices()):
full_tensor[..., i * 32 : (i + 1) * 32] = i

ttnn_tensor = ttnn.from_torch(full_tensor, mesh_mapper=ShardTensorToMesh(t3k_mesh_device, dim=3))
ttnn_tensor = ttnn.to_device(ttnn_tensor, t3k_mesh_device)
ttnn_tensor = ttnn.all_gather(ttnn_tensor, dim=3, num_links=1)

device_tensors: typing.List[ttnn.Tensor] = ttnn.get_device_tensors(ttnn_tensor)
for device_tensor in device_tensors:
device_tensor_torch = ttnn.to_torch(device_tensor)
assert torch.all(device_tensor_torch == full_tensor)
12 changes: 7 additions & 5 deletions tests/ttnn/unit_tests/gtests/test_ccl_on_galaxy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,9 @@ TEST(GalaxyTests, TestAllGatherDeadlock) {
}
// Iterate over each row and run line all-gather multiple times.
// For each row, send adversarial traffic to the first chip, that can hang the network if the CCL is not tagged.
auto view = MeshDeviceView(*mesh);
for (uint32_t row = 0; row < 8; row++) {
auto devs = mesh->get_devices_on_row(row);
auto devs = view.get_devices_on_row(row);
std::vector<uint32_t> device_ids = {};
for (auto dev : devs) {
device_ids.push_back(dev->id());
Expand Down Expand Up @@ -189,13 +190,14 @@ TEST(GalaxyTests, TestReduceScatterDeadlock) {
std::shared_ptr<MeshDevice> mesh = ttnn::multi_device::open_mesh_device(mesh_shape, 0, 0, 1, DispatchCoreType::WORKER);
// Create the outer ring on which Reduce Scatter will be run. This allows us to verify that there are no deadlocks when we send CCLs to the
// first tunnel (forward path).
std::vector<Device*> ring_devices = mesh->get_devices_on_row(0); // Tunnel 0
std::vector<Device*> ring_devices_1 = mesh->get_devices_on_column(mesh_shape.second - 1); // Orthogonal to tunnel .. no deadlocks
auto view = MeshDeviceView(*mesh);
std::vector<Device*> ring_devices = view.get_devices_on_row(0); // Tunnel 0
std::vector<Device*> ring_devices_1 = view.get_devices_on_column(mesh_shape.second - 1); // Orthogonal to tunnel .. no deadlocks
ring_devices_1 = std::vector<Device*>(ring_devices_1.begin() + 1, ring_devices_1.end());
std::vector<Device*> ring_devices_2 = mesh->get_devices_on_row(7); // Tunnel 7 .. potential deadlocks with lack of buffering
std::vector<Device*> ring_devices_2 = view.get_devices_on_row(7); // Tunnel 7 .. potential deadlocks with lack of buffering
std::reverse(ring_devices_2.begin(), ring_devices_2.end());
ring_devices_2 = std::vector<Device*>(ring_devices_2.begin() + 1, ring_devices_2.end());
std::vector<Device*> ring_devices_3 = mesh->get_devices_on_column(0); // Orthogonal to tunnel .. no deadlocks
std::vector<Device*> ring_devices_3 = view.get_devices_on_column(0); // Orthogonal to tunnel .. no deadlocks
std::reverse(ring_devices_3.begin(), ring_devices_3.end());
ring_devices_3 = std::vector<Device*>(ring_devices_3.begin() + 1, ring_devices_3.end() - 1);

Expand Down
4 changes: 4 additions & 0 deletions tests/ttnn/unit_tests/test_multi_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,3 +587,7 @@ def test_validate_as_tensor(tmp_path, mesh_device, height, width):
for device in mesh_device.get_devices():
device_tensor = ttnn.get_device_tensor(tensor, device)
assert torch.allclose(ttnn.to_torch(device_tensor), torch_input_tensor)


def test_ttnn_visualize_mesh_device(t3k_mesh_device):
ttnn.visualize_mesh_device(t3k_mesh_device)
2 changes: 1 addition & 1 deletion tt_metal/impl/device/mesh_configurations/T3000.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"logical_to_physical_coordinates": [
[[0, 0], [0, 0, 0, 0]], [[0, 1], [0, 1, 0, 0]], [[0, 2], [0, 2, 0, 0]], [[0, 3], [0, 3, 0, 0]],
[[1, 0], [1, 3, 0, 0]], [[1, 1], [1, 2, 0, 0]], [[1, 2], [1, 1, 0, 0]], [[1, 3], [1, 0, 0, 0]]
[[1, 0], [1, 0, 0, 0]], [[1, 1], [1, 1, 0, 0]], [[1, 2], [1, 2, 0, 0]], [[1, 3], [1, 3, 0, 0]]
]
}
140 changes: 115 additions & 25 deletions tt_metal/impl/device/mesh_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,14 @@ std::vector<chip_id_t> SystemMesh::get_mapped_physical_device_ids(const MeshDevi
}
return physical_device_ids;
}
void SystemMesh::register_mesh_device(const std::shared_ptr<MeshDevice> &mesh_device, const std::vector<Device*>& devices) {
std::vector<chip_id_t> physical_device_ids;
for (auto device : devices) {
physical_device_ids.push_back(device->id());
}
this->assigned_mesh_device_devices.insert({mesh_device->get_mesh_id(), mesh_device});
this->assigned_devices.insert({mesh_device->get_mesh_id(), physical_device_ids});
}

std::vector<Device*> SystemMesh::map_mesh_device(
std::shared_ptr<MeshDevice> mesh_device,
Expand All @@ -145,7 +153,6 @@ std::vector<Device*> SystemMesh::map_mesh_device(
TT_FATAL(requested_num_rows <= max_num_rows, "Requested too many rows: {} > {}", requested_num_rows, max_num_rows);
TT_FATAL(requested_num_rows*requested_num_cols <= max_num_rows*max_num_cols, "Requested submesh is too big: {}x{}", requested_num_rows, requested_num_cols);

this->assigned_mesh_device_devices.insert({mesh_device->get_mesh_id(), mesh_device});

auto physical_device_ids = user_provided_physical_device_ids.empty() ?
this->get_mapped_physical_device_ids(MeshDeviceConfig{mesh_device->shape(), offset}) :
Expand All @@ -158,35 +165,43 @@ std::vector<Device*> SystemMesh::map_mesh_device(
for (auto physical_device_id : physical_device_ids) {
auto mapped_device = this->opened_devices[mesh_device->get_mesh_id()].at(physical_device_id);
mapped_devices.push_back(mapped_device);
this->assigned_devices[mesh_device->get_mesh_id()].push_back(physical_device_id);
this->assigned_physical_id_to_device.insert({physical_device_id, mapped_device});
}

this->register_mesh_device(mesh_device, mapped_devices); // TODO: change this
return mapped_devices;
}

void SystemMesh::unmap_mesh_device(const std::shared_ptr<MeshDevice>& mesh_device) {
auto mesh_id = mesh_device->get_mesh_id();

// Clean up all state related to this virtual mesh
this->assigned_mesh_device_devices.erase(mesh_id);

// Remove the devices from assigned_physical_id_to_device
for (auto physical_id : this->assigned_devices.at(mesh_id)) {
this->assigned_physical_id_to_device.erase(physical_id);
// Close the devices
if (mesh_device->is_parent_mesh()) {
for (auto physical_id : this->assigned_devices.at(mesh_id)) {
this->assigned_physical_id_to_device.erase(physical_id);
}
tt::tt_metal::detail::CloseDevices(this->opened_devices.at(mesh_id));
this->opened_devices.erase(mesh_id);
}
this->assigned_devices.erase(mesh_id);
}

// Close the devices
tt::tt_metal::detail::CloseDevices(this->opened_devices.at(mesh_id));
this->opened_devices.erase(mesh_id);
Device* SystemMesh::get_device(const chip_id_t physical_device_id) {
auto it = this->assigned_physical_id_to_device.find(physical_device_id);
if (it == this->assigned_physical_id_to_device.end()) {
TT_THROW("Physical Device ID: {} not found in assigned devices", physical_device_id);
}
return it->second;
}

static MeshDeviceID generate_unique_mesh_id() {
static std::atomic<MeshDeviceID> next_id{0};
return next_id++;
}

MeshDevice::MeshDevice(const MeshShape& mesh_device_shape) : mesh_device_shape(mesh_device_shape), mesh_id(generate_unique_mesh_id()) {}
MeshDevice::MeshDevice(const MeshShape& mesh_device_shape, std::shared_ptr<MeshDevice> parent_mesh)
: mesh_device_shape(mesh_device_shape), mesh_id(generate_unique_mesh_id()), parent_mesh(parent_mesh) {}

std::shared_ptr<MeshDevice> MeshDevice::create(
const MeshShape& mesh_device_shape,
Expand All @@ -203,6 +218,36 @@ std::shared_ptr<MeshDevice> MeshDevice::create(
return mesh_device;
}

std::shared_ptr<MeshDevice> MeshDevice::create_submesh(const MeshShape &submesh_shape, const MeshOffset &offset) {
if (submesh_shape.first <= 0 || submesh_shape.second <= 0) {
TT_THROW("Invalid submesh shape: ({}, {}). Both dimensions must be positive.", submesh_shape.first, submesh_shape.second);
}

if (offset.first < 0 || offset.second < 0) {
TT_THROW("Invalid offset: ({}, {}). Offset must be non-negative.", offset.first, offset.second);
}

if (offset.first + submesh_shape.first > this->mesh_device_shape.first ||
offset.second + submesh_shape.second > this->mesh_device_shape.second) {
TT_THROW("Submesh ({}x{}) with offset ({}, {}) does not fit within parent mesh ({}x{}).",
submesh_shape.first, submesh_shape.second,
offset.first, offset.second,
this->mesh_device_shape.first, this->mesh_device_shape.second);
}

auto submesh = std::make_shared<MeshDevice>(submesh_shape, shared_from_this());
auto start_coordinate = Coordinate{offset.first, offset.second};
auto end_coordinate = Coordinate{offset.first + submesh_shape.first - 1, offset.second + submesh_shape.second - 1};
submesh->primary_view = std::make_unique<MeshDeviceView>(*this, start_coordinate, end_coordinate);
submesh->devices = submesh->primary_view->get_devices();
SystemMesh::instance().register_mesh_device(submesh, submesh->devices);
this->submeshes.push_back(submesh);
log_trace(LogMetal, "Instantiating submesh {}: {}x{} with offset: {} {}", submesh->get_mesh_id(), submesh_shape.first, submesh_shape.second, offset.first, offset.second);
log_trace(LogMetal, "Submesh {} instantiated with {} devices", submesh->get_mesh_id(), submesh->devices);

return submesh;
}

void MeshDevice::initialize(
size_t l1_small_size,
size_t trace_region_size,
Expand All @@ -223,16 +268,18 @@ void MeshDevice::initialize(
this->devices = instance.map_mesh_device(
shared_from_this(), num_command_queues, l1_small_size, trace_region_size, dispatch_core_type, offset, physical_device_ids);
this->primary_view = std::make_unique<tt::tt_metal::MeshDeviceView>(*this);

for (int device_index = 0; device_index < this->devices.size(); device_index++) {
this->physical_id_to_device_index.insert({this->devices[device_index]->id(), device_index});
}
}

MeshDevice::~MeshDevice() {
if (not this->devices.empty()) {
this->close_devices();
}
for (auto submesh : this->submeshes) {
submesh->close_devices();
}
this->primary_view.reset();
this->devices.clear();
this->parent_mesh.reset();
}

Device* MeshDevice::get_device_index(int logical_device_id) const {
Expand All @@ -241,7 +288,7 @@ Device* MeshDevice::get_device_index(int logical_device_id) const {
}

Device* MeshDevice::get_device(int physical_device_id) const {
return this->devices.at(this->physical_id_to_device_index.at(physical_device_id));
return SystemMesh::instance().get_device(physical_device_id);
}

std::vector<Device*> MeshDevice::get_devices() const { return this->devices; }
Expand All @@ -250,14 +297,6 @@ Device* MeshDevice::get_device(int row_idx, int col_idx) const {
return this->get_device_index(row_idx * num_cols() + col_idx);
}

std::vector<Device*> MeshDevice::get_devices_on_row(int row_idx) const {
return this->primary_view->get_devices_on_row(row_idx);
}

std::vector<Device*> MeshDevice::get_devices_on_column(int col_idx) const {
return this->primary_view->get_devices_on_column(col_idx);
}

const DeviceIds MeshDevice::get_device_ids() const {
DeviceIds device_ids;
for (auto device : this->get_devices()) {
Expand All @@ -283,7 +322,6 @@ MeshShape MeshDevice::shape() const { return this->mesh_device_shape; }
void MeshDevice::close_devices() {
SystemMesh::instance().unmap_mesh_device(shared_from_this());
this->devices.clear();
this->physical_id_to_device_index.clear();
this->primary_view.reset();
}

Expand All @@ -295,8 +333,60 @@ std::shared_ptr<const MeshDeviceView> MeshDevice::get_view() const { return this

std::shared_ptr<MeshDeviceView> MeshDevice::get_view() { return this->primary_view; }

std::vector<std::shared_ptr<MeshDeviceView>> MeshDevice::get_submesh_views() {
std::vector<std::shared_ptr<MeshDeviceView>> submesh_views;
if (this->submeshes.empty()) {
submesh_views.push_back(this->get_view());
}
else {
for (auto submesh : this->submeshes) {
submesh_views.push_back(submesh->get_view());
}
}
return submesh_views;
}

MeshDeviceID MeshDevice::get_mesh_id() const { return this->mesh_id; }

bool MeshDevice::is_parent_mesh() const { return this->parent_mesh == nullptr; }

std::shared_ptr<MeshDevice> SystemMesh::get_mesh_device(const std::vector<chip_id_t>& physical_device_ids) {
log_trace(LogMetal, "Getting mesh device for {} physical devices: {}", physical_device_ids.size(), physical_device_ids);
std::unordered_set<chip_id_t> input_set(physical_device_ids.begin(), physical_device_ids.end());

for (const auto& [mesh_id, mesh_device] : this->assigned_mesh_device_devices) {
const auto& assigned_devices = this->assigned_devices.at(mesh_id);
std::unordered_set<chip_id_t> assigned_set(assigned_devices.begin(), assigned_devices.end());
log_trace(LogMetal, "Assigned devices: {}", assigned_devices);

if (input_set == assigned_set) {
return mesh_device;
}
}
TT_THROW("No mesh device found for the provided devices");
}

std::shared_ptr<MeshDevice> MeshDevice::fetch_mesh_device(const std::vector<Device*>& devices) {
TT_FATAL(devices.size() > 0, "No devices provided");
auto& instance = SystemMesh::instance();
std::vector<chip_id_t> physical_device_ids;
for (auto device : devices) {
physical_device_ids.push_back(device->id());
}
return instance.get_mesh_device(physical_device_ids);
}

std::vector<std::shared_ptr<MeshDevice>> MeshDevice::get_submeshes() const { return this->submeshes; }

std::shared_ptr<MeshDeviceView> MeshDevice::get_view(const Device* device) {
for (auto submesh_view : this->get_submesh_views()) {
if (submesh_view->contains_device(device->id())) {
return submesh_view;
}
}
TT_THROW("Device {} not found in any submesh view", device->id());
}

std::ostream& operator<<(std::ostream& os, const MeshDevice& mesh_device) { return os << mesh_device.to_string(); }

bool validate_worker_modes(const std::vector<Device*>& workers) {
Expand Down
19 changes: 15 additions & 4 deletions tt_metal/impl/device/mesh_device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class SystemMesh {

// Get the physical device IDs mapped to a MeshDevice
std::vector<chip_id_t> get_mapped_physical_device_ids(const MeshDeviceConfig &config) const;
void register_mesh_device(const std::shared_ptr<MeshDevice> &mesh_device, const std::vector<Device*>& devices);

// Map MeshDevice to physical devices
std::vector<Device *> map_mesh_device(
Expand All @@ -85,14 +86,18 @@ class SystemMesh {

// Unmap MeshDevice, releasing the associated physical devices.
void unmap_mesh_device(const std::shared_ptr<MeshDevice> &mesh_device);
std::shared_ptr<MeshDevice> get_mesh_device(const std::vector<chip_id_t>& physical_device_ids);
Device* get_device(const chip_id_t physical_device_id);
};

class MeshDevice : public std::enable_shared_from_this<MeshDevice> {
private:
MeshDeviceID mesh_id;
MeshShape mesh_device_shape;
std::shared_ptr<MeshDeviceView> primary_view;
std::vector<Device *> devices;
std::unordered_map<chip_id_t, int> physical_id_to_device_index;
std::shared_ptr<MeshDevice> parent_mesh;
std::vector<std::shared_ptr<MeshDevice>> submeshes;

void initialize(
size_t l1_small_size,
Expand All @@ -103,7 +108,7 @@ class MeshDevice : public std::enable_shared_from_this<MeshDevice> {
const std::vector<chip_id_t> &physical_device_ids);

public:
MeshDevice(const MeshShape &mesh_device_shape);
MeshDevice(const MeshShape &mesh_device_shape, std::shared_ptr<MeshDevice> parent_mesh = nullptr);
~MeshDevice();

MeshDevice(const MeshDevice &) = delete;
Expand All @@ -116,8 +121,6 @@ class MeshDevice : public std::enable_shared_from_this<MeshDevice> {
Device *get_device_index(int logical_device_id) const;
Device *get_device(int physical_device_id) const;
Device *get_device(int row_idx, int col_idx) const;
std::vector<Device *> get_devices_on_row(int row_idx) const;
std::vector<Device *> get_devices_on_column(int col_idx) const;

const DeviceIds get_device_ids() const;

Expand All @@ -138,6 +141,7 @@ class MeshDevice : public std::enable_shared_from_this<MeshDevice> {

std::string to_string() const;
MeshDeviceID get_mesh_id() const;
bool is_parent_mesh() const;

static std::shared_ptr<MeshDevice> create(
const MeshShape &mesh_device_shape,
Expand All @@ -147,6 +151,13 @@ class MeshDevice : public std::enable_shared_from_this<MeshDevice> {
DispatchCoreType dispatch_core_type,
const std::pair<size_t, size_t> &offset = {0, 0},
const std::vector<chip_id_t> &physical_device_ids = {});

std::vector<std::shared_ptr<MeshDevice>> get_submeshes() const;
std::vector<std::shared_ptr<MeshDeviceView>> get_submesh_views();
std::shared_ptr<MeshDeviceView> get_view(const Device* device);

std::shared_ptr<MeshDevice> create_submesh(const MeshShape &submesh_shape, const MeshOffset &offset = {0, 0});
static std::shared_ptr<MeshDevice> fetch_mesh_device(const std::vector<Device*>& devices);
};

std::ostream &operator<<(std::ostream &os, const MeshDevice &mesh_device);
Expand Down
Loading

0 comments on commit c5c86a3

Please sign in to comment.