Skip to content

Commit

Permalink
Support GPUs in RaySampler
Browse files Browse the repository at this point in the history
  • Loading branch information
krzentner committed Apr 17, 2022
1 parent c56513f commit 3af5eeb
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 26 deletions.
35 changes: 26 additions & 9 deletions src/garage/sampler/ray_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,17 @@ class RaySampler(Sampler):
The maximum length episodes which will be sampled.
is_tf_worker (bool): Whether it is workers for TFTrainer.
seed(int): The seed to use to initialize random number generators.
n_workers(int): The number of workers to use.
n_workers(int or None): The number of workers to use. Defaults to
number of physical cpus, if worker_factory is also None.
worker_class(type): Class of the workers. Instances should implement
the Worker interface.
worker_args (dict or None): Additional arguments that should be passed
to the worker.
n_gpus (int or float): Number of GPUs to to use in total for sampling.
If `n_workers` is not a power of two, this may need to be set
slightly below the true value, since `n_workers / n_gpus` gpus are
allocated to each worker. Defaults to zero, because otherwise
nothing would run if no gpus were available.
"""

Expand All @@ -54,26 +60,32 @@ def __init__(
max_episode_length=None,
is_tf_worker=False,
seed=get_seed(),
n_workers=psutil.cpu_count(logical=False),
n_workers=None,
worker_class=DefaultWorker,
worker_args=None):
# pylint: disable=super-init-not-called
worker_args=None,
n_gpus=0):
if not ray.is_initialized():
ray.init(log_to_driver=False, ignore_reinit_error=True)
if worker_factory is None and max_episode_length is None:
raise TypeError('Must construct a sampler from WorkerFactory or'
'parameters (at least max_episode_length)')
if isinstance(worker_factory, WorkerFactory):
if worker_factory is not None:
if n_workers is None:
n_workers = worker_factory.n_workers
self._worker_factory = worker_factory
else:
if n_workers is None:
n_workers = psutil.cpu_count(logical=False)
self._worker_factory = WorkerFactory(
max_episode_length=max_episode_length,
is_tf_worker=is_tf_worker,
seed=seed,
n_workers=n_workers,
worker_class=worker_class,
worker_args=worker_args)
self._sampler_worker = ray.remote(SamplerWorker)
remote_wrapper = ray.remote(num_gpus=n_gpus / n_workers)
self._n_gpus = n_gpus
self._sampler_worker = remote_wrapper(SamplerWorker)
self._agents = agents
self._envs = self._worker_factory.prepare_worker_messages(envs)
self._all_workers = defaultdict(None)
Expand Down Expand Up @@ -103,7 +115,10 @@ def from_worker_factory(cls, worker_factory, agents, envs):
Sampler: An instance of `cls`.
"""
return cls(agents, envs, worker_factory=worker_factory)
return cls(agents,
envs,
worker_factory=worker_factory,
n_workers=worker_factory.n_workers)

def start_worker(self):
"""Initialize a new ray worker."""
Expand Down Expand Up @@ -308,7 +323,8 @@ def __getstate__(self):
"""
return dict(factory=self._worker_factory,
agents=self._agents,
envs=self._envs)
envs=self._envs,
n_gpus=self._n_gpus)

def __setstate__(self, state):
"""Unpickle the state.
Expand All @@ -319,7 +335,8 @@ def __setstate__(self, state):
"""
self.__init__(state['agents'],
state['envs'],
worker_factory=state['factory'])
worker_factory=state['factory'],
n_gpus=state['n_gpus'])


class SamplerWorker:
Expand Down
18 changes: 1 addition & 17 deletions src/garage/sampler/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,9 @@ class Sampler(abc.ABC):
`Sampler` needs. Specifically, it specifies how to construct `Worker`s,
which know how to collect episodes and update both agents and environments.
Currently, `__init__` is also part of the interface, but calling it is
deprecated. `start_worker` is also deprecated, and does not need to be
implemented.
`start_worker` is deprecated, and does not need to be implemented.
"""

def __init__(self, algo, env):
"""Construct a Sampler from an Algorithm.
Args:
algo (RLAlgorithm): The RL Algorithm controlling this
sampler.
env (Environment): The environment being sampled from.
Calling this method is deprecated.
"""
self.algo = algo
self.env = env

@classmethod
def from_worker_factory(cls, worker_factory, agents, envs):
"""Construct this sampler.
Expand Down
35 changes: 35 additions & 0 deletions tests/garage/sampler/test_ray_batched_sampler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Tests for ray_batched_sampler."""
import pickle
from unittest.mock import Mock

import numpy as np
Expand Down Expand Up @@ -138,6 +139,40 @@ def test_init_with_env_updates(ray_local_session_fixture):
assert sum(episodes.lengths) >= 160


def test_pickle(ray_local_session_fixture):
del ray_local_session_fixture
assert ray.is_initialized()
max_episode_length = 16
env = PointEnv()
policy = FixedPolicy(env.spec,
scripted_actions=[
env.action_space.sample()
for _ in range(max_episode_length)
])
tasks = SetTaskSampler(PointEnv)
n_workers = 4
workers = WorkerFactory(seed=100,
max_episode_length=max_episode_length,
n_workers=n_workers)
sampler = RaySampler.from_worker_factory(workers, policy, env)
sampler_pickled = pickle.dumps(sampler)
sampler.shutdown_worker()
sampler2 = pickle.loads(sampler_pickled)
episodes = sampler2.obtain_samples(0,
500,
np.asarray(policy.get_param_values()),
env_update=tasks.sample(n_workers))
mean_rewards = []
goals = []
for eps in episodes.split():
mean_rewards.append(eps.rewards.mean())
goals.append(eps.env_infos['task'][0]['goal'])
assert np.var(mean_rewards) > 0
assert np.var(goals) > 0
sampler2.shutdown_worker()
env.close()


def test_init_without_worker_factory(ray_local_session_fixture):
del ray_local_session_fixture
assert ray.is_initialized()
Expand Down

0 comments on commit 3af5eeb

Please sign in to comment.