Skip to content

Commit

Permalink
Update Tensorflow test file (lutzroeder#1162)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Oct 7, 2023
1 parent 4e33db3 commit f5174b4
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 39 deletions.
77 changes: 77 additions & 0 deletions source/tf-proto.js
Original file line number Diff line number Diff line change
Expand Up @@ -9827,6 +9827,83 @@ $root.tensorflow.MemmappedFileSystemDirectory = class MemmappedFileSystemDirecto
}
};

$root.tensorflow.FingerprintDef = class FingerprintDef {

constructor() {
}

static decode(reader, length) {
const message = new $root.tensorflow.FingerprintDef();
const end = length !== undefined ? reader.position + length : reader.length;
while (reader.position < end) {
const tag = reader.uint32();
switch (tag >>> 3) {
case 1:
message.saved_model_checksum = reader.uint64();
break;
case 2:
message.graph_def_program_hash = reader.uint64();
break;
case 3:
message.signature_def_hash = reader.uint64();
break;
case 4:
message.saved_object_graph_hash = reader.uint64();
break;
case 5:
message.checkpoint_hash = reader.uint64();
break;
case 6:
message.version = $root.tensorflow.VersionDef.decode(reader, reader.uint32());
break;
default:
reader.skipType(tag & 7);
break;
}
}
return message;
}

static decodeText(reader) {
const message = new $root.tensorflow.FingerprintDef();
reader.start();
while (!reader.end()) {
const tag = reader.tag();
switch (tag) {
case "saved_model_checksum":
message.saved_model_checksum = reader.uint64();
break;
case "graph_def_program_hash":
message.graph_def_program_hash = reader.uint64();
break;
case "signature_def_hash":
message.signature_def_hash = reader.uint64();
break;
case "saved_object_graph_hash":
message.saved_object_graph_hash = reader.uint64();
break;
case "checkpoint_hash":
message.checkpoint_hash = reader.uint64();
break;
case "version":
message.version = $root.tensorflow.VersionDef.decodeText(reader);
break;
default:
reader.field(tag, message);
break;
}
}
return message;
}
};

$root.tensorflow.FingerprintDef.prototype.saved_model_checksum = protobuf.Uint64.create(0);
$root.tensorflow.FingerprintDef.prototype.graph_def_program_hash = protobuf.Uint64.create(0);
$root.tensorflow.FingerprintDef.prototype.signature_def_hash = protobuf.Uint64.create(0);
$root.tensorflow.FingerprintDef.prototype.saved_object_graph_hash = protobuf.Uint64.create(0);
$root.tensorflow.FingerprintDef.prototype.checkpoint_hash = protobuf.Uint64.create(0);
$root.tensorflow.FingerprintDef.prototype.version = null;

$root.google = {};

$root.google.protobuf = {};
Expand Down
67 changes: 36 additions & 31 deletions source/tf.js
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,13 @@ tf.ModelFactory = class {
const metadata = await context.metadata('tf-metadata.json');
return new tf.Model(metadata, saved_model, format, producer, bundle);
};
const openSavedModel = async (saved_model, format, producer) => {
const openSavedModel = async (context, saved_model, format, producer) => {
if (format === '') {
format = 'TensorFlow Saved Model';
if (saved_model && saved_model.saved_model_schema_version) {
format = format + ' v' + saved_model.saved_model_schema_version.toString();
}
}
if (saved_model.meta_graphs.length === 1 &&
saved_model.meta_graphs[0].object_graph_def &&
saved_model.meta_graphs[0].object_graph_def.nodes &&
Expand All @@ -242,7 +248,7 @@ tf.ModelFactory = class {
return openModel(saved_model, format, producer, null);
}
}
if (saved_model && saved_model.meta_graphs && saved_model.meta_graphs.length > 0 &&
if (saved_model && Array.isArray(saved_model.meta_graphs) && saved_model.meta_graphs.length > 0 &&
saved_model.meta_graphs[0].meta_info_def &&
Object.prototype.hasOwnProperty.call(saved_model.meta_graphs[0].meta_info_def, 'tensorflow_version')) {
producer = 'TensorFlow v' + saved_model.meta_graphs[0].meta_info_def.tensorflow_version;
Expand Down Expand Up @@ -366,7 +372,7 @@ tf.ModelFactory = class {
const updated_saved_model = await openPyTorchMetadata(context, saved_model);
return openModel(updated_saved_model, format, producer, null);
}
return openSavedModel(saved_model, format, producer);
return openSavedModel(context, saved_model, format, producer);
};
const openJson = async (context, type) => {
try {
Expand Down Expand Up @@ -451,7 +457,7 @@ tf.ModelFactory = class {
}
}
}
return openSavedModel(saved_model, format, producer);
return openSavedModel(context, saved_model, format, producer);
};
try {
const streams = await Promise.all(shards.values());
Expand Down Expand Up @@ -494,7 +500,7 @@ tf.ModelFactory = class {
const saved_model = new tf.proto.tensorflow.SavedModel();
saved_model.meta_graphs.push(meta_graph);
const format = 'TensorFlow Graph';
return openSavedModel(saved_model, format, null);
return openSavedModel(context, saved_model, format, null);
} catch (error) {
const message = error && error.message ? error.message : error.toString();
throw new tf.Error('File text format is not tensorflow.GraphDef (' + message.replace(/\.$/, '') + ').');
Expand All @@ -508,21 +514,15 @@ tf.ModelFactory = class {
const saved_model = new tf.proto.tensorflow.SavedModel();
saved_model.meta_graphs.push(meta_graph);
const format = 'TensorFlow MetaGraph';
return openSavedModel(saved_model, format, null);
return openSavedModel(context, saved_model, format, null);
} catch (error) {
throw new tf.Error('File text format is not tensorflow.MetaGraphDef (' + error.message + ').');
}
};
const openTextSavedModel = (context) => {
const openTextSavedModel = (stream) => {
try {
const stream = context.stream;
const reader = protobuf.TextReader.open(stream);
const saved_model = tf.proto.tensorflow.SavedModel.decodeText(reader);
let format = 'TensorFlow Saved Model';
if (saved_model && Object.prototype.hasOwnProperty.call(saved_model, 'saved_model_schema_version')) {
format = format + ' v' + saved_model.saved_model_schema_version.toString();
}
return openSavedModel(saved_model, format, null);
return tf.proto.tensorflow.SavedModel.decodeText(reader);
} catch (error) {
throw new tf.Error('File text format is not tensorflow.SavedModel (' + error.message + ').');
}
Expand All @@ -542,7 +542,7 @@ tf.ModelFactory = class {
const message = error && error.message ? error.message : error.toString();
throw new tf.Error('File format is not tensorflow.GraphDef (' + message.replace(/\.$/, '') + ').');
}
return openSavedModel(saved_model, format, null);
return openSavedModel(context, saved_model, format, null);
};
const openBinaryMetaGraphDef = (context) => {
let saved_model = null;
Expand All @@ -557,28 +557,33 @@ tf.ModelFactory = class {
const message = error && error.message ? error.message : error.toString();
throw new tf.Error('File format is not tensorflow.MetaGraphDef (' + message.replace(/\.$/, '') + ').');
}
return openSavedModel(saved_model, format, null);
return openSavedModel(context, saved_model, format, null);
};
const openBinarySavedModel = (context) => {
let saved_model = null;
let format = 'TensorFlow Saved Model';
const openBinarySavedModel = (stream) => {
try {
const stream = context.stream;
const reader = protobuf.BinaryReader.open(stream);
saved_model = tf.proto.tensorflow.SavedModel.decode(reader);
if (saved_model && Object.prototype.hasOwnProperty.call(saved_model, 'saved_model_schema_version')) {
format = format + ' v' + saved_model.saved_model_schema_version.toString();
}
return tf.proto.tensorflow.SavedModel.decode(reader);
} catch (error) {
const message = error && error.message ? error.message : error.toString();
throw new tf.Error('File format is not tensorflow.SavedModel (' + message.replace(/\.$/, '') + ').');
}
return openSavedModel(saved_model, format, null);
};
const openFingerprint = async (context) => {
const identifier = 'saved_model.pb';
const stream = await context.request(identifier, null);
return openBinarySavedModel({ stream: stream });
let format = '';
let saved_model = null;
try {
const identifier = 'saved_model.pb';
const stream = await context.request(identifier, null);
saved_model = openBinarySavedModel(stream);

} catch (error) {
format = 'TensorFlow Fingerprint';
saved_model = new tf.proto.tensorflow.SavedModel();
}
const stream = context.stream;
const reader = protobuf.BinaryReader.open(stream);
saved_model.fingerprint = tf.proto.tensorflow.FingerprintDef.decode(reader);
return openSavedModel(context, saved_model, format, null);
};
const openMemmapped = (context) => {
const stream = context.stream;
Expand Down Expand Up @@ -636,7 +641,7 @@ tf.ModelFactory = class {
meta_graph.graph_def = graph_def;
const saved_model = new tf.proto.tensorflow.SavedModel();
saved_model.meta_graphs.push(meta_graph);
return openSavedModel(saved_model, format, null);
return openSavedModel(context, saved_model, format, null);
};
switch (target) {
case 'tf.bundle':
Expand All @@ -654,13 +659,13 @@ tf.ModelFactory = class {
case 'tf.pbtxt.MetaGraphDef':
return openTextMetaGraphDef(context);
case 'tf.pbtxt.SavedModel':
return openTextSavedModel(context);
return openSavedModel(context, openTextSavedModel(context.stream), '', null);
case 'tf.pb.GraphDef':
return openBinaryGraphDef(context);
case 'tf.pb.MetaGraphDef':
return openBinaryMetaGraphDef(context);
case 'tf.pb.SavedModel':
return openBinarySavedModel(context);
return openSavedModel(context, openBinarySavedModel(context.stream), '', null);
case 'tf.pb.FingerprintDef':
return openFingerprint(context);
case 'tf.pb.mmap':
Expand Down
14 changes: 7 additions & 7 deletions test/models.json
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,6 @@
"error": "Archive contains no model files.",
"link": "https://github.com/lutzroeder/netron/issues/458"
},
{
"type": "_",
"target": "fingerprint.pb",
"source": "https://github.com/lutzroeder/netron/files/10475238/fingerprint.pb.zip[fingerprint.pb]",
"error": "The file 'saved_model.pb' does not exist.",
"link": "https://github.com/lutzroeder/netron/issues/458"
},
{
"type": "_",
"target": "haarcascade_mcs_nose.xml",
Expand Down Expand Up @@ -6083,6 +6076,13 @@
"action": "skip-render",
"link": "https://github.com/Leavingseason/xDeepFM"
},
{
"type": "tf",
"target": "fingerprint.pb",
"source": "https://github.com/lutzroeder/netron/files/10475238/fingerprint.pb.zip[fingerprint.pb]",
"format": "TensorFlow Fingerprint",
"link": "https://github.com/lutzroeder/netron/issues/1162"
},
{
"type": "tf",
"target": "float16.txt",
Expand Down
2 changes: 1 addition & 1 deletion tools/tf
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ install() {
schema() {
echo "tf schema"
[[ $(grep -U $'\x0D' ./source/tf-proto.js) ]] && crlf=1
node ./tools/protoc.js --text --root tf --out ./source/tf-proto.js --path ./third_party/source/tensorflow --path ./third_party/source/tensorflow/third_party/xla/third_party/tsl tensorflow/core/protobuf/saved_model.proto tensorflow/core/protobuf/tensor_bundle.proto tensorflow/core/util/saved_tensor_slice.proto tensorflow/core/util/event.proto tensorflow/core/protobuf/config.proto tensorflow/core/util/memmapped_file_system.proto
node ./tools/protoc.js --text --root tf --out ./source/tf-proto.js --path ./third_party/source/tensorflow --path ./third_party/source/tensorflow/third_party/xla/third_party/tsl tensorflow/core/protobuf/saved_model.proto tensorflow/core/protobuf/tensor_bundle.proto tensorflow/core/util/saved_tensor_slice.proto tensorflow/core/util/event.proto tensorflow/core/protobuf/config.proto tensorflow/core/util/memmapped_file_system.proto tensorflow/core/protobuf/fingerprint.proto
if [[ -n ${crlf} ]]; then
unix2dos --quiet --newfile ./source/tf-proto.js ./source/tf-proto.js
fi
Expand Down

0 comments on commit f5174b4

Please sign in to comment.