diff --git a/t5x/trainer.py b/t5x/trainer.py index 2fa7cd45f..752beb379 100644 --- a/t5x/trainer.py +++ b/t5x/trainer.py @@ -293,11 +293,7 @@ def __init__( """ self._name = name if jax.process_index() == 0: - self._writer = metric_writers.create_default_writer( - summary_dir, - collection=name, - asynchronous=True, - ) + self._writer = self._create_writer(name, summary_dir) else: self._writer = metric_writers.MultiWriter([]) self.summary_dir = os.path.join(summary_dir, name) if summary_dir else None @@ -305,10 +301,21 @@ def __init__( # We use a thread pool with a single worker to ensure that calls to the # function are run in order (but in a background thread). self._summary_pool = asynclib.Pool( - thread_name_prefix="MetricsManager", max_workers=1) + thread_name_prefix="MetricsManager", max_workers=1 + ) # Times the duration between steps. self._duration_timer = _AsyncTimer() + def _create_writer( + self, name: str, summary_dir: Optional[str] = None + ) -> metric_writers.MetricWriter: + """Creates the writer for host 0.""" + return metric_writers.create_default_writer( + summary_dir, + collection=name, + asynchronous=True, + ) + def __del__(self): self.close() @@ -327,6 +334,7 @@ def close(self): def summary_writer(self) -> metric_writers.MetricWriter: """Returns the MetricWriter used by this class.""" # TODO(adarob): Make returned writer threadsafe. + assert self._writer is not None return self._writer def write_scalar(self, key: str, val: metric_writers.interface.Scalar, @@ -674,7 +682,9 @@ def accumulate_grads_microbatched( # Note: Default t5x models don't support flax_mutables. One needs to subclass # them and return flax_mutables from `get_initial_variables` and `loss_fn`. - initial_flax_mutables = train_state.flax_mutables if train_state.flax_mutables else None + initial_flax_mutables = ( + train_state.flax_mutables if train_state.flax_mutables else None + ) if num_microbatches is None or num_microbatches <= 1: