diff --git a/moojoco/environment/mjc_env.py b/moojoco/environment/mjc_env.py index a710ab0..7c2a889 100644 --- a/moojoco/environment/mjc_env.py +++ b/moojoco/environment/mjc_env.py @@ -277,7 +277,11 @@ def reset( self, rng: List[np.random.RandomState], *args, **kwargs ) -> VectorMJCEnvState: self._states = list( - self._pool.map(lambda env, sub_rng: env.reset(sub_rng, *args, **kwargs), self._envs, rng) + self._pool.map( + lambda env, sub_rng: env.reset(sub_rng, *args, **kwargs), + self._envs, + rng, + ) ) return self._merged_states