Skip to content

Commit

Permalink
GradClip
Browse files Browse the repository at this point in the history
  • Loading branch information
nicholas-leonard committed Apr 9, 2015
1 parent 267f5e2 commit b8fba78
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 24 deletions.
1 change: 1 addition & 0 deletions init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ torch.include('dp', 'visitor/maxnorm.lua')
torch.include('dp', 'visitor/weightdecay.lua')
torch.include('dp', 'visitor/learn.lua')
torch.include('dp', 'visitor/momentum.lua')
torch.include('dp', 'visitor/gradclip.lua')

--[[ observer ]]--
torch.include('dp', 'observer/observer.lua')
Expand Down
18 changes: 18 additions & 0 deletions model/layer.lua
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,24 @@ function Layer:maxNorm(max_out_norm, max_in_norm)
end
end

function Layer:gradClip(cutoff)
assert(self.backwarded, "Should call gradClip after a backward pass")
cutoff = self.mvstate.cutoff or cutoff
local params, gradParams = self:parameters()
local norm = 0
for k,gradParam in pairs(gradParams) do
norm = norm + math.pow(gradParam:norm(),2)
end
norm = math.sqrt(norm)
if norm > cutoff then
-- rescale gradParams to obtain desired norm
for k,gradParam in pairs(gradParams) do
gradParam:mul(cutoff/norm)
end
end
return norm
end

function Layer:share(layer, ...)
assert(layer.isLayer)
local arg = {...}
Expand Down
14 changes: 0 additions & 14 deletions model/sequential.lua
Original file line number Diff line number Diff line change
Expand Up @@ -82,17 +82,3 @@ function Sequential:_toModule()
self._models[i]:_toModule()
end
end

--[[
-- experimental
function Sequential:flux(state)
local output = self.output
-- setup
for i=1,#self._models-1 do
self._models[i]:setSuccessor(self._models[i+1])
end
return self._model[1]:flux()
self.input = output
return carry
end
--]]
10 changes: 0 additions & 10 deletions node.lua
Original file line number Diff line number Diff line change
Expand Up @@ -210,13 +210,3 @@ end
function Node:_evaluate(carry)
return self:_forward(carry)
end

--[[
-- experimental (would allow for one chained RPC call for both backward forward)
function Node:flux(carry)
local output, carry = self:forward()
local input, carry = self._successor:flux{output, carry}
local input, carry = self:backward{input, carry}
return input, carry
end
--]]
63 changes: 63 additions & 0 deletions visitor/gradclip.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
------------------------------------------------------------------------
--[[ GradClip ]]--
-- Ref.: A. http://goo.gl/Zxza8m
-- B. http://jmlr.org/proceedings/papers/v28/pascanu13.pdf
-- Visitor
-- Hard constraint on the upper bound of the norm of gradient with
-- respect to parameters (gradParams). Unlike ref A and B, which apply
-- the constraint on the norm of all parameters, the norm is applied
-- on the norm of each Layer's parameters.
-- Should occur before Learn in VisitorChain
------------------------------------------------------------------------
local GradClip, parent = torch.class("dp.GradClip", "dp.Visitor")
GradClip.isGradClip = true

function GradClip:__init(config)
config = config or {}
assert(torch.type(config) == 'table' and not config[1],
"Constructor requires key-value arguments")
local args, cutoff, name = xlua.unpack(
{config},
'GradClip',
'Hard constraint on the upper bound of the norm of gradParams.',
{arg='cutoff', type='number', default=1,
help="max norm of a Layer's parameters"},
{arg='name', type='string', default='gradclip',
help='identifies visitor in reports.'}
)
self._cutoff = cutoff
config.include = config.include or {}
table.insert(config.include, 'hasParams')
config.exclude = config.exclude or {}
table.insert(config.exclude, 'no-gradclip')
config.name = name
parent.__init(self, config)
self.norms = {}
end

function GradClip:_visitModel(model)
if model.gradClip then
local norm = model:gradClip(self._cutoff)
-- keep a moving average of norms
self.norms[model:id():toString()] = (self.norms[model:id():toString()] or 0)*0.8 + norm*0.2
else
if not model.mvstate[self:id():name()].warned then
print("Warning: GradClip not implemented for model " ..
torch.typename(model) .. ". Ignoring model-visitor pair")
model.mvstate[self:id():name()].warned = true
end
end
end

function GradClip:report()
local norms = _.values(self.norms)
if self._verbose then
print(self:id():name().." norms: ", unpack(norms))
end
local report = {
[self:name()] = {
norms = self.norms
}
}
return report
end

0 comments on commit b8fba78

Please sign in to comment.