Skip to content

Commit

Permalink
implement :inv(), 1.0 / tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
hughperkins committed Jan 20, 2016
1 parent 4bf1d24 commit a6cd275
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/TensorMath.lua
Original file line number Diff line number Diff line change
Expand Up @@ -694,7 +694,7 @@ for _,name in ipairs({"log", "log1p", "exp",
"tan", "atan", "tanh",
"sqrt", "sigmoid",
"ceil", "floor",
"abs", "sign", "round", "neg"}) do
"abs", "sign", "round", "neg", "cinv"}) do

wrap(name,
cname(name),
Expand Down
2 changes: 1 addition & 1 deletion src/lib/THClTensorMath.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ THCL_API int THClTensor_logicalall(THClState *state, THClTensor *self);
THCL_API int THClTensor_logicalany(THClState *state, THClTensor *self);


// Original functions, not in torch or cutorch:
THCL_API void THClTensor_cinv(THClState *state, THClTensor *self, THClTensor *src);
THCL_API void THClTensor_neg(THClState *state, THClTensor *self, THClTensor *src);
THCL_API void THClTensor_sub(THClState *state, THClTensor *self, THClTensor *src, float value);
THCL_API void THClTensor_csub(THClState *state, THClTensor *self, THClTensor *src1, float value, THClTensor *src2);
Expand Down
1 change: 1 addition & 0 deletions src/lib/THClTensorMathPointwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ IMPLEMENT_CL_TENSOR_BASIC_FUNC(floor, "floor")
IMPLEMENT_CL_TENSOR_BASIC_FUNC(abs, "fabs")
IMPLEMENT_CL_TENSOR_BASIC_FUNC(round, "round")
IMPLEMENT_CL_TENSOR_BASIC_FUNC(neg, "-")
IMPLEMENT_CL_TENSOR_BASIC_FUNC(cinv, "1.0f / ")

#undef IMPLEMENT_CL_TENSOR_BASIC_FUNC

Expand Down
6 changes: 2 additions & 4 deletions src/test/unit_tensor.lua
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ end

for _,name in ipairs({'abs', 'sqrt', 'log','exp', 'cos',
'acos', 'sin', 'asin', 'atan', 'tanh', 'ceil', 'floor',
'abs', 'round', 'sign', 'sigmoid'}) do
'abs', 'round', 'sign', 'sigmoid', 'cinv'}) do
cltorch.tests.tensor['inplace_' .. name] = function()
c = torch.ClTensor{{4, 2, -1},
{3.1,1.2, 4.9}}
Expand All @@ -249,7 +249,7 @@ end

for _,name in ipairs({'abs', 'sqrt', 'log','exp', 'cos',
'acos', 'sin', 'asin', 'atan', 'tanh', 'ceil', 'floor',
'abs', 'round', 'sign', 'sigmoid'}) do
'abs', 'round', 'sign', 'sigmoid', 'cinv'}) do
cltorch.tests.tensor['outplace_' .. name] = function()
c = torch.ClTensor{{4, 2, -1},
{3.1,1.2, 4.9}}
Expand Down Expand Up @@ -554,7 +554,6 @@ function cltorch.tests.tensor.test_addcdiv()
tester:asserteq(A:clone():addcdiv(1.234,B,C), (A:clone():cl():addcdiv(1.234, B:clone():cl(),C:clone():cl())):double())
end

-- this function doesnt exist in base torch
function cltorch.tests.tensor.test_neg()
-- no neg for Tensors, only for clTensor, but we can use '-' to
-- compare
Expand All @@ -567,7 +566,6 @@ function cltorch.tests.tensor.test_neg()
tester:asserteq(negA, negAcl2:double())
end

-- this function doesnt exist in base torch
function cltorch.tests.tensor.test_sub()
local s = torch.LongStorage{60,50}
local A = torch.Tensor(s):uniform() - 0.5
Expand Down

0 comments on commit a6cd275

Please sign in to comment.