From 9c09354b445fcc43f268d82ea58f93162b290abd Mon Sep 17 00:00:00 2001 From: lake4790k Date: Fri, 6 May 2016 20:48:17 +0200 Subject: [PATCH] sharedRmsProp and async nature params --- AsyncAgent.lua | 7 ++++--- AsyncMaster.lua | 13 +++++++++---- async_main.lua | 3 ++- async_run.sh | 2 +- modules/sharedRmsProp.lua | 29 +++++++++++++++++++++++++++++ 5 files changed, 45 insertions(+), 9 deletions(-) create mode 100644 modules/sharedRmsProp.lua diff --git a/AsyncAgent.lua b/AsyncAgent.lua index 867a89e..b3633b2 100644 --- a/AsyncAgent.lua +++ b/AsyncAgent.lua @@ -2,7 +2,7 @@ local AsyncModel = require 'AsyncModel' local CircularQueue = require 'structures/CircularQueue' local classic = require 'classic' local optim = require 'optim' -require 'modules/rmspropm' -- Add RMSProp with momentum +require 'modules/sharedRmsProp' require 'classic.torch' local AsyncAgent = classic.class('AsyncAgent') @@ -10,7 +10,7 @@ local AsyncAgent = classic.class('AsyncAgent') local EPSILON_ENDS = { 0.01, 0.1, 0.5} local EPSILON_PROBS = { 0.4, 0.7, 1 } -function AsyncAgent:_init(opt, policyNet, targetNet, theta, counters) +function AsyncAgent:_init(opt, policyNet, targetNet, theta, counters, sharedG) log.info('creating AsyncAgent') local asyncModel = AsyncModel(opt) self.env, self.model = asyncModel:getEnvAndModel() @@ -21,7 +21,8 @@ function AsyncAgent:_init(opt, policyNet, targetNet, theta, counters) self.optimiser = optim[opt.optimiser] self.optimParams = { learningRate = opt.eta, - momentum = opt.momentum + momentum = opt.momentum, + g = sharedG } local actionSpec = self.env:getActionSpec() diff --git a/AsyncMaster.lua b/AsyncMaster.lua index c9e7a4a..b10bb94 100644 --- a/AsyncMaster.lua +++ b/AsyncMaster.lua @@ -1,3 +1,4 @@ +require 'socket' local AsyncModel = require 'AsyncModel' local AsyncAgent = require 'AsyncAgent' local class = require 'classic' @@ -10,9 +11,10 @@ local TARGET_UPDATER = 1 local VALIDATOR = 2 local function checkNotNan(t) - local ok = t:ne(t):sum() == 0 + local sum = t:sum() + local ok = sum == sum if not ok then - log.error('ERROR'.. t:sum()) + log.error('ERROR'.. sum) end assert(ok) end @@ -82,6 +84,7 @@ function AsyncMaster:_init(opt) self.theta = policyNet:getParameters() self.targetTheta = targetNet:getParameters() + local sharedG = self.theta:clone():zero() self.controlPool = threads.Threads(2) self.controlPool:specific(true) @@ -93,7 +96,7 @@ function AsyncMaster:_init(opt) self.controlPool:addjob(VALIDATOR, torchSetup(opt)) self.controlPool:addjob(VALIDATOR, function() local AsyncAgent = require 'AsyncAgent' - evalAgent = AsyncAgent(opt, policyNet, targetNet, theta, counters) + evalAgent = AsyncAgent(opt, policyNet, targetNet, theta, counters, sharedG) end) self.controlPool:synchronize() @@ -113,7 +116,7 @@ function AsyncMaster:_init(opt) local mutex1 = threads1.Mutex(mutexId) mutex1:lock() local AsyncAgent = require 'AsyncAgent' - agent = AsyncAgent(opt, policyNet, targetNet, theta, counters) + agent = AsyncAgent(opt, policyNet, targetNet, theta, counters, sharedG) mutex1:unlock() end ) @@ -158,6 +161,8 @@ function AsyncMaster:start() local countSum = counters:sum() if countSum < 0 then return end + -- TODO report speed + local countSince = countSum - lastUpdate if countSince > opt.tau then lastUpdate = countSum diff --git a/async_main.lua b/async_main.lua index 5004bc9..fb17398 100644 --- a/async_main.lua +++ b/async_main.lua @@ -28,7 +28,7 @@ cmd:option('-doubleQ', 'true', 'Use Double Q-learning') -- Note from Georg Ostrovski: The advantage operators and Double DQN are not entirely orthogonal as the increased action gap seems to reduce the statistical bias that leads to value over-estimation in a similar way that Double DQN does cmd:option('-PALpha', 0.9, 'Persistent advantage learning parameter α (0 to disable)') -- Training options -cmd:option('-optimiser', 'rmspropm', 'Training algorithm') -- RMSProp with momentum as found in "Generating Sequences With Recurrent Neural Networks" +cmd:option('-optimiser', 'sharedRmsProp', 'Training algorithm') cmd:option('-eta', 0.0000625, 'Learning rate η') -- Prioritied experience replay learning rate (1/4 that of DQN; does not account for Duel as well) cmd:option('-momentum', 0.95, 'Gradient descent momentum') cmd:option('-batchSize', 5, 'Accumulate gradient x batchSize') @@ -111,6 +111,7 @@ opt.Tensor = function(...) return torch.Tensor(...) end +log.info(opt) local master = AsyncMaster(opt) diff --git a/async_run.sh b/async_run.sh index 1299f84..687e9b7 100755 --- a/async_run.sh +++ b/async_run.sh @@ -33,7 +33,7 @@ if [ "$PAPER" == "demo" ]; then th async_main.lua -async $ASYNC -eta 0.00025 -doubleQ false -duel false -optimiser adam -steps 15000000 -tau 4 -epsilonSteps 10000 -valFreq 10000 -valSteps 6000 -PALpha 0 "$@" elif [ "$PAPER" == "nature" ]; then # Nature - th async_main.lua -async $ASYNC -game $GAME -duel false -tau 320000 -doubleQ false -PALpha 0 -eta 0.00025 -gradClip 0 "$@" + th async_main.lua -async $ASYNC -game $GAME -duel false -tau 40000 -optimiser sharedRmsProp -epsilonSteps 4000000 -doubleQ false -PALpha 0 -eta 0.0016 -gradClip 0 "$@" elif [ "$PAPER" == "doubleq" ]; then # Double-Q (tuned) th async_main.lua -async $ASYNC -game $GAME -duel false -PALpha 0 -eta 0.00025 -gradClip 0 "$@" diff --git a/modules/sharedRmsProp.lua b/modules/sharedRmsProp.lua new file mode 100644 index 0000000..679436b --- /dev/null +++ b/modules/sharedRmsProp.lua @@ -0,0 +1,29 @@ +function optim.sharedRmsProp(opfunc, x, config, state) + -- Get state + local config = config or {} + local state = state or config + local lr = config.learningRate or 1e-2 + local momentum = config.momentum or 0.95 + local epsilon = config.epsilon or 0.01 + + -- Evaluate f(x) and df/dx + local fx, dfdx = opfunc(x) + + -- Initialise storage + if not state.g then + state.g = torch.Tensor():typeAs(x):resizeAs(dfdx):zero() + end + + if not state.tmp then + state.tmp = torch.Tensor():typeAs(x):resizeAs(dfdx) + end + + state.g:mul(momentum):addcmul(1 - momentum, dfdx, dfdx) + state.tmp:copy(state.g):add(epsilon):sqrt() + + -- Update x = x - lr x df/dx / tmp + x:addcdiv(-lr, dfdx, state.tmp) + + -- Return x*, f(x) before optimisation + return x, {fx} +end \ No newline at end of file