Skip to content

Commit

Permalink
Merge pull request #5102 from opsmill/lgu-schema-migrate
Browse files Browse the repository at this point in the history
Migrate schema validation / schema migration
  • Loading branch information
LucasG0 authored Dec 2, 2024
2 parents 76c968f + 1450381 commit 65aa5aa
Show file tree
Hide file tree
Showing 25 changed files with 160 additions and 752 deletions.
17 changes: 11 additions & 6 deletions backend/infrahub/api/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@
)
from infrahub.core.schema import GenericSchema, MainSchemaTypes, NodeSchema, ProfileSchema, SchemaRoot
from infrahub.core.schema.constants import SchemaNamespace # noqa: TCH001
from infrahub.core.validators.models.validate_migration import SchemaValidateMigrationData
from infrahub.core.validators.models.validate_migration import (
SchemaValidateMigrationData,
SchemaValidatorPathResponseData,
)
from infrahub.database import InfrahubDatabase # noqa: TCH001
from infrahub.events import EventMeta
from infrahub.events.schema_action import SchemaUpdatedEvent
Expand Down Expand Up @@ -307,13 +310,14 @@ async def load_schema(
schema_branch=candidate_schema,
constraints=result.constraints,
)
error_messages = await service.workflow.execute_workflow(
responses = await service.workflow.execute_workflow(
workflow=SCHEMA_VALIDATE_MIGRATION,
expected_return=list[str],
expected_return=list[SchemaValidatorPathResponseData],
parameters={"message": validate_migration_data},
)
error_messages = [violation.message for response in responses for violation in response.violations]
if error_messages:
raise SchemaNotValidError(message=",\n".join(error_messages))
raise SchemaNotValidError(",\n".join(error_messages))

# ----------------------------------------------------------
# Update the schema
Expand Down Expand Up @@ -402,11 +406,12 @@ async def check_schema(
schema_branch=candidate_schema,
constraints=result.constraints,
)
error_messages = await service.workflow.execute_workflow(
responses = await service.workflow.execute_workflow(
workflow=SCHEMA_VALIDATE_MIGRATION,
expected_return=list[str],
expected_return=list[SchemaValidatorPathResponseData],
parameters={"message": validate_migration_data},
)
error_messages = [violation.message for response in responses for violation in response.violations]
if error_messages:
raise SchemaNotValidError(message=",\n".join(error_messages))

Expand Down
4 changes: 3 additions & 1 deletion backend/infrahub/core/branch/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,11 @@ async def rebase_branch(branch: str) -> None:
if obj.has_schema_changes:
constraints += await merger.calculate_validations(target_schema=candidate_schema)
if constraints:
error_messages = await schema_validate_migrations(
responses = await schema_validate_migrations(
message=SchemaValidateMigrationData(branch=obj, schema_branch=candidate_schema, constraints=constraints)
)

error_messages = [violation.message for response in responses for violation in response.violations]
if error_messages:
raise ValidationError(",\n".join(error_messages))

Expand Down
10 changes: 9 additions & 1 deletion backend/infrahub/core/migrations/schema/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, Field

from infrahub.core.branch import Branch
from infrahub.core.models import SchemaUpdateMigrationInfo
from infrahub.core.path import SchemaPath
from infrahub.core.schema.schema_branch import SchemaBranch


Expand All @@ -13,3 +14,10 @@ class SchemaApplyMigrationData(BaseModel):
new_schema: SchemaBranch
previous_schema: SchemaBranch
migrations: list[SchemaUpdateMigrationInfo]


class SchemaMigrationPathResponseData(BaseModel):
errors: list[str] = Field(default_factory=list)
migration_name: str | None = None
nbr_migrations_executed: int | None = None
schema_path: SchemaPath | None = None
67 changes: 0 additions & 67 deletions backend/infrahub/core/migrations/schema/runner.py

This file was deleted.

72 changes: 58 additions & 14 deletions backend/infrahub/core/migrations/schema/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,25 @@
from typing import TYPE_CHECKING, Optional

from infrahub_sdk.batch import InfrahubBatch
from prefect import flow
from prefect import flow, task
from prefect.logging import get_run_logger

from infrahub.message_bus.messages.schema_migration_path import (
SchemaMigrationPathData,
)
from infrahub.message_bus.operations.schema.migration import schema_path_migrate
from infrahub.core.branch import Branch # noqa: TCH001
from infrahub.core.migrations import MIGRATION_MAP
from infrahub.core.path import SchemaPath # noqa: TCH001
from infrahub.services import services
from infrahub.workflows.utils import add_branch_tag

from .models import SchemaApplyMigrationData # noqa: TCH001
from .models import SchemaApplyMigrationData, SchemaMigrationPathResponseData

if TYPE_CHECKING:
from infrahub.core.schema import MainSchemaTypes


@flow(name="schema_apply_migrations", flow_run_name="Apply schema migrations")
async def schema_apply_migrations(message: SchemaApplyMigrationData) -> list[str]:
service = services.service
await add_branch_tag(branch_name=message.branch.name)
log = get_run_logger()

batch = InfrahubBatch()
error_messages: list[str] = []
Expand All @@ -30,10 +30,7 @@ async def schema_apply_migrations(message: SchemaApplyMigrationData) -> list[str
return error_messages

for migration in message.migrations:
service.log.info(
f"Preparing migration for {migration.migration_name!r} ({migration.routing_key})",
branch=message.branch.name,
)
log.info(f"Preparing migration for {migration.migration_name!r} ({migration.routing_key})")

new_node_schema: Optional[MainSchemaTypes] = None
previous_node_schema: Optional[MainSchemaTypes] = None
Expand All @@ -51,17 +48,64 @@ async def schema_apply_migrations(message: SchemaApplyMigrationData) -> list[str
f"Unable to find the previous version of the schema for {migration.path.schema_kind}, in order to run the migration."
)

msg = SchemaMigrationPathData(
batch.add(
task=schema_path_migrate,
branch=message.branch,
migration_name=migration.migration_name,
new_node_schema=new_node_schema,
previous_node_schema=previous_node_schema,
schema_path=migration.path,
)

batch.add(task=schema_path_migrate, message=msg)

async for _, result in batch.execute():
error_messages.extend(result.errors)

return error_messages


@task(
name="schema-path-migrate",
task_run_name="Migrate Schema Path {migration_name} on {branch.name}",
description="Apply a given migration to the database",
retries=3,
)
async def schema_path_migrate(
branch: Branch,
migration_name: str,
schema_path: SchemaPath,
new_node_schema: MainSchemaTypes | None = None,
previous_node_schema: MainSchemaTypes | None = None,
) -> SchemaMigrationPathResponseData:
service = services.service
log = get_run_logger()

async with service.database.start_session() as db:
node_kind = None
if new_node_schema:
node_kind = new_node_schema.kind
elif previous_node_schema:
node_kind = previous_node_schema.kind

log.info(
f"Migration for {node_kind} starting {schema_path.get_path()}",
)
migration_class = MIGRATION_MAP.get(migration_name)
if not migration_class:
raise ValueError(f"Unable to find the migration class for {migration_name}")

migration = migration_class( # type: ignore[call-arg]
new_node_schema=new_node_schema, # type: ignore[arg-type]
previous_node_schema=previous_node_schema, # type: ignore[arg-type]
schema_path=schema_path,
)
execution_result = await migration.execute(db=db, branch=branch)

log.info(f"Migration completed for {migration_name}")
log.debug(f"execution_result {execution_result}")

return SchemaMigrationPathResponseData(
migration_name=migration_name,
schema_path=schema_path,
errors=execution_result.errors,
nbr_migrations_executed=execution_result.nbr_migrations_executed,
)
71 changes: 0 additions & 71 deletions backend/infrahub/core/validators/checker.py

This file was deleted.

14 changes: 13 additions & 1 deletion backend/infrahub/core/validators/models/validate_migration.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, Field

from infrahub.core.branch import Branch
from infrahub.core.models import SchemaUpdateConstraintInfo
from infrahub.core.path import SchemaPath
from infrahub.core.schema.schema_branch import SchemaBranch
from infrahub.core.validators.model import SchemaViolation
from infrahub.message_bus import InfrahubResponseData


class SchemaValidateMigrationData(BaseModel):
Expand All @@ -12,3 +15,12 @@ class SchemaValidateMigrationData(BaseModel):
branch: Branch
schema_branch: SchemaBranch
constraints: list[SchemaUpdateConstraintInfo]


class SchemaValidatorPathResponseData(InfrahubResponseData):
violations: list[SchemaViolation] = Field(default_factory=list)
constraint_name: str
schema_path: SchemaPath

def get_messages(self) -> list[str]:
return [violation.message for violation in self.violations]
Loading

0 comments on commit 65aa5aa

Please sign in to comment.