-
Notifications
You must be signed in to change notification settings - Fork 2
/
serialization.lua
255 lines (235 loc) · 7.45 KB
/
serialization.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
local torch = require 'torch.env'
local class = require 'class'
local File = class.metatable('torch.File')
function File:writeBool(value)
if value then
self:writeInt(1)
else
self:writeInt(0)
end
end
function File:readBool()
return (self:readInt() == 1)
end
local TYPE_NIL = 0
local TYPE_NUMBER = 1
local TYPE_STRING = 2
local TYPE_TABLE = 3
local TYPE_TORCH = 4
local TYPE_BOOLEAN = 5
local TYPE_FUNCTION = 6
function File:isWritableObject(object)
local typename = class.type(object)
local typeidx
if type(object) ~= 'boolean' and not object then
typeidx = TYPE_NIL
elseif torch.metatable(typename) then
typeidx = TYPE_TORCH
elseif typename == 'table' then
typeidx = TYPE_TABLE
elseif typename == 'number' then
typeidx = TYPE_NUMBER
elseif typename == 'string' then
typeidx = TYPE_STRING
elseif typename == 'boolean' then
typeidx = TYPE_BOOLEAN
elseif typename == 'function' and pcall(string.dump, object) then
typeidx = TYPE_FUNCTION
end
return typeidx
end
function File:writeObject(object, force)
-- keep a record of written objects
self.__writeObjects = self.__writeObjects or {}
self.__writeObjectsRef = self.__writeObjectsRef or {}
-- if nil object, only write the type and return
if type(object) ~= 'boolean' and not object then
self:writeInt(TYPE_NIL)
return
end
-- check the type we are dealing with
local typeidx = self:isWritableObject(object)
if not typeidx then
error(string.format('unwritable object <%s>', type(object)))
end
self:writeInt(typeidx)
if typeidx == TYPE_NUMBER then
self:writeDouble(object)
elseif typeidx == TYPE_BOOLEAN then
self:writeBool(object)
elseif typeidx == TYPE_STRING then
local stringStorage = torch.CharStorage():string(object)
self:writeInt(stringStorage:size())
self:writeChar(stringStorage)
elseif typeidx == TYPE_FUNCTION then
local upvalues = {}
while true do
local name,value = debug.getupvalue(object, #upvalues+1)
if not name then break end
table.insert(upvalues, value)
end
local dumped = string.dump(object)
local stringStorage = torch.CharStorage():string(dumped)
self:writeInt(stringStorage:size())
self:writeChar(stringStorage)
self:writeObject(upvalues)
elseif typeidx == TYPE_TORCH or typeidx == TYPE_TABLE then
-- check it exists already
local objects = self.__writeObjects
local objectsRef = self.__writeObjectsRef
local index = objects[object]
if index and (not force) then
-- if already exists, write only its index
self:writeInt(index)
else
-- else write the object itself
index = objects.nWriteObject or 0
index = index + 1
objects[object] = index
objectsRef[object] = index -- we make sure the object is not going to disappear
self:writeInt(index)
objects.nWriteObject = index
if typeidx == TYPE_TORCH then
local version = 'V ' .. object.__version
self:writeInt(#version) -- backward compat
self:write(version .. '\n')
local className = class.type(object)
self:writeInt(#className) -- backward compat
self:write(className .. '\n')
if object.write then
object:write(self)
else
local var = {}
for k,v in pairs(object) do
if self:isWritableObject(v) then
var[k] = v
else
print(string.format('$ Warning: cannot write object field <%s>', k))
end
end
self:writeObject(var)
end
else -- it is a table
local size = 0; for k,v in pairs(object) do size = size + 1 end
self:writeInt(size)
for k,v in pairs(object) do
self:writeObject(k)
self:writeObject(v)
end
end
end
else
error('unwritable object')
end
end
function File:readObject()
-- keep a record of read objects
self.__readObjects = self.__readObjects or {}
-- read the typeidx
local typeidx = self:readInt()
-- is it nil?
if typeidx == TYPE_NIL then
return nil
end
if typeidx == TYPE_NUMBER then
return self:readDouble()
elseif typeidx == TYPE_BOOLEAN then
return self:readBool()
elseif typeidx == TYPE_STRING then
local size = self:readInt()
return self:readChar(size):string()
elseif typeidx == TYPE_FUNCTION then
local size = self:readInt()
local dumped = self:readChar(size):string()
local func = loadstring(dumped)
local upvalues = self:readObject()
for index,upvalue in ipairs(upvalues) do
debug.setupvalue(func, index, upvalue)
end
return func
elseif typeidx == TYPE_TABLE or typeidx == TYPE_TORCH then
-- read the index
local index = self:readInt()
-- check it is loaded already
local objects = self.__readObjects
if objects[index] then
return objects[index]
end
-- otherwise read it
if typeidx == TYPE_TORCH then
local version, className, versionNumber
self:readInt() -- backward compat
version = self:read('*l')
versionNumber = tonumber(string.match(version, '^V (.*)$'))
if not versionNumber then
className = version
versionNumber = 0 -- file created before existence of versioning system
else
self:readInt() -- backward compat
className = self:read('*l')
end
if not torch.metatable(className) then
error(string.format('unknown Torch class <%s>', className))
end
local object = torch.factory(className)
objects[index] = object
if object.read then
object:read(self, versionNumber)
else
local var = self:readObject()
for k,v in pairs(var) do
object[k] = v
end
end
return object
else -- it is a table
local size = self:readInt()
local object = {}
objects[index] = object
for i = 1,size do
local k = self:readObject()
local v = self:readObject()
object[k] = v
end
return object
end
else
error('unknown object')
end
end
-- simple helpers to save/load arbitrary objects/tables
function torch.save(filename, object, mode)
mode = mode or 'binary'
local file = torch.DiskFile(filename, 'w')
file[mode](file)
file:writeObject(object)
file:close()
end
function torch.load(filename, mode)
mode = mode or 'binary'
local file = torch.DiskFile(filename, 'r')
file[mode](file)
local object = file:readObject()
file:close()
return object
end
-- simple helpers to serialize/deserialize arbitrary objects/tables
function torch.serialize(object)
local f = torch.MemoryFile()
f:writeObject(object)
local s = f:storage():string()
f:close()
return s
end
function torch.deserialize(str)
local x = torch.CharStorage():string(str)
local tx = torch.CharTensor(x)
local xp = torch.CharStorage(x:size(1)+1)
local txp = torch.CharTensor(xp)
txp:narrow(1,1,tx:size(1)):copy(tx)
txp[tx:size(1)+1] = 0
local f = torch.MemoryFile(xp)
local object = f:readObject()
f:close()
return object
end