Skip to content

Commit

Permalink
Add DataDescriptor factory method suited for DataBackends.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 452381709
  • Loading branch information
hardik-vala authored and tensorflow-copybara committed Jun 1, 2022
1 parent ef0ec4f commit 021558e
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 1 deletion.
5 changes: 5 additions & 0 deletions tensorflow_federated/python/core/impl/executors/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,12 @@ py_library(
deps = [
":cardinality_carrying_base",
":ingestable_base",
"//tensorflow_federated/proto/v0:computation_py_pb2",
"//tensorflow_federated/python/common_libs:py_typecheck",
"//tensorflow_federated/python/core/impl/computation:computation_base",
"//tensorflow_federated/python/core/impl/types:computation_types",
"//tensorflow_federated/python/core/impl/types:placements",
"//tensorflow_federated/python/core/impl/types:type_serialization",
],
)

Expand All @@ -108,7 +110,10 @@ py_test(
python_version = "PY3",
srcs_version = "PY3",
deps = [
":data_backend_base",
":data_descriptor",
":data_executor",
":eager_tf_executor",
":executor_stacks",
":executor_test_utils",
"//tensorflow_federated/python/core/impl/federated_context:federated_computation",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,38 @@
"""Helper class for representing fully-specified data-yeilding computations."""

import asyncio
from typing import Any, Mapping, Optional
from typing import Any, List, Mapping, Optional

from tensorflow_federated.proto.v0 import computation_pb2 as pb
from tensorflow_federated.python.common_libs import py_typecheck
from tensorflow_federated.python.core.impl.computation import computation_base
from tensorflow_federated.python.core.impl.executors import cardinality_carrying_base
from tensorflow_federated.python.core.impl.executors import ingestable_base
from tensorflow_federated.python.core.impl.types import computation_types
from tensorflow_federated.python.core.impl.types import placements
from tensorflow_federated.python.core.impl.types.type_serialization import serialize_type


def CreateDataDescriptor(arg_uris: List[str], arg_type: computation_types.Type):
"""Constructs a `DataDescriptor` instance targeting a `tff.DataBackend`.
Args:
arg_uris: List of URIs compatible with the data backend embedded in the
given `tff.framework.ExecutionContext`.
arg_type: The type of data referenced by the URIs. An instance of
`tff.Type`.
Returns:
Instance of `DataDescriptor`
"""
arg_type_proto = serialize_type(arg_type)
args = [
pb.Computation(data=pb.Data(uri=uri), type=arg_type_proto)
for uri in arg_uris
]
return DataDescriptor(
None, args, computation_types.FederatedType(arg_type, placements.CLIENTS),
len(args))


class CardinalityFreeDataDescriptor(ingestable_base.Ingestable):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
from absl.testing import absltest
import tensorflow as tf

from tensorflow_federated.python.core.impl.executors import data_backend_base
from tensorflow_federated.python.core.impl.executors import data_descriptor
from tensorflow_federated.python.core.impl.executors import data_executor
from tensorflow_federated.python.core.impl.executors import eager_tf_executor
from tensorflow_federated.python.core.impl.executors import executor_stacks
from tensorflow_federated.python.core.impl.executors import executor_test_utils
from tensorflow_federated.python.core.impl.federated_context import federated_computation
Expand Down Expand Up @@ -109,6 +112,38 @@ def foo(x):
result = foo(ds)
self.assertEqual(result, 3000)

def test_create_data_descriptor_for_data_backend(self):

class TestDataBackend(data_backend_base.DataBackend):

def __init__(self, value):
self._value = value

async def materialize(self, data, type_spec):
return self._value

data_constant = 1
type_spec = computation_types.TensorType(tf.int32)

def ex_fn(device):
return data_executor.DataExecutor(
eager_tf_executor.EagerTFExecutor(device),
TestDataBackend(data_constant))

factory = executor_stacks.local_executor_factory(leaf_executor_fn=ex_fn)

@federated_computation.federated_computation(
computation_types.FederatedType(type_spec, placements.CLIENTS))
def foo(dd):
return intrinsics.federated_sum(dd)

with executor_test_utils.install_executor(factory):
uris = [f'foo://bar{i}' for i in range(3)]
dd = data_descriptor.CreateDataDescriptor(uris, type_spec)
result = foo(dd)

self.assertEqual(result, 3)


if __name__ == '__main__':
# TFF-CPP does not yet speak `Ingestable`; b/202336418
Expand Down

0 comments on commit 021558e

Please sign in to comment.