Skip to content

Commit

Permalink
Add lock on object creation
Browse files Browse the repository at this point in the history
  • Loading branch information
LucasG0 committed Dec 12, 2024
1 parent 90c67bf commit 0a24ff8
Show file tree
Hide file tree
Showing 4 changed files with 167 additions and 45 deletions.
76 changes: 44 additions & 32 deletions backend/infrahub/graphql/mutations/ipam.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from infrahub.graphql.mutations.node_getter.interface import MutationNodeGetterInterface
from infrahub.log import get_logger

from ... import lock
from ...lock import InfrahubMultiLock, build_object_lock_name
from .main import InfrahubMutationMixin, InfrahubMutationOptions

if TYPE_CHECKING:
Expand Down Expand Up @@ -106,12 +108,14 @@ async def mutate_create(
ip_address = ipaddress.ip_interface(data["address"]["value"])
namespace_id = await validate_namespace(db=db, data=data)

async with db.start_transaction() as dbt:
address = await cls.mutate_create_object(data=data, db=dbt, branch=branch)
reconciler = IpamReconciler(db=dbt, branch=branch)
reconciled_address = await reconciler.reconcile(
ip_value=ip_address, namespace=namespace_id, node_uuid=address.get_id()
)
lock_name = build_object_lock_name(cls._meta.schema.kind)
async with InfrahubMultiLock(lock_registry=lock.registry, locks=[lock_name]):
async with db.start_transaction() as dbt:
address = await cls.mutate_create_object(data=data, db=dbt, branch=branch)
reconciler = IpamReconciler(db=dbt, branch=branch)
reconciled_address = await reconciler.reconcile(
ip_value=ip_address, namespace=namespace_id, node_uuid=address.get_id()
)

result = await cls.mutate_create_to_graphql(info=info, db=db, obj=reconciled_address)

Expand Down Expand Up @@ -141,13 +145,15 @@ async def mutate_update(
namespace = await address.ip_namespace.get_peer(db)
namespace_id = await validate_namespace(db=db, data=data, existing_namespace_id=namespace.id)
try:
async with db.start_transaction() as dbt:
address = await cls.mutate_update_object(db=dbt, info=info, data=data, branch=branch, obj=address)
reconciler = IpamReconciler(db=dbt, branch=branch)
ip_address = ipaddress.ip_interface(address.address.value)
reconciled_address = await reconciler.reconcile(
ip_value=ip_address, node_uuid=address.get_id(), namespace=namespace_id
)
lock_name = build_object_lock_name(cls._meta.schema.kind)
async with InfrahubMultiLock(lock_registry=lock.registry, locks=[lock_name]):
async with db.start_transaction() as dbt:
address = await cls.mutate_update_object(db=dbt, info=info, data=data, branch=branch, obj=address)
reconciler = IpamReconciler(db=dbt, branch=branch)
ip_address = ipaddress.ip_interface(address.address.value)
reconciled_address = await reconciler.reconcile(
ip_value=ip_address, node_uuid=address.get_id(), namespace=namespace_id
)

result = await cls.mutate_update_to_graphql(db=dbt, info=info, obj=reconciled_address)
except ValidationError as exc:
Expand Down Expand Up @@ -216,12 +222,14 @@ async def mutate_create(
ip_network = ipaddress.ip_network(data["prefix"]["value"])
namespace_id = await validate_namespace(db=db, data=data)

async with db.start_transaction() as dbt:
prefix = await cls.mutate_create_object(data=data, db=dbt, branch=branch)
reconciler = IpamReconciler(db=dbt, branch=branch)
reconciled_prefix = await reconciler.reconcile(
ip_value=ip_network, namespace=namespace_id, node_uuid=prefix.get_id()
)
lock_name = build_object_lock_name(cls._meta.schema.kind)
async with InfrahubMultiLock(lock_registry=lock.registry, locks=[lock_name]):
async with db.start_transaction() as dbt:
prefix = await cls.mutate_create_object(data=data, db=dbt, branch=branch)
reconciler = IpamReconciler(db=dbt, branch=branch)
reconciled_prefix = await reconciler.reconcile(
ip_value=ip_network, namespace=namespace_id, node_uuid=prefix.get_id()
)

result = await cls.mutate_create_to_graphql(info=info, db=db, obj=reconciled_prefix)

Expand Down Expand Up @@ -251,13 +259,15 @@ async def mutate_update(
namespace = await prefix.ip_namespace.get_peer(db)
namespace_id = await validate_namespace(db=db, data=data, existing_namespace_id=namespace.id)
try:
async with db.start_transaction() as dbt:
prefix = await cls.mutate_update_object(db=dbt, info=info, data=data, branch=branch, obj=prefix)
reconciler = IpamReconciler(db=dbt, branch=branch)
ip_network = ipaddress.ip_network(prefix.prefix.value)
reconciled_prefix = await reconciler.reconcile(
ip_value=ip_network, node_uuid=prefix.get_id(), namespace=namespace_id
)
lock_name = build_object_lock_name(cls._meta.schema.kind)
async with InfrahubMultiLock(lock_registry=lock.registry, locks=[lock_name]):
async with db.start_transaction() as dbt:
prefix = await cls.mutate_update_object(db=dbt, info=info, data=data, branch=branch, obj=prefix)
reconciler = IpamReconciler(db=dbt, branch=branch)
ip_network = ipaddress.ip_network(prefix.prefix.value)
reconciled_prefix = await reconciler.reconcile(
ip_value=ip_network, node_uuid=prefix.get_id(), namespace=namespace_id
)

result = await cls.mutate_update_to_graphql(db=dbt, info=info, obj=reconciled_prefix)
except ValidationError as exc:
Expand Down Expand Up @@ -302,12 +312,14 @@ async def mutate_delete(
namespace_rels = await prefix.ip_namespace.get_relationships(db=db)
namespace_id = namespace_rels[0].peer_id
try:
async with context.db.start_transaction() as dbt:
reconciler = IpamReconciler(db=dbt, branch=branch)
ip_network = ipaddress.ip_network(prefix.prefix.value)
reconciled_prefix = await reconciler.reconcile(
ip_value=ip_network, node_uuid=prefix.get_id(), namespace=namespace_id, is_delete=True
)
lock_name = build_object_lock_name(cls._meta.schema.kind)
async with InfrahubMultiLock(lock_registry=lock.registry, locks=[lock_name]):
async with context.db.start_transaction() as dbt:
reconciler = IpamReconciler(db=dbt, branch=branch)
ip_network = ipaddress.ip_network(prefix.prefix.value)
reconciled_prefix = await reconciler.reconcile(
ip_value=ip_network, node_uuid=prefix.get_id(), namespace=namespace_id, is_delete=True
)
except ValidationError as exc:
raise ValueError(str(exc)) from exc

Expand Down
61 changes: 52 additions & 9 deletions backend/infrahub/graphql/mutations/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from infrahub_sdk.utils import extract_fields
from typing_extensions import Self

from infrahub import config
from infrahub import config, lock
from infrahub.auth import validate_mutation_permissions_update_node
from infrahub.core import registry
from infrahub.core.constants import MutationAction
Expand All @@ -25,6 +25,7 @@
from infrahub.log import get_log_data, get_logger
from infrahub.worker import WORKER_IDENTITY

from ...lock import InfrahubMultiLock, build_object_lock_name, get_kinds_to_lock_on_object_mutation
from .node_getter.by_default_filter import MutationNodeGetterByDefaultFilter
from .node_getter.by_hfid import MutationNodeGetterByHfid
from .node_getter.by_id import MutationNodeGetterById
Expand Down Expand Up @@ -134,6 +135,20 @@ async def _refresh_for_profile_update(
)
return obj

@classmethod
async def _call_mutate_create_object(cls, data: InputObjectType, db: InfrahubDatabase, branch: Branch):
"""
Wrapper around mutate_create_object to potentially activate locking.
"""

lock_kinds = get_kinds_to_lock_on_object_mutation(cls._meta.schema)
if lock_kinds:
lock_names = [build_object_lock_name(kind) for kind in lock_kinds]
async with InfrahubMultiLock(lock_registry=lock.registry, locks=lock_names):
return await cls.mutate_create_object(data=data, db=db, branch=branch)

return await cls.mutate_create_object(data=data, db=db, branch=branch)

@classmethod
async def mutate_create(
cls,
Expand All @@ -144,7 +159,7 @@ async def mutate_create(
) -> tuple[Node, Self]:
context: GraphqlContext = info.context
db = database or context.db
obj = await cls.mutate_create_object(data=data, db=db, branch=branch)
obj = cls._call_mutate_create_object(data=data, db=db, branch=branch)
result = await cls.mutate_create_to_graphql(info=info, db=db, obj=obj)
return obj, result

Expand Down Expand Up @@ -189,6 +204,40 @@ async def mutate_create_to_graphql(cls, info: GraphQLResolveInfo, db: InfrahubDa
result["object"] = await obj.to_graphql(db=db, fields=fields.get("object", {}))
return cls(**result)

@classmethod
async def _call_mutate_update(
cls,
info: GraphQLResolveInfo,
data: InputObjectType,
branch: Branch,
db: InfrahubDatabase,
obj: Node,
) -> tuple[Node, Self]:
"""
Wrapper around mutate_update to potentially activate locking and call it within a database transaction.
"""

lock_kinds = get_kinds_to_lock_on_object_mutation(cls._meta.schema)
lock_names = [build_object_lock_name(kind) for kind in lock_kinds]

if db.is_transaction:
if lock_names:
async with InfrahubMultiLock(lock_registry=lock.registry, locks=lock_names):
obj = await cls.mutate_update_object(db=db, info=info, data=data, branch=branch, obj=obj)
else:
obj = await cls.mutate_update_object(db=db, info=info, data=data, branch=branch, obj=obj)
result = await cls.mutate_update_to_graphql(db=db, info=info, obj=obj)
return obj, result

async with db.start_transaction() as dbt:
if lock_names:
async with InfrahubMultiLock(lock_registry=lock.registry, locks=lock_names):
obj = await cls.mutate_update_object(db=dbt, info=info, data=data, branch=branch, obj=obj)
else:
obj = await cls.mutate_update_object(db=db, info=info, data=data, branch=branch, obj=obj)
result = await cls.mutate_update_to_graphql(db=dbt, info=info, obj=obj)
return obj, result

@classmethod
@retry_db_transaction(name="object_update")
async def mutate_update(
Expand All @@ -207,13 +256,7 @@ async def mutate_update(
)

try:
if db.is_transaction:
obj = await cls.mutate_update_object(db=db, info=info, data=data, branch=branch, obj=obj)
result = await cls.mutate_update_to_graphql(db=db, info=info, obj=obj)
else:
async with db.start_transaction() as dbt:
obj = await cls.mutate_update_object(db=dbt, info=info, data=data, branch=branch, obj=obj)
result = await cls.mutate_update_to_graphql(db=dbt, info=info, obj=obj)
obj, result = await cls._call_mutate_update(info=info, data=data, db=db, branch=branch, obj=obj)
except ValidationError as exc:
raise ValueError(str(exc)) from exc

Expand Down
42 changes: 38 additions & 4 deletions backend/infrahub/lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
if TYPE_CHECKING:
from types import TracebackType

from infrahub.core.schema.generated.base_node_schema import GeneratedBaseNodeSchema
from infrahub.services import InfrahubServices

registry: InfrahubLockRegistry = None
Expand Down Expand Up @@ -46,8 +47,8 @@
class InfrahubMultiLock:
"""Context manager to allow multiple locks to be reserved together"""

def __init__(self, _registry: InfrahubLockRegistry, locks: Optional[list[str]] = None) -> None:
self.registry = _registry
def __init__(self, lock_registry: InfrahubLockRegistry, locks: Optional[list[str]] = None) -> None:
self.registry = lock_registry
self.locks = locks or []

async def __aenter__(self):
Expand Down Expand Up @@ -245,12 +246,45 @@ async def local_schema_wait(self) -> None:
await self.get(name=LOCAL_SCHEMA_LOCK).event.wait()

def global_schema_lock(self) -> InfrahubMultiLock:
return InfrahubMultiLock(_registry=self, locks=[LOCAL_SCHEMA_LOCK, GLOBAL_SCHEMA_LOCK])
return InfrahubMultiLock(lock_registry=self, locks=[LOCAL_SCHEMA_LOCK, GLOBAL_SCHEMA_LOCK])

def global_graph_lock(self) -> InfrahubMultiLock:
return InfrahubMultiLock(_registry=self, locks=[LOCAL_SCHEMA_LOCK, GLOBAL_GRAPH_LOCK, GLOBAL_SCHEMA_LOCK])
return InfrahubMultiLock(lock_registry=self, locks=[LOCAL_SCHEMA_LOCK, GLOBAL_GRAPH_LOCK, GLOBAL_SCHEMA_LOCK])


def initialize_lock(local_only: bool = False, service: Optional[InfrahubServices] = None) -> None:
global registry # pylint: disable=global-statement
registry = InfrahubLockRegistry(local_only=local_only, service=service)


def build_object_lock_name(name: str) -> str:
return f"global.object.{name}"


def get_kinds_to_lock_on_object_mutation(node_schema: GeneratedBaseNodeSchema) -> list[str]:
"""
Return kinds of which we want to lock during creating / updating an object of a given schema node.
Lock should be performed on schema kind and its generics having a uniqueness_constraint defined.
Note that if a generic uniqueness constraint is the same as the node schema uniqueness constraint, then
it means node schema overrided this constraint, in which case we only need to lock on the generic.
"""

schema_uc = None
kinds = []
if node_schema.uniqueness_constraints:
kinds.append(node_schema.kind)
schema_uc = node_schema.uniqueness_constraints

try:
ancestors_kinds = node_schema.inherit_from
except AttributeError:
return kinds

node_schema_kind_removed = False
for kind in ancestors_kinds:
kinds.append(kind)
uc = registry.schema.get(name=kind).uniqueness_constraints
if not node_schema_kind_removed and uc is not None and uc == schema_uc:
kinds.pop(0) # pop original schema
node_schema_kind_removed = True
return kinds
33 changes: 33 additions & 0 deletions backend/tests/integration/ipam/test_load_concurrent_prefixes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import ipaddress

from infrahub.database import InfrahubDatabase
from tests.integration.ipam.base import TestIpam


# See https://github.com/opsmill/infrahub/issues/4523
class TestLoadConcurrentPrefixes(TestIpam):
async def test_load_concurrent_prefixes(
self,
db: InfrahubDatabase,
default_branch,
client,
default_ipnamespace,
register_ipam_schema,
):
prefixes_batch = await client.create_batch()
network_8 = ipaddress.IPv4Network("10.0.0.0/8")
networks_16 = list(network_8.subnets(new_prefix=16))

networks = [network_8] + networks_16[0:10]

for network in networks:
prefix = await client.create("IpamIPPrefix", prefix=f"{network}")
prefixes_batch.add(task=prefix.save, node=prefix, allow_upsert=True)

async for _, _ in prefixes_batch.execute():
pass

nodes = await client.all("IpamIPPrefix", prefetch_relationships=True, populate_store=True)
for n in nodes:
if n.prefix.value != network_8:
assert n.parent.peer.prefix.value == network_8

0 comments on commit 0a24ff8

Please sign in to comment.