Skip to content

Commit

Permalink
Guard all Caffe2 protobuf string serializations with CAFFE_ENFORCE (f…
Browse files Browse the repository at this point in the history
…ixed reverted bug) (pytorch#12848)

Summary:
Pull Request resolved: pytorch#12848

Updated all non-test uses of protobuf::MessageLite::SerializeAsString to call
SerializeAsString_EnforceCheck so that the return value is checked and can
throw an exception if failing.

Most of the affected code was called from classes derived from  BlobSerializeBase.
Didn't touch most tests and ENFORCE calls because they usually do checks
anyway.

Original commit changeset: c0760e73ecc7

Reviewed By: dzhulgakov

Differential Revision: D10453456

fbshipit-source-id: d2f2b7b4578e721924354149f08f627c7e3bf070
  • Loading branch information
Michael Antonov authored and facebook-github-bot committed Oct 23, 2018
1 parent dd00c29 commit a6949ab
Show file tree
Hide file tree
Showing 14 changed files with 58 additions and 18 deletions.
3 changes: 1 addition & 2 deletions binaries/convert_caffe_image_db.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ int main(int argc, char** argv) {
data->add_dims(datum.channels());
data->set_byte_data(buffer, datum.data().size());
}
transaction->Put(cursor->key(), protos.SerializeAsString());
transaction->Put(cursor->key(), SerializeAsString_EnforceCheck(protos));
if (++count % FLAGS_batch_size == 0) {
transaction->Commit();
LOG(INFO) << "Converted " << count << " items so far.";
Expand All @@ -88,4 +88,3 @@ int main(int argc, char** argv) {
LOG(INFO) << "A total of " << count << " items processed.";
return 0;
}

23 changes: 21 additions & 2 deletions caffe2/core/blob_serialization.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class StringSerializer : public BlobSerializerBase {
blob_proto.set_name(name);
blob_proto.set_type("std::string");
blob_proto.set_content(*static_cast<const std::string*>(pointer));
acceptor(name, blob_proto.SerializeAsString());
acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
}
};

Expand Down Expand Up @@ -134,7 +134,7 @@ void TensorSerializer::SerializeWithChunkSize(
tensor, name, blob_proto.mutable_tensor(), chunkStart, chunk_size);
acceptor(
c10::str(name, kChunkIdSeparator, chunkStart / chunk_size),
blob_proto.SerializeAsString());
SerializeBlobProtoAsString_EnforceCheck(blob_proto));
};

#ifndef __ANDROID__
Expand Down Expand Up @@ -543,6 +543,25 @@ void TensorDeserializer::Deserialize(const TensorProto& proto, Tensor* tensor) {
context->FinishDeviceComputation();
}

////////////////////////////////////////////////////////////////////////////////
// Serialization Helpers
////////////////////////////////////////////////////////////////////////////////

std::string SerializeAsString_EnforceCheck(
const google::protobuf::MessageLite& msg,
const char* error_location) {
std::string serialize_output;
bool result = msg.SerializeToString(&serialize_output);
if (!error_location) {
CAFFE_ENFORCE(result, "protobuf::SerializeToString failed");
} else {
CAFFE_ENFORCE(result,
"protobuf::SerializeToString failed for ", error_location);
}
return serialize_output;
}


namespace {
// Serialize Tensor
REGISTER_BLOB_SERIALIZER((TypeMeta::Id<Tensor>()), TensorSerializer);
Expand Down
18 changes: 18 additions & 0 deletions caffe2/core/blob_serialization.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,24 @@ inline void CopyFromProtoWithCast(
}

} // namespace detail

////////////////////////////////////////////////////////////////////////////////
// Serialization Helpers
////////////////////////////////////////////////////////////////////////////////

// Converts MessageLite to string while also checking that SerializeAsString
// succeeds. Pass description of class/function of the call if you'd
// like it appended to the error message.
CAFFE2_API std::string SerializeAsString_EnforceCheck(
const google::protobuf::MessageLite&,
const char* error_location = nullptr);

// Convert BlobProto to string with success checks.
inline std::string SerializeBlobProtoAsString_EnforceCheck(
const BlobProto& blob) {
return SerializeAsString_EnforceCheck(blob, blob.name().c_str());
}

} // namespace caffe2

#endif // CAFFE2_CORE_BLOB_SERIALIZATION_H_
2 changes: 1 addition & 1 deletion caffe2/core/blob_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class BlobTestFooSerializer : public BlobSerializerBase {
reinterpret_cast<const char*>(
&static_cast<const BlobTestFoo*>(pointer)->val),
sizeof(int32_t)));
acceptor(name, blob_proto.SerializeAsString());
acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
}
};

Expand Down
4 changes: 2 additions & 2 deletions caffe2/core/db.cc
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,8 @@ void DBReaderSerializer::Serialize(
BlobProto blob_proto;
blob_proto.set_name(name);
blob_proto.set_type("DBReader");
blob_proto.set_content(proto.SerializeAsString());
acceptor(name, blob_proto.SerializeAsString());
blob_proto.set_content(SerializeAsString_EnforceCheck(proto));
acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
}

void DBReaderDeserializer::Deserialize(const BlobProto& proto, Blob* blob) {
Expand Down
2 changes: 1 addition & 1 deletion caffe2/core/int8_serialization.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class Int8TensorCPUSerializer : public BlobSerializerBase {
CAFFE_ENFORCE(false, "Unsupported data type in Int8TensorCPU");
}

acceptor(name, blob_proto.SerializeAsString());
acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
}

private:
Expand Down
2 changes: 1 addition & 1 deletion caffe2/core/qtensor_serialization.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ void QTensorSerializer<Context>::Serialize(
proto.set_is_signed(qtensor.is_signed());
detail::CopyToProtoWithCast(
qtensor.nbytes(), qtensor.data(), proto.mutable_data(), &this->context_);
acceptor(name, blob_proto.SerializeAsString());
acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
}

template <class Context>
Expand Down
5 changes: 4 additions & 1 deletion caffe2/db/protodb.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ class ProtoDBCursor : public Cursor {
void SeekToFirst() override { iter_ = 0; }
void Next() override { ++iter_; }
string key() override { return proto_->protos(iter_).name(); }
string value() override { return proto_->protos(iter_).SerializeAsString(); }
string value() override {
return
SerializeAsString_EnforceCheck(proto_->protos(iter_), "ProtoDBCursor");
}
bool Valid() override { return iter_ < proto_->protos_size(); }

private:
Expand Down
2 changes: 1 addition & 1 deletion caffe2/operators/counter_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ class CounterSerializer : public BlobSerializerBase {
proto.add_int64_data(
(*static_cast<const std::unique_ptr<Counter<int64_t>>*>(pointer))
->retrieve());
acceptor(name, blob_proto.SerializeAsString());
acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
}
};

Expand Down
4 changes: 2 additions & 2 deletions caffe2/operators/dataset_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1451,7 +1451,7 @@ class TreeCursorSerializer : public BlobSerializerBase {
}
blob_proto.set_content(os.str());

acceptor(name, blob_proto.SerializeAsString());
acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
}
};

Expand Down Expand Up @@ -1513,7 +1513,7 @@ void SharedTensorVectorPtrSerializer::Serialize(
blob_proto.set_name(name);
blob_proto.set_type("std::shared_ptr<std::vector<TensorCPU>>");
blob_proto.set_content("");
acceptor(name, blob_proto.SerializeAsString());
acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
};

void SharedTensorVectorPtrDeserializer::Deserialize(
Expand Down
2 changes: 1 addition & 1 deletion caffe2/operators/index_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ class IndexSerializer : public BlobSerializerBase {
os << base->maxElements() << " " << base->isFrozen();
blob_proto.set_content(os.str());

acceptor(name, blob_proto.SerializeAsString());
acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
}

private:
Expand Down
4 changes: 2 additions & 2 deletions caffe2/operators/map_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,8 @@ class MapSerializer : public BlobSerializerBase {
BlobProto blob_proto;
blob_proto.set_name(name);
blob_proto.set_type(MapTypeTraits<KEY_T, VALUE_T>::MapTypeName());
blob_proto.set_content(tensor_protos.SerializeAsString());
acceptor(name, blob_proto.SerializeAsString());
blob_proto.set_content(SerializeAsString_EnforceCheck(tensor_protos));
acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
}
};

Expand Down
3 changes: 2 additions & 1 deletion caffe2/python/pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,8 @@ void addObjectMethods(py::module& m) {
const auto& meta = GetGradientForOp(def, output_gradients);
std::vector<py::bytes> grad_ops;
for (const auto& op : meta.ops_) {
grad_ops.push_back(op.SerializeAsString());
grad_ops.push_back(
SerializeAsString_EnforceCheck(op, "addObjectMethods"));
}
return std::pair<std::vector<py::bytes>, std::vector<GradientWrapper>>{
grad_ops, meta.g_input_};
Expand Down
2 changes: 1 addition & 1 deletion caffe2/sgd/iter_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ void MutexSerializer::Serialize(
blob_proto.set_name(name);
blob_proto.set_type("std::unique_ptr<std::mutex>");
blob_proto.set_content("");
acceptor(name, blob_proto.SerializeAsString());
acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
}

void MutexDeserializer::Deserialize(const BlobProto& /* unused */, Blob* blob) {
Expand Down

0 comments on commit a6949ab

Please sign in to comment.