Skip to content

Commit

Permalink
Moved the logic that creates the (default) writer to a class method.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 554711699
  • Loading branch information
sgirgin authored and t5-copybara committed Aug 8, 2023
1 parent 3a43448 commit a257cac
Showing 1 changed file with 17 additions and 7 deletions.
24 changes: 17 additions & 7 deletions t5x/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,22 +293,29 @@ def __init__(
"""
self._name = name
if jax.process_index() == 0:
self._writer = metric_writers.create_default_writer(
summary_dir,
collection=name,
asynchronous=True,
)
self._writer = self._create_writer(name, summary_dir)
else:
self._writer = metric_writers.MultiWriter([])
self.summary_dir = os.path.join(summary_dir, name) if summary_dir else None
self._writer_lock = threading.Lock()
# We use a thread pool with a single worker to ensure that calls to the
# function are run in order (but in a background thread).
self._summary_pool = asynclib.Pool(
thread_name_prefix="MetricsManager", max_workers=1)
thread_name_prefix="MetricsManager", max_workers=1
)
# Times the duration between steps.
self._duration_timer = _AsyncTimer()

def _create_writer(
self, name: str, summary_dir: Optional[str] = None
) -> metric_writers.MetricWriter:
"""Creates the writer for host 0."""
return metric_writers.create_default_writer(
summary_dir,
collection=name,
asynchronous=True,
)

def __del__(self):
self.close()

Expand All @@ -327,6 +334,7 @@ def close(self):
def summary_writer(self) -> metric_writers.MetricWriter:
"""Returns the MetricWriter used by this class."""
# TODO(adarob): Make returned writer threadsafe.
assert self._writer is not None
return self._writer

def write_scalar(self, key: str, val: metric_writers.interface.Scalar,
Expand Down Expand Up @@ -674,7 +682,9 @@ def accumulate_grads_microbatched(
# Note: Default t5x models don't support flax_mutables. One needs to subclass
# them and return flax_mutables from `get_initial_variables` and `loss_fn`.

initial_flax_mutables = train_state.flax_mutables if train_state.flax_mutables else None
initial_flax_mutables = (
train_state.flax_mutables if train_state.flax_mutables else None
)

if num_microbatches is None or num_microbatches <= 1:

Expand Down

0 comments on commit a257cac

Please sign in to comment.