Skip to content

Commit

Permalink
Merge pull request #4994 from opsmill/dga-20241119-node-tag
Browse files Browse the repository at this point in the history
Add node tag for repository flows
  • Loading branch information
dgarros authored Nov 23, 2024
2 parents b899e86 + e281470 commit 258020a
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 41 deletions.
61 changes: 30 additions & 31 deletions backend/infrahub/git/tasks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from datetime import timedelta

from infrahub_sdk import InfrahubClient
from infrahub_sdk.protocols import CoreRepository
from prefect import flow, task
from prefect.automations import AutomationCore
from prefect.client.orchestration import get_client
Expand All @@ -11,18 +12,17 @@

from infrahub import lock
from infrahub.core.constants import InfrahubKind, RepositoryInternalStatus
from infrahub.core.protocols import CoreRepository
from infrahub.core.registry import registry
from infrahub.exceptions import RepositoryError
from infrahub.message_bus import Meta, messages
from infrahub.services import services
from infrahub.worker import WORKER_IDENTITY
from infrahub.workflows.catalogue import COMPUTED_ATTRIBUTE_SETUP_PYTHON

from ..log import get_log_data, get_logger
from ..log import get_log_data
from ..tasks.artifact import define_artifact
from ..workflows.catalogue import REQUEST_ARTIFACT_DEFINITION_GENERATE, REQUEST_ARTIFACT_GENERATE
from ..workflows.utils import add_branch_tag
from ..workflows.utils import add_branch_tag, add_tags
from .constants import AUTOMATION_NAME
from .models import (
GitDiffNamesOnly,
Expand All @@ -37,16 +37,14 @@
)
from .repository import InfrahubReadOnlyRepository, InfrahubRepository, get_initialized_repo

log = get_logger()


@flow(
name="git-repository-add-read-write",
flow_run_name="Adding repository {model.repository_name} in branch {model.infrahub_branch_name}",
)
async def add_git_repository(model: GitRepositoryAdd) -> None:
service = services.service
await add_branch_tag(model.infrahub_branch_name)
await add_tags(branches=[model.infrahub_branch_name], nodes=[model.repository_id])

async with lock.registry.get(name=model.repository_name, namespace="repository"):
repo = await InfrahubRepository.new(
Expand Down Expand Up @@ -82,7 +80,8 @@ async def add_git_repository(model: GitRepositoryAdd) -> None:
)
async def add_git_repository_read_only(model: GitRepositoryAddReadOnly) -> None:
service = services.service
await add_branch_tag(model.infrahub_branch_name)
await add_tags(branches=[model.infrahub_branch_name], nodes=[model.repository_id])

async with lock.registry.get(name=model.repository_name, namespace="repository"):
repo = await InfrahubReadOnlyRepository.new(
id=model.repository_id,
Expand All @@ -108,7 +107,7 @@ async def add_git_repository_read_only(model: GitRepositoryAddReadOnly) -> None:
await service.send(message=notification)


@flow(name="git_repositories_create_branch")
@flow(name="git_repositories_create_branch", flow_run_name="Create branch in Git Repositories")
async def create_branch(branch: str, branch_id: str) -> None:
"""Request to the creation of git branches in available repositories."""
service = services.service
Expand All @@ -129,10 +128,12 @@ async def create_branch(branch: str, branch_id: str) -> None:
pass


@flow(name="git_repositories_sync")
@flow(name="git_repositories_sync", flow_run_name="Sync Git Repositories")
async def sync_remote_repositories() -> None:
service = services.service

log = get_run_logger()

branches = await service.client.branch.all()
repositories = await service.client.get_list_repositories(branches=branches, kind=InfrahubKind.REPOSITORY)

Expand Down Expand Up @@ -196,7 +197,7 @@ async def sync_remote_repositories() -> None:
log.info(exc.message)


@task
@task(name="git-branch-create", task_run_name="Create Branch {branch} in repository {repository_name}")
async def git_branch_create(
client: InfrahubClient, branch: str, branch_id: str, repository_id: str, repository_name: str
) -> None:
Expand Down Expand Up @@ -238,8 +239,8 @@ async def generate_artifact_definition(branch: str) -> None:
async def generate_artifact(model: RequestArtifactGenerate) -> None:
service = services.service

await add_branch_tag(branch_name=model.branch_name)

await add_tags(branches=[model.branch_name], nodes=[model.repository_id])
log = get_run_logger()
repo = await get_initialized_repo(
repository_id=model.repository_id,
name=model.repository_name,
Expand All @@ -252,15 +253,10 @@ async def generate_artifact(model: RequestArtifactGenerate) -> None:
try:
result = await repo.render_artifact(artifact=artifact, message=model)
log.debug(
"Generated artifact",
name=model.artifact_name,
changed=result.changed,
checksum=result.checksum,
artifact_id=result.artifact_id,
storage_id=result.storage_id,
f"Generated artifact | changed: {result.changed} | {result.checksum} | {result.storage_id}",
)
except Exception as exc: # pylint: disable=broad-except
log.exception("Failed to generate artifact", error=exc)
except Exception: # pylint: disable=broad-except
log.exception("Failed to generate artifact")
artifact.status.value = "Error"
await artifact.save()

Expand Down Expand Up @@ -342,12 +338,12 @@ async def generate_request_artifact_definition(model: RequestArtifactDefinitionG
@flow(name="git-repository-pull-read-only", flow_run_name="Pull latest commit on {model.repository_name}")
async def pull_read_only(model: GitRepositoryPullReadOnly) -> None:
service = services.service

await add_tags(branches=[model.infrahub_branch_name], nodes=[model.repository_id])
log = get_run_logger()

if not model.ref and not model.commit:
log.warning(
"No commit or ref in GitRepositoryPullReadOnly message",
name=model.repository_name,
repository_id=model.repository_id,
)
log.warning("No commit or ref in GitRepositoryPullReadOnly message")
return
async with lock.registry.get(name=model.repository_name, namespace="repository"):
init_failed = False
Expand Down Expand Up @@ -388,11 +384,14 @@ async def pull_read_only(model: GitRepositoryPullReadOnly) -> None:
await service.send(message=message)


@flow(name="git-repository-merge")
@flow(
name="git-repository-merge",
flow_run_name="Merge {model.source_branch} > {model.destination_branch} in git repository",
)
async def merge_git_repository(model: GitRepositoryMerge) -> None:
service = services.service
await add_branch_tag(branch_name=model.source_branch)
await add_branch_tag(branch_name=model.destination_branch)

await add_tags(branches=[model.source_branch, model.destination_branch], nodes=[model.repository_id])

repo = await InfrahubRepository.init(
id=model.repository_id,
Expand Down Expand Up @@ -431,7 +430,7 @@ async def merge_git_repository(model: GitRepositoryMerge) -> None:

@flow(name="git-commit-automation-setup", flow_run_name="Setup git commit updated event in task-manager")
async def setup_commit_automation() -> None:
run_log = get_run_logger()
log = get_run_logger()

async with get_client(sync_client=False) as client:
deployments = {
Expand Down Expand Up @@ -471,10 +470,10 @@ async def setup_commit_automation() -> None:

if schema_update_automation:
await client.update_automation(automation_id=schema_update_automation.id, automation=automation)
run_log.info(f"{AUTOMATION_NAME} Updated")
log.info(f"{AUTOMATION_NAME} Updated")
else:
await client.create_automation(automation=automation)
run_log.info(f"{AUTOMATION_NAME} Created")
log.info(f"{AUTOMATION_NAME} Created")


@flow(name="git-repository-import-object", flow_run_name="Import objects from git repository")
Expand Down
24 changes: 20 additions & 4 deletions backend/infrahub/proposed_change/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import TYPE_CHECKING

import pytest
from infrahub_sdk.protocols import CoreProposedChange
from infrahub_sdk.protocols import CoreGeneratorDefinition, CoreProposedChange
from prefect import flow, task
from prefect.client.schemas.objects import (
State, # noqa: TCH002
Expand All @@ -24,7 +24,7 @@
from infrahub.core.diff.model.diff import DiffElementType, SchemaConflict
from infrahub.core.diff.model.path import NodeDiffFieldSummary
from infrahub.core.integrity.object_conflict.conflict_recorder import ObjectConflictValidatorRecorder
from infrahub.core.protocols import CoreDataCheck, CoreGeneratorDefinition, CoreValidator
from infrahub.core.protocols import CoreDataCheck, CoreValidator
from infrahub.core.protocols import CoreProposedChange as InternalCoreProposedChange
from infrahub.core.validators.checker import schema_validators_checker
from infrahub.core.validators.determiner import ConstraintValidatorDeterminer
Expand All @@ -45,6 +45,7 @@
from infrahub.pytest_plugin import InfrahubBackendPlugin
from infrahub.services import services
from infrahub.workflows.catalogue import REQUEST_PROPOSED_CHANGE_REPOSITORY_CHECKS
from infrahub.workflows.utils import add_tags

if TYPE_CHECKING:
from infrahub_sdk.node import InfrahubNode
Expand Down Expand Up @@ -96,6 +97,8 @@ async def merge_proposed_change(proposed_change_id: str, proposed_change_name: s
service = services.service
log = get_run_logger()

await add_tags(nodes=[proposed_change_id])

async with service.database.start_session() as db:
proposed_change = await registry.manager.get_one(
db=db, id=proposed_change_id, kind=InternalCoreProposedChange, raise_on_error=True
Expand Down Expand Up @@ -141,6 +144,9 @@ async def merge_proposed_change(proposed_change_id: str, proposed_change_name: s
)
async def cancel_proposed_changes_branch(branch_name: str) -> None:
service = services.service

await add_tags(branches=[branch_name])

proposed_changed_opened = await service.client.filters(
kind=CoreProposedChange,
include=["id", "source_branch"],
Expand All @@ -161,6 +167,8 @@ async def cancel_proposed_changes_branch(branch_name: str) -> None:
@task(name="Cancel a propose change", description="Cancel a propose change")
async def cancel_proposed_change(proposed_change: CoreProposedChange) -> None:
service = services.service

await add_tags(nodes=[proposed_change.id])
log = get_run_logger()

log.info("Canceling proposed change as the source branch was deleted")
Expand All @@ -177,6 +185,8 @@ async def run_proposed_change_data_integrity_check(model: RequestProposedChangeD
"""Triggers a data integrity validation check on the provided proposed change to start."""

service = services.service
await add_tags(nodes=[model.proposed_change])

destination_branch = await registry.get_branch(db=service.database, branch=model.destination_branch)
source_branch = await registry.get_branch(db=service.database, branch=model.source_branch)
component_registry = get_component_registry()
Expand All @@ -187,10 +197,12 @@ async def run_proposed_change_data_integrity_check(model: RequestProposedChangeD

@flow(
name="proposed-changed-run-generator",
flow_run_name="Run generators related to proposed change {model.proposed_change}",
flow_run_name="Run generators",
)
async def run_generators(model: RequestProposedChangeRunGenerators) -> None:
service = services.service
await add_tags(nodes=[model.proposed_change])

generators = await service.client.filters(
kind=CoreGeneratorDefinition,
prefetch_relationships=True,
Expand Down Expand Up @@ -277,14 +289,16 @@ async def run_generators(model: RequestProposedChangeRunGenerators) -> None:

@flow(
name="proposed-changed-schema-integrity",
flow_run_name="Got a request to process schema integrity defined in proposed_change: {model.proposed_change}",
flow_run_name="Process schema integrity",
)
async def run_proposed_change_schema_integrity_check(
model: RequestProposedChangeSchemaIntegrity,
) -> None:
# For now, we retrieve the latest schema for each branch from the registry
# In the future it would be good to generate the object SchemaUpdateValidationResult from message.branch_diff
service = services.service
await add_tags(nodes=[model.proposed_change])

source_schema = registry.schema.get_schema_branch(name=model.source_branch).duplicate()
dest_schema = registry.schema.get_schema_branch(name=model.destination_branch).duplicate()

Expand Down Expand Up @@ -370,6 +384,8 @@ async def _get_proposed_change_schema_integrity_constraints(
)
async def repository_checks(model: RequestProposedChangeRepositoryChecks) -> None:
service = services.service
await add_tags(nodes=[model.proposed_change])

events: list[InfrahubMessage] = []
for repository in model.branch_diff.repositories:
if (
Expand Down
12 changes: 6 additions & 6 deletions backend/infrahub/workflows/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,22 @@
from infrahub.services import InfrahubServices


async def add_tags(tags: list[str]) -> None:
async def add_tags(branches: list[str] | None = None, nodes: list[str] | None = None) -> None:
client = get_client(sync_client=False)
current_flow_run_id = flow_run.id
current_tags: list[str] = flow_run.tags
new_tags = current_tags + tags
branch_tags = [WorkflowTag.BRANCH.render(identifier=branch_name) for branch_name in branches] if branches else []
node_tags = [WorkflowTag.RELATED_NODE.render(identifier=node_id) for node_id in nodes] if nodes else []
new_tags = set(current_tags + branch_tags + node_tags)
await client.update_flow_run(current_flow_run_id, tags=list(new_tags))


async def add_branch_tag(branch_name: str) -> None:
tag = WorkflowTag.BRANCH.render(identifier=branch_name)
await add_tags(tags=[tag])
await add_tags(branches=[branch_name])


async def add_related_node_tag(node_id: str) -> None:
tag = WorkflowTag.RELATED_NODE.render(identifier=node_id)
await add_tags(tags=[tag])
await add_tags(nodes=[node_id])


async def wait_for_schema_to_converge(
Expand Down

0 comments on commit 258020a

Please sign in to comment.