Skip to content

Commit

Permalink
#0: Skip generating buffer page mapping for linear sharded reads
Browse files Browse the repository at this point in the history
  • Loading branch information
tt-aho committed May 24, 2024
1 parent 1c456d6 commit ca93d4a
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 55 deletions.
111 changes: 67 additions & 44 deletions tt_metal/impl/dispatch/command_queue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,16 +86,18 @@ void EnqueueReadInterleavedBufferCommand::add_prefetch_relay(HugepageDeviceComma

void EnqueueReadShardedBufferCommand::add_prefetch_relay(HugepageDeviceCommand& command) {
uint32_t padded_page_size = align(this->buffer.page_size(), ADDRESS_ALIGNMENT);
CoreCoord logical_core =
this->buffer_page_mapping.all_cores_[this->buffer_page_mapping.dev_page_to_core_mapping_[this->src_page_index]];
CoreCoord core = this->buffer.device()->worker_core_from_logical_core(logical_core);
const CoreCoord worker_core = this->buffer.device()->worker_core_from_logical_core(this->core);
command.add_prefetch_relay_linear(
get_noc_unicast_encoding(core),
padded_page_size * this->pages_to_read,
this->buffer.address() +
this->buffer_page_mapping.host_page_to_local_shard_page_mapping_
[this->buffer_page_mapping.dev_page_to_host_page_mapping_[this->src_page_index].value()] *
padded_page_size);
this->buffer.address() + (this->buffer_page_mapping.has_value()
? ((*this->buffer_page_mapping)
.host_page_to_local_shard_page_mapping_
[(*this->buffer_page_mapping)
.dev_page_to_host_page_mapping_[this->src_page_index]
.value()] *
padded_page_size)
: this->src_page_index * padded_page_size));
}

void EnqueueReadBufferCommand::process() {
Expand Down Expand Up @@ -224,30 +226,19 @@ void EnqueueWriteShardedBufferCommand::add_buffer_data(HugepageDeviceCommand& co
uint32_t data_size_bytes = this->pages_to_write * this->padded_page_size;
if (this->buffer_page_mapping.has_value()) {
const auto& page_mapping = this->buffer_page_mapping.value();
uint32_t core_index = page_mapping.dev_page_to_core_mapping_[this->dst_page_index];
bool width_page_padded =
page_mapping.core_shard_shape_[core_index][1] != buffer.shard_spec().shape_in_pages()[1];
if (width_page_padded or this->width_split or
(this->buffer.page_size() != this->padded_page_size and this->buffer.page_size() != this->buffer.size())) {
uint8_t* dst = command_sequence.reserve_space<uint8_t*, true>(data_size_bytes);
// TODO: Expose getter for cmd_write_offsetB?
uint32_t dst_offset = dst - (uint8_t*)command_sequence.data();
for (uint32_t dev_page = this->dst_page_index; dev_page < this->dst_page_index + this->pages_to_write;
++dev_page) {
auto& host_page = page_mapping.dev_page_to_host_page_mapping_[dev_page];
if (host_page.has_value()) {
command_sequence.update_cmd_sequence(
dst_offset,
(char*)this->src + host_page.value() * this->buffer.page_size(),
this->buffer.page_size());
}
dst_offset += this->padded_page_size;
uint8_t* dst = command_sequence.reserve_space<uint8_t*, true>(data_size_bytes);
// TODO: Expose getter for cmd_write_offsetB?
uint32_t dst_offset = dst - (uint8_t*)command_sequence.data();
for (uint32_t dev_page = this->dst_page_index; dev_page < this->dst_page_index + this->pages_to_write;
++dev_page) {
auto& host_page = page_mapping.dev_page_to_host_page_mapping_[dev_page];
if (host_page.has_value()) {
command_sequence.update_cmd_sequence(
dst_offset,
(char*)this->src + host_page.value() * this->buffer.page_size(),
this->buffer.page_size());
}
} else {
// There are no padded pages
uint32_t unpadded_src_offset =
page_mapping.dev_page_to_host_page_mapping_[this->dst_page_index].value() * this->buffer.page_size();
command_sequence.add_data((char*)this->src + unpadded_src_offset, data_size_bytes, data_size_bytes);
dst_offset += this->padded_page_size;
}
} else {
if (this->buffer.page_size() != this->padded_page_size and this->buffer.page_size() != this->buffer.size()) {
Expand Down Expand Up @@ -1316,20 +1307,43 @@ void HWCommandQueue::enqueue_read_buffer(Buffer& buffer, void* dst, bool blockin
uint32_t src_page_index = 0;

if (is_sharded(buffer.buffer_layout())) {
auto buffer_page_mapping = generate_buffer_page_mapping(buffer);
bool width_split = buffer.shard_spec().shape_in_pages()[1] != buffer.shard_spec().tensor2d_shape[1];
bool height_sharded = buffer.buffer_layout() == TensorMemoryLayout::HEIGHT_SHARDED;
std::optional<BufferPageMapping> buffer_page_mapping = std::nullopt;
if (!height_sharded) {
buffer_page_mapping = generate_buffer_page_mapping(buffer);
}
// Note that the src_page_index is the device page idx, not the host page idx
// Since we read core by core we are reading the device pages sequentially
bool width_split = buffer.shard_spec().shape_in_pages()[1] != buffer.shard_spec().tensor2d_shape[1];
const auto& cores = !height_sharded ? buffer_page_mapping.value().all_cores_
: corerange_to_cores(
buffer.shard_spec().grid(),
buffer.num_cores(),
buffer.shard_spec().orientation() == ShardOrientation::ROW_MAJOR);
uint32_t num_total_pages = buffer.num_pages();
uint32_t max_pages_per_shard = buffer.shard_spec().size();
for (uint32_t core_id = 0; core_id < buffer.num_cores(); ++core_id) {
uint32_t num_pages_to_read =
buffer_page_mapping.core_shard_shape_[core_id][0] * buffer.shard_spec().shape_in_pages()[1];
uint32_t num_pages_to_read;
bool linear_page_copy;
if (!height_sharded) {
num_pages_to_read =
buffer_page_mapping.value().core_shard_shape_[core_id][0] * buffer.shard_spec().shape_in_pages()[1];
bool width_page_padded = buffer_page_mapping.value().core_shard_shape_[core_id][1] !=
buffer.shard_spec().shape_in_pages()[1];
linear_page_copy = !(width_split or width_page_padded);
} else {
num_pages_to_read = min(num_total_pages, max_pages_per_shard);
num_total_pages -= num_pages_to_read;
linear_page_copy = true;
}
if (num_pages_to_read > 0) {
bool width_page_padded =
buffer_page_mapping.core_shard_shape_[core_id][1] != buffer.shard_spec().shape_in_pages()[1];
bool linear_page_copy = !(width_split or width_page_padded);
uint32_t host_page = buffer_page_mapping.core_host_page_indices_[core_id][0];
src_page_index = buffer_page_mapping.host_page_to_dev_page_mapping_[host_page];
unpadded_dst_offset = host_page * buffer.page_size();
if (height_sharded) {
unpadded_dst_offset = src_page_index * buffer.page_size();
} else {
uint32_t host_page = buffer_page_mapping.value().core_host_page_indices_[core_id][0];
src_page_index = buffer_page_mapping.value().host_page_to_dev_page_mapping_[host_page];
unpadded_dst_offset = host_page * buffer.page_size();
}

auto command = EnqueueReadShardedBufferCommand(
this->id,
Expand All @@ -1339,6 +1353,7 @@ void HWCommandQueue::enqueue_read_buffer(Buffer& buffer, void* dst, bool blockin
this->manager,
this->expected_num_workers_completed,
buffer_page_mapping,
cores[core_id],
src_page_index,
num_pages_to_read);

Expand All @@ -1349,7 +1364,10 @@ void HWCommandQueue::enqueue_read_buffer(Buffer& buffer, void* dst, bool blockin
unpadded_dst_offset,
num_pages_to_read,
src_page_index,
linear_page_copy));
linear_page_copy ? (*buffer_page_mapping).dev_page_to_host_page_mapping_
: vector<std::optional<uint32_t>>()));

src_page_index += num_pages_to_read;
this->enqueue_command(command, false);
this->increment_num_entries_in_completion_q();
}
Expand Down Expand Up @@ -1455,6 +1473,7 @@ void HWCommandQueue::enqueue_write_buffer(const Buffer& buffer, const void* src,
// Currently since writing sharded tensors uses write_linear, we write the padded pages on width
// Alternative write each page row into separate commands, or have a strided linear write
uint32_t num_pages;
bool linear_write;
if (!height_sharded) {
num_pages =
buffer_page_mapping.value().core_shard_shape_[core_id][0] * buffer.shard_spec().shape_in_pages()[1];
Expand All @@ -1463,9 +1482,13 @@ void HWCommandQueue::enqueue_write_buffer(const Buffer& buffer, const void* src,
}
dst_page_index = buffer_page_mapping.value().host_page_to_dev_page_mapping_
[buffer_page_mapping.value().core_host_page_indices_[core_id][0]];
bool width_page_padded = buffer_page_mapping.value().core_shard_shape_[core_id][1] !=
buffer.shard_spec().shape_in_pages()[1];
linear_write = !(width_split or width_page_padded);
} else {
num_pages = min(num_total_pages, max_pages_per_shard);
num_total_pages -= num_pages;
linear_write = true;
}
uint32_t curr_page_idx_in_shard = 0;
while (num_pages != 0) {
Expand Down Expand Up @@ -1498,7 +1521,7 @@ void HWCommandQueue::enqueue_write_buffer(const Buffer& buffer, const void* src,
issue_wait,
this->expected_num_workers_completed,
bank_base_address,
buffer_page_mapping,
linear_write ? std::nullopt : buffer_page_mapping,
cores[core_id],
width_split,
padded_page_size,
Expand Down Expand Up @@ -1701,7 +1724,7 @@ void HWCommandQueue::enqueue_trace(const uint32_t trace_id, bool blocking) {

void HWCommandQueue::copy_into_user_space(
const detail::ReadBufferDescriptor& read_buffer_descriptor, chip_id_t mmio_device_id, uint16_t channel) {
const auto& [buffer_layout, page_size, padded_page_size, linear_page_copy, dev_page_to_host_page_mapping, dst, dst_offset, num_pages_read, cur_dev_page_id] =
const auto& [buffer_layout, page_size, padded_page_size, dev_page_to_host_page_mapping, dst, dst_offset, num_pages_read, cur_dev_page_id] =
read_buffer_descriptor;

uint32_t padded_num_bytes = (num_pages_read * padded_page_size) + sizeof(CQDispatchCmd);
Expand Down Expand Up @@ -1747,7 +1770,7 @@ void HWCommandQueue::copy_into_user_space(

remaining_bytes_to_read -= bytes_xfered;

if (linear_page_copy) {
if (dev_page_to_host_page_mapping.empty()) {
void* contiguous_dst = (void*)(uint64_t(dst) + contig_dst_offset);
if ((page_size % ADDRESS_ALIGNMENT) == 0) {
uint32_t data_bytes_xfered = bytes_xfered - offset_in_completion_q_data;
Expand Down
20 changes: 9 additions & 11 deletions tt_metal/impl/dispatch/command_queue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,8 @@ class EnqueueReadInterleavedBufferCommand : public EnqueueReadBufferCommand {
class EnqueueReadShardedBufferCommand : public EnqueueReadBufferCommand {
private:
void add_prefetch_relay(HugepageDeviceCommand& command) override;
BufferPageMapping buffer_page_mapping;
const std::optional<BufferPageMapping>& buffer_page_mapping;
const CoreCoord core;

public:
EnqueueReadShardedBufferCommand(
Expand All @@ -142,7 +143,8 @@ class EnqueueReadShardedBufferCommand : public EnqueueReadBufferCommand {
void* dst,
SystemMemoryManager& manager,
uint32_t expected_num_workers_completed,
const BufferPageMapping& buffer_page_mapping,
const std::optional<BufferPageMapping>& buffer_page_mapping,
const CoreCoord& core,
uint32_t src_page_index = 0,
std::optional<uint32_t> pages_to_read = std::nullopt) :
EnqueueReadBufferCommand(
Expand All @@ -154,7 +156,8 @@ class EnqueueReadShardedBufferCommand : public EnqueueReadBufferCommand {
expected_num_workers_completed,
src_page_index,
pages_to_read),
buffer_page_mapping(buffer_page_mapping) {}
buffer_page_mapping(buffer_page_mapping),
core(core) {}
};

class EnqueueWriteShardedBufferCommand;
Expand Down Expand Up @@ -254,7 +257,7 @@ class EnqueueWriteShardedBufferCommand : public EnqueueWriteBufferCommand {
uint32_t expected_num_workers_completed,
uint32_t bank_base_address,
const std::optional<BufferPageMapping>& buffer_page_mapping,
const CoreCoord core,
const CoreCoord& core,
bool width_split,
uint32_t padded_page_size,
uint32_t dst_page_index = 0,
Expand Down Expand Up @@ -416,7 +419,6 @@ struct ReadBufferDescriptor {
TensorMemoryLayout buffer_layout;
uint32_t page_size;
uint32_t padded_page_size;
bool linear_page_copy;
vector<std::optional<uint32_t>> dev_page_to_host_page_mapping;
void* dst;
uint32_t dst_offset;
Expand All @@ -430,19 +432,15 @@ struct ReadBufferDescriptor {
uint32_t dst_offset,
uint32_t num_pages_read,
uint32_t cur_dev_page_id,
bool linear_page_copy = true) :
const std::vector<std::optional<uint32_t>>& dev_page_to_host_page_mapping = {}) :
buffer_layout(buffer.buffer_layout()),
page_size(this->page_size = buffer.page_size()),
padded_page_size(padded_page_size),
dst(dst),
dst_offset(dst_offset),
num_pages_read(num_pages_read),
cur_dev_page_id(cur_dev_page_id),
linear_page_copy(linear_page_copy) {
if (!linear_page_copy and is_sharded(this->buffer_layout)) {
this->dev_page_to_host_page_mapping = generate_buffer_page_mapping(buffer).dev_page_to_host_page_mapping_;
}
}
dev_page_to_host_page_mapping(dev_page_to_host_page_mapping) {}
};

/*
Expand Down

0 comments on commit ca93d4a

Please sign in to comment.