Skip to content

Commit

Permalink
align tensor hierarchy
Browse files Browse the repository at this point in the history
  • Loading branch information
RandyShuai committed Nov 20, 2023
1 parent 7defe5d commit c04c1f2
Show file tree
Hide file tree
Showing 10 changed files with 252 additions and 168 deletions.
2 changes: 1 addition & 1 deletion cmake/deps.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ abseil_cpp;https://github.com/abseil/abseil-cpp/archive/refs/tags/20230802.0.zip
cxxopts;https://github.com/jarro2783/cxxopts/archive/3c73d91c0b04e2b59462f0a741be8c07024c1bc0.zip;6c6ca7f8480b26c8d00476e0e24b7184717fe4f0
date;https://github.com/HowardHinnant/date/archive/refs/tags/v3.0.1.zip;2dac0c81dc54ebdd8f8d073a75c053b04b56e159
dlpack;https://github.com/dmlc/dlpack/archive/refs/tags/v0.6.zip;4d565dd2e5b31321e5549591d78aa7f377173445
eigen;https://gitlab.com/libeigen/eigen/-/archive/3.4/eigen-3.4.zip;738915dcf6d17856a85111f78c3c8d84384461aa
eigen;https://gitlab.com/libeigen/eigen/-/archive/3.4/eigen-3.4.zip;90414f534834cb5041a4c5ddece24978033d87c7
flatbuffers;https://github.com/google/flatbuffers/archive/refs/tags/v1.12.0.zip;ba0a75fd12dbef8f6557a74e611b7a3d0c5fe7bf
fp16;https://github.com/Maratyszcza/FP16/archive/0a92994d729ff76a58f692d3028ca1b64b145d91.zip;b985f6985a05a1c03ff1bb71190f66d8f98a1494
fxdiv;https://github.com/Maratyszcza/FXdiv/archive/63058eff77e11aa15bf531df5dd34395ec3017c8.zip;a5658f4036402dbca7cebee32be57fb8149811e1
Expand Down
38 changes: 19 additions & 19 deletions include/onnxruntime/core/framework/op_kernel_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace concurrency {
class ThreadPool;
}

class OpKernelContext : public interface::IKernelContext {
class OpKernelContext /*: public interface::IKernelContext*/ {
public:
using ArgMap = std::unordered_map<std::string, size_t>;

Expand Down Expand Up @@ -45,11 +45,11 @@ class OpKernelContext : public interface::IKernelContext {
}
}

const void* InputData(int index) const override {
//todo - check tensor type
const auto* tensor = Input<onnxruntime::Tensor>(index);
return tensor->DataRaw();
}
//const void* InputData(int index) const override {
// //todo - check tensor type
// const auto* tensor = Input<onnxruntime::Tensor>(index);
// return tensor->DataRaw();
//}

// Fetch a required input, enforcing that it is present.
// Fetch a required input, enforcing that it is present. Fetch a required input, enforcing that it is present.
Expand Down Expand Up @@ -78,19 +78,19 @@ class OpKernelContext : public interface::IKernelContext {
Tensor* Output(int index, const std::vector<int64_t>& shape);
Tensor* Output(int index, const std::initializer_list<int64_t>& shape);

void* AllocateOutput(int index, const interface::TensorShape& shape) override {
auto* tensor = Output(index, shape);
ORT_ENFORCE(tensor);
return tensor->MutableDataRaw();
}

const int64_t* InputShape(int index, size_t* num_dims) const override {
const auto* tensor = Input<onnxruntime::Tensor>(index);
const auto& shape = tensor->Shape();
auto dims = shape.GetDims();
*num_dims = dims.size();
return dims.data();
};
//void* AllocateOutput(int index, const interface::TensorShape& shape) override {
// auto* tensor = Output(index, shape);
// ORT_ENFORCE(tensor);
// return tensor->MutableDataRaw();
//}

//const int64_t* InputShape(int index, size_t* num_dims) const override {
// const auto* tensor = Input<onnxruntime::Tensor>(index);
// const auto& shape = tensor->Shape();
// auto dims = shape.GetDims();
// *num_dims = dims.size();
// return dims.data();
//};

// Fetch a required tensor output, enforcing that it is present.
Tensor& RequiredOutput(int index, const TensorShape& shape) {
Expand Down
17 changes: 16 additions & 1 deletion include/onnxruntime/core/framework/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "onnxruntime_config.h"
#include "core/framework/data_types.h"
#include "core/framework/data_types_internal.h"
#include "interface/framework/tensor.h"

struct OrtValue;

Expand All @@ -36,7 +37,7 @@ namespace onnxruntime {
it, and won't do any allocation / release.
*/

class Tensor final {
class Tensor final: public interface::ITensor {
public:
// NB! Removing Create() methods returning unique_ptr<Tensor>.
// Still available in other EPs that are dynamically linked.
Expand Down Expand Up @@ -293,6 +294,20 @@ class Tensor final {
void SetShapeAndStrides(const TensorShape& new_shape, gsl::span<const int64_t> new_strides);
#endif

const interface::ITensorShape& GetShape() const override {
return shape_;
}

interface::TensorDataType GetDataType() const override;

const void* GetRawData() const override {
return DataRaw();
}

void* GetMutableRawData() override {
return MutableDataRaw();
}

// More API methods.
private:
void Init(MLDataType elt_type,
Expand Down
102 changes: 18 additions & 84 deletions include/onnxruntime/interface/framework/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,84 +5,18 @@

#include <cassert>
#include <numeric>
#include "interface/common/data_types.h"
#include "interface/framework/tensor.h"
#include "core/common/status.h"

namespace onnxruntime {

namespace interface {

using TensorShape = std::vector<int64_t>;

struct IKernelContext {
virtual ~IKernelContext() = default;
virtual const void* InputData(int index) const = 0;
virtual const int64_t* InputShape(int index, size_t* num_dims) const = 0;
virtual void* AllocateOutput(int index, const TensorShape& shape) = 0;
};

struct IArg {
virtual ~IArg() = default;
};

using ArgPtr = std::unique_ptr<IArg>;
using ArgPtrs = std::vector<ArgPtr>;

struct ITensor : public IArg {
//using MyType = ITensor<T>;
ITensor(IKernelContext* ctx = {}, int index = -1) : ctx_(ctx), index_(index){};
const TensorShape& Shape() { return shape_; }
size_t NumberOfElements() const {
if (shape_.empty()) {
return 0;
} else {
return std::accumulate(shape_.begin(), shape_.end(), 1ULL, std::multiplies<size_t>{});
}
}
protected:
IKernelContext* ctx_ = {};
int index_ = {};
TensorShape shape_;
};

template <typename T>
struct TensorView : public ITensor {
TensorView(IKernelContext* ctx, int index) : ITensor(ctx, index) {
data_ = reinterpret_cast<const T*>(ctx->InputData(index));
size_t num_dims = 0;
const auto* dims = ctx->InputShape(index, &num_dims);
shape_ = TensorShape{dims, dims + num_dims};
}
TensorView(const T* data, const TensorShape& shape) : data_(data) {
shape_ = shape;
};
const T* Data() const {
return data_;
}

protected:
const T* data_ = {};
};

template <typename T>
struct Tensor : public ITensor {
Tensor(IKernelContext* ctx, int index) : ITensor(ctx, index) {}
Tensor(T* data, const TensorShape& shape) : data_(data) {
shape_ = shape;
};
T* Allocate(const TensorShape& shape) {
if (data_) {
return data_;
} else {
// assert ctx
shape_ = shape;
data_ = reinterpret_cast<T*>(ctx_->AllocateOutput(index_, shape_));
return data_;
}
}

protected:
T* data_ = {};
virtual const ITensor& GetInputTensor(int index) const = 0;
virtual const ITensorShape& GetInputShape(int index) const = 0;
virtual ITensor* AllocOutputTensor(int index, const int64_t* dims, size_t num_dims) = 0;
};

struct IKernelInfo {
Expand All @@ -91,31 +25,34 @@ struct IKernelInfo {

struct IKernel {
explicit IKernel() = default;
explicit IKernel(const IKernelInfo&){};
virtual ~IKernel() = default;
virtual onnxruntime::Status Compute(IKernelContext*) const = 0;

template <int ith_input, int ith_output, typename... Ts>
static typename std::enable_if<sizeof...(Ts) == 0, std::tuple<>>::type
CreateTuple(IKernelContext*, ArgPtrs&) {
CreateTuple(IKernelContext*, IArgPtrs&) {
return std::make_tuple();
}

// inputs
template <int ith_input, int ith_output, typename T, typename... Ts>
static typename std::enable_if<std::is_same<T, TensorView<float>&>::value, std::tuple<T, Ts...>>::type
CreateTuple(IKernelContext* context, ArgPtrs& args) {
args.push_back(std::make_unique<TensorView<float>>(context, ith_input));
static typename std::enable_if<std::is_same<T, IReadonlyTensor<float>&>::value, std::tuple<T, Ts...>>::type
CreateTuple(IKernelContext* context, IArgPtrs& args) {
const ITensor& input_tensor = context->GetInputTensor(ith_input);
args.push_back(std::make_unique<ReadonlyTensor<float>>(input_tensor));
std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*args.back().get())};
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args);
return std::tuple_cat(current, next);
}

// outputs
template <int ith_input, int ith_output, typename T, typename... Ts>
static typename std::enable_if<std::is_same<T, Tensor<float>&>::value, std::tuple<T, Ts...>>::type
CreateTuple(IKernelContext* context, ArgPtrs& args) {
args.push_back(std::make_unique<Tensor<float>>(context, ith_output));
static typename std::enable_if<std::is_same<T, IMutableTensor<float>&>::value, std::tuple<T, Ts...>>::type
CreateTuple(IKernelContext* context, IArgPtrs& args) {
MutableTensor<float>::AllocFn alloc_fn = [context](const int64_t* dims, size_t num_dims) {
return context->AllocOutputTensor(ith_output, dims, num_dims);
};
args.push_back(std::make_unique<MutableTensor<float>>(alloc_fn));
std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*args.back().get())};
auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args);
return std::tuple_cat(current, next);
Expand All @@ -128,7 +65,7 @@ struct FnKernel : public IKernel {
FnKernel(ComputeFn compute_fn) : compute_fn_(compute_fn) {}

onnxruntime::Status Compute(IKernelContext* context) const override {
ArgPtrs args;
IArgPtrs args;
auto t = CreateTuple<0, 0, Args...>(context, args);
return std::apply([this](Args const&... t_args) { return compute_fn_(t_args...); }, t);
}
Expand All @@ -151,19 +88,16 @@ struct StructKernel : public IKernel {

template <typename... Args>
onnxruntime::Status InvokeCompute(ComputeFn<Args...>, IKernelContext* context) const {
ArgPtrs args;
IArgPtrs args;
auto t = CreateTuple<0, 0, Args...>(context, args);
return std::apply([this](Args const&... t_args) { return kernel_->Compute(t_args...); }, t);
}
std::unique_ptr<K> kernel_;
};

struct IKernelBuilder {
// IKernelBuilder() = default;
// IKernelBuilder(const IKernelBuilder&) = delete;
explicit IKernelBuilder() = default;
IKernelBuilder(const IKernelBuilder&) = delete;

virtual ~IKernelBuilder() = default;
virtual IKernelBuilder& Provider(const char*) = 0;
virtual IKernelBuilder& SetDomain(const char*) = 0;
Expand All @@ -175,7 +109,7 @@ struct IKernelBuilder {
template <size_t, size_t, typename... Ts>
typename std::enable_if<sizeof...(Ts) >= 0, IKernelBuilder&>::type
ParseArgs() {
//todo - generate constraints by args...
// todo - generate constraints by args...
return *this;
}

Expand Down
86 changes: 86 additions & 0 deletions include/onnxruntime/interface/framework/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
// Licensed under the MIT License.

#pragma once
#include <numeric>
#include <vector>
#include <memory>
#include <functional>
#include <interface/common/data_types.h>

namespace onnxruntime {

Expand All @@ -11,6 +16,87 @@ struct ITensorShape {
virtual const int64_t* GetDimensions(size_t& num_dims) const = 0;
};

struct IArg {
virtual ~IArg() = default;
};

using IArgPtr = std::unique_ptr<IArg>;
using IArgPtrs = std::vector<IArgPtr>;

struct ITensor : public IArg {
virtual const ITensorShape& GetShape() const = 0;
virtual const void* GetRawData() const {
return {};
}
virtual void* GetMutableRawData() {
return {};
}
virtual TensorDataType GetDataType() const = 0;
};

// readonly tensors
template <typename T>
struct IReadonlyTensor : public ITensor {
virtual const T* GetData() const = 0;
};

template <typename T>
struct ReadonlyTensor : public IReadonlyTensor<T> {
using DataType = const T*;
ReadonlyTensor(const ITensor& readonly_tensor) : readonly_tensor_(readonly_tensor) {}
const ITensorShape& GetShape() const override { return readonly_tensor_.GetShape(); }
DataType GetData() const override { return reinterpret_cast<DataType>(readonly_tensor_.GetRawData()); }
TensorDataType GetDataType() const override { return readonly_tensor_.GetDataType(); }
const ITensor& readonly_tensor_;
};

// mutable tensors
template <typename T>
struct IMutableTensor : public ITensor {
virtual T* Allocate(const ITensorShape& shape) = 0;
};

template <typename T>
struct MutableTensor : public IMutableTensor<T> {
using AllocFn = std::function<ITensor*(const int64_t*, size_t)>;
MutableTensor(AllocFn alloc_fn) : alloc_fn_(alloc_fn) {}
const ITensorShape& GetShape() const override {
// assert mutable_tensor_
return mutable_tensor_->GetShape();
}
T* Allocate(const ITensorShape& shape) override {
if (!mutable_tensor_) {
size_t num_dims = 0;
const int64_t* dims = shape.GetDimensions(num_dims);
mutable_tensor_ = alloc_fn_(dims, num_dims);
}
return reinterpret_cast<T*>(mutable_tensor_->GetMutableRawData());
}
TensorDataType GetDataType() const override {
// assert mutable_tensor_
return mutable_tensor_->GetDataType();
}
AllocFn alloc_fn_;
ITensor* mutable_tensor_ = {};
};

template <typename T>
struct MutableTensorRef : public IMutableTensor<T> {
MutableTensorRef(IMutableTensor<T>& ref) : ref_(ref) {}
const ITensorShape& GetShape() const override {
return ref_.GetShape();
}
T* Allocate(const ITensorShape& shape) override {
return ref_.Allocate(shape);
}
TensorDataType GetDataType() const override {
return ref_.GetDataType();
}
IMutableTensor<T>& ref_;
};

// struct ITensorSeq

} // namespace interface

} // namespace onnxruntime
Loading

0 comments on commit c04c1f2

Please sign in to comment.