diff --git a/tensorflow/java/src/main/java/org/tensorflow/Tensor.java b/tensorflow/java/src/main/java/org/tensorflow/Tensor.java
index a64f8241e80135..cb0da8d8650ac6 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/Tensor.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/Tensor.java
@@ -21,6 +21,8 @@
/**
* A typed multi-dimensional array.
*
+ * Instances of a Tensor are not thread-safe.
+ *
*
WARNING: Resources consumed by the Tensor object must be explicitly freed by
* invoking the {@link #close()} method when the object is no longer needed. For example, using a
* try-with-resources block like:
@@ -78,7 +80,7 @@ public static Tensor create(Object obj) {
*
*
WARNING:If not invoked, memory will be leaked.
*
- *
The Tensor object is no longer usable after {@code close} is invoked.
+ *
The Tensor object is no longer usable after {@code close} returns.
*/
@Override
public void close() {
diff --git a/tensorflow/java/src/main/native/exception_jni.cc b/tensorflow/java/src/main/native/exception_jni.cc
new file mode 100644
index 00000000000000..3c1c9c2c27a60f
--- /dev/null
+++ b/tensorflow/java/src/main/native/exception_jni.cc
@@ -0,0 +1,35 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include
+
+#include "tensorflow/c/c_api.h"
+#include "tensorflow/java/src/main/native/exception_jni.h"
+
+const char kIllegalArgumentException[] = "java/lang/IllegalArgumentException";
+const char kIllegalStateException[] = "java/lang/IllegalStateException";
+const char kNullPointerException[] = "java/lang/NullPointerException";
+
+void throwException(JNIEnv* env, const char* clazz, const char* fmt, ...) {
+ va_list args;
+ va_start(args, fmt);
+ char* message = nullptr;
+ if (vasprintf(&message, fmt, args) >= 0) {
+ env->ThrowNew(env->FindClass(clazz), message);
+ } else {
+ env->ThrowNew(env->FindClass(clazz), "");
+ }
+ va_end(args);
+}
diff --git a/tensorflow/java/src/main/native/exception_jni.h b/tensorflow/java/src/main/native/exception_jni.h
new file mode 100644
index 00000000000000..b1239937f12962
--- /dev/null
+++ b/tensorflow/java/src/main/native/exception_jni.h
@@ -0,0 +1,34 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_JAVA_EXCEPTION_JNI_H_
+#define TENSORFLOW_JAVA_EXCEPTION_JNI_H_
+
+#include
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+extern const char kIllegalArgumentException[];
+extern const char kIllegalStateException[];
+extern const char kNullPointerException[];
+
+void throwException(JNIEnv* env, const char* clazz, const char* fmt, ...);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+#endif // TENSORFLOW_JAVA_EXCEPTION_JNI_H_
diff --git a/tensorflow/java/src/main/native/tensor_jni.cc b/tensorflow/java/src/main/native/tensor_jni.cc
index 4cc7f096486c1e..3b9a081a36ac85 100644
--- a/tensorflow/java/src/main/native/tensor_jni.cc
+++ b/tensorflow/java/src/main/native/tensor_jni.cc
@@ -16,29 +16,22 @@ limitations under the License.
#include "tensorflow/java/src/main/native/tensor_jni.h"
#include
-#include
#include
#include
#include
#include "tensorflow/c/c_api.h"
+#include "tensorflow/java/src/main/native/exception_jni.h"
namespace {
-const char kIllegalArgumentException[] = "java/lang/IllegalArgumentException";
-const char kIllegalStateException[] = "java/lang/IllegalStateException";
-const char kNullPointerException[] = "java/lang/kNullPointerException";
-
-void throwException(JNIEnv* env, const char* clazz, const char* fmt, ...) {
- va_list args;
- va_start(args, fmt);
- char* message = nullptr;
- if (vasprintf(&message, fmt, args) >= 0) {
- env->ThrowNew(env->FindClass(clazz), message);
- } else {
- env->ThrowNew(env->FindClass(clazz), "");
+TF_Tensor* requireHandle(JNIEnv* env, jlong handle) {
+ if (handle == 0) {
+ throwException(env, kNullPointerException,
+ "close() was called on the Tensor");
+ return nullptr;
}
- va_end(args);
+ return reinterpret_cast(handle);
}
size_t elemByteSize(TF_DataType dtype) {
@@ -272,9 +265,7 @@ JNIEXPORT jlong JNICALL Java_org_tensorflow_Tensor_allocate(JNIEnv* env,
JNIEXPORT void JNICALL Java_org_tensorflow_Tensor_delete(JNIEnv* env,
jclass clazz,
jlong handle) {
- if (handle == 0) {
- return;
- }
+ if (handle == 0) return;
TF_DeleteTensor(reinterpret_cast(handle));
}
@@ -282,8 +273,8 @@ JNIEXPORT void JNICALL Java_org_tensorflow_Tensor_setValue(JNIEnv* env,
jclass clazz,
jlong handle,
jobject value) {
- assert(handle != 0);
- TF_Tensor* t = reinterpret_cast(handle);
+ TF_Tensor* t = requireHandle(env, handle);
+ if (t == nullptr) return;
int num_dims = TF_NumDims(t);
TF_DataType dtype = TF_TensorType(t);
void* data = TF_TensorData(t);
@@ -299,8 +290,9 @@ JNIEXPORT void JNICALL Java_org_tensorflow_Tensor_setValue(JNIEnv* env,
#define DEFINE_GET_SCALAR_METHOD(jtype, dtype, method_suffix) \
JNIEXPORT jtype JNICALL Java_org_tensorflow_Tensor_scalar##method_suffix( \
JNIEnv* env, jclass clazz, jlong handle) { \
- TF_Tensor* t = reinterpret_cast(handle); \
jtype ret = 0; \
+ TF_Tensor* t = requireHandle(env, handle); \
+ if (t == nullptr) return ret; \
if (TF_NumDims(t) != 0) { \
throwException(env, kIllegalStateException, "Tensor is not a scalar"); \
} else if (TF_TensorType(t) != dtype) { \
@@ -322,14 +314,8 @@ JNIEXPORT void JNICALL Java_org_tensorflow_Tensor_readNDArray(JNIEnv* env,
jclass clazz,
jlong handle,
jobject value) {
- // The exceptions thrown use "copyTo()" since readNDArray is a private
- // function meant to serve the public Java copyTo() method.
- if (handle == 0) {
- throwException(env, kNullPointerException,
- "copyTo() cannot be called after close()");
- return;
- }
- TF_Tensor* t = reinterpret_cast(handle);
+ TF_Tensor* t = requireHandle(env, handle);
+ if (t == nullptr) return;
int num_dims = TF_NumDims(t);
TF_DataType dtype = TF_TensorType(t);
const void* data = TF_TensorData(t);
diff --git a/tensorflow/java/src/test/java/org/tensorflow/TensorTest.java b/tensorflow/java/src/test/java/org/tensorflow/TensorTest.java
index 7e74da3d170c66..41b35b27d216c0 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/TensorTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/TensorTest.java
@@ -184,4 +184,16 @@ public void failOnZeroDimension() {
// The expected exception.
}
}
+
+ @Test
+ public void useAfterClose() {
+ int n = 4;
+ Tensor t = Tensor.create(n);
+ t.close();
+ try {
+ t.intValue();
+ } catch (NullPointerException e) {
+ // The expected exception.
+ }
+ }
}