Skip to content

Commit

Permalink
Add async inference
Browse files Browse the repository at this point in the history
  • Loading branch information
duydq12 committed Nov 20, 2024
1 parent bcb40c0 commit c3dcd7b
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 2 deletions.
15 changes: 13 additions & 2 deletions src/infer_client/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

class InferenceAdapter(metaclass=ABCMeta):
"""
FilesystemAdapter interface
InferenceAdapter interface
"""

@abstractmethod
Expand All @@ -19,7 +19,18 @@ def health(self) -> bool:
@abstractmethod
def inference(self, ort_inputs: Dict[str, Any], ort_out_names: List[str]) -> List[Any]:
"""
Determine if a directory exists.
infernce model.
Arguments:
ort_inputs: The dictionary of input
ort_out_names: List of name of nodes wanna get value
Returns:
List of value of nodes
"""

@abstractmethod
async def inference_async(self, ort_inputs: Dict[str, Any], ort_out_names: List[str]) -> List[Any]:
"""
async infernce model.
Arguments:
ort_inputs: The dictionary of input
ort_out_names: List of name of nodes wanna get value
Expand Down
9 changes: 9 additions & 0 deletions src/infer_client/adapters/onnx.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import asyncio
import sys

from concurrent.futures import Executor
from os.path import join
from typing import Any, Dict, List

Expand All @@ -17,6 +19,7 @@ def __init__(
logger_level: int = 3,
use_tf32: bool = True,
enable_mem_pattern: bool = True,
executor: Executor = None,
) -> None:
onnxruntime.set_default_logger_severity(logger_level)
providers = ["CPUExecutionProvider"]
Expand Down Expand Up @@ -48,6 +51,7 @@ def __init__(
sess_options = onnxruntime.SessionOptions()
sess_options.enable_mem_pattern = enable_mem_pattern
self.ort_session = onnxruntime.InferenceSession(model_name, sess_options, providers=providers)
self.executor = executor

def health(self) -> bool:
if self.ort_session is None:
Expand All @@ -69,3 +73,8 @@ def inference(self, ort_inputs: Dict[str, Any], ort_out_names: List[str]) -> Lis
del io_binding

return ort_outs

async def inference_async(self, ort_inputs: Dict[str, Any], ort_out_names: List[str]) -> List[Any]:
return await asyncio.get_running_loop().run_in_executor(
self.executor, self.inference, ort_inputs, ort_out_names
)
3 changes: 3 additions & 0 deletions src/infer_client/adapters/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,6 @@ def inference(self, ort_inputs: Dict[str, Any], ort_out_names: List[str]) -> Lis
if len(ort_out_names) == 1:
return [res.as_numpy(ort_out_names[0])]
return [res.as_numpy(out) for out in ort_out_names]

async def inference_async(self, ort_inputs: Dict[str, Any], ort_out_names: List[str]) -> List[Any]:
raise NotImplementedError
3 changes: 3 additions & 0 deletions src/infer_client/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,6 @@ def health(self) -> bool:

def inference(self, ort_inputs, ort_out_names):
return self.adapter.inference(ort_inputs, ort_out_names)

async def inference_async(self, ort_inputs, ort_out_names):
return await self.adapter.inference_async(ort_inputs, ort_out_names)
24 changes: 24 additions & 0 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,27 @@ def test_inference(path: str, expected: str):
pred = res_onnx[0]
output = postprocess(pred)
assert output.item() == expected


@pytest.mark.asyncio
@pytest.mark.parametrize(
"path,expected",
(
(join(dirname(__file__), "resources/dog.jpg"), 264),
(join(dirname(__file__), "resources/cat.jpg"), 285),
),
)
async def test_async_inference(path: str, expected: str):
img = cv2.imread(path)
preprocessed_img = preprocess(img, (IMG_SIZE, IMG_SIZE))

res_onnx = await infer_onnx_obj.inference_async({"x": preprocessed_img}, ["400"])
res_triton = infer_triton_obj.inference({"x": preprocessed_img}, ["400"])

assert np.allclose(res_onnx, res_triton, rtol=1.0e-4)

if not res_onnx:
return None
pred = res_onnx[0]
output = postprocess(pred)
assert output.item() == expected

0 comments on commit c3dcd7b

Please sign in to comment.