diff --git a/tensorflow/java/src/main/java/org/tensorflow/Tensor.java b/tensorflow/java/src/main/java/org/tensorflow/Tensor.java index a64f8241e80135..cb0da8d8650ac6 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/Tensor.java +++ b/tensorflow/java/src/main/java/org/tensorflow/Tensor.java @@ -21,6 +21,8 @@ /** * A typed multi-dimensional array. * + * Instances of a Tensor are not thread-safe. + * *

WARNING: Resources consumed by the Tensor object must be explicitly freed by * invoking the {@link #close()} method when the object is no longer needed. For example, using a * try-with-resources block like: @@ -78,7 +80,7 @@ public static Tensor create(Object obj) { * *

WARNING:If not invoked, memory will be leaked. * - *

The Tensor object is no longer usable after {@code close} is invoked. + *

The Tensor object is no longer usable after {@code close} returns. */ @Override public void close() { diff --git a/tensorflow/java/src/main/native/exception_jni.cc b/tensorflow/java/src/main/native/exception_jni.cc new file mode 100644 index 00000000000000..3c1c9c2c27a60f --- /dev/null +++ b/tensorflow/java/src/main/native/exception_jni.cc @@ -0,0 +1,35 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/c/c_api.h" +#include "tensorflow/java/src/main/native/exception_jni.h" + +const char kIllegalArgumentException[] = "java/lang/IllegalArgumentException"; +const char kIllegalStateException[] = "java/lang/IllegalStateException"; +const char kNullPointerException[] = "java/lang/NullPointerException"; + +void throwException(JNIEnv* env, const char* clazz, const char* fmt, ...) { + va_list args; + va_start(args, fmt); + char* message = nullptr; + if (vasprintf(&message, fmt, args) >= 0) { + env->ThrowNew(env->FindClass(clazz), message); + } else { + env->ThrowNew(env->FindClass(clazz), ""); + } + va_end(args); +} diff --git a/tensorflow/java/src/main/native/exception_jni.h b/tensorflow/java/src/main/native/exception_jni.h new file mode 100644 index 00000000000000..b1239937f12962 --- /dev/null +++ b/tensorflow/java/src/main/native/exception_jni.h @@ -0,0 +1,34 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_JAVA_EXCEPTION_JNI_H_ +#define TENSORFLOW_JAVA_EXCEPTION_JNI_H_ + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +extern const char kIllegalArgumentException[]; +extern const char kIllegalStateException[]; +extern const char kNullPointerException[]; + +void throwException(JNIEnv* env, const char* clazz, const char* fmt, ...); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus +#endif // TENSORFLOW_JAVA_EXCEPTION_JNI_H_ diff --git a/tensorflow/java/src/main/native/tensor_jni.cc b/tensorflow/java/src/main/native/tensor_jni.cc index 4cc7f096486c1e..3b9a081a36ac85 100644 --- a/tensorflow/java/src/main/native/tensor_jni.cc +++ b/tensorflow/java/src/main/native/tensor_jni.cc @@ -16,29 +16,22 @@ limitations under the License. #include "tensorflow/java/src/main/native/tensor_jni.h" #include -#include #include #include #include #include "tensorflow/c/c_api.h" +#include "tensorflow/java/src/main/native/exception_jni.h" namespace { -const char kIllegalArgumentException[] = "java/lang/IllegalArgumentException"; -const char kIllegalStateException[] = "java/lang/IllegalStateException"; -const char kNullPointerException[] = "java/lang/kNullPointerException"; - -void throwException(JNIEnv* env, const char* clazz, const char* fmt, ...) { - va_list args; - va_start(args, fmt); - char* message = nullptr; - if (vasprintf(&message, fmt, args) >= 0) { - env->ThrowNew(env->FindClass(clazz), message); - } else { - env->ThrowNew(env->FindClass(clazz), ""); +TF_Tensor* requireHandle(JNIEnv* env, jlong handle) { + if (handle == 0) { + throwException(env, kNullPointerException, + "close() was called on the Tensor"); + return nullptr; } - va_end(args); + return reinterpret_cast(handle); } size_t elemByteSize(TF_DataType dtype) { @@ -272,9 +265,7 @@ JNIEXPORT jlong JNICALL Java_org_tensorflow_Tensor_allocate(JNIEnv* env, JNIEXPORT void JNICALL Java_org_tensorflow_Tensor_delete(JNIEnv* env, jclass clazz, jlong handle) { - if (handle == 0) { - return; - } + if (handle == 0) return; TF_DeleteTensor(reinterpret_cast(handle)); } @@ -282,8 +273,8 @@ JNIEXPORT void JNICALL Java_org_tensorflow_Tensor_setValue(JNIEnv* env, jclass clazz, jlong handle, jobject value) { - assert(handle != 0); - TF_Tensor* t = reinterpret_cast(handle); + TF_Tensor* t = requireHandle(env, handle); + if (t == nullptr) return; int num_dims = TF_NumDims(t); TF_DataType dtype = TF_TensorType(t); void* data = TF_TensorData(t); @@ -299,8 +290,9 @@ JNIEXPORT void JNICALL Java_org_tensorflow_Tensor_setValue(JNIEnv* env, #define DEFINE_GET_SCALAR_METHOD(jtype, dtype, method_suffix) \ JNIEXPORT jtype JNICALL Java_org_tensorflow_Tensor_scalar##method_suffix( \ JNIEnv* env, jclass clazz, jlong handle) { \ - TF_Tensor* t = reinterpret_cast(handle); \ jtype ret = 0; \ + TF_Tensor* t = requireHandle(env, handle); \ + if (t == nullptr) return ret; \ if (TF_NumDims(t) != 0) { \ throwException(env, kIllegalStateException, "Tensor is not a scalar"); \ } else if (TF_TensorType(t) != dtype) { \ @@ -322,14 +314,8 @@ JNIEXPORT void JNICALL Java_org_tensorflow_Tensor_readNDArray(JNIEnv* env, jclass clazz, jlong handle, jobject value) { - // The exceptions thrown use "copyTo()" since readNDArray is a private - // function meant to serve the public Java copyTo() method. - if (handle == 0) { - throwException(env, kNullPointerException, - "copyTo() cannot be called after close()"); - return; - } - TF_Tensor* t = reinterpret_cast(handle); + TF_Tensor* t = requireHandle(env, handle); + if (t == nullptr) return; int num_dims = TF_NumDims(t); TF_DataType dtype = TF_TensorType(t); const void* data = TF_TensorData(t); diff --git a/tensorflow/java/src/test/java/org/tensorflow/TensorTest.java b/tensorflow/java/src/test/java/org/tensorflow/TensorTest.java index 7e74da3d170c66..41b35b27d216c0 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/TensorTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/TensorTest.java @@ -184,4 +184,16 @@ public void failOnZeroDimension() { // The expected exception. } } + + @Test + public void useAfterClose() { + int n = 4; + Tensor t = Tensor.create(n); + t.close(); + try { + t.intValue(); + } catch (NullPointerException e) { + // The expected exception. + } + } }