diff --git a/tt_metal/impl/dispatch/command_queue.cpp b/tt_metal/impl/dispatch/command_queue.cpp index 45d5f95ec28..8fab9164e69 100644 --- a/tt_metal/impl/dispatch/command_queue.cpp +++ b/tt_metal/impl/dispatch/command_queue.cpp @@ -86,16 +86,18 @@ void EnqueueReadInterleavedBufferCommand::add_prefetch_relay(HugepageDeviceComma void EnqueueReadShardedBufferCommand::add_prefetch_relay(HugepageDeviceCommand& command) { uint32_t padded_page_size = align(this->buffer.page_size(), ADDRESS_ALIGNMENT); - CoreCoord logical_core = - this->buffer_page_mapping.all_cores_[this->buffer_page_mapping.dev_page_to_core_mapping_[this->src_page_index]]; - CoreCoord core = this->buffer.device()->worker_core_from_logical_core(logical_core); + const CoreCoord worker_core = this->buffer.device()->worker_core_from_logical_core(this->core); command.add_prefetch_relay_linear( get_noc_unicast_encoding(core), padded_page_size * this->pages_to_read, - this->buffer.address() + - this->buffer_page_mapping.host_page_to_local_shard_page_mapping_ - [this->buffer_page_mapping.dev_page_to_host_page_mapping_[this->src_page_index].value()] * - padded_page_size); + this->buffer.address() + (this->buffer_page_mapping.has_value() + ? ((*this->buffer_page_mapping) + .host_page_to_local_shard_page_mapping_ + [(*this->buffer_page_mapping) + .dev_page_to_host_page_mapping_[this->src_page_index] + .value()] * + padded_page_size) + : this->src_page_index * padded_page_size)); } void EnqueueReadBufferCommand::process() { @@ -224,30 +226,19 @@ void EnqueueWriteShardedBufferCommand::add_buffer_data(HugepageDeviceCommand& co uint32_t data_size_bytes = this->pages_to_write * this->padded_page_size; if (this->buffer_page_mapping.has_value()) { const auto& page_mapping = this->buffer_page_mapping.value(); - uint32_t core_index = page_mapping.dev_page_to_core_mapping_[this->dst_page_index]; - bool width_page_padded = - page_mapping.core_shard_shape_[core_index][1] != buffer.shard_spec().shape_in_pages()[1]; - if (width_page_padded or this->width_split or - (this->buffer.page_size() != this->padded_page_size and this->buffer.page_size() != this->buffer.size())) { - uint8_t* dst = command_sequence.reserve_space(data_size_bytes); - // TODO: Expose getter for cmd_write_offsetB? - uint32_t dst_offset = dst - (uint8_t*)command_sequence.data(); - for (uint32_t dev_page = this->dst_page_index; dev_page < this->dst_page_index + this->pages_to_write; - ++dev_page) { - auto& host_page = page_mapping.dev_page_to_host_page_mapping_[dev_page]; - if (host_page.has_value()) { - command_sequence.update_cmd_sequence( - dst_offset, - (char*)this->src + host_page.value() * this->buffer.page_size(), - this->buffer.page_size()); - } - dst_offset += this->padded_page_size; + uint8_t* dst = command_sequence.reserve_space(data_size_bytes); + // TODO: Expose getter for cmd_write_offsetB? + uint32_t dst_offset = dst - (uint8_t*)command_sequence.data(); + for (uint32_t dev_page = this->dst_page_index; dev_page < this->dst_page_index + this->pages_to_write; + ++dev_page) { + auto& host_page = page_mapping.dev_page_to_host_page_mapping_[dev_page]; + if (host_page.has_value()) { + command_sequence.update_cmd_sequence( + dst_offset, + (char*)this->src + host_page.value() * this->buffer.page_size(), + this->buffer.page_size()); } - } else { - // There are no padded pages - uint32_t unpadded_src_offset = - page_mapping.dev_page_to_host_page_mapping_[this->dst_page_index].value() * this->buffer.page_size(); - command_sequence.add_data((char*)this->src + unpadded_src_offset, data_size_bytes, data_size_bytes); + dst_offset += this->padded_page_size; } } else { if (this->buffer.page_size() != this->padded_page_size and this->buffer.page_size() != this->buffer.size()) { @@ -1316,20 +1307,43 @@ void HWCommandQueue::enqueue_read_buffer(Buffer& buffer, void* dst, bool blockin uint32_t src_page_index = 0; if (is_sharded(buffer.buffer_layout())) { - auto buffer_page_mapping = generate_buffer_page_mapping(buffer); + bool width_split = buffer.shard_spec().shape_in_pages()[1] != buffer.shard_spec().tensor2d_shape[1]; + bool height_sharded = buffer.buffer_layout() == TensorMemoryLayout::HEIGHT_SHARDED; + std::optional buffer_page_mapping = std::nullopt; + if (!height_sharded) { + buffer_page_mapping = generate_buffer_page_mapping(buffer); + } // Note that the src_page_index is the device page idx, not the host page idx // Since we read core by core we are reading the device pages sequentially - bool width_split = buffer.shard_spec().shape_in_pages()[1] != buffer.shard_spec().tensor2d_shape[1]; + const auto& cores = !height_sharded ? buffer_page_mapping.value().all_cores_ + : corerange_to_cores( + buffer.shard_spec().grid(), + buffer.num_cores(), + buffer.shard_spec().orientation() == ShardOrientation::ROW_MAJOR); + uint32_t num_total_pages = buffer.num_pages(); + uint32_t max_pages_per_shard = buffer.shard_spec().size(); for (uint32_t core_id = 0; core_id < buffer.num_cores(); ++core_id) { - uint32_t num_pages_to_read = - buffer_page_mapping.core_shard_shape_[core_id][0] * buffer.shard_spec().shape_in_pages()[1]; + uint32_t num_pages_to_read; + bool linear_page_copy; + if (!height_sharded) { + num_pages_to_read = + buffer_page_mapping.value().core_shard_shape_[core_id][0] * buffer.shard_spec().shape_in_pages()[1]; + bool width_page_padded = buffer_page_mapping.value().core_shard_shape_[core_id][1] != + buffer.shard_spec().shape_in_pages()[1]; + linear_page_copy = !(width_split or width_page_padded); + } else { + num_pages_to_read = min(num_total_pages, max_pages_per_shard); + num_total_pages -= num_pages_to_read; + linear_page_copy = true; + } if (num_pages_to_read > 0) { - bool width_page_padded = - buffer_page_mapping.core_shard_shape_[core_id][1] != buffer.shard_spec().shape_in_pages()[1]; - bool linear_page_copy = !(width_split or width_page_padded); - uint32_t host_page = buffer_page_mapping.core_host_page_indices_[core_id][0]; - src_page_index = buffer_page_mapping.host_page_to_dev_page_mapping_[host_page]; - unpadded_dst_offset = host_page * buffer.page_size(); + if (height_sharded) { + unpadded_dst_offset = src_page_index * buffer.page_size(); + } else { + uint32_t host_page = buffer_page_mapping.value().core_host_page_indices_[core_id][0]; + src_page_index = buffer_page_mapping.value().host_page_to_dev_page_mapping_[host_page]; + unpadded_dst_offset = host_page * buffer.page_size(); + } auto command = EnqueueReadShardedBufferCommand( this->id, @@ -1339,6 +1353,7 @@ void HWCommandQueue::enqueue_read_buffer(Buffer& buffer, void* dst, bool blockin this->manager, this->expected_num_workers_completed, buffer_page_mapping, + cores[core_id], src_page_index, num_pages_to_read); @@ -1349,7 +1364,10 @@ void HWCommandQueue::enqueue_read_buffer(Buffer& buffer, void* dst, bool blockin unpadded_dst_offset, num_pages_to_read, src_page_index, - linear_page_copy)); + linear_page_copy ? (*buffer_page_mapping).dev_page_to_host_page_mapping_ + : vector>())); + + src_page_index += num_pages_to_read; this->enqueue_command(command, false); this->increment_num_entries_in_completion_q(); } @@ -1455,6 +1473,7 @@ void HWCommandQueue::enqueue_write_buffer(const Buffer& buffer, const void* src, // Currently since writing sharded tensors uses write_linear, we write the padded pages on width // Alternative write each page row into separate commands, or have a strided linear write uint32_t num_pages; + bool linear_write; if (!height_sharded) { num_pages = buffer_page_mapping.value().core_shard_shape_[core_id][0] * buffer.shard_spec().shape_in_pages()[1]; @@ -1463,9 +1482,13 @@ void HWCommandQueue::enqueue_write_buffer(const Buffer& buffer, const void* src, } dst_page_index = buffer_page_mapping.value().host_page_to_dev_page_mapping_ [buffer_page_mapping.value().core_host_page_indices_[core_id][0]]; + bool width_page_padded = buffer_page_mapping.value().core_shard_shape_[core_id][1] != + buffer.shard_spec().shape_in_pages()[1]; + linear_write = !(width_split or width_page_padded); } else { num_pages = min(num_total_pages, max_pages_per_shard); num_total_pages -= num_pages; + linear_write = true; } uint32_t curr_page_idx_in_shard = 0; while (num_pages != 0) { @@ -1498,7 +1521,7 @@ void HWCommandQueue::enqueue_write_buffer(const Buffer& buffer, const void* src, issue_wait, this->expected_num_workers_completed, bank_base_address, - buffer_page_mapping, + linear_write ? std::nullopt : buffer_page_mapping, cores[core_id], width_split, padded_page_size, @@ -1701,7 +1724,7 @@ void HWCommandQueue::enqueue_trace(const uint32_t trace_id, bool blocking) { void HWCommandQueue::copy_into_user_space( const detail::ReadBufferDescriptor& read_buffer_descriptor, chip_id_t mmio_device_id, uint16_t channel) { - const auto& [buffer_layout, page_size, padded_page_size, linear_page_copy, dev_page_to_host_page_mapping, dst, dst_offset, num_pages_read, cur_dev_page_id] = + const auto& [buffer_layout, page_size, padded_page_size, dev_page_to_host_page_mapping, dst, dst_offset, num_pages_read, cur_dev_page_id] = read_buffer_descriptor; uint32_t padded_num_bytes = (num_pages_read * padded_page_size) + sizeof(CQDispatchCmd); @@ -1747,7 +1770,7 @@ void HWCommandQueue::copy_into_user_space( remaining_bytes_to_read -= bytes_xfered; - if (linear_page_copy) { + if (dev_page_to_host_page_mapping.empty()) { void* contiguous_dst = (void*)(uint64_t(dst) + contig_dst_offset); if ((page_size % ADDRESS_ALIGNMENT) == 0) { uint32_t data_bytes_xfered = bytes_xfered - offset_in_completion_q_data; diff --git a/tt_metal/impl/dispatch/command_queue.hpp b/tt_metal/impl/dispatch/command_queue.hpp index 1b42b8224f9..b89fe153c74 100644 --- a/tt_metal/impl/dispatch/command_queue.hpp +++ b/tt_metal/impl/dispatch/command_queue.hpp @@ -132,7 +132,8 @@ class EnqueueReadInterleavedBufferCommand : public EnqueueReadBufferCommand { class EnqueueReadShardedBufferCommand : public EnqueueReadBufferCommand { private: void add_prefetch_relay(HugepageDeviceCommand& command) override; - BufferPageMapping buffer_page_mapping; + const std::optional& buffer_page_mapping; + const CoreCoord core; public: EnqueueReadShardedBufferCommand( @@ -142,7 +143,8 @@ class EnqueueReadShardedBufferCommand : public EnqueueReadBufferCommand { void* dst, SystemMemoryManager& manager, uint32_t expected_num_workers_completed, - const BufferPageMapping& buffer_page_mapping, + const std::optional& buffer_page_mapping, + const CoreCoord& core, uint32_t src_page_index = 0, std::optional pages_to_read = std::nullopt) : EnqueueReadBufferCommand( @@ -154,7 +156,8 @@ class EnqueueReadShardedBufferCommand : public EnqueueReadBufferCommand { expected_num_workers_completed, src_page_index, pages_to_read), - buffer_page_mapping(buffer_page_mapping) {} + buffer_page_mapping(buffer_page_mapping), + core(core) {} }; class EnqueueWriteShardedBufferCommand; @@ -254,7 +257,7 @@ class EnqueueWriteShardedBufferCommand : public EnqueueWriteBufferCommand { uint32_t expected_num_workers_completed, uint32_t bank_base_address, const std::optional& buffer_page_mapping, - const CoreCoord core, + const CoreCoord& core, bool width_split, uint32_t padded_page_size, uint32_t dst_page_index = 0, @@ -416,7 +419,6 @@ struct ReadBufferDescriptor { TensorMemoryLayout buffer_layout; uint32_t page_size; uint32_t padded_page_size; - bool linear_page_copy; vector> dev_page_to_host_page_mapping; void* dst; uint32_t dst_offset; @@ -430,7 +432,7 @@ struct ReadBufferDescriptor { uint32_t dst_offset, uint32_t num_pages_read, uint32_t cur_dev_page_id, - bool linear_page_copy = true) : + const std::vector>& dev_page_to_host_page_mapping = {}) : buffer_layout(buffer.buffer_layout()), page_size(this->page_size = buffer.page_size()), padded_page_size(padded_page_size), @@ -438,11 +440,7 @@ struct ReadBufferDescriptor { dst_offset(dst_offset), num_pages_read(num_pages_read), cur_dev_page_id(cur_dev_page_id), - linear_page_copy(linear_page_copy) { - if (!linear_page_copy and is_sharded(this->buffer_layout)) { - this->dev_page_to_host_page_mapping = generate_buffer_page_mapping(buffer).dev_page_to_host_page_mapping_; - } - } + dev_page_to_host_page_mapping(dev_page_to_host_page_mapping) {} }; /*