Skip to content

Commit

Permalink
fix: tritonfrontend gRPC Streaming Segmentation Fault (#7671)
Browse files Browse the repository at this point in the history
  • Loading branch information
KrishnanPrash authored Oct 7, 2024
1 parent 6edd5c6 commit b247eb5
Show file tree
Hide file tree
Showing 5 changed files with 168 additions and 61 deletions.
12 changes: 9 additions & 3 deletions docs/customization_guide/tritonfrontend.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,15 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-->
### Triton Server (tritonfrontend) Bindings

The `tritonfrontend` python package is a set of bindings to Triton's existing frontends implemented in C++. Currently, `tritonfrontend` supports starting up `KServeHttp` and `KServeGrpc` frontends. These bindings used in-combination with Triton's Python In-Process API ([`tritonserver`](https://github.com/triton-inference-server/core/tree/main/python/tritonserver)) and [`tritonclient`](https://github.com/triton-inference-server/client/tree/main/src/python/library) extend the ability to use Triton's full feature set with a couple of lines of Python.
### Triton Server (tritonfrontend) Bindings (Beta)

The `tritonfrontend` python package is a set of bindings to Triton's existing
frontends implemented in C++. Currently, `tritonfrontend` supports starting up
`KServeHttp` and `KServeGrpc` frontends. These bindings used in-combination
with Triton's Python In-Process API
([`tritonserver`](https://github.com/triton-inference-server/core/tree/main/python/tritonserver))
and [`tritonclient`](https://github.com/triton-inference-server/client/tree/main/src/python/library)
extend the ability to use Triton's full feature set with a few lines of Python.

Let us walk through a simple example:
1. First we need to load the desired models and start the server with `tritonserver`.
Expand Down
112 changes: 62 additions & 50 deletions qa/L0_python_api/test_kserve.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,10 @@

import numpy as np
import pytest
import testing_utils as utils
import tritonclient.grpc as grpcclient
import tritonclient.http as httpclient
import tritonserver
from testing_utils import (
send_and_test_inference_identity,
setup_client,
setup_server,
setup_service,
teardown_client,
teardown_server,
teardown_service,
)
from tritonclient.utils import InferenceServerException
from tritonfrontend import KServeGrpc, KServeHttp

Expand Down Expand Up @@ -93,33 +85,33 @@ def test_wrong_grpc_parameters(self):
class TestKServe:
@pytest.mark.parametrize("frontend, client_type, url", [HTTP_ARGS, GRPC_ARGS])
def test_server_ready(self, frontend, client_type, url):
server = setup_server()
service = setup_service(server, frontend)
client = setup_client(client_type, url=url)
server = utils.setup_server()
service = utils.setup_service(server, frontend)
client = utils.setup_client(client_type, url=url)

assert client.is_server_ready()

teardown_client(client)
teardown_service(service)
teardown_server(server)
utils.teardown_client(client)
utils.teardown_service(service)
utils.teardown_server(server)

@pytest.mark.parametrize("frontend", [HTTP_ARGS[0], GRPC_ARGS[0]])
def test_service_double_start(self, frontend):
server = setup_server()
server = utils.setup_server()
# setup_service() performs service.start()
service = setup_service(server, frontend)
service = utils.setup_service(server, frontend)

with pytest.raises(
tritonserver.AlreadyExistsError, match="server is already running."
):
service.start()

teardown_server(server)
teardown_service(service)
utils.teardown_server(server)
utils.teardown_service(service)

@pytest.mark.parametrize("frontend", [HTTP_ARGS[0], GRPC_ARGS[0]])
def test_invalid_options(self, frontend):
server = setup_server()
server = utils.setup_server()
# Current flow is KServeHttp.Options or KServeGrpc.Options have to be
# provided to ensure type and range validation occurs.
with pytest.raises(
Expand All @@ -128,45 +120,65 @@ def test_invalid_options(self, frontend):
):
frontend.Server(server, {"port": 8001})

teardown_server(server)
utils.teardown_server(server)

@pytest.mark.parametrize("frontend", [HTTP_ARGS[0], GRPC_ARGS[0]])
def test_server_service_order(self, frontend):
server = setup_server()
service = setup_service(server, frontend)
server = utils.setup_server()
service = utils.setup_service(server, frontend)

teardown_server(server)
teardown_service(service)
utils.teardown_server(server)
utils.teardown_service(service)

@pytest.mark.parametrize("frontend, client_type", [HTTP_ARGS[:2], GRPC_ARGS[:2]])
def test_service_custom_port(self, frontend, client_type):
server = setup_server()
server = utils.setup_server()
options = frontend.Options(port=8005)
service = setup_service(server, frontend, options)
client = setup_client(client_type, url="localhost:8005")
service = utils.setup_service(server, frontend, options)
client = utils.setup_client(client_type, url="localhost:8005")

# Confirms that service starts at port 8005
client.is_server_ready()

teardown_client(client)
teardown_service(service)
teardown_server(server)
utils.teardown_client(client)
utils.teardown_service(service)
utils.teardown_server(server)

@pytest.mark.parametrize("frontend, client_type, url", [HTTP_ARGS, GRPC_ARGS])
def test_inference(self, frontend, client_type, url):
server = setup_server()
service = setup_service(server, frontend)
server = utils.setup_server()
service = utils.setup_service(server, frontend)

# TODO: use common/test_infer
assert send_and_test_inference_identity(client_type, url=url)
assert utils.send_and_test_inference_identity(client_type, url=url)

teardown_service(service)
teardown_server(server)
utils.teardown_service(service)
utils.teardown_server(server)

@pytest.mark.parametrize("frontend, client_type, url", [GRPC_ARGS])
def test_streaming_inference(self, frontend, client_type, url):
server = utils.setup_server()
service = utils.setup_service(server, frontend)

assert utils.send_and_test_stream_inference(client_type, url)

utils.teardown_service(service)
utils.teardown_server(server)

@pytest.mark.parametrize("frontend, client_type, url", [HTTP_ARGS])
def test_http_generate_inference(self, frontend, client_type, url):
server = utils.setup_server()
service = utils.setup_service(server, frontend)

assert utils.send_and_test_generate_inference()

utils.teardown_service(service)
utils.teardown_server(server)

@pytest.mark.parametrize("frontend, client_type, url", [HTTP_ARGS])
def test_http_req_during_shutdown(self, frontend, client_type, url):
server = setup_server()
http_service = setup_service(server, frontend)
server = utils.setup_server()
http_service = utils.setup_service(server, frontend)
http_client = httpclient.InferenceServerClient(url="localhost:8000")
model_name = "delayed_identity"
delay = 2 # seconds
Expand All @@ -182,7 +194,7 @@ def test_http_req_during_shutdown(self, frontend, client_type, url):
model_name=model_name, inputs=inputs, outputs=outputs
)
# http_service.stop() does not use graceful shutdown
teardown_service(http_service)
utils.teardown_service(http_service)

# So, inference request will fail as http endpoints have been stopped.
with pytest.raises(
Expand All @@ -194,20 +206,20 @@ def test_http_req_during_shutdown(self, frontend, client_type, url):
# However, due to an unsuccessful get_result(), async_request is still
# an active thread. Hence, join stalls until greenlet timeouts.
# Does not throw an exception, but displays error in logs.
teardown_client(http_client)
utils.teardown_client(http_client)

# delayed_identity will still be an active model
# Hence, server.stop() causes InternalError: Timeout.
with pytest.raises(
tritonserver.InternalError,
match="Exit timeout expired. Exiting immediately.",
):
teardown_server(server)
utils.teardown_server(server)

@pytest.mark.parametrize("frontend, client_type, url", [GRPC_ARGS])
def test_grpc_req_during_shutdown(self, frontend, client_type, url):
server = setup_server()
grpc_service = setup_service(server, frontend)
server = utils.setup_server()
grpc_service = utils.setup_service(server, frontend)
grpc_client = grpcclient.InferenceServerClient(url=url)
user_data = []

Expand All @@ -234,7 +246,7 @@ def callback(user_data, result, error):
callback=partial(callback, user_data),
)

teardown_service(grpc_service)
utils.teardown_service(grpc_service)

time_out = delay + 1
while (len(user_data) == 0) and time_out > 0:
Expand All @@ -256,17 +268,17 @@ def callback(user_data, result, error):
)
)

teardown_client(grpc_client)
teardown_server(server)
utils.teardown_client(grpc_client)
utils.teardown_server(server)

# KNOWN ISSUE: CAUSES SEGFAULT
# Created [DLIS-7231] to address at future date
# Once the server has been stopped, the underlying TRITONSERVER_Server instance
# is deleted. However, the frontend does not know the server instance
# is no longer valid.
# def test_inference_after_server_stop(self):
# server = setup_server()
# http_service = setup_service(server, KServeHttp)
# server = utils.setup_server()
# http_service = utils.setup_service(server, KServeHttp)
# http_client = setup_client(httpclient, url="localhost:8000")

# teardown_server(server) # Server has been stopped
Expand All @@ -282,5 +294,5 @@ def callback(user_data, result, error):

# results = http_client.infer(model_name, inputs=inputs, outputs=outputs)

# teardown_client(http_client)
# teardown_service(http_service)
# utils.teardown_client(http_client)
# utils.teardown_service(http_service)
80 changes: 80 additions & 0 deletions qa/L0_python_api/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,18 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import os
import queue
from typing import Union

import numpy as np
import requests
import tritonserver
from tritonclient.utils import InferenceServerException
from tritonfrontend import KServeGrpc, KServeHttp

# TODO: Re-Format documentation to fit:
# https://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings


def setup_server(model_repository="test_model_repository") -> tritonserver.Server:
module_directory = os.path.split(os.path.abspath(__file__))[0]
Expand Down Expand Up @@ -93,3 +99,77 @@ def send_and_test_inference_identity(frontend_client, url: str) -> bool:

teardown_client(client)
return input_data[0] == output_data[0].decode()


# Sends a streaming inference request to test_model_repository/identity model
# and verifies input == output
def send_and_test_stream_inference(frontend_client, url: str) -> bool:
model_name = "identity"

# Setting up the gRPC client stream
results = queue.Queue()
callback = lambda error, result: results.put(error or result)
client = frontend_client.InferenceServerClient(url=url)

client.start_stream(callback=callback)

# Preparing Input Data
text_input = "testing"
input_tensor = frontend_client.InferInput(
name="INPUT0", shape=[1], datatype="BYTES"
)
input_tensor.set_data_from_numpy(np.array([text_input.encode()], dtype=np.object_))

# Sending Streaming Inference Request
client.async_stream_infer(
model_name=model_name, inputs=[input_tensor], enable_empty_final_response=True
)

# Looping through until exception thrown or request completed
completed_requests, num_requests = 0, 1
text_output, is_final = None, None
while completed_requests != num_requests:
result = results.get()
if isinstance(result, InferenceServerException):
if result.status() == "StatusCode.CANCELLED":
completed_requests += 1
raise result

# Processing Response
text_output = result.as_numpy("OUTPUT0")[0].decode()

triton_final_response = result.get_response().parameters.get(
"triton_final_response", {}
)

is_final = False
if triton_final_response.HasField("bool_param"):
is_final = triton_final_response.bool_param

# Request Completed
if is_final:
completed_requests += 1

# Tearing down gRPC client stream
client.stop_stream(cancel_requests=True)

return is_final and (text_input == text_output)


def send_and_test_generate_inference() -> bool:
model_name = "identity"
url = f"http://localhost:8000/v2/models/{model_name}/generate"
input_text = "testing"
data = {
"INPUT0": input_text,
}

response = requests.post(url, json=data, stream=True)
if response.status_code == 200:
result = response.json()
output_text = result.get("OUTPUT0", "")

if output_text == input_text:
return True

return False
14 changes: 8 additions & 6 deletions src/grpc/stream_infer_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -324,12 +324,14 @@ ModelStreamInferHandler::Process(InferHandler::State* state, bool rpc_ok)
if (err == nullptr) {
TRITONSERVER_InferenceTrace* triton_trace = nullptr;
#ifdef TRITON_ENABLE_TRACING
GrpcServerCarrier carrier(state->context_->ctx_.get());
auto start_options =
trace_manager_->GetTraceStartOptions(carrier, request.model_name());
state->trace_ = std::move(trace_manager_->SampleTrace(start_options));
if (state->trace_ != nullptr) {
triton_trace = state->trace_->trace_;
if (trace_manager_ != nullptr) {
GrpcServerCarrier carrier(state->context_->ctx_.get());
auto start_options =
trace_manager_->GetTraceStartOptions(carrier, request.model_name());
state->trace_ = std::move(trace_manager_->SampleTrace(start_options));
if (state->trace_ != nullptr) {
triton_trace = state->trace_->trace_;
}
}
#endif // TRITON_ENABLE_TRACING

Expand Down
11 changes: 9 additions & 2 deletions src/http_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1810,6 +1810,10 @@ HTTPAPIServer::HandleTrace(evhtp_request_t* req, const std::string& model_name)
}

#ifdef TRITON_ENABLE_TRACING
if (trace_manager_ == nullptr) {
return;
}

TRITONSERVER_InferenceTraceLevel level = TRITONSERVER_TRACE_LEVEL_DISABLED;
uint32_t rate;
int32_t count;
Expand Down Expand Up @@ -3233,8 +3237,11 @@ HTTPAPIServer::HandleGenerate(

// If tracing is enabled see if this request should be traced.
TRITONSERVER_InferenceTrace* triton_trace = nullptr;
std::shared_ptr<TraceManager::Trace> trace =
StartTrace(req, model_name, &triton_trace);
std::shared_ptr<TraceManager::Trace> trace;
if (trace_manager_) {
// If tracing is enabled see if this request should be traced.
trace = StartTrace(req, model_name, &triton_trace);
}

std::map<std::string, triton::common::TritonJson::Value> input_metadata;
triton::common::TritonJson::Value meta_data_root;
Expand Down

0 comments on commit b247eb5

Please sign in to comment.