Skip to content

Commit

Permalink
Follow up to moving transform_args and transform_result from `Con…
Browse files Browse the repository at this point in the history
…creteComputation` to the appropriate execution contexts.

* Updated `SyncExecutionContext` to accept `transform_args` and `transform_result` and forward to `AsyncExecutionContext`.
* Removed custom handling of `transform_args` and `transform_result` from `MergeableCompExecutionContext`, this is not needed because this context takes an `AsyncExecutionContext` and not does construct one.
* Updated the execution contexts factories to pass the appropriate `transform_args` and `transform_result` functions to the execution contexts constructors.

PiperOrigin-RevId: 663410296
  • Loading branch information
michaelreneer authored and copybara-github committed Aug 15, 2024
1 parent c10f4dd commit 8d158ec
Show file tree
Hide file tree
Showing 15 changed files with 96 additions and 48 deletions.
6 changes: 6 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ and this project adheres to

## Unreleased

### Added

* `tff.tensorflow.transform_args` and `tff.tnesorflow.transform_result`, these
functions are intended to be used when instantiating and execution context
in a TensorFlow environment.

## Release 0.85.0

### Added
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from tensorflow_federated.python.core.backends.mapreduce import form_utils
from tensorflow_federated.python.core.backends.mapreduce import mapreduce_test_utils
from tensorflow_federated.python.core.environments.tensorflow_backend import tensorflow_computation_factory
from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation
from tensorflow_federated.python.core.impl.compiler import building_block_factory
from tensorflow_federated.python.core.impl.compiler import building_block_test_utils
from tensorflow_federated.python.core.impl.compiler import building_blocks
Expand All @@ -43,7 +44,11 @@

def _create_test_context():
factory = executor_factory.local_cpp_executor_factory()
return sync_execution_context.SyncExecutionContext(executor_fn=factory)
return sync_execution_context.SyncExecutionContext(
executor_fn=factory,
transform_args=tensorflow_computation.transform_args,
transform_result=tensorflow_computation.transform_result,
)


class CheckExtractionResultTest(absltest.TestCase):
Expand Down
2 changes: 2 additions & 0 deletions tensorflow_federated/python/core/backends/native/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ py_library(
deps = [
":compiler",
"//tensorflow_federated/python/core/environments/tensorflow_backend:tensorflow_executor_bindings",
"//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation",
"//tensorflow_federated/python/core/impl/context_stack:set_default_context",
"//tensorflow_federated/python/core/impl/execution_contexts:async_execution_context",
"//tensorflow_federated/python/core/impl/execution_contexts:sync_execution_context",
Expand Down Expand Up @@ -103,6 +104,7 @@ py_library(
deps = [
":compiler",
":mergeable_comp_compiler",
"//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation",
"//tensorflow_federated/python/core/impl/context_stack:context_base",
"//tensorflow_federated/python/core/impl/context_stack:context_stack_impl",
"//tensorflow_federated/python/core/impl/execution_contexts:async_execution_context",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from tensorflow_federated.python.core.backends.native import compiler
from tensorflow_federated.python.core.environments.tensorflow_backend import tensorflow_executor_bindings
from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation
from tensorflow_federated.python.core.impl.context_stack import set_default_context
from tensorflow_federated.python.core.impl.execution_contexts import async_execution_context
from tensorflow_federated.python.core.impl.execution_contexts import sync_execution_context
Expand Down Expand Up @@ -65,7 +66,10 @@ def create_sync_local_cpp_execution_context(
leaf_executor_fn=_create_tensorflow_backend_execution_stack,
)
context = sync_execution_context.SyncExecutionContext(
executor_fn=factory, compiler_fn=compiler.desugar_and_transform_to_native
executor_fn=factory,
compiler_fn=compiler.desugar_and_transform_to_native,
transform_args=tensorflow_computation.transform_args,
transform_result=tensorflow_computation.transform_result,
)
return context

Expand Down Expand Up @@ -120,7 +124,10 @@ def create_async_local_cpp_execution_context(
leaf_executor_fn=_create_tensorflow_backend_execution_stack,
)
context = async_execution_context.AsyncExecutionContext(
executor_fn=factory, compiler_fn=compiler.desugar_and_transform_to_native
executor_fn=factory,
compiler_fn=compiler.desugar_and_transform_to_native,
transform_args=tensorflow_computation.transform_args,
transform_result=tensorflow_computation.transform_result,
)
return context

Expand Down Expand Up @@ -158,7 +165,10 @@ def create_sync_remote_cpp_execution_context(
channels=channels, default_num_clients=default_num_clients
)
context = sync_execution_context.SyncExecutionContext(
executor_fn=factory, compiler_fn=compiler.desugar_and_transform_to_native
executor_fn=factory,
compiler_fn=compiler.desugar_and_transform_to_native,
transform_args=tensorflow_computation.transform_args,
transform_result=tensorflow_computation.transform_result,
)
return context

Expand All @@ -182,6 +192,9 @@ def create_async_remote_cpp_execution_context(
channels=channels, default_num_clients=default_num_clients
)
context = async_execution_context.AsyncExecutionContext(
executor_fn=factory, compiler_fn=compiler.desugar_and_transform_to_native
executor_fn=factory,
compiler_fn=compiler.desugar_and_transform_to_native,
transform_args=tensorflow_computation.transform_args,
transform_result=tensorflow_computation.transform_result,
)
return context
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from tensorflow_federated.python.core.backends.native import compiler
from tensorflow_federated.python.core.backends.native import mergeable_comp_compiler
from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation
from tensorflow_federated.python.core.impl.context_stack import context_base
from tensorflow_federated.python.core.impl.context_stack import context_stack_impl
from tensorflow_federated.python.core.impl.execution_contexts import async_execution_context
Expand Down Expand Up @@ -103,7 +104,10 @@ def create_async_local_cpp_execution_context(
stream_structs=stream_structs,
)
return async_execution_context.AsyncExecutionContext(
executor_fn=factory, compiler_fn=compiler.desugar_and_transform_to_native
executor_fn=factory,
compiler_fn=compiler.desugar_and_transform_to_native,
transform_args=tensorflow_computation.transform_args,
transform_result=tensorflow_computation.transform_result,
)


Expand Down Expand Up @@ -150,7 +154,10 @@ def create_sync_local_cpp_execution_context(
stream_structs=stream_structs,
)
return sync_execution_context.SyncExecutionContext(
executor_fn=factory, compiler_fn=compiler.desugar_and_transform_to_native
executor_fn=factory,
compiler_fn=compiler.desugar_and_transform_to_native,
transform_args=tensorflow_computation.transform_args,
transform_result=tensorflow_computation.transform_result,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@

def _create_test_context():
factory = executor_factory.local_cpp_executor_factory()
context = async_execution_context.AsyncExecutionContext(factory)
context = async_execution_context.AsyncExecutionContext(
executor_fn=factory,
transform_args=tensorflow_computation.transform_args,
transform_result=tensorflow_computation.transform_result,
)
return mergeable_comp_execution_context.MergeableCompExecutionContext(
[context]
)
Expand Down
2 changes: 2 additions & 0 deletions tensorflow_federated/python/core/backends/test/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ py_library(
":compiler",
"//tensorflow_federated/python/core/backends/native:compiler",
"//tensorflow_federated/python/core/environments/tensorflow_backend:tensorflow_executor_bindings",
"//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation",
"//tensorflow_federated/python/core/impl/context_stack:context_stack_impl",
"//tensorflow_federated/python/core/impl/execution_contexts:async_execution_context",
"//tensorflow_federated/python/core/impl/execution_contexts:sync_execution_context",
Expand Down Expand Up @@ -108,6 +109,7 @@ py_library(
deps = [
":compiler",
"//tensorflow_federated/python/core/backends/native:compiler",
"//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation",
"//tensorflow_federated/python/core/impl/context_stack:context_stack_impl",
"//tensorflow_federated/python/core/impl/execution_contexts:async_execution_context",
"//tensorflow_federated/python/core/impl/execution_contexts:sync_execution_context",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from tensorflow_federated.python.core.backends.native import compiler as native_compiler
from tensorflow_federated.python.core.backends.test import compiler as test_compiler
from tensorflow_federated.python.core.environments.tensorflow_backend import tensorflow_executor_bindings
from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation
from tensorflow_federated.python.core.impl.context_stack import context_stack_impl
from tensorflow_federated.python.core.impl.execution_contexts import async_execution_context
from tensorflow_federated.python.core.impl.execution_contexts import sync_execution_context
Expand Down Expand Up @@ -93,7 +94,10 @@ def _compile(comp):
leaf_executor_fn=_create_tensorflow_backend_execution_stack,
)
context = async_execution_context.AsyncExecutionContext(
executor_fn=factory, compiler_fn=_compile
executor_fn=factory,
compiler_fn=_compile,
transform_args=tensorflow_computation.transform_args,
transform_result=tensorflow_computation.transform_result,
)
return context

Expand Down Expand Up @@ -228,6 +232,8 @@ def initialize_channel(self) -> None:
return sync_execution_context.SyncExecutionContext(
executor_fn=ManagedServiceContext(),
compiler_fn=native_compiler.desugar_and_transform_to_native,
transform_args=tensorflow_computation.transform_args,
transform_result=tensorflow_computation.transform_result,
)


Expand Down Expand Up @@ -273,7 +279,10 @@ def _compile(comp):
leaf_executor_fn=_create_tensorflow_backend_execution_stack,
)
context = sync_execution_context.SyncExecutionContext(
executor_fn=factory, compiler_fn=_compile
executor_fn=factory,
compiler_fn=_compile,
transform_args=tensorflow_computation.transform_args,
transform_result=tensorflow_computation.transform_result,
)
return context

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from tensorflow_federated.python.core.backends.native import compiler as native_compiler
from tensorflow_federated.python.core.backends.test import compiler as test_compiler
from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation
from tensorflow_federated.python.core.impl.context_stack import context_stack_impl
from tensorflow_federated.python.core.impl.execution_contexts import async_execution_context
from tensorflow_federated.python.core.impl.execution_contexts import sync_execution_context
Expand Down Expand Up @@ -44,6 +45,8 @@ def _compile(comp):
return async_execution_context.AsyncExecutionContext(
executor_fn=factory,
compiler_fn=_compile,
transform_args=tensorflow_computation.transform_args,
transform_result=tensorflow_computation.transform_result,
)


Expand Down Expand Up @@ -85,6 +88,8 @@ def _compile(comp):
return sync_execution_context.SyncExecutionContext(
executor_fn=factory,
compiler_fn=_compile,
transform_args=tensorflow_computation.transform_args,
transform_result=tensorflow_computation.transform_result,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def create_async_local_cpp_execution_context(
leaf_executor_fn=_create_xla_backend_execution_stack,
)
return async_execution_context.AsyncExecutionContext(
executor_fn=factory, compiler_fn=compiler.transform_to_native_form
executor_fn=factory,
compiler_fn=compiler.transform_to_native_form,
)


Expand Down Expand Up @@ -95,7 +96,8 @@ def create_sync_local_cpp_execution_context(
# computations instead of TensorFlow, similar to "desugar intrinsics" in the
# native backend.
return sync_execution_context.SyncExecutionContext(
executor_fn=factory, compiler_fn=compiler.transform_to_native_form
executor_fn=factory,
compiler_fn=compiler.transform_to_native_form,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,6 @@
# pylint: disable=g-importing-member
from tensorflow_federated.python.core.environments.tensorflow_backend.tensorflow_tree_transformations import replace_intrinsics_with_bodies
from tensorflow_federated.python.core.environments.tensorflow_frontend.tensorflow_computation import tf_computation as computation
from tensorflow_federated.python.core.environments.tensorflow_frontend.tensorflow_computation import transform_args
from tensorflow_federated.python.core.environments.tensorflow_frontend.tensorflow_computation import transform_result
# pylint: enable=g-importing-member
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ py_library(
"//tensorflow_federated/python/core/impl/compiler:building_blocks",
"//tensorflow_federated/python/core/impl/compiler:tree_analysis",
"//tensorflow_federated/python/core/impl/computation:computation_base",
"//tensorflow_federated/python/core/impl/computation:function_utils",
"//tensorflow_federated/python/core/impl/context_stack:context_base",
"//tensorflow_federated/python/core/impl/executors:cardinalities_utils",
"//tensorflow_federated/python/core/impl/types:computation_types",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,19 @@
"""Execution context for single-aggregation computations."""

import asyncio
from collections.abc import Awaitable, Callable, Mapping, Sequence
from collections.abc import Awaitable, Callable, Sequence
import functools
import math
from typing import Generic, Optional, TypeVar, Union

import attrs
import tree

from tensorflow_federated.python.common_libs import async_utils
from tensorflow_federated.python.common_libs import py_typecheck
from tensorflow_federated.python.common_libs import structure
from tensorflow_federated.python.core.impl.compiler import building_blocks
from tensorflow_federated.python.core.impl.compiler import tree_analysis
from tensorflow_federated.python.core.impl.computation import computation_base
from tensorflow_federated.python.core.impl.computation import function_utils
from tensorflow_federated.python.core.impl.context_stack import context_base
from tensorflow_federated.python.core.impl.execution_contexts import compiler_pipeline
from tensorflow_federated.python.core.impl.executors import cardinalities_utils
Expand Down Expand Up @@ -695,36 +693,6 @@ def invoke(
comp, (MergeableCompForm, computation_base.Computation)
)

if arg is not None and self._transform_args is not None:
# `transform_args` is not intended to handle `tff.structure.Struct`.
# Normalize to a Python structure to make it simpler to handle; `args` is
# sometimes a `tff.structure.Struct` and sometimes it is not, other times
# it is a Python structure that contains a `tff.structure.Struct`.
def _to_python(obj):
if isinstance(obj, structure.Struct):
return structure.to_odict_or_tuple(obj)
else:
return None

if isinstance(arg, structure.Struct):
args, kwargs = function_utils.unpack_args_from_struct(arg)
args = tree.traverse(_to_python, args)
args = self._transform_args(args)
if not isinstance(args, Sequence):
raise ValueError(
f'Expected `args` to be a `Sequence`, found {type(args)}'
)
kwargs = tree.traverse(_to_python, kwargs)
kwargs = self._transform_args(kwargs)
if not isinstance(kwargs, Mapping):
raise ValueError(
f'Expected `kwargs` to be a `Mapping`, found {type(kwargs)}'
)
arg = function_utils.pack_args_into_struct(args, kwargs)
else:
arg = tree.traverse(_to_python, arg)
arg = self._transform_args(arg)

if isinstance(comp, computation_base.Computation):
if self._compiler_pipeline is None:
raise ValueError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,19 @@ def __init__(
executor_fn: executor_factory.ExecutorFactory,
compiler_fn: Optional[Callable[[_Computation], object]] = None,
*,
transform_args: Optional[Callable[[object], object]] = None,
transform_result: Optional[Callable[[object], object]] = None,
cardinality_inference_fn: cardinalities_utils.CardinalityInferenceFnType = cardinalities_utils.infer_cardinalities,
):
"""Initializes a synchronous execution context which retries invocations.
Args:
executor_fn: Instance of `executor_factory.ExecutorFactory`.
compiler_fn: A Python function that will be used to compile a computation.
transform_args: An `Optional` `Callable` used to transform the args before
they are passed to the computation.
transform_result: An `Optional` `Callable` used to transform the result
before it is returned.
cardinality_inference_fn: A Python function specifying how to infer
cardinalities from arguments (and their associated types). The value
returned by this function will be passed to the `create_executor` method
Expand All @@ -53,6 +59,8 @@ def __init__(
self._async_context = async_execution_context.AsyncExecutionContext(
executor_fn=executor_fn,
compiler_fn=compiler_fn,
transform_args=transform_args,
transform_result=transform_result,
cardinality_inference_fn=cardinality_inference_fn,
)
self._async_runner = async_utils.AsyncThreadRunner()
Expand Down
Loading

0 comments on commit 8d158ec

Please sign in to comment.