Skip to content

Commit

Permalink
Merge pull request apache#619 from dcslin/ceil
Browse files Browse the repository at this point in the history
added ceil to tensor, and tests
  • Loading branch information
nudles authored Mar 4, 2020
2 parents 0272da5 + 5b784f8 commit 001ba4c
Show file tree
Hide file tree
Showing 11 changed files with 80 additions and 0 deletions.
2 changes: 2 additions & 0 deletions include/singa/core/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,7 @@ void RepeatDataToFrom(bool broadcast_flag, const vector<size_t> &repeats,

// =============Element-wise operations====================================
Tensor Abs(const Tensor &in);
Tensor Ceil(const Tensor &in);
Tensor Exp(const Tensor &in);
Tensor Log(const Tensor &in);
Tensor ReLU(const Tensor &in);
Expand All @@ -366,6 +367,7 @@ Tensor Atanh(const Tensor &in);
Tensor Transform(const Tensor &in);

void Abs(const Tensor &in, Tensor *out);
void Ceil(const Tensor &in, Tensor *out);
void Exp(const Tensor &in, Tensor *out);
void Log(const Tensor &in, Tensor *out);
void ReLU(const Tensor &in, Tensor *out);
Expand Down
11 changes: 11 additions & 0 deletions python/singa/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,6 +805,17 @@ def exp(t):
return _call_singa_func(singa.Exp, t.data)


def ceil(t):
'''
Args:
t (Tensor): input Tensor
Returns:
a new Tensor whose element y = ceil(x), x is an element of t
'''
return _call_singa_func(singa.Ceil, t.data)


def log(t):
'''
Args:
Expand Down
1 change: 1 addition & 0 deletions src/api/core_tensor.i
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ namespace singa{
Tensor Transpose(const Tensor &in);

Tensor Abs(const Tensor &t);
Tensor Ceil(const Tensor &t);
Tensor Exp(const Tensor &t);
Tensor Log(const Tensor &t);
Tensor ReLU(const Tensor &t);
Expand Down
11 changes: 11 additions & 0 deletions src/core/tensor/math_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,13 @@ __global__ void KernelExp(const size_t n, const float *in, float *out) {
}
}

__global__ void KernelCeil2(const size_t n, const float *in, float *out) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n;
i += blockDim.x * gridDim.x) {
out[i] = std::ceil(in[i]);
}
}

__global__ void KernelLog(const size_t n, const float *in, float *out) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n;
i += blockDim.x * gridDim.x) {
Expand Down Expand Up @@ -510,6 +517,10 @@ void exp(const size_t n, const float *in, float *out, cudaStream_t s) {
KernelExp <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, in, out);
}

void ceil2(const size_t n, const float *in, float *out, cudaStream_t s) {
KernelCeil2 <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, in, out);
}

void log(const size_t n, const float *in, float *out, cudaStream_t s) {
KernelLog <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, in, out);
}
Expand Down
1 change: 1 addition & 0 deletions src/core/tensor/math_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ void set(const size_t n, const float v, float *out, cudaStream_t s);
void abs(const size_t n, const float *in, float *out, cudaStream_t s);
void sign(const size_t n, const float *in, float *out, cudaStream_t s);
void exp(const size_t n, const float *in, float *out, cudaStream_t s);
void ceil2(const size_t n, const float *in, float *out, cudaStream_t s);
void log(const size_t n, const float *in, float *out, cudaStream_t s);
void sqrt(const size_t n, const float *in, float *out, cudaStream_t s);
void square(const size_t n, const float *in, float *out, cudaStream_t s);
Expand Down
1 change: 1 addition & 0 deletions src/core/tensor/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,7 @@ template void Tensor::GetValue<int>(int *value, const size_t num);
void fn(const Tensor &in, Tensor *out) { EltwiseUnaryTensorFn(fn, in, out); }

GenUnaryTensorFn(Abs);
GenUnaryTensorFn(Ceil);
GenUnaryTensorFn(Exp);
GenUnaryTensorFn(Log);
GenUnaryTensorFn(ReLU);
Expand Down
5 changes: 5 additions & 0 deletions src/core/tensor/tensor_math.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@ void Abs(const Tensor &in, Tensor *out, Context *ctx) {
LOG(FATAL) << "Abs Not Implemented";
}

template <typename DType, typename Lang>
void Ceil(const Tensor &in, Tensor *out, Context *ctx) {
LOG(FATAL) << "Ceil Not Implemented";
}

/// out[i] = in[i] + x
template <typename DType, typename Lang>
void Add(const Tensor &in, const DType x, Tensor *out, Context *ctx) {
Expand Down
5 changes: 5 additions & 0 deletions src/core/tensor/tensor_math_cpp.h
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,11 @@ void Abs<float, lang::Cpp>(const Tensor &in, Tensor *out, Context *ctx) {
traverse_unary<float>(in, out, [](float x) { return fabs(x); });
}

template <>
void Ceil<float, lang::Cpp>(const Tensor &in, Tensor *out, Context *ctx) {
traverse_unary<float>(in, out, [](float x) { return std::ceil(x); });
}

#ifdef USE_DNNL
template <>
void SoftMax<float, lang::Cpp>(const Tensor &in, Tensor *out, Context *ctx) {
Expand Down
14 changes: 14 additions & 0 deletions src/core/tensor/tensor_math_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,20 @@ void Exp<float, lang::Cuda>(const Tensor& in, Tensor* out, Context* ctx) {
}
}

template <>
void Ceil<float, lang::Cuda>(const Tensor& in, Tensor* out, Context* ctx) {
const float* inPtr = static_cast<const float*>(in.block()->data());
float* outPtr = static_cast<float*>(out->block()->mutable_data());
const size_t num = in.Size();

if (in.stride() == out->stride()) {
cuda::ceil2(num, inPtr, outPtr, ctx->stream);
} else { // else we transform in to out to store first
Transform<float, lang::Cuda>(in, out, ctx);
cuda::ceil2(num, outPtr, outPtr, ctx->stream);
}
}

template <>
void GE<float, lang::Cuda>(const Tensor& in, const float x, Tensor* out,
Context* ctx) {
Expand Down
15 changes: 15 additions & 0 deletions test/python/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,21 @@ def test_concat(self):
np.testing.assert_array_almost_equal(
tensor.to_numpy(_cTensor_to_pyTensor(t3_ct)), np3)

def test_ceil(self):

for dev in [cpu_dev, gpu_dev]:

np1 = np.random.random([5, 6, 7, 8]).astype(np.float32)
np1 = np1 * 10
np2 = np.ceil(np1)

t1 = tensor.Tensor(device=dev, data=np1)

t2_ct = singa_api.Ceil(t1.data)

np.testing.assert_array_almost_equal(
tensor.to_numpy(_cTensor_to_pyTensor(t2_ct)), np2)


if __name__ == '__main__':
unittest.main()
14 changes: 14 additions & 0 deletions test/python/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,20 @@ def test_subscription_gpu(self):
np.testing.assert_array_almost_equal((tensor.to_numpy(sg_tensor_ret)),
np1[1:3, :, 1:, :-1])

def test_ceil(self):

for dev in [cpu_dev, gpu_dev]:

np1 = np.random.random([5, 6, 7, 8]).astype(np.float32)
np1 = np1 * 10
np2 = np.ceil(np1)

t1 = tensor.Tensor(device=dev, data=np1)

t2 = tensor.ceil(t1)

np.testing.assert_array_almost_equal(tensor.to_numpy(t2), np2)


if __name__ == '__main__':
unittest.main()

0 comments on commit 001ba4c

Please sign in to comment.