Skip to content

Commit

Permalink
some cleanups, improved test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
Peddle committed Dec 11, 2023
1 parent 86e8571 commit 6f140c0
Show file tree
Hide file tree
Showing 3 changed files with 235 additions and 10 deletions.
20 changes: 11 additions & 9 deletions potassium/potassium.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .worker import run_worker, init_worker
from .exceptions import RouteAlreadyInUseException, InvalidEndpointTypeException
from .types import Request, RequestHeaders, Response
import logging

class HandlerType(Enum):
HANDLER = "HANDLER"
Expand Down Expand Up @@ -116,15 +117,7 @@ def __init__(self, name, experimental_num_workers=1):
self.event_handler_thread = Thread(target=self._event_handler, daemon=True)
self.event_handler_thread.start()

self._status = PotassiumStatus(
num_started_inference_requests=0,
num_completed_inference_requests=0,
num_bad_requests=0,
num_workers=self._num_workers,
num_workers_started=0,
idle_start_timestamp=time.time(),
in_flight_request_start_times=[]
)
self._status = PotassiumStatus.initial(self._num_workers)

def _event_handler(self):
try:
Expand Down Expand Up @@ -285,6 +278,15 @@ def status():
return flask_app

def _init_server(self):
# unless the user has already set up logging, set up logging to stdout using
# a separate fd so that we don't get in the way of request logs
log = logging.getLogger('werkzeug')
if len(log.handlers) == 0:
# duplicate stdout
stdout_copy = os.dup(1)
# redirect flask logs to stdout_copy
log.addHandler(logging.StreamHandler(os.fdopen(stdout_copy, 'w')))

self._idle_start_time = time.time()
index_queue = ProcessQueue()
for i in range(self._num_workers):
Expand Down
18 changes: 17 additions & 1 deletion potassium/status.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,26 @@ class PotassiumStatus():
idle_start_timestamp: float
in_flight_request_start_times: List[Tuple[RequestID, float]]

@staticmethod
def initial(num_workers: int) -> "PotassiumStatus":
return PotassiumStatus(
num_started_inference_requests=0,
num_completed_inference_requests=0,
num_bad_requests=0,
num_workers=num_workers,
num_workers_started=0,
idle_start_timestamp=time.time(),
in_flight_request_start_times=[]
)

@property
def requests_in_progress(self):
return self.num_started_inference_requests - self.num_completed_inference_requests

@property
def gpu_available(self):
if self.num_workers_started < self.num_workers:
return False
return self.num_workers - self.requests_in_progress > 0

@property
Expand All @@ -40,7 +54,9 @@ def sequence_number(self):

@property
def idle_time(self):
if not self.gpu_available or len(self.in_flight_request_start_times) > 0:
num_received_requests_not_completed = self.num_started_inference_requests - self.num_completed_inference_requests
has_incomplete_requests = num_received_requests_not_completed > 0
if not self.gpu_available or has_incomplete_requests:
return 0
return time.time() - self.idle_start_timestamp

Expand Down
207 changes: 207 additions & 0 deletions tests/test_status.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
import pytest
from potassium.status import StatusEvent, PotassiumStatus, InvalidStatusEvent
import time

@pytest.mark.parametrize("worker_num", [
1,
2,
4
])
def test_workers_starting(worker_num):
status = PotassiumStatus.initial(worker_num)
assert status.num_workers == worker_num
assert status.num_workers_started == 0
assert status.gpu_available == False
status = status.update((StatusEvent.WORKER_STARTED,))

if worker_num == 1:
assert status.gpu_available == True
else:
assert status.gpu_available == False

for _ in range(worker_num-1):
status = status.update((StatusEvent.WORKER_STARTED,))
assert status.num_workers_started == worker_num
assert status.gpu_available == True

def test_bad_event():
status = PotassiumStatus.initial(1)
with pytest.raises(InvalidStatusEvent):
status.update(("BAD_EVENT",))

def test_inference_requests_single_worker():
status = PotassiumStatus.initial(1)
status = status.update((StatusEvent.WORKER_STARTED,))

status = status.update((StatusEvent.INFERENCE_REQUEST_RECEIVED,))
assert status.num_started_inference_requests == 1
assert status.num_completed_inference_requests == 0
assert status.gpu_available == False

status = status.update((StatusEvent.INFERENCE_START, 0))
status = status.update((StatusEvent.INFERENCE_END, 0))

assert status.num_started_inference_requests == 1
assert status.num_completed_inference_requests == 1
assert status.sequence_number == 1
assert status.gpu_available == True

status = status.update((StatusEvent.INFERENCE_REQUEST_RECEIVED,))
assert status.num_started_inference_requests == 2
status = status.update((StatusEvent.INFERENCE_REQUEST_RECEIVED,))
assert status.num_started_inference_requests == 3

status = status.update((StatusEvent.INFERENCE_START, 1))
status = status.update((StatusEvent.INFERENCE_END, 1))

assert status.num_started_inference_requests == 3
assert status.num_completed_inference_requests == 2
assert status.sequence_number == 3
assert status.gpu_available == False

status = status.update((StatusEvent.INFERENCE_START, 2))
status = status.update((StatusEvent.INFERENCE_END, 2))

assert status.num_started_inference_requests == 3
assert status.num_completed_inference_requests == 3
assert status.sequence_number == 3
assert status.gpu_available == True

def test_inference_requests_multiple_workers():
state = PotassiumStatus.initial(2)

state = state.update((StatusEvent.WORKER_STARTED,))
state = state.update((StatusEvent.WORKER_STARTED,))

assert state.gpu_available == True
assert state.sequence_number == 0

state = state.update((StatusEvent.INFERENCE_REQUEST_RECEIVED,))
assert state.num_started_inference_requests == 1
assert state.num_completed_inference_requests == 0
assert state.gpu_available == True

state = state.update((StatusEvent.INFERENCE_REQUEST_RECEIVED,))
assert state.num_started_inference_requests == 2
assert state.num_completed_inference_requests == 0
assert state.gpu_available == False

state = state.update((StatusEvent.INFERENCE_START, 0))
state = state.update((StatusEvent.INFERENCE_END, 0))

assert state.num_started_inference_requests == 2
assert state.num_completed_inference_requests == 1
assert state.sequence_number == 2
assert state.gpu_available == True

state = state.update((StatusEvent.INFERENCE_REQUEST_RECEIVED,))
state = state.update((StatusEvent.INFERENCE_REQUEST_RECEIVED,))

assert state.num_started_inference_requests == 4
assert state.num_completed_inference_requests == 1
assert state.sequence_number == 4
assert state.gpu_available == False

state = state.update((StatusEvent.INFERENCE_START, 1))
state = state.update((StatusEvent.INFERENCE_END, 1))

assert state.num_started_inference_requests == 4
assert state.num_completed_inference_requests == 2
assert state.sequence_number == 4
assert state.gpu_available == False

state = state.update((StatusEvent.INFERENCE_START, 2))
state = state.update((StatusEvent.INFERENCE_END, 2))

assert state.num_started_inference_requests == 4
assert state.num_completed_inference_requests == 3
assert state.sequence_number == 4
assert state.gpu_available == True

state = state.update((StatusEvent.INFERENCE_START, 3))
state = state.update((StatusEvent.INFERENCE_END, 3))

assert state.num_started_inference_requests == 4
assert state.num_completed_inference_requests == 4
assert state.sequence_number == 4
assert state.gpu_available == True

@pytest.mark.parametrize("status_result_tuple", [
(PotassiumStatus(
num_started_inference_requests=0,
num_completed_inference_requests=0,
num_bad_requests=0,
num_workers=1,
num_workers_started=0,
idle_start_timestamp=0,
in_flight_request_start_times=[]
), 0),
(PotassiumStatus(
num_started_inference_requests=0,
num_completed_inference_requests=0,
num_bad_requests=0,
num_workers=1,
num_workers_started=1,
idle_start_timestamp=0,
in_flight_request_start_times=[]
), time.time()),
(PotassiumStatus(
num_started_inference_requests=1,
num_completed_inference_requests=0,
num_bad_requests=0,
num_workers=1,
num_workers_started=1,
idle_start_timestamp=0,
in_flight_request_start_times=[]
), 0),
(PotassiumStatus(
num_started_inference_requests=2,
num_completed_inference_requests=0,
num_bad_requests=0,
num_workers=4,
num_workers_started=4,
idle_start_timestamp=0,
in_flight_request_start_times=[]
), 0),
(PotassiumStatus(
num_started_inference_requests=2,
num_completed_inference_requests=2,
num_bad_requests=0,
num_workers=4,
num_workers_started=4,
idle_start_timestamp=0,
in_flight_request_start_times=[]
), time.time()),
])
def test_idle_time(status_result_tuple):
status, result = status_result_tuple
delta = abs(status.idle_time - result)
ALLOWED_DELTA = 1
assert delta < ALLOWED_DELTA

def test_longest_inference_time():
status = PotassiumStatus(
num_started_inference_requests=6,
num_completed_inference_requests=2,
num_bad_requests=0,
num_workers=4,
num_workers_started=4,
idle_start_timestamp=0,
in_flight_request_start_times=[
("b", time.time() - 2),
("a", time.time() - 1),
("c", time.time() - 3),
("d", time.time()),
]
)

longest_inference_time = status.longest_inference_time
EXPECTED_LONGEST_INFERENCE_TIME = 3
delta = abs(longest_inference_time - EXPECTED_LONGEST_INFERENCE_TIME)

ALLOWED_DELTA = 0.1
assert delta < ALLOWED_DELTA




0 comments on commit 6f140c0

Please sign in to comment.