Skip to content

Commit

Permalink
#8895: Fix ttnn.as_tensor(..) method for placing tensors on-device
Browse files Browse the repository at this point in the history
This fixes the two reported issues:
1. The first time the cache is generated, the tensor is left on the
host. Loading the tensor from cache places it on device. Now, when the
cache is generated, the multi-device tensor is moved onto device when
user has requested it to be.
2. The memory-management for the Device* should be left to C++, not
Python.
  • Loading branch information
cfjchu committed May 31, 2024
1 parent fdfe175 commit 2c8d039
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 29 deletions.
39 changes: 39 additions & 0 deletions tests/ttnn/unit_tests/test_multi_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,3 +522,42 @@ def test_device_shard_to_torch(device_mesh):
device_tensor = ttnn.get_device_tensor(ttnn_output_tensor, device)
torch_device_tensor = ttnn.to_torch(device_tensor)
assert_with_pcc(torch_device_tensor, torch_output_golden[..., i * 32 : (i + 1) * 32], pcc=0.999)


@pytest.mark.parametrize("height", [7])
@pytest.mark.parametrize("width", [3])
def test_validate_as_tensor(tmp_path, device_mesh, height, width):
torch_input_tensor = torch.rand((height, width), dtype=torch.float32)

memory_config = ttnn.L1_MEMORY_CONFIG
tensor = ttnn.as_tensor(
torch_input_tensor,
dtype=ttnn.float32,
layout=ttnn.TILE_LAYOUT,
device=device_mesh,
memory_config=memory_config,
mesh_mapper=ttnn.ReplicateTensorToMesh(device_mesh),
cache_file_name=tmp_path / "cache_file",
)
assert tensor.dtype == ttnn.float32
assert tensor.devices() == device_mesh.get_devices()
assert tensor.layout == ttnn.TILE_LAYOUT
assert ttnn.get_memory_config(tensor) == memory_config

tensor = ttnn.as_tensor(
torch_input_tensor,
dtype=ttnn.float32,
layout=ttnn.TILE_LAYOUT,
device=device_mesh,
memory_config=memory_config,
mesh_mapper=ttnn.ReplicateTensorToMesh(device_mesh),
cache_file_name=tmp_path / "cache_file",
)
assert tensor.dtype == ttnn.float32
assert tensor.devices() == device_mesh.get_devices()
assert tensor.layout == ttnn.TILE_LAYOUT
assert ttnn.get_memory_config(tensor) == memory_config

for device in device_mesh.get_devices():
device_tensor = ttnn.get_device_tensor(tensor, device)
assert torch.allclose(ttnn.to_torch(device_tensor), torch_input_tensor)
20 changes: 12 additions & 8 deletions tt_eager/tensor/serialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#include "tensor/host_buffer/functions.hpp"
#include "tensor/tensor_utils.hpp"
#include "tt_eager/tensor/types.hpp"

namespace tt {

Expand Down Expand Up @@ -45,12 +46,14 @@ void dump_borrowed_storage(ofstream& output_stream, const BorrowedStorage& stora
);
}

void dump_multi_device_host_storage(ofstream& output_stream, const MultiDeviceHostStorage& storage) {
void dump_multi_device_host_storage(ofstream& output_stream, const MultiDeviceHostStorage& storage, const DistributedTensorConfig& strategy) {
std::size_t num_buffers = storage.num_buffers();
output_stream.write(reinterpret_cast<const char*>(&num_buffers), sizeof(std::size_t));
output_stream.write(reinterpret_cast<const char*>(&storage.strategy), sizeof(DistributedTensorConfig));

if (std::holds_alternative<ReplicateTensor>(storage.strategy)) {
// Use the user-specified strategy which defines how it gets distributed when mapped onto multi-device
output_stream.write(reinterpret_cast<const char*>(&strategy), sizeof(DistributedTensorConfig));

if (std::holds_alternative<ReplicateTensor>(strategy)) {
std::visit(
[&output_stream]<typename T>(const owned_buffer::Buffer<T>& generic_buffer) {
const auto buffer = owned_buffer::get_as<T>(generic_buffer);
Expand Down Expand Up @@ -175,7 +178,7 @@ MultiDeviceHostStorage load_multi_device_host_storage(ifstream& input_stream, Da

template <typename T>
Storage load_storage(ifstream& input_stream, DataType data_type, StorageType storage_type, T device) {
if (storage_type == StorageType::MULTI_DEVICE_HOST) {
if (storage_type == StorageType::MULTI_DEVICE_HOST or storage_type == StorageType::MULTI_DEVICE) {
if constexpr (std::is_same_v<T, DeviceMesh*>) {
return load_multi_device_host_storage(input_stream, data_type, device);
} else {
Expand All @@ -186,9 +189,9 @@ Storage load_storage(ifstream& input_stream, DataType data_type, StorageType sto
}
}

}
} // namespace detail

void dump_tensor(const std::string& file_name, const Tensor& tensor) {
void dump_tensor(const std::string& file_name, const Tensor& tensor, const std::unordered_map<std::string, std::string>& strategy) {
ofstream output_stream(file_name, ios::out | ios::binary);
if (not output_stream) {
throw std::runtime_error(fmt::format("Cannot open \"{}\"", file_name));
Expand Down Expand Up @@ -221,7 +224,7 @@ void dump_tensor(const std::string& file_name, const Tensor& tensor) {
}

std::visit(
[&output_stream](const auto& storage) {
[&output_stream, &strategy](const auto& storage) {

using StorageType = std::decay_t<decltype(storage)>;
if constexpr (std::is_same_v<StorageType, OwnedStorage>) {
Expand All @@ -237,7 +240,8 @@ void dump_tensor(const std::string& file_name, const Tensor& tensor) {
TT_THROW("Device storage isn't supported");
}
else if constexpr (std::is_same_v<StorageType, MultiDeviceHostStorage>) {
detail::dump_multi_device_host_storage(output_stream, storage);
auto distribute_config = get_distributed_tensor_config(strategy);
detail::dump_multi_device_host_storage(output_stream, storage, distribute_config);
}
else {
raise_unsupported_storage<StorageType>();
Expand Down
3 changes: 2 additions & 1 deletion tt_eager/tensor/serialization.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
#include "tensor/tensor.hpp"

#include <string>
#include <unordered_map>

namespace tt {

namespace tt_metal {

void dump_tensor(const std::string& file_name, const Tensor& tensor);
void dump_tensor(const std::string& file_name, const Tensor& tensor, const std::unordered_map<std::string, std::string>& strategy);

template <typename T>
Tensor load_tensor(const std::string& file_name, T device = nullptr);
Expand Down
2 changes: 1 addition & 1 deletion tt_eager/tensor/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ bool operator==(const MemoryConfig& config_a, const MemoryConfig& config_b) {

bool operator!=(const MemoryConfig& config_a, const MemoryConfig& config_b) { return not(config_a == config_b); }

void dump_memory_config(std::ofstream& output_stream, const MemoryConfig& memory_config) {
void dump_memory_config(std::ostream& output_stream, const MemoryConfig& memory_config) {
output_stream.write(reinterpret_cast<const char*>(&VERSION_ID), sizeof(std::uint8_t));
output_stream.write(reinterpret_cast<const char*>(&memory_config.memory_layout), sizeof(TensorMemoryLayout));
output_stream.write(reinterpret_cast<const char*>(&memory_config.buffer_type), sizeof(BufferType));
Expand Down
2 changes: 1 addition & 1 deletion tt_eager/tensor/types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ struct MemoryConfig {
bool operator==(const MemoryConfig &config_a, const MemoryConfig &config_b);
bool operator!=(const MemoryConfig &config_a, const MemoryConfig &config_b);

void dump_memory_config(std::ofstream &output_stream, const MemoryConfig &memory_config);
void dump_memory_config(std::ostream &output_stream, const MemoryConfig &memory_config);
void dump_memory_config(const std::string &file_name, const MemoryConfig &memory_config);

MemoryConfig load_memory_config(std::ifstream &input_stream);
Expand Down
21 changes: 18 additions & 3 deletions tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_pytensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1313,15 +1313,30 @@ Tensor convert_python_tensors_to_tt_tensors(py::list tensor_shards, std::optiona
storage_type = tt_tensor.storage_type()
)doc")
.def(
"device", [](const Tensor &self) { return self.device(); }, R"doc(
.def(
"device",
[](const Tensor &self) { return self.device(); },
R"doc(
Get the device of the tensor.
.. code-block:: python
device = tt_tensor.device()
)doc")
)doc",
py::return_value_policy::reference)
.def(
"devices",
[](const Tensor &self) { return self.get_workers(); },
R"doc(
Get devices tensor is mapped on to.
.. code-block:: python
devices = tt_tensor.devices()
)doc",
py::return_value_policy::reference)
.def(
"to_torch",
[](const Tensor &self) -> py::object { return detail::convert_tt_tensor_to_torch_tensor(self); },
Expand Down
31 changes: 16 additions & 15 deletions ttnn/ttnn/operations/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import math
import pathlib
from typing import Union, Tuple, Optional, Any, Callable
from typing import Union, Tuple, Optional, Any, Callable, Dict

from loguru import logger
import torch
Expand Down Expand Up @@ -578,9 +578,11 @@ def load_tensor(file_name: Union[str, pathlib.Path], *, device: ttnn.Device = No


@ttnn.register_operation(name="ttnn.dump_tensor", validate_input_tensors=lambda *args, **kwargs: None)
def dump_tensor(file_name: Union[str, pathlib.Path], tensor: ttnn.Tensor) -> None:
def dump_tensor(file_name: Union[str, pathlib.Path], tensor: ttnn.Tensor, distribute: Dict[str, str] = None) -> None:
if distribute is None:
distribute = dict()
file_name = pathlib.Path(file_name)
ttl.tensor.dump_tensor(str(file_name), tensor)
ttl.tensor.dump_tensor(str(file_name), tensor, distribute)


def _as_tensor_validate_input_tensors(operation_name, tensor, *args, **kwargs):
Expand Down Expand Up @@ -661,17 +663,22 @@ def from_torch_and_dump(tensor, dtype, layout, cache_file_name):
)
tensor = ttnn.to_layout(tensor, layout, dtype=dtype, memory_config=memory_config, device=device)
else:
tensor = ttnn.from_torch(tensor, dtype=dtype, layout=layout, mesh_mapper=mesh_mapper)
tensor = ttnn.from_torch(
tensor,
dtype=dtype,
layout=layout,
mesh_mapper=mesh_mapper,
memory_config=memory_config,
device=device,
)
logger.debug(
f"Generating cache for {cache_file_name} of shape {tensor.shape}, dtype {dtype_name}, layout {layout_name}"
)
pathlib.Path(cache_file_name).parent.mkdir(parents=True, exist_ok=True)
ttnn.dump_tensor(cache_file_name, tensor)
distributed_config = mesh_mapper.config() if mesh_mapper else dict()
ttnn.dump_tensor(cache_file_name, tensor, distributed_config)
return tensor

def dispatch_to_device_on_load(device) -> bool:
return isinstance(device, ttnn.DeviceMesh)

if isinstance(mesh_mapper, ttnn.ReplicateTensorToMesh):
storage_type = f"_multi_device" if mesh_mapper else ""
elif mesh_mapper:
Expand All @@ -682,11 +689,7 @@ def dispatch_to_device_on_load(device) -> bool:
cache_file_name = f"{cache_file_name}{storage_type}_dtype_{dtype_name}_layout_{layout_name}.bin"

try:
tensor = (
ttnn.load_tensor(cache_file_name, device=device)
if dispatch_to_device_on_load(device)
else ttnn.load_tensor(cache_file_name)
)
tensor = ttnn.load_tensor(cache_file_name, device=device)
if tuple(tensor.shape) != tuple(tensor.shape):
logger.warning(
f"Cached file {cache_file_name} has shape {tensor.shape}, expected {tensor.shape}, regenerating cache"
Expand All @@ -695,8 +698,6 @@ def dispatch_to_device_on_load(device) -> bool:
logger.debug(f"Loaded cache for {cache_file_name} of shape {tensor.shape}")
except (FileNotFoundError, RuntimeError):
tensor = from_torch_and_dump(tensor, dtype, layout, cache_file_name)
if not dispatch_to_device_on_load(device):
tensor = ttnn.to_device(tensor, device, memory_config=memory_config)
return tensor


Expand Down

0 comments on commit 2c8d039

Please sign in to comment.