Skip to content

Commit

Permalink
#9527: Remove usage of bcast for cases where multiply could be used
Browse files Browse the repository at this point in the history
  • Loading branch information
eyonland committed Jun 26, 2024
1 parent a8e9e22 commit 8caffae
Show file tree
Hide file tree
Showing 21 changed files with 444 additions and 220 deletions.
4 changes: 4 additions & 0 deletions tt_eager/tensor/host_buffer/functions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "tensor/host_buffer/types.hpp"
#include "tensor/tensor.hpp"
#include "tt_metal/tt_stl/reflection.hpp"

namespace tt {

Expand Down Expand Up @@ -41,6 +42,7 @@ Buffer<T> get_as(BorrowedBuffer& buffer) {

template <typename T>
Buffer<T> get_as(const BorrowedBuffer& buffer) {
TT_ASSERT(std::holds_alternative<Buffer<T>>(buffer), fmt::format("Unexpected type {} in {}:{} ",tt::stl::get_active_type_name_in_variant(buffer),__FILE__, __LINE__));
return std::get<Buffer<T>>(buffer);
}

Expand Down Expand Up @@ -155,12 +157,14 @@ namespace host_buffer {

template <typename T>
borrowed_buffer::Buffer<T> get_as(OwnedBuffer& buffer) {
TT_ASSERT(std::holds_alternative<owned_buffer::Buffer<T>>(buffer), fmt::format("Unexpected type {} in {}:{} ",tt::stl::get_active_type_name_in_variant(buffer),__FILE__, __LINE__));
auto& owned_buffer = std::get<owned_buffer::Buffer<T>>(buffer);
return borrowed_buffer::Buffer<T>(owned_buffer.begin(), owned_buffer.size());
}

template <typename T>
borrowed_buffer::Buffer<T> get_as(const OwnedBuffer& buffer) {
TT_ASSERT(std::holds_alternative<owned_buffer::Buffer<T>>(buffer), fmt::format("Unexpected type {} in {}:{} ",tt::stl::get_active_type_name_in_variant(buffer),__FILE__, __LINE__));
auto owned_buffer = std::get<owned_buffer::Buffer<T>>(buffer);
return borrowed_buffer::Buffer<T>(owned_buffer.begin(), owned_buffer.size());
}
Expand Down
4 changes: 4 additions & 0 deletions tt_eager/tensor/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ void Tensor::deallocate(bool force) {
auto dealloc_lambda = std::make_shared<std::function<void(Device*)>>(
[force, attr = this->tensor_attributes](Device* worker) mutable {
ZoneScopedN("ShardDeallocate");
TT_ASSERT(std::holds_alternative<tt::tt_metal::MultiDeviceStorage>(attr->storage), fmt::format("Unexpected type {} in {}:{} ",tt::stl::get_active_type_name_in_variant(attr->storage),__FILE__, __LINE__));
auto& s = std::get<MultiDeviceStorage>(attr->storage);
if (s.has_buffer_for_device(worker)) {
auto& device_buffer = s.get_buffer_for_device(worker);
Expand Down Expand Up @@ -809,6 +810,8 @@ const Shape Tensor::strides() const { return detail::compute_strides(this->get_l

uint32_t Tensor::volume() const { return tt::tt_metal::compute_volume(this->get_legacy_shape()); }

uint32_t Tensor::intended_volume() const { return tt::tt_metal::compute_volume(this->get_shape()); }

Tensor create_device_tensor(
const Shape& shape, DataType data_type, Layout layout, Device* device, const MemoryConfig& memory_config) {
ZoneScoped;
Expand Down Expand Up @@ -1037,6 +1040,7 @@ void write_tensor(Tensor host_tensor, Tensor device_tensor, uint8_t cq_id) {
auto host_storage = std::get<BorrowedStorage>(async_safe_tensor.get_storage());
std::visit([&host_data](auto&& b) { host_data = b.data(); }, host_storage.buffer);
} else {
TT_ASSERT(std::holds_alternative<OwnedStorage>(async_safe_tensor.get_storage()), fmt::format("Unexpected type {} in {}:{} ",tt::stl::get_active_type_name_in_variant(async_safe_tensor.get_storage()),__FILE__, __LINE__));
auto host_storage = std::get<OwnedStorage>(async_safe_tensor.get_storage());
std::visit([&host_data](auto&& b) { host_data = b.begin(); }, host_storage.get_buffer());
}
Expand Down
1 change: 1 addition & 0 deletions tt_eager/tensor/tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@ struct Tensor {
StorageType storage_type() const;
const Shape strides() const;
uint32_t volume() const;
uint32_t intended_volume() const;

bool is_allocated() const;

Expand Down
7 changes: 7 additions & 0 deletions tt_eager/tensor/tensor_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,7 @@ bool is_multi_device_tensor(const Tensor& tensor) {
std::vector<Tensor> get_tensors_from_multi_device_storage(const Tensor& multi_device_tensor) {
std::vector<ttnn::Tensor> tensors;
if (multi_device_tensor.storage_type() == StorageType::MULTI_DEVICE) {
TT_ASSERT(std::holds_alternative<MultiDeviceStorage>(multi_device_tensor.get_storage()), fmt::format("Unexpected type {} in {}:{} ",tt::stl::get_active_type_name_in_variant(multi_device_tensor.get_storage()),__FILE__, __LINE__));
const auto& tensor_storage = std::get<MultiDeviceStorage>(multi_device_tensor.get_storage());
tensors = std::vector<ttnn::Tensor>(tensor_storage.num_buffers(), Tensor());
for (int i = 0; i < tensor_storage.ordered_device_ids.size(); ++i) {
Expand All @@ -432,6 +433,7 @@ std::vector<Tensor> get_tensors_from_multi_device_storage(const Tensor& multi_de
}
return tensors;
} else if (multi_device_tensor.storage_type() == StorageType::MULTI_DEVICE_HOST) {
TT_ASSERT(std::holds_alternative<MultiDeviceStorage>(multi_device_tensor.get_storage()), fmt::format("Unexpected type {} in {}:{} ",tt::stl::get_active_type_name_in_variant(multi_device_tensor.get_storage()),__FILE__, __LINE__));
const auto& tensor_storage = std::get<MultiDeviceHostStorage>(multi_device_tensor.get_storage());
for (int i = 0; i < tensor_storage.num_buffers(); ++i) {
tensors.push_back(Tensor{
Expand All @@ -448,9 +450,11 @@ std::vector<Tensor> get_tensors_from_multi_device_storage(const Tensor& multi_de

DistributedTensorConfig get_distributed_tensor_config_from_tensor(const Tensor& tensor) {
if (tensor.storage_type() == StorageType::MULTI_DEVICE) {
TT_ASSERT(std::holds_alternative<MultiDeviceStorage>(tensor.get_storage()), fmt::format("Unexpected type {} in {}:{} ",tt::stl::get_active_type_name_in_variant(tensor.get_storage()),__FILE__, __LINE__));
const auto& tensor_storage = std::get<MultiDeviceStorage>(tensor.get_storage());
return tensor_storage.strategy;
} else if (tensor.storage_type() == StorageType::MULTI_DEVICE_HOST) {
TT_ASSERT(std::holds_alternative<MultiDeviceStorage>(tensor.get_storage()), fmt::format("Unexpected type {} in {}:{} ",tt::stl::get_active_type_name_in_variant(tensor.get_storage()),__FILE__, __LINE__));
const auto& tensor_storage = std::get<MultiDeviceHostStorage>(tensor.get_storage());
return tensor_storage.strategy;
}
Expand All @@ -468,6 +472,7 @@ Tensor create_multi_device_tensor(
std::unordered_map<int, tt::tt_metal::Shape> shapes;
std::unordered_map<int, DeviceBuffer> device_buffers;
for (const auto& tensor : tensors) {
TT_ASSERT(std::holds_alternative<DeviceStorage>(tensor.get_storage()), fmt::format("Unexpected type {} in {}:{} ",tt::stl::get_active_type_name_in_variant(tensor.get_storage()),__FILE__, __LINE__));
Device* device = std::get<DeviceStorage>(tensor.get_storage()).buffer->device();
auto device_id = device->id();
ordered_device_ids.push_back(device_id);
Expand All @@ -483,6 +488,7 @@ Tensor create_multi_device_tensor(
std::vector<OwnedBuffer> owned_buffers;
std::vector<Shape> shapes;
for (const auto& tensor : tensors) {
TT_ASSERT(std::holds_alternative<OwnedStorage>(tensor.get_storage()), fmt::format("Unexpected type {} in {}:{} ",tt::stl::get_active_type_name_in_variant(tensor.get_storage()),__FILE__, __LINE__));
owned_buffers.push_back(std::get<OwnedStorage>(tensor.get_storage()).buffer);
shapes.push_back(tensor.get_legacy_shape());
}
Expand Down Expand Up @@ -516,6 +522,7 @@ void apply(const Tensor& tensor, std::function<void(const Tensor&)> callable) {
std::vector<Device*> get_devices(const Tensor& tensor) {
std::vector<Device*> devices;
if (tensor.storage_type() == tt::tt_metal::StorageType::MULTI_DEVICE) {
TT_ASSERT(std::holds_alternative<tt::tt_metal::MultiDeviceStorage>(tensor.get_storage()), fmt::format("Unexpected type {} in {}:{} ",tt::stl::get_active_type_name_in_variant(tensor.get_storage()),__FILE__, __LINE__));
const auto& tensor_storage = std::get<tt::tt_metal::MultiDeviceStorage>(tensor.get_storage());
for (int i = 0; i < tensor_storage.ordered_device_ids.size(); ++i) {
auto device_id = tensor_storage.ordered_device_ids[i];
Expand Down
18 changes: 18 additions & 0 deletions tt_eager/tt_dnn/op_library/auto_format.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,24 @@ Tensor AutoFormat::move_tensor_to_mem_config(const Tensor& input, const MemoryCo
}
}

// This code is a workaround for cases where we need to remove autoformat but other dependent ops
// are not quite ready. So here we basically just put the tensor back on device.
// Used in backward_ops.cpp
// See: Remove auto format within permute_op.cpp #9404
Tensor AutoFormat::move_tensor_to_device_and_pad(const Tensor& input, Device *device, Layout target_layout, std::optional<MemoryConfig> target_mem_config){
const auto intended_shape = input.get_shape();
const auto device_shape = input.get_legacy_shape();
const auto new_intended_shape = std::array<std::uint32_t, 4>{intended_shape[0], intended_shape[1], intended_shape[-2], intended_shape[-1]};
const auto new_device_shape = std::array<std::uint32_t, 4>{
device_shape[0],
device_shape[1],
(device_shape[-2] % TILE_HEIGHT != 0 ? (device_shape[-2] / TILE_HEIGHT + 1) * TILE_HEIGHT : device_shape[-2]),
(device_shape[-1] % TILE_WIDTH != 0 ? (device_shape[-1] / TILE_WIDTH + 1) * TILE_WIDTH : device_shape[-1])
};
const auto new_shape = tt_metal::Shape(new_intended_shape, new_device_shape);
return AutoFormat::format_input_tensor(input, device, new_shape, 0.0, target_layout, target_mem_config);
}

Tensor AutoFormat::format_input_tensor(
const Tensor& input,
Device* device,
Expand Down
14 changes: 10 additions & 4 deletions tt_eager/tt_dnn/op_library/auto_format.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ class AutoFormat {


static Shape pad_to_tile_shape(const Shape& unpadded_shape, bool pad_c=false, bool pad_n=false, bool pad_h=true, bool pad_w=true) {
auto n = pad_n ? round_up(unpadded_shape[0], TILE_HEIGHT) : unpadded_shape[0];
auto c = pad_c ? round_up(unpadded_shape[1], TILE_WIDTH) : unpadded_shape[1];
auto h = pad_h ? round_up(unpadded_shape[2], TILE_HEIGHT) : unpadded_shape[2];
auto w = pad_w ? round_up(unpadded_shape[3], TILE_WIDTH) : unpadded_shape[3];
auto n = pad_n ? round_up(unpadded_shape.rank() >= 4 ? unpadded_shape[-4] : 1, TILE_HEIGHT) : unpadded_shape.rank() >= 4 ? unpadded_shape[-4] : 1;
auto c = pad_c ? round_up(unpadded_shape.rank() >= 3 ? unpadded_shape[-3] : 1, TILE_WIDTH) : unpadded_shape.rank() >= 3 ? unpadded_shape[-3] : 1;
auto h = pad_h ? round_up(unpadded_shape[-2], TILE_HEIGHT) : unpadded_shape[-2];
auto w = pad_w ? round_up(unpadded_shape[-1], TILE_WIDTH) : unpadded_shape[-1];
Shape padded_shape = {n, c, h, w};
return padded_shape;
}
Expand Down Expand Up @@ -83,6 +83,12 @@ class AutoFormat {
return false;
}

// This code is a workaround for cases where we need to remove autoformat but other dependent ops
// are not quite ready. So here we basically just put the tensor back on device.
// Used in backward_ops.cpp
// See: Remove auto format within permute_op.cpp #9404
static Tensor move_tensor_to_device_and_pad(const Tensor& input, Device *device, Layout target_layout, std::optional<MemoryConfig> target_mem_config);

static Tensor move_tensor_to_device(const Tensor &input, Device * device, const MemoryConfig& mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);

static Tensor move_tensor_to_mem_config(const Tensor &input, const MemoryConfig& mem_config);
Expand Down
10 changes: 10 additions & 0 deletions tt_eager/tt_dnn/op_library/backward/backward_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2218,6 +2218,11 @@ std::vector<Tensor> _prod_bw(
std::vector<int64_t> after_permute_dims = {0, 2, 3, 1};
Tensor tensor_1 = permute(tensor_1_temp, after_permute_dims, output_mem_config);
Tensor tensor_2 = permute(temp, after_permute_dims, output_mem_config);

// put the tensor back on device because permute throws it off device
// See: Remove auto format within permute_op.cpp #9404
tensor_2 = AutoFormat::move_tensor_to_device_and_pad(tensor_2, tensor_1.device(),tensor_1.get_layout(), tensor_1.memory_config());

after_permute_dims = {0, 3, 1, 2};
Tensor result = permute(
bcast(tensor_1, tensor_2, BcastOpMath::MUL, BcastOpDim::W, output_mem_config),
Expand Down Expand Up @@ -2250,6 +2255,11 @@ std::vector<Tensor> _prod_bw(
std::vector<int64_t> after_permute_dims = {3, 1, 2, 0};
Tensor tensor_1 = permute(tensor_1_temp, after_permute_dims, output_mem_config);
Tensor tensor_2 = permute(temp, after_permute_dims, output_mem_config);

// put the tensor back on device because permute throws it off device
// See: Remove auto format within permute_op.cpp #9404
tensor_2 = AutoFormat::move_tensor_to_device_and_pad(tensor_2, tensor_1.device(),tensor_1.get_layout(), tensor_1.memory_config());

Tensor result = permute(
bcast(tensor_1, tensor_2, BcastOpMath::MUL, BcastOpDim::W, output_mem_config),
after_permute_dims,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ namespace tt_metal {
operation::ProgramWithCallbacks bcast_sharded_h(const Tensor &a, const Tensor &b, const Tensor& output, BcastOpMath bcast_math/*, BcastOpDim bcast_dim*/){
const auto ashape = a.get_legacy_shape();
const auto bshape = b.get_legacy_shape();
uint32_t N = ashape[0], C = ashape[1], H = ashape[2], W = ashape[3];
uint32_t bN = bshape[0], bC = bshape[1], bH = bshape[2], bW = bshape[3];
uint32_t N = ashape.rank() >= 4 ? ashape[-4] : 1, C = ashape.rank() >= 3 ? ashape[-3] : 1, H = ashape[-2], W = ashape[-1];
uint32_t bN = bshape.rank() >= 4 ? bshape[-4] : 1, bC = bshape.rank() >= 3 ? bshape[-3] : 1, bH = bshape[-2], bW = bshape[-1];
uint32_t NC = N*C;


Expand Down
Loading

0 comments on commit 8caffae

Please sign in to comment.