Skip to content

Commit

Permalink
Fix Python 3.8
Browse files Browse the repository at this point in the history
  • Loading branch information
wesselb committed Apr 21, 2024
1 parent f16beb7 commit 96f370f
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions lab/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ def _module_attr(module, attr):
# Define TensorFlow module types.
_tf_tensor = ModuleType("tensorflow", "Tensor")
_tf_indexedslices = ModuleType("tensorflow", "IndexedSlices")
_tf_kerastensor = ModuleType("keras", "KerasTensor")
# On Python 3.9 and higher, we also need to support `keras.KerasTensor`.
if sys.version_info >= (3, 9):
_tf_kerastensor = ModuleType("keras", "KerasTensor")
_tf_variable = ModuleType("tensorflow", "Variable")
_tf_dtype = ModuleType("tensorflow", "DType")
_tf_randomstate = ModuleType("tensorflow", "random.Generator")
Expand Down Expand Up @@ -107,7 +109,10 @@ def _module_attr(module, attr):
NPNumeric = set_union_alias(NPNumeric, "B.NPNumeric")
AGNumeric = Union[_ag_tensor]
AGNumeric = set_union_alias(AGNumeric, "B.AGNumeric")
TFNumeric = Union[_tf_tensor, _tf_variable, _tf_indexedslices, _tf_kerastensor]
if sys.version_info >= (3, 9):
TFNumeric = Union[_tf_tensor, _tf_variable, _tf_indexedslices, _tf_kerastensor]
else:
TFNumeric = Union[_tf_tensor, _tf_variable, _tf_indexedslices]
TFNumeric = set_union_alias(TFNumeric, "B.TFNumeric")
TorchNumeric = Union[_torch_tensor]
TorchNumeric = set_union_alias(TorchNumeric, "B.TorchNumeric")
Expand Down

0 comments on commit 96f370f

Please sign in to comment.