Skip to content

Commit

Permalink
Update SchemaBranchDiff format
Browse files Browse the repository at this point in the history
  • Loading branch information
ogenstad committed Dec 13, 2024
1 parent c9c6741 commit 2ec5829
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 67 deletions.
48 changes: 2 additions & 46 deletions backend/infrahub/core/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@

from typing import TYPE_CHECKING, Optional, Union

from infrahub.core.constants import DiffAction, RepositoryInternalStatus
from infrahub.core.constants import RepositoryInternalStatus
from infrahub.core.manager import NodeManager
from infrahub.core.models import SchemaBranchDiff, SchemaUpdateValidationResult
from infrahub.core.models import SchemaUpdateValidationResult
from infrahub.core.protocols import CoreRepository
from infrahub.core.registry import registry
from infrahub.core.schema import GenericSchema, NodeSchema
from infrahub.core.timestamp import Timestamp
from infrahub.exceptions import ValidationError
from infrahub.message_bus import messages
Expand Down Expand Up @@ -49,8 +48,6 @@ def __init__(
self._destination_schema: Optional[SchemaBranch] = None
self._initial_source_schema: Optional[SchemaBranch] = None

self.schema_diff: Optional[SchemaBranchDiff] = None

self._service = service

@property
Expand Down Expand Up @@ -106,46 +103,6 @@ async def has_schema_changes(self) -> bool:
graph_diff = await self.get_graph_diff()
return await graph_diff.has_schema_changes()

async def get_schema_diff(self) -> SchemaBranchDiff:
"""Return a SchemaBranchDiff object with the list of nodes and generics
based on the information returned by the Graph Diff.
The Graph Diff return a list of UUID so we need to convert that back into Kind
"""

if self.schema_diff:
return self.schema_diff

graph_diff = await self.get_graph_diff()
schema_summary = await graph_diff.get_schema_summary()
schema_diff = SchemaBranchDiff()

# NOTE At this point there is no Generic in the schema but this could change in the future
for element in schema_summary.get(self.source_branch.name, []):
if element.kind == "SchemaNode" and DiffAction.REMOVED in element.actions:
continue
node = self.source_schema.get_by_any_id(id=element.node)
if isinstance(node, NodeSchema):
schema_diff.nodes.append(node.kind)
elif isinstance(node, GenericSchema):
schema_diff.generics.append(node.kind)

for element in schema_summary.get(self.destination_branch.name, []):
if element.kind == "SchemaNode" and DiffAction.REMOVED in element.actions:
continue
node = self.destination_schema.get_by_any_id(id=element.node)
if isinstance(node, NodeSchema):
schema_diff.nodes.append(node.kind)
elif isinstance(node, GenericSchema):
schema_diff.generics.append(node.kind)

# Remove duplicates if any
schema_diff.nodes = list(set(schema_diff.nodes))
schema_diff.generics = list(set(schema_diff.generics))

self.schema_diff = schema_diff
return self.schema_diff

async def update_schema(self) -> bool:
"""After the merge, if there was some changes, we need to:
- update the schema in the registry
Expand All @@ -154,7 +111,6 @@ async def update_schema(self) -> bool:

# NOTE we need to revisit how to calculate an accurate diff to pull only what needs to be updated from the schema
# for now the best solution is to pull everything to ensure the integrity of the schema
# schema_diff = await self.get_schema_diff()

if not await self.has_schema_changes():
return False
Expand Down
36 changes: 28 additions & 8 deletions backend/infrahub/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@ def __str__(self) -> str:


class SchemaBranchDiff(BaseModel):
nodes: list[str] = Field(default_factory=list)
generics: list[str] = Field(default_factory=list)
added_nodes: list[str] = Field(default_factory=list)
changed_nodes: list[str] = Field(default_factory=list)
added_generics: list[str] = Field(default_factory=list)
changed_generics: list[str] = Field(default_factory=list)
removed_nodes: list[str] = Field(default_factory=list)
removed_generics: list[str] = Field(default_factory=list)

Expand All @@ -42,9 +44,25 @@ def to_list(self) -> list[str]:

@property
def has_diff(self) -> bool:
if self.nodes or self.generics or self.removed_nodes or self.removed_generics:
return True
return False
return any([self.has_node_diff, self.has_generic_diff])

@property
def has_node_diff(self) -> bool:
return bool(self.added_nodes + self.changed_nodes + self.removed_nodes)

@property
def has_generic_diff(self) -> bool:
return bool(self.added_generics + self.changed_generics + self.removed_generics)

@property
def nodes(self) -> list[str]:
"""Return nodes that are still active."""
return self.added_nodes + self.changed_nodes

@property
def generics(self) -> list[str]:
"""Return generics that are still active."""
return self.added_generics + self.changed_generics


class SchemaBranchHash(BaseModel):
Expand All @@ -57,10 +75,12 @@ def compare(self, other: SchemaBranchHash) -> SchemaBranchDiff | None:
return None

return SchemaBranchDiff(
nodes=[key for key, value in other.nodes.items() if key not in self.nodes or self.nodes[key] != value],
added_nodes=[key for key in other.nodes if key not in self.nodes],
changed_nodes=[key for key, value in other.nodes.items() if key in self.nodes and self.nodes[key] != value],
removed_nodes=[key for key in self.nodes if key not in other.nodes],
generics=[
key for key, value in other.generics.items() if key not in self.generics or self.generics[key] != value
added_generics=[key for key in other.generics if key not in self.generics],
changed_generics=[
key for key, value in other.generics.items() if key in self.generics and self.generics[key] != value
],
removed_generics=[key for key in self.generics if key not in other.generics],
)
Expand Down
44 changes: 31 additions & 13 deletions backend/infrahub/core/schema/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,8 @@ async def update_schema_branch(
schema_diff = None
if limit:
schema_diff = SchemaBranchDiff(
nodes=[name for name in list(schema.nodes.keys()) if name in limit],
generics=[name for name in list(schema.generics.keys()) if name in limit],
added_nodes=[name for name in list(schema.nodes.keys()) if name in limit],
added_generics=[name for name in list(schema.generics.keys()) if name in limit],
)

updated_schema = await self.load_schema_from_db(
Expand Down Expand Up @@ -192,32 +192,50 @@ async def update_schema_to_db(

branch = await registry.get_branch(branch=branch, db=db)

item_kinds = []
added_nodes = []
added_generics = []
for item_kind, item_diff in diff.added.items():
item = schema.get(name=item_kind, duplicate=False)
node = await self.load_node_to_db(node=item, branch=branch, db=db)
schema.set(name=item_kind, schema=node)
item_kinds.append(item_kind)
if item.is_node_schema:
added_nodes.append(item_kind)
else:
added_generics.append(item_kind)

changed_nodes = []
changed_generics = []
for item_kind, item_diff in diff.changed.items():
item = schema.get(name=item_kind, duplicate=False)
if item_diff:
node = await self.update_node_in_db_based_on_diff(node=item, branch=branch, db=db, diff=item_diff)
else:
node = await self.update_node_in_db(node=item, branch=branch, db=db)
schema.set(name=item_kind, schema=node)
item_kinds.append(item_kind)
if item.is_node_schema:
changed_nodes.append(item_kind)
else:
changed_generics.append(item_kind)

removed_nodes = []
removed_generics = []
for item_kind, item_diff in diff.removed.items():
item = schema.get(name=item_kind, duplicate=False)
node = await self.delete_node_in_db(node=item, branch=branch, db=db)
schema.delete(name=item_kind)

schema_diff = SchemaBranchDiff(
nodes=[name for name in schema.node_names if name in item_kinds],
generics=[name for name in schema.generic_names if name in item_kinds],
if item.is_node_schema:
removed_nodes.append(item_kind)
else:
removed_generics.append(item_kind)

return SchemaBranchDiff(
added_nodes=added_nodes,
added_generics=added_generics,
changed_nodes=changed_nodes,
changed_generics=changed_generics,
removed_nodes=removed_nodes,
removed_generics=removed_generics,
)
return schema_diff

async def load_schema_to_db(
self,
Expand Down Expand Up @@ -574,15 +592,15 @@ async def load_schema(
if not branch.is_default and branch.origin_branch:
origin_branch: Branch = await registry.get_branch(branch=branch.origin_branch, db=db)

if origin_branch.schema_hash.main == branch.schema_hash.main:
if origin_branch.active_schema_hash.main == branch.active_schema_hash.main:
origin_schema = self.get_schema_branch(name=origin_branch.name)
new_branch_schema = origin_schema.duplicate()
self.set_schema_branch(name=branch.name, schema=new_branch_schema)
log.info("Loading schema from cache")
return new_branch_schema

current_schema = self.get_schema_branch(name=branch.name)
schema_diff = current_schema.get_hash_full().compare(branch.schema_hash)
schema_diff = current_schema.get_hash_full().compare(branch.active_schema_hash)
branch_schema = await self.load_schema_from_db(
db=db, branch=branch, schema=current_schema, schema_diff=schema_diff
)
Expand Down Expand Up @@ -619,7 +637,7 @@ async def load_schema_from_db(
has_filters = False

# If a diff is provided but is empty there is nothing to query
if schema_diff is not None and not schema_diff:
if schema_diff is not None and not schema_diff.has_diff:
return schema

if schema_diff:
Expand Down

0 comments on commit 2ec5829

Please sign in to comment.