diff --git a/backend/infrahub/core/merge.py b/backend/infrahub/core/merge.py index 1cb00c19f7..d1712d47b8 100644 --- a/backend/infrahub/core/merge.py +++ b/backend/infrahub/core/merge.py @@ -2,12 +2,11 @@ from typing import TYPE_CHECKING, Optional, Union -from infrahub.core.constants import DiffAction, RepositoryInternalStatus +from infrahub.core.constants import RepositoryInternalStatus from infrahub.core.manager import NodeManager -from infrahub.core.models import SchemaBranchDiff, SchemaUpdateValidationResult +from infrahub.core.models import SchemaUpdateValidationResult from infrahub.core.protocols import CoreRepository from infrahub.core.registry import registry -from infrahub.core.schema import GenericSchema, NodeSchema from infrahub.core.timestamp import Timestamp from infrahub.exceptions import ValidationError from infrahub.message_bus import messages @@ -49,8 +48,6 @@ def __init__( self._destination_schema: Optional[SchemaBranch] = None self._initial_source_schema: Optional[SchemaBranch] = None - self.schema_diff: Optional[SchemaBranchDiff] = None - self._service = service @property @@ -106,46 +103,6 @@ async def has_schema_changes(self) -> bool: graph_diff = await self.get_graph_diff() return await graph_diff.has_schema_changes() - async def get_schema_diff(self) -> SchemaBranchDiff: - """Return a SchemaBranchDiff object with the list of nodes and generics - based on the information returned by the Graph Diff. - - The Graph Diff return a list of UUID so we need to convert that back into Kind - """ - - if self.schema_diff: - return self.schema_diff - - graph_diff = await self.get_graph_diff() - schema_summary = await graph_diff.get_schema_summary() - schema_diff = SchemaBranchDiff() - - # NOTE At this point there is no Generic in the schema but this could change in the future - for element in schema_summary.get(self.source_branch.name, []): - if element.kind == "SchemaNode" and DiffAction.REMOVED in element.actions: - continue - node = self.source_schema.get_by_any_id(id=element.node) - if isinstance(node, NodeSchema): - schema_diff.nodes.append(node.kind) - elif isinstance(node, GenericSchema): - schema_diff.generics.append(node.kind) - - for element in schema_summary.get(self.destination_branch.name, []): - if element.kind == "SchemaNode" and DiffAction.REMOVED in element.actions: - continue - node = self.destination_schema.get_by_any_id(id=element.node) - if isinstance(node, NodeSchema): - schema_diff.nodes.append(node.kind) - elif isinstance(node, GenericSchema): - schema_diff.generics.append(node.kind) - - # Remove duplicates if any - schema_diff.nodes = list(set(schema_diff.nodes)) - schema_diff.generics = list(set(schema_diff.generics)) - - self.schema_diff = schema_diff - return self.schema_diff - async def update_schema(self) -> bool: """After the merge, if there was some changes, we need to: - update the schema in the registry @@ -154,7 +111,6 @@ async def update_schema(self) -> bool: # NOTE we need to revisit how to calculate an accurate diff to pull only what needs to be updated from the schema # for now the best solution is to pull everything to ensure the integrity of the schema - # schema_diff = await self.get_schema_diff() if not await self.has_schema_changes(): return False diff --git a/backend/infrahub/core/models.py b/backend/infrahub/core/models.py index 541c27ba1c..d3fad1bfc9 100644 --- a/backend/infrahub/core/models.py +++ b/backend/infrahub/core/models.py @@ -29,8 +29,10 @@ def __str__(self) -> str: class SchemaBranchDiff(BaseModel): - nodes: list[str] = Field(default_factory=list) - generics: list[str] = Field(default_factory=list) + added_nodes: list[str] = Field(default_factory=list) + changed_nodes: list[str] = Field(default_factory=list) + added_generics: list[str] = Field(default_factory=list) + changed_generics: list[str] = Field(default_factory=list) removed_nodes: list[str] = Field(default_factory=list) removed_generics: list[str] = Field(default_factory=list) @@ -42,9 +44,25 @@ def to_list(self) -> list[str]: @property def has_diff(self) -> bool: - if self.nodes or self.generics or self.removed_nodes or self.removed_generics: - return True - return False + return any([self.has_node_diff, self.has_generic_diff]) + + @property + def has_node_diff(self) -> bool: + return bool(self.added_nodes + self.changed_nodes + self.removed_nodes) + + @property + def has_generic_diff(self) -> bool: + return bool(self.added_generics + self.changed_generics + self.removed_generics) + + @property + def nodes(self) -> list[str]: + """Return nodes that are still active.""" + return self.added_nodes + self.changed_nodes + + @property + def generics(self) -> list[str]: + """Return generics that are still active.""" + return self.added_generics + self.changed_generics class SchemaBranchHash(BaseModel): @@ -57,10 +75,12 @@ def compare(self, other: SchemaBranchHash) -> SchemaBranchDiff | None: return None return SchemaBranchDiff( - nodes=[key for key, value in other.nodes.items() if key not in self.nodes or self.nodes[key] != value], + added_nodes=[key for key in other.nodes if key not in self.nodes], + changed_nodes=[key for key, value in other.nodes.items() if key in self.nodes and self.nodes[key] != value], removed_nodes=[key for key in self.nodes if key not in other.nodes], - generics=[ - key for key, value in other.generics.items() if key not in self.generics or self.generics[key] != value + added_generics=[key for key in other.generics if key not in self.generics], + changed_generics=[ + key for key, value in other.generics.items() if key in self.generics and self.generics[key] != value ], removed_generics=[key for key in self.generics if key not in other.generics], ) diff --git a/backend/infrahub/core/schema/manager.py b/backend/infrahub/core/schema/manager.py index 7f689ad7d0..f06ba39fde 100644 --- a/backend/infrahub/core/schema/manager.py +++ b/backend/infrahub/core/schema/manager.py @@ -162,8 +162,8 @@ async def update_schema_branch( schema_diff = None if limit: schema_diff = SchemaBranchDiff( - nodes=[name for name in list(schema.nodes.keys()) if name in limit], - generics=[name for name in list(schema.generics.keys()) if name in limit], + added_nodes=[name for name in list(schema.nodes.keys()) if name in limit], + added_generics=[name for name in list(schema.generics.keys()) if name in limit], ) updated_schema = await self.load_schema_from_db( @@ -192,13 +192,19 @@ async def update_schema_to_db( branch = await registry.get_branch(branch=branch, db=db) - item_kinds = [] + added_nodes = [] + added_generics = [] for item_kind, item_diff in diff.added.items(): item = schema.get(name=item_kind, duplicate=False) node = await self.load_node_to_db(node=item, branch=branch, db=db) schema.set(name=item_kind, schema=node) - item_kinds.append(item_kind) + if item.is_node_schema: + added_nodes.append(item_kind) + else: + added_generics.append(item_kind) + changed_nodes = [] + changed_generics = [] for item_kind, item_diff in diff.changed.items(): item = schema.get(name=item_kind, duplicate=False) if item_diff: @@ -206,18 +212,30 @@ async def update_schema_to_db( else: node = await self.update_node_in_db(node=item, branch=branch, db=db) schema.set(name=item_kind, schema=node) - item_kinds.append(item_kind) + if item.is_node_schema: + changed_nodes.append(item_kind) + else: + changed_generics.append(item_kind) + removed_nodes = [] + removed_generics = [] for item_kind, item_diff in diff.removed.items(): item = schema.get(name=item_kind, duplicate=False) node = await self.delete_node_in_db(node=item, branch=branch, db=db) schema.delete(name=item_kind) - - schema_diff = SchemaBranchDiff( - nodes=[name for name in schema.node_names if name in item_kinds], - generics=[name for name in schema.generic_names if name in item_kinds], + if item.is_node_schema: + removed_nodes.append(item_kind) + else: + removed_generics.append(item_kind) + + return SchemaBranchDiff( + added_nodes=added_nodes, + added_generics=added_generics, + changed_nodes=changed_nodes, + changed_generics=changed_generics, + removed_nodes=removed_nodes, + removed_generics=removed_generics, ) - return schema_diff async def load_schema_to_db( self, @@ -574,7 +592,7 @@ async def load_schema( if not branch.is_default and branch.origin_branch: origin_branch: Branch = await registry.get_branch(branch=branch.origin_branch, db=db) - if origin_branch.schema_hash.main == branch.schema_hash.main: + if origin_branch.active_schema_hash.main == branch.active_schema_hash.main: origin_schema = self.get_schema_branch(name=origin_branch.name) new_branch_schema = origin_schema.duplicate() self.set_schema_branch(name=branch.name, schema=new_branch_schema) @@ -582,7 +600,7 @@ async def load_schema( return new_branch_schema current_schema = self.get_schema_branch(name=branch.name) - schema_diff = current_schema.get_hash_full().compare(branch.schema_hash) + schema_diff = current_schema.get_hash_full().compare(branch.active_schema_hash) branch_schema = await self.load_schema_from_db( db=db, branch=branch, schema=current_schema, schema_diff=schema_diff ) @@ -619,7 +637,7 @@ async def load_schema_from_db( has_filters = False # If a diff is provided but is empty there is nothing to query - if schema_diff is not None and not schema_diff: + if schema_diff is not None and not schema_diff.has_diff: return schema if schema_diff: