From 5b784f8d90d236420c2d705853fe80b8761f9537 Mon Sep 17 00:00:00 2001 From: dcslin <13751447+dcslin@users.noreply.github.com> Date: Tue, 3 Mar 2020 05:05:19 +0000 Subject: [PATCH] added ceil to tensor, and tests --- include/singa/core/tensor.h | 2 ++ python/singa/tensor.py | 11 +++++++++++ src/api/core_tensor.i | 1 + src/core/tensor/math_kernel.cu | 11 +++++++++++ src/core/tensor/math_kernel.h | 1 + src/core/tensor/tensor.cc | 1 + src/core/tensor/tensor_math.h | 5 +++++ src/core/tensor/tensor_math_cpp.h | 5 +++++ src/core/tensor/tensor_math_cuda.h | 14 ++++++++++++++ test/python/test_api.py | 15 +++++++++++++++ test/python/test_tensor.py | 14 ++++++++++++++ 11 files changed, 80 insertions(+) diff --git a/include/singa/core/tensor.h b/include/singa/core/tensor.h index 02817b4e93..6d4b86b497 100644 --- a/include/singa/core/tensor.h +++ b/include/singa/core/tensor.h @@ -342,6 +342,7 @@ void RepeatDataToFrom(bool broadcast_flag, const vector &repeats, // =============Element-wise operations==================================== Tensor Abs(const Tensor &in); +Tensor Ceil(const Tensor &in); Tensor Exp(const Tensor &in); Tensor Log(const Tensor &in); Tensor ReLU(const Tensor &in); @@ -366,6 +367,7 @@ Tensor Atanh(const Tensor &in); Tensor Transform(const Tensor &in); void Abs(const Tensor &in, Tensor *out); +void Ceil(const Tensor &in, Tensor *out); void Exp(const Tensor &in, Tensor *out); void Log(const Tensor &in, Tensor *out); void ReLU(const Tensor &in, Tensor *out); diff --git a/python/singa/tensor.py b/python/singa/tensor.py index f835b44094..b64c8b58f9 100755 --- a/python/singa/tensor.py +++ b/python/singa/tensor.py @@ -805,6 +805,17 @@ def exp(t): return _call_singa_func(singa.Exp, t.data) +def ceil(t): + ''' + Args: + t (Tensor): input Tensor + + Returns: + a new Tensor whose element y = ceil(x), x is an element of t + ''' + return _call_singa_func(singa.Ceil, t.data) + + def log(t): ''' Args: diff --git a/src/api/core_tensor.i b/src/api/core_tensor.i index 4550e6ae4a..28e8ac5a32 100755 --- a/src/api/core_tensor.i +++ b/src/api/core_tensor.i @@ -171,6 +171,7 @@ namespace singa{ Tensor Transpose(const Tensor &in); Tensor Abs(const Tensor &t); + Tensor Ceil(const Tensor &t); Tensor Exp(const Tensor &t); Tensor Log(const Tensor &t); Tensor ReLU(const Tensor &t); diff --git a/src/core/tensor/math_kernel.cu b/src/core/tensor/math_kernel.cu index 5ac23cc7f4..7d9df0c1d8 100644 --- a/src/core/tensor/math_kernel.cu +++ b/src/core/tensor/math_kernel.cu @@ -117,6 +117,13 @@ __global__ void KernelExp(const size_t n, const float *in, float *out) { } } +__global__ void KernelCeil2(const size_t n, const float *in, float *out) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; + i += blockDim.x * gridDim.x) { + out[i] = std::ceil(in[i]); + } +} + __global__ void KernelLog(const size_t n, const float *in, float *out) { for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) { @@ -510,6 +517,10 @@ void exp(const size_t n, const float *in, float *out, cudaStream_t s) { KernelExp <<>> (n, in, out); } +void ceil2(const size_t n, const float *in, float *out, cudaStream_t s) { + KernelCeil2 <<>> (n, in, out); +} + void log(const size_t n, const float *in, float *out, cudaStream_t s) { KernelLog <<>> (n, in, out); } diff --git a/src/core/tensor/math_kernel.h b/src/core/tensor/math_kernel.h index af5f93853b..0b9f2faa11 100644 --- a/src/core/tensor/math_kernel.h +++ b/src/core/tensor/math_kernel.h @@ -44,6 +44,7 @@ void set(const size_t n, const float v, float *out, cudaStream_t s); void abs(const size_t n, const float *in, float *out, cudaStream_t s); void sign(const size_t n, const float *in, float *out, cudaStream_t s); void exp(const size_t n, const float *in, float *out, cudaStream_t s); +void ceil2(const size_t n, const float *in, float *out, cudaStream_t s); void log(const size_t n, const float *in, float *out, cudaStream_t s); void sqrt(const size_t n, const float *in, float *out, cudaStream_t s); void square(const size_t n, const float *in, float *out, cudaStream_t s); diff --git a/src/core/tensor/tensor.cc b/src/core/tensor/tensor.cc index 715be806a2..7665009860 100644 --- a/src/core/tensor/tensor.cc +++ b/src/core/tensor/tensor.cc @@ -719,6 +719,7 @@ template void Tensor::GetValue(int *value, const size_t num); void fn(const Tensor &in, Tensor *out) { EltwiseUnaryTensorFn(fn, in, out); } GenUnaryTensorFn(Abs); +GenUnaryTensorFn(Ceil); GenUnaryTensorFn(Exp); GenUnaryTensorFn(Log); GenUnaryTensorFn(ReLU); diff --git a/src/core/tensor/tensor_math.h b/src/core/tensor/tensor_math.h index d9440adfde..35ef665573 100644 --- a/src/core/tensor/tensor_math.h +++ b/src/core/tensor/tensor_math.h @@ -86,6 +86,11 @@ void Abs(const Tensor &in, Tensor *out, Context *ctx) { LOG(FATAL) << "Abs Not Implemented"; } +template +void Ceil(const Tensor &in, Tensor *out, Context *ctx) { + LOG(FATAL) << "Ceil Not Implemented"; +} + /// out[i] = in[i] + x template void Add(const Tensor &in, const DType x, Tensor *out, Context *ctx) { diff --git a/src/core/tensor/tensor_math_cpp.h b/src/core/tensor/tensor_math_cpp.h index 9550d24b7c..7ce7d1459c 100644 --- a/src/core/tensor/tensor_math_cpp.h +++ b/src/core/tensor/tensor_math_cpp.h @@ -240,6 +240,11 @@ void Abs(const Tensor &in, Tensor *out, Context *ctx) { traverse_unary(in, out, [](float x) { return fabs(x); }); } +template <> +void Ceil(const Tensor &in, Tensor *out, Context *ctx) { + traverse_unary(in, out, [](float x) { return std::ceil(x); }); +} + #ifdef USE_DNNL template <> void SoftMax(const Tensor &in, Tensor *out, Context *ctx) { diff --git a/src/core/tensor/tensor_math_cuda.h b/src/core/tensor/tensor_math_cuda.h index ea7e9a5831..3c043ab3ac 100644 --- a/src/core/tensor/tensor_math_cuda.h +++ b/src/core/tensor/tensor_math_cuda.h @@ -346,6 +346,20 @@ void Exp(const Tensor& in, Tensor* out, Context* ctx) { } } +template <> +void Ceil(const Tensor& in, Tensor* out, Context* ctx) { + const float* inPtr = static_cast(in.block()->data()); + float* outPtr = static_cast(out->block()->mutable_data()); + const size_t num = in.Size(); + + if (in.stride() == out->stride()) { + cuda::ceil2(num, inPtr, outPtr, ctx->stream); + } else { // else we transform in to out to store first + Transform(in, out, ctx); + cuda::ceil2(num, outPtr, outPtr, ctx->stream); + } +} + template <> void GE(const Tensor& in, const float x, Tensor* out, Context* ctx) { diff --git a/test/python/test_api.py b/test/python/test_api.py index 3b847aaf80..d62b58ffdd 100644 --- a/test/python/test_api.py +++ b/test/python/test_api.py @@ -590,6 +590,21 @@ def test_concat(self): np.testing.assert_array_almost_equal( tensor.to_numpy(_cTensor_to_pyTensor(t3_ct)), np3) + def test_ceil(self): + + for dev in [cpu_dev, gpu_dev]: + + np1 = np.random.random([5, 6, 7, 8]).astype(np.float32) + np1 = np1 * 10 + np2 = np.ceil(np1) + + t1 = tensor.Tensor(device=dev, data=np1) + + t2_ct = singa_api.Ceil(t1.data) + + np.testing.assert_array_almost_equal( + tensor.to_numpy(_cTensor_to_pyTensor(t2_ct)), np2) + if __name__ == '__main__': unittest.main() diff --git a/test/python/test_tensor.py b/test/python/test_tensor.py index 872d65097b..4bf592c8e9 100644 --- a/test/python/test_tensor.py +++ b/test/python/test_tensor.py @@ -324,6 +324,20 @@ def test_subscription_gpu(self): np.testing.assert_array_almost_equal((tensor.to_numpy(sg_tensor_ret)), np1[1:3, :, 1:, :-1]) + def test_ceil(self): + + for dev in [cpu_dev, gpu_dev]: + + np1 = np.random.random([5, 6, 7, 8]).astype(np.float32) + np1 = np1 * 10 + np2 = np.ceil(np1) + + t1 = tensor.Tensor(device=dev, data=np1) + + t2 = tensor.ceil(t1) + + np.testing.assert_array_almost_equal(tensor.to_numpy(t2), np2) + if __name__ == '__main__': unittest.main()