Skip to content

Commit

Permalink
#0: Remove overhead in calling functions wrapped in tensor_impl_wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
yan-zaretskiy committed May 25, 2024
1 parent 9c11c6f commit a4cdb32
Show file tree
Hide file tree
Showing 8 changed files with 851 additions and 726 deletions.
1 change: 0 additions & 1 deletion tt_eager/tensor/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@

set(TENSOR_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/tensor_impl_wrapper.cpp
${CMAKE_CURRENT_SOURCE_DIR}/tensor_impl.cpp
${CMAKE_CURRENT_SOURCE_DIR}/tensor.cpp
${CMAKE_CURRENT_SOURCE_DIR}/types.cpp
Expand Down
1 change: 0 additions & 1 deletion tt_eager/tensor/module.mk
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
TENSOR_SRCS = \
tt_eager/tensor/tensor_impl_wrapper.cpp \
tt_eager/tensor/tensor_impl.cpp \
tt_eager/tensor/tensor.cpp \
tt_eager/tensor/types.cpp \
Expand Down
639 changes: 356 additions & 283 deletions tt_eager/tensor/tensor.cpp

Large diffs are not rendered by default.

299 changes: 188 additions & 111 deletions tt_eager/tensor/tensor_impl.cpp

Large diffs are not rendered by default.

206 changes: 158 additions & 48 deletions tt_eager/tensor/tensor_impl.hpp

Large diffs are not rendered by default.

158 changes: 0 additions & 158 deletions tt_eager/tensor/tensor_impl_wrapper.cpp

This file was deleted.

76 changes: 43 additions & 33 deletions tt_eager/tensor/tensor_impl_wrapper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,38 +4,48 @@

#pragma once

#include "tensor/tensor.hpp"
#include "tensor/tensor_impl.hpp"

namespace tt {

namespace tt_metal {

namespace tensor_impl {


uint32_t element_size_bytes_wrapper(DataType dtype);

uint32_t packed_buffer_size_bytes_wrapper(DataType dtype, uint32_t volume_unpacked_data);

Tensor to_host_wrapper(const Tensor &tensor, bool blocking = true);

Tensor to_host_wrapper_sharded(const Tensor &tensor);

Tensor to_extract_shard_wrapper(const Tensor &tensor, const uint32_t & core_id);

Tensor to_device_wrapper(const Tensor &tensor, Device *target_device, const MemoryConfig &mem_config, std::optional<std::reference_wrapper<CommandQueue>> queue = std::nullopt);

Tensor to_layout_wrapper(const Tensor &tensor, Layout target_layout);

Tensor pad_wrapper(const Tensor &tensor, const Shape &output_tensor_shape, const Shape &input_tensor_start, float pad_value);

Tensor unpad_wrapper(const Tensor &tensor, const Shape &output_tensor_start, const Shape &output_tensor_end);

std::string to_string_wrapper(const Tensor &tensor);

} // namespace tensor_impl

} // namespace tt_metal

} // namespace tt
namespace tt::tt_metal::tensor_impl {

// Utility to convert runtime DataType to compile-time constant and dispatch the function call
template <typename Func, typename... Args>
auto dispatch(DataType dtype, Func &&func, Args &&...args) {
switch (dtype) {
case DataType::BFLOAT16: return func.template operator()<bfloat16>(static_cast<Args &&>(args)...);
case DataType::FLOAT32: return func.template operator()<float>(static_cast<Args &&>(args)...);
case DataType::INT32: return func.template operator()<int32_t>(static_cast<Args &&>(args)...);
case DataType::UINT32: return func.template operator()<uint32_t>(static_cast<Args &&>(args)...);
case DataType::UINT16: return func.template operator()<uint16_t>(static_cast<Args &&>(args)...);
case DataType::BFLOAT8_B: return func.template operator()<bfloat8_b>(static_cast<Args &&>(args)...);
case DataType::BFLOAT4_B: return func.template operator()<bfloat4_b>(static_cast<Args &&>(args)...);
default: TT_THROW("Unsupported data type");
}
}

#define AS_LAMBDA(func) []<typename T>(auto &&...args) { return func<T>(std::forward<decltype(args)>(args)...); }

#define WRAP_FUNCTION(func) \
template <typename... Args> \
auto func##_wrapper(Args &&...args) { \
return dispatch( \
std::get<0>(std::forward_as_tuple(args...)).get_dtype(), AS_LAMBDA(func), std::forward<Args>(args)...); \
}

inline uint32_t packed_buffer_size_bytes_wrapper(DataType dtype, uint32_t volume_unpacked_data) {
return dispatch(dtype, AS_LAMBDA(packed_buffer_size_bytes), volume_unpacked_data);
}

WRAP_FUNCTION(to_host)
WRAP_FUNCTION(extract_shard)
WRAP_FUNCTION(to_host_sharded)
WRAP_FUNCTION(to_device)
WRAP_FUNCTION(to_layout)
WRAP_FUNCTION(pad)
WRAP_FUNCTION(unpad)
WRAP_FUNCTION(to_string)

#undef WRAP_FUNCTION
#undef AS_LAMBDA

} // namespace tt::tt_metal::tensor_impl
Loading

0 comments on commit a4cdb32

Please sign in to comment.