diff --git a/include/cppflow/model.h b/include/cppflow/model.h index d3a8d3e..9b378b4 100644 --- a/include/cppflow/model.h +++ b/include/cppflow/model.h @@ -25,6 +25,7 @@ namespace cppflow { FROZEN_GRAPH, }; + model() = default; explicit model(const std::string& filename, const TYPE type=TYPE::SAVED_MODEL); std::vector get_operations() const; diff --git a/include/cppflow/tensor.h b/include/cppflow/tensor.h index 345e5b2..a559387 100644 --- a/include/cppflow/tensor.h +++ b/include/cppflow/tensor.h @@ -35,6 +35,16 @@ namespace cppflow { template tensor(const std::vector& values, const std::vector& shape); + /** + * Creates a flat tensor with the given values, and specified length and shape + * @tparam T A type that can be convertible into a tensor + * @param values The values to be converted + * @param len The length of the converted tensor + * @param shape The shape of the converted tensor + */ + template + tensor(T *values, size_t len, const std::vector& shape); + /** * Creates a flat tensor with the given values * @tparam T A type that can be convertible into a tensor @@ -69,6 +79,23 @@ namespace cppflow { */ datatype dtype() const; + /** + * Converts the tensor into a pointer of primitive type T + * @tparam T The c++ type (must be equivalent to the tensor type) + * @return A pointer of type T representing the flat tensor + */ + template + T *get_raw_data() const; + + /** + * Converts the tensor into a pointer of primitive type T + * @tparam T The c++ type (must be equivalent to the tensor type) + * @return A pointer of type T representing the flat tensor + * @return The size of the array + */ + template + T *get_raw_data(size_t &size) const; + /** * Converts the tensor into a C++ vector * @tparam T The c++ type (must be equivalent to the tensor type) @@ -77,7 +104,6 @@ namespace cppflow { template std::vector get_data() const; - ~tensor() = default; tensor(const tensor &tensor) = default; tensor(tensor &&tensor) = default; @@ -140,6 +166,11 @@ namespace cppflow { tensor::tensor(const std::vector& values, const std::vector& shape) : tensor(deduce_tf_type(), values.data(), values.size() * sizeof(T), shape) {} + + template + tensor::tensor(T *values, size_t len, const std::vector& shape) : + tensor(deduce_tf_type(), values, len * sizeof(T), shape) {} + template tensor::tensor(const std::initializer_list& values) : tensor(std::vector(values), {(int64_t) values.size()}) {} @@ -213,28 +244,49 @@ namespace cppflow { return res; } + template - std::vector tensor::get_data() const { + T *tensor::get_raw_data(size_t &size) const { // Check if asked datatype and tensor datatype match if (this->dtype() != deduce_tf_type()) { auto type1 = cppflow::to_string(deduce_tf_type()); auto type2 = cppflow::to_string(this->dtype()); - auto error = "Datatype in function get_data (" + type1 + ") does not match tensor datatype (" + type2 + ")"; + auto error = "Datatype in function get_raw_data (" + type1 + ") does not match tensor datatype (" + type2 + ")"; throw std::runtime_error(error); } - auto res_tensor = get_tensor(); // Check tensor data is not empty auto raw_data = TF_TensorData(res_tensor.get()); //this->error_check(raw_data != nullptr, "Tensor data is empty"); - size_t size = TF_TensorByteSize(res_tensor.get()) / TF_DataTypeSize(TF_TensorType(res_tensor.get())); + // Get size of array + size = TF_TensorByteSize(res_tensor.get()) / TF_DataTypeSize(TF_TensorType(res_tensor.get())); // Convert to correct type const auto T_data = static_cast(raw_data); + + return T_data; + } + + template + T *tensor::get_raw_data() const { + + // Get the raw data and return + size_t size = 0; + const auto T_data = this->get_raw_data(size); + + return T_data; + } + + template + std::vector tensor::get_data() const { + + // Get the raw data and size of array + size_t size = 0; + const auto T_data = this->get_raw_data(size); std::vector r(T_data, T_data + size); return r;