From 953e7feaf5edbc44ed94669ad32079c429ed278f Mon Sep 17 00:00:00 2001 From: fpetrini15 Date: Fri, 29 Dec 2023 16:18:23 -0800 Subject: [PATCH 1/3] Support Double-Type Infer/Response Parameters --- qa/L0_parameters/parameters_test.py | 1 + src/grpc/infer_handler.cc | 6 ++++++ src/grpc/infer_handler.h | 3 +++ src/http_server.cc | 14 ++++++++++++++ 4 files changed, 24 insertions(+) diff --git a/qa/L0_parameters/parameters_test.py b/qa/L0_parameters/parameters_test.py index 959f0fc5dc..190326fc2c 100755 --- a/qa/L0_parameters/parameters_test.py +++ b/qa/L0_parameters/parameters_test.py @@ -56,6 +56,7 @@ async def asyncSetUp(self): self.parameter_list = [] self.parameter_list.append({"key1": "value1", "key2": "value2"}) self.parameter_list.append({"key1": 1, "key2": 2}) + self.parameter_list.append({"key1": 123.123, "key2": 321.321}) self.parameter_list.append({"key1": True, "key2": "value2"}) self.parameter_list.append({"triton_": True, "key2": "value2"}) diff --git a/src/grpc/infer_handler.cc b/src/grpc/infer_handler.cc index e179f0f34c..021cd3cf18 100644 --- a/src/grpc/infer_handler.cc +++ b/src/grpc/infer_handler.cc @@ -327,6 +327,12 @@ SetInferenceRequestMetadata( RETURN_IF_ERR(TRITONSERVER_InferenceRequestSetStringParameter( inference_request, param.first.c_str(), infer_param.string_param().c_str())); + } else if ( + infer_param.parameter_choice_case() == + inference::InferParameter::ParameterChoiceCase::kDoubleParam) { + RETURN_IF_ERR(TRITONSERVER_InferenceRequestSetDoubleParameter( + inference_request, param.first.c_str(), + infer_param.double_param())); } else { return TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_INVALID_ARG, diff --git a/src/grpc/infer_handler.h b/src/grpc/infer_handler.h index 1c96b0e1fe..42a9437c77 100644 --- a/src/grpc/infer_handler.h +++ b/src/grpc/infer_handler.h @@ -474,6 +474,9 @@ InferResponseCompleteCommon( case TRITONSERVER_PARAMETER_STRING: param.set_string_param(reinterpret_cast(vvalue)); break; + case TRITONSERVER_PARAMETER_DOUBLE: + param.set_double_param(*(reinterpret_cast(vvalue))); + break; case TRITONSERVER_PARAMETER_BYTES: return TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_UNSUPPORTED, diff --git a/src/http_server.cc b/src/http_server.cc index d1bd9ce641..797951554e 100644 --- a/src/http_server.cc +++ b/src/http_server.cc @@ -183,6 +183,11 @@ SetTritonParameterFromJsonParameter( RETURN_IF_ERR(value.AsBool(&bool_value)); RETURN_IF_ERR(TRITONSERVER_InferenceRequestSetBoolParameter( irequest, parameter.c_str(), bool_value)); + } else if (value.IsNumber()) { + double double_value; + RETURN_IF_ERR(value.AsDouble(&double_value)); + RETURN_IF_ERR(TRITONSERVER_InferenceRequestSetDoubleParameter( + irequest, parameter.c_str(), double_value)); } else { return TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_INVALID_ARG, @@ -3816,6 +3821,10 @@ HTTPAPIServer::InferRequestClass::FinalizeResponse( RETURN_IF_ERR(params_json.AddStringRef( name, reinterpret_cast(vvalue))); break; + case TRITONSERVER_PARAMETER_DOUBLE: + RETURN_IF_ERR(params_json.AddInt( + name, *(reinterpret_cast(vvalue)))); + break; case TRITONSERVER_PARAMETER_BYTES: return TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_UNSUPPORTED, @@ -4271,6 +4280,7 @@ HTTPAPIServer::GenerateRequestClass::FinalizeResponse( case TRITONSERVER_PARAMETER_BOOL: case TRITONSERVER_PARAMETER_INT: case TRITONSERVER_PARAMETER_STRING: + case TRITONSERVER_PARAMETER_DOUBLE: triton_outputs.emplace( name, TritonOutput(TritonOutput::Type::PARAMETER, pidx)); break; @@ -4443,6 +4453,10 @@ HTTPAPIServer::GenerateRequestClass::ExactMappingOutput( RETURN_IF_ERR(generate_response->AddStringRef( name, reinterpret_cast(vvalue))); break; + case TRITONSERVER_PARAMETER_DOUBLE: + RETURN_IF_ERR(generate_response->AddDouble( + name, *(reinterpret_cast(vvalue)))); + break; case TRITONSERVER_PARAMETER_BYTES: return TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_UNSUPPORTED, From 0fc3acc3b15542432fb309500218770759cf07cc Mon Sep 17 00:00:00 2001 From: fpetrini15 Date: Fri, 29 Dec 2023 18:53:11 -0800 Subject: [PATCH 2/3] Add additional testing --- qa/L0_http/generate_endpoint_test.py | 2 +- src/http_server.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/qa/L0_http/generate_endpoint_test.py b/qa/L0_http/generate_endpoint_test.py index 29d2e20d96..94b1ebdaba 100755 --- a/qa/L0_http/generate_endpoint_test.py +++ b/qa/L0_http/generate_endpoint_test.py @@ -343,7 +343,7 @@ def test_complex_schema(self): inputs = { "PROMPT": "hello world", "STREAM": True, - "PARAMS": {"PARAM_0": 0, "PARAM_1": True}, + "PARAMS": {"PARAM_0": 0, "PARAM_1": True, "PARAM_2": 123.123}, } r = self.generate(self._model_name, inputs) try: diff --git a/src/http_server.cc b/src/http_server.cc index 797951554e..c9a6bc5b77 100644 --- a/src/http_server.cc +++ b/src/http_server.cc @@ -3822,7 +3822,7 @@ HTTPAPIServer::InferRequestClass::FinalizeResponse( name, reinterpret_cast(vvalue))); break; case TRITONSERVER_PARAMETER_DOUBLE: - RETURN_IF_ERR(params_json.AddInt( + RETURN_IF_ERR(params_json.AddDouble( name, *(reinterpret_cast(vvalue)))); break; case TRITONSERVER_PARAMETER_BYTES: From 6af2af2f5598d91a90d009ae6fabe6a94bb4f279 Mon Sep 17 00:00:00 2001 From: fpetrini15 Date: Tue, 2 Jan 2024 18:34:31 -0800 Subject: [PATCH 3/3] Additional testing --- qa/L0_backend_identity/identity_test.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/qa/L0_backend_identity/identity_test.py b/qa/L0_backend_identity/identity_test.py index ef0634b95c..5b82890266 100755 --- a/qa/L0_backend_identity/identity_test.py +++ b/qa/L0_backend_identity/identity_test.py @@ -276,11 +276,13 @@ param0 = params["param0"] param1 = params["param1"] param2 = params["param2"] + param3 = params["param3"] else: params = response.parameters param0 = params["param0"].string_param param1 = params["param1"].int64_param param2 = params["param2"].bool_param + param3 = params["param3"].double_param if param0 != "an example string parameter": print("error: expected 'param0' == 'an example string parameter'") @@ -291,3 +293,6 @@ if param2 != False: print("error: expected 'param2' == False") sys.exit(1) + if param3 != 123.123: + print("error: expected 'param3' == 123.123") + sys.exit(1)