diff --git a/dreamer.py b/dreamer.py index fa6236d..9fbcb6a 100644 --- a/dreamer.py +++ b/dreamer.py @@ -113,7 +113,7 @@ def __call__(self, obs, reset, state=None, training=True): if state is not None and reset.any(): mask = tf.cast(1 - reset, self._float)[:, None] state = tf.nest.map_structure(lambda x: x * mask, state) - if self._should_train(step): + if self._should_train(step) and training: log = self._should_log(step) n = self._c.pretrain if self._should_pretrain() else self._c.train_steps print(f'Training for {n} steps.') diff --git a/models.py b/models.py index 0f40316..0b2d9bb 100644 --- a/models.py +++ b/models.py @@ -172,5 +172,5 @@ def __call__(self, features): x = self.get(f'hout', tfkl.Dense, self._size)(x) dist = tools.OneHotDist(x) else: - raise NotImplementedError(dist) + raise NotImplementedError(self._dist) return dist