Skip to content

Commit

Permalink
#13127: Add compute_output_specs which allows user to return output t…
Browse files Browse the repository at this point in the history
…ensor layout

- Op must implement either compute_output_shapes + create_output_tensors OR compute_output_specs
  * If compute_output_specs is implemented, infra will handle tensor creation
  * Also works for optional output tensors
- Update generic reduction which carries over padding along N and C
  * pytest tests/tt_eager/python_api_testing/unit_testing/misc/test_min_max.py::test_min_max_for_dim_hw
  * pytest tests/tt_eager/python_api_testing/unit_testing/misc/test_sharded.py::test_sharded_reduce_h
- Update transpose op which transposes and maintains padding
  * pytest tests/tt_eager/python_api_testing/unit_testing/misc/test_transpose.py
- Update slice op which uses optional output tensors
  * pytest tests/ttnn/unit_tests/operations/test_slice.py::test_slice_output_tensor_tile
  • Loading branch information
TT-BrianLiu committed Nov 8, 2024
1 parent b58292f commit 5de5bc9
Show file tree
Hide file tree
Showing 10 changed files with 162 additions and 129 deletions.
83 changes: 72 additions & 11 deletions ttnn/cpp/ttnn/operation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,22 @@ constexpr bool implements_validate_with_output_tensors_and_optional_input_tensor
const OptionalTensors&>; // optional output_tensors
}

template <class T, class... Args>
using has_compute_output_shapes_t = decltype(std::declval<T>().compute_output_shapes(std::declval<Args>()...));

template <class T>
constexpr bool implements_compute_output_shapes() {
return std::experimental::is_detected_v<has_compute_output_shapes_t, T, const Tensors&>;
}

template <class T, class... Args>
using has_compute_output_specs_t = decltype(std::declval<T>().compute_output_specs(std::declval<Args>()...));

template <class T>
constexpr bool implements_compute_output_specs() {
return std::experimental::is_detected_v<has_compute_output_specs_t, T, const Tensors&>;
}

template <class T, class... Args>
using has_create_output_tensors_t = decltype(std::declval<T>().create_output_tensors(std::declval<Args>()...));

Expand Down Expand Up @@ -378,13 +394,41 @@ constexpr bool implements_get_parallelization_strategy() {
return std::experimental::is_detected_v<has_get_parallelization_strategy_t, T, const Tensors&>;
}

template <typename ConcreteOperation>
auto default_create_output_tensors(
const ConcreteOperation& operation,
const Tensors& input_tensors,
const OptionalTensors& optional_output_tensors) -> ProgramOutputTensors<ConcreteOperation> {
using OutputTensors = ProgramOutputTensors<ConcreteOperation>;
OutputTensors output_tensors;

if (!optional_output_tensors.empty() and optional_output_tensors[0].has_value()) {
output_tensors.reserve(optional_output_tensors.size());
for (const auto& optional_output_tensor : optional_output_tensors) {
TT_FATAL(optional_output_tensor.has_value(), "If using optional output tensors, all output tensors must have a value");
output_tensors.emplace_back(optional_output_tensor.value());
}
return output_tensors;
}
const auto& device = input_tensors.at(0).device();
const auto& output_specs = operation.compute_output_specs(input_tensors);
output_tensors.reserve(output_specs.size());
for (const auto& [output_shape, output_layout] : output_specs) {
output_tensors.emplace_back(create_device_tensor(
output_shape,
output_layout,
device));
}
return output_tensors;
}

} // namespace detail

template <class OutputTensorsT = Tensors>
struct DeviceOperation final {
using storage_t = std::array<std::byte, 1152>;
using OutputTensors = OutputTensorsT;
using ComputedShapes = std::variant<std::vector<tt::tt_metal::LegacyShape>, std::vector<ttnn::SimpleShape>>;
using ComputedShapes = std::variant<std::vector<tt::tt_metal::LegacyShape>, std::vector<ttnn::SimpleShape>, std::vector<ttnn::TensorSpec>>;

inline const std::string get_type_name() const { return this->get_type_name_impl_(this->type_erased_storage); }

Expand All @@ -396,6 +440,7 @@ struct DeviceOperation final {
this->type_erased_storage, input_tensors, optional_input_tensors, optional_output_tensors);
}

// TODO: Rename into compute_output_specs in later PR
inline const ComputedShapes compute_output_shapes(const Tensors& input_tensors) const {
return this->compute_output_shapes_impl_(this->type_erased_storage, input_tensors);
}
Expand Down Expand Up @@ -502,14 +547,6 @@ struct DeviceOperation final {
static_assert(
tt::stl::concepts::always_false_v<T>,
"You cannot implement both validate and validate_with_output_tensors");
} else if constexpr (
(detail::implements_validate_with_output_tensors<T>() or
detail::implements_validate_with_output_tensors_and_optional_input_tensors<T>()) and
not detail::implements_create_output_tensors_with_optional_output_tensors<T>()) {
static_assert(
tt::stl::concepts::always_false_v<T>,
"Operation doesn't implement create_output_tensors with ant optional output tensors argument "
"when using validate_with_output_tensors");
} else if constexpr (detail::implements_validate<T>() and not detail::implements_create_program<T>()) {
static_assert(
tt::stl::concepts::always_false_v<T>,
Expand Down Expand Up @@ -547,17 +584,41 @@ struct DeviceOperation final {
compute_output_shapes_impl_{
[](const storage_t& storage, const Tensors& input_tensors) -> const ComputedShapes {
const auto& operation = *reinterpret_cast<const std::decay_t<T>*>(&storage);
return operation.compute_output_shapes(input_tensors);
if constexpr (detail::implements_compute_output_shapes<T>() and detail::implements_compute_output_specs<T>()) {
static_assert(
tt::stl::concepts::always_false_v<T>,
"Operation cannot implement both compute_output_shapes and compute_output_specs");
} else if constexpr (detail::implements_compute_output_shapes<T>()) {
return operation.compute_output_shapes(input_tensors);
} else if constexpr (detail::implements_compute_output_specs<T>()) {
return operation.compute_output_specs(input_tensors);
} else {
static_assert(
tt::stl::concepts::always_false_v<T>,
"Operation must implement either compute_output_shapes or compute_output_specs");
}
}},
create_output_tensors_impl_{
[](const storage_t& storage,
const Tensors& input_tensors,
const OptionalTensors& output_tensors) -> const OutputTensors {
const auto& operation = *reinterpret_cast<const std::decay_t<T>*>(&storage);
if constexpr (detail::implements_create_output_tensors_with_optional_output_tensors<T>()) {
static_assert(
detail::implements_compute_output_shapes<T>(),
"Operation must implement compute_output_shapes if it implements create_output_tensors");
return operation.create_output_tensors(input_tensors, output_tensors);
} else {
} else if constexpr (detail::implements_create_output_tensors<T>()) {
static_assert(
detail::implements_compute_output_shapes<T>(),
"Operation must implement compute_output_shapes if it implements create_output_tensors");
return operation.create_output_tensors(input_tensors);
} else if constexpr (detail::implements_compute_output_specs<T>()) {
return detail::default_create_output_tensors(operation, input_tensors, output_tensors);
} else {
static_assert(
tt::stl::concepts::always_false_v<T>,
"Operation must implement either create_output_tensors or compute_output_specs");
}
}},
create_program_impl_{
Expand Down
42 changes: 10 additions & 32 deletions ttnn/cpp/ttnn/operations/data_movement/slice/device/slice_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,11 @@ void SliceDeviceOperation::validate_with_output_tensors(
TT_FATAL(this->slice_start[i] <= this->slice_end[i], "Error");
}
if(!output_tensors.empty() && output_tensors[0].has_value()){
const auto output_shape_required = this->compute_output_shapes(input_tensors)[0];
const auto output_shape_required = std::get<0>(this->compute_output_specs(input_tensors)[0]);
const auto& out_tensor = output_tensors[0].value();
TT_FATAL(out_tensor.get_legacy_shape() == output_shape_required, "The input tensors need a shape of {}, however the output tensor is only {}", output_shape_required, out_tensor.get_legacy_shape());
TT_FATAL(out_tensor.get_padded_shape() == output_shape_required, "The input tensors need a shape of {}, however the output tensor is only {}", output_shape_required, out_tensor.get_padded_shape());
}
auto output_tensor_shape = this->compute_output_shapes(input_tensors)[0];
auto output_tensor_shape = std::get<0>(this->compute_output_specs(input_tensors)[0]);
if (has_step) { // if all ones modify before passing in to function
TT_FATAL(input_tensor_a.get_layout() == Layout::ROW_MAJOR, "Strided slice is only supported for row major layout");
TT_FATAL(!input_tensor_a.is_sharded(), "Strided slice is not supported for sharded tensor");
Expand Down Expand Up @@ -117,40 +117,18 @@ void SliceDeviceOperation::validate_with_output_tensors(
}
}

std::vector<tt::tt_metal::LegacyShape> SliceDeviceOperation::compute_output_shapes(const std::vector<Tensor> &input_tensors) const {
SmallVector<uint32_t> out_shape;
auto rank = input_tensors[0].get_legacy_shape().rank();
out_shape.reserve(rank);
std::vector<ttnn::TensorSpec> SliceDeviceOperation::compute_output_specs(const std::vector<Tensor> &input_tensors) const {
const auto& input_tensor = input_tensors[0];
SmallVector<uint32_t> out_shape(input_tensor.get_logical_shape().rank());

auto output_dim_i = [this] (size_t i) {
return (this->slice_end[i] - this->slice_start[i] + this->step[i] - 1) / this->step[i];
};
for (uint32_t i = 0; i < rank; i++) {
out_shape.push_back(output_dim_i(i));
}
tt::tt_metal::LegacyShape output_tensor_shape(out_shape);
return {output_tensor_shape};
}

std::vector<Tensor> SliceDeviceOperation::create_output_tensors(
const std::vector<Tensor> &input_tensors, const std::vector<std::optional<Tensor>> &output_tensors) const {
if (!output_tensors.empty() && output_tensors[0].has_value()) {
return {output_tensors[0].value()};
}
const auto &input_tensor_a = input_tensors.at(0);
const auto shapes = compute_output_shapes(input_tensors);

if (input_tensor_a.is_sharded()) {
return {create_device_tensor(
shapes[0],
input_tensor_a.get_dtype(),
input_tensor_a.get_layout(),
input_tensor_a.device(),
this->output_mem_config)};
} else {
return operation::generic_create_output_tensors(
*this, input_tensors, input_tensor_a.get_dtype(), input_tensor_a.get_layout(), this->output_mem_config);
for (uint32_t i = 0; i < out_shape.size(); i++) {
out_shape[i] = output_dim_i(i);
}
ttnn::SimpleShape output_tensor_shape(std::move(out_shape));
return {ttnn::TensorSpec(output_tensor_shape, tt::tt_metal::TensorLayout(input_tensor.get_dtype(), PageConfig(input_tensor.get_layout()), this->output_mem_config))};
}

operation::ProgramWithCallbacks SliceDeviceOperation::create_program(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ struct SliceDeviceOperation {


void validate_with_output_tensors(const std::vector<Tensor>& input_tensors, const std::vector<std::optional<Tensor>>& output_tensors) const;
std::vector<tt::tt_metal::LegacyShape> compute_output_shapes(const std::vector<Tensor>& input_tensors) const;
std::vector<Tensor> create_output_tensors(const std::vector<Tensor>& input_tensors, const std::vector<std::optional<Tensor>>& output_tensors) const;
std::vector<ttnn::TensorSpec> compute_output_specs(const std::vector<Tensor>& input_tensors) const;
operation::ProgramWithCallbacks create_program(
const std::vector<Tensor>& input_tensors, std::vector<Tensor>& output_tensors) const;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,72 +77,57 @@ void Transpose::validate(const std::vector<Tensor> &input_tensors) const {
}


std::vector<tt::tt_metal::LegacyShape> Transpose::compute_output_shapes(const std::vector<Tensor> &input_tensors) const {
std::vector<ttnn::TensorSpec> Transpose::compute_output_specs(const std::vector<Tensor> &input_tensors) const {
const auto& input_tensor = input_tensors.at(0);
auto out_shape = input_tensor.get_legacy_shape();
auto padding = out_shape.padding();

// TODO: Remove usage of input/output padded shape
// - Get output alignment from input alignment and output dtype, layout, mem_config
// - Get shard spec from output strides (logical shape + alignment)?
auto output_shape = input_tensor.get_logical_shape();
auto output_padded_shape = input_tensor.get_padded_shape();

switch (this->dim){
case TransposeOpDim::CN:
std::swap(out_shape[0], out_shape[1]);
std::swap(padding[0], padding[1]);
std::swap(output_shape[0], output_shape[1]);
std::swap(output_padded_shape[0], output_padded_shape[1]);
break;
case TransposeOpDim::HC:
std::swap(out_shape[1], out_shape[2]);
std::swap(padding[1], padding[2]);
std::swap(output_shape[1], output_shape[2]);
std::swap(output_padded_shape[1], output_padded_shape[2]);
break;
case TransposeOpDim::WH:
std::swap(out_shape[2], out_shape[3]);
std::swap(padding[2], padding[3]);
std::swap(output_shape[2], output_shape[3]);
std::swap(output_padded_shape[2], output_padded_shape[3]);
break;
case TransposeOpDim::NH:
std::swap(out_shape[0], out_shape[2]);
std::swap(padding[0], padding[2]);
std::swap(output_shape[0], output_shape[2]);
std::swap(output_padded_shape[0], output_padded_shape[2]);
break;
case TransposeOpDim::NW:
std::swap(out_shape[0], out_shape[3]);
std::swap(padding[0], padding[3]);
std::swap(output_shape[0], output_shape[3]);
std::swap(output_padded_shape[0], output_padded_shape[3]);
break;
case TransposeOpDim::CW:
std::swap(out_shape[1], out_shape[3]);
std::swap(padding[1], padding[3]);
std::swap(output_shape[1], output_shape[3]);
std::swap(output_padded_shape[1], output_padded_shape[3]);
break;
}
return {tt::tt_metal::LegacyShape(out_shape, padding)};
}


std::vector<Tensor> Transpose::create_output_tensors(const std::vector<Tensor> &input_tensors) const {
const auto& input_tensor = input_tensors.at(0);
// This is only for WH
auto output_mem_config = this->output_mem_config;
if (this->output_mem_config.is_sharded()) {
if (this->dim == TransposeOpDim::WH) {
const auto& input_padded_shape = input_tensor.get_padded_shape();
ShardSpec shard_spec = input_tensor.shard_spec().value();
shard_spec.shape[0] = shard_spec.shape[0] / input_tensor.get_legacy_shape()[-2] * input_tensor.get_legacy_shape()[-1];
shard_spec.shape[1] = input_tensor.get_legacy_shape()[-2];
const auto output_shape = this->compute_output_shapes(input_tensors)[0];
auto mem_config = this->output_mem_config;
mem_config.shard_spec = shard_spec;
return {create_device_tensor(
output_shape,
input_tensor.get_dtype(),
input_tensor.get_layout(),
input_tensor.device(),
mem_config)};
shard_spec.shape[0] = shard_spec.shape[0] / input_padded_shape[-2] * input_padded_shape[-1];
shard_spec.shape[1] = input_padded_shape[-2];
output_mem_config.shard_spec = shard_spec;
} else if (this->dim == TransposeOpDim::HC) {
const auto output_shape = this->compute_output_shapes(input_tensors)[0];
auto mem_config = this->output_mem_config;
mem_config.shard_spec = input_tensor.shard_spec().value();
return {create_device_tensor(
output_shape,
input_tensor.get_dtype(),
input_tensor.get_layout(),
input_tensor.device(),
mem_config)};
output_mem_config.shard_spec = input_tensor.shard_spec().value();
} else {
TT_ASSERT(false, "Unsupported sharding");
}
}
return operation::generic_create_output_tensors(*this, input_tensors, input_tensor.get_dtype(), input_tensor.get_layout(), this->output_mem_config);
return {ttnn::TensorSpec(output_shape, TensorLayout::fromLegacyPaddedShape(input_tensor.get_dtype(), PageConfig(input_tensor.get_layout()), output_mem_config, ttnn::Shape(output_shape.view(), output_padded_shape.view())))};
}

operation::ProgramWithCallbacks Transpose::create_program(const std::vector<Tensor>& input_tensors, std::vector<Tensor> &output_tensors) const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ struct Transpose {
const MemoryConfig output_mem_config;

void validate(const std::vector<Tensor> &input_tensors) const;
std::vector<tt::tt_metal::LegacyShape> compute_output_shapes(const std::vector<Tensor> &input_tensors) const;
std::vector<Tensor> create_output_tensors(const std::vector<Tensor> &input_tensors) const;
std::vector<ttnn::TensorSpec> compute_output_specs(const std::vector<Tensor> &input_tensors) const;
operation::ProgramWithCallbacks create_program(const std::vector<Tensor>& input_tensors, std::vector<Tensor> &output_tensors) const;
TransposeOpParallelizationStrategy get_parallelization_strategy(const std::vector<Tensor> &input_tensors) const;

Expand Down
Loading

0 comments on commit 5de5bc9

Please sign in to comment.