diff --git a/backend/infrahub/graphql/mutations/ipam.py b/backend/infrahub/graphql/mutations/ipam.py index 3c9472cf9b..cfb03b0cd9 100644 --- a/backend/infrahub/graphql/mutations/ipam.py +++ b/backend/infrahub/graphql/mutations/ipam.py @@ -17,6 +17,8 @@ from infrahub.graphql.mutations.node_getter.interface import MutationNodeGetterInterface from infrahub.log import get_logger +from ... import lock +from ...lock import InfrahubMultiLock, build_object_lock_name from .main import InfrahubMutationMixin, InfrahubMutationOptions if TYPE_CHECKING: @@ -106,12 +108,14 @@ async def mutate_create( ip_address = ipaddress.ip_interface(data["address"]["value"]) namespace_id = await validate_namespace(db=db, data=data) - async with db.start_transaction() as dbt: - address = await cls.mutate_create_object(data=data, db=dbt, branch=branch) - reconciler = IpamReconciler(db=dbt, branch=branch) - reconciled_address = await reconciler.reconcile( - ip_value=ip_address, namespace=namespace_id, node_uuid=address.get_id() - ) + lock_name = build_object_lock_name(cls._meta.schema.kind) + async with InfrahubMultiLock(lock_registry=lock.registry, locks=[lock_name]): + async with db.start_transaction() as dbt: + address = await cls.mutate_create_object(data=data, db=dbt, branch=branch) + reconciler = IpamReconciler(db=dbt, branch=branch) + reconciled_address = await reconciler.reconcile( + ip_value=ip_address, namespace=namespace_id, node_uuid=address.get_id() + ) result = await cls.mutate_create_to_graphql(info=info, db=db, obj=reconciled_address) @@ -141,13 +145,15 @@ async def mutate_update( namespace = await address.ip_namespace.get_peer(db) namespace_id = await validate_namespace(db=db, data=data, existing_namespace_id=namespace.id) try: - async with db.start_transaction() as dbt: - address = await cls.mutate_update_object(db=dbt, info=info, data=data, branch=branch, obj=address) - reconciler = IpamReconciler(db=dbt, branch=branch) - ip_address = ipaddress.ip_interface(address.address.value) - reconciled_address = await reconciler.reconcile( - ip_value=ip_address, node_uuid=address.get_id(), namespace=namespace_id - ) + lock_name = build_object_lock_name(cls._meta.schema.kind) + async with InfrahubMultiLock(lock_registry=lock.registry, locks=[lock_name]): + async with db.start_transaction() as dbt: + address = await cls.mutate_update_object(db=dbt, info=info, data=data, branch=branch, obj=address) + reconciler = IpamReconciler(db=dbt, branch=branch) + ip_address = ipaddress.ip_interface(address.address.value) + reconciled_address = await reconciler.reconcile( + ip_value=ip_address, node_uuid=address.get_id(), namespace=namespace_id + ) result = await cls.mutate_update_to_graphql(db=dbt, info=info, obj=reconciled_address) except ValidationError as exc: @@ -216,12 +222,14 @@ async def mutate_create( ip_network = ipaddress.ip_network(data["prefix"]["value"]) namespace_id = await validate_namespace(db=db, data=data) - async with db.start_transaction() as dbt: - prefix = await cls.mutate_create_object(data=data, db=dbt, branch=branch) - reconciler = IpamReconciler(db=dbt, branch=branch) - reconciled_prefix = await reconciler.reconcile( - ip_value=ip_network, namespace=namespace_id, node_uuid=prefix.get_id() - ) + lock_name = build_object_lock_name(cls._meta.schema.kind) + async with InfrahubMultiLock(lock_registry=lock.registry, locks=[lock_name]): + async with db.start_transaction() as dbt: + prefix = await cls.mutate_create_object(data=data, db=dbt, branch=branch) + reconciler = IpamReconciler(db=dbt, branch=branch) + reconciled_prefix = await reconciler.reconcile( + ip_value=ip_network, namespace=namespace_id, node_uuid=prefix.get_id() + ) result = await cls.mutate_create_to_graphql(info=info, db=db, obj=reconciled_prefix) @@ -251,13 +259,15 @@ async def mutate_update( namespace = await prefix.ip_namespace.get_peer(db) namespace_id = await validate_namespace(db=db, data=data, existing_namespace_id=namespace.id) try: - async with db.start_transaction() as dbt: - prefix = await cls.mutate_update_object(db=dbt, info=info, data=data, branch=branch, obj=prefix) - reconciler = IpamReconciler(db=dbt, branch=branch) - ip_network = ipaddress.ip_network(prefix.prefix.value) - reconciled_prefix = await reconciler.reconcile( - ip_value=ip_network, node_uuid=prefix.get_id(), namespace=namespace_id - ) + lock_name = build_object_lock_name(cls._meta.schema.kind) + async with InfrahubMultiLock(lock_registry=lock.registry, locks=[lock_name]): + async with db.start_transaction() as dbt: + prefix = await cls.mutate_update_object(db=dbt, info=info, data=data, branch=branch, obj=prefix) + reconciler = IpamReconciler(db=dbt, branch=branch) + ip_network = ipaddress.ip_network(prefix.prefix.value) + reconciled_prefix = await reconciler.reconcile( + ip_value=ip_network, node_uuid=prefix.get_id(), namespace=namespace_id + ) result = await cls.mutate_update_to_graphql(db=dbt, info=info, obj=reconciled_prefix) except ValidationError as exc: @@ -302,12 +312,14 @@ async def mutate_delete( namespace_rels = await prefix.ip_namespace.get_relationships(db=db) namespace_id = namespace_rels[0].peer_id try: - async with context.db.start_transaction() as dbt: - reconciler = IpamReconciler(db=dbt, branch=branch) - ip_network = ipaddress.ip_network(prefix.prefix.value) - reconciled_prefix = await reconciler.reconcile( - ip_value=ip_network, node_uuid=prefix.get_id(), namespace=namespace_id, is_delete=True - ) + lock_name = build_object_lock_name(cls._meta.schema.kind) + async with InfrahubMultiLock(lock_registry=lock.registry, locks=[lock_name]): + async with context.db.start_transaction() as dbt: + reconciler = IpamReconciler(db=dbt, branch=branch) + ip_network = ipaddress.ip_network(prefix.prefix.value) + reconciled_prefix = await reconciler.reconcile( + ip_value=ip_network, node_uuid=prefix.get_id(), namespace=namespace_id, is_delete=True + ) except ValidationError as exc: raise ValueError(str(exc)) from exc diff --git a/backend/infrahub/graphql/mutations/main.py b/backend/infrahub/graphql/mutations/main.py index dac437fce2..6a027a466e 100644 --- a/backend/infrahub/graphql/mutations/main.py +++ b/backend/infrahub/graphql/mutations/main.py @@ -7,7 +7,7 @@ from infrahub_sdk.utils import extract_fields from typing_extensions import Self -from infrahub import config +from infrahub import config, lock from infrahub.auth import validate_mutation_permissions_update_node from infrahub.core import registry from infrahub.core.constants import MutationAction @@ -25,6 +25,7 @@ from infrahub.log import get_log_data, get_logger from infrahub.worker import WORKER_IDENTITY +from ...lock import InfrahubMultiLock, build_object_lock_name, get_kinds_to_lock_on_object_mutation from .node_getter.by_default_filter import MutationNodeGetterByDefaultFilter from .node_getter.by_hfid import MutationNodeGetterByHfid from .node_getter.by_id import MutationNodeGetterById @@ -134,6 +135,20 @@ async def _refresh_for_profile_update( ) return obj + @classmethod + async def _call_mutate_create_object(cls, data: InputObjectType, db: InfrahubDatabase, branch: Branch): + """ + Wrapper around mutate_create_object to potentially activate locking. + """ + + lock_kinds = get_kinds_to_lock_on_object_mutation(cls._meta.schema) + if lock_kinds: + lock_names = [build_object_lock_name(kind) for kind in lock_kinds] + async with InfrahubMultiLock(lock_registry=lock.registry, locks=lock_names): + return await cls.mutate_create_object(data=data, db=db, branch=branch) + + return await cls.mutate_create_object(data=data, db=db, branch=branch) + @classmethod async def mutate_create( cls, @@ -144,7 +159,7 @@ async def mutate_create( ) -> tuple[Node, Self]: context: GraphqlContext = info.context db = database or context.db - obj = await cls.mutate_create_object(data=data, db=db, branch=branch) + obj = cls._call_mutate_create_object(data=data, db=db, branch=branch) result = await cls.mutate_create_to_graphql(info=info, db=db, obj=obj) return obj, result @@ -189,6 +204,40 @@ async def mutate_create_to_graphql(cls, info: GraphQLResolveInfo, db: InfrahubDa result["object"] = await obj.to_graphql(db=db, fields=fields.get("object", {})) return cls(**result) + @classmethod + async def _call_mutate_update( + cls, + info: GraphQLResolveInfo, + data: InputObjectType, + branch: Branch, + db: InfrahubDatabase, + obj: Node, + ) -> tuple[Node, Self]: + """ + Wrapper around mutate_update to potentially activate locking and call it within a database transaction. + """ + + lock_kinds = get_kinds_to_lock_on_object_mutation(cls._meta.schema) + lock_names = [build_object_lock_name(kind) for kind in lock_kinds] + + if db.is_transaction: + if lock_names: + async with InfrahubMultiLock(lock_registry=lock.registry, locks=lock_names): + obj = await cls.mutate_update_object(db=db, info=info, data=data, branch=branch, obj=obj) + else: + obj = await cls.mutate_update_object(db=db, info=info, data=data, branch=branch, obj=obj) + result = await cls.mutate_update_to_graphql(db=db, info=info, obj=obj) + return obj, result + + async with db.start_transaction() as dbt: + if lock_names: + async with InfrahubMultiLock(lock_registry=lock.registry, locks=lock_names): + obj = await cls.mutate_update_object(db=dbt, info=info, data=data, branch=branch, obj=obj) + else: + obj = await cls.mutate_update_object(db=db, info=info, data=data, branch=branch, obj=obj) + result = await cls.mutate_update_to_graphql(db=dbt, info=info, obj=obj) + return obj, result + @classmethod @retry_db_transaction(name="object_update") async def mutate_update( @@ -207,13 +256,7 @@ async def mutate_update( ) try: - if db.is_transaction: - obj = await cls.mutate_update_object(db=db, info=info, data=data, branch=branch, obj=obj) - result = await cls.mutate_update_to_graphql(db=db, info=info, obj=obj) - else: - async with db.start_transaction() as dbt: - obj = await cls.mutate_update_object(db=dbt, info=info, data=data, branch=branch, obj=obj) - result = await cls.mutate_update_to_graphql(db=dbt, info=info, obj=obj) + obj, result = await cls._call_mutate_update(info=info, data=data, db=db, branch=branch, obj=obj) except ValidationError as exc: raise ValueError(str(exc)) from exc diff --git a/backend/infrahub/lock.py b/backend/infrahub/lock.py index 5ee28cb0eb..425b4aa4c4 100644 --- a/backend/infrahub/lock.py +++ b/backend/infrahub/lock.py @@ -16,6 +16,7 @@ if TYPE_CHECKING: from types import TracebackType + from infrahub.core.schema.generated.base_node_schema import GeneratedBaseNodeSchema from infrahub.services import InfrahubServices registry: InfrahubLockRegistry = None @@ -46,8 +47,8 @@ class InfrahubMultiLock: """Context manager to allow multiple locks to be reserved together""" - def __init__(self, _registry: InfrahubLockRegistry, locks: Optional[list[str]] = None) -> None: - self.registry = _registry + def __init__(self, lock_registry: InfrahubLockRegistry, locks: Optional[list[str]] = None) -> None: + self.registry = lock_registry self.locks = locks or [] async def __aenter__(self): @@ -245,12 +246,45 @@ async def local_schema_wait(self) -> None: await self.get(name=LOCAL_SCHEMA_LOCK).event.wait() def global_schema_lock(self) -> InfrahubMultiLock: - return InfrahubMultiLock(_registry=self, locks=[LOCAL_SCHEMA_LOCK, GLOBAL_SCHEMA_LOCK]) + return InfrahubMultiLock(lock_registry=self, locks=[LOCAL_SCHEMA_LOCK, GLOBAL_SCHEMA_LOCK]) def global_graph_lock(self) -> InfrahubMultiLock: - return InfrahubMultiLock(_registry=self, locks=[LOCAL_SCHEMA_LOCK, GLOBAL_GRAPH_LOCK, GLOBAL_SCHEMA_LOCK]) + return InfrahubMultiLock(lock_registry=self, locks=[LOCAL_SCHEMA_LOCK, GLOBAL_GRAPH_LOCK, GLOBAL_SCHEMA_LOCK]) def initialize_lock(local_only: bool = False, service: Optional[InfrahubServices] = None) -> None: global registry # pylint: disable=global-statement registry = InfrahubLockRegistry(local_only=local_only, service=service) + + +def build_object_lock_name(name: str) -> str: + return f"global.object.{name}" + + +def get_kinds_to_lock_on_object_mutation(node_schema: GeneratedBaseNodeSchema) -> list[str]: + """ + Return kinds of which we want to lock during creating / updating an object of a given schema node. + Lock should be performed on schema kind and its generics having a uniqueness_constraint defined. + Note that if a generic uniqueness constraint is the same as the node schema uniqueness constraint, then + it means node schema overrided this constraint, in which case we only need to lock on the generic. + """ + + schema_uc = None + kinds = [] + if node_schema.uniqueness_constraints: + kinds.append(node_schema.kind) + schema_uc = node_schema.uniqueness_constraints + + try: + ancestors_kinds = node_schema.inherit_from + except AttributeError: + return kinds + + node_schema_kind_removed = False + for kind in ancestors_kinds: + kinds.append(kind) + uc = registry.schema.get(name=kind).uniqueness_constraints + if not node_schema_kind_removed and uc is not None and uc == schema_uc: + kinds.pop(0) # pop original schema + node_schema_kind_removed = True + return kinds diff --git a/backend/tests/integration/ipam/test_load_concurrent_prefixes.py b/backend/tests/integration/ipam/test_load_concurrent_prefixes.py new file mode 100644 index 0000000000..d7fe59327e --- /dev/null +++ b/backend/tests/integration/ipam/test_load_concurrent_prefixes.py @@ -0,0 +1,33 @@ +import ipaddress + +from infrahub.database import InfrahubDatabase +from tests.integration.ipam.base import TestIpam + + +# See https://github.com/opsmill/infrahub/issues/4523 +class TestLoadConcurrentPrefixes(TestIpam): + async def test_load_concurrent_prefixes( + self, + db: InfrahubDatabase, + default_branch, + client, + default_ipnamespace, + register_ipam_schema, + ): + prefixes_batch = await client.create_batch() + network_8 = ipaddress.IPv4Network("10.0.0.0/8") + networks_16 = list(network_8.subnets(new_prefix=16)) + + networks = [network_8] + networks_16[0:10] + + for network in networks: + prefix = await client.create("IpamIPPrefix", prefix=f"{network}") + prefixes_batch.add(task=prefix.save, node=prefix, allow_upsert=True) + + async for _, _ in prefixes_batch.execute(): + pass + + nodes = await client.all("IpamIPPrefix", prefetch_relationships=True, populate_store=True) + for n in nodes: + if n.prefix.value != network_8: + assert n.parent.peer.prefix.value == network_8