diff --git a/potassium/potassium.py b/potassium/potassium.py index 4e2e202..54cfc36 100644 --- a/potassium/potassium.py +++ b/potassium/potassium.py @@ -17,6 +17,7 @@ from .worker import run_worker, init_worker from .exceptions import RouteAlreadyInUseException, InvalidEndpointTypeException from .types import Request, RequestHeaders, Response +import logging class HandlerType(Enum): HANDLER = "HANDLER" @@ -116,15 +117,7 @@ def __init__(self, name, experimental_num_workers=1): self.event_handler_thread = Thread(target=self._event_handler, daemon=True) self.event_handler_thread.start() - self._status = PotassiumStatus( - num_started_inference_requests=0, - num_completed_inference_requests=0, - num_bad_requests=0, - num_workers=self._num_workers, - num_workers_started=0, - idle_start_timestamp=time.time(), - in_flight_request_start_times=[] - ) + self._status = PotassiumStatus.initial(self._num_workers) def _event_handler(self): try: @@ -285,6 +278,15 @@ def status(): return flask_app def _init_server(self): + # unless the user has already set up logging, set up logging to stdout using + # a separate fd so that we don't get in the way of request logs + log = logging.getLogger('werkzeug') + if len(log.handlers) == 0: + # duplicate stdout + stdout_copy = os.dup(1) + # redirect flask logs to stdout_copy + log.addHandler(logging.StreamHandler(os.fdopen(stdout_copy, 'w'))) + self._idle_start_time = time.time() index_queue = ProcessQueue() for i in range(self._num_workers): diff --git a/potassium/status.py b/potassium/status.py index a5f34a4..edbff22 100644 --- a/potassium/status.py +++ b/potassium/status.py @@ -26,12 +26,26 @@ class PotassiumStatus(): idle_start_timestamp: float in_flight_request_start_times: List[Tuple[RequestID, float]] + @staticmethod + def initial(num_workers: int) -> "PotassiumStatus": + return PotassiumStatus( + num_started_inference_requests=0, + num_completed_inference_requests=0, + num_bad_requests=0, + num_workers=num_workers, + num_workers_started=0, + idle_start_timestamp=time.time(), + in_flight_request_start_times=[] + ) + @property def requests_in_progress(self): return self.num_started_inference_requests - self.num_completed_inference_requests @property def gpu_available(self): + if self.num_workers_started < self.num_workers: + return False return self.num_workers - self.requests_in_progress > 0 @property @@ -40,7 +54,9 @@ def sequence_number(self): @property def idle_time(self): - if not self.gpu_available or len(self.in_flight_request_start_times) > 0: + num_received_requests_not_completed = self.num_started_inference_requests - self.num_completed_inference_requests + has_incomplete_requests = num_received_requests_not_completed > 0 + if not self.gpu_available or has_incomplete_requests: return 0 return time.time() - self.idle_start_timestamp diff --git a/tests/test_status.py b/tests/test_status.py new file mode 100644 index 0000000..cba909f --- /dev/null +++ b/tests/test_status.py @@ -0,0 +1,207 @@ +import pytest +from potassium.status import StatusEvent, PotassiumStatus, InvalidStatusEvent +import time + +@pytest.mark.parametrize("worker_num", [ + 1, + 2, + 4 +]) +def test_workers_starting(worker_num): + status = PotassiumStatus.initial(worker_num) + assert status.num_workers == worker_num + assert status.num_workers_started == 0 + assert status.gpu_available == False + status = status.update((StatusEvent.WORKER_STARTED,)) + + if worker_num == 1: + assert status.gpu_available == True + else: + assert status.gpu_available == False + + for _ in range(worker_num-1): + status = status.update((StatusEvent.WORKER_STARTED,)) + assert status.num_workers_started == worker_num + assert status.gpu_available == True + +def test_bad_event(): + status = PotassiumStatus.initial(1) + with pytest.raises(InvalidStatusEvent): + status.update(("BAD_EVENT",)) + +def test_inference_requests_single_worker(): + status = PotassiumStatus.initial(1) + status = status.update((StatusEvent.WORKER_STARTED,)) + + status = status.update((StatusEvent.INFERENCE_REQUEST_RECEIVED,)) + assert status.num_started_inference_requests == 1 + assert status.num_completed_inference_requests == 0 + assert status.gpu_available == False + + status = status.update((StatusEvent.INFERENCE_START, 0)) + status = status.update((StatusEvent.INFERENCE_END, 0)) + + assert status.num_started_inference_requests == 1 + assert status.num_completed_inference_requests == 1 + assert status.sequence_number == 1 + assert status.gpu_available == True + + status = status.update((StatusEvent.INFERENCE_REQUEST_RECEIVED,)) + assert status.num_started_inference_requests == 2 + status = status.update((StatusEvent.INFERENCE_REQUEST_RECEIVED,)) + assert status.num_started_inference_requests == 3 + + status = status.update((StatusEvent.INFERENCE_START, 1)) + status = status.update((StatusEvent.INFERENCE_END, 1)) + + assert status.num_started_inference_requests == 3 + assert status.num_completed_inference_requests == 2 + assert status.sequence_number == 3 + assert status.gpu_available == False + + status = status.update((StatusEvent.INFERENCE_START, 2)) + status = status.update((StatusEvent.INFERENCE_END, 2)) + + assert status.num_started_inference_requests == 3 + assert status.num_completed_inference_requests == 3 + assert status.sequence_number == 3 + assert status.gpu_available == True + +def test_inference_requests_multiple_workers(): + state = PotassiumStatus.initial(2) + + state = state.update((StatusEvent.WORKER_STARTED,)) + state = state.update((StatusEvent.WORKER_STARTED,)) + + assert state.gpu_available == True + assert state.sequence_number == 0 + + state = state.update((StatusEvent.INFERENCE_REQUEST_RECEIVED,)) + assert state.num_started_inference_requests == 1 + assert state.num_completed_inference_requests == 0 + assert state.gpu_available == True + + state = state.update((StatusEvent.INFERENCE_REQUEST_RECEIVED,)) + assert state.num_started_inference_requests == 2 + assert state.num_completed_inference_requests == 0 + assert state.gpu_available == False + + state = state.update((StatusEvent.INFERENCE_START, 0)) + state = state.update((StatusEvent.INFERENCE_END, 0)) + + assert state.num_started_inference_requests == 2 + assert state.num_completed_inference_requests == 1 + assert state.sequence_number == 2 + assert state.gpu_available == True + + state = state.update((StatusEvent.INFERENCE_REQUEST_RECEIVED,)) + state = state.update((StatusEvent.INFERENCE_REQUEST_RECEIVED,)) + + assert state.num_started_inference_requests == 4 + assert state.num_completed_inference_requests == 1 + assert state.sequence_number == 4 + assert state.gpu_available == False + + state = state.update((StatusEvent.INFERENCE_START, 1)) + state = state.update((StatusEvent.INFERENCE_END, 1)) + + assert state.num_started_inference_requests == 4 + assert state.num_completed_inference_requests == 2 + assert state.sequence_number == 4 + assert state.gpu_available == False + + state = state.update((StatusEvent.INFERENCE_START, 2)) + state = state.update((StatusEvent.INFERENCE_END, 2)) + + assert state.num_started_inference_requests == 4 + assert state.num_completed_inference_requests == 3 + assert state.sequence_number == 4 + assert state.gpu_available == True + + state = state.update((StatusEvent.INFERENCE_START, 3)) + state = state.update((StatusEvent.INFERENCE_END, 3)) + + assert state.num_started_inference_requests == 4 + assert state.num_completed_inference_requests == 4 + assert state.sequence_number == 4 + assert state.gpu_available == True + +@pytest.mark.parametrize("status_result_tuple", [ + (PotassiumStatus( + num_started_inference_requests=0, + num_completed_inference_requests=0, + num_bad_requests=0, + num_workers=1, + num_workers_started=0, + idle_start_timestamp=0, + in_flight_request_start_times=[] + ), 0), + (PotassiumStatus( + num_started_inference_requests=0, + num_completed_inference_requests=0, + num_bad_requests=0, + num_workers=1, + num_workers_started=1, + idle_start_timestamp=0, + in_flight_request_start_times=[] + ), time.time()), + (PotassiumStatus( + num_started_inference_requests=1, + num_completed_inference_requests=0, + num_bad_requests=0, + num_workers=1, + num_workers_started=1, + idle_start_timestamp=0, + in_flight_request_start_times=[] + ), 0), + (PotassiumStatus( + num_started_inference_requests=2, + num_completed_inference_requests=0, + num_bad_requests=0, + num_workers=4, + num_workers_started=4, + idle_start_timestamp=0, + in_flight_request_start_times=[] + ), 0), + (PotassiumStatus( + num_started_inference_requests=2, + num_completed_inference_requests=2, + num_bad_requests=0, + num_workers=4, + num_workers_started=4, + idle_start_timestamp=0, + in_flight_request_start_times=[] + ), time.time()), +]) +def test_idle_time(status_result_tuple): + status, result = status_result_tuple + delta = abs(status.idle_time - result) + ALLOWED_DELTA = 1 + assert delta < ALLOWED_DELTA + +def test_longest_inference_time(): + status = PotassiumStatus( + num_started_inference_requests=6, + num_completed_inference_requests=2, + num_bad_requests=0, + num_workers=4, + num_workers_started=4, + idle_start_timestamp=0, + in_flight_request_start_times=[ + ("b", time.time() - 2), + ("a", time.time() - 1), + ("c", time.time() - 3), + ("d", time.time()), + ] + ) + + longest_inference_time = status.longest_inference_time + EXPECTED_LONGEST_INFERENCE_TIME = 3 + delta = abs(longest_inference_time - EXPECTED_LONGEST_INFERENCE_TIME) + + ALLOWED_DELTA = 0.1 + assert delta < ALLOWED_DELTA + + + +