diff --git a/tensor2tensor/serving/query.py b/tensor2tensor/serving/query.py index 69a7aadff..f64b30020 100644 --- a/tensor2tensor/serving/query.py +++ b/tensor2tensor/serving/query.py @@ -48,6 +48,8 @@ "cloud_mlengine_model_version", None, "Version of the model to use. If None, requests will be " "sent to the default version.") +flags.DEFINE_string("version", None, "Version of the model to use.") +flags.DEFINE_string("version_label", None, "Label of the model to use.") def validate_flags(): @@ -72,7 +74,9 @@ def make_request_fn(): request_fn = serving_utils.make_grpc_request_fn( servable_name=FLAGS.servable_name, server=FLAGS.server, - timeout_secs=FLAGS.timeout_secs) + timeout_secs=FLAGS.timeout_secs, + version_label=FLAGS.version_label, + version=FLAGS.version) return request_fn diff --git a/tensor2tensor/serving/serving_utils.py b/tensor2tensor/serving/serving_utils.py index a1b437282..f3ce7b596 100644 --- a/tensor2tensor/serving/serving_utils.py +++ b/tensor2tensor/serving/serving_utils.py @@ -105,7 +105,7 @@ def _decode(output_ids, output_decoder): -def make_grpc_request_fn(servable_name, server, timeout_secs): +def make_grpc_request_fn(servable_name, server, timeout_secs, version_label, version): """Wraps function to make grpc requests with runtime args.""" stub = _create_stub(server) @@ -113,6 +113,12 @@ def _make_grpc_request(examples): """Builds and sends request to TensorFlow model server.""" request = predict_pb2.PredictRequest() request.model_spec.name = servable_name + + if version_label is not None: + request.model_spec.version_label = version_label + elif version is not None: + request.model_spec.version = version + request.inputs["input"].CopyFrom( tf.make_tensor_proto( [ex.SerializeToString() for ex in examples], shape=[len(examples)]))