Skip to content

Commit

Permalink
Java: Introduce the Operation class.
Browse files Browse the repository at this point in the history
Add the Operation and Output classes and the ability to query
an Operation in an Graph by name.

Another step in the journey that is zplizzi#5
Change: 141122031
  • Loading branch information
asimshankar authored and tensorflower-gardener committed Dec 6, 2016
1 parent 57004b1 commit 22812d5
Show file tree
Hide file tree
Showing 9 changed files with 354 additions and 14 deletions.
71 changes: 68 additions & 3 deletions tensorflow/java/src/main/java/org/tensorflow/Graph.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,35 @@ public Graph() {
@Override
public void close() {
synchronized (nativeHandleLock) {
if (nativeHandle != 0) {
delete(nativeHandle);
nativeHandle = 0;
if (nativeHandle == 0) {
return;
}
while (refcount > 0) {
try {
nativeHandleLock.wait();
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
// Possible leak of the graph in this case?
return;
}
}
delete(nativeHandle);
nativeHandle = 0;
}
}

/**
* Returns the operation (node in the Graph) with the provided name.
*
* <p>Or {@code null} if no such operation exists in the Graph.
*/
public Operation operation(String name) {
synchronized (nativeHandleLock) {
long oph = operation(nativeHandle, name);
if (oph == 0) {
return null;
}
return new Operation(this, oph);
}
}

Expand Down Expand Up @@ -89,11 +114,51 @@ public byte[] toGraphDef() {

private final Object nativeHandleLock = new Object();
private long nativeHandle;
private int refcount = 0;

// Related native objects (such as the TF_Operation object backing an Operation instance)
// have a validity tied to that of the Graph. The handles to those native objects are not
// valid after Graph.close() has been invoked.
//
// Instances of the Reference class should be used to ensure the Graph has not been closed
// while dependent handles are in use.
class Reference implements AutoCloseable {
private Reference() {
synchronized (Graph.this.nativeHandleLock) {
active = Graph.this.nativeHandle != 0;
if (!active) {
throw new IllegalStateException("close() has been called on the Graph");
}
Graph.this.refcount++;
}
}

@Override
public void close() {
synchronized (Graph.this.nativeHandleLock) {
if (!active) {
return;
}
active = false;
if (--Graph.this.refcount == 0) {
Graph.this.nativeHandleLock.notifyAll();
}
}
}

private boolean active;
}

Reference ref() {
return new Reference();
}

private static native long allocate();

private static native void delete(long handle);

private static native long operation(long handle, String name);

private static native void importGraphDef(long handle, byte[] graphDef, String prefix)
throws IllegalArgumentException;

Expand Down
77 changes: 77 additions & 0 deletions tensorflow/java/src/main/java/org/tensorflow/Operation.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

package org.tensorflow;

/**
* A Graph node that performs computation on Tensors.
*
* <p>An Operation is a node in a {@link Graph} that takes zero or more {@link Tensor}s (produced by
* other Operations in the Graph) as input, and produces zero or more {@link Tensor}s as output.
*
* <p>Operation instances are valid only as long as the Graph they are a part of is valid. Thus, if
* {@link Graph#close()} has been invoked, then methods on the Operation instance may fail with an
* {@code IllegalStateException}.
*
* <p>Operation instances are immutable and thread-safe.
*/
public final class Operation {

// Create an Operation instance referring to an operation in g, with the given handle to the C
// TF_Operation object. The handle is valid only as long as g has not been closed, hence it is
// called unsafeHandle. Graph.ref() is used to safely use the unsafeHandle.
Operation(Graph g, long unsafeNativeHandle) {
this.graph = g;
this.unsafeNativeHandle = unsafeNativeHandle;
}

/** Returns the full name of the Operation. */
public String name() {
try (Graph.Reference r = graph.ref()) {
return name(unsafeNativeHandle);
}
}

/**
* Returns the type of the operation, i.e., the name of the computation performed by the
* operation.
*/
public String type() {
try (Graph.Reference r = graph.ref()) {
return type(unsafeNativeHandle);
}
}

/** Returns the number of tensors output by this operation. */
public int numOutputs() {
try (Graph.Reference r = graph.ref()) {
return numOutputs(unsafeNativeHandle);
}
}

/** Returns a symbolic handle to one of the tensors produced by this operation. */
public Output output(int idx) {
return new Output(this, idx);
}

private final long unsafeNativeHandle;
private final Graph graph;

private static native String name(long handle);

private static native String type(long handle);

private static native int numOutputs(long handle);
}
44 changes: 44 additions & 0 deletions tensorflow/java/src/main/java/org/tensorflow/Output.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

package org.tensorflow;

/**
* A symbolic handle to a tensor produced by an {@link Operation}.
*
* <p>An Output is a symbolic handle to a tensor. The value of the Tensor is computed by executing
* the {@link Operation} in a {@link Session}.
*/
public final class Output {

/** Handle to the idx-th output of the Operation {@code op}. */
public Output(Operation op, int idx) {
operation = op;
index = idx;
}

/** Returns the Operation that will produce the tensor referred to by this Output. */
public Operation op() {
return operation;
}

/** Returns the index into the outputs of the Operation. */
public int index() {
return index;
}

private final Operation operation;
private final int index;
}
14 changes: 14 additions & 0 deletions tensorflow/java/src/main/native/graph_jni.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ limitations under the License.

namespace {
TF_Graph* requireHandle(JNIEnv* env, jlong handle) {
static_assert(sizeof(jlong) >= sizeof(TF_Graph*),
"Cannot package C object pointers as a Java long");
if (handle == 0) {
throwException(env, kNullPointerException,
"close() has been called on the Graph");
Expand All @@ -40,6 +42,18 @@ JNIEXPORT void JNICALL Java_org_tensorflow_Graph_delete(JNIEnv*, jclass,
TF_DeleteGraph(reinterpret_cast<TF_Graph*>(handle));
}

JNIEXPORT jlong JNICALL Java_org_tensorflow_Graph_operation(JNIEnv* env,
jclass clazz,
jlong handle,
jstring name) {
TF_Graph* g = requireHandle(env, handle);
if (g == nullptr) return 0;
const char* cname = env->GetStringUTFChars(name, nullptr);
TF_Operation* op = TF_GraphOperationByName(g, cname);
env->ReleaseStringUTFChars(name, cname);
return reinterpret_cast<jlong>(op);
}

JNIEXPORT void JNICALL Java_org_tensorflow_Graph_importGraphDef(
JNIEnv* env, jclass clazz, jlong handle, jbyteArray graph_def,
jstring prefix) {
Expand Down
8 changes: 8 additions & 0 deletions tensorflow/java/src/main/native/graph_jni.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@ JNIEXPORT jlong JNICALL Java_org_tensorflow_Graph_allocate(JNIEnv *, jclass);
JNIEXPORT void JNICALL Java_org_tensorflow_Graph_delete(JNIEnv *, jclass,
jlong);

/*
* Class: org_tensorflow_Graph
* Method: operation
* Signature: (JLjava/lang/String;)J
*/
JNIEXPORT jlong JNICALL Java_org_tensorflow_Graph_operation(JNIEnv *, jclass,
jlong, jstring);

/*
* Class: org_tensorflow_Graph
* Method: importGraphDef
Expand Down
57 changes: 57 additions & 0 deletions tensorflow/java/src/main/native/operation_jni.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/java/src/main/native/operation_jni.h"

#include "tensorflow/c/c_api.h"
#include "tensorflow/java/src/main/native/exception_jni.h"

namespace {
TF_Operation* requireHandle(JNIEnv* env, jlong handle) {
static_assert(sizeof(jlong) >= sizeof(TF_Operation*),
"Cannot package C object pointers as a Java long");
if (handle == 0) {
throwException(
env, kNullPointerException,
"close() has been called on the Graph this Operation was a part of");
return nullptr;
}
return reinterpret_cast<TF_Operation*>(handle);
}
} // namespace

JNIEXPORT jstring JNICALL Java_org_tensorflow_Operation_name(JNIEnv* env,
jclass clazz,
jlong handle) {
TF_Operation* op = requireHandle(env, handle);
if (op == nullptr) return nullptr;
return env->NewStringUTF(TF_OperationName(op));
}

JNIEXPORT jstring JNICALL Java_org_tensorflow_Operation_type(JNIEnv* env,
jclass clazz,
jlong handle) {
TF_Operation* op = requireHandle(env, handle);
if (op == nullptr) return nullptr;
return env->NewStringUTF(TF_OperationOpType(op));
}

JNIEXPORT jint JNICALL Java_org_tensorflow_Operation_numOutputs(JNIEnv* env,
jclass clazz,
jlong handle) {
TF_Operation* op = requireHandle(env, handle);
if (op == nullptr) return 0;
return TF_OperationNumOutputs(op);
}
53 changes: 53 additions & 0 deletions tensorflow/java/src/main/native/operation_jni.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@

/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_JAVA_OPERATION_JNI_H_
#define TENSORFLOW_JAVA_OPERATION_JNI_H_

#include <jni.h>

#ifdef __cplusplus
extern "C" {
#endif

/*
* Class: org_tensorflow_Operation
* Method: name
* Signature: (J)Ljava/lang/String;
*/
JNIEXPORT jstring JNICALL Java_org_tensorflow_Operation_name(JNIEnv *, jclass,
jlong);

/*
* Class: org_tensorflow_Operation
* Method: type
* Signature: (J)Ljava/lang/String;
*/
JNIEXPORT jstring JNICALL Java_org_tensorflow_Operation_type(JNIEnv *, jclass,
jlong);

/*
* Class: org_tensorflow_Operation
* Method: numOutputs
* Signature: (J)I
*/
JNIEXPORT jint JNICALL Java_org_tensorflow_Operation_numOutputs(JNIEnv *,
jclass, jlong);

#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
#endif // TENSORFLOW_JAVA_OPERATION_JNI_H_
Loading

0 comments on commit 22812d5

Please sign in to comment.