diff --git a/qa/L0_backend_identity/identity_test.py b/qa/L0_backend_identity/identity_test.py index ef0634b95c..5b82890266 100755 --- a/qa/L0_backend_identity/identity_test.py +++ b/qa/L0_backend_identity/identity_test.py @@ -276,11 +276,13 @@ param0 = params["param0"] param1 = params["param1"] param2 = params["param2"] + param3 = params["param3"] else: params = response.parameters param0 = params["param0"].string_param param1 = params["param1"].int64_param param2 = params["param2"].bool_param + param3 = params["param3"].double_param if param0 != "an example string parameter": print("error: expected 'param0' == 'an example string parameter'") @@ -291,3 +293,6 @@ if param2 != False: print("error: expected 'param2' == False") sys.exit(1) + if param3 != 123.123: + print("error: expected 'param3' == 123.123") + sys.exit(1) diff --git a/qa/L0_http/generate_endpoint_test.py b/qa/L0_http/generate_endpoint_test.py index 87ad834be3..8c44ad8419 100755 --- a/qa/L0_http/generate_endpoint_test.py +++ b/qa/L0_http/generate_endpoint_test.py @@ -343,7 +343,7 @@ def test_complex_schema(self): inputs = { "PROMPT": "hello world", "STREAM": True, - "PARAMS": {"PARAM_0": 0, "PARAM_1": True}, + "PARAMS": {"PARAM_0": 0, "PARAM_1": True, "PARAM_2": 123.123}, } r = self.generate(self._model_name, inputs) try: diff --git a/qa/L0_parameters/parameters_test.py b/qa/L0_parameters/parameters_test.py index 959f0fc5dc..190326fc2c 100755 --- a/qa/L0_parameters/parameters_test.py +++ b/qa/L0_parameters/parameters_test.py @@ -56,6 +56,7 @@ async def asyncSetUp(self): self.parameter_list = [] self.parameter_list.append({"key1": "value1", "key2": "value2"}) self.parameter_list.append({"key1": 1, "key2": 2}) + self.parameter_list.append({"key1": 123.123, "key2": 321.321}) self.parameter_list.append({"key1": True, "key2": "value2"}) self.parameter_list.append({"triton_": True, "key2": "value2"}) diff --git a/src/grpc/infer_handler.cc b/src/grpc/infer_handler.cc index e179f0f34c..021cd3cf18 100644 --- a/src/grpc/infer_handler.cc +++ b/src/grpc/infer_handler.cc @@ -327,6 +327,12 @@ SetInferenceRequestMetadata( RETURN_IF_ERR(TRITONSERVER_InferenceRequestSetStringParameter( inference_request, param.first.c_str(), infer_param.string_param().c_str())); + } else if ( + infer_param.parameter_choice_case() == + inference::InferParameter::ParameterChoiceCase::kDoubleParam) { + RETURN_IF_ERR(TRITONSERVER_InferenceRequestSetDoubleParameter( + inference_request, param.first.c_str(), + infer_param.double_param())); } else { return TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_INVALID_ARG, diff --git a/src/grpc/infer_handler.h b/src/grpc/infer_handler.h index 1c96b0e1fe..42a9437c77 100644 --- a/src/grpc/infer_handler.h +++ b/src/grpc/infer_handler.h @@ -474,6 +474,9 @@ InferResponseCompleteCommon( case TRITONSERVER_PARAMETER_STRING: param.set_string_param(reinterpret_cast(vvalue)); break; + case TRITONSERVER_PARAMETER_DOUBLE: + param.set_double_param(*(reinterpret_cast(vvalue))); + break; case TRITONSERVER_PARAMETER_BYTES: return TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_UNSUPPORTED, diff --git a/src/http_server.cc b/src/http_server.cc index ab5f6c6f06..75319b9484 100644 --- a/src/http_server.cc +++ b/src/http_server.cc @@ -183,6 +183,11 @@ SetTritonParameterFromJsonParameter( RETURN_IF_ERR(value.AsBool(&bool_value)); RETURN_IF_ERR(TRITONSERVER_InferenceRequestSetBoolParameter( irequest, parameter.c_str(), bool_value)); + } else if (value.IsNumber()) { + double double_value; + RETURN_IF_ERR(value.AsDouble(&double_value)); + RETURN_IF_ERR(TRITONSERVER_InferenceRequestSetDoubleParameter( + irequest, parameter.c_str(), double_value)); } else { return TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_INVALID_ARG, @@ -3816,6 +3821,10 @@ HTTPAPIServer::InferRequestClass::FinalizeResponse( RETURN_IF_ERR(params_json.AddStringRef( name, reinterpret_cast(vvalue))); break; + case TRITONSERVER_PARAMETER_DOUBLE: + RETURN_IF_ERR(params_json.AddDouble( + name, *(reinterpret_cast(vvalue)))); + break; case TRITONSERVER_PARAMETER_BYTES: return TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_UNSUPPORTED, @@ -4271,6 +4280,7 @@ HTTPAPIServer::GenerateRequestClass::FinalizeResponse( case TRITONSERVER_PARAMETER_BOOL: case TRITONSERVER_PARAMETER_INT: case TRITONSERVER_PARAMETER_STRING: + case TRITONSERVER_PARAMETER_DOUBLE: triton_outputs.emplace( name, TritonOutput(TritonOutput::Type::PARAMETER, pidx)); break; @@ -4443,6 +4453,10 @@ HTTPAPIServer::GenerateRequestClass::ExactMappingOutput( RETURN_IF_ERR(generate_response->AddStringRef( name, reinterpret_cast(vvalue))); break; + case TRITONSERVER_PARAMETER_DOUBLE: + RETURN_IF_ERR(generate_response->AddDouble( + name, *(reinterpret_cast(vvalue)))); + break; case TRITONSERVER_PARAMETER_BYTES: return TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_UNSUPPORTED,