Skip to content

Commit

Permalink
#11403: SubMesh Support + Porting/Stamping T3K Tests to Galaxy
Browse files Browse the repository at this point in the history
  • Loading branch information
cfjchu committed Oct 2, 2024
1 parent e5e6e29 commit 5161b0a
Show file tree
Hide file tree
Showing 16 changed files with 501 additions and 193 deletions.
7 changes: 6 additions & 1 deletion conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,11 @@ def pcie_mesh_device(request, silicon_arch_name, silicon_arch_wormhole_b0, devic
request.node.pci_ids = device_ids[:num_pcie_devices_requested]

mesh_device = ttnn.open_mesh_device(
ttnn.MeshShape(2, 2), dispatch_core_type=get_dispatch_core_type(), **device_params, offset=(0, 1)
ttnn.MeshShape(2, 2),
dispatch_core_type=get_dispatch_core_type(),
**device_params,
offset=(0, 1),
mesh_type=ttnn.MeshType.Ring,
)

logger.debug(f"multidevice with {mesh_device.get_num_devices()} devices is created")
Expand Down Expand Up @@ -283,6 +287,7 @@ def t3k_mesh_device(request, silicon_arch_name, silicon_arch_wormhole_b0, device
ttnn.MeshShape(2, 4),
dispatch_core_type=get_dispatch_core_type(),
**device_params,
mesh_type=ttnn.MeshType.Ring,
)

logger.debug(f"multidevice with {mesh_device.get_num_devices()} devices is created")
Expand Down
3 changes: 1 addition & 2 deletions models/demos/t3000/llama2_70b/tt/llama_mlp_optimized.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,11 @@ def load_weights(self):
padded_w3 = self.state_dict[w3_str].transpose(-2, -1).view(1, 1, H, H4)

# w1: 8k x 4k. width-sharded on 12 banks, 4224 over 12 banks.
device = self.mesh_device.get_device(0)
weight_grid = ttnn.CoreRangeSet(
{
ttnn.CoreRange(
ttnn.CoreCoord(0, 0),
ttnn.CoreCoord(device.dram_grid_size().x - 1, device.dram_grid_size().y - 1),
ttnn.CoreCoord(self.mesh_device.dram_grid_size().x - 1, self.mesh_device.dram_grid_size().y - 1),
)
}
)
Expand Down
4 changes: 4 additions & 0 deletions tests/scripts/tg/run_tg_model_perf_tests.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
#!/bin/bash

run_tg_llm_tests() {

echo "LOG_METAL: Running run_t3000_llama2_70b_tests"
pytest -n auto models/demos/t3000/llama2_70b/tests/test_llama_perf_decode.py -m "model_perf_t3000" --timeout=600 ; fail+=$?

# Merge all the generated reports
env python models/perf/merge_perf_results.py; fail+=$?

Expand Down
12 changes: 7 additions & 5 deletions tests/ttnn/unit_tests/gtests/test_ccl_on_galaxy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,9 @@ TEST(GalaxyTests, TestAllGatherDeadlock) {
}
// Iterate over each row and run line all-gather multiple times.
// For each row, send adversarial traffic to the first chip, that can hang the network if the CCL is not tagged.
auto view = MeshDeviceView(*mesh);
for (uint32_t row = 0; row < 8; row++) {
auto devs = mesh->get_devices_on_row(row);
auto devs = view.get_devices_on_row(row);
std::vector<uint32_t> device_ids = {};
for (auto dev : devs) {
device_ids.push_back(dev->id());
Expand Down Expand Up @@ -189,13 +190,14 @@ TEST(GalaxyTests, TestReduceScatterDeadlock) {
std::shared_ptr<MeshDevice> mesh = ttnn::multi_device::open_mesh_device(mesh_shape, 0, 0, 1, DispatchCoreType::WORKER);
// Create the outer ring on which Reduce Scatter will be run. This allows us to verify that there are no deadlocks when we send CCLs to the
// first tunnel (forward path).
std::vector<Device*> ring_devices = mesh->get_devices_on_row(0); // Tunnel 0
std::vector<Device*> ring_devices_1 = mesh->get_devices_on_column(mesh_shape.second - 1); // Orthogonal to tunnel .. no deadlocks
auto view = MeshDeviceView(*mesh);
std::vector<Device*> ring_devices = view.get_devices_on_row(0); // Tunnel 0
std::vector<Device*> ring_devices_1 = view.get_devices_on_column(mesh_shape.second - 1); // Orthogonal to tunnel .. no deadlocks
ring_devices_1 = std::vector<Device*>(ring_devices_1.begin() + 1, ring_devices_1.end());
std::vector<Device*> ring_devices_2 = mesh->get_devices_on_row(7); // Tunnel 7 .. potential deadlocks with lack of buffering
std::vector<Device*> ring_devices_2 = view.get_devices_on_row(7); // Tunnel 7 .. potential deadlocks with lack of buffering
std::reverse(ring_devices_2.begin(), ring_devices_2.end());
ring_devices_2 = std::vector<Device*>(ring_devices_2.begin() + 1, ring_devices_2.end());
std::vector<Device*> ring_devices_3 = mesh->get_devices_on_column(0); // Orthogonal to tunnel .. no deadlocks
std::vector<Device*> ring_devices_3 = view.get_devices_on_column(0); // Orthogonal to tunnel .. no deadlocks
std::reverse(ring_devices_3.begin(), ring_devices_3.end());
ring_devices_3 = std::vector<Device*>(ring_devices_3.begin() + 1, ring_devices_3.end() - 1);

Expand Down
4 changes: 2 additions & 2 deletions tests/ttnn/unit_tests/gtests/ttnn_test_fixtures.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,11 @@ class T3kMultiDeviceFixture : public ::testing::Test {
}
constexpr auto DEFAULT_NUM_COMMAND_QUEUES = 1;
mesh_device_ = MeshDevice::create(
MeshShape{2, 4},
DEFAULT_L1_SMALL_SIZE,
DEFAULT_TRACE_REGION_SIZE,
DEFAULT_NUM_COMMAND_QUEUES,
DispatchCoreType::WORKER);
DispatchCoreType::WORKER,
MeshDeviceConfig(MeshShape{2, 4}, MeshType::Ring));
}

void TearDown() override {
Expand Down
25 changes: 25 additions & 0 deletions tests/ttnn/unit_tests/test_multi_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,3 +587,28 @@ def test_validate_as_tensor(tmp_path, mesh_device, height, width):
for device in mesh_device.get_devices():
device_tensor = ttnn.get_device_tensor(tensor, device)
assert torch.allclose(ttnn.to_torch(device_tensor), torch_input_tensor)


def test_visualize_mesh_device(t3k_mesh_device):
ttnn.visualize_mesh_device(t3k_mesh_device)


def test_all_gather_multiple_submeshes(t3k_mesh_device):
"""Test all_gather with multiple submeshes"""

def model(submesh):
full_tensor = torch.ones((1, 1, 32, 32 * submesh.get_num_devices()), dtype=torch.bfloat16)
for i in range(submesh.get_num_devices()):
full_tensor[..., i * 32 : (i + 1) * 32] = i

ttnn_tensor = ttnn.from_torch(full_tensor, mesh_mapper=ShardTensorToMesh(submesh, dim=3))
ttnn_tensor = ttnn.to_device(ttnn_tensor, submesh)
ttnn_tensor = ttnn.all_gather(ttnn_tensor, dim=3, num_links=1)

for device_tensor in ttnn.get_device_tensors(ttnn_tensor):
device_tensor_torch = ttnn.to_torch(device_tensor)
assert torch.all(device_tensor_torch == full_tensor)

submesh_devices = t3k_mesh_device.create_submeshes((2, 2), ttnn.MeshType.Ring)
for submesh in submesh_devices:
model(submesh)
2 changes: 1 addition & 1 deletion tt_metal/impl/device/mesh_configurations/T3000.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"logical_to_physical_coordinates": [
[[0, 0], [0, 0, 0, 0]], [[0, 1], [0, 1, 0, 0]], [[0, 2], [0, 2, 0, 0]], [[0, 3], [0, 3, 0, 0]],
[[1, 0], [1, 3, 0, 0]], [[1, 1], [1, 2, 0, 0]], [[1, 2], [1, 1, 0, 0]], [[1, 3], [1, 0, 0, 0]]
[[1, 0], [1, 0, 0, 0]], [[1, 1], [1, 1, 0, 0]], [[1, 2], [1, 2, 0, 0]], [[1, 3], [1, 3, 0, 0]]
]
}
Loading

0 comments on commit 5161b0a

Please sign in to comment.