diff --git a/news/15.bugfix b/news/15.bugfix new file mode 100644 index 0000000..d28742e --- /dev/null +++ b/news/15.bugfix @@ -0,0 +1,3 @@ +Fix error when a class is extended to add others aggregated classes as base +classes. Before this fix, the builder used the list of all base classes +the final class should have as base classes for every class into the hierarchy. diff --git a/src/extendable/main.py b/src/extendable/main.py index 0638288..2360bbc 100644 --- a/src/extendable/main.py +++ b/src/extendable/main.py @@ -24,6 +24,7 @@ class ExtendableClassDef: name: str base_names: List[str] + original_base_names: List[str] namespace: Dict[str, Any] original_name: str others_bases: List[Any] @@ -44,7 +45,8 @@ def __init__( self.name = namespace["__xreg_name__"] self.original_name = original_name self.others_bases = bases - self.base_names = namespace["__xreg_base_names__"] or [] + self.base_names = namespace["__xreg_base_names__"].copy() or [] + self.original_base_names = self.base_names.copy() self.hierarchy = [self] self.metaclass = metaclass self.kwargs = kwargs diff --git a/src/extendable/registry.py b/src/extendable/registry.py index afd1e46..5092048 100644 --- a/src/extendable/registry.py +++ b/src/extendable/registry.py @@ -121,13 +121,18 @@ def build_extendable_class( # determine all the classes the component should inherit from bases = LastOrderedSet[main.ExtendableMeta]() for base_name in cls_def.base_names: + # the base_names contains all the bases for the final aggregated + # class. Here we check that all the base required to build the + # current hierarchy are already build. if base_name not in self: if idx != 0 or base_name != cls_def.name: raise TypeError( f"Pydnatic class '{name}' extends an non-existing " f"extendable class '{base_name}'." ) - else: + elif base_name in cls_def.original_base_names: + # The bases to inherit for the current class are the one + # defined into the original class definition. parent_class = self[base_name] bases.add(parent_class) for other_base in class_def.others_bases: diff --git a/tests/test_simple.py b/tests/test_simple.py index 7691550..6ae6e3a 100644 --- a/tests/test_simple.py +++ b/tests/test_simple.py @@ -258,6 +258,35 @@ class MyExt(metaclass=MyMeta): assert MyMeta._is_extendable(MyExt) +def test_mixin_inheritance(test_registry): + class BaseMixin(metaclass=ExtendableMeta): + def test(self): + return "base" + + class MixinA(BaseMixin): + def test_a(self): + return "A" + + class MixinB(BaseMixin): + def test_b(self): + return "B" + + class ExtendedB(MixinB, extends=True): + def test_b(self): + res = super().test_b() + return res + " extended" + + class Mixin(MixinA, MixinB, extends=MixinB): + pass + + test_registry.init_registry() + + obj = Mixin() + assert obj.test() == "base" + assert obj.test_a() == "A" + assert obj.test_b() == "B extended" + + def test_issubclass_multi_level(test_registry): class A(metaclass=ExtendableMeta): pass