Skip to content

Commit

Permalink
Add attempt QPS metric
Browse files Browse the repository at this point in the history
Summary:
Add an attempt QPS metric to measure throughput (QPS) performance of a job attempt. 

This is relevant when different job attempts can be scheduled on different hardware types. When that happens, the lifetime qps metric ends up being an average across different hardware types that can have very different capabilities and is no longer useful for performance analysis. Having a metric that calculates QPS at an attempt level allows for meaningful performance analysis even across different hardware types.

Differential Revision: D64878139
  • Loading branch information
Karthik Jayaraman authored and facebook-github-bot committed Oct 25, 2024
1 parent cd64b9d commit ddcbc78
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 9 deletions.
2 changes: 2 additions & 0 deletions torchrec/metrics/metrics_namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class MetricName(MetricNameBase):
LOG_LOSS = "logloss"
THROUGHPUT = "throughput"
TOTAL_EXAMPLES = "total_examples"
ATTEMPT_EXAMPLES = "attempt_examples"
CTR = "ctr"
CALIBRATION = "calibration"
MSE = "mse"
Expand Down Expand Up @@ -124,6 +125,7 @@ class MetricPrefix(StrValueMixin, Enum):
DEFAULT = ""
LIFETIME = "lifetime_"
WINDOW = "window_"
ATTEMPT = "attempt_"


def task_wildcard_metrics_pattern(
Expand Down
53 changes: 48 additions & 5 deletions torchrec/metrics/tests/test_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import unittest
from unittest.mock import Mock, patch

import torch

from torchrec.metrics.metrics_config import BatchSizeStage

from torchrec.metrics.throughput import ThroughputMetric
Expand All @@ -32,7 +34,11 @@ def test_no_batches(self, time_mock: Mock) -> None:
batch_size=self.batch_size, world_size=self.world_size, window_seconds=100
)
self.assertEqual(
throughput_metric.compute(), {"throughput-throughput|total_examples": 0}
throughput_metric.compute(),
{
"throughput-throughput|total_examples": 0,
"throughput-throughput|attempt_examples": 0,
},
)

@patch(THROUGHPUT_PATH + ".time.monotonic")
Expand All @@ -44,7 +50,12 @@ def test_one_batch(self, time_mock: Mock) -> None:
throughput_metric.update()
self.assertEqual(
throughput_metric.compute(),
{"throughput-throughput|total_examples": self.batch_size * self.world_size},
{
"throughput-throughput|total_examples": self.batch_size
* self.world_size,
"throughput-throughput|attempt_examples": self.batch_size
* self.world_size,
},
)

@patch(THROUGHPUT_PATH + ".time.monotonic")
Expand Down Expand Up @@ -73,7 +84,11 @@ def _test_throughput(self, time_mock: Mock, warmup_steps: int) -> None:
total_examples = self.world_size * self.batch_size * (i + 1)
if i < warmup_steps:
self.assertEqual(
ret, {"throughput-throughput|total_examples": total_examples}
ret,
{
"throughput-throughput|total_examples": total_examples,
"throughput-throughput|attempt_examples": total_examples,
},
)
continue

Expand Down Expand Up @@ -102,6 +117,13 @@ def _test_throughput(self, time_mock: Mock, warmup_steps: int) -> None:
self.assertEqual(
ret["throughput-throughput|total_examples"], total_examples
)
# only one attempt so attempt examples and throughput are the same as total/lifetime
self.assertEqual(
ret["throughput-throughput|attempt_examples"], total_examples
)
self.assertEqual(
ret["throughput-throughput|attempt_throughput"], lifetime_throughput
)

def test_throughput_warmup_steps_0(self) -> None:
with self.assertRaises(ValueError):
Expand Down Expand Up @@ -140,8 +162,23 @@ def test_warmup_checkpointing(self) -> None:
* self.world_size,
)

self.assertEqual(
throughput_metric.attempt_warmup_examples.item(),
warmup_steps * self.batch_size * self.world_size,
)
self.assertEqual(
throughput_metric.attempt_examples.item(),
(warmup_steps + extra_steps) * self.batch_size * self.world_size,
)
# Mimic trainer crashing and loading a checkpoint
throughput_metric._steps = 0
throughput_metric.attempt_examples = torch.tensor(0, dtype=torch.long)
throughput_metric.attempt_warmup_examples = torch.tensor(
0, dtype=torch.long
)
throughput_metric.attempt_time_lapse_after_warmup = torch.tensor(
0, dtype=torch.double
)

@patch(THROUGHPUT_PATH + ".time.monotonic")
def test_batch_size_schedule(self, time_mock: Mock) -> None:
Expand All @@ -159,12 +196,18 @@ def test_batch_size_schedule(self, time_mock: Mock) -> None:
total_examples += batch_size_stages[0].batch_size * self.world_size
self.assertEqual(
throughput_metric.compute(),
{"throughput-throughput|total_examples": total_examples},
{
"throughput-throughput|total_examples": total_examples,
"throughput-throughput|attempt_examples": total_examples,
},
)

throughput_metric.update()
total_examples += batch_size_stages[1].batch_size * self.world_size
self.assertEqual(
throughput_metric.compute(),
{"throughput-throughput|total_examples": total_examples},
{
"throughput-throughput|total_examples": total_examples,
"throughput-throughput|attempt_examples": total_examples,
},
)
50 changes: 46 additions & 4 deletions torchrec/metrics/throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ class ThroughputMetric(nn.Module):
_previous_ts: float
_lifetime_throughput_key: str
_window_throughput_key: str
_attempt_throughput_key: str
_total_examples_key: str
_attempt_examples_key: str
_steps: int

def __init__(
Expand Down Expand Up @@ -119,6 +121,20 @@ def __init__(
"time_lapse_after_warmup", torch.tensor(0, dtype=torch.double)
)

self.register_buffer(
"attempt_examples", torch.tensor(0, dtype=torch.long), persistent=False
)
self.register_buffer(
"attempt_warmup_examples",
torch.tensor(0, dtype=torch.long),
persistent=False,
)
self.register_buffer(
"attempt_time_lapse_after_warmup",
torch.tensor(0, dtype=torch.double),
persistent=False,
)

self._window_time_lapse_buffer = deque(maxlen=MAX_WINDOW_TS)
self._window_time_lapse = 0
self._previous_ts = 0
Expand All @@ -135,12 +151,22 @@ def __init__(
self._metric_name,
MetricPrefix.WINDOW,
)
self._attempt_throughput_key = compose_metric_key(
self._namespace,
str(self._namespace),
self._metric_name,
MetricPrefix.ATTEMPT,
)
self._total_examples_key = compose_metric_key(
self._namespace,
str(self._namespace),
MetricName.TOTAL_EXAMPLES,
)

self._attempt_examples_key = compose_metric_key(
self._namespace,
str(self._namespace),
MetricName.ATTEMPT_EXAMPLES,
)
self._steps = 0

def _get_batch_size(self) -> int:
Expand Down Expand Up @@ -177,28 +203,38 @@ def update(self) -> None:
self._steps += 1
if self._batch_size_stages is not None:
self.num_batch += 1
self.total_examples += self._batch_examples()
batch_examples = self._batch_examples()
self.total_examples += batch_examples
self.attempt_examples += batch_examples

if self._steps <= self._warmup_steps:
self.warmup_examples += self._batch_examples()
self.warmup_examples += batch_examples
self.attempt_warmup_examples += batch_examples
if self._steps == self._warmup_steps:
self._previous_ts = ts
else:
time_lapse = ts - self._previous_ts
self.time_lapse_after_warmup += time_lapse
self.attempt_time_lapse_after_warmup += time_lapse
self._window_time_lapse += time_lapse
self._window_time_lapse_buffer.append(time_lapse)
self._check_window()
self._previous_ts = ts

def compute(self) -> Dict[str, torch.Tensor]:
ret = {self._total_examples_key: self.total_examples}
ret = {
self._total_examples_key: self.total_examples,
self._attempt_examples_key: self.attempt_examples,
}
if self._steps > self._warmup_steps and (
not math.isclose(self.time_lapse_after_warmup.item(), 0)
):
lifetime_throughput = (
self.total_examples - self.warmup_examples
) / self.time_lapse_after_warmup
attempt_throughput = (
self.attempt_examples - self.attempt_warmup_examples
) / self.attempt_time_lapse_after_warmup
if not math.isclose(self._window_time_lapse, 0):
window_throughput = (
len(self._window_time_lapse_buffer)
Expand All @@ -216,4 +252,10 @@ def compute(self) -> Dict[str, torch.Tensor]:
),
}
)
if not math.isclose(attempt_throughput.item(), 0):
ret.update(
{
self._attempt_throughput_key: attempt_throughput.clone().detach(),
}
)
return ret

0 comments on commit ddcbc78

Please sign in to comment.