diff --git a/backend/infrahub/core/merge.py b/backend/infrahub/core/merge.py index 1cb00c19f7..d1712d47b8 100644 --- a/backend/infrahub/core/merge.py +++ b/backend/infrahub/core/merge.py @@ -2,12 +2,11 @@ from typing import TYPE_CHECKING, Optional, Union -from infrahub.core.constants import DiffAction, RepositoryInternalStatus +from infrahub.core.constants import RepositoryInternalStatus from infrahub.core.manager import NodeManager -from infrahub.core.models import SchemaBranchDiff, SchemaUpdateValidationResult +from infrahub.core.models import SchemaUpdateValidationResult from infrahub.core.protocols import CoreRepository from infrahub.core.registry import registry -from infrahub.core.schema import GenericSchema, NodeSchema from infrahub.core.timestamp import Timestamp from infrahub.exceptions import ValidationError from infrahub.message_bus import messages @@ -49,8 +48,6 @@ def __init__( self._destination_schema: Optional[SchemaBranch] = None self._initial_source_schema: Optional[SchemaBranch] = None - self.schema_diff: Optional[SchemaBranchDiff] = None - self._service = service @property @@ -106,46 +103,6 @@ async def has_schema_changes(self) -> bool: graph_diff = await self.get_graph_diff() return await graph_diff.has_schema_changes() - async def get_schema_diff(self) -> SchemaBranchDiff: - """Return a SchemaBranchDiff object with the list of nodes and generics - based on the information returned by the Graph Diff. - - The Graph Diff return a list of UUID so we need to convert that back into Kind - """ - - if self.schema_diff: - return self.schema_diff - - graph_diff = await self.get_graph_diff() - schema_summary = await graph_diff.get_schema_summary() - schema_diff = SchemaBranchDiff() - - # NOTE At this point there is no Generic in the schema but this could change in the future - for element in schema_summary.get(self.source_branch.name, []): - if element.kind == "SchemaNode" and DiffAction.REMOVED in element.actions: - continue - node = self.source_schema.get_by_any_id(id=element.node) - if isinstance(node, NodeSchema): - schema_diff.nodes.append(node.kind) - elif isinstance(node, GenericSchema): - schema_diff.generics.append(node.kind) - - for element in schema_summary.get(self.destination_branch.name, []): - if element.kind == "SchemaNode" and DiffAction.REMOVED in element.actions: - continue - node = self.destination_schema.get_by_any_id(id=element.node) - if isinstance(node, NodeSchema): - schema_diff.nodes.append(node.kind) - elif isinstance(node, GenericSchema): - schema_diff.generics.append(node.kind) - - # Remove duplicates if any - schema_diff.nodes = list(set(schema_diff.nodes)) - schema_diff.generics = list(set(schema_diff.generics)) - - self.schema_diff = schema_diff - return self.schema_diff - async def update_schema(self) -> bool: """After the merge, if there was some changes, we need to: - update the schema in the registry @@ -154,7 +111,6 @@ async def update_schema(self) -> bool: # NOTE we need to revisit how to calculate an accurate diff to pull only what needs to be updated from the schema # for now the best solution is to pull everything to ensure the integrity of the schema - # schema_diff = await self.get_schema_diff() if not await self.has_schema_changes(): return False diff --git a/backend/infrahub/core/models.py b/backend/infrahub/core/models.py index 54e2358715..d8d52d4038 100644 --- a/backend/infrahub/core/models.py +++ b/backend/infrahub/core/models.py @@ -29,8 +29,12 @@ def __str__(self) -> str: class SchemaBranchDiff(BaseModel): - nodes: list[str] = Field(default_factory=list) - generics: list[str] = Field(default_factory=list) + added_nodes: list[str] = Field(default_factory=list) + changed_nodes: list[str] = Field(default_factory=list) + added_generics: list[str] = Field(default_factory=list) + changed_generics: list[str] = Field(default_factory=list) + removed_nodes: list[str] = Field(default_factory=list) + removed_generics: list[str] = Field(default_factory=list) def to_string(self) -> str: return ", ".join(self.nodes + self.generics) @@ -40,9 +44,25 @@ def to_list(self) -> list[str]: @property def has_diff(self) -> bool: - if self.nodes or self.generics: - return True - return False + return any([self.has_node_diff, self.has_generic_diff]) + + @property + def has_node_diff(self) -> bool: + return bool(self.added_nodes + self.changed_nodes + self.removed_nodes) + + @property + def has_generic_diff(self) -> bool: + return bool(self.added_generics + self.changed_generics + self.removed_generics) + + @property + def nodes(self) -> list[str]: + """Return nodes that are still active.""" + return self.added_nodes + self.changed_nodes + + @property + def generics(self) -> list[str]: + """Return generics that are still active.""" + return self.added_generics + self.changed_generics class SchemaBranchHash(BaseModel): @@ -50,15 +70,19 @@ class SchemaBranchHash(BaseModel): nodes: dict[str, str] = Field(default_factory=dict) generics: dict[str, str] = Field(default_factory=dict) - def compare(self, other: SchemaBranchHash) -> Optional[SchemaBranchDiff]: + def compare(self, other: SchemaBranchHash) -> SchemaBranchDiff | None: if other.main == self.main: return None return SchemaBranchDiff( - nodes=[key for key, value in other.nodes.items() if key not in self.nodes or self.nodes[key] != value], - generics=[ - key for key, value in other.generics.items() if key not in self.generics or self.generics[key] != value + added_nodes=[key for key in other.nodes if key not in self.nodes], + changed_nodes=[key for key, value in other.nodes.items() if key in self.nodes and self.nodes[key] != value], + removed_nodes=[key for key in self.nodes if key not in other.nodes], + added_generics=[key for key in other.generics if key not in self.generics], + changed_generics=[ + key for key, value in other.generics.items() if key in self.generics and self.generics[key] != value ], + removed_generics=[key for key in self.generics if key not in other.generics], ) diff --git a/backend/infrahub/core/schema/manager.py b/backend/infrahub/core/schema/manager.py index b6e04b024d..f06ba39fde 100644 --- a/backend/infrahub/core/schema/manager.py +++ b/backend/infrahub/core/schema/manager.py @@ -162,8 +162,8 @@ async def update_schema_branch( schema_diff = None if limit: schema_diff = SchemaBranchDiff( - nodes=[name for name in list(schema.nodes.keys()) if name in limit], - generics=[name for name in list(schema.generics.keys()) if name in limit], + added_nodes=[name for name in list(schema.nodes.keys()) if name in limit], + added_generics=[name for name in list(schema.generics.keys()) if name in limit], ) updated_schema = await self.load_schema_from_db( @@ -192,13 +192,19 @@ async def update_schema_to_db( branch = await registry.get_branch(branch=branch, db=db) - item_kinds = [] + added_nodes = [] + added_generics = [] for item_kind, item_diff in diff.added.items(): item = schema.get(name=item_kind, duplicate=False) node = await self.load_node_to_db(node=item, branch=branch, db=db) schema.set(name=item_kind, schema=node) - item_kinds.append(item_kind) + if item.is_node_schema: + added_nodes.append(item_kind) + else: + added_generics.append(item_kind) + changed_nodes = [] + changed_generics = [] for item_kind, item_diff in diff.changed.items(): item = schema.get(name=item_kind, duplicate=False) if item_diff: @@ -206,18 +212,30 @@ async def update_schema_to_db( else: node = await self.update_node_in_db(node=item, branch=branch, db=db) schema.set(name=item_kind, schema=node) - item_kinds.append(item_kind) + if item.is_node_schema: + changed_nodes.append(item_kind) + else: + changed_generics.append(item_kind) + removed_nodes = [] + removed_generics = [] for item_kind, item_diff in diff.removed.items(): item = schema.get(name=item_kind, duplicate=False) node = await self.delete_node_in_db(node=item, branch=branch, db=db) schema.delete(name=item_kind) - - schema_diff = SchemaBranchDiff( - nodes=[name for name in schema.node_names if name in item_kinds], - generics=[name for name in schema.generic_names if name in item_kinds], + if item.is_node_schema: + removed_nodes.append(item_kind) + else: + removed_generics.append(item_kind) + + return SchemaBranchDiff( + added_nodes=added_nodes, + added_generics=added_generics, + changed_nodes=changed_nodes, + changed_generics=changed_generics, + removed_nodes=removed_nodes, + removed_generics=removed_generics, ) - return schema_diff async def load_schema_to_db( self, @@ -574,7 +592,7 @@ async def load_schema( if not branch.is_default and branch.origin_branch: origin_branch: Branch = await registry.get_branch(branch=branch.origin_branch, db=db) - if origin_branch.schema_hash.main == branch.schema_hash.main: + if origin_branch.active_schema_hash.main == branch.active_schema_hash.main: origin_schema = self.get_schema_branch(name=origin_branch.name) new_branch_schema = origin_schema.duplicate() self.set_schema_branch(name=branch.name, schema=new_branch_schema) @@ -582,7 +600,7 @@ async def load_schema( return new_branch_schema current_schema = self.get_schema_branch(name=branch.name) - schema_diff = current_schema.get_hash_full().compare(branch.schema_hash) + schema_diff = current_schema.get_hash_full().compare(branch.active_schema_hash) branch_schema = await self.load_schema_from_db( db=db, branch=branch, schema=current_schema, schema_diff=schema_diff ) @@ -619,7 +637,7 @@ async def load_schema_from_db( has_filters = False # If a diff is provided but is empty there is nothing to query - if schema_diff is not None and not schema_diff: + if schema_diff is not None and not schema_diff.has_diff: return schema if schema_diff: @@ -636,6 +654,12 @@ async def load_schema_from_db( if filter_value["namespace__values"]: filters[node_type] = filter_value has_filters = True + for removed_generic in schema_diff.removed_generics: + if removed_generic in schema.generic_names: + schema.delete(name=removed_generic) + for removed_node in schema_diff.removed_nodes: + if removed_node in schema.node_names: + schema.delete(name=removed_node) if not has_filters or filters["generics"]: generic_schema = self.get(name="SchemaGeneric", branch=branch) diff --git a/backend/tests/integration_docker/test_files/delete_interface_schema.yml b/backend/tests/integration_docker/test_files/delete_interface_schema.yml new file mode 100644 index 0000000000..9ef70055b9 --- /dev/null +++ b/backend/tests/integration_docker/test_files/delete_interface_schema.yml @@ -0,0 +1,21 @@ +--- +version: "1.0" +nodes: + - name: Device + namespace: Network + human_friendly_id: ['hostname__value'] + attributes: + - name: hostname + kind: Text + unique: true + - name: model + kind: Text + - name: Interface + namespace: Network + state: absent + attributes: + - name: name + kind: Text + - name: description + kind: Text + optional: true diff --git a/backend/tests/integration_docker/test_files/device_and_interface_schema.yml b/backend/tests/integration_docker/test_files/device_and_interface_schema.yml new file mode 100644 index 0000000000..9eeed74fd7 --- /dev/null +++ b/backend/tests/integration_docker/test_files/device_and_interface_schema.yml @@ -0,0 +1,20 @@ +--- +version: "1.0" +nodes: + - name: Device + namespace: Network + human_friendly_id: ['hostname__value'] + attributes: + - name: hostname + kind: Text + unique: true + - name: model + kind: Text + - name: Interface + namespace: Network + attributes: + - name: name + kind: Text + - name: description + kind: Text + optional: true diff --git a/backend/tests/integration_docker/test_schema_migration.py b/backend/tests/integration_docker/test_schema_migration.py index f5bfc34baa..36951cd812 100644 --- a/backend/tests/integration_docker/test_schema_migration.py +++ b/backend/tests/integration_docker/test_schema_migration.py @@ -1,7 +1,9 @@ import copy +from pathlib import Path from typing import Any import pytest +import yaml from infrahub_sdk import InfrahubClient from infrahub.testing.helpers import TestInfrahubDev @@ -11,6 +13,8 @@ SchemaCarPerson, ) +CURRENT_DIRECTORY = Path(__file__).parent.resolve() + class TestSchemaMigrations(TestInfrahubDev, SchemaCarPerson): @pytest.fixture(scope="class") @@ -34,6 +38,9 @@ def schema_person_with_age( async def test_setup_initial_schema( self, default_branch: str, infrahub_client: InfrahubClient, schema_base: dict[str, Any] ) -> None: + await infrahub_client.schema.wait_until_converged(branch=default_branch) + # Validate that the schema is in sync after initial startup + assert await self.schema_in_sync(client=infrahub_client, branch=default_branch) resp = await infrahub_client.schema.load( schemas=[schema_base], branch=default_branch, wait_until_converged=True ) @@ -53,3 +60,40 @@ async def test_update_schema(self, infrahub_client: InfrahubClient, schema_perso resp = await infrahub_client.schema.load(schemas=[schema_person_with_age], branch=branch.name) assert resp.errors == {} + + async def test_schema_load_and_delete(self, infrahub_client: InfrahubClient) -> None: + with Path(CURRENT_DIRECTORY / "test_files/device_and_interface_schema.yml").open(encoding="utf-8") as file: + device_and_interface_schema = yaml.safe_load(file.read()) + + with Path(CURRENT_DIRECTORY / "test_files/delete_interface_schema.yml").open(encoding="utf-8") as file: + delete_interface_schema = yaml.safe_load(file.read()) + + device_branch = await infrahub_client.branch.create(branch_name="device_branch") + + device_interface = await infrahub_client.schema.load( + schemas=[device_and_interface_schema], branch=device_branch.name, wait_until_converged=True + ) + assert device_interface.schema_updated + # Validate that the schema is in sync after loading the device and interface schema + assert await self.schema_in_sync(client=infrahub_client, branch=device_branch.name) + + delete_interface = await infrahub_client.schema.load( + schemas=[delete_interface_schema], branch=device_branch.name, wait_until_converged=True + ) + assert delete_interface.schema_updated + # Validate that the schema is in sync after removing the interface + assert await self.schema_in_sync(client=infrahub_client, branch=device_branch.name) + + @staticmethod + async def schema_in_sync(client: InfrahubClient, branch: str | None) -> bool: + SCHEMA_HASH_SYNC_STATUS = """ + query { + InfrahubStatus { + summary { + schema_hash_synced + } + } + } + """ + response = await client.execute_graphql(query=SCHEMA_HASH_SYNC_STATUS, branch_name=branch) + return response["InfrahubStatus"]["summary"]["schema_hash_synced"] diff --git a/backend/tests/unit/core/schema_manager/test_manager_schema.py b/backend/tests/unit/core/schema_manager/test_manager_schema.py index ab36ae23e6..467c9eee5d 100644 --- a/backend/tests/unit/core/schema_manager/test_manager_schema.py +++ b/backend/tests/unit/core/schema_manager/test_manager_schema.py @@ -2500,6 +2500,7 @@ async def test_load_schema( schema1 = registry.schema.register_schema(schema=SchemaRoot(**FULL_SCHEMA), branch=default_branch.name) await registry.schema.load_schema_to_db(schema=schema1, db=db, branch=default_branch.name) + default_branch.update_schema_hash() schema11 = registry.schema.get_schema_branch(name=default_branch.name) schema2 = await registry.schema.load_schema(db=db, branch=default_branch.name) diff --git a/changelog/4836.fixed.md b/changelog/4836.fixed.md new file mode 100644 index 0000000000..424652878e --- /dev/null +++ b/changelog/4836.fixed.md @@ -0,0 +1 @@ +Ensure that deleted schema nodes are removed from all workers and that the schema is in sync without having to restart