Skip to content

Commit

Permalink
Support Double-Type Inference Request/Response Parameters (#6755)
Browse files Browse the repository at this point in the history
* Support Double-Type Infer/Response Parameters
  • Loading branch information
fpetrini15 authored Feb 1, 2024
1 parent 8f98789 commit 9860f73
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 1 deletion.
5 changes: 5 additions & 0 deletions qa/L0_backend_identity/identity_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,11 +276,13 @@
param0 = params["param0"]
param1 = params["param1"]
param2 = params["param2"]
param3 = params["param3"]
else:
params = response.parameters
param0 = params["param0"].string_param
param1 = params["param1"].int64_param
param2 = params["param2"].bool_param
param3 = params["param3"].double_param

if param0 != "an example string parameter":
print("error: expected 'param0' == 'an example string parameter'")
Expand All @@ -291,3 +293,6 @@
if param2 != False:
print("error: expected 'param2' == False")
sys.exit(1)
if param3 != 123.123:
print("error: expected 'param3' == 123.123")
sys.exit(1)
2 changes: 1 addition & 1 deletion qa/L0_http/generate_endpoint_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def test_complex_schema(self):
inputs = {
"PROMPT": "hello world",
"STREAM": True,
"PARAMS": {"PARAM_0": 0, "PARAM_1": True},
"PARAMS": {"PARAM_0": 0, "PARAM_1": True, "PARAM_2": 123.123},
}
r = self.generate(self._model_name, inputs)
try:
Expand Down
1 change: 1 addition & 0 deletions qa/L0_parameters/parameters_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ async def asyncSetUp(self):
self.parameter_list = []
self.parameter_list.append({"key1": "value1", "key2": "value2"})
self.parameter_list.append({"key1": 1, "key2": 2})
self.parameter_list.append({"key1": 123.123, "key2": 321.321})
self.parameter_list.append({"key1": True, "key2": "value2"})
self.parameter_list.append({"triton_": True, "key2": "value2"})

Expand Down
6 changes: 6 additions & 0 deletions src/grpc/infer_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,12 @@ SetInferenceRequestMetadata(
RETURN_IF_ERR(TRITONSERVER_InferenceRequestSetStringParameter(
inference_request, param.first.c_str(),
infer_param.string_param().c_str()));
} else if (
infer_param.parameter_choice_case() ==
inference::InferParameter::ParameterChoiceCase::kDoubleParam) {
RETURN_IF_ERR(TRITONSERVER_InferenceRequestSetDoubleParameter(
inference_request, param.first.c_str(),
infer_param.double_param()));
} else {
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INVALID_ARG,
Expand Down
3 changes: 3 additions & 0 deletions src/grpc/infer_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,9 @@ InferResponseCompleteCommon(
case TRITONSERVER_PARAMETER_STRING:
param.set_string_param(reinterpret_cast<const char*>(vvalue));
break;
case TRITONSERVER_PARAMETER_DOUBLE:
param.set_double_param(*(reinterpret_cast<const double*>(vvalue)));
break;
case TRITONSERVER_PARAMETER_BYTES:
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_UNSUPPORTED,
Expand Down
14 changes: 14 additions & 0 deletions src/http_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,11 @@ SetTritonParameterFromJsonParameter(
RETURN_IF_ERR(value.AsBool(&bool_value));
RETURN_IF_ERR(TRITONSERVER_InferenceRequestSetBoolParameter(
irequest, parameter.c_str(), bool_value));
} else if (value.IsNumber()) {
double double_value;
RETURN_IF_ERR(value.AsDouble(&double_value));
RETURN_IF_ERR(TRITONSERVER_InferenceRequestSetDoubleParameter(
irequest, parameter.c_str(), double_value));
} else {
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INVALID_ARG,
Expand Down Expand Up @@ -3816,6 +3821,10 @@ HTTPAPIServer::InferRequestClass::FinalizeResponse(
RETURN_IF_ERR(params_json.AddStringRef(
name, reinterpret_cast<const char*>(vvalue)));
break;
case TRITONSERVER_PARAMETER_DOUBLE:
RETURN_IF_ERR(params_json.AddDouble(
name, *(reinterpret_cast<const double*>(vvalue))));
break;
case TRITONSERVER_PARAMETER_BYTES:
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_UNSUPPORTED,
Expand Down Expand Up @@ -4271,6 +4280,7 @@ HTTPAPIServer::GenerateRequestClass::FinalizeResponse(
case TRITONSERVER_PARAMETER_BOOL:
case TRITONSERVER_PARAMETER_INT:
case TRITONSERVER_PARAMETER_STRING:
case TRITONSERVER_PARAMETER_DOUBLE:
triton_outputs.emplace(
name, TritonOutput(TritonOutput::Type::PARAMETER, pidx));
break;
Expand Down Expand Up @@ -4443,6 +4453,10 @@ HTTPAPIServer::GenerateRequestClass::ExactMappingOutput(
RETURN_IF_ERR(generate_response->AddStringRef(
name, reinterpret_cast<const char*>(vvalue)));
break;
case TRITONSERVER_PARAMETER_DOUBLE:
RETURN_IF_ERR(generate_response->AddDouble(
name, *(reinterpret_cast<const double*>(vvalue))));
break;
case TRITONSERVER_PARAMETER_BYTES:
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_UNSUPPORTED,
Expand Down

0 comments on commit 9860f73

Please sign in to comment.