Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure that deleted nodes are removed from the schema_branch #5183

Merged
merged 3 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 2 additions & 46 deletions backend/infrahub/core/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@

from typing import TYPE_CHECKING, Optional, Union

from infrahub.core.constants import DiffAction, RepositoryInternalStatus
from infrahub.core.constants import RepositoryInternalStatus
from infrahub.core.manager import NodeManager
from infrahub.core.models import SchemaBranchDiff, SchemaUpdateValidationResult
from infrahub.core.models import SchemaUpdateValidationResult
from infrahub.core.protocols import CoreRepository
from infrahub.core.registry import registry
from infrahub.core.schema import GenericSchema, NodeSchema
from infrahub.core.timestamp import Timestamp
from infrahub.exceptions import ValidationError
from infrahub.message_bus import messages
Expand Down Expand Up @@ -49,8 +48,6 @@ def __init__(
self._destination_schema: Optional[SchemaBranch] = None
self._initial_source_schema: Optional[SchemaBranch] = None

self.schema_diff: Optional[SchemaBranchDiff] = None

self._service = service

@property
Expand Down Expand Up @@ -106,46 +103,6 @@ async def has_schema_changes(self) -> bool:
graph_diff = await self.get_graph_diff()
return await graph_diff.has_schema_changes()

async def get_schema_diff(self) -> SchemaBranchDiff:
"""Return a SchemaBranchDiff object with the list of nodes and generics
based on the information returned by the Graph Diff.

The Graph Diff return a list of UUID so we need to convert that back into Kind
"""

if self.schema_diff:
return self.schema_diff

graph_diff = await self.get_graph_diff()
schema_summary = await graph_diff.get_schema_summary()
schema_diff = SchemaBranchDiff()

# NOTE At this point there is no Generic in the schema but this could change in the future
for element in schema_summary.get(self.source_branch.name, []):
if element.kind == "SchemaNode" and DiffAction.REMOVED in element.actions:
continue
node = self.source_schema.get_by_any_id(id=element.node)
if isinstance(node, NodeSchema):
schema_diff.nodes.append(node.kind)
elif isinstance(node, GenericSchema):
schema_diff.generics.append(node.kind)

for element in schema_summary.get(self.destination_branch.name, []):
if element.kind == "SchemaNode" and DiffAction.REMOVED in element.actions:
continue
node = self.destination_schema.get_by_any_id(id=element.node)
if isinstance(node, NodeSchema):
schema_diff.nodes.append(node.kind)
elif isinstance(node, GenericSchema):
schema_diff.generics.append(node.kind)

# Remove duplicates if any
schema_diff.nodes = list(set(schema_diff.nodes))
schema_diff.generics = list(set(schema_diff.generics))

self.schema_diff = schema_diff
return self.schema_diff

async def update_schema(self) -> bool:
"""After the merge, if there was some changes, we need to:
- update the schema in the registry
Expand All @@ -154,7 +111,6 @@ async def update_schema(self) -> bool:

# NOTE we need to revisit how to calculate an accurate diff to pull only what needs to be updated from the schema
# for now the best solution is to pull everything to ensure the integrity of the schema
# schema_diff = await self.get_schema_diff()

if not await self.has_schema_changes():
return False
Expand Down
42 changes: 33 additions & 9 deletions backend/infrahub/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,12 @@ def __str__(self) -> str:


class SchemaBranchDiff(BaseModel):
nodes: list[str] = Field(default_factory=list)
generics: list[str] = Field(default_factory=list)
added_nodes: list[str] = Field(default_factory=list)
changed_nodes: list[str] = Field(default_factory=list)
added_generics: list[str] = Field(default_factory=list)
changed_generics: list[str] = Field(default_factory=list)
removed_nodes: list[str] = Field(default_factory=list)
removed_generics: list[str] = Field(default_factory=list)

def to_string(self) -> str:
return ", ".join(self.nodes + self.generics)
Expand All @@ -40,25 +44,45 @@ def to_list(self) -> list[str]:

@property
def has_diff(self) -> bool:
if self.nodes or self.generics:
return True
return False
return any([self.has_node_diff, self.has_generic_diff])

@property
def has_node_diff(self) -> bool:
return bool(self.added_nodes + self.changed_nodes + self.removed_nodes)

@property
def has_generic_diff(self) -> bool:
return bool(self.added_generics + self.changed_generics + self.removed_generics)

@property
def nodes(self) -> list[str]:
"""Return nodes that are still active."""
return self.added_nodes + self.changed_nodes

@property
def generics(self) -> list[str]:
"""Return generics that are still active."""
return self.added_generics + self.changed_generics


class SchemaBranchHash(BaseModel):
main: str
nodes: dict[str, str] = Field(default_factory=dict)
generics: dict[str, str] = Field(default_factory=dict)

def compare(self, other: SchemaBranchHash) -> Optional[SchemaBranchDiff]:
def compare(self, other: SchemaBranchHash) -> SchemaBranchDiff | None:
if other.main == self.main:
return None

return SchemaBranchDiff(
nodes=[key for key, value in other.nodes.items() if key not in self.nodes or self.nodes[key] != value],
generics=[
key for key, value in other.generics.items() if key not in self.generics or self.generics[key] != value
added_nodes=[key for key in other.nodes if key not in self.nodes],
changed_nodes=[key for key, value in other.nodes.items() if key in self.nodes and self.nodes[key] != value],
removed_nodes=[key for key in self.nodes if key not in other.nodes],
added_generics=[key for key in other.generics if key not in self.generics],
changed_generics=[
key for key, value in other.generics.items() if key in self.generics and self.generics[key] != value
],
removed_generics=[key for key in self.generics if key not in other.generics],
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it might be cleaner if these used sets instead of lists for this logic, but I understand if it is out of scope for this small change

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean all of the objects in SchemaBranchDiff? If so perhaps we can do that once it's brought back into develop?



Expand Down
50 changes: 37 additions & 13 deletions backend/infrahub/core/schema/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,8 @@ async def update_schema_branch(
schema_diff = None
if limit:
schema_diff = SchemaBranchDiff(
nodes=[name for name in list(schema.nodes.keys()) if name in limit],
generics=[name for name in list(schema.generics.keys()) if name in limit],
added_nodes=[name for name in list(schema.nodes.keys()) if name in limit],
added_generics=[name for name in list(schema.generics.keys()) if name in limit],
)

updated_schema = await self.load_schema_from_db(
Expand Down Expand Up @@ -192,32 +192,50 @@ async def update_schema_to_db(

branch = await registry.get_branch(branch=branch, db=db)

item_kinds = []
added_nodes = []
added_generics = []
for item_kind, item_diff in diff.added.items():
item = schema.get(name=item_kind, duplicate=False)
node = await self.load_node_to_db(node=item, branch=branch, db=db)
schema.set(name=item_kind, schema=node)
item_kinds.append(item_kind)
if item.is_node_schema:
added_nodes.append(item_kind)
else:
added_generics.append(item_kind)

changed_nodes = []
changed_generics = []
for item_kind, item_diff in diff.changed.items():
item = schema.get(name=item_kind, duplicate=False)
if item_diff:
node = await self.update_node_in_db_based_on_diff(node=item, branch=branch, db=db, diff=item_diff)
else:
node = await self.update_node_in_db(node=item, branch=branch, db=db)
schema.set(name=item_kind, schema=node)
item_kinds.append(item_kind)
if item.is_node_schema:
changed_nodes.append(item_kind)
else:
changed_generics.append(item_kind)

removed_nodes = []
removed_generics = []
for item_kind, item_diff in diff.removed.items():
item = schema.get(name=item_kind, duplicate=False)
node = await self.delete_node_in_db(node=item, branch=branch, db=db)
schema.delete(name=item_kind)

schema_diff = SchemaBranchDiff(
nodes=[name for name in schema.node_names if name in item_kinds],
generics=[name for name in schema.generic_names if name in item_kinds],
if item.is_node_schema:
removed_nodes.append(item_kind)
else:
removed_generics.append(item_kind)

return SchemaBranchDiff(
added_nodes=added_nodes,
added_generics=added_generics,
changed_nodes=changed_nodes,
changed_generics=changed_generics,
removed_nodes=removed_nodes,
removed_generics=removed_generics,
)
return schema_diff

async def load_schema_to_db(
self,
Expand Down Expand Up @@ -574,15 +592,15 @@ async def load_schema(
if not branch.is_default and branch.origin_branch:
origin_branch: Branch = await registry.get_branch(branch=branch.origin_branch, db=db)

if origin_branch.schema_hash.main == branch.schema_hash.main:
if origin_branch.active_schema_hash.main == branch.active_schema_hash.main:
origin_schema = self.get_schema_branch(name=origin_branch.name)
new_branch_schema = origin_schema.duplicate()
self.set_schema_branch(name=branch.name, schema=new_branch_schema)
log.info("Loading schema from cache")
return new_branch_schema

current_schema = self.get_schema_branch(name=branch.name)
schema_diff = current_schema.get_hash_full().compare(branch.schema_hash)
schema_diff = current_schema.get_hash_full().compare(branch.active_schema_hash)
branch_schema = await self.load_schema_from_db(
db=db, branch=branch, schema=current_schema, schema_diff=schema_diff
)
Expand Down Expand Up @@ -619,7 +637,7 @@ async def load_schema_from_db(
has_filters = False

# If a diff is provided but is empty there is nothing to query
if schema_diff is not None and not schema_diff:
if schema_diff is not None and not schema_diff.has_diff:
return schema

if schema_diff:
Expand All @@ -636,6 +654,12 @@ async def load_schema_from_db(
if filter_value["namespace__values"]:
filters[node_type] = filter_value
has_filters = True
for removed_generic in schema_diff.removed_generics:
if removed_generic in schema.generic_names:
schema.delete(name=removed_generic)
for removed_node in schema_diff.removed_nodes:
if removed_node in schema.node_names:
schema.delete(name=removed_node)

if not has_filters or filters["generics"]:
generic_schema = self.get(name="SchemaGeneric", branch=branch)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
---
version: "1.0"
nodes:
- name: Device
namespace: Network
human_friendly_id: ['hostname__value']
attributes:
- name: hostname
kind: Text
unique: true
- name: model
kind: Text
- name: Interface
namespace: Network
state: absent
attributes:
- name: name
kind: Text
- name: description
kind: Text
optional: true
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
---
version: "1.0"
nodes:
- name: Device
namespace: Network
human_friendly_id: ['hostname__value']
attributes:
- name: hostname
kind: Text
unique: true
- name: model
kind: Text
- name: Interface
namespace: Network
attributes:
- name: name
kind: Text
- name: description
kind: Text
optional: true
44 changes: 44 additions & 0 deletions backend/tests/integration_docker/test_schema_migration.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import copy
from pathlib import Path
from typing import Any

import pytest
import yaml
from infrahub_sdk import InfrahubClient

from infrahub.testing.helpers import TestInfrahubDev
Expand All @@ -11,6 +13,8 @@
SchemaCarPerson,
)

CURRENT_DIRECTORY = Path(__file__).parent.resolve()


class TestSchemaMigrations(TestInfrahubDev, SchemaCarPerson):
@pytest.fixture(scope="class")
Expand All @@ -34,6 +38,9 @@ def schema_person_with_age(
async def test_setup_initial_schema(
self, default_branch: str, infrahub_client: InfrahubClient, schema_base: dict[str, Any]
) -> None:
await infrahub_client.schema.wait_until_converged(branch=default_branch)
# Validate that the schema is in sync after initial startup
assert await self.schema_in_sync(client=infrahub_client, branch=default_branch)
resp = await infrahub_client.schema.load(
schemas=[schema_base], branch=default_branch, wait_until_converged=True
)
Expand All @@ -53,3 +60,40 @@ async def test_update_schema(self, infrahub_client: InfrahubClient, schema_perso
resp = await infrahub_client.schema.load(schemas=[schema_person_with_age], branch=branch.name)

assert resp.errors == {}

async def test_schema_load_and_delete(self, infrahub_client: InfrahubClient) -> None:
with Path(CURRENT_DIRECTORY / "test_files/device_and_interface_schema.yml").open(encoding="utf-8") as file:
device_and_interface_schema = yaml.safe_load(file.read())

with Path(CURRENT_DIRECTORY / "test_files/delete_interface_schema.yml").open(encoding="utf-8") as file:
delete_interface_schema = yaml.safe_load(file.read())

device_branch = await infrahub_client.branch.create(branch_name="device_branch")

device_interface = await infrahub_client.schema.load(
schemas=[device_and_interface_schema], branch=device_branch.name, wait_until_converged=True
)
assert device_interface.schema_updated
# Validate that the schema is in sync after loading the device and interface schema
assert await self.schema_in_sync(client=infrahub_client, branch=device_branch.name)

delete_interface = await infrahub_client.schema.load(
schemas=[delete_interface_schema], branch=device_branch.name, wait_until_converged=True
)
assert delete_interface.schema_updated
# Validate that the schema is in sync after removing the interface
assert await self.schema_in_sync(client=infrahub_client, branch=device_branch.name)

@staticmethod
async def schema_in_sync(client: InfrahubClient, branch: str | None) -> bool:
SCHEMA_HASH_SYNC_STATUS = """
query {
InfrahubStatus {
summary {
schema_hash_synced
}
}
}
"""
response = await client.execute_graphql(query=SCHEMA_HASH_SYNC_STATUS, branch_name=branch)
return response["InfrahubStatus"]["summary"]["schema_hash_synced"]
Original file line number Diff line number Diff line change
Expand Up @@ -2482,6 +2482,7 @@ async def test_load_schema(

schema1 = registry.schema.register_schema(schema=SchemaRoot(**FULL_SCHEMA), branch=default_branch.name)
await registry.schema.load_schema_to_db(schema=schema1, db=db, branch=default_branch.name)
default_branch.update_schema_hash()
schema11 = registry.schema.get_schema_branch(name=default_branch.name)
schema2 = await registry.schema.load_schema(db=db, branch=default_branch.name)

Expand Down
1 change: 1 addition & 0 deletions changelog/4836.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Ensure that deleted schema nodes are removed from all workers and that the schema is in sync without having to restart
Loading