Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: vLLM metrics optimization #66

Merged
merged 3 commits into from
Sep 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def init_engine(self):
self.llm_engine = AsyncLLMEngine.from_engine_args(aync_engine_args)

# Create vLLM custom metrics
self.vllm_metrics = None
if (
"REPORT_CUSTOM_METRICS" in self.model_config["parameters"]
and self.model_config["parameters"]["REPORT_CUSTOM_METRICS"]["string_value"]
Expand All @@ -174,9 +175,10 @@ def init_engine(self):
}
# Add vLLM custom metrics
engine_config = self.llm_engine.engine.model_config
self.llm_engine.add_logger(
"triton", VllmStatLogger(labels, engine_config.max_model_len)
self.vllm_metrics = VllmStatLogger(
labels, engine_config.max_model_len, self.logger
)
self.llm_engine.add_logger("triton", self.vllm_metrics)
except pb_utils.TritonModelException as e:
if "metrics not supported" in str(e):
# Metrics are disabled at the server
Expand Down Expand Up @@ -572,6 +574,10 @@ def finalize(self):
self._response_thread.join()
self._response_thread = None

# Shutdown the logger thread.
if self.vllm_metrics is not None:
self.vllm_metrics.finalize()

# When using parallel tensors, the stub process may not shutdown due to
# unreleased references, so manually run the garbage collector once.
self.logger.log_info("[vllm] Running Garbage Collector on finalize...")
Expand Down
38 changes: 34 additions & 4 deletions src/utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import queue
import threading
from typing import Dict, List, Union

import triton_python_backend_utils as pb_utils
Expand Down Expand Up @@ -170,11 +172,18 @@ def __init__(self, labels: List[str], max_model_len: int):
class VllmStatLogger(VllmStatLoggerBase):
"""StatLogger is used as an adapter between vLLM stats collector and Triton metrics provider."""

# local_interval not used here. It's for vLLM logs to stdout.
def __init__(self, labels: Dict, max_model_len: int) -> None:
def __init__(self, labels: Dict, max_model_len: int, log_logger) -> None:
# Tracked stats over current local logging interval.
# local_interval not used here. It's for vLLM logs to stdout.
super().__init__(local_interval=0)
self.metrics = TritonMetrics(labels, max_model_len)
self.log_logger = log_logger

# Starting the metrics thread. It allows vLLM to keep making progress
# while reporting metrics to triton metrics service.
self._logger_queue = queue.Queue()
self._logger_thread = threading.Thread(target=self.logger_loop)
self._logger_thread.start()

def info(self, type: str, obj: SupportsMetricsInfo) -> None:
pass
Expand All @@ -190,7 +199,7 @@ def _log_counter(self, counter, data: Union[int, float]) -> None:
None
"""
if data != 0:
counter.increment(data)
self._logger_queue.put_nowait((counter, "increment", data))

def _log_histogram(self, histogram, data: Union[List[int], List[float]]) -> None:
"""Convenience function for logging list to histogram.
Expand All @@ -203,7 +212,7 @@ def _log_histogram(self, histogram, data: Union[List[int], List[float]]) -> None
None
"""
for datum in data:
histogram.observe(datum)
self._logger_queue.put_nowait((histogram, "observe", datum))

def log(self, stats: VllmStats) -> None:
"""Report stats to Triton metrics server.
Expand Down Expand Up @@ -246,3 +255,24 @@ def log(self, stats: VllmStats) -> None:
self._log_counter(metric, data)
for metric, data in histogram_metrics:
self._log_histogram(metric, data)

def logger_loop(self):
while True:
item = self._logger_queue.get()
# To signal shutdown a None item will be added to the queue.
if item is None:
break
metric, command, data = item
if command == "increment":
metric.increment(data)
elif command == "observe":
metric.observe(data)
else:
self.log_logger.log_error(f"Undefined command name: {command}")

def finalize(self):
# Shutdown the logger thread.
self._logger_queue.put(None)
if self._logger_thread is not None:
self._logger_thread.join()
self._logger_thread = None
Loading