diff --git a/pokemonred_puffer/train.py b/pokemonred_puffer/train.py index b604484..4b8741b 100644 --- a/pokemonred_puffer/train.py +++ b/pokemonred_puffer/train.py @@ -331,6 +331,8 @@ def train( vec = pufferlib.vector.Multiprocessing elif vec == Vectorization.ray: vec = pufferlib.vector.Ray + else: + vec = pufferlib.vector.Multiprocessing # TODO: Remove the +1 once the driver env doesn't permanently increase the env id env_send_queues = [Queue() for _ in range(2 * config.train.num_envs + 1)]