Skip to content

Commit

Permalink
#0: Unify create_device_tensor and create_sharded_device_tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
tt-aho committed May 16, 2024
1 parent cd1be06 commit 07149f7
Show file tree
Hide file tree
Showing 30 changed files with 80 additions and 85 deletions.
1 change: 0 additions & 1 deletion tests/tt_eager/tensors/test_raw_host_memory_pointer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,6 @@ void test_raw_host_memory_pointer() {
Tensor e_dev = tt::tt_metal::add(c_dev, d_dev);

tt::tt_metal::memcpy(tensor_for_printing, e_dev);
tensor_for_printing.print();

for (auto& element : owned_buffer::get_as<bfloat16>(tensor_for_printing)) {
TT_ASSERT(element == bfloat16(10.0f));
Expand Down
58 changes: 28 additions & 30 deletions tt_eager/tensor/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -795,38 +795,36 @@ uint32_t Tensor::volume() const { return tt::tt_metal::compute_volume(this->get_

Tensor create_device_tensor(const Shape& shape, DataType data_type, Layout layout, Device *device, const MemoryConfig& memory_config) {
ZoneScoped;
uint32_t packed_size_in_bytes = tensor_impl::packed_buffer_size_bytes_wrapper(data_type, compute_buffer_size(shape, data_type));
auto device_buffer = tensor_impl::allocate_buffer_on_device(packed_size_in_bytes, device, shape, data_type, layout, memory_config);
return Tensor(DeviceStorage{device_buffer}, shape, data_type, layout);
}
if (memory_config.is_sharded()) {
TT_ASSERT(memory_config.shard_spec.has_value());
TT_ASSERT(memory_config.is_l1());

Tensor create_sharded_device_tensor(const Shape& shape, DataType data_type, Layout layout, Device *device, const MemoryConfig& memory_config) {
ZoneScoped;
TT_ASSERT(memory_config.is_sharded());
TT_ASSERT(memory_config.shard_spec.has_value());
TT_ASSERT(memory_config.is_l1());
auto shard_spec = memory_config.shard_spec.value();
auto& shard_shape = shard_spec.shape;

auto shard_spec = memory_config.shard_spec.value();
auto& shard_shape = shard_spec.shape;
auto width = shape[-1];
auto other_dims = 1;
for (int i = 0; i < shape.rank() - 1; i++) {
other_dims *= shape[i];
}

auto width = shape[-1];
auto other_dims = 1;
for (int i = 0; i < shape.rank() - 1; i++) {
other_dims *= shape[i];
auto element_size = tensor_impl::element_size_bytes_wrapper(data_type);
auto page_shape = tensor_impl::get_sharded_page_shape(layout, data_type, shard_spec.shape);
std::array<uint32_t,2> tensor2d_size = {other_dims/page_shape[0], width/page_shape[1]};
ShardSpecBuffer shard_spec_buffer(shard_spec, page_shape, tensor2d_size);
uint32_t packed_size_in_bytes;

packed_size_in_bytes = tensor_impl::packed_buffer_size_bytes_wrapper(data_type, compute_buffer_size(shape, data_type));
auto device_buffer = tensor_impl::allocate_buffer_on_device(packed_size_in_bytes, device, shape,
data_type, layout, memory_config,
std::make_optional<ShardSpecBuffer>(shard_spec_buffer)
);
return Tensor(DeviceStorage{device_buffer}, shape, data_type, layout);
} else {
uint32_t packed_size_in_bytes = tensor_impl::packed_buffer_size_bytes_wrapper(data_type, compute_buffer_size(shape, data_type));
auto device_buffer = tensor_impl::allocate_buffer_on_device(packed_size_in_bytes, device, shape, data_type, layout, memory_config);
return Tensor(DeviceStorage{device_buffer}, shape, data_type, layout);
}

auto element_size = tensor_impl::element_size_bytes_wrapper(data_type);
auto page_shape = tensor_impl::get_sharded_page_shape(layout, data_type, shard_spec.shape);
std::array<uint32_t,2> tensor2d_size = {other_dims/page_shape[0], width/page_shape[1]};
ShardSpecBuffer shard_spec_buffer(shard_spec, page_shape, tensor2d_size);
uint32_t packed_size_in_bytes;

packed_size_in_bytes = tensor_impl::packed_buffer_size_bytes_wrapper(data_type, compute_buffer_size(shape, data_type));
auto device_buffer = tensor_impl::allocate_buffer_on_device(packed_size_in_bytes, device, shape,
data_type, layout, memory_config,
std::make_optional<ShardSpecBuffer>(shard_spec_buffer)
);
return Tensor(DeviceStorage{device_buffer}, shape, data_type, layout);
}

void* get_raw_host_data_ptr(const Tensor& tensor) {
Expand Down Expand Up @@ -910,7 +908,7 @@ Tensor allocate_tensor_on_device(const Shape& shape, DataType data_type, Layout
device->push_work(
[shape, data_type, layout, device, memory_config, device_tensor] () mutable {
if (memory_config.is_sharded()) {
auto local_tensor = create_sharded_device_tensor(shape, data_type, layout, device, memory_config);
auto local_tensor = create_device_tensor(shape, data_type, layout, device, memory_config);
device_tensor.populate_buffers_and_metadata(local_tensor);
} else {
auto local_tensor = create_device_tensor(shape, data_type, layout, device, memory_config);
Expand All @@ -934,7 +932,7 @@ Tensor allocate_tensor_on_device(const Shape& shape, DataType data_type, Layout
worker->push_work(
[shape, data_type, layout, worker, memory_config, device_tensor, worker_index] () mutable {
if (memory_config.is_sharded()) {
auto local_tensor = create_sharded_device_tensor(shape, data_type, layout, worker, memory_config);
auto local_tensor = create_device_tensor(shape, data_type, layout, worker, memory_config);
insert_buffer_and_shape_for_device(worker, local_tensor, device_tensor, worker_index);
} else {
auto local_tensor = create_device_tensor(shape, data_type, layout, worker, memory_config);
Expand Down
2 changes: 0 additions & 2 deletions tt_eager/tensor/tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -339,8 +339,6 @@ struct Tensor {

Tensor create_device_tensor(const Shape& shape, DataType dtype, Layout layout, Device *device, const MemoryConfig& memory_config = {.memory_layout=tt::tt_metal::TensorMemoryLayout::INTERLEAVED});

Tensor create_sharded_device_tensor(const Shape& shape, DataType data_type, Layout layout, Device *device, const MemoryConfig& memory_config);

// template<typename Buffer>
// void *get_host_buffer(const Tensor &tensor);
void *get_raw_host_data_ptr(const Tensor &tensor);
Expand Down
2 changes: 1 addition & 1 deletion tt_eager/tt_dnn/op_library/all_gather/all_gather_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ std::vector<Shape> AllGather::compute_output_shapes(const std::vector<Tensor> &i
std::vector<Tensor> AllGather::create_output_tensors(const std::vector<Tensor> &input_tensors) const {
const auto& input_tensor = input_tensors[0];
if(this->output_mem_config.is_sharded()) {
return {create_sharded_device_tensor(
return {create_device_tensor(
this->compute_output_shapes(input_tensors).at(0),
input_tensor.get_dtype(),
input_tensor.get_layout(),
Expand Down
2 changes: 1 addition & 1 deletion tt_eager/tt_dnn/op_library/bcast/bcast_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ std::vector<Tensor> EltwiseBinaryBroadcast::create_output_tensors(const std::vec
}
auto mem_config = this->output_mem_config;
mem_config.shard_spec = shard_spec;
return {create_sharded_device_tensor(this->compute_output_shapes(input_tensors).at(0), input_tensor.get_dtype(), Layout::TILE, input_tensor.device(), mem_config)};
return {create_device_tensor(this->compute_output_shapes(input_tensors).at(0), input_tensor.get_dtype(), Layout::TILE, input_tensor.device(), mem_config)};
} else {
return operation::generic_create_output_tensors(*this, input_tensors, input_tensor.get_dtype(), Layout::TILE, this->output_mem_config);
}
Expand Down
6 changes: 3 additions & 3 deletions tt_eager/tt_dnn/op_library/bmm/bmm_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -932,7 +932,7 @@ std::vector<Tensor> Matmul::create_output_tensors(const std::vector<Tensor>& inp
ShardSpec shard_spec = ShardSpec{all_cores, {per_core_M * TILE_HEIGHT, per_core_N * TILE_WIDTH}, ShardOrientation::ROW_MAJOR};
auto mem_config = this->output_mem_config;
mem_config.shard_spec = shard_spec;
return {create_sharded_device_tensor(this->compute_output_shapes(input_tensors).at(0), this->output_dtype, output_layout, input_tensor_a.device(), mem_config)};
return {create_device_tensor(this->compute_output_shapes(input_tensors).at(0), this->output_dtype, output_layout, input_tensor_a.device(), mem_config)};
} else if constexpr (
std::is_same_v<ProgramConfigType, MatmulMultiCoreReuseMultiCastProgramConfig>
) {
Expand All @@ -957,7 +957,7 @@ std::vector<Tensor> Matmul::create_output_tensors(const std::vector<Tensor>& inp
ShardSpec shard_spec = ShardSpec{all_cores, {per_core_M * TILE_HEIGHT, per_core_N * TILE_WIDTH}, shard_orientation};
auto mem_config = this->output_mem_config;
mem_config.shard_spec = shard_spec;
return {create_sharded_device_tensor(this->compute_output_shapes(input_tensors).at(0), this->output_dtype, output_layout, input_tensor_a.device(), mem_config)};
return {create_device_tensor(this->compute_output_shapes(input_tensors).at(0), this->output_dtype, output_layout, input_tensor_a.device(), mem_config)};
} else if constexpr (
std::is_same_v<ProgramConfigType, MatmulMultiCoreReuseProgramConfig>
) {
Expand All @@ -981,7 +981,7 @@ std::vector<Tensor> Matmul::create_output_tensors(const std::vector<Tensor>& inp
ShardSpec shard_spec = ShardSpec{all_cores, {per_core_M * TILE_HEIGHT, per_core_N * TILE_WIDTH}, shard_orientation};
auto mem_config = this->output_mem_config;
mem_config.shard_spec = shard_spec;
return {create_sharded_device_tensor(this->compute_output_shapes(input_tensors).at(0), this->output_dtype, output_layout, input_tensor_a.device(), mem_config)};
return {create_device_tensor(this->compute_output_shapes(input_tensors).at(0), this->output_dtype, output_layout, input_tensor_a.device(), mem_config)};
} else {
TT_FATAL(false, "Unsupported op for output sharding");
return {};
Expand Down
2 changes: 1 addition & 1 deletion tt_eager/tt_dnn/op_library/concat/concat_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ std::vector<Tensor> Concat::create_output_tensors(const std::vector<Tensor> &inp
const Tensor &ref_in_tensor = input_tensors.at(0);

if (this->output_mem_config.is_sharded()) {
return {create_sharded_device_tensor(
return {create_device_tensor(
this->compute_output_shapes(input_tensors).at(0),
ref_in_tensor.get_dtype(),
ref_in_tensor.get_layout(),
Expand Down
4 changes: 2 additions & 2 deletions tt_eager/tt_dnn/op_library/conv/optimized_conv_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ std::vector<Tensor> OptimizedConv::create_output_tensors(const std::vector<Tenso
auto shard_spec = ShardSpec{shard_grid, shard_shape, ShardOrientation::ROW_MAJOR};
auto mem_config = this->output_mem_config;
mem_config.shard_spec = shard_spec;
return {create_sharded_device_tensor(output_shape, this->output_dtype, output_layout, input_tensor.device(), mem_config)};
return {create_device_tensor(output_shape, this->output_dtype, output_layout, input_tensor.device(), mem_config)};
} else {
auto [act_matrix_shape, act_matrix_shape_unpadded] = optimized_conv_op_utils::compute_opt_conv_activation_as_mm_shape(this->input_tensor_shape, conv_params, this->parallelization_config.per_core_out_matrix_height_ntiles, extra_padding_for_32B_alignment);
uint32_t act_matrix_height = (uint32_t) act_matrix_shape[1];
Expand All @@ -180,7 +180,7 @@ std::vector<Tensor> OptimizedConv::create_output_tensors(const std::vector<Tenso
auto shard_spec = ShardSpec{shard_grid, shard_shape, transpose_mcast ? ShardOrientation::COL_MAJOR : ShardOrientation::ROW_MAJOR};
auto mem_config = this->output_mem_config;
mem_config.shard_spec = shard_spec;
return {create_sharded_device_tensor(output_shape, this->output_dtype, output_layout, input_tensor.device(), mem_config)};
return {create_device_tensor(output_shape, this->output_dtype, output_layout, input_tensor.device(), mem_config)};
}

}
Expand Down
2 changes: 1 addition & 1 deletion tt_eager/tt_dnn/op_library/downsample/downsample_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ std::vector<Tensor> Downsample::create_output_tensors(const std::vector<Tensor>
uint32_t output_shard_width = round_up(output_shape[3], num_cores_width_sliced * TILE_WIDTH) / num_cores_width_sliced;
auto mem_config = input_tensor.memory_config();
mem_config.shard_spec = ShardSpec {input_tensor.shard_spec().value().grid, std::array<uint32_t, 2>{{output_shard_height, output_shard_width}}, input_tensor.shard_spec().value().orientation};
return {create_sharded_device_tensor(output_shape, this->output_dtype, Layout::TILE, input_tensor.device(), mem_config)};
return {create_device_tensor(output_shape, this->output_dtype, Layout::TILE, input_tensor.device(), mem_config)};
}

operation::ProgramWithCallbacks Downsample::create_program(const std::vector<Tensor>& input_tensors, std::vector<Tensor> &output_tensors) const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ std::vector<Tensor> EltwiseBinary::create_output_tensors(
}
auto mem_config = this->output_mem_config;
mem_config.shard_spec = shard_spec;
return {create_sharded_device_tensor(this->compute_output_shapes(input_tensors).at(0), this->output_dtype, Layout::TILE, input_tensor_a.device(), mem_config)};
return {create_device_tensor(this->compute_output_shapes(input_tensors).at(0), this->output_dtype, Layout::TILE, input_tensor_a.device(), mem_config)};
}
return operation::generic_create_output_tensors(*this, input_tensors, this->output_dtype, Layout::TILE, this->output_mem_config);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ std::vector<Tensor> EltwiseUnary::create_output_tensors(const std::vector<Tensor
const auto& input_tensor = input_tensors.at(0);
if (this->output_mem_config.is_sharded()) {
Shape output_shape = compute_output_shapes(input_tensors).at(0);
return {create_sharded_device_tensor(
return {create_device_tensor(
output_shape,
input_tensor.get_dtype(),
input_tensor.get_layout(),
Expand Down
2 changes: 1 addition & 1 deletion tt_eager/tt_dnn/op_library/fold/fold_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ std::vector<Tensor> Fold::create_output_tensors(const std::vector<Tensor> &input
mem_config.shard_spec->shape[0] /= stride_h * stride_w;
mem_config.shard_spec->shape[1] *= stride_h * stride_w;

return {create_sharded_device_tensor(
return {create_device_tensor(
compute_output_shapes(input_tensors).at(0),
output_dtype,
input_tensor.get_layout(),
Expand Down
2 changes: 1 addition & 1 deletion tt_eager/tt_dnn/op_library/groupnorm/groupnorm_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1841,7 +1841,7 @@ std::vector<Tensor> GroupNorm::create_output_tensors(const std::vector<Tensor> &
} else {
auto mem_config = this->output_mem_config;
mem_config.shard_spec = input_tensor.shard_spec();
return {create_sharded_device_tensor(this->compute_output_shapes(input_tensors).at(0), program_config.out_data_format, Layout::ROW_MAJOR, input_tensor.device(), mem_config)};
return {create_device_tensor(this->compute_output_shapes(input_tensors).at(0), program_config.out_data_format, Layout::ROW_MAJOR, input_tensor.device(), mem_config)};
}
}
operation::ProgramWithCallbacks GroupNorm::create_program(
Expand Down
2 changes: 1 addition & 1 deletion tt_eager/tt_dnn/op_library/layernorm/layernorm_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ std::vector<Tensor> LayerNorm::create_output_tensors(const std::vector<Tensor> &
} else {
auto mem_config = this->output_mem_config;
mem_config.shard_spec = input_tensor.shard_spec().value();
return {create_sharded_device_tensor(this->compute_output_shapes(input_tensors).at(0), input_tensors.at(0).get_dtype(), Layout::TILE, input_tensor.device(), mem_config)};
return {create_device_tensor(this->compute_output_shapes(input_tensors).at(0), input_tensors.at(0).get_dtype(), Layout::TILE, input_tensor.device(), mem_config)};
}
} else {
return operation::generic_create_output_tensors(*this, input_tensors, input_tensor.get_dtype(), Layout::TILE, this->output_mem_config);
Expand Down
2 changes: 1 addition & 1 deletion tt_eager/tt_dnn/op_library/move/move_op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ inline Tensor move_sharded(const Tensor& input_tensor, const std::optional<Memor
// log_debug(LogOp, "OUTPUT SHARD SPEC: {}", out_shard_spec);
auto shard_mem_config = output_mem_config;
shard_mem_config.shard_spec = shard_spec;
auto output_tensor = create_sharded_device_tensor(input_shape, input_dtype, input_layout, input_tensor.device(), shard_mem_config);
auto output_tensor = create_device_tensor(input_shape, input_dtype, input_layout, input_tensor.device(), shard_mem_config);
if (input_tensor.buffer()->address() == output_tensor.buffer()->address()) {
tt::log_debug(tt::LogOp, "WARNING: No space to move the tensor. Move op's input address and output address are equal: {}", input_address);
return output_tensor;
Expand Down
Loading

0 comments on commit 07149f7

Please sign in to comment.