diff --git a/tensorflow/java/src/main/java/org/tensorflow/Graph.java b/tensorflow/java/src/main/java/org/tensorflow/Graph.java
index 03369a12cc2b26..c50e65f44eee06 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/Graph.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/Graph.java
@@ -38,10 +38,35 @@ public Graph() {
@Override
public void close() {
synchronized (nativeHandleLock) {
- if (nativeHandle != 0) {
- delete(nativeHandle);
- nativeHandle = 0;
+ if (nativeHandle == 0) {
+ return;
}
+ while (refcount > 0) {
+ try {
+ nativeHandleLock.wait();
+ } catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ // Possible leak of the graph in this case?
+ return;
+ }
+ }
+ delete(nativeHandle);
+ nativeHandle = 0;
+ }
+ }
+
+ /**
+ * Returns the operation (node in the Graph) with the provided name.
+ *
+ *
Or {@code null} if no such operation exists in the Graph.
+ */
+ public Operation operation(String name) {
+ synchronized (nativeHandleLock) {
+ long oph = operation(nativeHandle, name);
+ if (oph == 0) {
+ return null;
+ }
+ return new Operation(this, oph);
}
}
@@ -89,11 +114,51 @@ public byte[] toGraphDef() {
private final Object nativeHandleLock = new Object();
private long nativeHandle;
+ private int refcount = 0;
+
+ // Related native objects (such as the TF_Operation object backing an Operation instance)
+ // have a validity tied to that of the Graph. The handles to those native objects are not
+ // valid after Graph.close() has been invoked.
+ //
+ // Instances of the Reference class should be used to ensure the Graph has not been closed
+ // while dependent handles are in use.
+ class Reference implements AutoCloseable {
+ private Reference() {
+ synchronized (Graph.this.nativeHandleLock) {
+ active = Graph.this.nativeHandle != 0;
+ if (!active) {
+ throw new IllegalStateException("close() has been called on the Graph");
+ }
+ Graph.this.refcount++;
+ }
+ }
+
+ @Override
+ public void close() {
+ synchronized (Graph.this.nativeHandleLock) {
+ if (!active) {
+ return;
+ }
+ active = false;
+ if (--Graph.this.refcount == 0) {
+ Graph.this.nativeHandleLock.notifyAll();
+ }
+ }
+ }
+
+ private boolean active;
+ }
+
+ Reference ref() {
+ return new Reference();
+ }
private static native long allocate();
private static native void delete(long handle);
+ private static native long operation(long handle, String name);
+
private static native void importGraphDef(long handle, byte[] graphDef, String prefix)
throws IllegalArgumentException;
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Operation.java b/tensorflow/java/src/main/java/org/tensorflow/Operation.java
new file mode 100644
index 00000000000000..cf552636925ab5
--- /dev/null
+++ b/tensorflow/java/src/main/java/org/tensorflow/Operation.java
@@ -0,0 +1,77 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow;
+
+/**
+ * A Graph node that performs computation on Tensors.
+ *
+ *
An Operation is a node in a {@link Graph} that takes zero or more {@link Tensor}s (produced by
+ * other Operations in the Graph) as input, and produces zero or more {@link Tensor}s as output.
+ *
+ *
Operation instances are valid only as long as the Graph they are a part of is valid. Thus, if
+ * {@link Graph#close()} has been invoked, then methods on the Operation instance may fail with an
+ * {@code IllegalStateException}.
+ *
+ *
Operation instances are immutable and thread-safe.
+ */
+public final class Operation {
+
+ // Create an Operation instance referring to an operation in g, with the given handle to the C
+ // TF_Operation object. The handle is valid only as long as g has not been closed, hence it is
+ // called unsafeHandle. Graph.ref() is used to safely use the unsafeHandle.
+ Operation(Graph g, long unsafeNativeHandle) {
+ this.graph = g;
+ this.unsafeNativeHandle = unsafeNativeHandle;
+ }
+
+ /** Returns the full name of the Operation. */
+ public String name() {
+ try (Graph.Reference r = graph.ref()) {
+ return name(unsafeNativeHandle);
+ }
+ }
+
+ /**
+ * Returns the type of the operation, i.e., the name of the computation performed by the
+ * operation.
+ */
+ public String type() {
+ try (Graph.Reference r = graph.ref()) {
+ return type(unsafeNativeHandle);
+ }
+ }
+
+ /** Returns the number of tensors output by this operation. */
+ public int numOutputs() {
+ try (Graph.Reference r = graph.ref()) {
+ return numOutputs(unsafeNativeHandle);
+ }
+ }
+
+ /** Returns a symbolic handle to one of the tensors produced by this operation. */
+ public Output output(int idx) {
+ return new Output(this, idx);
+ }
+
+ private final long unsafeNativeHandle;
+ private final Graph graph;
+
+ private static native String name(long handle);
+
+ private static native String type(long handle);
+
+ private static native int numOutputs(long handle);
+}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Output.java b/tensorflow/java/src/main/java/org/tensorflow/Output.java
new file mode 100644
index 00000000000000..f0fffc2c1df20d
--- /dev/null
+++ b/tensorflow/java/src/main/java/org/tensorflow/Output.java
@@ -0,0 +1,44 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow;
+
+/**
+ * A symbolic handle to a tensor produced by an {@link Operation}.
+ *
+ *
An Output is a symbolic handle to a tensor. The value of the Tensor is computed by executing
+ * the {@link Operation} in a {@link Session}.
+ */
+public final class Output {
+
+ /** Handle to the idx-th output of the Operation {@code op}. */
+ public Output(Operation op, int idx) {
+ operation = op;
+ index = idx;
+ }
+
+ /** Returns the Operation that will produce the tensor referred to by this Output. */
+ public Operation op() {
+ return operation;
+ }
+
+ /** Returns the index into the outputs of the Operation. */
+ public int index() {
+ return index;
+ }
+
+ private final Operation operation;
+ private final int index;
+}
diff --git a/tensorflow/java/src/main/native/graph_jni.cc b/tensorflow/java/src/main/native/graph_jni.cc
index 09d6916a8f7acd..9ae3d2ceab2e1e 100644
--- a/tensorflow/java/src/main/native/graph_jni.cc
+++ b/tensorflow/java/src/main/native/graph_jni.cc
@@ -21,6 +21,8 @@ limitations under the License.
namespace {
TF_Graph* requireHandle(JNIEnv* env, jlong handle) {
+ static_assert(sizeof(jlong) >= sizeof(TF_Graph*),
+ "Cannot package C object pointers as a Java long");
if (handle == 0) {
throwException(env, kNullPointerException,
"close() has been called on the Graph");
@@ -40,6 +42,18 @@ JNIEXPORT void JNICALL Java_org_tensorflow_Graph_delete(JNIEnv*, jclass,
TF_DeleteGraph(reinterpret_cast(handle));
}
+JNIEXPORT jlong JNICALL Java_org_tensorflow_Graph_operation(JNIEnv* env,
+ jclass clazz,
+ jlong handle,
+ jstring name) {
+ TF_Graph* g = requireHandle(env, handle);
+ if (g == nullptr) return 0;
+ const char* cname = env->GetStringUTFChars(name, nullptr);
+ TF_Operation* op = TF_GraphOperationByName(g, cname);
+ env->ReleaseStringUTFChars(name, cname);
+ return reinterpret_cast(op);
+}
+
JNIEXPORT void JNICALL Java_org_tensorflow_Graph_importGraphDef(
JNIEnv* env, jclass clazz, jlong handle, jbyteArray graph_def,
jstring prefix) {
diff --git a/tensorflow/java/src/main/native/graph_jni.h b/tensorflow/java/src/main/native/graph_jni.h
index 1e847aa4e58edb..b84c11578ee584 100644
--- a/tensorflow/java/src/main/native/graph_jni.h
+++ b/tensorflow/java/src/main/native/graph_jni.h
@@ -37,6 +37,14 @@ JNIEXPORT jlong JNICALL Java_org_tensorflow_Graph_allocate(JNIEnv *, jclass);
JNIEXPORT void JNICALL Java_org_tensorflow_Graph_delete(JNIEnv *, jclass,
jlong);
+/*
+ * Class: org_tensorflow_Graph
+ * Method: operation
+ * Signature: (JLjava/lang/String;)J
+ */
+JNIEXPORT jlong JNICALL Java_org_tensorflow_Graph_operation(JNIEnv *, jclass,
+ jlong, jstring);
+
/*
* Class: org_tensorflow_Graph
* Method: importGraphDef
diff --git a/tensorflow/java/src/main/native/operation_jni.cc b/tensorflow/java/src/main/native/operation_jni.cc
new file mode 100644
index 00000000000000..e2eaacd4189e48
--- /dev/null
+++ b/tensorflow/java/src/main/native/operation_jni.cc
@@ -0,0 +1,57 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/java/src/main/native/operation_jni.h"
+
+#include "tensorflow/c/c_api.h"
+#include "tensorflow/java/src/main/native/exception_jni.h"
+
+namespace {
+TF_Operation* requireHandle(JNIEnv* env, jlong handle) {
+ static_assert(sizeof(jlong) >= sizeof(TF_Operation*),
+ "Cannot package C object pointers as a Java long");
+ if (handle == 0) {
+ throwException(
+ env, kNullPointerException,
+ "close() has been called on the Graph this Operation was a part of");
+ return nullptr;
+ }
+ return reinterpret_cast(handle);
+}
+} // namespace
+
+JNIEXPORT jstring JNICALL Java_org_tensorflow_Operation_name(JNIEnv* env,
+ jclass clazz,
+ jlong handle) {
+ TF_Operation* op = requireHandle(env, handle);
+ if (op == nullptr) return nullptr;
+ return env->NewStringUTF(TF_OperationName(op));
+}
+
+JNIEXPORT jstring JNICALL Java_org_tensorflow_Operation_type(JNIEnv* env,
+ jclass clazz,
+ jlong handle) {
+ TF_Operation* op = requireHandle(env, handle);
+ if (op == nullptr) return nullptr;
+ return env->NewStringUTF(TF_OperationOpType(op));
+}
+
+JNIEXPORT jint JNICALL Java_org_tensorflow_Operation_numOutputs(JNIEnv* env,
+ jclass clazz,
+ jlong handle) {
+ TF_Operation* op = requireHandle(env, handle);
+ if (op == nullptr) return 0;
+ return TF_OperationNumOutputs(op);
+}
diff --git a/tensorflow/java/src/main/native/operation_jni.h b/tensorflow/java/src/main/native/operation_jni.h
new file mode 100644
index 00000000000000..ca25ef728eeb54
--- /dev/null
+++ b/tensorflow/java/src/main/native/operation_jni.h
@@ -0,0 +1,53 @@
+
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_JAVA_OPERATION_JNI_H_
+#define TENSORFLOW_JAVA_OPERATION_JNI_H_
+
+#include
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/*
+ * Class: org_tensorflow_Operation
+ * Method: name
+ * Signature: (J)Ljava/lang/String;
+ */
+JNIEXPORT jstring JNICALL Java_org_tensorflow_Operation_name(JNIEnv *, jclass,
+ jlong);
+
+/*
+ * Class: org_tensorflow_Operation
+ * Method: type
+ * Signature: (J)Ljava/lang/String;
+ */
+JNIEXPORT jstring JNICALL Java_org_tensorflow_Operation_type(JNIEnv *, jclass,
+ jlong);
+
+/*
+ * Class: org_tensorflow_Operation
+ * Method: numOutputs
+ * Signature: (J)I
+ */
+JNIEXPORT jint JNICALL Java_org_tensorflow_Operation_numOutputs(JNIEnv *,
+ jclass, jlong);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+#endif // TENSORFLOW_JAVA_OPERATION_JNI_H_
diff --git a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java
index 1b6743531bc960..78df0d6d1e22a5 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java
@@ -16,6 +16,7 @@
package org.tensorflow;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.fail;
import java.nio.file.Files;
@@ -29,21 +30,42 @@
public class GraphTest {
@Test
- public void graphDefImportAndExport() {
- try (Graph g = new Graph()) {
- final byte[] inGraphDef = Files.readAllBytes(Paths.get("tensorflow/java/test_graph_def.data"));
- g.importGraphDef(inGraphDef);
- final byte[] outGraphDef = g.toGraphDef();
- // The graphs may not be identical as the proto format allows the same message
- // to be encoded in multiple ways. Once the Graph API is expressive enough
- // to construct graphs and query for nodes/operations, use that.
- // Till then a very crude test:
- assertEquals(inGraphDef.length, outGraphDef.length);
+ public void graphDefRoundTrip() {
+ try (Graph imported = new Graph()) {
+ final byte[] inGraphDef =
+ Files.readAllBytes(Paths.get("tensorflow/java/test_graph_def.data"));
+ imported.importGraphDef(inGraphDef);
+ validateImportedGraph(imported, "");
+
+ final byte[] outGraphDef = imported.toGraphDef();
+ try (Graph exported = new Graph()) {
+ exported.importGraphDef(outGraphDef, "HeyHeyHey");
+ validateImportedGraph(exported, "HeyHeyHey/");
+ }
+ // Knowing how test_graph_def.data was generated, it should have these nodes:
} catch (Exception e) {
fail("Unexpected exception: " + e);
}
}
+ // Helper function whose implementation is based on knowledge of how test_graph_def.data was
+ // produced.
+ private void validateImportedGraph(Graph g, String prefix) {
+ Operation op = g.operation(prefix + "MyConstant");
+ assertNotNull(op);
+ assertEquals(prefix + "MyConstant", op.name());
+ assertEquals("Const", op.type());
+ assertEquals(1, op.numOutputs());
+ assertEquals(op, op.output(0).op());
+
+ op = g.operation(prefix + "while/Less");
+ assertNotNull(op);
+ assertEquals(prefix + "while/Less", op.name());
+ assertEquals("Less", op.type());
+ assertEquals(1, op.numOutputs());
+ assertEquals(op, op.output(0).op());
+ }
+
@Test
public void failImportOnInvalidGraphDefs() {
try (Graph g = new Graph()) {
diff --git a/tensorflow/java/src/test/python/graphdef.py b/tensorflow/java/src/test/python/graphdef.py
index 1f3dfbf85c3424..91ea4cc650266a 100644
--- a/tensorflow/java/src/test/python/graphdef.py
+++ b/tensorflow/java/src/test/python/graphdef.py
@@ -30,7 +30,7 @@ def main():
if len(sys.argv) != 2:
print('Usage: ' + sys.argv[0] + '