forked from torch/cutorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
FFI.lua
50 lines (39 loc) · 1.1 KB
/
FFI.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
if jit then
local ffi = require 'ffi'
local cdefs = [[
typedef struct THCudaStorage
{
float *data;
long size;
int refcount;
char flag;
THAllocator *allocator;
void *allocatorContext;
} THCudaStorage;
typedef struct THCudaTensor
{
long *size;
long *stride;
int nDimension;
THCudaStorage *storage;
long storageOffset;
int refcount;
char flag;
} THCudaTensor;
]]
ffi.cdef(cdefs)
local Storage = torch.getmetatable('torch.CudaStorage')
local Storage_tt = ffi.typeof('THCudaStorage**')
rawset(Storage, "cdata", function(self) return Storage_tt(self)[0] end)
rawset(Storage, "data", function(self) return Storage_tt(self)[0].data end)
-- Tensor
local Tensor = torch.getmetatable('torch.CudaTensor')
local Tensor_tt = ffi.typeof('THCudaTensor**')
rawset(Tensor, "cdata", function(self) return Tensor_tt(self)[0] end)
rawset(Tensor, "data",
function(self)
self = Tensor_tt(self)[0]
return self.storage ~= nil and self.storage.data + self.storageOffset or nil
end
)
end