Skip to content

Commit

Permalink
Remove usage of Any from all modules in the executors package.
Browse files Browse the repository at this point in the history
In most cases it's possible to simply replace `Any` with `object` but in some cases a more specific type is appropriate.

Note this change also fixes any pytype errors caused by adding the correct Python type annotation.

PiperOrigin-RevId: 562043784
  • Loading branch information
michaelreneer authored and tensorflow-copybara committed Sep 6, 2023
1 parent 21192cc commit 0c88773
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
"""Utilities for cardinality inference and handling."""

from collections.abc import Callable, Mapping
from typing import Any

from tensorflow_federated.python.common_libs import py_typecheck
from tensorflow_federated.python.common_libs import structure
Expand Down Expand Up @@ -64,7 +63,7 @@ def __init__(self, value, type_spec):
# We define this type here to avoid having to redeclare it wherever we
# parameterize by a cardinality inference fn.
CardinalityInferenceFnType = Callable[
[Any, computation_types.Type], Mapping[placements.PlacementLiteral, int]
[object, computation_types.Type], Mapping[placements.PlacementLiteral, int]
]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import asyncio
from collections.abc import Sequence
import concurrent
from typing import Any, NoReturn, Optional
from typing import NoReturn, Optional

from tensorflow_federated.python.common_libs import structure
from tensorflow_federated.python.common_libs import tracing
Expand Down Expand Up @@ -66,7 +66,7 @@ def reference(self) -> int:
return self._owned_value_id.ref

@tracing.trace
async def compute(self) -> Any:
async def compute(self) -> object:
"""Pulls protocol buffer out of C++ into Python, and deserializes."""
running_loop = asyncio.get_running_loop()

Expand Down Expand Up @@ -107,7 +107,7 @@ def __init__(

@tracing.trace
async def create_value(
self, value: Any, type_signature: computation_types.Type
self, value: object, type_signature: computation_types.Type
) -> CppToPythonExecutorValue:
serialized_value, _ = value_serialization.serialize_value(
value, type_signature
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import asyncio
from collections.abc import Mapping
from typing import Any, Optional
from typing import Optional

from tensorflow_federated.proto.v0 import computation_pb2 as pb
from tensorflow_federated.python.core.impl.computation import computation_base
Expand Down Expand Up @@ -75,7 +75,7 @@ class CardinalityFreeDataDescriptor(ingestable_base.Ingestable):
def __init__(
self,
comp: Optional[computation_base.Computation],
arg: Any,
arg: object,
arg_type: computation_types.Type,
):
"""Constructs this data descriptor from the given computation and argument.
Expand All @@ -97,7 +97,7 @@ def __init__(
"""
self._comp = comp
self._arg = arg
self._arg_type = computation_types.to_type(arg_type)
self._arg_type = arg_type
if self._comp is not None:
if (
self._comp.type_signature.parameter is None
Expand Down Expand Up @@ -150,7 +150,7 @@ class DataDescriptor(
def __init__(
self,
comp: Optional[computation_base.Computation],
arg: Any,
arg: object,
arg_type: computation_types.Type,
cardinality: Optional[int] = None,
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,15 @@
# limitations under the License.
"""Utility functions for writing executors."""

from typing import Any, Optional

from typing import Optional

from tensorflow_federated.python.common_libs import py_typecheck
from tensorflow_federated.python.core.impl.types import computation_types
from tensorflow_federated.python.core.impl.types import typed_object


def reconcile_value_with_type_spec(
value: Any, type_spec: computation_types.Type
value: object, type_spec: computation_types.Type
) -> computation_types.Type:
"""Reconciles the type of `value` with the given `type_spec`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""Custom exceptions and symbols for TFF executors."""

import typing
from typing import Any, Union
from typing import Union

import grpc
from typing_extensions import TypeGuard
Expand All @@ -36,10 +36,10 @@ def code(self) -> grpc.StatusCode:
def details(self) -> str:
return self._grpc_call.details()

def initial_metadata(self) -> Any:
def initial_metadata(self) -> object:
return self._grpc_call.initial_metadata()

def trailing_metadata(self) -> Any:
def trailing_metadata(self) -> object:
return self._grpc_call.trailing_metadata()


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import collections
from collections.abc import Collection, Mapping, Sequence
from typing import Any, Optional, Union
from typing import Optional, Union

import numpy as np
import tensorflow as tf
Expand All @@ -37,7 +37,7 @@
from tensorflow_federated.python.core.impl.utils import tensorflow_utils

_SerializeReturnType = tuple[executor_pb2.Value, computation_types.Type]
_DeserializeReturnType = tuple[Any, computation_types.Type]
_DeserializeReturnType = tuple[object, computation_types.Type]

# The maximum size allowed for serialized sequence values. Sequence that
# serialize to values larger than this will result in errors being raised. This
Expand Down Expand Up @@ -78,7 +78,7 @@ def _value_proto_for_np_array(

@tracing.trace
def _serialize_tensor_value(
value: Any, type_spec: computation_types.TensorType
value: object, type_spec: computation_types.TensorType
) -> tuple[executor_pb2.Value, computation_types.TensorType]:
"""Serializes a tensor value into `executor_pb2.Value`.
Expand All @@ -101,7 +101,7 @@ def _serialize_tensor_value(
if tf.is_tensor(value):
if isinstance(value, tf.Variable):
value = value.read_value()
if tf.executing_eagerly():
if isinstance(value, tf.Tensor) and tf.executing_eagerly():
value = value.numpy()
else:
# Attempt to extract the value using the current graph context.
Expand Down Expand Up @@ -228,7 +228,7 @@ def _check_ordereddict_container_for_struct(type_to_check):
@tracing.trace
def _serialize_sequence_value(
value: Union[
Union[type_conversions.TF_DATASET_REPRESENTATION_TYPES], list[Any]
Union[type_conversions.TF_DATASET_REPRESENTATION_TYPES], list[object]
],
type_spec: computation_types.SequenceType,
) -> _SerializeReturnType:
Expand Down Expand Up @@ -278,7 +278,7 @@ def _serialize_sequence_value(

@tracing.trace
def _serialize_struct_type(
struct_typed_value: Any,
struct_typed_value: object,
type_spec: computation_types.StructType,
) -> tuple[executor_pb2.Value, computation_types.StructType]:
"""Serializes a value of tuple type."""
Expand Down Expand Up @@ -308,7 +308,7 @@ def _serialize_struct_type(

@tracing.trace
def _serialize_federated_value(
federated_value: Any, type_spec: computation_types.FederatedType
federated_value: object, type_spec: computation_types.FederatedType
) -> tuple[executor_pb2.Value, computation_types.FederatedType]:
"""Serializes a value of federated type."""
if type_spec.all_equal:
Expand All @@ -329,7 +329,7 @@ def _serialize_federated_value(

@tracing.trace
def serialize_value(
value: Any,
value: object,
type_spec: Optional[computation_types.Type] = None,
) -> _SerializeReturnType:
"""Serializes a value into `executor_pb2.Value`.
Expand Down

0 comments on commit 0c88773

Please sign in to comment.