From 4b91f8c86e1329162c8118824ad44dd9be6b9e3d Mon Sep 17 00:00:00 2001 From: Yingge He Date: Tue, 6 Aug 2024 22:20:15 -0700 Subject: [PATCH] Add histogram test --- .../metrics_test/vllm_metrics_test.py | 25 ++++--- src/utils/metrics.py | 74 ++++++++++++++++++- 2 files changed, 89 insertions(+), 10 deletions(-) diff --git a/ci/L0_backend_vllm/metrics_test/vllm_metrics_test.py b/ci/L0_backend_vllm/metrics_test/vllm_metrics_test.py index 8284835b..196a0d64 100644 --- a/ci/L0_backend_vllm/metrics_test/vllm_metrics_test.py +++ b/ci/L0_backend_vllm/metrics_test/vllm_metrics_test.py @@ -112,21 +112,28 @@ def vllm_infer( self.triton_client.stop_stream() def test_vllm_metrics(self): - # All vLLM metrics from tritonserver - expected_metrics_dict = { - "vllm:prompt_tokens_total": 0, - "vllm:generation_tokens_total": 0, - } - # Test vLLM metrics self.vllm_infer( prompts=self.prompts, sampling_parameters=self.sampling_parameters, model_name=self.vllm_model_name, ) - expected_metrics_dict["vllm:prompt_tokens_total"] = 18 - expected_metrics_dict["vllm:generation_tokens_total"] = 48 - self.assertEqual(self.get_metrics(), expected_metrics_dict) + metrics_dict = self.get_metrics() + + # vllm:prompt_tokens_total + self.assertEqual(metrics_dict["vllm:prompt_tokens_total"], 18) + # vllm:generation_tokens_total + self.assertEqual(metrics_dict["vllm:generation_tokens_total"], 48) + # vllm:time_to_first_token_seconds + self.assertEqual(metrics_dict["vllm:time_to_first_token_seconds_count"], 3) + self.assertTrue( + 0 < metrics_dict["vllm:time_to_first_token_seconds_sum"] < 0.0005 + ) + # vllm:time_per_output_token_seconds + self.assertEqual(metrics_dict["vllm:time_per_output_token_seconds_count"], 45) + self.assertTrue( + 0 <= metrics_dict["vllm:time_per_output_token_seconds_sum"] <= 0.005 + ) def tearDown(self): self.triton_client.close() diff --git a/src/utils/metrics.py b/src/utils/metrics.py index e8c58372..d8c71ebc 100644 --- a/src/utils/metrics.py +++ b/src/utils/metrics.py @@ -24,7 +24,7 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -from typing import Dict, Union +from typing import Dict, List, Union import triton_python_backend_utils as pb_utils from vllm.engine.metrics import StatLoggerBase as VllmStatLoggerBase @@ -46,6 +46,16 @@ def __init__(self, labels): description="Number of generation tokens processed.", kind=pb_utils.MetricFamily.COUNTER, ) + self.histogram_time_to_first_token_family = pb_utils.MetricFamily( + name="vllm:time_to_first_token_seconds", + description="Histogram of time to first token in seconds.", + kind=pb_utils.MetricFamily.HISTOGRAM, + ) + self.histogram_time_per_output_token_family = pb_utils.MetricFamily( + name="vllm:time_per_output_token_seconds", + description="Histogram of time per output token in seconds.", + kind=pb_utils.MetricFamily.HISTOGRAM, + ) # Initialize metrics # Iteration stats @@ -55,6 +65,49 @@ def __init__(self, labels): self.counter_generation_tokens = self.counter_generation_tokens_family.Metric( labels=labels ) + self.histogram_time_to_first_token = ( + self.histogram_time_to_first_token_family.Metric( + labels=labels, + buckets=[ + 0.001, + 0.005, + 0.01, + 0.02, + 0.04, + 0.06, + 0.08, + 0.1, + 0.25, + 0.5, + 0.75, + 1.0, + 2.5, + 5.0, + 7.5, + 10.0, + ], + ) + ) + self.histogram_time_per_output_token = ( + self.histogram_time_per_output_token_family.Metric( + labels=labels, + buckets=[ + 0.01, + 0.025, + 0.05, + 0.075, + 0.1, + 0.15, + 0.2, + 0.3, + 0.4, + 0.5, + 0.75, + 1.0, + 2.5, + ], + ) + ) class VllmStatLogger(VllmStatLoggerBase): @@ -93,6 +146,19 @@ def _log_counter(self, counter, data: Union[int, float]) -> None: """ if data != 0: counter.increment(data) + + def _log_histogram(self, histogram, data: Union[List[int], List[float]]) -> None: + """Convenience function for logging list to histogram. + + Args: + histogram: A histogram metric instance. + data: A list of int or float data to observe into the histogram metric. + + Returns: + None + """ + for datum in data: + histogram.observe(datum) def log(self, stats: VllmStats) -> None: """Logs tracked stats to triton metrics server every iteration. @@ -108,4 +174,10 @@ def log(self, stats: VllmStats) -> None: ) self._log_counter( self.metrics.counter_generation_tokens, stats.num_generation_tokens_iter + self._log_histogram( + self.metrics.histogram_time_to_first_token, stats.time_to_first_tokens_iter + ) + self._log_histogram( + self.metrics.histogram_time_per_output_token, + stats.time_per_output_tokens_iter, )