Skip to content

Commit

Permalink
fix some unit tests, remove unnecessary code
Browse files Browse the repository at this point in the history
  • Loading branch information
ajtmccarty committed Dec 2, 2024
1 parent 7202704 commit 8b745aa
Show file tree
Hide file tree
Showing 8 changed files with 195 additions and 39 deletions.
15 changes: 4 additions & 11 deletions backend/infrahub/core/diff/conflicts_enricher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from infrahub.core.constants import DiffAction, RelationshipCardinality
from infrahub.core.constants.database import DatabaseEdgeType

from .managed_relationship_checker import ManagedRelationshipChecker
from .model.path import (
EnrichedDiffAttribute,
EnrichedDiffConflict,
Expand All @@ -16,8 +15,7 @@


class ConflictsEnricher:
def __init__(self, managed_relationship_checker: ManagedRelationshipChecker) -> None:
self.managed_relationship_checker = managed_relationship_checker
def __init__(self) -> None:
self._base_branch_name: str | None = None
self._diff_branch_name: str | None = None

Expand All @@ -38,7 +36,6 @@ async def add_conflicts_to_branch_diff(
) -> None:
self._base_branch_name = branch_diff_root.base_branch_name
self._diff_branch_name = branch_diff_root.diff_branch_name
self.managed_relationship_checker.reset()

base_node_map = {n.uuid: n for n in base_diff_root.nodes}
branch_node_map = {n.uuid: n for n in branch_diff_root.nodes}
Expand Down Expand Up @@ -82,7 +79,6 @@ def _add_node_conflicts(self, base_node: EnrichedDiffNode, branch_node: Enriched
self._add_relationship_conflicts(
base_relationship=base_relationship,
branch_relationship=branch_relationship,
node_kind=branch_node.kind,
)
else:
branch_relationship.clear_conflicts()
Expand Down Expand Up @@ -137,11 +133,8 @@ def _add_attribute_conflicts(
)

def _add_relationship_conflicts(
self, base_relationship: EnrichedDiffRelationship, branch_relationship: EnrichedDiffRelationship, node_kind: str
self, base_relationship: EnrichedDiffRelationship, branch_relationship: EnrichedDiffRelationship
) -> None:
is_managed = self.managed_relationship_checker.check(
node_kind=node_kind, relationship_name=branch_relationship.name
)
is_cardinality_one = branch_relationship.cardinality is RelationshipCardinality.ONE
if is_cardinality_one:
if not base_relationship.relationships or not branch_relationship.relationships:
Expand All @@ -153,7 +146,7 @@ def _add_relationship_conflicts(
base_element=base_element,
branch_element=branch_element,
is_cardinality_one=is_cardinality_one,
is_managed=is_managed,
is_managed=branch_relationship.is_managed,
)
return
base_peer_id_map = {element.peer_id: element for element in base_relationship.relationships}
Expand All @@ -168,7 +161,7 @@ def _add_relationship_conflicts(
base_element=base_element,
branch_element=branch_element,
is_cardinality_one=is_cardinality_one,
is_managed=is_managed,
is_managed=branch_relationship.is_managed,
)

def _add_relationship_conflicts_for_one_peer(
Expand Down
21 changes: 6 additions & 15 deletions backend/infrahub/core/diff/query_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,10 +558,7 @@ def parse(self, include_unchanged: bool = False) -> None:
def _parse_path(self, database_path: DatabasePath) -> None:
diff_root = self._get_diff_root(database_path=database_path)
diff_node = self._get_diff_node(database_path=database_path, diff_root=diff_root)
is_base_branch = False
if self.base_branch_name != self.diff_branch_name and diff_root.branch == self.base_branch_name:
is_base_branch = True
self._update_attribute_level(database_path=database_path, diff_node=diff_node, is_base_branch=is_base_branch)
self._update_attribute_level(database_path=database_path, diff_node=diff_node)

def _get_diff_root(self, database_path: DatabasePath) -> DiffRootIntermediate:
branch = database_path.deepest_branch
Expand Down Expand Up @@ -595,9 +592,7 @@ def _get_relationship_schema(
return rel_schema
return None

def _update_attribute_level(
self, database_path: DatabasePath, diff_node: DiffNodeIntermediate, is_base_branch: bool
) -> None:
def _update_attribute_level(self, database_path: DatabasePath, diff_node: DiffNodeIntermediate) -> None:
node_schema = self.db.schema.get(
name=database_path.node_kind, branch=database_path.deepest_branch, duplicate=False
)
Expand All @@ -608,9 +603,7 @@ def _update_attribute_level(
relationship_schema = self._get_relationship_schema(database_path=database_path, node_schema=node_schema)
if not relationship_schema:
return
diff_relationship = self._get_diff_relationship(
diff_node=diff_node, relationship_schema=relationship_schema, is_base_branch=is_base_branch
)
diff_relationship = self._get_diff_relationship(diff_node=diff_node, relationship_schema=relationship_schema)
diff_relationship.add_path(
database_path=database_path, diff_from_time=self.from_time, diff_to_time=self.to_time
)
Expand Down Expand Up @@ -647,15 +640,13 @@ def _update_attribute_property(
)

def _get_diff_relationship(
self, diff_node: DiffNodeIntermediate, relationship_schema: RelationshipSchema, is_base_branch: bool
self, diff_node: DiffNodeIntermediate, relationship_schema: RelationshipSchema
) -> DiffRelationshipIntermediate:
diff_relationship = diff_node.relationships_by_name.get(relationship_schema.name)
if not diff_relationship:
is_managed = False
if is_base_branch and self.managed_relationship_checker.check(
is_managed = self.managed_relationship_checker.check(
node_kind=diff_node.kind, relationship_name=relationship_schema.name
):
is_managed = True
)
diff_relationship = DiffRelationshipIntermediate(
name=relationship_schema.name,
cardinality=relationship_schema.cardinality,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
from infrahub.core.diff.conflicts_enricher import ConflictsEnricher
from infrahub.dependencies.interface import DependencyBuilder, DependencyBuilderContext

from .managed_relationship_checker import ManagedRelationshipCheckerDependency


class DiffConflictsEnricherDependency(DependencyBuilder[ConflictsEnricher]):
@classmethod
def build(cls, context: DependencyBuilderContext) -> ConflictsEnricher:
return ConflictsEnricher(
managed_relationship_checker=ManagedRelationshipCheckerDependency.build(context=context)
)
return ConflictsEnricher()
25 changes: 24 additions & 1 deletion backend/tests/integration/ipam/test_ipam_rebase_reconcile.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,21 @@

from infrahub.database import InfrahubDatabase


BRANCH_CREATE = """
mutation($branch: String!) {
BranchCreate(data: {
name: $branch
}) {
ok
object {
id
name
}
}
}
"""

CREATE_IPPREFIX = """
mutation CreatePrefix($prefix: String!, $namespace_id: String!) {
IpamIPPrefixCreate(
Expand Down Expand Up @@ -85,7 +100,15 @@ async def test_step02_add_delete_prefix(
initial_dataset,
client: InfrahubClient,
) -> None:
branch = await create_branch(db=db, branch_name="delete_prefix")
gql_params = prepare_graphql_params(db=db, include_subscription=False, branch=registry.default_branch)
result = await graphql(
schema=gql_params.schema,
source=BRANCH_CREATE,
context_value=gql_params.context,
variable_values={"branch": "delete_prefix"},
)
assert not result.errors
branch = await registry.get_branch(db=db, branch="delete_prefix")

gql_params = prepare_graphql_params(db=db, include_subscription=False, branch=registry.default_branch)
result = await graphql(
Expand Down
1 change: 0 additions & 1 deletion backend/tests/unit/core/diff/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ class EnrichedRelationshipGroupFactory(DataclassFactory[EnrichedDiffRelationship
num_conflicts = 0
nodes = set()
contains_conflict = False
is_managed = False


class EnrichedRelationshipElementFactory(DataclassFactory[EnrichedDiffSingleRelationship]):
Expand Down
106 changes: 100 additions & 6 deletions backend/tests/unit/core/diff/test_conflicts_enricher.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import random
from unittest.mock import AsyncMock
from uuid import uuid4

import pytest
Expand All @@ -8,7 +7,6 @@
from infrahub.core.constants import DiffAction, RelationshipCardinality
from infrahub.core.constants.database import DatabaseEdgeType
from infrahub.core.diff.conflicts_enricher import ConflictsEnricher
from infrahub.core.diff.managed_relationship_checker import ManagedRelationshipChecker
from infrahub.core.diff.model.path import EnrichedDiffConflict
from infrahub.core.initialization import create_branch
from infrahub.core.timestamp import Timestamp
Expand All @@ -33,9 +31,7 @@ def setup_method(self):
self.to_time = Timestamp()

async def __call_system_under_test(self, db: InfrahubDatabase, base_enriched_diff, branch_enriched_diff) -> None:
mock_rel_checker = AsyncMock(spec=ManagedRelationshipChecker)
mock_rel_checker.check.return_value = False
conflicts_enricher = ConflictsEnricher(managed_relationship_checker=mock_rel_checker)
conflicts_enricher = ConflictsEnricher()
return await conflicts_enricher.add_conflicts_to_branch_diff(
base_diff_root=base_enriched_diff, branch_diff_root=branch_enriched_diff
)
Expand Down Expand Up @@ -172,7 +168,8 @@ async def test_one_attribute_conflict(self, db: InfrahubDatabase):
else:
assert prop.conflict is None

async def test_cardinality_one_conflicts(self, db: InfrahubDatabase, car_person_schema):
@pytest.mark.parametrize("is_managed", [True, False])
async def test_cardinality_one_conflicts(self, db: InfrahubDatabase, car_person_schema, is_managed: bool):
branch = await create_branch(db=db, branch_name="branch")
property_type = DatabaseEdgeType.IS_RELATED
relationship_name = "owner"
Expand All @@ -199,6 +196,7 @@ async def test_cardinality_one_conflicts(self, db: InfrahubDatabase, car_person_
)
},
cardinality=RelationshipCardinality.ONE,
is_managed=is_managed,
)
}
base_nodes = {
Expand Down Expand Up @@ -227,6 +225,7 @@ async def test_cardinality_one_conflicts(self, db: InfrahubDatabase, car_person_
)
},
cardinality=RelationshipCardinality.ONE,
is_managed=is_managed,
)
}
branch_nodes = {
Expand All @@ -253,6 +252,7 @@ async def test_cardinality_one_conflicts(self, db: InfrahubDatabase, car_person_
node.uuid == node_uuid
and rel.name == relationship_name
and rel_element.peer_id == previous_peer_id
and not is_managed
):
assert rel_element.conflict
assert rel_element.conflict == EnrichedDiffConflict(
Expand All @@ -265,6 +265,8 @@ async def test_cardinality_one_conflicts(self, db: InfrahubDatabase, car_person_
diff_branch_changed_at=branch_conflict_property.changed_at,
selected_branch=None,
)
else:
assert rel_element.conflict is None

async def test_cardinality_many_conflicts(self, db: InfrahubDatabase, car_person_schema):
branch = await create_branch(db=db, branch_name="branch")
Expand Down Expand Up @@ -309,6 +311,7 @@ async def test_cardinality_many_conflicts(self, db: InfrahubDatabase, car_person
EnrichedRelationshipElementFactory.build(peer_id=peer_id_2, properties=base_properties_2),
},
cardinality=RelationshipCardinality.MANY,
is_managed=False,
)
}
base_nodes = {
Expand Down Expand Up @@ -349,6 +352,7 @@ async def test_cardinality_many_conflicts(self, db: InfrahubDatabase, car_person
EnrichedRelationshipElementFactory.build(peer_id=peer_id_2, properties=branch_properties_2),
},
cardinality=RelationshipCardinality.MANY,
is_managed=False,
)
}
branch_nodes = {
Expand Down Expand Up @@ -564,6 +568,7 @@ async def test_manually_fixed_cardinality_one_conflict_cleared(self, db: Infrahu
)
},
cardinality=RelationshipCardinality.ONE,
is_managed=False,
)
}
base_nodes = {
Expand Down Expand Up @@ -596,6 +601,7 @@ async def test_manually_fixed_cardinality_one_conflict_cleared(self, db: Infrahu
)
},
cardinality=RelationshipCardinality.ONE,
is_managed=False,
)
}
branch_nodes = {
Expand Down Expand Up @@ -646,6 +652,7 @@ async def test_unchanged_cardinality_one_clears_conflicts(
)
},
cardinality=RelationshipCardinality.ONE,
is_managed=False,
)
}
base_nodes = {
Expand Down Expand Up @@ -680,6 +687,7 @@ async def test_unchanged_cardinality_one_clears_conflicts(
)
},
cardinality=RelationshipCardinality.ONE,
is_managed=False,
)
}
branch_nodes = {
Expand Down Expand Up @@ -747,6 +755,7 @@ async def test_unchanged_cardinality_many_rel_clears_conflict(self, db: Infrahub
EnrichedRelationshipElementFactory.build(peer_id=peer_id_2, properties=base_properties_2),
},
cardinality=RelationshipCardinality.MANY,
is_managed=False,
)
}
base_nodes = {
Expand Down Expand Up @@ -788,6 +797,7 @@ async def test_unchanged_cardinality_many_rel_clears_conflict(self, db: Infrahub
EnrichedRelationshipElementFactory.build(peer_id=peer_id_2, properties=branch_properties_2),
},
cardinality=RelationshipCardinality.MANY,
is_managed=False,
)
}
branch_nodes = {
Expand Down Expand Up @@ -861,3 +871,87 @@ async def test_manually_fixed_node_conflict_cleared(self, db: InfrahubDatabase):

for node in branch_root.nodes:
assert node.conflict is None

async def test_managed_rel_cannot_create_conflict(self, db: InfrahubDatabase, car_person_schema):
base_action = DiffAction.UPDATED
branch_action = DiffAction.REMOVED
branch = await create_branch(db=db, branch_name="branch")
property_type = DatabaseEdgeType.IS_RELATED
relationship_name = "owner"
node_uuid = str(uuid4())
node_kind = "TestCar"
peer_id = str(uuid4())
base_properties = {
EnrichedPropertyFactory.build(property_type=DatabaseEdgeType.IS_VISIBLE),
EnrichedPropertyFactory.build(property_type=property_type, action=base_action),
}
base_relationships = {
EnrichedRelationshipGroupFactory.build(
name=relationship_name,
relationships={
EnrichedRelationshipElementFactory.build(
peer_id=peer_id, properties=base_properties, action=DiffAction.UPDATED
)
},
cardinality=RelationshipCardinality.ONE,
is_managed=True,
)
}
base_nodes = {
EnrichedNodeFactory.build(
uuid=node_uuid,
kind=node_kind,
action=DiffAction.UPDATED,
relationships=base_relationships,
),
EnrichedNodeFactory.build(relationships=set()),
}
base_root = EnrichedRootFactory.build(nodes=base_nodes)
branch_conflict_property = EnrichedPropertyFactory.build(
property_type=property_type,
previous_value=peer_id,
action=branch_action,
conflict=EnrichedConflictFactory.build(),
)
branch_properties = {
branch_conflict_property,
EnrichedPropertyFactory.build(property_type=DatabaseEdgeType.HAS_OWNER),
}
branch_relationships = {
EnrichedRelationshipGroupFactory.build(
name=relationship_name,
relationships={
EnrichedRelationshipElementFactory.build(
peer_id=peer_id,
properties=branch_properties,
action=DiffAction.REMOVED,
conflict=EnrichedConflictFactory.build(),
)
},
cardinality=RelationshipCardinality.ONE,
is_managed=True,
)
}
branch_nodes = {
EnrichedNodeFactory.build(
uuid=node_uuid,
kind=node_kind,
action=DiffAction.UPDATED,
relationships=branch_relationships,
),
EnrichedNodeFactory.build(relationships=set()),
}
branch_root = EnrichedRootFactory.build(nodes=branch_nodes, diff_branch_name=branch.name)

await self.__call_system_under_test(db=db, base_enriched_diff=base_root, branch_enriched_diff=branch_root)

for node in branch_root.nodes:
assert node.conflict is None
for attribute in node.attributes:
for prop in attribute.properties:
assert prop.conflict is None
for rel in node.relationships:
for rel_element in rel.relationships:
assert rel_element.conflict is None
for prop in rel_element.properties:
assert prop.conflict is None
Loading

0 comments on commit 8b745aa

Please sign in to comment.