Skip to content

Commit

Permalink
Initial 0.7 shared memory prototype
Browse files Browse the repository at this point in the history
  • Loading branch information
Joseph Suarez committed Jan 23, 2024
1 parent 5660ccc commit ab2f3f0
Show file tree
Hide file tree
Showing 8 changed files with 152 additions and 58 deletions.
6 changes: 3 additions & 3 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -664,9 +664,9 @@ pokemon_red:
package: pokemon_red
train:
total_timesteps: 100_000_000
num_envs: 4
envs_per_worker: 1
envpool_batch_size: 4
num_envs: 24
envs_per_worker: 4
envpool_batch_size: 24
update_epochs: 3
gamma: 0.998
batch_size: 1024
Expand Down
2 changes: 2 additions & 0 deletions pufferlib/emulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,12 +554,14 @@ def _recursion_helper(current, key):
return flat

def concatenate(flat_sample):
'''
if len(flat_sample) == 1:
flat_sample = flat_sample[0]
if isinstance(flat_sample,(np.ndarray,
gymnasium.wrappers.frame_stack.LazyFrames)):
return flat_sample
return np.array([flat_sample])
'''
return np.concatenate([
e.ravel() if isinstance(e, np.ndarray) else np.array([e])
for e in flat_sample]
Expand Down
6 changes: 6 additions & 0 deletions pufferlib/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,9 @@ def __init__(self, env, *args, framestack, flat_size,
self.channels_last = channels_last
self.downsample = downsample

self.flat_observation_space = env.flat_observation_space
self.flat_observation_structure = env.flat_observation_structure

self.network = nn.Sequential(
pufferlib.pytorch.layer_init(nn.Conv2d(framestack, 32, 8, stride=4)),
nn.ReLU(),
Expand All @@ -209,6 +212,9 @@ def __init__(self, env, *args, framestack, flat_size,
self.value_fn = pufferlib.pytorch.layer_init(nn.Linear(output_size, 1), std=1)

def encode_observations(self, observations):
observations = pufferlib.emulation.unpack_batched_obs(observations,
self.flat_observation_space, self.flat_observation_structure)

if self.channels_last:
observations = observations.permute(0, 3, 1, 2)
if self.downsample > 1:
Expand Down
43 changes: 25 additions & 18 deletions pufferlib/vectorization/gym_multi_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,17 @@


def reset(state, seed=None):
if state.preallocated_obs is None:
obs_space = state.envs[0].observation_space
obs_n = obs_space.shape[0]
n_envs = len(state.envs)

state.preallocated_obs = np.empty(
(n_envs, *obs_space.shape), dtype=obs_space.dtype)
state.preallocated_rewards = np.empty(n_envs, dtype=np.float32)
state.preallocated_dones = np.empty(n_envs, dtype=np.bool)
state.preallocated_truncateds = np.empty(n_envs, dtype=np.bool)

infos = []
for idx, e in enumerate(state.envs):
if seed is None:
Expand All @@ -20,38 +31,34 @@ def reset(state, seed=None):

i['mask'] = True
infos.append(i)
if state.preallocated_obs is None:
state.preallocated_obs = np.empty(
(len(state.envs), *ob.shape), dtype=ob.dtype)

state.preallocated_obs[idx] = ob
state.preallocated_rewards[idx] = 0
state.preallocated_dones[idx] = False
state.preallocated_truncateds[idx] = False

rewards = [0] * len(state.preallocated_obs)
dones = [False] * len(state.preallocated_obs)
truncateds = [False] * len(state.preallocated_obs)

return state.preallocated_obs, rewards, dones, truncateds, infos
return (state.preallocated_obs, state.preallocated_rewards,
state.preallocated_dones, state.preallocated_truncateds, infos)

def step(state, actions):
rewards, dones, truncateds, infos = [], [], [], []

infos = []
for idx, (env, atns) in enumerate(zip(state.envs, actions)):
if env.done:
o, i = env.reset()
rewards.append(0)
dones.append(False)
truncateds.append(False)
state.preallocated_rewards[idx] = 0
state.preallocated_dones[idx] = False
state.preallocated_truncateds[idx] = False
else:
o, r, d, t, i = env.step(atns)
rewards.append(r)
dones.append(d)
truncateds.append(t)
state.preallocated_rewards[idx] = r
state.preallocated_dones[idx] = d
state.preallocated_truncateds[idx] = t

i['mask'] = True
infos.append(i)
state.preallocated_obs[idx] = o

return state.preallocated_obs, rewards, dones, truncateds, infos
return (state.preallocated_obs, state.preallocated_rewards,
state.preallocated_dones, state.preallocated_truncateds, infos)

class GymMultiEnv:
__init__ = init
Expand Down
79 changes: 62 additions & 17 deletions pufferlib/vectorization/multiprocessing_vec_env.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from pdb import set_trace as T
import numpy as np
import psutil
import time

import selectors
from multiprocessing import Process, Queue, Manager, Pipe
from multiprocessing import Process, Queue, Manager, Pipe, Array
from queue import Empty

from pufferlib import namespace
Expand All @@ -11,6 +13,7 @@
calc_scale_params,
setup,
single_observation_space,
_single_observation_space,
single_action_space,
single_action_space,
structured_observation_space,
Expand Down Expand Up @@ -39,13 +42,23 @@ def init(self: object = None,
num_workers, workers_per_batch, envs_per_batch, agents_per_batch, agents_per_worker = calc_scale_params(
num_envs, envs_per_batch, envs_per_worker, agents_per_env)

observation_size = int(np.prod(_single_observation_space(driver_env).shape))
# Shared memory for obs, rewards, terminals, truncateds
shared_mem = [
Array('d', envs_per_worker*(3+observation_size))
for _ in range(num_workers)
]
main_send_pipes, work_recv_pipes = zip(*[Pipe() for _ in range(num_workers)])
work_send_pipes, main_recv_pipes = zip(*[Pipe() for _ in range(num_workers)])

num_cores = psutil.cpu_count()
#curr_process = psutil.Process()
#curr_process.cpu_affinity([num_cores-1])
processes = [Process(
target=_worker_process,
args=(multi_env_cls, env_creator, env_args, env_kwargs,
envs_per_worker, work_send_pipes[i], work_recv_pipes[i]))
args=(multi_env_cls, env_creator, env_args, env_kwargs, envs_per_worker,
i%(num_cores-1), shared_mem[i], work_send_pipes[i], work_recv_pipes[i])
)
for i in range(num_workers)]

for p in processes:
Expand All @@ -59,6 +72,8 @@ def init(self: object = None,
return namespace(self,
processes = processes,
sel = sel,
observation_size = observation_size,
shared_mem = shared_mem,
send_pipes = main_send_pipes,
recv_pipes = main_recv_pipes,
driver_env = driver_env,
Expand All @@ -76,44 +91,74 @@ def init(self: object = None,
env_pool = env_pool,
)

def _worker_process(multi_env_cls, env_creator, env_args, env_kwargs, n, send_pipe, recv_pipe):
def _unpack_shared_mem(shared_mem, n):
np_buf = np.frombuffer(shared_mem.get_obj(), dtype=float)
obs_arr = np_buf[:-3*n]
rewards_arr = np_buf[-3*n:-2*n]
terminals_arr = np_buf[-2*n:-n]
truncated_arr = np_buf[-n:]

return obs_arr, rewards_arr, terminals_arr, truncated_arr

def _worker_process(multi_env_cls, env_creator, env_args, env_kwargs, n,
worker_idx, shared_mem, send_pipe, recv_pipe):

# I don't know if this helps. Sometimes it does, sometimes not.
# Need to run more comprehensive tests
#curr_process = psutil.Process()
#curr_process.cpu_affinity([worker_idx])

envs = multi_env_cls(env_creator, env_args, env_kwargs, n=n)
obs_arr, rewards_arr, terminals_arr, truncated_arr = _unpack_shared_mem(shared_mem, n)

while True:
request, args, kwargs = recv_pipe.recv()
func = getattr(envs, request)
response = func(*args, **kwargs)
send_pipe.send(response)

# TODO: Handle put/get
obs, reward, done, truncated, info = response

# TESTED: There is no overhead associated with 4 assignments to shared memory
# vs. 4 assigns to an intermediate numpy array and then 1 assign to shared memory
obs_arr[:] = obs.ravel()
rewards_arr[:] = reward.ravel()
terminals_arr[:] = done.ravel()
truncated_arr[:] = truncated.ravel()
send_pipe.send(info)

def recv(state):
recv_precheck(state)

recvs = []
next_env_id = []
if state.env_pool:
for env_id in range(state.workers_per_batch):
response_pipe = state.recv_pipes[env_id]
response = response_pipe.recv()

o, r, d, t, i = response
recvs.append((o, r, d, t, i, env_id))
next_env_id.append(env_id)
else:
while len(recvs) < state.workers_per_batch:
for key, _ in state.sel.select(timeout=None):
response_pipe = key.fileobj
env_id = state.recv_pipes.index(response_pipe)

if response_pipe.poll(): # Check if data is available
response = response_pipe.recv()
info = response_pipe.recv()
o, r, d, t = _unpack_shared_mem(
state.shared_mem[env_id], state.envs_per_worker)
o = o.reshape(state.envs_per_worker, state.observation_size)

o, r, d, t, i = response
recvs.append((o, r, d, t, i, env_id))
recvs.append((o, r, d, t, info, env_id))
next_env_id.append(env_id)

if len(recvs) == state.workers_per_batch:
break

else:
for env_id in range(state.workers_per_batch):
response_pipe = state.recv_pipes[env_id]
info = response_pipe.recv()
o, r, d, t = _unpack_shared_mem(
state.shared_mem[env_id], state.envs_per_worker)
o = o.reshape(state.envs_per_worker, state.observation_size)
recvs.append((o, r, d, t, info, env_id))
next_env_id.append(env_id)

state.prev_env_id = next_env_id
return aggregate_recvs(state, recvs)

Expand Down
28 changes: 19 additions & 9 deletions pufferlib/vectorization/pettingzoo_multi_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,21 @@


def reset(state, seed=None):
if state.preallocated_obs is None:
obs_space = state.envs[0].observation_space
obs_n = obs_space.shape[0]
n_agents = len(state.envs[0].possible_agents)
n_envs = len(state.envs)
n = n_envs * n_agents

state.preallocated_obs = np.empty(
(n, *obs_space.shape), dtype=obs_space.dtype)
state.preallocated_rewards = np.empty(n, dtype=np.float32)
state.preallocated_dones = np.empty(n, dtype=np.bool)
state.preallocated_truncateds = np.empty(n, dtype=np.bool)

state.agent_keys = []
infos = []

ptr = 0
for idx, e in enumerate(state.envs):
if seed is None:
Expand All @@ -25,18 +37,16 @@ def reset(state, seed=None):
state.agent_keys.append(list(obs.keys()))
infos.append(i)

if state.preallocated_obs is None:
ob = obs[list(obs.keys())[0]]
state.preallocated_obs = np.empty((len(state.envs)*len(obs), *ob.shape), dtype=ob.dtype)

for o in obs.values():
state.preallocated_obs[ptr] = o
ptr += 1

rewards = [0] * len(state.preallocated_obs)
dones = [False] * len(state.preallocated_obs)
truncateds = [False] * len(state.preallocated_obs)
return state.preallocated_obs, rewards, dones, truncateds, infos
state.preallocated_rewards[:] = 0
state.preallocated_dones[:] = False
state.preallocated_truncateds[:] = False

return (state.preallocated_obs, state.preallocated_rewards,
state.preallocated_dones, state.preallocated_truncateds, infos)

def step(state, actions):
actions = np.array_split(actions, len(state.envs))
Expand Down
13 changes: 9 additions & 4 deletions pufferlib/vectorization/vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,15 @@ def aggregate_recvs(state, recvs):
assert all(state.workers_per_batch == len(e) for e in
(obs, rewards, dones, truncateds, infos, env_ids))

obs = np.stack(list(chain.from_iterable(obs)), 0)
rewards = list(chain.from_iterable(rewards))
dones = list(chain.from_iterable(dones))
truncateds = list(chain.from_iterable(truncateds))

obs = np.concatenate(obs)
rewards = np.concatenate(rewards)
dones = np.concatenate(dones)
truncateds = np.concatenate(truncateds)
#obs = np.stack(list(chain.from_iterable(obs)), 0)
#rewards = list(chain.from_iterable(rewards))
#dones = list(chain.from_iterable(dones))
#truncateds = list(chain.from_iterable(truncateds))
infos = [i for ii in infos for i in ii]

# TODO: Masking will break for 1-agent PZ envs
Expand Down
33 changes: 26 additions & 7 deletions tests/pool/test_multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,40 @@
import time

from pufferlib.vectorization import Multiprocessing
from pufferlib.registry import pokemon_red
from pufferlib.environments import pokemon_red

def test_envpool(workers=24, steps=100):
pool = Multiprocessing(pokemon_red.make_env, num_workers=workers)
def test_envpool(num_envs, envs_per_worker, envs_per_batch, steps=1000, env_pool=True):
pool = Multiprocessing(pokemon_red.env_creator(), num_envs=num_envs,
envs_per_worker=envs_per_worker, envs_per_batch=envs_per_batch,
env_pool=True,
)
pool.async_reset()

a = np.array([pool.single_action_space.sample() for _ in range(envs_per_batch)])
start = time.time()
for s in range(steps):
o, r, d, t, i, mask = pool.recv()
a = np.array([pool.single_action_space.sample() for _ in mask])
o, r, d, t, i, mask, env_id = pool.recv()
pool.send(a)
end = time.time()
print('Steps per second: ', steps / (end - start))
print('Steps per second: ', envs_per_batch * steps / (end - start))
pool.close()


if __name__ == '__main__':
test_envpool()
# 225 sps
#test_envpool(num_envs=1, envs_per_worker=1, envs_per_batch=1, env_pool=False)

# 600 sps
#test_envpool(num_envs=6, envs_per_worker=1, envs_per_batch=6, env_pool=False)

# 645 sps
#test_envpool(num_envs=24, envs_per_worker=4, envs_per_batch=24, env_pool=False)

# 755 sps
# test_envpool(num_envs=24, envs_per_worker=4, envs_per_batch=24)

# 1050 sps
# test_envpool(num_envs=48, envs_per_worker=4, envs_per_batch=24)

# 1300 sps
test_envpool(num_envs=48, envs_per_worker=4, envs_per_batch=12)

0 comments on commit ab2f3f0

Please sign in to comment.