diff --git a/tensorflow_federated/python/core/impl/types/type_conversions.py b/tensorflow_federated/python/core/impl/types/type_conversions.py index 273931171b..5e758e19f6 100644 --- a/tensorflow_federated/python/core/impl/types/type_conversions.py +++ b/tensorflow_federated/python/core/impl/types/type_conversions.py @@ -13,9 +13,10 @@ """Utilities for type conversion, type checking, type inference, etc.""" import collections -from collections.abc import Callable +from collections.abc import Callable, Hashable import dataclasses -from typing import Any, Optional +import typing +from typing import Optional, Union import attrs import numpy as np @@ -39,7 +40,7 @@ ) -def infer_type(arg: Any) -> Optional[computation_types.Type]: +def infer_type(arg: object) -> Optional[computation_types.Type]: """Infers the TFF type of the argument (a `computation_types.Type` instance). Warning: This function is only partially implemented. @@ -58,6 +59,26 @@ def infer_type(arg: Any) -> Optional[computation_types.Type]: Either an instance of `computation_types.Type`, or `None` if the argument is `None`. """ + # TODO: b/224484886 - Downcasting to all handled types. + arg = typing.cast( + Union[ + None, + typed_object.TypedObject, + tf.RaggedTensor, + tf.SparseTensor, + tf.Tensor, + structure.Struct, + py_typecheck.SupportsNamedTuple, + dict[Hashable, object], + collections.OrderedDict[Hashable, object], + tuple[object, ...], + list[object], + # Inlined from TF_DATASET_REPRESENTATION_TYPES + tf.data.Dataset, + tf.compat.v1.data.Dataset, + ], + arg, + ) if arg is None: return None elif isinstance(arg, typed_object.TypedObject): @@ -82,8 +103,10 @@ def infer_type(arg: Any) -> Optional[computation_types.Type]: ), tf.SparseTensor, ) - else: + elif isinstance(arg, tf.Tensor): return computation_types.TensorType(arg.dtype.base_dtype, arg.shape) + else: + raise NotImplementedError(f'Unexpected type found: {type(arg)}.') elif isinstance(arg, TF_DATASET_REPRESENTATION_TYPES): element_type = computation_types.to_type(arg.element_spec) return computation_types.SequenceType(element_type) @@ -585,8 +608,8 @@ def _map_element(element): def structure_from_tensor_type_tree( - fn: Callable[[computation_types.TensorType], Any], type_spec -) -> Any: + fn: Callable[[computation_types.TensorType], object], type_spec +) -> object: """Constructs a structure from a `type_spec` tree of `tff.TensorType`s. Args: