Skip to content

Commit

Permalink
feat: Support --exclusive-start-key option for `ensure_identity_tra…
Browse files Browse the repository at this point in the history
…its_blanks`
  • Loading branch information
khvn26 committed Jan 2, 2025
1 parent b60af94 commit f9e32f2
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 16 deletions.
29 changes: 22 additions & 7 deletions api/edge_api/management/commands/ensure_identity_traits_blanks.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,42 @@
import json
from argparse import ArgumentParser
from typing import Any

import structlog
from django.core.management import BaseCommand
from structlog import get_logger
from structlog.stdlib import BoundLogger

from environments.dynamodb import DynamoIdentityWrapper

identity_wrapper = DynamoIdentityWrapper()

logger: structlog.BoundLogger = structlog.get_logger()

LOG_COUNT_EVERY = 100_000


class Command(BaseCommand):
def handle(self, *args: Any, **options: Any) -> None:
def add_arguments(self, parser: ArgumentParser) -> None:
parser.add_argument(
"--exclusive-start-key",
dest="exclusive_start_key",
type=str,
default="",
help="Exclusive start key in valid JSON",
)

def handle(self, *args: Any, exclusive_start_key: str, **options: Any) -> None:
total_count = identity_wrapper.table.item_count
scanned_count = 0
fixed_count = 0
scanned_count = scanned_percentage = fixed_count = 0

log: structlog.BoundLogger = logger.bind(total_count=total_count)

kwargs = {}
if exclusive_start_key:
kwargs["ExclusiveStartKey"] = json.loads(exclusive_start_key)

log: BoundLogger = get_logger(total_count=total_count)
log.info("started")

for identity_document in identity_wrapper.query_get_all_items():
for identity_document in identity_wrapper.scan_iter_all_items(**kwargs):
should_write_identity_document = False

if identity_traits_data := identity_document.get("identity_traits"):
Expand Down
46 changes: 41 additions & 5 deletions api/environments/dynamodb/wrappers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,36 @@

import boto3
import boto3.dynamodb.types
import structlog
from botocore.config import Config
from sentry_sdk import set_context

if typing.TYPE_CHECKING:
from mypy_boto3_dynamodb.service_resource import Table
from mypy_boto3_dynamodb.type_defs import (
QueryOutputTableTypeDef,
ScanOutputTableTypeDef,
TableAttributeValueTypeDef,
)

DynamoDBOutput = QueryOutputTableTypeDef | ScanOutputTableTypeDef

P = typing.ParamSpec("P")

# Avoid `decimal.Rounded` when reading large numbers
# See https://github.com/boto/boto3/issues/2500
boto3.dynamodb.types.DYNAMODB_CONTEXT = Context(prec=100)


logger: structlog.BoundLogger = structlog.get_logger()


class BaseDynamoWrapper:
table_name: str = None

def __init__(self) -> None:
self._table: typing.Optional["Table"] = None
self._log = logger.bind(table_name=self.table_name)

@property
def table(self) -> typing.Optional["Table"]:
Expand All @@ -40,14 +54,20 @@ def get_table(self) -> typing.Optional["Table"]:
def is_enabled(self) -> bool:
return self.table is not None

def query_get_all_items(self, **kwargs: dict) -> typing.Generator[dict, None, None]:
if kwargs:
response_getter = partial(self.table.query, **kwargs)
else:
response_getter = partial(self.table.scan)
def _iter_all_items(
self,
response_getter_method: "typing.Callable[[P], DynamoDBOutput]",
**kwargs: "P.kwargs",
) -> typing.Generator[dict[str, "TableAttributeValueTypeDef"], None, None]:
response_getter = partial(response_getter_method, **kwargs)
set_context(
"dynamodb",
{"table_name": self.table_name, **kwargs},
)

while True:
query_response = response_getter()

for item in query_response["Items"]:
yield item

Expand All @@ -56,3 +76,19 @@ def query_get_all_items(self, **kwargs: dict) -> typing.Generator[dict, None, No
break

response_getter.keywords["ExclusiveStartKey"] = last_evaluated_key
set_context(
"dynamodb",
{"table_name": self.table_name, **response_getter.keywords},
)

def scan_iter_all_items(
self,
**kwargs: typing.Any,
) -> typing.Generator[dict[str, "TableAttributeValueTypeDef"], None, None]:
return self._iter_all_items(self.table.scan, **kwargs)

def query_iter_all_items(
self,
**kwargs: typing.Any,
) -> typing.Generator[dict[str, "TableAttributeValueTypeDef"], None, None]:
return self._iter_all_items(self.table.query, **kwargs)
4 changes: 2 additions & 2 deletions api/environments/dynamodb/wrappers/environment_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def get_identity_overrides_by_environment_id(
) -> typing.List[dict[str, Any]]:
try:
return list(
self.query_get_all_items(
self.query_iter_all_items(
KeyConditionExpression=Key(ENVIRONMENTS_V2_PARTITION_KEY).eq(
str(environment_id),
)
Expand Down Expand Up @@ -122,7 +122,7 @@ def delete_environment(self, environment_id: int):
"ProjectionExpression": "document_key",
}
with self.table.batch_writer() as writer:
for item in self.query_get_all_items(**query_kwargs):
for item in self.query_iter_all_items(**query_kwargs):
writer.delete_item(
Key={
ENVIRONMENTS_V2_PARTITION_KEY: environment_id,
Expand Down
2 changes: 1 addition & 1 deletion api/tests/integration/edge_api/identities/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def identity_overrides_v2(
edge_identity.save(admin_user)
return [
item["document_key"]
for item in dynamodb_wrapper_v2.query_get_all_items(
for item in dynamodb_wrapper_v2.query_iter_all_items(
KeyConditionExpression=Key("environment_id").eq(
str(dynamo_enabled_environment)
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def test_delete_identity(
KeyConditionExpression=Key("identity_uuid").eq(identity_uuid),
)["Count"]
assert not list(
dynamodb_wrapper_v2.query_get_all_items(
dynamodb_wrapper_v2.query_iter_all_items(
KeyConditionExpression=Key("environment_id").eq(
str(dynamo_enabled_environment)
)
Expand Down
22 changes: 22 additions & 0 deletions api/tests/unit/edge_api/test_unit_edge_api_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,3 +218,25 @@ def test_ensure_identity_traits_blanks__logs_expected(
"total_count": 11,
},
]


def test_ensure_identity_traits_blanks__exclusive_start_key__calls_expected(
flagsmith_identities_table: "Table",
mocker: "MockerFixture",
) -> None:
# Given
exclusive_start_key = '{"composite_key":"test_hello"}'
expected_kwargs = {"ExclusiveStartKey": {"composite_key": "test_hello"}}

identity_wrapper_mock = mocker.patch(
"edge_api.management.commands.ensure_identity_traits_blanks.identity_wrapper"
)

# When
call_command(
"ensure_identity_traits_blanks",
exclusive_start_key=exclusive_start_key,
)

# Then
identity_wrapper_mock.scan_get_all_items.assert_called_once_with(**expected_kwargs)

0 comments on commit f9e32f2

Please sign in to comment.