From 021558e0784a71b0e5b2d01ccb7af98e917a9b2c Mon Sep 17 00:00:00 2001 From: Hardik Vala Date: Wed, 1 Jun 2022 14:13:06 -0700 Subject: [PATCH] Add DataDescriptor factory method suited for DataBackends. PiperOrigin-RevId: 452381709 --- .../python/core/impl/executors/BUILD | 5 +++ .../core/impl/executors/data_descriptor.py | 26 +++++++++++++- .../impl/executors/data_descriptor_test.py | 35 +++++++++++++++++++ 3 files changed, 65 insertions(+), 1 deletion(-) diff --git a/tensorflow_federated/python/core/impl/executors/BUILD b/tensorflow_federated/python/core/impl/executors/BUILD index 0be7fd3f52..265b21ad8f 100644 --- a/tensorflow_federated/python/core/impl/executors/BUILD +++ b/tensorflow_federated/python/core/impl/executors/BUILD @@ -95,10 +95,12 @@ py_library( deps = [ ":cardinality_carrying_base", ":ingestable_base", + "//tensorflow_federated/proto/v0:computation_py_pb2", "//tensorflow_federated/python/common_libs:py_typecheck", "//tensorflow_federated/python/core/impl/computation:computation_base", "//tensorflow_federated/python/core/impl/types:computation_types", "//tensorflow_federated/python/core/impl/types:placements", + "//tensorflow_federated/python/core/impl/types:type_serialization", ], ) @@ -108,7 +110,10 @@ py_test( python_version = "PY3", srcs_version = "PY3", deps = [ + ":data_backend_base", ":data_descriptor", + ":data_executor", + ":eager_tf_executor", ":executor_stacks", ":executor_test_utils", "//tensorflow_federated/python/core/impl/federated_context:federated_computation", diff --git a/tensorflow_federated/python/core/impl/executors/data_descriptor.py b/tensorflow_federated/python/core/impl/executors/data_descriptor.py index d7acf662ff..5fabcc4021 100644 --- a/tensorflow_federated/python/core/impl/executors/data_descriptor.py +++ b/tensorflow_federated/python/core/impl/executors/data_descriptor.py @@ -19,14 +19,38 @@ """Helper class for representing fully-specified data-yeilding computations.""" import asyncio -from typing import Any, Mapping, Optional +from typing import Any, List, Mapping, Optional +from tensorflow_federated.proto.v0 import computation_pb2 as pb from tensorflow_federated.python.common_libs import py_typecheck from tensorflow_federated.python.core.impl.computation import computation_base from tensorflow_federated.python.core.impl.executors import cardinality_carrying_base from tensorflow_federated.python.core.impl.executors import ingestable_base from tensorflow_federated.python.core.impl.types import computation_types from tensorflow_federated.python.core.impl.types import placements +from tensorflow_federated.python.core.impl.types.type_serialization import serialize_type + + +def CreateDataDescriptor(arg_uris: List[str], arg_type: computation_types.Type): + """Constructs a `DataDescriptor` instance targeting a `tff.DataBackend`. + + Args: + arg_uris: List of URIs compatible with the data backend embedded in the + given `tff.framework.ExecutionContext`. + arg_type: The type of data referenced by the URIs. An instance of + `tff.Type`. + + Returns: + Instance of `DataDescriptor` + """ + arg_type_proto = serialize_type(arg_type) + args = [ + pb.Computation(data=pb.Data(uri=uri), type=arg_type_proto) + for uri in arg_uris + ] + return DataDescriptor( + None, args, computation_types.FederatedType(arg_type, placements.CLIENTS), + len(args)) class CardinalityFreeDataDescriptor(ingestable_base.Ingestable): diff --git a/tensorflow_federated/python/core/impl/executors/data_descriptor_test.py b/tensorflow_federated/python/core/impl/executors/data_descriptor_test.py index cc10c0f967..fdd1639d2a 100644 --- a/tensorflow_federated/python/core/impl/executors/data_descriptor_test.py +++ b/tensorflow_federated/python/core/impl/executors/data_descriptor_test.py @@ -15,7 +15,10 @@ from absl.testing import absltest import tensorflow as tf +from tensorflow_federated.python.core.impl.executors import data_backend_base from tensorflow_federated.python.core.impl.executors import data_descriptor +from tensorflow_federated.python.core.impl.executors import data_executor +from tensorflow_federated.python.core.impl.executors import eager_tf_executor from tensorflow_federated.python.core.impl.executors import executor_stacks from tensorflow_federated.python.core.impl.executors import executor_test_utils from tensorflow_federated.python.core.impl.federated_context import federated_computation @@ -109,6 +112,38 @@ def foo(x): result = foo(ds) self.assertEqual(result, 3000) + def test_create_data_descriptor_for_data_backend(self): + + class TestDataBackend(data_backend_base.DataBackend): + + def __init__(self, value): + self._value = value + + async def materialize(self, data, type_spec): + return self._value + + data_constant = 1 + type_spec = computation_types.TensorType(tf.int32) + + def ex_fn(device): + return data_executor.DataExecutor( + eager_tf_executor.EagerTFExecutor(device), + TestDataBackend(data_constant)) + + factory = executor_stacks.local_executor_factory(leaf_executor_fn=ex_fn) + + @federated_computation.federated_computation( + computation_types.FederatedType(type_spec, placements.CLIENTS)) + def foo(dd): + return intrinsics.federated_sum(dd) + + with executor_test_utils.install_executor(factory): + uris = [f'foo://bar{i}' for i in range(3)] + dd = data_descriptor.CreateDataDescriptor(uris, type_spec) + result = foo(dd) + + self.assertEqual(result, 3) + if __name__ == '__main__': # TFF-CPP does not yet speak `Ingestable`; b/202336418