Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
Added functionality for querying by version number/version label.
Browse files Browse the repository at this point in the history
  • Loading branch information
jamieposton committed Aug 3, 2020
1 parent 2ea8ec1 commit d40ac6c
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
6 changes: 5 additions & 1 deletion tensor2tensor/serving/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
"cloud_mlengine_model_version", None,
"Version of the model to use. If None, requests will be "
"sent to the default version.")
flags.DEFINE_string("version", None, "Version of the model to use.")
flags.DEFINE_string("version_label", None, "Label of the model to use.")


def validate_flags():
Expand All @@ -72,7 +74,9 @@ def make_request_fn():
request_fn = serving_utils.make_grpc_request_fn(
servable_name=FLAGS.servable_name,
server=FLAGS.server,
timeout_secs=FLAGS.timeout_secs)
timeout_secs=FLAGS.timeout_secs,
version_label=FLAGS.version_label,
version=FLAGS.version)
return request_fn


Expand Down
8 changes: 7 additions & 1 deletion tensor2tensor/serving/serving_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,14 +105,20 @@ def _decode(output_ids, output_decoder):



def make_grpc_request_fn(servable_name, server, timeout_secs):
def make_grpc_request_fn(servable_name, server, timeout_secs, version_label, version):
"""Wraps function to make grpc requests with runtime args."""
stub = _create_stub(server)

def _make_grpc_request(examples):
"""Builds and sends request to TensorFlow model server."""
request = predict_pb2.PredictRequest()
request.model_spec.name = servable_name

if version_label is not None:
request.model_spec.version_label = version_label
elif version is not None:
request.model_spec.version = version

request.inputs["input"].CopyFrom(
tf.make_tensor_proto(
[ex.SerializeToString() for ex in examples], shape=[len(examples)]))
Expand Down

0 comments on commit d40ac6c

Please sign in to comment.