Skip to content

Commit

Permalink
Back out "Revert D12967258: Support more data types in ONNXIFI transf…
Browse files Browse the repository at this point in the history
…orm" (pytorch#13812)

Summary:
Pull Request resolved: pytorch#13812

Original commit changeset: 2cf95bdc5ed8

Looks like in iOS, `uint64_t` is not the same as `size_t`. :( Fixed it here.

Reviewed By: houseroad

Differential Revision: D13017390

fbshipit-source-id: d33854ce341225aba372fb945c3704edc14f9411
  • Loading branch information
Yinghai Lu authored and facebook-github-bot committed Nov 11, 2018
1 parent 786f9ba commit d97ac82
Show file tree
Hide file tree
Showing 7 changed files with 170 additions and 52 deletions.
23 changes: 23 additions & 0 deletions caffe2/onnx/onnx_exporter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,29 @@ NodeProto AddShapeNode(const std::string& input, const std::string& output) {

} // namespace

::ONNX_NAMESPACE::TensorProto::DataType Caffe2TypeToOnnxType(
caffe2::TensorProto::DataType t) {
#define CAFFE2_TO_ONNX_TYPE(x) \
case (caffe2::TensorProto::x): \
return ::ONNX_NAMESPACE::TensorProto::x
switch (t) {
CAFFE2_TO_ONNX_TYPE(FLOAT);
CAFFE2_TO_ONNX_TYPE(BOOL);
CAFFE2_TO_ONNX_TYPE(INT8);
CAFFE2_TO_ONNX_TYPE(UINT8);
CAFFE2_TO_ONNX_TYPE(UINT16);
CAFFE2_TO_ONNX_TYPE(INT16);
CAFFE2_TO_ONNX_TYPE(INT32);
CAFFE2_TO_ONNX_TYPE(INT64);
CAFFE2_TO_ONNX_TYPE(FLOAT16);
default:
LOG(WARNING) << "Unsupported Caffe2 tensor type: " << t
<< ", fallback to FLOAT";
return ::ONNX_NAMESPACE::TensorProto::FLOAT;
}
#undef CAFFE2_TO_ONNX_TYPE
}

std::unordered_map<std::string, std::string> SsaRewrite(
caffe2::NetDef* init_net,
caffe2::NetDef* pred_net) {
Expand Down
4 changes: 4 additions & 0 deletions caffe2/onnx/onnx_exporter.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include "caffe2/core/common.h"
#include "caffe2/core/tensor.h"
#include "caffe2/onnx/helper.h"
#include "caffe2/proto/caffe2_pb.h"
#include "onnx/onnx_pb.h"
Expand Down Expand Up @@ -29,6 +30,9 @@ CAFFE2_API std::unordered_map<std::string, std::string> SsaRewrite(
caffe2::NetDef* init_net,
caffe2::NetDef* pred_net);

::ONNX_NAMESPACE::TensorProto::DataType Caffe2TypeToOnnxType(
caffe2::TensorProto::DataType t);

class CAFFE2_API OnnxExporter {
using SpecialOpConverter = ConvertedResult (OnnxExporter::*)(
const caffe2::OperatorDef&,
Expand Down
93 changes: 75 additions & 18 deletions caffe2/operators/onnxifi_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,75 @@ namespace caffe2 {

namespace {

void SetInputTensorDescriptorTypeAndBuffer(
const Tensor& cpu_tensor,
onnxTensorDescriptorV1* desc) {
if (cpu_tensor.template IsType<float>()) {
desc->dataType = ONNXIFI_DATATYPE_FLOAT32;
desc->buffer = reinterpret_cast<onnxPointer>(cpu_tensor.data<float>());
} else if (cpu_tensor.template IsType<int32_t>()) {
desc->dataType = ONNXIFI_DATATYPE_INT32;
desc->buffer = reinterpret_cast<onnxPointer>(cpu_tensor.data<int32_t>());
} else if (cpu_tensor.template IsType<int8_t>()) {
desc->dataType = ONNXIFI_DATATYPE_INT8;
desc->buffer = reinterpret_cast<onnxPointer>(cpu_tensor.data<int8_t>());
} else if (cpu_tensor.template IsType<uint8_t>()) {
desc->dataType = ONNXIFI_DATATYPE_UINT8;
desc->buffer = reinterpret_cast<onnxPointer>(cpu_tensor.data<uint8_t>());
} else if (cpu_tensor.template IsType<int64_t>()) {
desc->dataType = ONNXIFI_DATATYPE_INT64;
desc->buffer = reinterpret_cast<onnxPointer>(cpu_tensor.data<int64_t>());
} else if (cpu_tensor.template IsType<int16_t>()) {
desc->dataType = ONNXIFI_DATATYPE_INT16;
desc->buffer = reinterpret_cast<onnxPointer>(cpu_tensor.data<int16_t>());
} else if (cpu_tensor.template IsType<uint16_t>()) {
desc->dataType = ONNXIFI_DATATYPE_UINT16;
desc->buffer = reinterpret_cast<onnxPointer>(cpu_tensor.data<uint16_t>());
} else {
CAFFE_THROW(
"Unsupported tensor type in ONNXIFI: ", cpu_tensor.dtype().name());
}
}

void SetOutputTensorDescriptorTypeAndBuffer(
uint64_t onnxifi_type,
Tensor* cpu_tensor,
onnxTensorDescriptorV1* desc) {
desc->dataType = onnxifi_type;
switch (onnxifi_type) {
case (ONNXIFI_DATATYPE_FLOAT32):
desc->buffer =
reinterpret_cast<onnxPointer>(cpu_tensor->mutable_data<float>());
break;
case (ONNXIFI_DATATYPE_INT32):
desc->buffer =
reinterpret_cast<onnxPointer>(cpu_tensor->mutable_data<int32_t>());
break;
case (ONNXIFI_DATATYPE_INT8):
desc->buffer =
reinterpret_cast<onnxPointer>(cpu_tensor->mutable_data<int8_t>());
break;
case (ONNXIFI_DATATYPE_UINT8):
desc->buffer =
reinterpret_cast<onnxPointer>(cpu_tensor->mutable_data<uint8_t>());
break;
case (ONNXIFI_DATATYPE_INT64):
desc->buffer =
reinterpret_cast<onnxPointer>(cpu_tensor->mutable_data<int64_t>());
break;
case (ONNXIFI_DATATYPE_INT16):
desc->buffer =
reinterpret_cast<onnxPointer>(cpu_tensor->mutable_data<int16_t>());
break;
case (ONNXIFI_DATATYPE_UINT16):
desc->buffer =
reinterpret_cast<onnxPointer>(cpu_tensor->mutable_data<uint16_t>());
break;
default:
CAFFE_THROW("Unsupported ONXNIFI data type: ", onnxifi_type);
}
}

void BlobToTensorDescriptor(
const std::string& name,
Workspace* ws,
Expand All @@ -24,16 +93,7 @@ void BlobToTensorDescriptor(

// Data type
const auto& cpu_tensor = blob->template Get<TensorCPU>();
if (cpu_tensor.template IsType<float>()) {
desc->dataType = ONNXIFI_DATATYPE_FLOAT32;
desc->buffer = reinterpret_cast<onnxPointer>(cpu_tensor.data<float>());
} else if (cpu_tensor.template IsType<int64_t>()) {
desc->dataType = ONNXIFI_DATATYPE_INT64;
desc->buffer = reinterpret_cast<onnxPointer>(cpu_tensor.data<int64_t>());
} else if (cpu_tensor.template IsType<int32_t>()) {
desc->dataType = ONNXIFI_DATATYPE_INT32;
desc->buffer = reinterpret_cast<onnxPointer>(cpu_tensor.data<int32_t>());
}
SetInputTensorDescriptorTypeAndBuffer(cpu_tensor, desc);

// Set dims
const auto shape = cpu_tensor.sizes();
Expand Down Expand Up @@ -79,23 +139,20 @@ bool OnnxifiOp<float, CPUContext>::RunOnDevice() {
const auto tensor_dims = input_tensor.sizes();
auto& tensor_descriptor = input_desc_.at(i);
tensor_descriptor.tag = ONNXIFI_TAG_TENSOR_DESCRIPTOR_V1;
tensor_descriptor.dataType = ONNXIFI_DATATYPE_FLOAT32;
tensor_descriptor.memoryType = ONNXIFI_MEMORY_TYPE_CPU;
tensor_descriptor.dimensions = tensor_dims.size();
input_shapes_.emplace_back(tensor_dims.cbegin(), tensor_dims.cend());
tensor_descriptor.shape = input_shapes_.back().data();
tensor_descriptor.buffer =
reinterpret_cast<onnxPointer>(input_tensor.data<float>());
SetInputTensorDescriptorTypeAndBuffer(input_tensor, &tensor_descriptor);
}

for (unsigned i = 0U; i < OutputSize(); ++i) {
auto* output_tensor = Output(i);
std::vector<int64_t> tensor_dims;
SetOutputShape(i, &tensor_dims);
std::vector<size_t> tensor_dims;
uint64_t type = SetOutputShapeAndType(i, &tensor_dims);
output_tensor->Resize(tensor_dims);
auto& tensor_descriptor = output_desc_.at(i);
tensor_descriptor.tag = ONNXIFI_TAG_TENSOR_DESCRIPTOR_V1;
tensor_descriptor.dataType = ONNXIFI_DATATYPE_FLOAT32;
tensor_descriptor.memoryType = ONNXIFI_MEMORY_TYPE_CPU;
tensor_descriptor.dimensions = tensor_dims.size();
CAFFE_ENFORCE(
Expand All @@ -104,8 +161,8 @@ bool OnnxifiOp<float, CPUContext>::RunOnDevice() {
" has 0 dim");
output_shapes_.emplace_back(tensor_dims.cbegin(), tensor_dims.cend());
tensor_descriptor.shape = output_shapes_.back().data();
tensor_descriptor.buffer =
reinterpret_cast<onnxPointer>(output_tensor->mutable_data<float>());
SetOutputTensorDescriptorTypeAndBuffer(
type, output_tensor, &tensor_descriptor);
}

CAFFE_ENFORCE_EQ(
Expand Down
39 changes: 27 additions & 12 deletions caffe2/operators/onnxifi_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@ namespace caffe2 {

template <typename T, typename Context>
class OnnxifiOp final : public Operator<Context> {
struct TensorInfo {
TensorInfo() {}
TensorInfo(TensorInfo&&) = default;
TensorInfo& operator=(TensorInfo&&) = default;
std::vector<uint64_t> dims;
uint64_t onnxifi_type;
};

public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
OnnxifiOp(const OperatorDef& operator_def, Workspace* ws)
Expand All @@ -35,14 +43,15 @@ class OnnxifiOp final : public Operator<Context> {
output_desc_.back().name = output.c_str();

// For output, we try to get its output size hint
const std::string key = c10::str("output_size_hint_", output_idx);
auto output_size_hint = this->template GetRepeatedArgument<int>(key);
if (!output_size_hint.empty()) {
std::vector<int64_t> dims;
for (const auto v : output_size_hint) {
dims.push_back(v);
const std::string key = c10::str("output_shape_hint_", output_idx);
auto output_shape_hint = this->template GetRepeatedArgument<int>(key);
if (!output_shape_hint.empty()) {
TensorInfo info;
info.onnxifi_type = output_shape_hint.front();
for (int i = 1; i < output_shape_hint.size(); ++i) {
info.dims.push_back(output_shape_hint[i]);
}
output_size_hints_.emplace(output_idx, std::move(dims));
output_shape_hints_.emplace(output_idx, std::move(info));
}
++output_idx;
}
Expand Down Expand Up @@ -127,11 +136,17 @@ class OnnxifiOp final : public Operator<Context> {
bool RunOnDevice() override;

private:
void SetOutputShape(int output_idx, std::vector<int64_t>* dims) {
const auto it = output_size_hints_.find(output_idx);
if (it != output_size_hints_.end()) {
*dims = it->second;
uint64_t SetOutputShapeAndType(int output_idx, std::vector<size_t>* dims) {
uint64_t type = ONNXIFI_DATATYPE_FLOAT32;
const auto it = output_shape_hints_.find(output_idx);
if (it != output_shape_hints_.end()) {
std::copy(
it->second.dims.begin(),
it->second.dims.end(),
std::back_inserter(*dims));
type = it->second.onnxifi_type;
}
return type;
}

void BuildPropertyList(
Expand Down Expand Up @@ -163,7 +178,7 @@ class OnnxifiOp final : public Operator<Context> {
std::vector<std::vector<uint64_t>> output_shapes_;

// output shape hints
std::unordered_map<int, std::vector<int64_t>> output_size_hints_;
std::unordered_map<int, TensorInfo> output_shape_hints_;
};

} // namespace caffe2
54 changes: 36 additions & 18 deletions caffe2/opt/onnxifi_transformer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,27 @@ void AnnotateOpIndex(NetDef* net) {
}
}

// TODO(yinghai): Remove the awkward conversion between unordered_map and map
uint64_t OnnxifiDataType(caffe2::TensorProto::DataType t) {
#define CAFFE2_TO_ONNXIFI_TYPE(x, y) \
case (caffe2::TensorProto::x): \
return y
switch (t) {
CAFFE2_TO_ONNXIFI_TYPE(FLOAT, ONNXIFI_DATATYPE_FLOAT32);
CAFFE2_TO_ONNXIFI_TYPE(INT8, ONNXIFI_DATATYPE_INT8);
CAFFE2_TO_ONNXIFI_TYPE(UINT8, ONNXIFI_DATATYPE_UINT8);
CAFFE2_TO_ONNXIFI_TYPE(INT16, ONNXIFI_DATATYPE_INT16);
CAFFE2_TO_ONNXIFI_TYPE(UINT16, ONNXIFI_DATATYPE_UINT16);
CAFFE2_TO_ONNXIFI_TYPE(INT32, ONNXIFI_DATATYPE_INT32);
CAFFE2_TO_ONNXIFI_TYPE(INT64, ONNXIFI_DATATYPE_INT64);
CAFFE2_TO_ONNXIFI_TYPE(FLOAT16, ONNXIFI_DATATYPE_FLOAT16);
default:
LOG(WARNING) << "Unsupported Caffe2 tensor type: " << t
<< ", fallback to FLOAT";
return ONNXIFI_DATATYPE_FLOAT32;
}
#undef CAFFE2_TO_ONNXIFI_TYPE
}

std::unordered_map<std::string, TensorShape> InferShapes(
Workspace* ws,
NetDef* pred_net,
Expand Down Expand Up @@ -55,7 +75,8 @@ std::unordered_map<std::string, TensorShape> InferShapes(
ws_local.RunNetOnce(*pred_net);
const std::vector<std::string> ws_blobs = ws_local.Blobs();
for (const auto& s : ws_blobs) {
auto shape = GetTensorShapeOfBlob(ws_local.GetBlob(s));
const Blob* b = ws_local.GetBlob(s);
auto shape = GetTensorShapeOfBlob(b);
if (!shape.unknown_shape()) {
shape_hints.emplace(s, std::move(shape));
}
Expand Down Expand Up @@ -89,7 +110,7 @@ std::vector<::ONNX_NAMESPACE::ValueInfoProto> ConvertToValueInfo(
} else {
auto* tensor_type = value_info.mutable_type()->mutable_tensor_type();
tensor_type->set_elem_type(
::ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT);
onnx::Caffe2TypeToOnnxType(it->second.data_type()));
auto* shape = tensor_type->mutable_shape();
for (int i = 0; i < it->second.dims().size(); ++i) {
shape->add_dim()->set_dim_value(it->second.dims(i));
Expand Down Expand Up @@ -125,7 +146,7 @@ OnnxifiTransformer::OnnxifiTransformer(bool infer_shapes, bool debug)

OperatorDef OnnxifiTransformer::BuildOnnxifiOp(
const std::string& onnx_model_str,
const std::unordered_map<std::string, std::vector<int>>& output_size_hints,
const std::unordered_map<std::string, TensorShape>& output_shape_hints,
const std::unordered_set<std::string>& initialization_list,
const caffe2::NetDef& net) {
OperatorDef op;
Expand Down Expand Up @@ -156,13 +177,14 @@ OperatorDef OnnxifiTransformer::BuildOnnxifiOp(
// Add output size hints
for (int i = 0; i < op.output_size(); ++i) {
const auto& o = op.output(i);
const auto it = output_size_hints.find(o);
if (it != output_size_hints.end()) {
const auto& dims = it->second;
auto* output_size_hint_arg = op.add_arg();
output_size_hint_arg->set_name(c10::str("output_size_hint_", i));
for (const auto& d : dims) {
output_size_hint_arg->add_ints(d);
const auto it = output_shape_hints.find(o);
if (it != output_shape_hints.end()) {
const auto& shape = it->second;
auto* output_shape_hint_arg = op.add_arg();
output_shape_hint_arg->set_name(c10::str("output_shape_hint_", i));
output_shape_hint_arg->add_ints(OnnxifiDataType(shape.data_type()));
for (const auto& d : shape.dims()) {
output_shape_hint_arg->add_ints(d);
}

VLOG(2) << "Adding output hint: " << o;
Expand Down Expand Up @@ -234,20 +256,16 @@ NetDef OnnxifiTransformer::SubnetToOnnxifiOp(
io_names.emplace_back(output);
}
auto io_vec = ConvertToValueInfo(io_names, *shape_hints);
std::unordered_map<std::string, std::vector<int>> output_shape_hints;
std::unordered_map<std::string, TensorShape> output_shape_hints;
for (const auto& i : io_vec) {
onnx_model.mutable_graph()->add_output()->CopyFrom(i);
auto ret = output_shape_hints.emplace(i.name(), std::vector<int>());
auto& vec = ret.first->second;
const auto it = shape_hints->find(i.name());
CAFFE_ENFORCE(
it != shape_hints->end(),
"Cannot find shape info for output ",
i.name());
const auto& shape = it->second;
for (int k = 0; k < shape.dims().size(); ++k) {
vec.push_back(shape.dims(k));
}
output_shape_hints.emplace(i.name(), shape);
}

// Convert inputs and figure out weights
Expand Down Expand Up @@ -352,7 +370,7 @@ void OnnxifiTransformer::Transform(
auto shape_hints_ordered =
SsaRewriteAndMapNames(ws, pred_net, input_shape_hints);
Workspace mapped_ws(ws, input_mapping_);
auto shape_hints =
std::unordered_map<std::string, TensorShape> shape_hints =
InferShapes(&mapped_ws, pred_net, &shape_hints_ordered, infer_shapes_);

CAFFE_ENFORCE(pred_net, "Predict net cannot be nullptr");
Expand Down
3 changes: 1 addition & 2 deletions caffe2/opt/onnxifi_transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@ class CAFFE2_API OnnxifiTransformer {

OperatorDef BuildOnnxifiOp(
const std::string& onnx_model_str,
const std::unordered_map<std::string, std::vector<int>>&
output_size_hints,
const std::unordered_map<std::string, TensorShape>& output_size_hints,
const std::unordered_set<std::string>& initialization_list,
const caffe2::NetDef& net);

Expand Down
6 changes: 4 additions & 2 deletions caffe2/python/onnx/test_onnxifi.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from caffe2.python.onnx.onnxifi import onnxifi_caffe2_net
from caffe2.python.onnx.tests.test_utils import TestCase

ONNXIFI_DATATYPE_FLOAT32 = 1


def _print_net(net):
for i in net.external_input:
Expand Down Expand Up @@ -51,7 +53,7 @@ def test_relu_graph(self):
["X"],
["Y"],
onnx_model=model_def.SerializeToString(),
output_size_hint_0=[batch_size, 1, 3, 2])
output_shape_hint_0=[ONNXIFI_DATATYPE_FLOAT32, batch_size, 1, 3, 2])
workspace.FeedBlob("X", X)
workspace.RunOperatorOnce(op)
Y = workspace.FetchBlob("Y")
Expand Down Expand Up @@ -92,7 +94,7 @@ def test_conv_graph(self):
["Y"],
onnx_model=model_def.SerializeToString(),
initializers=["W", "W"],
output_size_hint_0=[1, 1, 3, 3])
output_shape_hint_0=[ONNXIFI_DATATYPE_FLOAT32, 1, 1, 3, 3])
workspace.FeedBlob("X", X)
workspace.FeedBlob("W", W)
workspace.RunOperatorOnce(op)
Expand Down

0 comments on commit d97ac82

Please sign in to comment.