diff --git a/include/onnxruntime/core/framework/tensor_shape.h b/include/onnxruntime/core/framework/tensor_shape.h index b3783696b8d78..e7cff9a820af6 100644 --- a/include/onnxruntime/core/framework/tensor_shape.h +++ b/include/onnxruntime/core/framework/tensor_shape.h @@ -9,6 +9,7 @@ #include #include "core/common/gsl.h" #include "onnxruntime_config.h" +#include "interface/framework/tensor.h" #ifndef DISABLE_ABSEIL // Need to include abseil inlined_vector.h header directly here @@ -65,7 +66,7 @@ inline gsl::span ToConstSpan(const TensorShapeVector& vec) { return gsl::make_span(vec); } -class TensorShape { +class TensorShape: public interface::ITensorShape { // We use negative numbers for unknown symbolic dimension. Each negative // number represents a unique symbolic dimension. public: @@ -126,6 +127,12 @@ class TensorShape { */ gsl::span GetDims() const { return values_; } + const int64_t* GetDimensions(size_t& num_dims) const override { + auto dims = GetDims(); + num_dims = dims.size(); + return dims.data(); + } + TensorShapeVector AsShapeVector() const { return ToShapeVector(values_); } @@ -137,6 +144,9 @@ class TensorShape { */ int64_t Size() const; + int64_t NumberOfElements() const override { + return Size(); + } /** Return the total number of elements up to the specified dimension. If the dimension interval is empty (dimension == 0), return 1. diff --git a/include/onnxruntime/interface/framework/tensor.h b/include/onnxruntime/interface/framework/tensor.h new file mode 100644 index 0000000000000..b465093b7a2dc --- /dev/null +++ b/include/onnxruntime/interface/framework/tensor.h @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +namespace onnxruntime { + +namespace interface { +struct ITensorShape { + virtual int64_t NumberOfElements() const = 0; + virtual const int64_t* GetDimensions(size_t& num_dims) const = 0; +}; + +} // namespace interface + +} // namespace onnxruntime \ No newline at end of file