Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: decode request in threadpool #290

83 changes: 64 additions & 19 deletions src/litserve/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import concurrent.futures
import inspect
import logging
import multiprocessing as mp
import os
import pickle
import sys
import time
from queue import Empty, Queue
from typing import Dict, List, Optional, Tuple, Union
from typing import Callable, Dict, List, Optional, Tuple, Union

from fastapi import HTTPException
from starlette.formparsers import MultiPartParser
Expand Down Expand Up @@ -55,6 +57,20 @@ def _inject_context(context: Union[List[dict], dict], func, *args, **kwargs):
return func(*args, **kwargs)


def decode_requests_in_threadpool(
executor: concurrent.futures.ThreadPoolExecutor,
contexts: List[Union[List[dict], dict]],
func: Callable,
inputs: List[dict],
):
x = [executor.submit(_inject_context, context, func, input) for input, context in zip(inputs, contexts)]

for i, _x in enumerate(x):
x[i] = _x.result()

return x


def collate_requests(
lit_api: LitAPI, request_queue: Queue, max_batch_size: int, batch_timeout: float
) -> Tuple[List, List]:
Expand Down Expand Up @@ -165,7 +181,11 @@ def run_batched_loop(
max_batch_size: int,
batch_timeout: float,
callback_runner: CallbackRunner,
concurrent_decode: bool = True,
):
if concurrent_decode:
executor = concurrent.futures.ThreadPoolExecutor(max_workers=os.cpu_count())

while True:
batches, timed_out_uids = collate_requests(
lit_api,
Expand Down Expand Up @@ -193,14 +213,17 @@ def run_batched_loop(
lit_spec.populate_context(context, input)

callback_runner.trigger_event(EventTypes.BEFORE_DECODE_REQUEST, lit_api=lit_api)
x = [
_inject_context(
context,
lit_api.decode_request,
input,
)
for input, context in zip(inputs, contexts)
]
if concurrent_decode:
x = decode_requests_in_threadpool(executor, contexts, lit_api.decode_request, inputs)
else:
x = [
_inject_context(
context,
lit_api.decode_request,
input,
)
for input, context in zip(inputs, contexts)
]
callback_runner.trigger_event(EventTypes.AFTER_DECODE_REQUEST, lit_api=lit_api)

x = lit_api.batch(x)
Expand Down Expand Up @@ -300,7 +323,11 @@ def run_batched_streaming_loop(
max_batch_size: int,
batch_timeout: float,
callback_runner: CallbackRunner,
concurrent_decode: bool = True,
):
if concurrent_decode:
executor = concurrent.futures.ThreadPoolExecutor(max_workers=os.cpu_count())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can limit number of threads to batch size

Suggested change
executor = concurrent.futures.ThreadPoolExecutor(max_workers=os.cpu_count())
executor = concurrent.futures.ThreadPoolExecutor(max_workers=min(max_batch_size, os.cpu_count()))


while True:
batches, timed_out_uids = collate_requests(
lit_api,
Expand All @@ -326,14 +353,17 @@ def run_batched_streaming_loop(
lit_spec.populate_context(context, input)

callback_runner.trigger_event(EventTypes.BEFORE_DECODE_REQUEST, lit_api=lit_api)
x = [
_inject_context(
context,
lit_api.decode_request,
input,
)
for input, context in zip(inputs, contexts)
]
if concurrent_decode:
x = decode_requests_in_threadpool(executor, contexts, lit_api.decode_request, inputs)
else:
x = [
_inject_context(
context,
lit_api.decode_request,
input,
)
for input, context in zip(inputs, contexts)
]
callback_runner.trigger_event(EventTypes.AFTER_DECODE_REQUEST, lit_api=lit_api)

x = lit_api.batch(x)
Expand Down Expand Up @@ -378,6 +408,7 @@ def inference_worker(
stream: bool,
workers_setup_status: Dict[str, bool],
callback_runner: CallbackRunner,
concurrent_decode: bool = True,
):
callback_runner.trigger_event(EventTypes.BEFORE_SETUP, lit_api=lit_api)
lit_api.setup(device)
Expand All @@ -394,15 +425,29 @@ def inference_worker(
if stream:
if max_batch_size > 1:
run_batched_streaming_loop(
lit_api, lit_spec, request_queue, response_queues, max_batch_size, batch_timeout, callback_runner
lit_api,
lit_spec,
request_queue,
response_queues,
max_batch_size,
batch_timeout,
callback_runner,
concurrent_decode,
)
else:
run_streaming_loop(lit_api, lit_spec, request_queue, response_queues, callback_runner)
return

if max_batch_size > 1:
run_batched_loop(
lit_api, lit_spec, request_queue, response_queues, max_batch_size, batch_timeout, callback_runner
lit_api,
lit_spec,
request_queue,
response_queues,
max_batch_size,
batch_timeout,
callback_runner,
concurrent_decode,
)
else:
run_single_loop(lit_api, lit_spec, request_queue, response_queues, callback_runner)
3 changes: 3 additions & 0 deletions src/litserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def __init__(
callbacks: Optional[Union[List[Callback], Callback]] = None,
middlewares: Optional[list[Union[Callable, tuple[Callable, dict]]]] = None,
loggers: Optional[Union[Logger, List[Logger]]] = None,
concurrent_decode: bool = True,
):
if batch_timeout > timeout and timeout not in (False, -1):
raise ValueError("batch_timeout must be less than timeout")
Expand Down Expand Up @@ -178,6 +179,7 @@ def __init__(
self.max_payload_size = max_payload_size
self._connector = _Connector(accelerator=accelerator, devices=devices)
self._callback_runner = CallbackRunner(callbacks)
self._concurrent_decode = concurrent_decode

specs = spec if spec is not None else []
self._specs = specs if isinstance(specs, Sequence) else [specs]
Expand Down Expand Up @@ -246,6 +248,7 @@ def launch_inference_worker(self, num_uvicorn_servers: int):
self.stream,
self.workers_setup_status,
self._callback_runner,
self._concurrent_decode,
),
)
process.start()
Expand Down
Loading