From 96f370f5c48df68186f5c5aea3a049a7dbc31143 Mon Sep 17 00:00:00 2001 From: Wessel Bruinsma Date: Sun, 21 Apr 2024 15:04:26 +0200 Subject: [PATCH] Fix Python 3.8 --- lab/types.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/lab/types.py b/lab/types.py index f274a3a..868e63a 100644 --- a/lab/types.py +++ b/lab/types.py @@ -68,7 +68,9 @@ def _module_attr(module, attr): # Define TensorFlow module types. _tf_tensor = ModuleType("tensorflow", "Tensor") _tf_indexedslices = ModuleType("tensorflow", "IndexedSlices") -_tf_kerastensor = ModuleType("keras", "KerasTensor") +# On Python 3.9 and higher, we also need to support `keras.KerasTensor`. +if sys.version_info >= (3, 9): + _tf_kerastensor = ModuleType("keras", "KerasTensor") _tf_variable = ModuleType("tensorflow", "Variable") _tf_dtype = ModuleType("tensorflow", "DType") _tf_randomstate = ModuleType("tensorflow", "random.Generator") @@ -107,7 +109,10 @@ def _module_attr(module, attr): NPNumeric = set_union_alias(NPNumeric, "B.NPNumeric") AGNumeric = Union[_ag_tensor] AGNumeric = set_union_alias(AGNumeric, "B.AGNumeric") -TFNumeric = Union[_tf_tensor, _tf_variable, _tf_indexedslices, _tf_kerastensor] +if sys.version_info >= (3, 9): + TFNumeric = Union[_tf_tensor, _tf_variable, _tf_indexedslices, _tf_kerastensor] +else: + TFNumeric = Union[_tf_tensor, _tf_variable, _tf_indexedslices] TFNumeric = set_union_alias(TFNumeric, "B.TFNumeric") TorchNumeric = Union[_torch_tensor] TorchNumeric = set_union_alias(TorchNumeric, "B.TorchNumeric")