Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: tritonfrontend gRPC Streaming Segmentation Fault #7671

Merged
merged 15 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions docs/customization_guide/tritonfrontend.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,15 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-->
### Triton Server (tritonfrontend) Bindings
oandreeva-nv marked this conversation as resolved.
Show resolved Hide resolved

The `tritonfrontend` python package is a set of bindings to Triton's existing frontends implemented in C++. Currently, `tritonfrontend` supports starting up `KServeHttp` and `KServeGrpc` frontends. These bindings used in-combination with Triton's Python In-Process API ([`tritonserver`](https://github.com/triton-inference-server/core/tree/main/python/tritonserver)) and [`tritonclient`](https://github.com/triton-inference-server/client/tree/main/src/python/library) extend the ability to use Triton's full feature set with a couple of lines of Python.
### Triton Server (tritonfrontend) Bindings (Beta)

The `tritonfrontend` python package is a set of bindings to Triton's existing
oandreeva-nv marked this conversation as resolved.
Show resolved Hide resolved
frontends implemented in C++. Currently, `tritonfrontend` supports starting up
`KServeHttp` and `KServeGrpc` frontends. These bindings used in-combination
with Triton's Python In-Process API
([`tritonserver`](https://github.com/triton-inference-server/core/tree/main/python/tritonserver))
and [`tritonclient`](https://github.com/triton-inference-server/client/tree/main/src/python/library)
extend the ability to use Triton's full feature set with a few lines of Python.

Let us walk through a simple example:
1. First we need to load the desired models and start the server with `tritonserver`.
Expand Down
112 changes: 62 additions & 50 deletions qa/L0_python_api/test_kserve.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,10 @@

import numpy as np
import pytest
import testing_utils as utils
nnshah1 marked this conversation as resolved.
Show resolved Hide resolved
import tritonclient.grpc as grpcclient
import tritonclient.http as httpclient
import tritonserver
from testing_utils import (
send_and_test_inference_identity,
setup_client,
setup_server,
setup_service,
teardown_client,
teardown_server,
teardown_service,
)
from tritonclient.utils import InferenceServerException
from tritonfrontend import KServeGrpc, KServeHttp

Expand Down Expand Up @@ -93,33 +85,33 @@ def test_wrong_grpc_parameters(self):
class TestKServe:
@pytest.mark.parametrize("frontend, client_type, url", [HTTP_ARGS, GRPC_ARGS])
def test_server_ready(self, frontend, client_type, url):
server = setup_server()
service = setup_service(server, frontend)
client = setup_client(client_type, url=url)
server = utils.setup_server()
service = utils.setup_service(server, frontend)
client = utils.setup_client(client_type, url=url)

assert client.is_server_ready()

teardown_client(client)
teardown_service(service)
teardown_server(server)
utils.teardown_client(client)
utils.teardown_service(service)
utils.teardown_server(server)

@pytest.mark.parametrize("frontend", [HTTP_ARGS[0], GRPC_ARGS[0]])
def test_service_double_start(self, frontend):
server = setup_server()
server = utils.setup_server()
# setup_service() performs service.start()
service = setup_service(server, frontend)
service = utils.setup_service(server, frontend)

with pytest.raises(
tritonserver.AlreadyExistsError, match="server is already running."
):
service.start()

teardown_server(server)
teardown_service(service)
utils.teardown_server(server)
utils.teardown_service(service)

@pytest.mark.parametrize("frontend", [HTTP_ARGS[0], GRPC_ARGS[0]])
def test_invalid_options(self, frontend):
server = setup_server()
server = utils.setup_server()
# Current flow is KServeHttp.Options or KServeGrpc.Options have to be
# provided to ensure type and range validation occurs.
with pytest.raises(
Expand All @@ -128,45 +120,65 @@ def test_invalid_options(self, frontend):
):
frontend.Server(server, {"port": 8001})

teardown_server(server)
utils.teardown_server(server)

@pytest.mark.parametrize("frontend", [HTTP_ARGS[0], GRPC_ARGS[0]])
def test_server_service_order(self, frontend):
server = setup_server()
service = setup_service(server, frontend)
server = utils.setup_server()
service = utils.setup_service(server, frontend)

teardown_server(server)
teardown_service(service)
utils.teardown_server(server)
utils.teardown_service(service)

@pytest.mark.parametrize("frontend, client_type", [HTTP_ARGS[:2], GRPC_ARGS[:2]])
def test_service_custom_port(self, frontend, client_type):
server = setup_server()
server = utils.setup_server()
options = frontend.Options(port=8005)
service = setup_service(server, frontend, options)
client = setup_client(client_type, url="localhost:8005")
service = utils.setup_service(server, frontend, options)
client = utils.setup_client(client_type, url="localhost:8005")

# Confirms that service starts at port 8005
client.is_server_ready()

teardown_client(client)
teardown_service(service)
teardown_server(server)
utils.teardown_client(client)
utils.teardown_service(service)
utils.teardown_server(server)

@pytest.mark.parametrize("frontend, client_type, url", [HTTP_ARGS, GRPC_ARGS])
def test_inference(self, frontend, client_type, url):
server = setup_server()
service = setup_service(server, frontend)
server = utils.setup_server()
service = utils.setup_service(server, frontend)

# TODO: use common/test_infer
assert send_and_test_inference_identity(client_type, url=url)
assert utils.send_and_test_inference_identity(client_type, url=url)

teardown_service(service)
teardown_server(server)
utils.teardown_service(service)
utils.teardown_server(server)

@pytest.mark.parametrize("frontend, client_type, url", [GRPC_ARGS])
def test_streaming_inference(self, frontend, client_type, url):
server = utils.setup_server()
service = utils.setup_service(server, frontend)

assert utils.send_and_test_stream_inference(client_type, url)

utils.teardown_service(service)
utils.teardown_server(server)

@pytest.mark.parametrize("frontend, client_type, url", [HTTP_ARGS])
def test_http_generate_inference(self, frontend, client_type, url):
server = utils.setup_server()
service = utils.setup_service(server, frontend)

assert utils.send_and_test_generate_inference()

utils.teardown_service(service)
utils.teardown_server(server)

@pytest.mark.parametrize("frontend, client_type, url", [HTTP_ARGS])
def test_http_req_during_shutdown(self, frontend, client_type, url):
server = setup_server()
http_service = setup_service(server, frontend)
server = utils.setup_server()
http_service = utils.setup_service(server, frontend)
http_client = httpclient.InferenceServerClient(url="localhost:8000")
model_name = "delayed_identity"
delay = 2 # seconds
Expand All @@ -182,7 +194,7 @@ def test_http_req_during_shutdown(self, frontend, client_type, url):
model_name=model_name, inputs=inputs, outputs=outputs
)
# http_service.stop() does not use graceful shutdown
teardown_service(http_service)
utils.teardown_service(http_service)

# So, inference request will fail as http endpoints have been stopped.
with pytest.raises(
Expand All @@ -194,20 +206,20 @@ def test_http_req_during_shutdown(self, frontend, client_type, url):
# However, due to an unsuccessful get_result(), async_request is still
# an active thread. Hence, join stalls until greenlet timeouts.
# Does not throw an exception, but displays error in logs.
teardown_client(http_client)
utils.teardown_client(http_client)

# delayed_identity will still be an active model
# Hence, server.stop() causes InternalError: Timeout.
with pytest.raises(
tritonserver.InternalError,
match="Exit timeout expired. Exiting immediately.",
):
teardown_server(server)
utils.teardown_server(server)

@pytest.mark.parametrize("frontend, client_type, url", [GRPC_ARGS])
def test_grpc_req_during_shutdown(self, frontend, client_type, url):
server = setup_server()
grpc_service = setup_service(server, frontend)
server = utils.setup_server()
grpc_service = utils.setup_service(server, frontend)
grpc_client = grpcclient.InferenceServerClient(url=url)
user_data = []

Expand All @@ -234,7 +246,7 @@ def callback(user_data, result, error):
callback=partial(callback, user_data),
)

teardown_service(grpc_service)
utils.teardown_service(grpc_service)

time_out = delay + 1
while (len(user_data) == 0) and time_out > 0:
Expand All @@ -256,17 +268,17 @@ def callback(user_data, result, error):
)
)

teardown_client(grpc_client)
teardown_server(server)
utils.teardown_client(grpc_client)
utils.teardown_server(server)

# KNOWN ISSUE: CAUSES SEGFAULT
# Created [DLIS-7231] to address at future date
# Once the server has been stopped, the underlying TRITONSERVER_Server instance
# is deleted. However, the frontend does not know the server instance
# is no longer valid.
# def test_inference_after_server_stop(self):
# server = setup_server()
# http_service = setup_service(server, KServeHttp)
# server = utils.setup_server()
# http_service = utils.setup_service(server, KServeHttp)
# http_client = setup_client(httpclient, url="localhost:8000")

# teardown_server(server) # Server has been stopped
Expand All @@ -282,5 +294,5 @@ def callback(user_data, result, error):

# results = http_client.infer(model_name, inputs=inputs, outputs=outputs)

# teardown_client(http_client)
# teardown_service(http_service)
# utils.teardown_client(http_client)
# utils.teardown_service(http_service)
81 changes: 81 additions & 0 deletions qa/L0_python_api/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,20 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import json
Fixed Show fixed Hide fixed
import os
import queue
from typing import Union

import numpy as np
import requests
import tritonserver
from tritonclient.utils import InferenceServerException
from tritonfrontend import KServeGrpc, KServeHttp

# TODO: Re-Format documentation to fit:
# https://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings


def setup_server(model_repository="test_model_repository") -> tritonserver.Server:
module_directory = os.path.split(os.path.abspath(__file__))[0]
Expand Down Expand Up @@ -93,3 +100,77 @@

teardown_client(client)
return input_data[0] == output_data[0].decode()


# Sends a streaming inference request to test_model_repository/identity model
KrishnanPrash marked this conversation as resolved.
Show resolved Hide resolved
# and verifies input == output
def send_and_test_stream_inference(frontend_client, url: str) -> bool:
model_name = "identity"

# Setting up the gRPC client stream
results = queue.Queue()
callback = lambda error, result: results.put(error or result)
client = frontend_client.InferenceServerClient(url=url)

client.start_stream(callback=callback)

# Preparing Input Data
text_input = "testing"
input_tensor = frontend_client.InferInput(
name="INPUT0", shape=[1], datatype="BYTES"
)
input_tensor.set_data_from_numpy(np.array([text_input.encode()], dtype=np.object_))

# Sending Streaming Inference Request
client.async_stream_infer(
model_name=model_name, inputs=[input_tensor], enable_empty_final_response=True
)

# Looping through until exception thrown or request completed
completed_requests, num_requests = 0, 1
text_output, is_final = None, None
while completed_requests != num_requests:
result = results.get()
if isinstance(result, InferenceServerException):
if result.status() == "StatusCode.CANCELLED":
completed_requests += 1
raise result

# Processing Response
text_output = result.as_numpy("OUTPUT0")[0].decode()

triton_final_response = result.get_response().parameters.get(
"triton_final_response", {}
)

is_final = False
if triton_final_response.HasField("bool_param"):
is_final = triton_final_response.bool_param

# Request Completed
if is_final:
completed_requests += 1

# Tearing down gRPC client stream
client.stop_stream(cancel_requests=True)

return is_final and (text_input == text_output)


def send_and_test_generate_inference() -> bool:
model_name = "identity"
url = f"http://localhost:8000/v2/models/{model_name}/generate"
input_text = "testing"
data = {
"INPUT0": input_text,
}

response = requests.post(url, json=data, stream=True)
Copy link
Collaborator

@rmccorm4 rmccorm4 Oct 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment for future: I'm not sure that stream=True here is meaningful -- what was your intention or expected behavior by setting it?

(Can revisit later though)

if response.status_code == 200:
result = response.json()
output_text = result.get("OUTPUT0", "")

if output_text == input_text:
return True

return False
14 changes: 8 additions & 6 deletions src/grpc/stream_infer_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -324,12 +324,14 @@ ModelStreamInferHandler::Process(InferHandler::State* state, bool rpc_ok)
if (err == nullptr) {
TRITONSERVER_InferenceTrace* triton_trace = nullptr;
#ifdef TRITON_ENABLE_TRACING
GrpcServerCarrier carrier(state->context_->ctx_.get());
auto start_options =
trace_manager_->GetTraceStartOptions(carrier, request.model_name());
state->trace_ = std::move(trace_manager_->SampleTrace(start_options));
if (state->trace_ != nullptr) {
triton_trace = state->trace_->trace_;
if (trace_manager_ != nullptr) {
GrpcServerCarrier carrier(state->context_->ctx_.get());
auto start_options =
trace_manager_->GetTraceStartOptions(carrier, request.model_name());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see now. I think this is not the only potential place for a SegFault. Do we want to fix it in other places as well? Or we're targeting streaming case at the moment?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be happy to add checks in other places as well. Could you provide an example where these checks would be needed?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@KrishnanPrash Can you look for TRITON_ENABLE_TRACING blocks and break the logic into separate utility function in grpc_utils.h?
This function will take trace_manager_ as input. If trace_manager_ is nullptr then the logic is skipped, otherwise the same logic is run.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this place (HandleGenerate) seem to be relevant:

server/src/http_server.cc

Lines 3235 to 3237 in dbb064f

TRITONSERVER_InferenceTrace* triton_trace = nullptr;
std::shared_ptr<TraceManager::Trace> trace =
StartTrace(req, model_name, &triton_trace);

I can see that HandleInfer on HTTP is guarded:

server/src/http_server.cc

Lines 3599 to 3602 in dbb064f

if (trace_manager_) {
// If tracing is enabled see if this request should be traced.
trace = StartTrace(req, model_name, &triton_trace);
}

qq to @GuanLuo , I can see that you've added guards for HandleInfer (link above), was there a reason not to guard HandleGenerate ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@KrishnanPrash, do you need to support trace update with this PR?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can see that you've added guards for HandleInfer (link above), was there a reason not to guard HandleGenerate ?

Probably just a missed spot..

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see in the HandleInfer() method - that the trace_manager_ check is not guarded with #IF TRACING compile time macro - should it be?

I guess secondary question - if tracing is compiled in - should trace_manager_ be null? - could we make that a pre-condition instead of runtime check if tracing is enabled? (just a question)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't believe trace_manager_ needs to be guarded with a #ifdef TRACING_ENABLE_TRACING compile time macro because StartTrace() function has it's own check for if tracing is enabled.

As for the secondary question, if tracing is compiled in, trace_manager_ should technically never be null after being passed to the services. However, because the bindings do not yet support tracing, a hopefully temporary situation arises of tracing being enabled, but trace_manager_ being null.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah so this is temporary? - once enable tracing is added we can remove the runtime check?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could remove the runtime check, but personally I am in favor of keeping these checks because it allows to catch unexpected behavior earlier, rather than later.

After tracing support is provided in tritonfrontend, we could probably modify these checks and return an error to fail earlier with a cleaner error message.

state->trace_ = std::move(trace_manager_->SampleTrace(start_options));
if (state->trace_ != nullptr) {
triton_trace = state->trace_->trace_;
}
}
#endif // TRITON_ENABLE_TRACING

Expand Down
11 changes: 9 additions & 2 deletions src/http_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1810,6 +1810,10 @@ HTTPAPIServer::HandleTrace(evhtp_request_t* req, const std::string& model_name)
}

#ifdef TRITON_ENABLE_TRACING
if (trace_manager_ == nullptr) {
return;
}

TRITONSERVER_InferenceTraceLevel level = TRITONSERVER_TRACE_LEVEL_DISABLED;
uint32_t rate;
int32_t count;
Expand Down Expand Up @@ -3233,8 +3237,11 @@ HTTPAPIServer::HandleGenerate(

// If tracing is enabled see if this request should be traced.
TRITONSERVER_InferenceTrace* triton_trace = nullptr;
std::shared_ptr<TraceManager::Trace> trace =
StartTrace(req, model_name, &triton_trace);
std::shared_ptr<TraceManager::Trace> trace;
KrishnanPrash marked this conversation as resolved.
Show resolved Hide resolved
if (trace_manager_) {
rmccorm4 marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not put into StartTrace()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will make these changes as well as a part of this refactoring ticket [DLIS-7380].

// If tracing is enabled see if this request should be traced.
trace = StartTrace(req, model_name, &triton_trace);
}

std::map<std::string, triton::common::TritonJson::Value> input_metadata;
triton::common::TritonJson::Value meta_data_root;
Expand Down
Loading