Skip to content

Commit

Permalink
#8264: Async Engine Optimizations
Browse files Browse the repository at this point in the history
  - copy_borrowed_tensor_in_async_mode does not stall for device tensors
    anymore
  - Typechecking moved to compile time
  - work_executor optimizations: Pass shared ptrs down to workers,
    instead of lambda objects. Lock Free Queue is now statically
    initialized
  - launch_op optimization: lambda initialized outside multi-device for
    loop
  - Tensor deallocate optimization: Pass attribute ptr to lambda instead
    of passing entire tensor object
  - System Level Optimizations: Set process priority to 0. Bind CQ reader to
    core and use CV to toggle its state instead of calling sleep
  • Loading branch information
tt-asaigal committed May 15, 2024
1 parent a20cb5c commit cd0587b
Show file tree
Hide file tree
Showing 10 changed files with 271 additions and 139 deletions.
52 changes: 21 additions & 31 deletions tt_eager/tensor/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,18 +114,16 @@ void Tensor::deallocate(bool force) {
uint32_t ref_count_to_use = (this->workers.at(0)->get_worker_mode() == WorkExecutorMode::SYNCHRONOUS or not this->tensor_attributes->main_thread_tensor) ? this->tensor_attributes.use_count() : this->tensor_attributes->main_thread_ref_count;
if ((force or ref_count_to_use == 1) and not this->tensor_attributes->deallocated) {
this->tensor_attributes->deallocated = true;
// Record ref count before sending to worker
uint32_t device_tensor_ref_count = this->tensor_attributes->record_main_thread_ref_count();
this->workers.at(0)->push_work([force, *this] () mutable {
this->workers.at(0)->push_work(std::make_shared<std::function<void()>>([force, attr = this->tensor_attributes] () mutable {
// Cross worker synchronization: If the tensor being deallocated is shared across workers (ex: all_gather op),
// wait until all workers are done with this tensor before deallocating.
bool num_threads_sharing_tensor = this->tensor_attributes->num_sibling_workers_sharing_tensor;
bool num_threads_sharing_tensor = attr->num_sibling_workers_sharing_tensor;
if (num_threads_sharing_tensor) {
while (num_threads_sharing_tensor) {
num_threads_sharing_tensor = this->tensor_attributes->num_sibling_workers_sharing_tensor;;
num_threads_sharing_tensor = attr->num_sibling_workers_sharing_tensor;;
}
}
std::visit([force, this] (auto&& s) {
std::visit([force, attr] (auto&& s) {
using type = std::decay_t<decltype(s)>;
if constexpr (std::is_same_v<type, DeviceStorage>) {
if (force or s.buffer.use_count() == 1) {
Expand All @@ -138,13 +136,11 @@ void Tensor::deallocate(bool force) {
} else if constexpr(std::is_same_v<type, OwnedStorage>) {
// Manage Dynamic Storage (due to autoformat in async mode): Main thread sees this tensor as a device tensor, since worker has not updated
// storage time. When the worker executes the dealloc request, the storage type has been appropriately updated to Owned.
TT_ASSERT(this->tensor_attributes->dynamic_storage, "Tensor storage type changed during runtime (device -> host), but dynamic storage was not marked.");
TT_ASSERT(attr->dynamic_storage, "Tensor storage type changed during runtime (device -> host), but dynamic storage was not marked.");
std::visit([] (auto&& buffer) { buffer.reset(); }, s.buffer);
}
}, this->tensor_attributes->storage);
});
// Update ref count after sending to worker
this->tensor_attributes->update_main_thread_ref_count(this->workers.at(0), device_tensor_ref_count);
}, attr->storage);
}));
}
} else {
TT_FATAL(this->deallocate_through_destructor, "Device tensors created in the main thread cannot be explictly deallocated in worker threads.");
Expand All @@ -155,32 +151,26 @@ void Tensor::deallocate(bool force) {
}
} else if constexpr (std::is_same_v<T, MultiDeviceStorage>) {
if (this->workers.at(0)->in_main_thread() or not this->tensor_attributes->main_thread_tensor) {
if (not this->tensor_attributes->main_thread_tensor) {
TT_ASSERT(not this->tensor_attributes->main_thread_ref_count, "main_thread_ref_count for tensors created inside a worker thread must be 0");
}
// If owned by the main thread, deallocate this tensor only from the main thread. If owned by worker thread, allow deallocation in worker and use shared_ptr ref count, since this is a thread_local tensor
uint32_t ref_count_to_use = (this->workers.at(0)->get_worker_mode() == WorkExecutorMode::SYNCHRONOUS or not this->tensor_attributes->main_thread_tensor) ? this->tensor_attributes.use_count() : this->tensor_attributes->main_thread_ref_count;
if ((force or ref_count_to_use == 1) and not this->tensor_attributes->deallocated) {
this->tensor_attributes->deallocated = true;
// Record ref count before sending to workers
uint32_t device_tensor_ref_count = this->tensor_attributes->record_main_thread_ref_count();
auto dealloc_lambda = std::make_shared<std::function<void(Device*)>>([force, attr = this->tensor_attributes] (Device* worker) mutable {
ZoneScopedN("ShardDeallocate");
auto& s = std::get<MultiDeviceStorage>(attr->storage);
if (s.buffers.find(worker->id()) != s.buffers.end()) {
if ((force or s.buffers.at(worker->id()).use_count() == 1)) {
DeallocateBuffer(*(s.buffers.at(worker->id())));
}
s.buffers.at(worker->id()).reset();
}
});

for (auto worker : this->workers) {
worker->push_work([force, *this, worker] () mutable {
std::visit([force, worker] (auto&& s) {
using type = std::decay_t<decltype(s)>;
if constexpr (std::is_same_v<type, MultiDeviceStorage>) {
if (s.buffers.find(worker->id()) != s.buffers.end()) {
if (force or s.buffers.at(worker->id()).use_count() == 1) {
DeallocateBuffer(*(s.buffers.at(worker->id())));
}
s.buffers.at(worker->id()).reset();
}
}
}, this->tensor_attributes->storage);
});
worker->push_work(std::make_shared<std::function<void()>>([worker, dealloc_lambda] () mutable {
(*dealloc_lambda)(worker);
}));
}
// Update ref count after sending to workers
this->tensor_attributes->update_main_thread_ref_count(this->workers.at(0), device_tensor_ref_count);
}
} else {
TT_FATAL(this->deallocate_through_destructor, "Device tensors created in the main thread cannot be explictly deallocated in worker threads.");
Expand Down
72 changes: 44 additions & 28 deletions tt_eager/tensor/tensor_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -346,47 +346,63 @@ uint32_t num_buffers_in_tensor(const Tensor& tensor) {
} else if (std::holds_alternative<DeviceStorage>(tensor.get_storage()) || std::holds_alternative<OwnedStorage>(tensor.get_storage()) || std::holds_alternative<BorrowedStorage>(tensor.get_storage())) {
return 1;
} else {
TT_FATAL(false, "get_shard_for_device only supports multi-device or device tensors");
TT_FATAL(false, "num_buffers_in_tensor only supports multi-device or device tensors");
}
}

Tensor get_shard_for_device(const Tensor& tensor, Device* target_device, std::optional<int> buffer_index) {
if (std::holds_alternative<MultiDeviceStorage>(tensor.get_storage())) {
auto device_storage = std::get<tt::tt_metal::MultiDeviceStorage>(tensor.get_storage());
auto shard_shape = device_storage.get_tensor_shape_for_device(target_device);
auto shard_buffer = device_storage.get_buffer_for_device(target_device);
return Tensor{DeviceStorage{shard_buffer}, shard_shape, tensor.get_dtype(), tensor.get_layout()};
} else if (std::holds_alternative<MultiDeviceHostStorage>(tensor.get_storage())) {
auto host_storage = std::get<tt::tt_metal::MultiDeviceHostStorage>(tensor.get_storage());
auto shard_shape = host_storage.get_tensor_shape(buffer_index.value());
auto shard_buffer = host_storage.get_buffer(buffer_index.value());
return Tensor{OwnedStorage{shard_buffer}, shard_shape, tensor.get_dtype(), tensor.get_layout()};
} else if (std::holds_alternative<DeviceStorage>(tensor.get_storage()) || std::holds_alternative<OwnedStorage>(tensor.get_storage()) || std::holds_alternative<BorrowedStorage>(tensor.get_storage())) {
return tensor;
} else {
TT_FATAL(false, "get_shard_for_device only supports multi-device or device tensors");
}
ZoneScopedN("GetShardForDevice");
Tensor shard = Tensor();
auto& storage = tensor.get_storage();
std::visit([target_device, buffer_index, &tensor, &shard] (auto&& s) {
using T = std::decay_t<decltype(s)>;
if constexpr (std::is_same_v<T, MultiDeviceStorage>) {
auto shard_shape = s.get_tensor_shape_for_device(target_device);
auto shard_buffer = s.get_buffer_for_device(target_device);
shard = Tensor{DeviceStorage{shard_buffer}, shard_shape, tensor.get_dtype(), tensor.get_layout()};
} else if constexpr (std::is_same_v<T, MultiDeviceHostStorage>) {
auto shard_shape = s.get_tensor_shape(buffer_index.value());
auto shard_buffer = s.get_buffer(buffer_index.value());
shard = Tensor{OwnedStorage{shard_buffer}, shard_shape, tensor.get_dtype(), tensor.get_layout()};
} else if constexpr (std::is_same_v<T, OwnedStorage> || std::is_same_v<T, BorrowedStorage> || std::is_same_v<T, DeviceStorage>) {
shard = tensor;
} else {
TT_FATAL(false, "get_shard_for_device only supports multi-device or device tensors");
}
}, storage);
return shard;
}

void insert_buffer_and_shape_for_device(Device* target_device, const Tensor& shard, Tensor& tensor_to_modify, std::optional<int> buffer_index) {
if (std::holds_alternative<MultiDeviceHostStorage>(tensor_to_modify.tensor_attributes->storage)) {
std::get<MultiDeviceHostStorage>(tensor_to_modify.tensor_attributes->storage).insert_buffer_and_shape_for_device(buffer_index.value(), std::get<OwnedStorage>(shard.get_storage()).get_buffer(), shard.get_legacy_shape());
} else if (std::holds_alternative<MultiDeviceStorage>(tensor_to_modify.tensor_attributes->storage)) {
std::get<MultiDeviceStorage>(tensor_to_modify.tensor_attributes->storage).insert_buffer_and_shape_for_device(target_device, std::get<DeviceStorage>(shard.get_storage()).get_buffer(), shard.get_legacy_shape());
} else if (std::holds_alternative<OwnedStorage>(tensor_to_modify.tensor_attributes->storage)) {
std::get<OwnedStorage>(tensor_to_modify.tensor_attributes->storage).insert_buffer(std::get<OwnedStorage>(shard.get_storage()).get_buffer());
} else if (std::holds_alternative<DeviceStorage>(tensor_to_modify.tensor_attributes->storage)) {
std::get<DeviceStorage>(tensor_to_modify.tensor_attributes->storage).insert_buffer(std::get<DeviceStorage>(shard.get_storage()).get_buffer());
} else {
TT_FATAL(false, "Unsupported storage in insert_buffer_and_shape_for_device");
}
ZoneScopedN("InsertBufferAndShapeForDevice");
std::visit([target_device, &shard, &tensor_to_modify, buffer_index] (auto&& s) {
using T = std::decay_t<decltype(s)>;
if constexpr (std::is_same_v<T, MultiDeviceHostStorage>) {
s.insert_buffer_and_shape_for_device(buffer_index.value(), std::get<OwnedStorage>(shard.get_storage()).get_buffer(), shard.get_legacy_shape());
} else if constexpr (std::is_same_v<T, MultiDeviceStorage>) {
s.insert_buffer_and_shape_for_device(target_device, std::get<DeviceStorage>(shard.get_storage()).get_buffer(), shard.get_legacy_shape());
} else if constexpr (std::is_same_v<T, OwnedStorage>) {
s.insert_buffer(std::get<OwnedStorage>(shard.get_storage()).get_buffer());
} else if constexpr (std::is_same_v<T, DeviceStorage>) {
s.insert_buffer(std::get<DeviceStorage>(shard.get_storage()).get_buffer());
} else {
TT_FATAL(false, "Unsupported storage in insert_buffer_and_shape_for_device");
}
}, tensor_to_modify.tensor_attributes->storage);
}


Tensor copy_borrowed_tensor_in_async_mode(Device* worker, const Tensor& tensor) {
// When using async mode, tensors with borrowed storage cannot be passed to workers.
// They need to be copied to owned storage before being passed to the worker.
ZoneScopedN("ConvertBorrowedToOwned");
if (worker->get_worker_mode() == WorkExecutorMode::ASYNCHRONOUS and tensor.storage_type() == StorageType::BORROWED) {
// Tensor has workers (on device) or runtime mode is synchronous or tensor has multiple buffers.
// No need to check for borrowed storage.
if (worker->get_worker_mode() == WorkExecutorMode::SYNCHRONOUS or
tensor.get_workers().size() or
tensor.tensor_attributes->tensor_populated.size() > 1) return tensor;

if (tensor.storage_type() == StorageType::BORROWED) {
ZoneScopedN("CopyBorrowedStorage");
auto borrowed_buffer = std::get<BorrowedStorage>(tensor.get_storage()).buffer;
Tensor owned_tensor;
Expand Down
20 changes: 15 additions & 5 deletions tt_eager/tensor/types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -366,29 +366,33 @@ struct MultiDeviceHostStorage {
DistributedTensorConfig strategy;
std::vector<OwnedBuffer> buffers;
std::vector<Shape> shapes;
std::mutex mtx;
mutable std::mutex mtx;
MultiDeviceHostStorage() = default;
MultiDeviceHostStorage(DistributedTensorConfig strategy_, std::vector<OwnedBuffer> buffers_, std::vector<Shape> shapes_) : strategy(strategy_), buffers(buffers_), shapes(shapes_) {}
MultiDeviceHostStorage(MultiDeviceHostStorage &&other) {
std::lock_guard<std::mutex> lock(mtx);
strategy = other.strategy;
buffers = other.buffers;
shapes = other.shapes;
}

MultiDeviceHostStorage(const MultiDeviceHostStorage &other) {
std::lock_guard<std::mutex> lock(mtx);
strategy = other.strategy;
buffers = other.buffers;
shapes = other.shapes;
}

MultiDeviceHostStorage &operator=(const MultiDeviceHostStorage &other) {
std::lock_guard<std::mutex> lock(mtx);
strategy = other.strategy;
buffers = other.buffers;
shapes = other.shapes;
return *this;
}

MultiDeviceHostStorage &operator=( MultiDeviceHostStorage &&other) {
std::lock_guard<std::mutex> lock(mtx);
strategy = other.strategy;
buffers = other.buffers;
shapes = other.shapes;
Expand All @@ -410,13 +414,13 @@ struct MultiDeviceHostStorage {
shapes[buffer_index] = shape;
}

OwnedBuffer get_buffer(int buffer_index) {
OwnedBuffer get_buffer(int buffer_index) const {
std::lock_guard<std::mutex> lock(mtx);
TT_FATAL(buffer_index < buffers.size(), "Buffer not found for buffer_index " + std::to_string(buffer_index));
return buffers[buffer_index];;
}

Shape get_tensor_shape(int shape_index) {
Shape get_tensor_shape(int shape_index) const {
std::lock_guard<std::mutex> lock(mtx);
TT_FATAL(shape_index < shapes.size(), "Buffer not found for device " + std::to_string(shape_index));
return shapes[shape_index];
Expand All @@ -443,19 +447,22 @@ struct MultiDeviceHostStorage {
std::unordered_map<int, Shape> shapes_) : strategy(strategy_), ordered_device_ids(ordered_device_ids_), buffers(buffers_), shapes(shapes_) {}

MultiDeviceStorage(MultiDeviceStorage &&other) {
std::lock_guard<std::mutex> lock(mtx);
ordered_device_ids = other.ordered_device_ids;
strategy = other.strategy;
buffers = other.buffers;
shapes = other.shapes;
}
MultiDeviceStorage(const MultiDeviceStorage &other) {
std::lock_guard<std::mutex> lock(other.mtx);
ordered_device_ids = other.ordered_device_ids;
strategy = other.strategy;
buffers = other.buffers;
shapes = other.shapes;
}

MultiDeviceStorage &operator=(const MultiDeviceStorage &other) {
std::lock_guard<std::mutex> lock(other.mtx);
ordered_device_ids = other.ordered_device_ids;
strategy = other.strategy;
buffers = other.buffers;
Expand All @@ -464,6 +471,7 @@ struct MultiDeviceHostStorage {
}

MultiDeviceStorage &operator=( MultiDeviceStorage &&other) {
std::lock_guard<std::mutex> lock(mtx);
ordered_device_ids = other.ordered_device_ids;
strategy = other.strategy;
buffers = other.buffers;
Expand Down Expand Up @@ -497,18 +505,20 @@ struct MultiDeviceHostStorage {
// Helper Functions - Getters and setters to get/modify storage attributes. These are needed to
// preinitialize empty tensor handles and use/populate them in the worker threads.
void insert_buffer_and_shape_for_device(Device* device, const DeviceBuffer buffer, const Shape shape) {
TT_FATAL(device == buffer->device(), "Mismatch between device derived from buffer and device derived from MultiDeviceStorage.");
std::lock_guard<std::mutex> lock(mtx);
buffers.insert({device->id(), buffer});
shapes.insert({device->id(), shape});
}

DeviceBuffer get_buffer_for_device(Device* device) {
DeviceBuffer get_buffer_for_device(Device* device) const {
std::lock_guard<std::mutex> lock(mtx);
TT_FATAL(buffers.find(device->id()) != buffers.end(), "Buffer not found for device " + std::to_string(device->id()));
TT_FATAL(buffers.at(device->id())->device() == device, "Mismatch between device derived from buffer and device derived from MultiDeviceStorage.");
return buffers.at(device->id());
}

Shape get_tensor_shape_for_device(Device* device) {
Shape get_tensor_shape_for_device(Device* device) const {
std::lock_guard<std::mutex> lock(mtx);
TT_FATAL(shapes.find(device->id()) != shapes.end(), "Shape not found for device " + std::to_string(device->id()));
return shapes.at(device->id());
Expand Down
Loading

0 comments on commit cd0587b

Please sign in to comment.