Skip to content

Commit

Permalink
Implementation of totable for CUDA tensors and storages.
Browse files Browse the repository at this point in the history
  • Loading branch information
dominikgrewe committed Jan 26, 2016
1 parent c50d7cb commit 924043a
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 1 deletion.
18 changes: 18 additions & 0 deletions Tensor.lua
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,21 @@ do
rawset(metatable, func, torch[func])
end
end

local CudaTensorTypes = {
float = 'torch.CudaTensor',
double = 'torch.CudaDoubleTensor',
byte = 'torch.CudaByteTensor',
char = 'torch.CudaCharTensor',
int = 'torch.CudaIntTensor',
short = 'torch.CudaShortTensor',
long = 'torch.CudaLongTensor'
}

for ValueType, CudaTensorType in pairs(CudaTensorTypes) do
local function Tensor__totable(self)
local host_tensor = self[ValueType](self)
return host_tensor:totable()
end
rawset(torch.getmetatable(CudaTensorType), 'totable', Tensor__totable)
end
43 changes: 43 additions & 0 deletions test/test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -2154,6 +2154,49 @@ function test.cudaStorageTypeCopy()
end
end

function test.tensorToTable()
local types = {
{'CudaTensor', 'FloatTensor'},
{'CudaByteTensor', 'ByteTensor'},
{'CudaCharTensor', 'CharTensor'},
{'CudaShortTensor', 'ShortTensor'},
{'CudaIntTensor', 'IntTensor'},
{'CudaLongTensor', 'LongTensor'},
{'CudaDoubleTensor', 'DoubleTensor'},
}

for _, types in ipairs(types) do
local cudaType, hostType = unpack(types)
local dim = torch.random(5)
local size = torch.LongTensor(dim):random(5):totable()
hostTensor = torch[hostType](size):random()
cudaTensor = torch[cudaType](size):copy(hostTensor)
tester:assertTableEq(hostTensor:totable(), cudaTensor:totable(),
'wrong result for ' .. cudaType .. ':totable()')
end
end

function test.storageToTable()
local types = {
{'CudaStorage', 'FloatTensor'},
{'CudaByteStorage', 'ByteTensor'},
{'CudaCharStorage', 'CharTensor'},
{'CudaShortStorage', 'ShortTensor'},
{'CudaIntStorage', 'IntTensor'},
{'CudaLongStorage', 'LongTensor'},
{'CudaDoubleStorage', 'DoubleTensor'},
}

for _, types in ipairs(types) do
local cudaStorageType, hostTensorType = unpack(types)
local size = torch.random(10)
hostTensor = torch[hostTensorType](size):random()
cudaStorage = torch[cudaStorageType](size):copy(hostTensor:storage())
tester:assertTableEq(hostTensor:storage():totable(), cudaStorage:totable(),
'wrong result for ' .. cudaStorageType .. ':totable()')
end
end

function test.maskedSelect()
local n_row = math.random(minsize,maxsize)
local n_col = math.random(minsize,maxsize)
Expand Down
9 changes: 8 additions & 1 deletion torch/generic/Storage.c
Original file line number Diff line number Diff line change
Expand Up @@ -190,15 +190,22 @@ static int torch_Storage_(string)(lua_State *L)

static int torch_Storage_(totable)(lua_State *L)
{
THCState *state = cutorch_getstate(L);
THCStorage *storage = luaT_checkudata(L, 1, torch_Storage);
THStorage *host_storage;
long i;

/* Copy storage from device to host. */
host_storage = THStorage_(newWithSize)(THCStorage_(size)(state, storage));
THStorage_(copyCuda)(state, host_storage, storage);

lua_newtable(L);
for(i = 0; i < storage->size; i++)
{
lua_pushnumber(L, (lua_Number)storage->data[i]);
lua_pushnumber(L, (lua_Number)host_storage->data[i]);
lua_rawseti(L, -2, i+1);
}
THStorage_(free)(host_storage);
return 1;
}

Expand Down

0 comments on commit 924043a

Please sign in to comment.