diff --git a/torchrec/metrics/metrics_namespace.py b/torchrec/metrics/metrics_namespace.py index 20e257d6d..63386b481 100644 --- a/torchrec/metrics/metrics_namespace.py +++ b/torchrec/metrics/metrics_namespace.py @@ -45,6 +45,7 @@ class MetricName(MetricNameBase): LOG_LOSS = "logloss" THROUGHPUT = "throughput" TOTAL_EXAMPLES = "total_examples" + ATTEMPT_EXAMPLES = "attempt_examples" CTR = "ctr" CALIBRATION = "calibration" MSE = "mse" @@ -124,6 +125,7 @@ class MetricPrefix(StrValueMixin, Enum): DEFAULT = "" LIFETIME = "lifetime_" WINDOW = "window_" + ATTEMPT = "attempt_" def task_wildcard_metrics_pattern( diff --git a/torchrec/metrics/tests/test_throughput.py b/torchrec/metrics/tests/test_throughput.py index f907f13c3..55c7373ef 100644 --- a/torchrec/metrics/tests/test_throughput.py +++ b/torchrec/metrics/tests/test_throughput.py @@ -12,6 +12,8 @@ import unittest from unittest.mock import Mock, patch +import torch + from torchrec.metrics.metrics_config import BatchSizeStage from torchrec.metrics.throughput import ThroughputMetric @@ -32,7 +34,11 @@ def test_no_batches(self, time_mock: Mock) -> None: batch_size=self.batch_size, world_size=self.world_size, window_seconds=100 ) self.assertEqual( - throughput_metric.compute(), {"throughput-throughput|total_examples": 0} + throughput_metric.compute(), + { + "throughput-throughput|total_examples": 0, + "throughput-throughput|attempt_examples": 0, + }, ) @patch(THROUGHPUT_PATH + ".time.monotonic") @@ -44,7 +50,12 @@ def test_one_batch(self, time_mock: Mock) -> None: throughput_metric.update() self.assertEqual( throughput_metric.compute(), - {"throughput-throughput|total_examples": self.batch_size * self.world_size}, + { + "throughput-throughput|total_examples": self.batch_size + * self.world_size, + "throughput-throughput|attempt_examples": self.batch_size + * self.world_size, + }, ) @patch(THROUGHPUT_PATH + ".time.monotonic") @@ -73,7 +84,11 @@ def _test_throughput(self, time_mock: Mock, warmup_steps: int) -> None: total_examples = self.world_size * self.batch_size * (i + 1) if i < warmup_steps: self.assertEqual( - ret, {"throughput-throughput|total_examples": total_examples} + ret, + { + "throughput-throughput|total_examples": total_examples, + "throughput-throughput|attempt_examples": total_examples, + }, ) continue @@ -102,6 +117,13 @@ def _test_throughput(self, time_mock: Mock, warmup_steps: int) -> None: self.assertEqual( ret["throughput-throughput|total_examples"], total_examples ) + # only one attempt so attempt examples and throughput are the same as total/lifetime + self.assertEqual( + ret["throughput-throughput|attempt_examples"], total_examples + ) + self.assertEqual( + ret["throughput-throughput|attempt_throughput"], lifetime_throughput + ) def test_throughput_warmup_steps_0(self) -> None: with self.assertRaises(ValueError): @@ -140,8 +162,23 @@ def test_warmup_checkpointing(self) -> None: * self.world_size, ) + self.assertEqual( + throughput_metric.attempt_warmup_examples.item(), + warmup_steps * self.batch_size * self.world_size, + ) + self.assertEqual( + throughput_metric.attempt_examples.item(), + (warmup_steps + extra_steps) * self.batch_size * self.world_size, + ) # Mimic trainer crashing and loading a checkpoint throughput_metric._steps = 0 + throughput_metric.attempt_examples = torch.tensor(0, dtype=torch.long) + throughput_metric.attempt_warmup_examples = torch.tensor( + 0, dtype=torch.long + ) + throughput_metric.attempt_time_lapse_after_warmup = torch.tensor( + 0, dtype=torch.double + ) @patch(THROUGHPUT_PATH + ".time.monotonic") def test_batch_size_schedule(self, time_mock: Mock) -> None: @@ -159,12 +196,18 @@ def test_batch_size_schedule(self, time_mock: Mock) -> None: total_examples += batch_size_stages[0].batch_size * self.world_size self.assertEqual( throughput_metric.compute(), - {"throughput-throughput|total_examples": total_examples}, + { + "throughput-throughput|total_examples": total_examples, + "throughput-throughput|attempt_examples": total_examples, + }, ) throughput_metric.update() total_examples += batch_size_stages[1].batch_size * self.world_size self.assertEqual( throughput_metric.compute(), - {"throughput-throughput|total_examples": total_examples}, + { + "throughput-throughput|total_examples": total_examples, + "throughput-throughput|attempt_examples": total_examples, + }, ) diff --git a/torchrec/metrics/throughput.py b/torchrec/metrics/throughput.py index aba333c75..f1758dc2c 100644 --- a/torchrec/metrics/throughput.py +++ b/torchrec/metrics/throughput.py @@ -71,7 +71,9 @@ class ThroughputMetric(nn.Module): _previous_ts: float _lifetime_throughput_key: str _window_throughput_key: str + _attempt_throughput_key: str _total_examples_key: str + _attempt_examples_key: str _steps: int def __init__( @@ -119,6 +121,20 @@ def __init__( "time_lapse_after_warmup", torch.tensor(0, dtype=torch.double) ) + self.register_buffer( + "attempt_examples", torch.tensor(0, dtype=torch.long), persistent=False + ) + self.register_buffer( + "attempt_warmup_examples", + torch.tensor(0, dtype=torch.long), + persistent=False, + ) + self.register_buffer( + "attempt_time_lapse_after_warmup", + torch.tensor(0, dtype=torch.double), + persistent=False, + ) + self._window_time_lapse_buffer = deque(maxlen=MAX_WINDOW_TS) self._window_time_lapse = 0 self._previous_ts = 0 @@ -135,12 +151,22 @@ def __init__( self._metric_name, MetricPrefix.WINDOW, ) + self._attempt_throughput_key = compose_metric_key( + self._namespace, + str(self._namespace), + self._metric_name, + MetricPrefix.ATTEMPT, + ) self._total_examples_key = compose_metric_key( self._namespace, str(self._namespace), MetricName.TOTAL_EXAMPLES, ) - + self._attempt_examples_key = compose_metric_key( + self._namespace, + str(self._namespace), + MetricName.ATTEMPT_EXAMPLES, + ) self._steps = 0 def _get_batch_size(self) -> int: @@ -177,28 +203,38 @@ def update(self) -> None: self._steps += 1 if self._batch_size_stages is not None: self.num_batch += 1 - self.total_examples += self._batch_examples() + batch_examples = self._batch_examples() + self.total_examples += batch_examples + self.attempt_examples += batch_examples if self._steps <= self._warmup_steps: - self.warmup_examples += self._batch_examples() + self.warmup_examples += batch_examples + self.attempt_warmup_examples += batch_examples if self._steps == self._warmup_steps: self._previous_ts = ts else: time_lapse = ts - self._previous_ts self.time_lapse_after_warmup += time_lapse + self.attempt_time_lapse_after_warmup += time_lapse self._window_time_lapse += time_lapse self._window_time_lapse_buffer.append(time_lapse) self._check_window() self._previous_ts = ts def compute(self) -> Dict[str, torch.Tensor]: - ret = {self._total_examples_key: self.total_examples} + ret = { + self._total_examples_key: self.total_examples, + self._attempt_examples_key: self.attempt_examples, + } if self._steps > self._warmup_steps and ( not math.isclose(self.time_lapse_after_warmup.item(), 0) ): lifetime_throughput = ( self.total_examples - self.warmup_examples ) / self.time_lapse_after_warmup + attempt_throughput = ( + self.attempt_examples - self.attempt_warmup_examples + ) / self.attempt_time_lapse_after_warmup if not math.isclose(self._window_time_lapse, 0): window_throughput = ( len(self._window_time_lapse_buffer) @@ -216,4 +252,10 @@ def compute(self) -> Dict[str, torch.Tensor]: ), } ) + if not math.isclose(attempt_throughput.item(), 0): + ret.update( + { + self._attempt_throughput_key: attempt_throughput.clone().detach(), + } + ) return ret