Skip to content

Commit

Permalink
#0: rewrite linkedlist using shared pointers to avoid memory leaks
Browse files Browse the repository at this point in the history
  • Loading branch information
Muthu committed Dec 21, 2023
1 parent 85e7d6f commit 6daa904
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 84 deletions.
2 changes: 1 addition & 1 deletion tt_metal/detail/tt_metal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ namespace tt::tt_metal{
{
std::lock_guard<std::mutex> lock(cq_creation_mutex);
if (not command_queues[id] or (command_queues[id] and command_queues[id]->device != device)) {
command_queues[device->id()] = std::make_unique<CommandQueue>(device);
command_queues[device->id()] = std::move( std::make_unique<CommandQueue>(device) );
}
}
return *(command_queues[id]);
Expand Down
119 changes: 60 additions & 59 deletions tt_metal/impl/allocator/algorithms/free_list.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,21 @@ FreeList::FreeList(uint64_t max_size_bytes, uint64_t offset_bytes, uint64_t min_
}

void FreeList::init() {
auto block = new FreeList::Block{.address = 0, .size = this->max_size_bytes_};
this->block_head_ = block;
this->block_tail_ = block;
this->free_block_head_ = block;
this->free_block_tail_ = block;
this->block_head_.reset( new FreeList::Block{.address = 0, .size = this->max_size_bytes_} );
this->block_tail_ = block_head_;
this->free_block_head_ = block_head_;
this->free_block_tail_ = block_head_;
}

bool FreeList::is_allocated(const Block *block) const {
return block->prev_free == nullptr and block->next_free == nullptr and block != this->free_block_head_ and block != this->free_block_tail_;
return (block->prev_free == nullptr) and (block->next_free == nullptr and block != this->free_block_head_.get()) and (block != this->free_block_tail_.get());
}

std::vector<std::pair<uint64_t, uint64_t>> FreeList::available_addresses(uint64_t size_bytes) const {
uint64_t alloc_size = size_bytes < this->min_allocation_size_ ? this->min_allocation_size_ : size_bytes;
alloc_size = this->align(alloc_size);
std::vector<std::pair<uint64_t, uint64_t>> addresses;
FreeList::Block *curr_block = this->free_block_head_;
std::shared_ptr<Block> curr_block = this->free_block_head_;
while (curr_block != nullptr) {
if (curr_block->size >= alloc_size) {
uint64_t end_range = (curr_block->address + curr_block->size) - alloc_size;
Expand All @@ -46,50 +45,48 @@ std::vector<std::pair<uint64_t, uint64_t>> FreeList::available_addresses(uint64_
return addresses;
}

FreeList::Block *FreeList::search_best(uint64_t size_bytes, bool bottom_up) {
FreeList::Block *best_block = nullptr;
FreeList::Block *curr_block = bottom_up ? this->free_block_head_ : this->free_block_tail_;
std::shared_ptr<FreeList::Block> FreeList::search_best(uint64_t size_bytes, bool bottom_up) {
std::shared_ptr<FreeList::Block> best_block;
std::shared_ptr<FreeList::Block> curr_block = bottom_up ? this->free_block_head_ : this->free_block_tail_;
while (curr_block != nullptr) {
if (curr_block->size == size_bytes) {
best_block = curr_block;
return best_block;
return std::move(best_block);
}
curr_block = bottom_up ? curr_block->next_free : curr_block->prev_free;
}
//best search fail over to search first
return search_first(size_bytes,bottom_up);
}

FreeList::Block *FreeList::search_first(uint64_t size_bytes, bool bottom_up) {
FreeList::Block *curr_block = bottom_up ? this->free_block_head_ : this->free_block_tail_;
FreeList::Block *first_fit_block = nullptr;
std::shared_ptr<FreeList::Block> FreeList::search_first(uint64_t size_bytes, bool bottom_up) {
std::shared_ptr<FreeList::Block> curr_block = bottom_up ? this->free_block_head_ : this->free_block_tail_;
while (curr_block != nullptr) {
if (curr_block->size >= size_bytes) {
first_fit_block = curr_block;
break;
return std::move(curr_block);
}
curr_block = bottom_up ? curr_block->next_free : curr_block->prev_free;
}

return first_fit_block;
return std::shared_ptr<FreeList::Block>{};
}

FreeList::Block *FreeList::search(uint64_t size_bytes, bool bottom_up) {
std::shared_ptr<FreeList::Block> FreeList::search(uint64_t size_bytes, bool bottom_up) {
switch (this->search_policy_) {
case FreeList::SearchPolicy::BEST:
return search_best(size_bytes, bottom_up);
return std::move(search_best(size_bytes, bottom_up));
break;
case FreeList::SearchPolicy::FIRST:
return search_first(size_bytes, bottom_up);
return std::move(search_first(size_bytes, bottom_up));
break;
default:
TT_ASSERT(false && "Unsupported search policy");
}
return nullptr;
return std::shared_ptr<FreeList::Block>{};
}

void FreeList::allocate_entire_free_block(Block *free_block_to_allocate) {
TT_ASSERT(not is_allocated(free_block_to_allocate));
void FreeList::allocate_entire_free_block(std::shared_ptr<FreeList::Block>& free_block_to_allocate) {
TT_ASSERT(not is_allocated(free_block_to_allocate.get()));
if (free_block_to_allocate->prev_free != nullptr) {
free_block_to_allocate->prev_free->next_free = free_block_to_allocate->next_free;
}
Expand All @@ -116,14 +113,15 @@ void FreeList::allocate_entire_free_block(Block *free_block_to_allocate) {

// free_block range: [a, b)
// allocated_block range: [a, c), where c < b
void FreeList::update_left_aligned_allocated_block_connections(Block *free_block, Block *allocated_block) {
void FreeList::update_left_aligned_allocated_block_connections(std::shared_ptr<FreeList::Block>& free_block,
std::shared_ptr<FreeList::Block>& allocated_block) {
allocated_block->prev_block = free_block->prev_block;
allocated_block->next_block = free_block;
if (free_block->prev_block != nullptr) {
free_block->prev_block->next_block = allocated_block;
}
if (free_block == this->block_head_) {
this->block_head_ = allocated_block;
this->block_head_ = allocated_block ;
}
// next_free and prev_free connections of free_block are still valid
free_block->prev_block = allocated_block;
Expand All @@ -133,7 +131,7 @@ void FreeList::update_left_aligned_allocated_block_connections(Block *free_block

// free_block range: [a, b)
// allocated_block range: [c, b), where c > a
void FreeList::update_right_aligned_allocated_block_connections(Block *free_block, Block *allocated_block) {
void FreeList::update_right_aligned_allocated_block_connections(std::shared_ptr<FreeList::Block>& free_block, std::shared_ptr<FreeList::Block>& allocated_block) {
allocated_block->prev_block = free_block;
allocated_block->next_block = free_block->next_block;
if (free_block->next_block != nullptr) {
Expand All @@ -148,7 +146,7 @@ void FreeList::update_right_aligned_allocated_block_connections(Block *free_bloc
}

// Offset marks the start of the allocated block
FreeList::Block *FreeList::allocate_slice_of_free_block(Block *free_block, uint64_t offset, uint64_t size_bytes) {
std::shared_ptr<FreeList::Block> FreeList::allocate_slice_of_free_block(std::shared_ptr<FreeList::Block>& free_block, uint64_t offset, uint64_t size_bytes) {
TT_ASSERT(free_block->address + offset + size_bytes <= free_block->address + free_block->size);

// Allocated slice spans the entire space of free_block
Expand All @@ -157,10 +155,10 @@ FreeList::Block *FreeList::allocate_slice_of_free_block(Block *free_block, uint6
return free_block;
}

auto allocated_block = new FreeList::Block{
std::shared_ptr<FreeList::Block> allocated_block(new FreeList::Block{
.address = free_block->address + offset,
.size = size_bytes,
};
});

// Allocated slice takes up a portion of free_block, three cases to consider:
// 1. allocated_block is left aligned with free_block with free space remaining on the right
Expand All @@ -181,25 +179,24 @@ FreeList::Block *FreeList::allocate_slice_of_free_block(Block *free_block, uint6
// Result: | free_block_mod | allocated_block | next_free_block |
uint64_t next_free_block_addr = free_block->address + offset + size_bytes;
uint64_t next_free_block_size = (free_block->address + free_block->size) - next_free_block_addr;
auto next_free_block = new FreeList::Block{
std::shared_ptr<FreeList::Block> next_free_block(new Block{
.address = next_free_block_addr,
.size = next_free_block_size,
.prev_block = allocated_block,
.next_block = free_block->next_block,
.prev_free = free_block,
.next_free = free_block->next_free
};
.next_free = free_block->next_free});
if (free_block->next_block != nullptr) {
free_block->next_block->prev_block = next_free_block;
}
if (free_block->next_free != nullptr) {
free_block->next_free->prev_free = next_free_block;
}
if (this->free_block_tail_ == free_block) {
this->free_block_tail_ = next_free_block;
this->free_block_tail_ = next_free_block ;
}
if (this->block_tail_ == free_block) {
this->block_tail_ = next_free_block;
this->block_tail_ = next_free_block ;
}
free_block->next_free = next_free_block;
free_block->next_block = allocated_block;
Expand All @@ -210,7 +207,7 @@ FreeList::Block *FreeList::allocate_slice_of_free_block(Block *free_block, uint6
free_block->size -= (allocated_block->size + next_free_block->size);
}

return allocated_block;
return std::move(allocated_block);
}

void FreeList::update_lowest_occupied_address(uint64_t address) {
Expand Down Expand Up @@ -244,7 +241,7 @@ std::optional<uint64_t> FreeList::allocate(uint64_t size_bytes, bool bottom_up,
std::optional<uint64_t> FreeList::allocate_at_address(uint64_t absolute_start_address, uint64_t size_bytes) {
TT_ASSERT(absolute_start_address % this->alignment_ == 0, "Requested address " + std::to_string(absolute_start_address) + " should be " + std::to_string(this->alignment_) + "B aligned");
auto start_address = absolute_start_address - this->offset_bytes_;
FreeList::Block *curr_block = this->free_block_head_;
auto& curr_block = this->free_block_head_;
uint64_t alloc_size = size_bytes < this->min_allocation_size_ ? this->min_allocation_size_ : size_bytes;
alloc_size = this->align(alloc_size);
// Look for a free block of size at least size_bytes that encompasses start_address
Expand All @@ -269,22 +266,21 @@ std::optional<uint64_t> FreeList::allocate_at_address(uint64_t absolute_start_ad
return absolute_start_address;
}

FreeList::Block *FreeList::find_block(uint64_t address) {
FreeList::Block *block = nullptr;
FreeList::Block *curr_block = this->block_head_;
std::shared_ptr<FreeList::Block> FreeList::find_block(uint64_t address) {
auto curr_block = this->block_head_;
while (curr_block != nullptr) {
if (curr_block->address == address) {
return curr_block;
}
curr_block = curr_block->next_block;
}
return block;
return std::shared_ptr<FreeList::Block>(nullptr);
}

void FreeList::update_lowest_occupied_address() {
FreeList::Block *block = this->block_head_;
auto& block = this->block_head_;
while (block != nullptr) {
if (this->is_allocated(block)) {
if (this->is_allocated(block.get())) {
break;
}
block = block->next_block;
Expand All @@ -298,8 +294,8 @@ void FreeList::update_lowest_occupied_address() {

void FreeList::deallocate(uint64_t absolute_address) {
uint64_t address = absolute_address - this->offset_bytes_;
FreeList::Block *block_to_free = find_block(address);
if (block_to_free == nullptr or not this->is_allocated(block_to_free)) {
auto block_to_free = find_block(address);
if (block_to_free == nullptr or not this->is_allocated(block_to_free.get())) {
return;
}

Expand All @@ -308,7 +304,7 @@ void FreeList::deallocate(uint64_t absolute_address) {

bool merged_prev = false;
bool merged_next = false;
if (prev != nullptr and not is_allocated(prev)) {
if (prev != nullptr and not is_allocated(prev.get())) {
prev->next_block = block_to_free->next_block;
if (block_to_free->next_block != nullptr) {
block_to_free->next_block->prev_block = prev;
Expand All @@ -318,16 +314,16 @@ void FreeList::deallocate(uint64_t absolute_address) {
merged_prev = true;
}

if (next != nullptr and not is_allocated(next)) {
if (next != nullptr and not is_allocated(next.get())) {
block_to_free->next_block = next->next_block;
if (next->next_block != nullptr) {
next->next_block->prev_block = block_to_free;
}
if (next == this->free_block_head_) {
this->free_block_head_ = block_to_free;
this->free_block_head_ = block_to_free ;
}
if (next == this->free_block_tail_) {
this->free_block_tail_ = block_to_free;
this->free_block_tail_ = block_to_free ;
}
block_to_free->next_free = next->next_free;
if (next->next_free != nullptr) {
Expand All @@ -346,11 +342,11 @@ void FreeList::deallocate(uint64_t absolute_address) {
if (not merged_prev and not merged_next) {
// Find where to include deallocated block in free list
auto prev_free_block = block_to_free->prev_block;
while (prev_free_block != nullptr and is_allocated(prev_free_block)) {
while (prev_free_block != nullptr and is_allocated(prev_free_block.get())) {
prev_free_block = prev_free_block->prev_block;
}
auto next_free_block = block_to_free->next_block;
while (next_free_block != nullptr and is_allocated(next_free_block)) {
while (next_free_block != nullptr and is_allocated(next_free_block.get())) {
next_free_block = next_free_block->next_block;
}
block_to_free->prev_free = prev_free_block;
Expand All @@ -362,9 +358,9 @@ void FreeList::deallocate(uint64_t absolute_address) {

block_to_free->next_free = next_free_block;
if (next_free_block != nullptr) {
next_free_block->prev_free = block_to_free;
next_free_block->prev_free = block_to_free ;
} else {
this->free_block_tail_ = block_to_free;
this->free_block_tail_ = block_to_free ;
}
}

Expand All @@ -374,11 +370,11 @@ void FreeList::deallocate(uint64_t absolute_address) {
}

void FreeList::reset() {
Block *curr_block = this->block_head_;
Block *next;
std::shared_ptr<Block> curr_block = this->block_head_;
std::shared_ptr<Block> next;
while (curr_block != nullptr) {
next = curr_block->next_block;
delete curr_block;
curr_block.reset();
curr_block = next;
}
this->block_head_ = nullptr;
Expand All @@ -398,7 +394,7 @@ Statistics FreeList::get_statistics() const {
.largest_free_block_bytes = 0
};

Block *curr_block = this->block_head_;
Block *curr_block = this->block_head_.get();
while (curr_block != nullptr) {
if (this->is_allocated(curr_block)) {
stats.total_allocated_bytes += curr_block->size;
Expand All @@ -409,7 +405,7 @@ Statistics FreeList::get_statistics() const {
stats.largest_free_block_addrs.push_back(curr_block->address + this->offset_bytes_);
}
}
curr_block = curr_block->next_block;
curr_block = curr_block->next_block.get();
}
if (stats.total_allocated_bytes == 0) {
stats.total_free_bytes = this->max_size_bytes_;
Expand All @@ -419,7 +415,12 @@ Statistics FreeList::get_statistics() const {
}

FreeList::~FreeList() {
// this->block_head_ and this->free_block_head_ are reset
this->reset();
this->block_head_.reset();
this->free_block_head_.reset();
this->block_tail_.reset();
this->free_block_tail_.reset();
}

void FreeList::dump_block(const Block *block, std::ofstream &out) const {
Expand All @@ -431,10 +432,10 @@ void FreeList::dump_block(const Block *block, std::ofstream &out) const {

void FreeList::dump_blocks(std::ofstream &out) const {
out << ",,Blocks:\n";
Block *curr_block = this->block_head_;
Block *curr_block = this->block_head_.get();
while (curr_block != nullptr) {
this->dump_block(curr_block, out);
curr_block = curr_block->next_block;
curr_block = curr_block->next_block.get();
}
out << "\n";
}
Expand Down
Loading

0 comments on commit 6daa904

Please sign in to comment.