From 0c88773ed761d9533ca3039baa174a0863779cdb Mon Sep 17 00:00:00 2001 From: Michael Reneer Date: Fri, 1 Sep 2023 14:06:12 -0700 Subject: [PATCH] Remove usage of `Any` from all modules in the `executors` package. In most cases it's possible to simply replace `Any` with `object` but in some cases a more specific type is appropriate. Note this change also fixes any pytype errors caused by adding the correct Python type annotation. PiperOrigin-RevId: 562043784 --- .../core/impl/executors/cardinalities_utils.py | 3 +-- .../impl/executors/cpp_to_python_executor.py | 6 +++--- .../core/impl/executors/data_descriptor.py | 8 ++++---- .../python/core/impl/executors/executor_utils.py | 5 ++--- .../core/impl/executors/executors_errors.py | 6 +++--- .../core/impl/executors/value_serialization.py | 16 ++++++++-------- 6 files changed, 21 insertions(+), 23 deletions(-) diff --git a/tensorflow_federated/python/core/impl/executors/cardinalities_utils.py b/tensorflow_federated/python/core/impl/executors/cardinalities_utils.py index 9575589de2..fefc96dae3 100644 --- a/tensorflow_federated/python/core/impl/executors/cardinalities_utils.py +++ b/tensorflow_federated/python/core/impl/executors/cardinalities_utils.py @@ -14,7 +14,6 @@ """Utilities for cardinality inference and handling.""" from collections.abc import Callable, Mapping -from typing import Any from tensorflow_federated.python.common_libs import py_typecheck from tensorflow_federated.python.common_libs import structure @@ -64,7 +63,7 @@ def __init__(self, value, type_spec): # We define this type here to avoid having to redeclare it wherever we # parameterize by a cardinality inference fn. CardinalityInferenceFnType = Callable[ - [Any, computation_types.Type], Mapping[placements.PlacementLiteral, int] + [object, computation_types.Type], Mapping[placements.PlacementLiteral, int] ] diff --git a/tensorflow_federated/python/core/impl/executors/cpp_to_python_executor.py b/tensorflow_federated/python/core/impl/executors/cpp_to_python_executor.py index f23c4d0e46..c2ec5ed8d6 100644 --- a/tensorflow_federated/python/core/impl/executors/cpp_to_python_executor.py +++ b/tensorflow_federated/python/core/impl/executors/cpp_to_python_executor.py @@ -16,7 +16,7 @@ import asyncio from collections.abc import Sequence import concurrent -from typing import Any, NoReturn, Optional +from typing import NoReturn, Optional from tensorflow_federated.python.common_libs import structure from tensorflow_federated.python.common_libs import tracing @@ -66,7 +66,7 @@ def reference(self) -> int: return self._owned_value_id.ref @tracing.trace - async def compute(self) -> Any: + async def compute(self) -> object: """Pulls protocol buffer out of C++ into Python, and deserializes.""" running_loop = asyncio.get_running_loop() @@ -107,7 +107,7 @@ def __init__( @tracing.trace async def create_value( - self, value: Any, type_signature: computation_types.Type + self, value: object, type_signature: computation_types.Type ) -> CppToPythonExecutorValue: serialized_value, _ = value_serialization.serialize_value( value, type_signature diff --git a/tensorflow_federated/python/core/impl/executors/data_descriptor.py b/tensorflow_federated/python/core/impl/executors/data_descriptor.py index 0cf99c3454..1857b49af5 100644 --- a/tensorflow_federated/python/core/impl/executors/data_descriptor.py +++ b/tensorflow_federated/python/core/impl/executors/data_descriptor.py @@ -15,7 +15,7 @@ import asyncio from collections.abc import Mapping -from typing import Any, Optional +from typing import Optional from tensorflow_federated.proto.v0 import computation_pb2 as pb from tensorflow_federated.python.core.impl.computation import computation_base @@ -75,7 +75,7 @@ class CardinalityFreeDataDescriptor(ingestable_base.Ingestable): def __init__( self, comp: Optional[computation_base.Computation], - arg: Any, + arg: object, arg_type: computation_types.Type, ): """Constructs this data descriptor from the given computation and argument. @@ -97,7 +97,7 @@ def __init__( """ self._comp = comp self._arg = arg - self._arg_type = computation_types.to_type(arg_type) + self._arg_type = arg_type if self._comp is not None: if ( self._comp.type_signature.parameter is None @@ -150,7 +150,7 @@ class DataDescriptor( def __init__( self, comp: Optional[computation_base.Computation], - arg: Any, + arg: object, arg_type: computation_types.Type, cardinality: Optional[int] = None, ): diff --git a/tensorflow_federated/python/core/impl/executors/executor_utils.py b/tensorflow_federated/python/core/impl/executors/executor_utils.py index 21d6025dd2..49e4fa89da 100644 --- a/tensorflow_federated/python/core/impl/executors/executor_utils.py +++ b/tensorflow_federated/python/core/impl/executors/executor_utils.py @@ -13,8 +13,7 @@ # limitations under the License. """Utility functions for writing executors.""" -from typing import Any, Optional - +from typing import Optional from tensorflow_federated.python.common_libs import py_typecheck from tensorflow_federated.python.core.impl.types import computation_types @@ -22,7 +21,7 @@ def reconcile_value_with_type_spec( - value: Any, type_spec: computation_types.Type + value: object, type_spec: computation_types.Type ) -> computation_types.Type: """Reconciles the type of `value` with the given `type_spec`. diff --git a/tensorflow_federated/python/core/impl/executors/executors_errors.py b/tensorflow_federated/python/core/impl/executors/executors_errors.py index f92bea73c4..1f9c030383 100644 --- a/tensorflow_federated/python/core/impl/executors/executors_errors.py +++ b/tensorflow_federated/python/core/impl/executors/executors_errors.py @@ -14,7 +14,7 @@ """Custom exceptions and symbols for TFF executors.""" import typing -from typing import Any, Union +from typing import Union import grpc from typing_extensions import TypeGuard @@ -36,10 +36,10 @@ def code(self) -> grpc.StatusCode: def details(self) -> str: return self._grpc_call.details() - def initial_metadata(self) -> Any: + def initial_metadata(self) -> object: return self._grpc_call.initial_metadata() - def trailing_metadata(self) -> Any: + def trailing_metadata(self) -> object: return self._grpc_call.trailing_metadata() diff --git a/tensorflow_federated/python/core/impl/executors/value_serialization.py b/tensorflow_federated/python/core/impl/executors/value_serialization.py index 121acca1bd..6b66a2adde 100644 --- a/tensorflow_federated/python/core/impl/executors/value_serialization.py +++ b/tensorflow_federated/python/core/impl/executors/value_serialization.py @@ -15,7 +15,7 @@ import collections from collections.abc import Collection, Mapping, Sequence -from typing import Any, Optional, Union +from typing import Optional, Union import numpy as np import tensorflow as tf @@ -37,7 +37,7 @@ from tensorflow_federated.python.core.impl.utils import tensorflow_utils _SerializeReturnType = tuple[executor_pb2.Value, computation_types.Type] -_DeserializeReturnType = tuple[Any, computation_types.Type] +_DeserializeReturnType = tuple[object, computation_types.Type] # The maximum size allowed for serialized sequence values. Sequence that # serialize to values larger than this will result in errors being raised. This @@ -78,7 +78,7 @@ def _value_proto_for_np_array( @tracing.trace def _serialize_tensor_value( - value: Any, type_spec: computation_types.TensorType + value: object, type_spec: computation_types.TensorType ) -> tuple[executor_pb2.Value, computation_types.TensorType]: """Serializes a tensor value into `executor_pb2.Value`. @@ -101,7 +101,7 @@ def _serialize_tensor_value( if tf.is_tensor(value): if isinstance(value, tf.Variable): value = value.read_value() - if tf.executing_eagerly(): + if isinstance(value, tf.Tensor) and tf.executing_eagerly(): value = value.numpy() else: # Attempt to extract the value using the current graph context. @@ -228,7 +228,7 @@ def _check_ordereddict_container_for_struct(type_to_check): @tracing.trace def _serialize_sequence_value( value: Union[ - Union[type_conversions.TF_DATASET_REPRESENTATION_TYPES], list[Any] + Union[type_conversions.TF_DATASET_REPRESENTATION_TYPES], list[object] ], type_spec: computation_types.SequenceType, ) -> _SerializeReturnType: @@ -278,7 +278,7 @@ def _serialize_sequence_value( @tracing.trace def _serialize_struct_type( - struct_typed_value: Any, + struct_typed_value: object, type_spec: computation_types.StructType, ) -> tuple[executor_pb2.Value, computation_types.StructType]: """Serializes a value of tuple type.""" @@ -308,7 +308,7 @@ def _serialize_struct_type( @tracing.trace def _serialize_federated_value( - federated_value: Any, type_spec: computation_types.FederatedType + federated_value: object, type_spec: computation_types.FederatedType ) -> tuple[executor_pb2.Value, computation_types.FederatedType]: """Serializes a value of federated type.""" if type_spec.all_equal: @@ -329,7 +329,7 @@ def _serialize_federated_value( @tracing.trace def serialize_value( - value: Any, + value: object, type_spec: Optional[computation_types.Type] = None, ) -> _SerializeReturnType: """Serializes a value into `executor_pb2.Value`.