diff --git a/chainerrl/agents/__init__.py b/chainerrl/agents/__init__.py index 99d878935..d70b50425 100644 --- a/chainerrl/agents/__init__.py +++ b/chainerrl/agents/__init__.py @@ -10,6 +10,7 @@ from chainerrl.agents.double_pal import DoublePAL # NOQA from chainerrl.agents.dpp import DPP # NOQA from chainerrl.agents.dqn import DQN # NOQA +from chainerrl.agents.implicit_quantile_state_q_function_actor import ImplicitQuantileStateQFunctionActor # NOQA from chainerrl.agents.iqn import IQN # NOQA from chainerrl.agents.nsq import NSQ # NOQA from chainerrl.agents.pal import PAL # NOQA @@ -22,3 +23,4 @@ from chainerrl.agents.soft_actor_critic import SoftActorCritic # NOQA from chainerrl.agents.td3 import TD3 # NOQA from chainerrl.agents.trpo import TRPO # NOQA +from chainerrl.agents.state_q_function_actor import StateQFunctionActor # NOQA diff --git a/chainerrl/agents/dqn.py b/chainerrl/agents/dqn.py index 2c74e1c24..4e055c9a7 100644 --- a/chainerrl/agents/dqn.py +++ b/chainerrl/agents/dqn.py @@ -1,18 +1,29 @@ +import collections import copy from logging import getLogger +import multiprocessing as mp +import time import chainer from chainer import cuda import chainer.functions as F +import numpy as np +import chainerrl from chainerrl import agent from chainerrl.misc.batch_states import batch_states +from chainerrl.misc.copy_param import copy_param from chainerrl.misc.copy_param import synchronize_parameters from chainerrl.replay_buffer import batch_experiences from chainerrl.replay_buffer import batch_recurrent_experiences from chainerrl.replay_buffer import ReplayUpdater +def _mean_or_nan(xs): + """Return its mean a non-empty sequence, numpy.nan for a empty one.""" + return np.mean(xs) if xs else np.nan + + def compute_value_loss(y, t, clip_delta=True, batch_accumulator='mean'): """Compute a loss for value prediction problem. @@ -115,10 +126,6 @@ class DQN(agent.AttributeSavingMixin, agent.BatchAgent): target_update_method (str): 'hard' or 'soft'. soft_update_tau (float): Tau of soft target update. n_times_update (int): Number of repetition of update - average_q_decay (float): Decay rate of average Q, only used for - recording statistics - average_loss_decay (float): Decay rate of average loss, only used for - recording statistics batch_accumulator (str): 'mean' or 'sum' episodic_update_len (int or None): Subsequences of this length are used for update if set int and episodic_update=True @@ -139,8 +146,7 @@ def __init__(self, q_function, optimizer, replay_buffer, gamma, phi=lambda x: x, target_update_method='hard', soft_update_tau=1e-2, - n_times_update=1, average_q_decay=0.999, - average_loss_decay=0.99, + n_times_update=1, batch_accumulator='mean', episodic_update_len=None, logger=getLogger(__name__), @@ -184,6 +190,13 @@ def __init__(self, q_function, optimizer, replay_buffer, gamma, replay_start_size=replay_start_size, update_interval=update_interval, ) + self.minibatch_size = minibatch_size + self.episodic_update_len = episodic_update_len + self.replay_start_size = replay_start_size + self.update_interval = update_interval + + assert target_update_interval % update_interval == 0,\ + "target_update_interval should be a multiple of update_interval" self.t = 0 self.last_state = None @@ -192,10 +205,10 @@ def __init__(self, q_function, optimizer, replay_buffer, gamma, self.sync_target_network() # For backward compatibility self.target_q_function = self.target_model - self.average_q = 0 - self.average_q_decay = average_q_decay - self.average_loss = 0 - self.average_loss_decay = average_loss_decay + + # Statistics + self.q_record = collections.deque(maxlen=1000) + self.loss_record = collections.deque(maxlen=100) # Recurrent states of the model self.train_recurrent_states = None @@ -262,9 +275,7 @@ def update(self, experiences, errors_out=None): if has_weight: self.replay_buffer.update_errors(errors_out) - # Update stats - self.average_loss *= self.average_loss_decay - self.average_loss += (1 - self.average_loss_decay) * float(loss.array) + self.loss_record.append(float(loss.array)) self.model.cleargrads() loss.backward() @@ -281,10 +292,8 @@ def update_from_episodes(self, episodes, errors_out=None): batch_states=self.batch_states, ) loss = self._compute_loss(exp_batch, errors_out=None) - # Update stats - self.average_loss *= self.average_loss_decay - self.average_loss += (1 - self.average_loss_decay) * float(loss.array) self.optimizer.update(lambda: loss) + self.loss_record.append(float(loss.array)) def _compute_target_values(self, exp_batch): batch_next_state = exp_batch['next_state'] @@ -340,6 +349,8 @@ def _compute_loss(self, exp_batch, errors_out=None): """ y, t = self._compute_y_and_t(exp_batch) + self.q_record.extend(cuda.to_cpu(y.array).ravel()) + if errors_out is not None: del errors_out[:] delta = F.absolute(y - t) @@ -366,10 +377,6 @@ def act(self, obs): q = float(action_value.max.array) action = cuda.to_cpu(action_value.greedy_actions.array)[0] - # Update stats - self.average_q *= self.average_q_decay - self.average_q += (1 - self.average_q_decay) * q - self.logger.debug('t:%s q:%s action_value:%s', self.t, q, action_value) return action @@ -414,10 +421,6 @@ def act_and_train(self, obs, reward): action = self.explorer.select_action( self.t, lambda: greedy_action, action_value=action_value) - # Update stats - self.average_q *= self.average_q_decay - self.average_q += (1 - self.average_q_decay) * q - self.t += 1 self.last_state = obs self.last_action = action @@ -443,9 +446,16 @@ def _evaluate_model_and_update_recurrent_states(self, batch_obs, test): def batch_act_and_train(self, batch_obs): with chainer.using_config('train', False), chainer.no_backprop_mode(): + + +<< << << < HEAD + batch_av = self._evaluate_model_and_update_train_recurrent_states( + batch_obs) +== == == = batch_av = self._evaluate_model_and_update_recurrent_states( batch_obs, test=False) batch_maxq = batch_av.max.array +>>>>>> > master batch_argmax = cuda.to_cpu(batch_av.greedy_actions.array) batch_action = [ self.explorer.select_action( @@ -456,10 +466,6 @@ def batch_act_and_train(self, batch_obs): self.batch_last_obs = list(batch_obs) self.batch_last_action = list(batch_action) - # Update stats - self.average_q *= self.average_q_decay - self.average_q += (1 - self.average_q_decay) * float(batch_maxq.mean()) - return batch_action def batch_act(self, batch_obs): @@ -560,13 +566,146 @@ def stop_episode_and_train(self, state, reward, done=False): self.train_recurrent_states = None self.replay_buffer.stop_current_episode() + def _can_start_replay(self): + if len(self.replay_buffer) < self.replay_start_size: + return False + if (self.recurrent + and self.replay_buffer.n_episodes < self.minibatch_size): + return False + return True + + def _poll_pipe(self, actor_idx, pipe, replay_buffer_lock): + if pipe.closed: + return + try: + while pipe.poll(): + cmd, data = pipe.recv() + if cmd == 'get_statistics': + assert data is None + pipe.send(self.get_statistics()) + elif cmd == 'load': + self.load(data) + pipe.send(None) + elif cmd == 'save': + self.save(data) + pipe.send(None) + elif cmd == 'transition': + with replay_buffer_lock: + self.replay_buffer.append(**data, env_id=actor_idx) + elif cmd == 'stop_episode': + assert data is None + with replay_buffer_lock: + self.replay_buffer.stop_current_episode( + env_id=actor_idx) + else: + raise RuntimeError( + 'Unknown command from actor: {}'.format(cmd)) + except EOFError: + pipe.close() + + def _learner_loop(self, pipes, replay_buffer_lock, stop_event, + n_updates=None): + + # Device.use should be called in a new thread + self.model.device.use() + # To stop this loop, call stop_event.set() + while not stop_event.is_set(): + time.sleep(1e-6) + # Update model if possible + if not self._can_start_replay(): + continue + if n_updates is not None: + assert self.optimizer.t <= n_updates + if self.optimizer.t == n_updates: + stop_event.set() + break + if self.recurrent: + with replay_buffer_lock: + episodes = self.replay_buffer.sample_episodes( + self.minibatch_size, self.episodic_update_len) + self.update_from_episodes(episodes) + else: + with replay_buffer_lock: + transitions = self.replay_buffer.sample( + self.minibatch_size) + self.update(transitions) + # To keep the ratio of target updates to model updates, + # here we calculate back the effective current timestep + # from update_interval and number of updates so far. + effective_timestep = self.optimizer.t * self.update_interval + if effective_timestep % self.target_update_interval == 0: + self.sync_target_network() + + def _poller_loop(self, shared_model, pipes, replay_buffer_lock, + stop_event): + # To stop this loop, call stop_event.set() + while not stop_event.is_set(): + time.sleep(1e-6) + # Poll actors for messages + for i, pipe in enumerate(pipes): + self._poll_pipe(i, pipe, replay_buffer_lock) + # Synchronize shared model + copy_param(source_link=self.model, + target_link=shared_model) + + def setup_actor_learner_training(self, n_actors, n_updates=None): + # Make a copy on shared memory and share among actors and a learner + shared_model = copy.deepcopy(self.model).to_cpu() + shared_arrays = chainerrl.misc.async_.extract_params_as_shared_arrays( + shared_model) + + # Pipes are used for infrequent communication + learner_pipes, actor_pipes = list(zip(*[ + mp.Pipe() for _ in range(n_actors)])) + + def make_actor(i): + chainerrl.misc.async_.set_shared_params( + shared_model, shared_arrays) + return chainerrl.agents.StateQFunctionActor( + pipe=actor_pipes[i], + model=shared_model, + explorer=self.explorer, + phi=self.phi, + batch_states=self.batch_states, + logger=self.logger, + recurrent=self.recurrent, + ) + + replay_buffer_lock = mp.Lock() + + poller_stop_event = mp.Event() + poller = chainerrl.misc.StoppableThread( + target=self._poller_loop, + kwargs=dict( + shared_model=shared_model, + pipes=learner_pipes, + replay_buffer_lock=replay_buffer_lock, + stop_event=poller_stop_event, + ), + stop_event=poller_stop_event, + ) + + learner_stop_event = mp.Event() + learner = chainerrl.misc.StoppableThread( + target=self._learner_loop, + kwargs=dict( + pipes=learner_pipes, + replay_buffer_lock=replay_buffer_lock, + stop_event=learner_stop_event, + n_updates=n_updates, + ), + stop_event=learner_stop_event, + ) + + return make_actor, learner, poller + def stop_episode(self): if self.recurrent: self.test_recurrent_states = None def get_statistics(self): return [ - ('average_q', self.average_q), - ('average_loss', self.average_loss), + ('average_q', _mean_or_nan(self.q_record)), + ('average_loss', _mean_or_nan(self.loss_record)), ('n_updates', self.optimizer.t), ] diff --git a/chainerrl/agents/implicit_quantile_state_q_function_actor.py b/chainerrl/agents/implicit_quantile_state_q_function_actor.py new file mode 100644 index 000000000..ac26d62e6 --- /dev/null +++ b/chainerrl/agents/implicit_quantile_state_q_function_actor.py @@ -0,0 +1,53 @@ +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals +from __future__ import absolute_import +from builtins import * # NOQA +from future import standard_library +standard_library.install_aliases() # NOQA + +from chainerrl.agents import state_q_function_actor + + +class ImplicitQuantileStateQFunctionActor( + state_q_function_actor.StateQFunctionActor): + """Actor that acts according to the implicit quantile Q-function. + + This actor specialization is required because the interface of an implicit + quantile Q-function is different from that of a usual Q-function. + """ + + def __init__(self, *args, **kwargs): + # K=32 were used in the IQN paper's experiments + # (personal communication) + self.quantile_thresholds_K = kwargs.pop('quantile_thresholds_K', 32) + super().__init__(*args, **kwargs) + + @property + def xp(self): + return self.model.xp + + def _evaluate_model_and_update_train_recurrent_states(self, batch_obs): + batch_xs = self.batch_states(batch_obs, self.xp, self.phi) + if self.recurrent: + self.train_prev_recurrent_states = self.train_recurrent_states + tau2av, self.train_recurrent_states = self.model( + batch_xs, self.train_recurrent_states) + else: + tau2av = self.model(batch_xs) + taus_tilde = self.xp.random.uniform( + 0, 1, + size=(len(batch_obs), self.quantile_thresholds_K)).astype('f') + return tau2av(taus_tilde) + + def _evaluate_model_and_update_test_recurrent_states(self, batch_obs): + batch_xs = self.batch_states(batch_obs, self.xp, self.phi) + if self.recurrent: + tau2av, self.test_recurrent_states = self.model( + batch_xs, self.test_recurrent_states) + else: + tau2av = self.model(batch_xs) + taus_tilde = self.xp.random.uniform( + 0, 1, + size=(len(batch_obs), self.quantile_thresholds_K)).astype('f') + return tau2av(taus_tilde) diff --git a/chainerrl/agents/iqn.py b/chainerrl/agents/iqn.py index 1ca3f8cfb..2c168b933 100644 --- a/chainerrl/agents/iqn.py +++ b/chainerrl/agents/iqn.py @@ -1,8 +1,12 @@ +import copy +import multiprocessing as mp +import threading import chainer from chainer import cuda import chainer.functions as F import chainer.links as L +import chainerrl from chainerrl.action_value import QuantileDiscreteActionValue from chainerrl.agents import dqn from chainerrl.links import StatelessRecurrentChainList @@ -380,6 +384,8 @@ def _compute_loss(self, exp_batch, errors_out=None): with chainer.no_backprop_mode(): t = self._compute_target_values(exp_batch) + self.q_record.extend(cuda.to_cpu(y.array.mean(axis=1)).ravel()) + eltwise_loss = compute_eltwise_huber_quantile_loss(y, t, taus) if errors_out is not None: del errors_out[:] @@ -420,3 +426,58 @@ def _evaluate_model_and_update_recurrent_states(self, batch_obs, test): 0, 1, size=(len(batch_obs), self.quantile_thresholds_K)).astype('f') return tau2av(taus_tilde) + + def setup_actor_learner_training(self, n_actors, n_updates=None): + # Override DQN.setup_actor_learner_training to use + # `ImplicitQuantileStateQFunctionActor`, not `StateQFunctionActor`. + + # Make a copy on shared memory and share among actors and a learner + shared_model = copy.deepcopy(self.model).to_cpu() + shared_arrays = chainerrl.misc.async_.extract_params_as_shared_arrays( + shared_model) + + # Pipes are used for infrequent communication + learner_pipes, actor_pipes = list(zip(*[ + mp.Pipe() for _ in range(n_actors)])) + + def make_actor(i): + chainerrl.misc.async_.set_shared_params( + shared_model, shared_arrays) + return chainerrl.agents.ImplicitQuantileStateQFunctionActor( + pipe=actor_pipes[i], + model=shared_model, + explorer=self.explorer, + phi=self.phi, + batch_states=self.batch_states, + logger=self.logger, + recurrent=self.recurrent, + quantile_thresholds_K=self.quantile_thresholds_K, + ) + + replay_buffer_lock = threading.Lock() + + poller_stop_event = mp.Event() + poller = chainerrl.misc.StoppableThread( + target=self._poller_loop, + kwargs=dict( + shared_model=shared_model, + pipes=learner_pipes, + replay_buffer_lock=replay_buffer_lock, + stop_event=poller_stop_event, + ), + stop_event=poller_stop_event, + ) + + learner_stop_event = mp.Event() + learner = chainerrl.misc.StoppableThread( + target=self._learner_loop, + kwargs=dict( + pipes=learner_pipes, + replay_buffer_lock=replay_buffer_lock, + stop_event=learner_stop_event, + n_updates=n_updates, + ), + stop_event=learner_stop_event, + ) + + return make_actor, learner, poller diff --git a/chainerrl/agents/state_q_function_actor.py b/chainerrl/agents/state_q_function_actor.py new file mode 100644 index 000000000..c3be1a606 --- /dev/null +++ b/chainerrl/agents/state_q_function_actor.py @@ -0,0 +1,165 @@ +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals +from __future__ import absolute_import +from builtins import * # NOQA +from future import standard_library +standard_library.install_aliases() # NOQA + +from logging import getLogger + +import chainer +from chainer import cuda + +from chainerrl import agent +from chainerrl.misc.batch_states import batch_states + + +class StateQFunctionActor(agent.AsyncAgent): + """Actor that acts according to the Q-function.""" + + process_idx = None + shared_attributes = [] + + def __init__( + self, + pipe, + model, + explorer, + phi=lambda x: x, + recurrent=False, + logger=getLogger(__name__), + batch_states=batch_states, + ): + self.pipe = pipe + self.model = model + self.explorer = explorer + self.phi = phi + self.recurrent = recurrent + self.logger = logger + self.batch_states = batch_states + + self.t = 0 + self.last_state = None + self.last_action = None + + # Recurrent states of the model + self.train_recurrent_states = None + self.train_prev_recurrent_states = None + self.test_recurrent_states = None + + @property + def xp(self): + return self.model.xp + + def _evaluate_model_and_update_train_recurrent_states(self, batch_obs): + batch_xs = self.batch_states(batch_obs, self.xp, self.phi) + if self.recurrent: + self.train_prev_recurrent_states = self.train_recurrent_states + batch_av, self.train_recurrent_states = self.model( + batch_xs, self.train_recurrent_states) + else: + batch_av = self.model(batch_xs) + return batch_av + + def _evaluate_model_and_update_test_recurrent_states(self, batch_obs): + batch_xs = self.batch_states(batch_obs, self.xp, self.phi) + if self.recurrent: + batch_av, self.test_recurrent_states = self.model( + batch_xs, self.test_recurrent_states) + else: + batch_av = self.model(batch_xs) + return batch_av + + def act(self, obs): + with chainer.using_config('train', False), chainer.no_backprop_mode(): + action_value =\ + self._evaluate_model_and_update_test_recurrent_states([obs]) + action = cuda.to_cpu(action_value.greedy_actions.array)[0] + return action + + def _send_to_learner(self, transition, stop_episode=False): + self.pipe.send(('transition', transition)) + if stop_episode: + self.pipe.send(('stop_episode', None)) + + def act_and_train(self, obs, reward): + + with chainer.using_config('train', False), chainer.no_backprop_mode(): + action_value =\ + self._evaluate_model_and_update_train_recurrent_states([obs]) + greedy_action = cuda.to_cpu(action_value.greedy_actions.array)[0] + + action = self.explorer.select_action( + self.t, lambda: greedy_action, action_value=action_value) + self.t += 1 + + if self.last_state is not None: + assert self.last_action is not None + # Add a transition to the replay buffer + transition = { + 'state': self.last_state, + 'action': self.last_action, + 'reward': reward, + 'next_state': obs, + 'is_state_terminal': False, + } + if self.recurrent: + transition['recurrent_state'] =\ + self.model.get_recurrent_state_at( + self.train_prev_recurrent_states, + 0, unwrap_variable=True) + self.train_prev_recurrent_states = None + transition['next_recurrent_state'] =\ + self.model.get_recurrent_state_at( + self.train_recurrent_states, 0, unwrap_variable=True) + self._send_to_learner(transition) + + self.last_state = obs + self.last_action = action + + return self.last_action + + def stop_episode_and_train(self, state, reward, done=False): + + assert self.last_state is not None + assert self.last_action is not None + + # Add a transition to the replay buffer + transition = { + 'state': self.last_state, + 'action': self.last_action, + 'reward': reward, + 'next_state': state, + 'is_state_terminal': done, + } + if self.recurrent: + transition['recurrent_state'] =\ + self.model.get_recurrent_state_at( + self.train_prev_recurrent_states, 0, unwrap_variable=True) + self.train_prev_recurrent_states = None + transition['next_recurrent_state'] =\ + self.model.get_recurrent_state_at( + self.train_recurrent_states, 0, unwrap_variable=True) + self._send_to_learner(transition, stop_episode=True) + + self.last_state = None + self.last_action = None + if self.recurrent: + self.train_recurrent_states = None + + def stop_episode(self): + if self.recurrent: + self.test_recurrent_states = None + + def save(self, dirname): + self.pipe.send(('save', dirname)) + self.pipe.recv() + + def load(self, dirname): + self.pipe.send(('load', dirname)) + self.pipe.recv() + + def get_statistics(self): + self.pipe.send(('get_statistics', None)) + return self.pipe.recv() diff --git a/chainerrl/experiments/train_agent_async.py b/chainerrl/experiments/train_agent_async.py index 9cc89a9f5..d83c9a733 100644 --- a/chainerrl/experiments/train_agent_async.py +++ b/chainerrl/experiments/train_agent_async.py @@ -2,13 +2,15 @@ import multiprocessing as mp import os +import numpy as np + from chainerrl.experiments.evaluator import AsyncEvaluator from chainerrl.misc import async_ from chainerrl.misc import random_seed def train_loop(process_idx, env, agent, steps, outdir, counter, - episodes_counter, training_done, + episodes_counter, stop_event, max_episode_len=None, evaluator=None, eval_env=None, successful_score=None, logger=None, global_step_hooks=()): @@ -50,7 +52,7 @@ def train_loop(process_idx, env, agent, steps, outdir, counter, reset = (episode_len == max_episode_len or info.get('needs_reset', False)) - if done or reset or global_t >= steps or training_done.value: + if done or reset or global_t >= steps or stop_event.is_set(): agent.stop_episode_and_train(obs, r, done) if process_idx == 0: @@ -67,10 +69,8 @@ def train_loop(process_idx, env, agent, steps, outdir, counter, if (eval_score is not None and successful_score is not None and eval_score >= successful_score): - with training_done.get_lock(): - if not training_done.value: - training_done.value = True - successful = True + stop_event.set() + successful = True # Break immediately in order to avoid an additional # call of agent.act_and_train break @@ -79,7 +79,7 @@ def train_loop(process_idx, env, agent, steps, outdir, counter, episodes_counter.value += 1 global_episodes = episodes_counter.value - if global_t >= steps or training_done.value: + if global_t >= steps or stop_event.is_set(): break # Start a new episode @@ -121,21 +121,26 @@ def set_shared_objects(agent, shared_objects): setattr(agent, attr, new_value) -def train_agent_async(outdir, processes, make_env, - profile=False, - steps=8 * 10 ** 7, - eval_interval=10 ** 6, - eval_n_steps=None, - eval_n_episodes=10, - max_episode_len=None, - step_offset=0, - successful_score=None, - agent=None, - make_agent=None, - global_step_hooks=(), - save_best_so_far_agent=True, - logger=None, - ): +def train_agent_async( + outdir, + processes, + make_env, + profile=False, + steps=8 * 10 ** 7, + eval_interval=10 ** 6, + eval_n_steps=None, + eval_n_episodes=10, + max_episode_len=None, + step_offset=0, + successful_score=None, + agent=None, + make_agent=None, + global_step_hooks=[], + save_best_so_far_agent=True, + logger=None, + random_seeds=None, + stop_event=None, +): """Train agent asynchronously using multiprocessing. Either `agent` or `make_agent` must be specified. @@ -163,6 +168,10 @@ def train_agent_async(outdir, processes, make_env, if the score (= mean return of evaluation episodes) exceeds the best-so-far score, the current agent is saved. logger (logging.Logger): Logger used in this function. + random_seeds (array-like of ints or None): Random seeds for processes. + If set to None, [0, 1, ..., processes-1] are used. + stop_event (multiprocessing.Event or None): Event to stop training. + If set to None, a new Event object is created and used internally. Returns: Trained agent. @@ -175,7 +184,9 @@ def train_agent_async(outdir, processes, make_env, counter = mp.Value('l', 0) episodes_counter = mp.Value('l', 0) - training_done = mp.Value('b', False) # bool + + if stop_event is None: + stop_event = mp.Event() if agent is None: assert make_agent is not None @@ -197,8 +208,11 @@ def train_agent_async(outdir, processes, make_env, logger=logger, ) + if random_seeds is None: + random_seeds = np.arange(processes) + def run_func(process_idx): - random_seed.set_random_seed(process_idx) + random_seed.set_random_seed(random_seeds[process_idx]) env = make_env(process_idx, test=False) if evaluator is None: @@ -224,7 +238,7 @@ def f(): max_episode_len=max_episode_len, evaluator=evaluator, successful_score=successful_score, - training_done=training_done, + stop_event=stop_event, eval_env=eval_env, global_step_hooks=global_step_hooks, logger=logger) @@ -242,4 +256,6 @@ def f(): async_.run_async(processes, run_func) + stop_event.set() + return agent diff --git a/chainerrl/links/stateless_recurrent.py b/chainerrl/links/stateless_recurrent.py index 79ad73f35..10f799a28 100644 --- a/chainerrl/links/stateless_recurrent.py +++ b/chainerrl/links/stateless_recurrent.py @@ -304,12 +304,21 @@ def get_recurrent_state_at(link, recurrent_state, indices, unwrap_variable): raise ValueError('{} is not a recurrent link'.format(link)) +def _to_device_variable_or_ndarray(device, x): + if isinstance(x, chainer.Variable): + x.to_device(device) + return x + else: + return chainer.dataset.to_device(device, x) + + def concatenate_recurrent_states(link, split_recurrent_states): if isinstance(link, L.NStepLSTM): # shape: (n_layers, batch_size, out_size) n_layers = link.n_layers out_size = link.out_size xp = link.xp + device = link.device hs = [] cs = [] for srs in split_recurrent_states: @@ -318,6 +327,8 @@ def concatenate_recurrent_states(link, split_recurrent_states): c = xp.zeros((n_layers, 1, out_size), dtype=np.float32) else: h, c = srs + h = _to_device_variable_or_ndarray(device, h) + c = _to_device_variable_or_ndarray(device, c) if h.ndim == 2: assert h.shape == (n_layers, out_size) assert c.shape == (n_layers, out_size) @@ -333,12 +344,14 @@ def concatenate_recurrent_states(link, split_recurrent_states): n_layers = link.n_layers out_size = link.out_size xp = link.xp + device = link.device hs = [] for srs in split_recurrent_states: if srs is None: h = xp.zeros((n_layers, 1, out_size), dtype=np.float32) else: h = srs + h = _to_device_variable_or_ndarray(device, h) if h.ndim == 2: assert h.shape == (n_layers, out_size) # add batch axis diff --git a/chainerrl/misc/__init__.py b/chainerrl/misc/__init__.py index 483070f68..98ebf6df3 100644 --- a/chainerrl/misc/__init__.py +++ b/chainerrl/misc/__init__.py @@ -7,4 +7,5 @@ from chainerrl.misc.namedpersistent import namedpersistent # NOQA from chainerrl.misc.is_return_code_zero import is_return_code_zero # NOQA from chainerrl.misc.random_seed import set_random_seed # NOQA +from chainerrl.misc.stoppable_thread import StoppableThread # NOQA from chainerrl.misc.pretrained_models import download_model # NOQA diff --git a/chainerrl/misc/async_.py b/chainerrl/misc/async_.py index 480e7213a..bd89baa9c 100644 --- a/chainerrl/misc/async_.py +++ b/chainerrl/misc/async_.py @@ -5,7 +5,6 @@ import numpy as np import chainerrl -from chainerrl.misc import random_seed class AbnormalExitWarning(Warning): @@ -170,13 +169,8 @@ def run_async(n_process, run_func): processes = [] - def set_seed_and_run(process_idx, run_func): - random_seed.set_random_seed(np.random.randint(0, 2 ** 32)) - run_func(process_idx) - for process_idx in range(n_process): - processes.append(mp.Process(target=set_seed_and_run, args=( - process_idx, run_func))) + processes.append(mp.Process(target=run_func, args=(process_idx,))) for p in processes: p.start() diff --git a/chainerrl/misc/copy_param.py b/chainerrl/misc/copy_param.py index 09a24c2b7..e852279aa 100644 --- a/chainerrl/misc/copy_param.py +++ b/chainerrl/misc/copy_param.py @@ -4,6 +4,7 @@ def copy_param(target_link, source_link): """Copy parameters of a link to another link.""" target_params = dict(target_link.namedparams()) + target_device = target_link.device for param_name, param in source_link.namedparams(): if target_params[param_name].array is None: raise TypeError( @@ -11,15 +12,15 @@ def copy_param(target_link, source_link): 'not initialized.\nPlease try to forward dummy input ' 'beforehand to determine parameter shape of the model.'.format( param_name)) - target_params[param_name].array[...] = param.array + target_params[param_name].array[...] = target_device.send(param.array) # Copy Batch Normalization's statistics target_links = dict(target_link.namedlinks()) for link_name, link in source_link.namedlinks(): if isinstance(link, L.BatchNormalization): target_bn = target_links[link_name] - target_bn.avg_mean[...] = link.avg_mean - target_bn.avg_var[...] = link.avg_var + target_bn.avg_mean[...] = target_device.send(link.avg_mean) + target_bn.avg_var[...] = target_device.send(link.avg_var) def soft_copy_param(target_link, source_link, tau): diff --git a/chainerrl/misc/stoppable_thread.py b/chainerrl/misc/stoppable_thread.py new file mode 100644 index 000000000..f501044d2 --- /dev/null +++ b/chainerrl/misc/stoppable_thread.py @@ -0,0 +1,20 @@ +import threading + + +class StoppableThread(threading.Thread): + """Thread with an event object to stop itself. + + Args: + stop_event (threading.Event): Event that stops the thread if it is set. + *args, **kwargs: Forwarded to `threading.Thread`. + """ + + def __init__(self, stop_event, *args, **kwargs): + super(StoppableThread, self).__init__(*args, **kwargs) + self.stop_event = stop_event + + def stop(self): + self.stop_event.set() + + def is_stopped(self): + self.stop_event.is_set() diff --git a/chainerrl/replay_buffer.py b/chainerrl/replay_buffer.py index 71f67f57f..dfeb9e1de 100644 --- a/chainerrl/replay_buffer.py +++ b/chainerrl/replay_buffer.py @@ -260,15 +260,23 @@ def __init__(self, replay_buffer, update_func, batchsize, episodic_update, self.update_interval = update_interval def update_if_necessary(self, iteration): + """Update the model if the condition is met. + + Args: + iteration (int): Timestep. + + Returns: + bool: True iff the condition was updated this time. + """ if len(self.replay_buffer) < self.replay_start_size: - return + return False if (self.episodic_update and self.replay_buffer.n_episodes < self.batchsize): - return + return False if iteration % self.update_interval != 0: - return + return False for _ in range(self.n_times_update): if self.episodic_update: @@ -278,3 +286,4 @@ def update_if_necessary(self, iteration): else: transitions = self.replay_buffer.sample(self.batchsize) self.update_func(transitions) + return True diff --git a/examples/grasping/train_dqn_batch_grasping.py b/examples/grasping/train_dqn_batch_grasping.py index 55a06f854..49ccda61d 100644 --- a/examples/grasping/train_dqn_batch_grasping.py +++ b/examples/grasping/train_dqn_batch_grasping.py @@ -2,6 +2,9 @@ import functools import os +# Prevent numpy from using multiple threads +os.environ['OMP_NUM_THREADS'] = '1' # NOQA + import chainer from chainer import functions as F from chainer import links as L @@ -293,18 +296,27 @@ def phi(x): args.eval_n_runs, eval_stats['mean'], eval_stats['median'], eval_stats['stdev'])) else: - experiments.train_agent_batch_with_evaluation( - agent=agent, - env=make_batch_env(test=False), - eval_env=eval_env, + + make_actor, learner, poller = agent.setup_actor_learner_training( + args.num_envs) + + poller.start() + learner.start() + experiments.train_agent_async( + processes=args.num_envs, + make_agent=make_actor, + make_env=make_env, steps=args.steps, eval_n_steps=None, eval_n_episodes=args.eval_n_runs, eval_interval=args.eval_interval, outdir=args.outdir, - save_best_so_far_agent=False, - log_interval=1000, + stop_event=learner.stop_event, ) + learner.stop() + learner.join() + poller.stop() + poller.join() if __name__ == '__main__': diff --git a/tests/agents_tests/basetest_training.py b/tests/agents_tests/basetest_training.py index 0a8c2d994..094f13d62 100644 --- a/tests/agents_tests/basetest_training.py +++ b/tests/agents_tests/basetest_training.py @@ -8,6 +8,7 @@ import chainerrl from chainerrl.experiments.evaluator import batch_run_evaluation_episodes +from chainerrl.experiments import train_agent_async from chainerrl.experiments import train_agent_batch_with_evaluation from chainerrl.experiments import train_agent_with_evaluation from chainerrl.misc import random_seed @@ -178,3 +179,72 @@ def test_batch_training_cpu_fast(self): self._test_batch_training(-1, steps=10, require_success=False) self._test_batch_training( -1, steps=0, load_model=True, require_success=False) + + +class _TestActorLearnerTrainingMixin(object): + """Mixin for testing actor-learner training. + + Inherit this after _TestTraining to enable test cases for batch training. + """ + + def _test_actor_learner_training(self, gpu, steps=100000, + require_success=True): + + logging.basicConfig(level=logging.DEBUG) + + test_env, successful_return = self.make_env_and_successful_return( + test=True) + agent = self.make_agent(test_env, gpu) + + def make_env(process_idx, test): + env, _ = self.make_env_and_successful_return(test=test) + return env + + # Train + if steps > 0: + make_actor, learner, poller =\ + agent.setup_actor_learner_training(n_actors=2) + + poller.start() + learner.start() + train_agent_async( + processes=2, + steps=steps, + outdir=self.tmpdir, + eval_interval=200, + eval_n_steps=None, + eval_n_episodes=5, + successful_score=successful_return, + make_env=make_env, + make_agent=make_actor, + stop_event=learner.stop_event, + ) + learner.stop() + learner.join() + poller.stop() + poller.join() + + # Test + + # Because in actor-learner traininig the model can be updated between + # evaluation and saving, it is difficult too guarantee the learned + # model succeeds. Thus we only check if the training was successful. + + if require_success: + assert os.path.exists(os.path.join(self.tmpdir, 'successful')) + + @testing.attr.slow + @testing.attr.gpu + def test_actor_learner_training_gpu(self): + self._test_actor_learner_training(0, steps=100000) + + @testing.attr.slow + def test_actor_learner_training_cpu(self): + self._test_actor_learner_training(-1, steps=100000) + + @testing.attr.gpu + def test_actor_learner_training_gpu_fast(self): + self._test_actor_learner_training(0, steps=10, require_success=False) + + def test_actor_learner_training_cpu_fast(self): + self._test_actor_learner_training(-1, steps=10, require_success=False) diff --git a/tests/agents_tests/test_dqn.py b/tests/agents_tests/test_dqn.py index d8c53b64b..5a47025c0 100644 --- a/tests/agents_tests/test_dqn.py +++ b/tests/agents_tests/test_dqn.py @@ -9,11 +9,14 @@ from chainerrl.agents.dqn import compute_weighted_value_loss from chainerrl.agents.dqn import DQN +from basetest_training import _TestActorLearnerTrainingMixin from basetest_training import _TestBatchTrainingMixin class TestDQNOnDiscreteABC( - _TestBatchTrainingMixin, base._TestDQNOnDiscreteABC): + _TestActorLearnerTrainingMixin, + _TestBatchTrainingMixin, + base._TestDQNOnDiscreteABC): def make_dqn_agent(self, env, q_func, opt, explorer, rbuf, gpu): return DQN(q_func, opt, rbuf, gpu=gpu, gamma=0.9, explorer=explorer, @@ -31,7 +34,9 @@ def test_replay_capacity_checked(self): class TestDQNOnDiscreteABCBoltzmann( - _TestBatchTrainingMixin, base._TestDQNOnDiscreteABC): + _TestActorLearnerTrainingMixin, + _TestBatchTrainingMixin, + base._TestDQNOnDiscreteABC): def make_dqn_agent(self, env, q_func, opt, explorer, rbuf, gpu): explorer = chainerrl.explorers.Boltzmann() @@ -40,7 +45,9 @@ def make_dqn_agent(self, env, q_func, opt, explorer, rbuf, gpu): class TestDQNOnContinuousABC( - _TestBatchTrainingMixin, base._TestDQNOnContinuousABC): + _TestActorLearnerTrainingMixin, + _TestBatchTrainingMixin, + base._TestDQNOnContinuousABC): def make_dqn_agent(self, env, q_func, opt, explorer, rbuf, gpu): return DQN(q_func, opt, rbuf, gpu=gpu, gamma=0.9, explorer=explorer, @@ -48,7 +55,9 @@ def make_dqn_agent(self, env, q_func, opt, explorer, rbuf, gpu): class TestDQNOnDiscretePOABC( - _TestBatchTrainingMixin, base._TestDQNOnDiscretePOABC): + _TestActorLearnerTrainingMixin, + _TestBatchTrainingMixin, + base._TestDQNOnDiscretePOABC): def make_dqn_agent(self, env, q_func, opt, explorer, rbuf, gpu): return DQN(q_func, opt, rbuf, gpu=gpu, gamma=0.9, explorer=explorer, diff --git a/tests/agents_tests/test_iqn.py b/tests/agents_tests/test_iqn.py index 9413b788a..a75036cb4 100644 --- a/tests/agents_tests/test_iqn.py +++ b/tests/agents_tests/test_iqn.py @@ -8,6 +8,7 @@ from chainer import testing import basetest_dqn_like as base +from basetest_training import _TestActorLearnerTrainingMixin from basetest_training import _TestBatchTrainingMixin import chainerrl from chainerrl.agents import iqn @@ -18,7 +19,9 @@ 'quantile_thresholds_N_prime': [1, 7], })) class TestIQNOnDiscreteABC( - _TestBatchTrainingMixin, base._TestDQNOnDiscreteABC): + _TestActorLearnerTrainingMixin, + _TestBatchTrainingMixin, + base._TestDQNOnDiscreteABC): def make_q_func(self, env): obs_size = env.observation_space.low.size @@ -46,7 +49,9 @@ def make_dqn_agent(self, env, q_func, opt, explorer, rbuf, gpu): class TestIQNOnDiscretePOABC( - _TestBatchTrainingMixin, base._TestDQNOnDiscretePOABC): + _TestActorLearnerTrainingMixin, + _TestBatchTrainingMixin, + base._TestDQNOnDiscretePOABC): def make_q_func(self, env): obs_size = env.observation_space.low.size