From a6cd27536c6fe46e9d2a8dd658567217a9847442 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Wed, 20 Jan 2016 13:40:13 +0100 Subject: [PATCH] implement :inv(), 1.0 / tensor --- src/TensorMath.lua | 2 +- src/lib/THClTensorMath.h | 2 +- src/lib/THClTensorMathPointwise.cpp | 1 + src/test/unit_tensor.lua | 6 ++---- 4 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/TensorMath.lua b/src/TensorMath.lua index 133e394..bd1fabf 100644 --- a/src/TensorMath.lua +++ b/src/TensorMath.lua @@ -694,7 +694,7 @@ for _,name in ipairs({"log", "log1p", "exp", "tan", "atan", "tanh", "sqrt", "sigmoid", "ceil", "floor", - "abs", "sign", "round", "neg"}) do + "abs", "sign", "round", "neg", "cinv"}) do wrap(name, cname(name), diff --git a/src/lib/THClTensorMath.h b/src/lib/THClTensorMath.h index b45d460..5fe80f5 100644 --- a/src/lib/THClTensorMath.h +++ b/src/lib/THClTensorMath.h @@ -121,7 +121,7 @@ THCL_API int THClTensor_logicalall(THClState *state, THClTensor *self); THCL_API int THClTensor_logicalany(THClState *state, THClTensor *self); -// Original functions, not in torch or cutorch: +THCL_API void THClTensor_cinv(THClState *state, THClTensor *self, THClTensor *src); THCL_API void THClTensor_neg(THClState *state, THClTensor *self, THClTensor *src); THCL_API void THClTensor_sub(THClState *state, THClTensor *self, THClTensor *src, float value); THCL_API void THClTensor_csub(THClState *state, THClTensor *self, THClTensor *src1, float value, THClTensor *src2); diff --git a/src/lib/THClTensorMathPointwise.cpp b/src/lib/THClTensorMathPointwise.cpp index c1856df..d990601 100644 --- a/src/lib/THClTensorMathPointwise.cpp +++ b/src/lib/THClTensorMathPointwise.cpp @@ -55,6 +55,7 @@ IMPLEMENT_CL_TENSOR_BASIC_FUNC(floor, "floor") IMPLEMENT_CL_TENSOR_BASIC_FUNC(abs, "fabs") IMPLEMENT_CL_TENSOR_BASIC_FUNC(round, "round") IMPLEMENT_CL_TENSOR_BASIC_FUNC(neg, "-") +IMPLEMENT_CL_TENSOR_BASIC_FUNC(cinv, "1.0f / ") #undef IMPLEMENT_CL_TENSOR_BASIC_FUNC diff --git a/src/test/unit_tensor.lua b/src/test/unit_tensor.lua index 60ba585..b09ddf3 100644 --- a/src/test/unit_tensor.lua +++ b/src/test/unit_tensor.lua @@ -235,7 +235,7 @@ end for _,name in ipairs({'abs', 'sqrt', 'log','exp', 'cos', 'acos', 'sin', 'asin', 'atan', 'tanh', 'ceil', 'floor', - 'abs', 'round', 'sign', 'sigmoid'}) do + 'abs', 'round', 'sign', 'sigmoid', 'cinv'}) do cltorch.tests.tensor['inplace_' .. name] = function() c = torch.ClTensor{{4, 2, -1}, {3.1,1.2, 4.9}} @@ -249,7 +249,7 @@ end for _,name in ipairs({'abs', 'sqrt', 'log','exp', 'cos', 'acos', 'sin', 'asin', 'atan', 'tanh', 'ceil', 'floor', - 'abs', 'round', 'sign', 'sigmoid'}) do + 'abs', 'round', 'sign', 'sigmoid', 'cinv'}) do cltorch.tests.tensor['outplace_' .. name] = function() c = torch.ClTensor{{4, 2, -1}, {3.1,1.2, 4.9}} @@ -554,7 +554,6 @@ function cltorch.tests.tensor.test_addcdiv() tester:asserteq(A:clone():addcdiv(1.234,B,C), (A:clone():cl():addcdiv(1.234, B:clone():cl(),C:clone():cl())):double()) end --- this function doesnt exist in base torch function cltorch.tests.tensor.test_neg() -- no neg for Tensors, only for clTensor, but we can use '-' to -- compare @@ -567,7 +566,6 @@ function cltorch.tests.tensor.test_neg() tester:asserteq(negA, negAcl2:double()) end --- this function doesnt exist in base torch function cltorch.tests.tensor.test_sub() local s = torch.LongStorage{60,50} local A = torch.Tensor(s):uniform() - 0.5