Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] IMPALA-style actor-learner parallelism for DQN variants #477

Open
wants to merge 37 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
47480ff
tmp
muupan May 9, 2019
539e25a
Merge branch 'add-env-id-to-replay-buffer' into actor-learner-dqn
muupan May 10, 2019
bca85c9
Implement actor-learner parallelism
muupan May 10, 2019
e7c816b
Fix a bug of not shaaring models
muupan May 10, 2019
091237f
Rename
muupan May 11, 2019
e88ad5a
Add tests
muupan May 11, 2019
d9ab261
Make tests pass
muupan May 11, 2019
174f43d
Merge branch 'recurrent-dqn' into actor-learner-dqn
muupan May 11, 2019
d3c64e5
Record statistics in loss computation
muupan May 11, 2019
027a208
Remove process_idx from messages
muupan May 11, 2019
5dcf397
Support recurrent in actor-learner training
muupan May 11, 2019
ca35585
Restore unintentionally deleted update
muupan May 12, 2019
a635b07
Restore unintentionally deleted stop_current_episode
muupan May 12, 2019
04101b3
Support actor-learner training in IQN
muupan May 12, 2019
c1567cb
Add a missing file
muupan May 12, 2019
2296f5c
Record q for IQN
muupan May 12, 2019
1ef7860
Forward quantile_thresholds_K
muupan May 12, 2019
4f5a9f1
Support copy_param between different devices
muupan May 12, 2019
4d38460
Set OMP_NUM_THREADS=1
muupan May 12, 2019
f97903f
Rename function
muupan May 12, 2019
3e77935
Use poller thread to parallelize learning and polling
muupan May 12, 2019
5a3cf6d
Use replay_buffer_lock in IQN as well
muupan May 12, 2019
339b701
Support mixing cpu and gpu recurrent states
muupan May 12, 2019
c895d7c
Use pipes only
muupan May 12, 2019
6bd4305
Merge branch 'actor-learner-dqn' into muupan/actor-learner-dqn-poller
muupan May 17, 2019
00a5b4a
Add StoppableThread
muupan May 21, 2019
b6c6b53
Modify actor-learner interface to accept n_updates
muupan May 21, 2019
e68611d
Update example as well
muupan May 21, 2019
39dae2d
Merge branch 'master' into actor-learner-dqn-poller
muupan May 21, 2019
7d48c04
Call device.use in case gpu>=1
muupan May 21, 2019
ef21c29
Synchronize model in poller, not learner
muupan May 21, 2019
4bf0581
Merge branch 'master' into actor-learner-dqn-poller
muupan May 24, 2019
eaff3d7
Merge branch 'share-persistent-values' into actor-learner-dqn-poller
muupan Jun 17, 2019
0454b21
Merge branch 'share-persistent-values' into actor-learner-dqn-poller
muupan Jun 19, 2019
ab84fa2
Merge branch 'share-persistent-values' into actor-learner-dqn-poller
muupan Jun 19, 2019
4944351
Merge branch 'master' into actor-learner-dqn-poller
muupan Mar 5, 2020
aa4a534
Fix remaining conflict
muupan Mar 5, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions chainerrl/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from chainerrl.agents.double_pal import DoublePAL # NOQA
from chainerrl.agents.dpp import DPP # NOQA
from chainerrl.agents.dqn import DQN # NOQA
from chainerrl.agents.implicit_quantile_state_q_function_actor import ImplicitQuantileStateQFunctionActor # NOQA
from chainerrl.agents.iqn import IQN # NOQA
from chainerrl.agents.nsq import NSQ # NOQA
from chainerrl.agents.pal import PAL # NOQA
Expand All @@ -22,3 +23,4 @@
from chainerrl.agents.soft_actor_critic import SoftActorCritic # NOQA
from chainerrl.agents.td3 import TD3 # NOQA
from chainerrl.agents.trpo import TRPO # NOQA
from chainerrl.agents.state_q_function_actor import StateQFunctionActor # NOQA
199 changes: 169 additions & 30 deletions chainerrl/agents/dqn.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,29 @@
import collections
import copy
from logging import getLogger
import multiprocessing as mp
import time

import chainer
from chainer import cuda
import chainer.functions as F
import numpy as np

import chainerrl
from chainerrl import agent
from chainerrl.misc.batch_states import batch_states
from chainerrl.misc.copy_param import copy_param
from chainerrl.misc.copy_param import synchronize_parameters
from chainerrl.replay_buffer import batch_experiences
from chainerrl.replay_buffer import batch_recurrent_experiences
from chainerrl.replay_buffer import ReplayUpdater


def _mean_or_nan(xs):
"""Return its mean a non-empty sequence, numpy.nan for a empty one."""
return np.mean(xs) if xs else np.nan


def compute_value_loss(y, t, clip_delta=True, batch_accumulator='mean'):
"""Compute a loss for value prediction problem.

Expand Down Expand Up @@ -115,10 +126,6 @@ class DQN(agent.AttributeSavingMixin, agent.BatchAgent):
target_update_method (str): 'hard' or 'soft'.
soft_update_tau (float): Tau of soft target update.
n_times_update (int): Number of repetition of update
average_q_decay (float): Decay rate of average Q, only used for
recording statistics
average_loss_decay (float): Decay rate of average loss, only used for
recording statistics
batch_accumulator (str): 'mean' or 'sum'
episodic_update_len (int or None): Subsequences of this length are used
for update if set int and episodic_update=True
Expand All @@ -139,8 +146,7 @@ def __init__(self, q_function, optimizer, replay_buffer, gamma,
phi=lambda x: x,
target_update_method='hard',
soft_update_tau=1e-2,
n_times_update=1, average_q_decay=0.999,
average_loss_decay=0.99,
n_times_update=1,
batch_accumulator='mean',
episodic_update_len=None,
logger=getLogger(__name__),
Expand Down Expand Up @@ -184,6 +190,13 @@ def __init__(self, q_function, optimizer, replay_buffer, gamma,
replay_start_size=replay_start_size,
update_interval=update_interval,
)
self.minibatch_size = minibatch_size
self.episodic_update_len = episodic_update_len
self.replay_start_size = replay_start_size
self.update_interval = update_interval

assert target_update_interval % update_interval == 0,\
"target_update_interval should be a multiple of update_interval"

self.t = 0
self.last_state = None
Expand All @@ -192,10 +205,10 @@ def __init__(self, q_function, optimizer, replay_buffer, gamma,
self.sync_target_network()
# For backward compatibility
self.target_q_function = self.target_model
self.average_q = 0
self.average_q_decay = average_q_decay
self.average_loss = 0
self.average_loss_decay = average_loss_decay

# Statistics
self.q_record = collections.deque(maxlen=1000)
self.loss_record = collections.deque(maxlen=100)

# Recurrent states of the model
self.train_recurrent_states = None
Expand Down Expand Up @@ -262,9 +275,7 @@ def update(self, experiences, errors_out=None):
if has_weight:
self.replay_buffer.update_errors(errors_out)

# Update stats
self.average_loss *= self.average_loss_decay
self.average_loss += (1 - self.average_loss_decay) * float(loss.array)
self.loss_record.append(float(loss.array))

self.model.cleargrads()
loss.backward()
Expand All @@ -281,10 +292,8 @@ def update_from_episodes(self, episodes, errors_out=None):
batch_states=self.batch_states,
)
loss = self._compute_loss(exp_batch, errors_out=None)
# Update stats
self.average_loss *= self.average_loss_decay
self.average_loss += (1 - self.average_loss_decay) * float(loss.array)
self.optimizer.update(lambda: loss)
self.loss_record.append(float(loss.array))

def _compute_target_values(self, exp_batch):
batch_next_state = exp_batch['next_state']
Expand Down Expand Up @@ -340,6 +349,8 @@ def _compute_loss(self, exp_batch, errors_out=None):
"""
y, t = self._compute_y_and_t(exp_batch)

self.q_record.extend(cuda.to_cpu(y.array).ravel())

if errors_out is not None:
del errors_out[:]
delta = F.absolute(y - t)
Expand All @@ -366,10 +377,6 @@ def act(self, obs):
q = float(action_value.max.array)
action = cuda.to_cpu(action_value.greedy_actions.array)[0]

# Update stats
self.average_q *= self.average_q_decay
self.average_q += (1 - self.average_q_decay) * q

self.logger.debug('t:%s q:%s action_value:%s', self.t, q, action_value)
return action

Expand Down Expand Up @@ -414,10 +421,6 @@ def act_and_train(self, obs, reward):
action = self.explorer.select_action(
self.t, lambda: greedy_action, action_value=action_value)

# Update stats
self.average_q *= self.average_q_decay
self.average_q += (1 - self.average_q_decay) * q

self.t += 1
self.last_state = obs
self.last_action = action
Expand All @@ -443,9 +446,16 @@ def _evaluate_model_and_update_recurrent_states(self, batch_obs, test):

def batch_act_and_train(self, batch_obs):
with chainer.using_config('train', False), chainer.no_backprop_mode():


<< << << < HEAD
batch_av = self._evaluate_model_and_update_train_recurrent_states(
batch_obs)
== == == =
batch_av = self._evaluate_model_and_update_recurrent_states(
batch_obs, test=False)
batch_maxq = batch_av.max.array
>>>>>> > master
batch_argmax = cuda.to_cpu(batch_av.greedy_actions.array)
batch_action = [
self.explorer.select_action(
Expand All @@ -456,10 +466,6 @@ def batch_act_and_train(self, batch_obs):
self.batch_last_obs = list(batch_obs)
self.batch_last_action = list(batch_action)

# Update stats
self.average_q *= self.average_q_decay
self.average_q += (1 - self.average_q_decay) * float(batch_maxq.mean())

return batch_action

def batch_act(self, batch_obs):
Expand Down Expand Up @@ -560,13 +566,146 @@ def stop_episode_and_train(self, state, reward, done=False):
self.train_recurrent_states = None
self.replay_buffer.stop_current_episode()

def _can_start_replay(self):
if len(self.replay_buffer) < self.replay_start_size:
return False
if (self.recurrent
and self.replay_buffer.n_episodes < self.minibatch_size):
return False
return True

def _poll_pipe(self, actor_idx, pipe, replay_buffer_lock):
if pipe.closed:
return
try:
while pipe.poll():
cmd, data = pipe.recv()
if cmd == 'get_statistics':
assert data is None
pipe.send(self.get_statistics())
elif cmd == 'load':
self.load(data)
pipe.send(None)
elif cmd == 'save':
self.save(data)
pipe.send(None)
elif cmd == 'transition':
with replay_buffer_lock:
self.replay_buffer.append(**data, env_id=actor_idx)
elif cmd == 'stop_episode':
assert data is None
with replay_buffer_lock:
self.replay_buffer.stop_current_episode(
env_id=actor_idx)
else:
raise RuntimeError(
'Unknown command from actor: {}'.format(cmd))
except EOFError:
pipe.close()

def _learner_loop(self, pipes, replay_buffer_lock, stop_event,
n_updates=None):

# Device.use should be called in a new thread
self.model.device.use()
# To stop this loop, call stop_event.set()
while not stop_event.is_set():
time.sleep(1e-6)
# Update model if possible
if not self._can_start_replay():
continue
if n_updates is not None:
assert self.optimizer.t <= n_updates
if self.optimizer.t == n_updates:
stop_event.set()
break
if self.recurrent:
with replay_buffer_lock:
episodes = self.replay_buffer.sample_episodes(
self.minibatch_size, self.episodic_update_len)
self.update_from_episodes(episodes)
else:
with replay_buffer_lock:
transitions = self.replay_buffer.sample(
self.minibatch_size)
self.update(transitions)
# To keep the ratio of target updates to model updates,
# here we calculate back the effective current timestep
# from update_interval and number of updates so far.
effective_timestep = self.optimizer.t * self.update_interval
if effective_timestep % self.target_update_interval == 0:
self.sync_target_network()

def _poller_loop(self, shared_model, pipes, replay_buffer_lock,
stop_event):
# To stop this loop, call stop_event.set()
while not stop_event.is_set():
time.sleep(1e-6)
# Poll actors for messages
for i, pipe in enumerate(pipes):
self._poll_pipe(i, pipe, replay_buffer_lock)
# Synchronize shared model
copy_param(source_link=self.model,
target_link=shared_model)

def setup_actor_learner_training(self, n_actors, n_updates=None):
# Make a copy on shared memory and share among actors and a learner
shared_model = copy.deepcopy(self.model).to_cpu()
shared_arrays = chainerrl.misc.async_.extract_params_as_shared_arrays(
shared_model)

# Pipes are used for infrequent communication
learner_pipes, actor_pipes = list(zip(*[
mp.Pipe() for _ in range(n_actors)]))

def make_actor(i):
chainerrl.misc.async_.set_shared_params(
shared_model, shared_arrays)
return chainerrl.agents.StateQFunctionActor(
pipe=actor_pipes[i],
model=shared_model,
explorer=self.explorer,
phi=self.phi,
batch_states=self.batch_states,
logger=self.logger,
recurrent=self.recurrent,
)

replay_buffer_lock = mp.Lock()

poller_stop_event = mp.Event()
poller = chainerrl.misc.StoppableThread(
target=self._poller_loop,
kwargs=dict(
shared_model=shared_model,
pipes=learner_pipes,
replay_buffer_lock=replay_buffer_lock,
stop_event=poller_stop_event,
),
stop_event=poller_stop_event,
)

learner_stop_event = mp.Event()
learner = chainerrl.misc.StoppableThread(
target=self._learner_loop,
kwargs=dict(
pipes=learner_pipes,
replay_buffer_lock=replay_buffer_lock,
stop_event=learner_stop_event,
n_updates=n_updates,
),
stop_event=learner_stop_event,
)

return make_actor, learner, poller

def stop_episode(self):
if self.recurrent:
self.test_recurrent_states = None

def get_statistics(self):
return [
('average_q', self.average_q),
('average_loss', self.average_loss),
('average_q', _mean_or_nan(self.q_record)),
('average_loss', _mean_or_nan(self.loss_record)),
('n_updates', self.optimizer.t),
]
53 changes: 53 additions & 0 deletions chainerrl/agents/implicit_quantile_state_q_function_actor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from __future__ import absolute_import
from builtins import * # NOQA
from future import standard_library
standard_library.install_aliases() # NOQA

from chainerrl.agents import state_q_function_actor


class ImplicitQuantileStateQFunctionActor(
state_q_function_actor.StateQFunctionActor):
"""Actor that acts according to the implicit quantile Q-function.

This actor specialization is required because the interface of an implicit
quantile Q-function is different from that of a usual Q-function.
"""

def __init__(self, *args, **kwargs):
# K=32 were used in the IQN paper's experiments
# (personal communication)
self.quantile_thresholds_K = kwargs.pop('quantile_thresholds_K', 32)
super().__init__(*args, **kwargs)

@property
def xp(self):
return self.model.xp

def _evaluate_model_and_update_train_recurrent_states(self, batch_obs):
batch_xs = self.batch_states(batch_obs, self.xp, self.phi)
if self.recurrent:
self.train_prev_recurrent_states = self.train_recurrent_states
tau2av, self.train_recurrent_states = self.model(
batch_xs, self.train_recurrent_states)
else:
tau2av = self.model(batch_xs)
taus_tilde = self.xp.random.uniform(
0, 1,
size=(len(batch_obs), self.quantile_thresholds_K)).astype('f')
return tau2av(taus_tilde)

def _evaluate_model_and_update_test_recurrent_states(self, batch_obs):
batch_xs = self.batch_states(batch_obs, self.xp, self.phi)
if self.recurrent:
tau2av, self.test_recurrent_states = self.model(
batch_xs, self.test_recurrent_states)
else:
tau2av = self.model(batch_xs)
taus_tilde = self.xp.random.uniform(
0, 1,
size=(len(batch_obs), self.quantile_thresholds_K)).astype('f')
return tau2av(taus_tilde)
Loading