Skip to content

Commit

Permalink
Remove usage of Any from modules in the types package.
Browse files Browse the repository at this point in the history
In most cases it's possible to simply replace `Any` with `object` but in some cases a more specific type is appropriate.

Note this change also fixes any pytype errors caused by adding the correct Python type annotation.

PiperOrigin-RevId: 567069077
  • Loading branch information
michaelreneer authored and tensorflow-copybara committed Sep 20, 2023
1 parent b4b14f1 commit dcc6880
Showing 1 changed file with 29 additions and 6 deletions.
35 changes: 29 additions & 6 deletions tensorflow_federated/python/core/impl/types/type_conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
"""Utilities for type conversion, type checking, type inference, etc."""

import collections
from collections.abc import Callable
from collections.abc import Callable, Hashable
import dataclasses
from typing import Any, Optional
import typing
from typing import Optional, Union

import attrs
import numpy as np
Expand All @@ -39,7 +40,7 @@
)


def infer_type(arg: Any) -> Optional[computation_types.Type]:
def infer_type(arg: object) -> Optional[computation_types.Type]:
"""Infers the TFF type of the argument (a `computation_types.Type` instance).
Warning: This function is only partially implemented.
Expand All @@ -58,6 +59,26 @@ def infer_type(arg: Any) -> Optional[computation_types.Type]:
Either an instance of `computation_types.Type`, or `None` if the argument is
`None`.
"""
# TODO: b/224484886 - Downcasting to all handled types.
arg = typing.cast(
Union[
None,
typed_object.TypedObject,
tf.RaggedTensor,
tf.SparseTensor,
tf.Tensor,
structure.Struct,
py_typecheck.SupportsNamedTuple,
dict[Hashable, object],
collections.OrderedDict[Hashable, object],
tuple[object, ...],
list[object],
# Inlined from TF_DATASET_REPRESENTATION_TYPES
tf.data.Dataset,
tf.compat.v1.data.Dataset,
],
arg,
)
if arg is None:
return None
elif isinstance(arg, typed_object.TypedObject):
Expand All @@ -82,8 +103,10 @@ def infer_type(arg: Any) -> Optional[computation_types.Type]:
),
tf.SparseTensor,
)
else:
elif isinstance(arg, tf.Tensor):
return computation_types.TensorType(arg.dtype.base_dtype, arg.shape)
else:
raise NotImplementedError(f'Unexpected type found: {type(arg)}.')
elif isinstance(arg, TF_DATASET_REPRESENTATION_TYPES):
element_type = computation_types.to_type(arg.element_spec)
return computation_types.SequenceType(element_type)
Expand Down Expand Up @@ -585,8 +608,8 @@ def _map_element(element):


def structure_from_tensor_type_tree(
fn: Callable[[computation_types.TensorType], Any], type_spec
) -> Any:
fn: Callable[[computation_types.TensorType], object], type_spec
) -> object:
"""Constructs a structure from a `type_spec` tree of `tff.TensorType`s.
Args:
Expand Down

0 comments on commit dcc6880

Please sign in to comment.