Skip to content

Commit

Permalink
#8689: Add multi-device API to fetch device shard from tensor
Browse files Browse the repository at this point in the history
- From the python-side this adds an API to do:
ttnn.get_device_tensor(tensor, device: Union[int, Device])
- This also adds an API to the device_mesh to get all devices.
  • Loading branch information
cfjchu committed May 21, 2024
1 parent 19bef99 commit 7ba07ec
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 11 deletions.
23 changes: 23 additions & 0 deletions tests/ttnn/unit_tests/test_multi_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,3 +499,26 @@ def test_clone(device_mesh):
results_11BH = ttnn.to_device(results_11BH, device_mesh)
results_11BH = ttnn.clone(results_11BH, dtype=ttnn.bfloat8_b, memory_config=ttnn.L1_MEMORY_CONFIG)
print(results_11BH)


def test_device_shard_to_torch(device_mesh):
"""Test `ttnn.get_device_tensor(..) API"""
torch_input_tensor = torch.rand((1, 1, 32, 32 * device_mesh.get_num_devices()), dtype=torch.bfloat16)
torch_output_golden = torch.nn.functional.gelu(torch_input_tensor)
torch_output_golden = torch.exp(torch_output_golden)

ttnn_input_tensor = ttnn.from_torch(
torch_input_tensor,
layout=ttnn.TILE_LAYOUT,
mesh_mapper=ShardTensorToMesh(device_mesh, dim=3),
device=device_mesh,
)

ttnn_gelu_output = ttnn.gelu(ttnn_input_tensor)
ttnn_output_tensor = ttnn.exp(ttnn_gelu_output)

# Skip the compose/torch.cat call entirely
for i, device in enumerate(device_mesh.get_devices()):
device_tensor = ttnn.get_device_tensor(ttnn_output_tensor, device)
torch_device_tensor = ttnn.to_torch(device_tensor)
assert_with_pcc(torch_device_tensor, torch_output_golden[..., i * 32 : (i + 1) * 32], pcc=0.999)
10 changes: 7 additions & 3 deletions tt_eager/tensor/tensor_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,11 +205,11 @@ const Shape infer_dims_for_reshape(int N, int C, int H, int W, uint32_t old_volu

bool is_device_tensor(const Tensor& tensor) { return tensor.storage_type() == StorageType::DEVICE; }

Tensor get_device_tensor(Device* device, const Tensor& multi_device_tensor) {
Tensor get_device_tensor(const Tensor& multi_device_tensor, const int device_id) {
const auto& tensor_storage = std::get<MultiDeviceStorage>(multi_device_tensor.get_storage());
if (tensor_storage.buffers.find(device->id()) != tensor_storage.buffers.end()) {
if (tensor_storage.buffers.find(device_id) != tensor_storage.buffers.end()) {
return Tensor{
DeviceStorage{tensor_storage.buffers.at(device->id())},
DeviceStorage{tensor_storage.buffers.at(device_id)},
multi_device_tensor.get_legacy_shape(),
multi_device_tensor.get_dtype(),
multi_device_tensor.get_layout()
Expand All @@ -218,6 +218,10 @@ Tensor get_device_tensor(Device* device, const Tensor& multi_device_tensor) {
TT_THROW("Device not found in multi-device tensor");
}

Tensor get_device_tensor(const Tensor& multi_device_tensor, const Device* device) {
return get_device_tensor(multi_device_tensor, device->id());
}

bool is_multi_device_tensor(const Tensor& tensor) {
return tensor.storage_type() == StorageType::MULTI_DEVICE or tensor.storage_type() == StorageType::MULTI_DEVICE_HOST;
}
Expand Down
7 changes: 4 additions & 3 deletions tt_eager/tensor/tensor_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ bool is_cpu_tensor(const Tensor& tensor);
bool is_device_tensor(const Tensor& tensor);

// Given a multi-device tensor and a device, returns the tensor on the given device.
Tensor get_device_tensor(Device* device, const Tensor& multi_device_tensor);
Tensor get_device_tensor(const Tensor& multi_device_tensor, const Device* device);
Tensor get_device_tensor(const Tensor& multi_device_tensor, const int device_id);

// Returns true has MultiDeviceHost/MultiDevice Storage
bool is_multi_device_tensor(const Tensor& tensor);
Expand Down Expand Up @@ -130,12 +131,12 @@ auto get_device_tensors(Device* device, const TensorContainer& input_tensors) {
for (const auto& tensor : input_tensors) {
if constexpr (IsOptional::value) {
if (tensor.has_value()) {
transformed_tensors.emplace_back(get_device_tensor(device, tensor.value()));
transformed_tensors.emplace_back(get_device_tensor(tensor.value(), device));
} else {
transformed_tensors.emplace_back(std::nullopt);
}
} else {
transformed_tensors.emplace_back(get_device_tensor(device, tensor));
transformed_tensors.emplace_back(get_device_tensor(tensor, device));
}
}
return transformed_tensors;
Expand Down
4 changes: 2 additions & 2 deletions tt_metal/impl/device/multi_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ DeviceMesh::~DeviceMesh() {
}


Device &DeviceMesh::get_device(int queried_device_id)
Device* DeviceMesh::get_device(int queried_device_id)
{
for (const auto& [device_id, device] : mesh_devices) {
if (device_id == queried_device_id) {
return *device;
return device.get();
}
}
TT_THROW("User has provided an invalid device index");
Expand Down
2 changes: 1 addition & 1 deletion tt_metal/impl/device/multi_device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class DeviceMesh
DeviceMesh &operator=(DeviceMesh &&) = delete;

std::vector<Device*> get_devices() const;
Device &get_device(int queried_device_id);
Device* get_device(int queried_device_id);

const DeviceIds get_device_ids() const;

Expand Down
41 changes: 40 additions & 1 deletion ttnn/cpp/pybind11/multi_device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include "tt_eager/tensor/tensor_utils.hpp"
#include "ttnn/multi_device.hpp"

namespace py = pybind11;
Expand All @@ -25,7 +26,13 @@ void py_module(py::module& module) {
py::arg("l1_small_size"))
.def("get_device", &ttnn::multi_device::DeviceMesh::get_device, py::return_value_policy::reference)
.def("get_num_devices", &ttnn::multi_device::DeviceMesh::num_devices)
.def("get_device_ids", &ttnn::multi_device::DeviceMesh::get_device_ids);
.def("get_device_ids", &ttnn::multi_device::DeviceMesh::get_device_ids)
.def("get_devices", &ttnn::multi_device::DeviceMesh::get_devices, py::return_value_policy::reference, R"doc(
Get the devices in the device mesh.
Returns:
List[Device]: The devices in the device mesh.
)doc");

module.def(
"open_device_mesh",
Expand All @@ -36,6 +43,38 @@ void py_module(py::module& module) {
py::arg("l1_small_size"));

module.def("close_device_mesh", &close_device_mesh, py::arg("device_mesh"), py::kw_only());
module.def(
"get_device_tensor",
py::overload_cast<const Tensor&, int>(&tt::tt_metal::get_device_tensor),
py::arg("tensor"),
py::arg("device_id"),
py::kw_only(),
R"doc(
Get the tensor shard corresponding to the device_id.
Args:
tensor (Tensor): The tensor to get the shard from.
device_id (int): The device id to get the shard for.
Returns:
Tensor: The shard of the tensor corresponding to the device_id.
)doc");
module.def(
"get_device_tensor",
py::overload_cast<const Tensor&, const Device*>(&tt::tt_metal::get_device_tensor),
py::arg("tensor"),
py::arg("device"),
py::kw_only(),
R"doc(
Get the tensor shard corresponding to the device.
Args:
tensor (Tensor): The tensor to get the shard from.
device (Device): The device to get the shard for.
Returns:
Tensor: The shard of the tensor corresponding to the device.
)doc");
module.def("get_device_tensors", &get_device_tensors, py::arg("tensor"), py::kw_only());
module.def("aggregate_as_tensor", &aggregate_as_tensor, py::arg("tensors"), py::kw_only());
}
Expand Down
2 changes: 1 addition & 1 deletion ttnn/ttnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def manage_config(name, value):
logger.debug(f"Restored ttnn.CONFIG.{name} to {original_value}")


from ttnn._ttnn.multi_device import get_device_tensors, aggregate_as_tensor
from ttnn._ttnn.multi_device import get_device_tensor, get_device_tensors, aggregate_as_tensor

from ttnn.types import (
TILE_SIZE,
Expand Down

0 comments on commit 7ba07ec

Please sign in to comment.