Skip to content

Commit

Permalink
add python-d example.
Browse files Browse the repository at this point in the history
  • Loading branch information
ShigekiKarita committed May 10, 2020
1 parent 557b427 commit 11b3f6b
Show file tree
Hide file tree
Showing 13 changed files with 132 additions and 19 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@ download/
*.pb
*.dll
*.lib
tmp.bin
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,16 @@ with (newGraph)
}
```

And more:
- [save/load TF graphs between Python and D](example/graph_import)

## Features

- [x] Setup CI
- [x] Wrap tensor and session for basic usages (see `tfd.session` unittests).
- [x] mir.ndslice.Slice `s` <=> tfd.tensor.Tensor `t` integration by `s.tensor`, `t.slicedAs(s)`.
- [x] [Example](example/graph_import) to save/load TF graphs.
- [ ] Use [pbd](https://github.com/ShigekiKarita/pbd) to save/load proto files.
- [ ] Example using C API to save/load TF graphs.
- [ ] Parse `ops.pbtxt` to generate typed ops bindings.
- [ ] Rewrite C API example with typed bindings.
- [ ] Implement autograd, and simple training APIs in D.
Expand Down
Binary file added example/graph_import/add-d.bin
Binary file not shown.
Binary file added example/graph_import/add-py.bin
Binary file not shown.
20 changes: 20 additions & 0 deletions example/graph_import/graph_export.d
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#!/usr/bin/env dub
/+ dub.json:
{
"dependencies": {
"tfd": {"path": "../.."}
}
}
+/
import tfd : newGraph;

void main()
{
with (newGraph) {
auto a = placeholder!int("a");
auto b = constant(3, "b");
// TODO(karita): provide name "add" by identity
auto add = a + b;
write("add-d.bin");
}
}
7 changes: 7 additions & 0 deletions example/graph_import/graph_export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import tensorflow.compat.v1 as tf

with tf.Session() as sess:
a = tf.placeholder(tf.int32, (), "a")
b = tf.constant(3)
c = tf.identity(a + b, "add")
tf.io.write_graph(sess.graph_def, ".", "add-py.bin", as_text=False)
25 changes: 25 additions & 0 deletions example/graph_import/graph_import.d
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#!/usr/bin/env dub
/+ dub.json:
{
"dependencies": {
"tfd": {"path": "../.."}
}
}
+/
import std.stdio : writeln;
import std.file : read;
import tfd : newGraph, tensor;

void main()
{
with (newGraph) {
// TODO(karita): pbtxt support.
load(read("add-py.bin"));
auto a = operationByName("a");
auto add = operationByName("add");

const t = session.run([add], [a: 1.tensor])[0].tensor;
assert(t.scalar!int == 1 + 3);
writeln(t.scalar!int);
}
}
17 changes: 17 additions & 0 deletions example/graph_import/graph_import.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import tensorflow.compat.v1 as tf

with tf.Session() as sess:
with open("add-d.bin", "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())

tf.import_graph_def(graph_def)
graph = tf.get_default_graph()
for op in graph.get_operations():
print(op.name)

a = graph.get_tensor_by_name("import/a:0")
add = graph.get_tensor_by_name("import/add:0")
result = sess.run(add, {a: 1})
print(result)
assert(result == 4)
2 changes: 1 addition & 1 deletion source/tfd/c_api/linux.d
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
module tfd.c_api.linux_;
module tfd.c_api.linux;
version (linux):

import core.stdc.config;
Expand Down
2 changes: 1 addition & 1 deletion source/tfd/c_api/linux.dpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/// Tensorflow C API header generated by dpp.
/// NOTE: module name linux will be " 1" (https://github.com/atilaneves/dpp/issues/258)
module tfd.c_api.linux_;
module tfd.c_api.linux;
version (linux):
#include <tensorflow/c/c_api.h>
2 changes: 1 addition & 1 deletion source/tfd/c_api/package.d
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ version (Windows)
}
else version (linux)
{
public import tfd.c_api.linux_;
public import tfd.c_api.linux;
}
else
{
Expand Down
64 changes: 49 additions & 15 deletions source/tfd/graph.d
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,37 @@ struct GraphOwner
TF_DeleteGraph(this.ptr);
TF_DeleteStatus(this.status);
}

/// Loads serialized graph (GraphDef proto).
@nogc nothrow @trusted
void load(const(void)[] proto)
{
auto buffer = TF_NewBufferFromString(proto.ptr, proto.length);
auto opts = TF_NewImportGraphDefOptions;
TF_GraphImportGraphDef(this.ptr, buffer, opts, this.status);
assertStatus(this.status);
}

/// Returns serialized bytes (GraphDef proto).
@nogc nothrow @trusted
TF_Buffer* serialize()
{
auto buffer = TF_NewBuffer;
TF_GraphToGraphDef(this.ptr, buffer, this.status);
assertStatus(this.status);
return buffer;
}

/// Writes serialized bytes (GraphDef proto) to a given file.
@nogc nothrow @trusted
void write(const(char)* fileName)
{
import core.stdc.stdio;

auto buffer = this.serialize();
auto fp = fopen(fileName, "wb");
fwrite(buffer.data, 1, buffer.length, fp);
}
}

/// TF_Operation wrapper used in Graph.
Expand Down Expand Up @@ -199,6 +230,15 @@ struct Graph
/// Base reference counted pointer.
SlimRCPtr!GraphOwner base;
alias base this;
/// Get an operation by name

@nogc nothrow @trusted
Operation operationByName(const(char)* name)
{
auto opr = TF_GraphOperationByName(this.ptr, name);
assert(opr);
return Operation(opr, this);
}

/// Creates a placeholder in this graph.
Operation placeholder(T, size_t N)(
Expand Down Expand Up @@ -245,7 +285,7 @@ unittest
{
import tfd.tensor;

auto buffer = TF_NewBuffer;
TF_Buffer* buffer;
scope (exit) TF_DeleteBuffer(buffer);
{
auto graph = newGraph;
Expand All @@ -259,22 +299,16 @@ unittest
auto add = a + b;
assert(TF_GraphOperationByName(graph, "add"));
}
// Export to a GraphDef (protobuf)
TF_GraphToGraphDef(graph, buffer, graph.status);
assertStatus(graph.status);
buffer = graph.serialize;
// for coverage
graph.write("tmp.bin");
}
{
with (newGraph) {
// Import from the GraphDef (protobuf)
auto graph = newGraph;
auto opts = TF_NewImportGraphDefOptions;
TF_GraphImportGraphDef(graph, buffer, opts, graph.status);
assertStatus(graph.status);

auto a = TF_GraphOperationByName(graph, "a");
assert(a);
auto add = TF_GraphOperationByName(graph, "add");
assert(add);
const t = TensorOwner(graph.session.run([add], [a: 1.tensor])[0]);
load(buffer.data[0 .. buffer.length]);
auto a = operationByName("a");
auto add = operationByName("add");
const t = session.run([add], [a: 1.tensor])[0].tensor;
assert(t.scalar!int == 1 + 3);
}
}
6 changes: 6 additions & 0 deletions source/tfd/tensor.d
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,12 @@ Tensor tensor(Args ...)(Args args)
return createSlimRC!TensorOwner(makeTF_Tensor(forward!args));
}

@trusted
Tensor tensor(TF_Tensor* t)
{
return createSlimRC!TensorOwner(t);
}

/// Make a scalar RCTensor.
@nogc nothrow @safe
unittest
Expand Down

0 comments on commit 11b3f6b

Please sign in to comment.