Skip to content

Commit

Permalink
Fix TypeError: isinstance() arg 2 must be a type or tuple of types wh…
Browse files Browse the repository at this point in the history
…en records have Enum fields
  • Loading branch information
Tinitto committed Feb 23, 2023
1 parent 171791b commit 9ccd93e
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 6 deletions.
94 changes: 90 additions & 4 deletions funml/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Common utility functions"""
from __future__ import annotations
import datetime
import importlib
import functools
import re
import string
Expand All @@ -10,7 +11,6 @@
import random
from typing import Any, Dict, Tuple, List, Set, Union


_compound_type_regex = re.compile(r"tuple|list|set|dict")
_compound_type_generic_type_map = {
"tuple": "Tuple",
Expand Down Expand Up @@ -38,7 +38,10 @@ def is_type(value: Any, cls: Any) -> bool:
return True

_type = _extract_type(_type)
return isinstance(value, _type) or value == _type
try:
return isinstance(value, _type) or value == _type
except Exception as exp:
raise exp


def _extract_type(annotation: Any):
Expand Down Expand Up @@ -205,7 +208,7 @@ def get_cls_annotations(cls, *, globals=None, locals=None, eval_str=False):
return_value = {
key: value
if not isinstance(value, str)
else _parse_type_str(_to_generic(value), globals, locals)
else _parse_type_str(_to_generic(value), globals, locals, module_name)
for key, value in ann.items()
}
return return_value
Expand Down Expand Up @@ -261,12 +264,95 @@ def _parse_type_str(
value: str,
__globals: dict[str, Any] | None = ...,
__locals: typing.Mapping[str, Any] | None = ...,
__module_name: str = "",
):
"""Converts a type value expressed as a string into its python value."""
try:
return eval(value, __globals, __locals)
parsed_value = eval(value, __globals, __locals)
if isinstance(parsed_value, str):
v = LazyImport(module_name=__module_name, name=parsed_value)
return v
return parsed_value
except TypeError as exp:
if "not subscriptable" in f"{exp}":
# ignore types that are not supported in the given version
return Any
raise exp


class LazyImport:
"""An import that is lazily loaded.
This is very vital especially when using annotations in records
Args:
module_name: the name of the module from which the import is coming
name: the name being imported from that module
"""

def __init__(self, module_name: str, name: str):
self._module_name = module_name
self._name = name
self._value = None

@property
def full_name(self):
"""the qualified name of the thing being imported"""
return f"{self._module_name}.{self._name}"

def __call__(self, *args, **kwargs) -> Any:
"""What is called when this import is called like a callable i.e. a class or function.
Args:
args: the arguments passed to the imported object
kwargs: the key-word arguments passed to the imported value
Returns:
the output of calling the imported object
Raises:
ImportError: Import can't find module or can't find name in module
"""
if self._value is None:
module = importlib.import_module(self._module_name)
self._value = getattr(module, self._name, None)

if self._value is not None:
return self._value(*args, **kwargs)

raise ImportError(self._module_name, name=self._name)

def __instancecheck__(self, instance: Any) -> bool:
"""Checks to see that the instance passed is an instance of the imported object.
Args:
instance: the object being checked
Returns:
the boolean indicating whether the instance is an instance of the imported object
"""
try:
super_classes = [
f"{v.__module__}.{v.__name__}" for v in instance.__class__.mro()
]
return self.full_name in super_classes
except AttributeError:
return False

def __eq__(self, other: Any) -> bool:
"""Checks the equality of the imported object and any other object.
It is able to check for equality between itself and an eagerly imported object.
Args:
other: the object that is either equal or not to current object.
Returns:
the boolean indicating whether the other is equal to the current object.
"""
if isinstance(other, LazyImport):
return self.full_name == other.full_name
try:
return self.full_name == other.__qualname__
except AttributeError:
return False
15 changes: 13 additions & 2 deletions tests/test_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@

from collections.abc import Callable
from os import PathLike
from pathlib import PurePath
from typing import Optional, List, Any

import pytest

from funml import record, to_dict
from funml import record, to_dict, Enum


def test_records_created():
Expand Down Expand Up @@ -128,6 +127,12 @@ class Department:
description: ...
func: Callable[[int], Any]
path: bytes | PathLike[bytes] | str = ""
branch: "Branch"

class Branch(Enum):
HeadOffice = None
Arua = None
Nebbi = None

echo = lambda v: v

Expand All @@ -139,6 +144,7 @@ class Department:
is_active=False,
description="security",
func=echo,
branch=Branch.Arua,
)
it_dept = Department(
seniors=["Paul"],
Expand All @@ -148,6 +154,7 @@ class Department:
is_active=True,
description="it",
func=echo,
branch=Branch.Nebbi,
)
hr_dept = Department(
seniors=["Stella", "Isingoma"],
Expand All @@ -157,6 +164,7 @@ class Department:
is_active=False,
description=4,
func=echo,
branch=Branch.HeadOffice,
)

another_security_dept = Department(
Expand All @@ -167,6 +175,7 @@ class Department:
is_active=False,
description="security",
func=echo,
branch=Branch.Arua,
)
another_it_dept = Department(
seniors=["Paul"],
Expand All @@ -176,6 +185,7 @@ class Department:
is_active=True,
description="it",
func=echo,
branch=Branch.Nebbi,
)
another_hr_dept = Department(
seniors=["Stella", "Isingoma"],
Expand All @@ -185,6 +195,7 @@ class Department:
is_active=False,
description=4,
func=echo,
branch=Branch.HeadOffice,
)

assert security_dept == another_security_dept
Expand Down

0 comments on commit 9ccd93e

Please sign in to comment.