Skip to content

Commit

Permalink
Merge pull request #15 from lake4790k/async
Browse files Browse the repository at this point in the history
sharedRmsProp and async nature params
  • Loading branch information
Kaixhin committed May 7, 2016
2 parents 592dd39 + 9c09354 commit 4d2559c
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 9 deletions.
7 changes: 4 additions & 3 deletions AsyncAgent.lua
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@ local AsyncModel = require 'AsyncModel'
local CircularQueue = require 'structures/CircularQueue'
local classic = require 'classic'
local optim = require 'optim'
require 'modules/rmspropm' -- Add RMSProp with momentum
require 'modules/sharedRmsProp'
require 'classic.torch'

local AsyncAgent = classic.class('AsyncAgent')

local EPSILON_ENDS = { 0.01, 0.1, 0.5}
local EPSILON_PROBS = { 0.4, 0.7, 1 }

function AsyncAgent:_init(opt, policyNet, targetNet, theta, counters)
function AsyncAgent:_init(opt, policyNet, targetNet, theta, counters, sharedG)
log.info('creating AsyncAgent')
local asyncModel = AsyncModel(opt)
self.env, self.model = asyncModel:getEnvAndModel()
Expand All @@ -21,7 +21,8 @@ function AsyncAgent:_init(opt, policyNet, targetNet, theta, counters)
self.optimiser = optim[opt.optimiser]
self.optimParams = {
learningRate = opt.eta,
momentum = opt.momentum
momentum = opt.momentum,
g = sharedG
}

local actionSpec = self.env:getActionSpec()
Expand Down
13 changes: 9 additions & 4 deletions AsyncMaster.lua
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
require 'socket'
local AsyncModel = require 'AsyncModel'
local AsyncAgent = require 'AsyncAgent'
local class = require 'classic'
Expand All @@ -10,9 +11,10 @@ local TARGET_UPDATER = 1
local VALIDATOR = 2

local function checkNotNan(t)
local ok = t:ne(t):sum() == 0
local sum = t:sum()
local ok = sum == sum
if not ok then
log.error('ERROR'.. t:sum())
log.error('ERROR'.. sum)
end
assert(ok)
end
Expand Down Expand Up @@ -82,6 +84,7 @@ function AsyncMaster:_init(opt)

self.theta = policyNet:getParameters()
self.targetTheta = targetNet:getParameters()
local sharedG = self.theta:clone():zero()

self.controlPool = threads.Threads(2)
self.controlPool:specific(true)
Expand All @@ -93,7 +96,7 @@ function AsyncMaster:_init(opt)
self.controlPool:addjob(VALIDATOR, torchSetup(opt))
self.controlPool:addjob(VALIDATOR, function()
local AsyncAgent = require 'AsyncAgent'
evalAgent = AsyncAgent(opt, policyNet, targetNet, theta, counters)
evalAgent = AsyncAgent(opt, policyNet, targetNet, theta, counters, sharedG)
end)

self.controlPool:synchronize()
Expand All @@ -113,7 +116,7 @@ function AsyncMaster:_init(opt)
local mutex1 = threads1.Mutex(mutexId)
mutex1:lock()
local AsyncAgent = require 'AsyncAgent'
agent = AsyncAgent(opt, policyNet, targetNet, theta, counters)
agent = AsyncAgent(opt, policyNet, targetNet, theta, counters, sharedG)
mutex1:unlock()
end
)
Expand Down Expand Up @@ -158,6 +161,8 @@ function AsyncMaster:start()
local countSum = counters:sum()
if countSum < 0 then return end

-- TODO report speed

local countSince = countSum - lastUpdate
if countSince > opt.tau then
lastUpdate = countSum
Expand Down
3 changes: 2 additions & 1 deletion async_main.lua
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ cmd:option('-doubleQ', 'true', 'Use Double Q-learning')
-- Note from Georg Ostrovski: The advantage operators and Double DQN are not entirely orthogonal as the increased action gap seems to reduce the statistical bias that leads to value over-estimation in a similar way that Double DQN does
cmd:option('-PALpha', 0.9, 'Persistent advantage learning parameter α (0 to disable)')
-- Training options
cmd:option('-optimiser', 'rmspropm', 'Training algorithm') -- RMSProp with momentum as found in "Generating Sequences With Recurrent Neural Networks"
cmd:option('-optimiser', 'sharedRmsProp', 'Training algorithm')
cmd:option('-eta', 0.0000625, 'Learning rate η') -- Prioritied experience replay learning rate (1/4 that of DQN; does not account for Duel as well)
cmd:option('-momentum', 0.95, 'Gradient descent momentum')
cmd:option('-batchSize', 5, 'Accumulate gradient x batchSize')
Expand Down Expand Up @@ -111,6 +111,7 @@ opt.Tensor = function(...)
return torch.Tensor(...)
end

log.info(opt)

local master = AsyncMaster(opt)

Expand Down
2 changes: 1 addition & 1 deletion async_run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ if [ "$PAPER" == "demo" ]; then
th async_main.lua -async $ASYNC -eta 0.00025 -doubleQ false -duel false -optimiser adam -steps 15000000 -tau 4 -epsilonSteps 10000 -valFreq 10000 -valSteps 6000 -PALpha 0 "$@"
elif [ "$PAPER" == "nature" ]; then
# Nature
th async_main.lua -async $ASYNC -game $GAME -duel false -tau 320000 -doubleQ false -PALpha 0 -eta 0.00025 -gradClip 0 "$@"
th async_main.lua -async $ASYNC -game $GAME -duel false -tau 40000 -optimiser sharedRmsProp -epsilonSteps 4000000 -doubleQ false -PALpha 0 -eta 0.0016 -gradClip 0 "$@"
elif [ "$PAPER" == "doubleq" ]; then
# Double-Q (tuned)
th async_main.lua -async $ASYNC -game $GAME -duel false -PALpha 0 -eta 0.00025 -gradClip 0 "$@"
Expand Down
29 changes: 29 additions & 0 deletions modules/sharedRmsProp.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
function optim.sharedRmsProp(opfunc, x, config, state)
-- Get state
local config = config or {}
local state = state or config
local lr = config.learningRate or 1e-2
local momentum = config.momentum or 0.95
local epsilon = config.epsilon or 0.01

-- Evaluate f(x) and df/dx
local fx, dfdx = opfunc(x)

-- Initialise storage
if not state.g then
state.g = torch.Tensor():typeAs(x):resizeAs(dfdx):zero()
end

if not state.tmp then
state.tmp = torch.Tensor():typeAs(x):resizeAs(dfdx)
end

state.g:mul(momentum):addcmul(1 - momentum, dfdx, dfdx)
state.tmp:copy(state.g):add(epsilon):sqrt()

-- Update x = x - lr x df/dx / tmp
x:addcdiv(-lr, dfdx, state.tmp)

-- Return x*, f(x) before optimisation
return x, {fx}
end

0 comments on commit 4d2559c

Please sign in to comment.