From 22812d55368956b74ccf3afd4c1139458b7b29fb Mon Sep 17 00:00:00 2001 From: Asim Shankar Date: Mon, 5 Dec 2016 17:16:23 -0800 Subject: [PATCH] Java: Introduce the Operation class. Add the Operation and Output classes and the ability to query an Operation in an Graph by name. Another step in the journey that is #5 Change: 141122031 --- .../src/main/java/org/tensorflow/Graph.java | 71 ++++++++++++++++- .../main/java/org/tensorflow/Operation.java | 77 +++++++++++++++++++ .../src/main/java/org/tensorflow/Output.java | 44 +++++++++++ tensorflow/java/src/main/native/graph_jni.cc | 14 ++++ tensorflow/java/src/main/native/graph_jni.h | 8 ++ .../java/src/main/native/operation_jni.cc | 57 ++++++++++++++ .../java/src/main/native/operation_jni.h | 53 +++++++++++++ .../test/java/org/tensorflow/GraphTest.java | 42 +++++++--- tensorflow/java/src/test/python/graphdef.py | 2 +- 9 files changed, 354 insertions(+), 14 deletions(-) create mode 100644 tensorflow/java/src/main/java/org/tensorflow/Operation.java create mode 100644 tensorflow/java/src/main/java/org/tensorflow/Output.java create mode 100644 tensorflow/java/src/main/native/operation_jni.cc create mode 100644 tensorflow/java/src/main/native/operation_jni.h diff --git a/tensorflow/java/src/main/java/org/tensorflow/Graph.java b/tensorflow/java/src/main/java/org/tensorflow/Graph.java index 03369a12cc2b26..c50e65f44eee06 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/Graph.java +++ b/tensorflow/java/src/main/java/org/tensorflow/Graph.java @@ -38,10 +38,35 @@ public Graph() { @Override public void close() { synchronized (nativeHandleLock) { - if (nativeHandle != 0) { - delete(nativeHandle); - nativeHandle = 0; + if (nativeHandle == 0) { + return; } + while (refcount > 0) { + try { + nativeHandleLock.wait(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + // Possible leak of the graph in this case? + return; + } + } + delete(nativeHandle); + nativeHandle = 0; + } + } + + /** + * Returns the operation (node in the Graph) with the provided name. + * + *

Or {@code null} if no such operation exists in the Graph. + */ + public Operation operation(String name) { + synchronized (nativeHandleLock) { + long oph = operation(nativeHandle, name); + if (oph == 0) { + return null; + } + return new Operation(this, oph); } } @@ -89,11 +114,51 @@ public byte[] toGraphDef() { private final Object nativeHandleLock = new Object(); private long nativeHandle; + private int refcount = 0; + + // Related native objects (such as the TF_Operation object backing an Operation instance) + // have a validity tied to that of the Graph. The handles to those native objects are not + // valid after Graph.close() has been invoked. + // + // Instances of the Reference class should be used to ensure the Graph has not been closed + // while dependent handles are in use. + class Reference implements AutoCloseable { + private Reference() { + synchronized (Graph.this.nativeHandleLock) { + active = Graph.this.nativeHandle != 0; + if (!active) { + throw new IllegalStateException("close() has been called on the Graph"); + } + Graph.this.refcount++; + } + } + + @Override + public void close() { + synchronized (Graph.this.nativeHandleLock) { + if (!active) { + return; + } + active = false; + if (--Graph.this.refcount == 0) { + Graph.this.nativeHandleLock.notifyAll(); + } + } + } + + private boolean active; + } + + Reference ref() { + return new Reference(); + } private static native long allocate(); private static native void delete(long handle); + private static native long operation(long handle, String name); + private static native void importGraphDef(long handle, byte[] graphDef, String prefix) throws IllegalArgumentException; diff --git a/tensorflow/java/src/main/java/org/tensorflow/Operation.java b/tensorflow/java/src/main/java/org/tensorflow/Operation.java new file mode 100644 index 00000000000000..cf552636925ab5 --- /dev/null +++ b/tensorflow/java/src/main/java/org/tensorflow/Operation.java @@ -0,0 +1,77 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow; + +/** + * A Graph node that performs computation on Tensors. + * + *

An Operation is a node in a {@link Graph} that takes zero or more {@link Tensor}s (produced by + * other Operations in the Graph) as input, and produces zero or more {@link Tensor}s as output. + * + *

Operation instances are valid only as long as the Graph they are a part of is valid. Thus, if + * {@link Graph#close()} has been invoked, then methods on the Operation instance may fail with an + * {@code IllegalStateException}. + * + *

Operation instances are immutable and thread-safe. + */ +public final class Operation { + + // Create an Operation instance referring to an operation in g, with the given handle to the C + // TF_Operation object. The handle is valid only as long as g has not been closed, hence it is + // called unsafeHandle. Graph.ref() is used to safely use the unsafeHandle. + Operation(Graph g, long unsafeNativeHandle) { + this.graph = g; + this.unsafeNativeHandle = unsafeNativeHandle; + } + + /** Returns the full name of the Operation. */ + public String name() { + try (Graph.Reference r = graph.ref()) { + return name(unsafeNativeHandle); + } + } + + /** + * Returns the type of the operation, i.e., the name of the computation performed by the + * operation. + */ + public String type() { + try (Graph.Reference r = graph.ref()) { + return type(unsafeNativeHandle); + } + } + + /** Returns the number of tensors output by this operation. */ + public int numOutputs() { + try (Graph.Reference r = graph.ref()) { + return numOutputs(unsafeNativeHandle); + } + } + + /** Returns a symbolic handle to one of the tensors produced by this operation. */ + public Output output(int idx) { + return new Output(this, idx); + } + + private final long unsafeNativeHandle; + private final Graph graph; + + private static native String name(long handle); + + private static native String type(long handle); + + private static native int numOutputs(long handle); +} diff --git a/tensorflow/java/src/main/java/org/tensorflow/Output.java b/tensorflow/java/src/main/java/org/tensorflow/Output.java new file mode 100644 index 00000000000000..f0fffc2c1df20d --- /dev/null +++ b/tensorflow/java/src/main/java/org/tensorflow/Output.java @@ -0,0 +1,44 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow; + +/** + * A symbolic handle to a tensor produced by an {@link Operation}. + * + *

An Output is a symbolic handle to a tensor. The value of the Tensor is computed by executing + * the {@link Operation} in a {@link Session}. + */ +public final class Output { + + /** Handle to the idx-th output of the Operation {@code op}. */ + public Output(Operation op, int idx) { + operation = op; + index = idx; + } + + /** Returns the Operation that will produce the tensor referred to by this Output. */ + public Operation op() { + return operation; + } + + /** Returns the index into the outputs of the Operation. */ + public int index() { + return index; + } + + private final Operation operation; + private final int index; +} diff --git a/tensorflow/java/src/main/native/graph_jni.cc b/tensorflow/java/src/main/native/graph_jni.cc index 09d6916a8f7acd..9ae3d2ceab2e1e 100644 --- a/tensorflow/java/src/main/native/graph_jni.cc +++ b/tensorflow/java/src/main/native/graph_jni.cc @@ -21,6 +21,8 @@ limitations under the License. namespace { TF_Graph* requireHandle(JNIEnv* env, jlong handle) { + static_assert(sizeof(jlong) >= sizeof(TF_Graph*), + "Cannot package C object pointers as a Java long"); if (handle == 0) { throwException(env, kNullPointerException, "close() has been called on the Graph"); @@ -40,6 +42,18 @@ JNIEXPORT void JNICALL Java_org_tensorflow_Graph_delete(JNIEnv*, jclass, TF_DeleteGraph(reinterpret_cast(handle)); } +JNIEXPORT jlong JNICALL Java_org_tensorflow_Graph_operation(JNIEnv* env, + jclass clazz, + jlong handle, + jstring name) { + TF_Graph* g = requireHandle(env, handle); + if (g == nullptr) return 0; + const char* cname = env->GetStringUTFChars(name, nullptr); + TF_Operation* op = TF_GraphOperationByName(g, cname); + env->ReleaseStringUTFChars(name, cname); + return reinterpret_cast(op); +} + JNIEXPORT void JNICALL Java_org_tensorflow_Graph_importGraphDef( JNIEnv* env, jclass clazz, jlong handle, jbyteArray graph_def, jstring prefix) { diff --git a/tensorflow/java/src/main/native/graph_jni.h b/tensorflow/java/src/main/native/graph_jni.h index 1e847aa4e58edb..b84c11578ee584 100644 --- a/tensorflow/java/src/main/native/graph_jni.h +++ b/tensorflow/java/src/main/native/graph_jni.h @@ -37,6 +37,14 @@ JNIEXPORT jlong JNICALL Java_org_tensorflow_Graph_allocate(JNIEnv *, jclass); JNIEXPORT void JNICALL Java_org_tensorflow_Graph_delete(JNIEnv *, jclass, jlong); +/* + * Class: org_tensorflow_Graph + * Method: operation + * Signature: (JLjava/lang/String;)J + */ +JNIEXPORT jlong JNICALL Java_org_tensorflow_Graph_operation(JNIEnv *, jclass, + jlong, jstring); + /* * Class: org_tensorflow_Graph * Method: importGraphDef diff --git a/tensorflow/java/src/main/native/operation_jni.cc b/tensorflow/java/src/main/native/operation_jni.cc new file mode 100644 index 00000000000000..e2eaacd4189e48 --- /dev/null +++ b/tensorflow/java/src/main/native/operation_jni.cc @@ -0,0 +1,57 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/java/src/main/native/operation_jni.h" + +#include "tensorflow/c/c_api.h" +#include "tensorflow/java/src/main/native/exception_jni.h" + +namespace { +TF_Operation* requireHandle(JNIEnv* env, jlong handle) { + static_assert(sizeof(jlong) >= sizeof(TF_Operation*), + "Cannot package C object pointers as a Java long"); + if (handle == 0) { + throwException( + env, kNullPointerException, + "close() has been called on the Graph this Operation was a part of"); + return nullptr; + } + return reinterpret_cast(handle); +} +} // namespace + +JNIEXPORT jstring JNICALL Java_org_tensorflow_Operation_name(JNIEnv* env, + jclass clazz, + jlong handle) { + TF_Operation* op = requireHandle(env, handle); + if (op == nullptr) return nullptr; + return env->NewStringUTF(TF_OperationName(op)); +} + +JNIEXPORT jstring JNICALL Java_org_tensorflow_Operation_type(JNIEnv* env, + jclass clazz, + jlong handle) { + TF_Operation* op = requireHandle(env, handle); + if (op == nullptr) return nullptr; + return env->NewStringUTF(TF_OperationOpType(op)); +} + +JNIEXPORT jint JNICALL Java_org_tensorflow_Operation_numOutputs(JNIEnv* env, + jclass clazz, + jlong handle) { + TF_Operation* op = requireHandle(env, handle); + if (op == nullptr) return 0; + return TF_OperationNumOutputs(op); +} diff --git a/tensorflow/java/src/main/native/operation_jni.h b/tensorflow/java/src/main/native/operation_jni.h new file mode 100644 index 00000000000000..ca25ef728eeb54 --- /dev/null +++ b/tensorflow/java/src/main/native/operation_jni.h @@ -0,0 +1,53 @@ + +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_JAVA_OPERATION_JNI_H_ +#define TENSORFLOW_JAVA_OPERATION_JNI_H_ + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/* + * Class: org_tensorflow_Operation + * Method: name + * Signature: (J)Ljava/lang/String; + */ +JNIEXPORT jstring JNICALL Java_org_tensorflow_Operation_name(JNIEnv *, jclass, + jlong); + +/* + * Class: org_tensorflow_Operation + * Method: type + * Signature: (J)Ljava/lang/String; + */ +JNIEXPORT jstring JNICALL Java_org_tensorflow_Operation_type(JNIEnv *, jclass, + jlong); + +/* + * Class: org_tensorflow_Operation + * Method: numOutputs + * Signature: (J)I + */ +JNIEXPORT jint JNICALL Java_org_tensorflow_Operation_numOutputs(JNIEnv *, + jclass, jlong); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus +#endif // TENSORFLOW_JAVA_OPERATION_JNI_H_ diff --git a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java index 1b6743531bc960..78df0d6d1e22a5 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java @@ -16,6 +16,7 @@ package org.tensorflow; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; import static org.junit.Assert.fail; import java.nio.file.Files; @@ -29,21 +30,42 @@ public class GraphTest { @Test - public void graphDefImportAndExport() { - try (Graph g = new Graph()) { - final byte[] inGraphDef = Files.readAllBytes(Paths.get("tensorflow/java/test_graph_def.data")); - g.importGraphDef(inGraphDef); - final byte[] outGraphDef = g.toGraphDef(); - // The graphs may not be identical as the proto format allows the same message - // to be encoded in multiple ways. Once the Graph API is expressive enough - // to construct graphs and query for nodes/operations, use that. - // Till then a very crude test: - assertEquals(inGraphDef.length, outGraphDef.length); + public void graphDefRoundTrip() { + try (Graph imported = new Graph()) { + final byte[] inGraphDef = + Files.readAllBytes(Paths.get("tensorflow/java/test_graph_def.data")); + imported.importGraphDef(inGraphDef); + validateImportedGraph(imported, ""); + + final byte[] outGraphDef = imported.toGraphDef(); + try (Graph exported = new Graph()) { + exported.importGraphDef(outGraphDef, "HeyHeyHey"); + validateImportedGraph(exported, "HeyHeyHey/"); + } + // Knowing how test_graph_def.data was generated, it should have these nodes: } catch (Exception e) { fail("Unexpected exception: " + e); } } + // Helper function whose implementation is based on knowledge of how test_graph_def.data was + // produced. + private void validateImportedGraph(Graph g, String prefix) { + Operation op = g.operation(prefix + "MyConstant"); + assertNotNull(op); + assertEquals(prefix + "MyConstant", op.name()); + assertEquals("Const", op.type()); + assertEquals(1, op.numOutputs()); + assertEquals(op, op.output(0).op()); + + op = g.operation(prefix + "while/Less"); + assertNotNull(op); + assertEquals(prefix + "while/Less", op.name()); + assertEquals("Less", op.type()); + assertEquals(1, op.numOutputs()); + assertEquals(op, op.output(0).op()); + } + @Test public void failImportOnInvalidGraphDefs() { try (Graph g = new Graph()) { diff --git a/tensorflow/java/src/test/python/graphdef.py b/tensorflow/java/src/test/python/graphdef.py index 1f3dfbf85c3424..91ea4cc650266a 100644 --- a/tensorflow/java/src/test/python/graphdef.py +++ b/tensorflow/java/src/test/python/graphdef.py @@ -30,7 +30,7 @@ def main(): if len(sys.argv) != 2: print('Usage: ' + sys.argv[0] + ' ') sys.exit(-1) - i = tf.constant(0) + i = tf.constant(0, name='MyConstant') c = lambda i: tf.less(i, 10) b = lambda i: tf.add(i, 1) tf.while_loop(c, b, [i])