Skip to content

Commit

Permalink
#0: Rewrite threading, no mutexes
Browse files Browse the repository at this point in the history
  • Loading branch information
sminakov-tt committed Oct 23, 2024
1 parent d2cff39 commit 51b0755
Show file tree
Hide file tree
Showing 10 changed files with 169 additions and 225 deletions.
2 changes: 1 addition & 1 deletion tt_metal/detail/tt_metal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ inline namespace v0 {

void SetLazyCommandQueueMode(bool lazy);

void AllocateBuffer(Buffer* buffer, bool bottom_up);
DeviceAddr AllocateBuffer(const Buffer* buffer, bool bottom_up);

void DeallocateBuffer(Buffer *buffer);
} // namespace detail
Expand Down
4 changes: 2 additions & 2 deletions tt_metal/graph/graph_tracking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ bool GraphTracker::add_hook(const std::shared_ptr<IGraphHooks>& new_hook) {
return true;
}

void GraphTracker::track_allocate(Buffer* buffer, bool bottom_up) {
void GraphTracker::track_allocate(const Buffer* buffer, bool bottom_up) {
if (processors.empty()) {
return;
}
Expand Down Expand Up @@ -73,7 +73,7 @@ void GraphTracker::track_program(Program* program) {
}
}

bool GraphTracker::hook_allocate(Buffer* buffer, bool bottom_up) {
bool GraphTracker::hook_allocate(const Buffer* buffer, bool bottom_up) {
if (hook == nullptr)
return false;

Expand Down
8 changes: 4 additions & 4 deletions tt_metal/graph/graph_tracking.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ inline namespace v0 {

IGraphProcessor() = default;

virtual void track_allocate(tt::tt_metal::Buffer* buffer, bool bottom_up) {};
virtual void track_allocate(const tt::tt_metal::Buffer* buffer, bool bottom_up) {};

virtual void track_deallocate(tt::tt_metal::Buffer* buffer) {};

Expand All @@ -54,7 +54,7 @@ inline namespace v0 {
class IGraphHooks {
public:
IGraphHooks() = default;
virtual bool hook_allocate(tt::tt_metal::Buffer* buffer, bool bottom_up) = 0;
virtual bool hook_allocate(const tt::tt_metal::Buffer* buffer, bool bottom_up) = 0;

virtual bool hook_deallocate(tt::tt_metal::Buffer* buffer) = 0;

Expand All @@ -77,7 +77,7 @@ inline namespace v0 {

bool add_hook(const std::shared_ptr<IGraphHooks>& hook);

void track_allocate(Buffer* buffer, bool bottom_up);
void track_allocate(const Buffer* buffer, bool bottom_up);

void track_deallocate(Buffer* buffer);

Expand Down Expand Up @@ -118,7 +118,7 @@ inline namespace v0 {
}
}

bool hook_allocate(Buffer* buffer, bool bottom_up);
bool hook_allocate(const Buffer* buffer, bool bottom_up);

bool hook_deallocate(Buffer* buffer);

Expand Down
211 changes: 112 additions & 99 deletions tt_metal/impl/buffers/buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,50 +116,6 @@ inline std::tuple<std::vector<std::vector<uint32_t>>, std::vector<std::array<uin
return {ret_vec, ret_shard_shape};
}

std::shared_ptr<Buffer> Buffer::create(
Device *device,
DeviceAddr size,
DeviceAddr page_size,
const BufferType buffer_type,
const TensorMemoryLayout buffer_layout,
const std::optional<ShardSpecBuffer>& shard_parameters,
const std::optional<bool> bottom_up,
bool allocate) {
auto bufferPtr = new Buffer(device, size, page_size, buffer_type, buffer_layout, shard_parameters, bottom_up);
auto buffer = std::shared_ptr<Buffer>(bufferPtr, deallocateAndDelete);
buffer->weak_self = buffer;
if (allocate) {
buffer->allocate();
}
return buffer;
}

Buffer::Buffer(
Device *device,
DeviceAddr size,
DeviceAddr page_size,
const BufferType buffer_type,
const TensorMemoryLayout buffer_layout,
const std::optional<ShardSpecBuffer>& shard_parameters,
const std::optional<bool> bottom_up) :
device_(device),
size_(size),
page_size_(page_size),
buffer_type_(buffer_type),
buffer_layout_(buffer_layout),
shard_parameters_(shard_parameters),
bottom_up_(bottom_up),
buffer_page_mapping_(nullptr) {
TT_FATAL(this->device_ != nullptr and this->device_->allocator_ != nullptr, "Device and allocator need to not be null.");

if (size == 0) {
allocation_status_.store(AllocationStatus::ALLOCATED, std::memory_order::relaxed);
return;
}

validate_buffer_size_and_page_size(size, page_size, buffer_type, buffer_layout, shard_parameters);
}

BufferPageMapping generate_buffer_page_mapping(const Buffer& buffer) {
BufferPageMapping buffer_page_mapping;

Expand Down Expand Up @@ -226,71 +182,129 @@ BufferPageMapping generate_buffer_page_mapping(const Buffer& buffer) {
return buffer_page_mapping;
}

void Buffer::allocate() {
{
std::unique_lock lock(allocation_mutex_);
TT_FATAL(allocation_status_.load(std::memory_order::relaxed) == AllocationStatus::NOT_ALLOCATED, "Can't allocate buffer after it was already allocated");
allocation_status_.store(AllocationStatus::ALLOCATION_REQUESTED, std::memory_order::relaxed);
Buffer::Buffer(
Device *device,
DeviceAddr size,
DeviceAddr page_size,
const BufferType buffer_type,
const TensorMemoryLayout buffer_layout,
const std::optional<ShardSpecBuffer>& shard_parameters,
const std::optional<bool> bottom_up) :
device_(device),
size_(size),
page_size_(page_size),
buffer_type_(buffer_type),
buffer_layout_(buffer_layout),
shard_parameters_(shard_parameters),
bottom_up_(bottom_up),
buffer_page_mapping_(nullptr) {
TT_FATAL(this->device_ != nullptr && this->device_->allocator_ != nullptr, "Device and allocator need to not be null.");

if (size != 0) {
validate_buffer_size_and_page_size(size, page_size, buffer_type, buffer_layout, shard_parameters);
}
}

device_->push_work([self = weak_self.lock()] {
std::unique_lock lock(self->allocation_mutex_);
if (self->allocation_status_.load(std::memory_order::relaxed) != AllocationStatus::ALLOCATION_REQUESTED) {
// The allocation was interrupted by a deallocation
std::shared_ptr<Buffer> Buffer::create(
Device *device,
DeviceAddr size,
DeviceAddr page_size,
const BufferType buffer_type,
const TensorMemoryLayout buffer_layout,
const std::optional<ShardSpecBuffer>& shard_parameters,
const std::optional<bool> bottom_up) {
auto* bufferPtr = new Buffer(device, size, page_size, buffer_type, buffer_layout, shard_parameters, bottom_up);
auto buffer = std::shared_ptr<Buffer>(bufferPtr, deleter);
buffer->weak_self = buffer;

if (buffer->size_ == 0) {
buffer->allocation_status_ = AllocationStatus::ALLOCATED;
return buffer;
}

// Faster path for single-threaded mode
if (buffer->device_->can_use_passthrough_scheduling()) {
buffer->allocate_impl();
buffer->allocation_status_ = AllocationStatus::ALLOCATED;
return buffer;
}

buffer->device_->push_work([buffer] {
auto expected_status = AllocationStatus::ALLOCATION_REQUESTED;
if (!buffer->allocation_status_.compare_exchange_strong(expected_status, AllocationStatus::ALLOCATING)) {
// Buffer was already deallocated before we got here
return;
}

bool bottom_up = self->bottom_up_.value_or(self->is_dram());
detail::AllocateBuffer(self.get(), bottom_up);
detail::BUFFER_MAP.insert({self->device_->id(), self->address_}, self.get());
buffer->allocate_impl();

self->allocation_status_.store(AllocationStatus::ALLOCATED, std::memory_order::relaxed);
lock.unlock();
self->allocation_cv_.notify_all();
// We need compare exchange here to handle the case of deallocation being requested before we finished allocating
expected_status = AllocationStatus::ALLOCATING;
if (buffer->allocation_status_.compare_exchange_strong(expected_status, AllocationStatus::ALLOCATED)) {
buffer->allocation_status_.notify_all();
}
});

return buffer;
}

void Buffer::deallocate() {
if (size_ == 0) {
// 0-size buffer, no need to deallocate
return;
}
void Buffer::allocate_impl() {
bool bottom_up = bottom_up_.value_or(is_dram());
address_ = detail::AllocateBuffer(this, bottom_up);
detail::BUFFER_MAP.insert({-device_->id(), address_}, this);
}

{
std::unique_lock lock(allocation_mutex_);
auto status = allocation_status_.load(std::memory_order::relaxed);
if (status != AllocationStatus::ALLOCATED && status != AllocationStatus::ALLOCATION_REQUESTED) {
// Buffer isn't allocated, nothing to be done
return;
bool Buffer::prepare_deallocation(std::atomic<AllocationStatus>& status) {
while (true) {
auto current_status = status.load();
switch (current_status) {
case AllocationStatus::ALLOCATION_REQUESTED:
// Allocation was requested but not started, canceling allocation, nothing else to be done
if (status.compare_exchange_weak(current_status, AllocationStatus::DEALLOCATED)) {
status.notify_all();
return false;
}
break;
case AllocationStatus::ALLOCATING:
case AllocationStatus::ALLOCATED:
// Allocation already started, will have to deallocate
if (status.compare_exchange_weak(current_status, AllocationStatus::DEALLOCATION_REQUESTED)) {
status.notify_all();
return true;
}
break;
case AllocationStatus::DEALLOCATION_REQUESTED:
case AllocationStatus::DEALLOCATED:
// Deallocation was already started, nothing to be done
return false;
}
// Overwriting either ALLOCATED or ALLOCATION_REQUESTED with DEALLOCATION_REQUESTED
allocation_status_.store(AllocationStatus::DEALLOCATION_REQUESTED, std::memory_order::relaxed);
}
}

void Buffer::deallocate() {
if (!prepare_deallocation(allocation_status_)) {
return;
}

device_->push_work([self = weak_self.lock()] {
// Because the status is DEALLOCATION_REQUESTED, it won't be changed by anyone else, no need to lock a mutex
if (!self->device_->initialized_) {
return;
if (self->device_->initialized_ && self->size_ != 0) {
detail::BUFFER_MAP.erase({self->device_->id(), self->address_});
detail::DeallocateBuffer(self.get());
}

detail::BUFFER_MAP.erase({self->device()->id(), self->address()});
detail::DeallocateBuffer(self.get());
self->allocation_status_.store(AllocationStatus::DEALLOCATED, std::memory_order::relaxed);
self->allocation_status_ = AllocationStatus::DEALLOCATED;
});
}

void Buffer::deallocateAndDelete(Buffer* buffer) {
// This is the last reference to the buffer, no need to lock or update AllocationStatus
void Buffer::deleter(Buffer* buffer) {
// There is no concurrent allocations/deallocations happening, so no extra checks are required
if (buffer->allocation_status_ == AllocationStatus::DEALLOCATED) {
return;
}

buffer->device_->push_work([buffer] {
// Buffer will be deleted at the end of this block
std::unique_ptr<Buffer> unique_buffer = std::unique_ptr<Buffer>(buffer);

auto status = buffer->allocation_status_.load(std::memory_order::relaxed);
if (status == AllocationStatus::NOT_ALLOCATED || status == AllocationStatus::ALLOCATION_REQUESTED || status == AllocationStatus::DEALLOCATED) {
// Buffer isn't allocated, nothing to be done
return;
}

if (!buffer->device_->initialized_ || buffer->size_ == 0) {
return;
}
Expand All @@ -301,30 +315,29 @@ void Buffer::deallocateAndDelete(Buffer* buffer) {
}

bool Buffer::is_allocated() const {
auto allocation_status = allocation_status_.load(std::memory_order::relaxed);
auto allocation_status = allocation_status_.load();

if (device_->can_use_passthrough_scheduling()) {
return allocation_status == AllocationStatus::ALLOCATED;
}
// For calls from different threads we consider buffer to be allocated even if it's just ALLOCATION_REQUESTED,

// For calls from different threads we consider buffer to be allocated even if it's just ALLOCATION_REQUESTED or ALLOCATING,
// because once the caller will try to access it, the buffer will already be fully allocated
return allocation_status == AllocationStatus::ALLOCATED || allocation_status == AllocationStatus::ALLOCATION_REQUESTED;
return allocation_status == AllocationStatus::ALLOCATION_REQUESTED
|| allocation_status == AllocationStatus::ALLOCATING
|| allocation_status == AllocationStatus::ALLOCATED;
}

uint32_t Buffer::address() const {
if (device_->can_use_passthrough_scheduling()) {
// No locking required, because address can only be modified from the same thread
return address_;
}

std::unique_lock lock(allocation_mutex_);
allocation_cv_.wait(lock, [this] { return this->allocation_status_.load(std::memory_order::relaxed) != AllocationStatus::ALLOCATION_REQUESTED; });
return address_;
}
// Waiting for the buffer to be allocated if the allocation is pending
allocation_status_.wait(AllocationStatus::ALLOCATION_REQUESTED);
allocation_status_.wait(AllocationStatus::ALLOCATING);

void Buffer::set_address(uint64_t addr) {
TT_FATAL(device_->can_use_passthrough_scheduling() , "Buffer::set_address must be called in device worker thread");
TT_FATAL(allocation_status_.load(std::memory_order::relaxed) == AllocationStatus::ALLOCATION_REQUESTED, "Buffer address can only be set during allocation");
address_ = addr;
return address_;
}

DeviceAddr Buffer::page_size() const {
Expand Down
Loading

0 comments on commit 51b0755

Please sign in to comment.