Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Report more histogram metrics #61

Merged
merged 5 commits into from
Aug 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 45 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,16 @@ counter_generation_tokens
histogram_time_to_first_token
# Histogram of time per output token in seconds.
histogram_time_per_output_token
# Histogram of end to end request latency in seconds.
histogram_e2e_time_request
# Number of prefill tokens processed.
histogram_num_prompt_tokens_request
# Number of generation tokens processed.
histogram_num_generation_tokens_request
# Histogram of the best_of request parameter.
histogram_best_of_request
# Histogram of the n request parameter.
histogram_n_request
```
Your output for these fields should look similar to the following:
```bash
Expand All @@ -238,17 +248,50 @@ vllm:generation_tokens_total{model="vllm_model",version="1"} 16
vllm:time_to_first_token_seconds_count{model="vllm_model",version="1"} 1
vllm:time_to_first_token_seconds_sum{model="vllm_model",version="1"} 0.03233122825622559
vllm:time_to_first_token_seconds_bucket{model="vllm_model",version="1",le="0.001"} 0
vllm:time_to_first_token_seconds_bucket{model="vllm_model",version="1",le="0.005"} 0
...
vllm:time_to_first_token_seconds_bucket{model="vllm_model",version="1",le="+Inf"} 1
# HELP vllm:time_per_output_token_seconds Histogram of time per output token in seconds.
# TYPE vllm:time_per_output_token_seconds histogram
vllm:time_per_output_token_seconds_count{model="vllm_model",version="1"} 15
vllm:time_per_output_token_seconds_sum{model="vllm_model",version="1"} 0.04501533508300781
vllm:time_per_output_token_seconds_bucket{model="vllm_model",version="1",le="0.01"} 14
vllm:time_per_output_token_seconds_bucket{model="vllm_model",version="1",le="0.025"} 15
...
vllm:time_per_output_token_seconds_bucket{model="vllm_model",version="1",le="+Inf"} 15
# HELP vllm:e2e_request_latency_seconds Histogram of end to end request latency in seconds.
# TYPE vllm:e2e_request_latency_seconds histogram
vllm:e2e_request_latency_seconds_count{model="vllm_model",version="1"} 1
vllm:e2e_request_latency_seconds_sum{model="vllm_model",version="1"} 0.08686184883117676
vllm:e2e_request_latency_seconds_bucket{model="vllm_model",version="1",le="1"} 1
...
vllm:e2e_request_latency_seconds_bucket{model="vllm_model",version="1",le="+Inf"} 1
# HELP vllm:request_prompt_tokens Number of prefill tokens processed.
# TYPE vllm:request_prompt_tokens histogram
vllm:request_prompt_tokens_count{model="vllm_model",version="1"} 1
vllm:request_prompt_tokens_sum{model="vllm_model",version="1"} 10
vllm:request_prompt_tokens_bucket{model="vllm_model",version="1",le="1"} 0
...
vllm:request_prompt_tokens_bucket{model="vllm_model",version="1",le="+Inf"} 1
# HELP vllm:request_generation_tokens Number of generation tokens processed.
# TYPE vllm:request_generation_tokens histogram
vllm:request_generation_tokens_count{model="vllm_model",version="1"} 1
vllm:request_generation_tokens_sum{model="vllm_model",version="1"} 16
vllm:request_generation_tokens_bucket{model="vllm_model",version="1",le="1"} 0
...
vllm:request_generation_tokens_bucket{model="vllm_model",version="1",le="+Inf"} 1
# HELP vllm:request_params_best_of Histogram of the best_of request parameter.
# TYPE vllm:request_params_best_of histogram
vllm:request_params_best_of_count{model="vllm_model",version="1"} 1
vllm:request_params_best_of_sum{model="vllm_model",version="1"} 1
vllm:request_params_best_of_bucket{model="vllm_model",version="1",le="1"} 1
...
vllm:request_params_best_of_bucket{model="vllm_model",version="1",le="+Inf"} 1
# HELP vllm:request_params_n Histogram of the n request parameter.
# TYPE vllm:request_params_n histogram
vllm:request_params_n_count{model="vllm_model",version="1"} 1
vllm:request_params_n_sum{model="vllm_model",version="1"} 1
vllm:request_params_n_bucket{model="vllm_model",version="1",le="1"} 1
...
vllm:request_params_n_bucket{model="vllm_model",version="1",le="+Inf"} 1
```
To enable vLLM engine colleting metrics, "disable_log_stats" option need to be either false
or left empty (false by default) in [model.json](https://github.com/triton-inference-server/vllm_backend/blob/main/samples/model_repository/vllm_model/1/model.json).
Expand Down
72 changes: 63 additions & 9 deletions ci/L0_backend_vllm/metrics_test/vllm_metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def setUp(self):
]
self.sampling_parameters = {"temperature": "0", "top_p": "1"}

def get_vllm_metrics(self):
def parse_vllm_metrics(self):
"""
Store vllm metrics in a dictionary.
"""
Expand Down Expand Up @@ -112,27 +112,81 @@ def vllm_infer(
self.triton_client.stop_stream()

def test_vllm_metrics(self):
# Adding sampling parameters for testing metrics.
# Definitions can be found here https://docs.vllm.ai/en/latest/dev/sampling_params.html
n, best_of = 2, 4
custom_sampling_parameters = self.sampling_parameters.copy()
# Changing "temperature" because "best_of" must be 1 when using greedy
# sampling, i.e. "temperature": "0".
custom_sampling_parameters.update(
{"n": str(n), "best_of": str(best_of), "temperature": "1"}
)

# Test vLLM metrics
self.vllm_infer(
prompts=self.prompts,
sampling_parameters=self.sampling_parameters,
sampling_parameters=custom_sampling_parameters,
model_name=self.vllm_model_name,
)
metrics_dict = self.get_vllm_metrics()
metrics_dict = self.parse_vllm_metrics()
total_prompts = len(self.prompts)

# vllm:prompt_tokens_total
self.assertEqual(metrics_dict["vllm:prompt_tokens_total"], 18)
# vllm:generation_tokens_total
self.assertEqual(metrics_dict["vllm:generation_tokens_total"], 48)

self.assertEqual(metrics_dict["vllm:generation_tokens_total"], 188)
# vllm:time_to_first_token_seconds
self.assertEqual(metrics_dict["vllm:time_to_first_token_seconds_count"], 3)
self.assertEqual(
metrics_dict["vllm:time_to_first_token_seconds_count"], total_prompts
)
self.assertGreater(metrics_dict["vllm:time_to_first_token_seconds_sum"], 0)
self.assertEqual(metrics_dict["vllm:time_to_first_token_seconds_bucket"], 3)
self.assertEqual(
metrics_dict["vllm:time_to_first_token_seconds_bucket"], total_prompts
)
# vllm:time_per_output_token_seconds
self.assertEqual(metrics_dict["vllm:time_per_output_token_seconds_count"], 45)
self.assertGreater(metrics_dict["vllm:time_per_output_token_seconds_sum"], 0)
self.assertEqual(metrics_dict["vllm:time_per_output_token_seconds_bucket"], 45)
# vllm:e2e_request_latency_seconds
self.assertEqual(
metrics_dict["vllm:e2e_request_latency_seconds_count"], total_prompts
)
Comment on lines +151 to +153
Copy link
Contributor

@rmccorm4 rmccorm4 Aug 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think it would be easier to read/review/maintain this function if we could group the common parts:

histogram_metrics = ["vllm:e2e_request_latency_seconds", "..."]
for metric in metrics:
  # Expect exactly one observation and bucket per prompt
  self.assertEqual(f"{metric}_count", total_prompts)
  self.assertEqual(f"{metric}_bucket", total_prompts)

  # Compare the exact expected sum where it makes sense, otherwise assert non-zero
  if metric.endswith("_best_of"):
    self.assertEqual(f"{metric}_sum", best_of*total_prompts)
  elif metric.endswith("_n"):
    self.assertEqual(f"{metric}_sum", n*total_prompts)
  else:
    self.assertGreater(f"{metric}_sum", 0)

for as many of the metrics that make sense to follow the same pattern.

We can have separate/special cases for the metrics that don't fit this pattern.

Feel free to modify or change if any of the above is incorrect, just an example.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your input. I personally don't think it's maintainable to add new metric tests since there are a lot of special cases. Counter example

# vllm:time_per_output_token_seconds
self.assertEqual(metrics_dict["vllm:time_per_output_token_seconds_count"], 45)
# This line is fine
self.assertGreater(metrics_dict["vllm:time_per_output_token_seconds_sum"], 0)
self.assertEqual(metrics_dict["vllm:time_per_output_token_seconds_bucket"], 45)

self.assertGreater(metrics_dict["vllm:e2e_request_latency_seconds_sum"], 0)
self.assertEqual(
metrics_dict["vllm:e2e_request_latency_seconds_bucket"], total_prompts
)
# vllm:request_prompt_tokens
self.assertEqual(
metrics_dict["vllm:request_prompt_tokens_count"], total_prompts
)
self.assertEqual(metrics_dict["vllm:request_prompt_tokens_sum"], 18)
self.assertEqual(
metrics_dict["vllm:request_prompt_tokens_bucket"], total_prompts
)
# vllm:request_generation_tokens
self.assertEqual(
metrics_dict["vllm:request_generation_tokens_count"],
best_of * total_prompts,
)
self.assertEqual(metrics_dict["vllm:request_generation_tokens_sum"], 188)
self.assertEqual(
metrics_dict["vllm:request_generation_tokens_bucket"],
best_of * total_prompts,
)
# vllm:request_params_best_of
self.assertEqual(
metrics_dict["vllm:request_params_best_of_count"], total_prompts
)
self.assertEqual(
metrics_dict["vllm:request_params_best_of_sum"], best_of * total_prompts
)
self.assertEqual(
metrics_dict["vllm:request_params_best_of_bucket"], total_prompts
)
# vllm:request_params_n
self.assertEqual(metrics_dict["vllm:request_params_n_count"], total_prompts)
self.assertEqual(metrics_dict["vllm:request_params_n_sum"], n * total_prompts)
self.assertEqual(metrics_dict["vllm:request_params_n_bucket"], total_prompts)

def test_vllm_metrics_disabled(self):
# Test vLLM metrics
Expand All @@ -141,7 +195,7 @@ def test_vllm_metrics_disabled(self):
sampling_parameters=self.sampling_parameters,
model_name=self.vllm_model_name,
)
metrics_dict = self.get_vllm_metrics()
metrics_dict = self.parse_vllm_metrics()

# No vLLM metric found
self.assertEqual(len(metrics_dict), 0)
Expand All @@ -154,7 +208,7 @@ def test_vllm_metrics_refused(self):
model_name=self.vllm_model_name,
)
with self.assertRaises(requests.exceptions.ConnectionError):
self.get_vllm_metrics()
self.parse_vllm_metrics()

def tearDown(self):
self.triton_client.close()
Expand Down
5 changes: 4 additions & 1 deletion src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,10 @@ def init_engine(self):
"version": self.args["model_version"],
}
# Add vLLM custom metrics
self.llm_engine.add_logger("triton", VllmStatLogger(labels=labels))
engine_config = self.llm_engine.engine.model_config
self.llm_engine.add_logger(
"triton", VllmStatLogger(labels, engine_config.max_model_len)
)
except pb_utils.TritonModelException as e:
if "metrics not supported" in str(e):
# Metrics are disabled at the server
Expand Down
112 changes: 93 additions & 19 deletions src/utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@
import triton_python_backend_utils as pb_utils
from vllm.engine.metrics import StatLoggerBase as VllmStatLoggerBase
from vllm.engine.metrics import Stats as VllmStats
from vllm.engine.metrics import SupportsMetricsInfo
from vllm.engine.metrics import SupportsMetricsInfo, build_1_2_5_buckets


class TritonMetrics:
def __init__(self, labels):
def __init__(self, labels: List[str], max_model_len: int):
# Initialize metric families
# Iteration stats
self.counter_prompt_tokens_family = pb_utils.MetricFamily(
Expand All @@ -56,6 +56,34 @@ def __init__(self, labels):
description="Histogram of time per output token in seconds.",
kind=pb_utils.MetricFamily.HISTOGRAM,
)
# Request stats
# Latency
self.histogram_e2e_time_request_family = pb_utils.MetricFamily(
name="vllm:e2e_request_latency_seconds",
description="Histogram of end to end request latency in seconds.",
kind=pb_utils.MetricFamily.HISTOGRAM,
)
# Metadata
self.histogram_num_prompt_tokens_request_family = pb_utils.MetricFamily(
name="vllm:request_prompt_tokens",
description="Number of prefill tokens processed.",
kind=pb_utils.MetricFamily.HISTOGRAM,
)
self.histogram_num_generation_tokens_request_family = pb_utils.MetricFamily(
name="vllm:request_generation_tokens",
description="Number of generation tokens processed.",
kind=pb_utils.MetricFamily.HISTOGRAM,
)
self.histogram_best_of_request_family = pb_utils.MetricFamily(
name="vllm:request_params_best_of",
description="Histogram of the best_of request parameter.",
kind=pb_utils.MetricFamily.HISTOGRAM,
)
self.histogram_n_request_family = pb_utils.MetricFamily(
name="vllm:request_params_n",
description="Histogram of the n request parameter.",
kind=pb_utils.MetricFamily.HISTOGRAM,
)

# Initialize metrics
# Iteration stats
Expand All @@ -65,7 +93,7 @@ def __init__(self, labels):
self.counter_generation_tokens = self.counter_generation_tokens_family.Metric(
labels=labels
)
# Use the same bucket boundaries from vLLM sample metrics.
# Use the same bucket boundaries from vLLM sample metrics as an example.
# https://github.com/vllm-project/vllm/blob/21313e09e3f9448817016290da20d0db1adf3664/vllm/engine/metrics.py#L81-L96
self.histogram_time_to_first_token = (
self.histogram_time_to_first_token_family.Metric(
Expand Down Expand Up @@ -110,16 +138,43 @@ def __init__(self, labels):
],
)
)
# Request stats
# Latency
self.histogram_e2e_time_request = self.histogram_e2e_time_request_family.Metric(
oandreeva-nv marked this conversation as resolved.
Show resolved Hide resolved
labels=labels,
buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0],
)
# Metadata
self.histogram_num_prompt_tokens_request = (
self.histogram_num_prompt_tokens_request_family.Metric(
labels=labels,
buckets=build_1_2_5_buckets(max_model_len),
)
)
self.histogram_num_generation_tokens_request = (
self.histogram_num_generation_tokens_request_family.Metric(
labels=labels,
buckets=build_1_2_5_buckets(max_model_len),
)
)
self.histogram_best_of_request = self.histogram_best_of_request_family.Metric(
labels=labels,
buckets=[1, 2, 5, 10, 20],
)
self.histogram_n_request = self.histogram_n_request_family.Metric(
labels=labels,
buckets=[1, 2, 5, 10, 20],
)


class VllmStatLogger(VllmStatLoggerBase):
"""StatLogger is used as an adapter between vLLM stats collector and Triton metrics provider."""

# local_interval not used here. It's for vLLM logs to stdout.
def __init__(self, labels: Dict, local_interval: float = 0) -> None:
def __init__(self, labels: Dict, max_model_len: int) -> None:
# Tracked stats over current local logging interval.
super().__init__(local_interval)
self.metrics = TritonMetrics(labels=labels)
super().__init__(local_interval=0)
self.metrics = TritonMetrics(labels, max_model_len)

def info(self, type: str, obj: SupportsMetricsInfo) -> None:
pass
Expand Down Expand Up @@ -159,16 +214,35 @@ def log(self, stats: VllmStats) -> None:
Returns:
None
"""
self._log_counter(
self.metrics.counter_prompt_tokens, stats.num_prompt_tokens_iter
)
self._log_counter(
self.metrics.counter_generation_tokens, stats.num_generation_tokens_iter
)
self._log_histogram(
self.metrics.histogram_time_to_first_token, stats.time_to_first_tokens_iter
)
self._log_histogram(
self.metrics.histogram_time_per_output_token,
stats.time_per_output_tokens_iter,
)
# The list of vLLM metrics reporting to Triton is also documented here.
# https://github.com/triton-inference-server/vllm_backend/blob/main/README.md#triton-metrics
counter_metrics = [
(self.metrics.counter_prompt_tokens, stats.num_prompt_tokens_iter),
(self.metrics.counter_generation_tokens, stats.num_generation_tokens_iter),
]
histogram_metrics = [
(
self.metrics.histogram_time_to_first_token,
stats.time_to_first_tokens_iter,
),
(
self.metrics.histogram_time_per_output_token,
stats.time_per_output_tokens_iter,
),
(self.metrics.histogram_e2e_time_request, stats.time_e2e_requests),
(
self.metrics.histogram_num_prompt_tokens_request,
stats.num_prompt_tokens_requests,
),
(
self.metrics.histogram_num_generation_tokens_request,
stats.num_generation_tokens_requests,
),
(self.metrics.histogram_best_of_request, stats.best_of_requests),
(self.metrics.histogram_n_request, stats.n_requests),
]

for metric, data in counter_metrics:
self._log_counter(metric, data)
for metric, data in histogram_metrics:
self._log_histogram(metric, data)
Loading