Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: inheritance #15

Merged
merged 1 commit into from
Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions news/15.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Fix error when a class is extended to add others aggregated classes as base
classes. Before this fix, the builder used the list of all base classes
the final class should have as base classes for every class into the hierarchy.
4 changes: 3 additions & 1 deletion src/extendable/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
class ExtendableClassDef:
name: str
base_names: List[str]
original_base_names: List[str]
namespace: Dict[str, Any]
original_name: str
others_bases: List[Any]
Expand All @@ -44,7 +45,8 @@ def __init__(
self.name = namespace["__xreg_name__"]
self.original_name = original_name
self.others_bases = bases
self.base_names = namespace["__xreg_base_names__"] or []
self.base_names = namespace["__xreg_base_names__"].copy() or []
self.original_base_names = self.base_names.copy()
self.hierarchy = [self]
self.metaclass = metaclass
self.kwargs = kwargs
Expand Down
7 changes: 6 additions & 1 deletion src/extendable/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,18 @@ def build_extendable_class(
# determine all the classes the component should inherit from
bases = LastOrderedSet[main.ExtendableMeta]()
for base_name in cls_def.base_names:
# the base_names contains all the bases for the final aggregated
# class. Here we check that all the base required to build the
# current hierarchy are already build.
if base_name not in self:
if idx != 0 or base_name != cls_def.name:
raise TypeError(
f"Pydnatic class '{name}' extends an non-existing "
f"extendable class '{base_name}'."
)
else:
elif base_name in cls_def.original_base_names:
# The bases to inherit for the current class are the one
# defined into the original class definition.
parent_class = self[base_name]
bases.add(parent_class)
for other_base in class_def.others_bases:
Expand Down
29 changes: 29 additions & 0 deletions tests/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,35 @@ class MyExt(metaclass=MyMeta):
assert MyMeta._is_extendable(MyExt)


def test_mixin_inheritance(test_registry):
class BaseMixin(metaclass=ExtendableMeta):
def test(self):
return "base"

class MixinA(BaseMixin):
def test_a(self):
return "A"

class MixinB(BaseMixin):
def test_b(self):
return "B"

class ExtendedB(MixinB, extends=True):
def test_b(self):
res = super().test_b()
return res + " extended"

class Mixin(MixinA, MixinB, extends=MixinB):
pass

test_registry.init_registry()

obj = Mixin()
assert obj.test() == "base"
assert obj.test_a() == "A"
assert obj.test_b() == "B extended"


def test_issubclass_multi_level(test_registry):
class A(metaclass=ExtendableMeta):
pass
Expand Down
Loading