Skip to content

Commit

Permalink
add tensorshape interface
Browse files Browse the repository at this point in the history
  • Loading branch information
RandyShuai committed Nov 17, 2023
1 parent 48dcc85 commit 7defe5d
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 1 deletion.
12 changes: 11 additions & 1 deletion include/onnxruntime/core/framework/tensor_shape.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <cstring>
#include "core/common/gsl.h"
#include "onnxruntime_config.h"
#include "interface/framework/tensor.h"

#ifndef DISABLE_ABSEIL
// Need to include abseil inlined_vector.h header directly here
Expand Down Expand Up @@ -65,7 +66,7 @@ inline gsl::span<const int64_t> ToConstSpan(const TensorShapeVector& vec) {
return gsl::make_span(vec);
}

class TensorShape {
class TensorShape: public interface::ITensorShape {
// We use negative numbers for unknown symbolic dimension. Each negative
// number represents a unique symbolic dimension.
public:
Expand Down Expand Up @@ -126,6 +127,12 @@ class TensorShape {
*/
gsl::span<const int64_t> GetDims() const { return values_; }

const int64_t* GetDimensions(size_t& num_dims) const override {
auto dims = GetDims();
num_dims = dims.size();
return dims.data();
}

TensorShapeVector AsShapeVector() const {
return ToShapeVector(values_);
}
Expand All @@ -137,6 +144,9 @@ class TensorShape {
*/
int64_t Size() const;

int64_t NumberOfElements() const override {
return Size();
}
/**
Return the total number of elements up to the specified dimension.
If the dimension interval is empty (dimension == 0), return 1.
Expand Down
16 changes: 16 additions & 0 deletions include/onnxruntime/interface/framework/tensor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

namespace onnxruntime {

namespace interface {
struct ITensorShape {
virtual int64_t NumberOfElements() const = 0;
virtual const int64_t* GetDimensions(size_t& num_dims) const = 0;
};

} // namespace interface

} // namespace onnxruntime

0 comments on commit 7defe5d

Please sign in to comment.