Skip to content

Commit

Permalink
Fix async RuntimeWarning in the tff.program.NativeFederatedContext.
Browse files Browse the repository at this point in the history
The `RuntimeWarning`'s were being raised by the `tff.program.NativeFederatedContext` when empty structures were encountered. The fix was to pass around coroutine functions rather than coroutines internally and only create the coroutine when (and more importantly if) the coroutine was awaited.

PiperOrigin-RevId: 497211818
  • Loading branch information
michaelreneer authored and tensorflow-copybara committed Dec 22, 2022
1 parent 3ade095 commit fb3538d
Show file tree
Hide file tree
Showing 11 changed files with 428 additions and 367 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ async def _value():
return value

return native_platform.AwaitableValueReference(
_value(), type_conversions.infer_type(value))
_value, type_conversions.infer_type(value))

test_value = collections.OrderedDict(
a=awaitable_value('foo'),
Expand Down
3 changes: 2 additions & 1 deletion tensorflow_federated/python/program/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -258,10 +258,11 @@ py_test(
":native_platform",
":program_test_utils",
":value_reference",
"//tensorflow_federated/python/common_libs:async_utils",
"//tensorflow_federated/python/common_libs:structure",
"//tensorflow_federated/python/core/backends/native:execution_contexts",
"//tensorflow_federated/python/core/impl/computation:computation_base",
"//tensorflow_federated/python/core/impl/context_stack:context_stack_impl",
"//tensorflow_federated/python/core/impl/context_stack:context_base",
"//tensorflow_federated/python/core/impl/federated_context:federated_computation",
"//tensorflow_federated/python/core/impl/federated_context:intrinsics",
"//tensorflow_federated/python/core/impl/tensorflow_context:tensorflow_computation",
Expand Down
14 changes: 8 additions & 6 deletions tensorflow_federated/python/program/dataset_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,10 @@ def __init__(self, datasets: Sequence[tf.data.Dataset],
py_typecheck.check_type(dataset, tf.data.Dataset)
element_spec = datasets[0].element_spec
if dataset.element_spec != element_spec:
raise ValueError('Expected each `tf.data.Dataset` in `datasets` to '
'have the same type specification, found '
f'\'{element_spec}\' and \'{dataset.element_spec}\'.')
raise ValueError(
'Expected each `tf.data.Dataset` in `datasets` to have the same '
f'type specification, found \'{element_spec}\' and '
f'\'{dataset.element_spec}\'.')
py_typecheck.check_type(federated_type, computation_types.FederatedType)

self._datasets = datasets
Expand Down Expand Up @@ -118,9 +119,10 @@ def __init__(self, datasets: Sequence[tf.data.Dataset]):
py_typecheck.check_type(dataset, tf.data.Dataset)
element_spec = datasets[0].element_spec
if dataset.element_spec != element_spec:
raise ValueError('Expected each `tf.data.Dataset` in `datasets` to '
'have the same type specification, found '
f'\'{element_spec}\' and \'{dataset.element_spec}\'.')
raise ValueError(
'Expected each `tf.data.Dataset` in `datasets` to have the same '
f'type specification, found \'{element_spec}\' and '
f'\'{dataset.element_spec}\'.')

self._datasets = datasets
self._federated_type = computation_types.FederatedType(
Expand Down
6 changes: 3 additions & 3 deletions tensorflow_federated/python/program/federated_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class FederatedContext(context_base.SyncContext):
Federated type signature of the `tff.Computation`:
* Server-placed values must be represented by
`tff.program.MaterializableStucture`.
`tff.program.MaterializableStructure`.
* Client-placed values must be represented by structures of values returned
by a `tff.program.FederatedDataSourceIterator`.
Expand All @@ -86,7 +86,7 @@ class FederatedContext(context_base.SyncContext):
we specify the Python representation of values in a manner that can be stated
entirely in the TensorFlow Federated typesystem.
We have choosen to limit the TensorFlow Federated type signatures of invoked
We have chosen to limit the TensorFlow Federated type signatures of invoked
`tff.Computation`s to disallow the returning of client-placed values,
`tff.SequenceTypes`, and `tff.FunctionTypes`, in order to reduced the area
which needs to be supported by federated programs. Below we describe the
Expand Down Expand Up @@ -154,4 +154,4 @@ def check_in_federated_context() -> None:
if not isinstance(context_stack.current, FederatedContext):
raise ValueError(
'Expected the current context to be a `tff.program.FederatedContext`, '
f'found \'{type(context_stack.current)}\'.')
f'found {type(context_stack.current)}.')
4 changes: 2 additions & 2 deletions tensorflow_federated/python/program/federated_context_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def test_raises_type_error_with_type_signature(self, type_signature):
class CheckInFederatedContextTest(parameterized.TestCase):

def test_does_not_raise_value_error_with_context(self):
context = mock.MagicMock(spec=federated_context.FederatedContext)
context = mock.Mock(spec=federated_context.FederatedContext)

with self.assertRaises(ValueError):
federated_context.check_in_federated_context()
Expand Down Expand Up @@ -132,7 +132,7 @@ def test_raises_value_error_with_context_nested(self):
with self.assertRaises(ValueError):
federated_context.check_in_federated_context()

context = mock.MagicMock(spec=federated_context.FederatedContext)
context = mock.Mock(spec=federated_context.FederatedContext)
with context_stack_impl.context_stack.install(context):
try:
federated_context.check_in_federated_context()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -309,15 +309,23 @@ class FileProgramStateManagerLoadTest(parameterized.TestCase,
[[True, program_test_utils.TestMaterializableValueReference(1)], ['a']],
[[np.bool_(True), np.int32(1)], [b'a']]),
('dict',
{'a': True,
'b': program_test_utils.TestMaterializableValueReference(1),
'c': b'a'},
{
'a': True,
'b': program_test_utils.TestMaterializableValueReference(1),
'c': b'a',
},
{'a': np.bool_(True), 'b': np.int32(1), 'c': b'a'}),
('dict_empty', {}, {}),
('dict_nested',
{'x': {'a': True,
'b': program_test_utils.TestMaterializableValueReference(1)},
'y': {'c': 'a'}},
{
'x': {
'a': True,
'b': program_test_utils.TestMaterializableValueReference(1),
},
'y': {
'c': 'a',
},
},
{'x': {'a': np.bool_(True), 'b': np.int32(1)}, 'y': {'c': b'a'}}),
('attr',
program_test_utils.TestAttrObj2(
Expand Down Expand Up @@ -597,15 +605,23 @@ class FileProgramStateManagerSaveTest(parameterized.TestCase,
[[True, program_test_utils.TestMaterializableValueReference(1)], ['a']],
[True, 1, 'a']),
('dict',
{'a': True,
'b': program_test_utils.TestMaterializableValueReference(1),
'c': 'a'},
{
'a': True,
'b': program_test_utils.TestMaterializableValueReference(1),
'c': 'a',
},
[True, 1, 'a']),
('dict_empty', {}, []),
('dict_nested',
{'x': {'a': True,
'b': program_test_utils.TestMaterializableValueReference(1)},
'y': {'c': 'a'}},
{
'x': {
'a': True,
'b': program_test_utils.TestMaterializableValueReference(1),
},
'y': {
'c': 'a',
},
},
[True, 1, 'a']),
('attr',
program_test_utils.TestAttrObj2(
Expand Down
120 changes: 80 additions & 40 deletions tensorflow_federated/python/program/native_platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
"""A federated platform implemented using native TFF components."""

import asyncio
from collections.abc import Awaitable
import inspect
from collections.abc import Awaitable, Callable
import functools
import typing
from typing import Any, Optional, TypeVar, Union

Expand All @@ -33,6 +33,7 @@


# pyformat: disable
_MaterializedValueFn = Callable[[], Awaitable[value_reference.MaterializedValue]]
_T = TypeVar('_T')
# This type defines values of type `_T` nested in a structure of
# `tff.structure.Struct`'s.
Expand All @@ -46,23 +47,25 @@


class AwaitableValueReference(value_reference.MaterializableValueReference):
"""A `tff.program.MaterializableValueReference` backed by an `Awaitable`."""
"""A `tff.program.MaterializableValueReference` backed by a coroutine function."""

def __init__(self, awaitable: Awaitable[value_reference.MaterializedValue],
def __init__(self, fn: _MaterializedValueFn,
type_signature: value_reference.MaterializableTypeSignature):
"""Returns an initialized `tff.program.AwaitableValueReference`.
Args:
awaitable: An `Awaitable` that returns the referenced value.
fn: A function that returns an `Awaitable` representing the referenced
value.
type_signature: The `tff.Type` of this object.
"""
if not inspect.isawaitable(awaitable):
raise TypeError(f'Expected a `Awaitable`, found {type(awaitable)}')
if not callable(fn):
raise TypeError(
f'Expected a function that returns an `Awaitable`, found {type(fn)}')
py_typecheck.check_type(
type_signature,
typing.get_args(value_reference.MaterializableTypeSignature))

self._awaitable = awaitable
self._fn = fn
self._type_signature = type_signature
self._value = None

Expand All @@ -74,7 +77,7 @@ def type_signature(self) -> value_reference.MaterializableTypeSignature:
async def get_value(self) -> value_reference.MaterializedValue:
"""Returns the referenced value as a numpy scalar or array."""
if self._value is None:
self._value = await self._awaitable
self._value = await self._fn()
return self._value

def __eq__(self, other: Any) -> bool:
Expand All @@ -83,55 +86,90 @@ def __eq__(self, other: Any) -> bool:
elif not isinstance(other, AwaitableValueReference):
return NotImplemented
return (self._type_signature == other._type_signature and
self._awaitable == other._awaitable)
self._fn == other._fn)


def _wrap_in_shared_awaitable(
fn: Callable[..., Awaitable[Any]]
) -> Callable[..., async_utils.SharedAwaitable]:
"""Wraps the returned awaitable in a `tff.async_utils.SharedAwaitable`.
Args:
fn: A function that returns an `Awaitable`.
Returns:
A function that returns a `tff.async_utils.SharedAwaitable`
"""
if not callable(fn):
raise TypeError(
f'Expected a function that returns an `Awaitable`, found {type(fn)}')

@functools.cache
def wrapper(*args: Any, **kwargs: Any) -> async_utils.SharedAwaitable:
awaitable = fn(*args, **kwargs)
return async_utils.SharedAwaitable(awaitable)

return wrapper


def _create_structure_of_awaitable_references(
awaitable: Awaitable[value_reference.MaterializedStructure],
type_signature: computation_types.Type
fn: _MaterializedValueFn, type_signature: computation_types.Type
) -> _StructStructure[AwaitableValueReference]:
"""Returns a structure of `tff.program.AwaitableValueReference`s."""
if not inspect.isawaitable(awaitable):
raise TypeError(f'Expected an `Awaitable`, found {type(awaitable)}')
"""Returns a structure of `tff.program.AwaitableValueReference`s.
Args:
fn: A function that returns an `Awaitable` used to create the structure of
`tff.program.AwaitableValueReference`s.
type_signature: The `tff.Type` of the value returned by `coro_fn`; must
contain only structures, server-placed values, or tensors.
Raises:
NotImplementedError: If `type_signature` contains an unexpected type.
"""
if not callable(fn):
raise TypeError(
f'Expected a function that returns an `Awaitable`, found {type(fn)}')
py_typecheck.check_type(type_signature, computation_types.Type)

# A `async_utils.SharedAwaitable` is required to materialize structures of
# values multiple times. This happens when a value is released using multiple
# `tff.program.ReleaseManager`s.
fn = _wrap_in_shared_awaitable(fn)

if type_signature.is_struct():

async def _to_structure(
awaitable: Awaitable[structure.Struct]) -> structure.Struct:
return structure.from_container(await awaitable)
async def _to_structure(fn: _MaterializedValueFn) -> structure.Struct:
value = await fn()
return structure.from_container(value)

fn = functools.partial(_to_structure, fn)

awaitable = _to_structure(awaitable)
# A `async_utils.SharedAwaitable` is required to materialize structures of
# values sequentially or concurrently. This happens when `get_item` is
# invoked for each element.
shared_awaitable = async_utils.SharedAwaitable(awaitable)
# A `tff.async_utils.SharedAwaitable` is required to materialize structures
# of values concurrently. This happens when the structure is flattened and
# the `tff.program.AwaitableValueReference`s are materialized concurrently,
# see `tff.program.materialize_value` for an example.
fn = _wrap_in_shared_awaitable(fn)

async def _get_item(awaitable: Awaitable[structure.Struct],
async def _get_item(fn: _MaterializedValueFn,
index: int) -> value_reference.MaterializedValue:
value = await awaitable
value = await fn()
return value[index]

elements = []
element_types = structure.iter_elements(type_signature)
for index, (name, element_type) in enumerate(element_types):
element_awaitable = _get_item(shared_awaitable, index)
# A `async_utils.SharedAwaitable` is required to materialize structures of
# values multiple times. This happens when a value is released using
# multiple `tff.program.ReleaseManager`s.
element_shared_awaitable = async_utils.SharedAwaitable(element_awaitable)
element_fn = functools.partial(_get_item, fn, index)
element = _create_structure_of_awaitable_references(
element_shared_awaitable, element_type)
element_fn, element_type)
elements.append((name, element))
return structure.Struct(elements)
elif (type_signature.is_federated() and
type_signature.placement == placements.SERVER):
return _create_structure_of_awaitable_references(awaitable,
type_signature.member)
return _create_structure_of_awaitable_references(fn, type_signature.member)
elif type_signature.is_sequence():
return AwaitableValueReference(awaitable, type_signature)
return AwaitableValueReference(fn, type_signature)
elif type_signature.is_tensor():
return AwaitableValueReference(awaitable, type_signature)
return AwaitableValueReference(fn, type_signature)
else:
raise NotImplementedError(f'Unexpected type found: {type_signature}.')

Expand Down Expand Up @@ -205,6 +243,9 @@ def invoke(
The result of invocation; a structure of
`tff.program.MaterializableValueReference`.
Raises:
ValueError: If the result type of the invoked computation does not contain
only structures, server-placed values, or tensors.
Raises:
ValueError: If the result type of `comp` does not contain only structures,
server-placed values, or tensors.
Expand All @@ -213,9 +254,8 @@ def invoke(
result_type = comp.type_signature.result
if not federated_context.contains_only_server_placed_data(result_type):
raise ValueError(
'Expected the result type of the invoked computation to contain only '
'structures, server-placed values, or tensors, found '
f'\'{result_type}\'.')
'Expected the result type of `comp` to contain only structures, '
f'server-placed values, or tensors, found {result_type}.')

async def _invoke(
context: context_base.AsyncContext, comp: computation_base.Computation,
Expand All @@ -226,7 +266,7 @@ async def _invoke(
arg, comp.type_signature.parameter)
return await context.invoke(comp, arg)

result_coro = _invoke(self._context, comp, arg)
result = _create_structure_of_awaitable_references(result_coro, result_type)
coro_fn = functools.partial(_invoke, self._context, comp, arg)
result = _create_structure_of_awaitable_references(coro_fn, result_type)
result = type_conversions.type_to_py_container(result, result_type)
return result
Loading

0 comments on commit fb3538d

Please sign in to comment.