Skip to content

Commit

Permalink
Add Pagination to Delete Commands (#543)
Browse files Browse the repository at this point in the history
* added batching to the delete call

* fixed small typo

* added unit test and did formatting

* added batching to the delete call

* fixed small typo

* added unit test and did formatting

* fix magic number
  • Loading branch information
ben-githubs authored Sep 25, 2024
1 parent 43f588f commit 8e27e06
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 28 deletions.
87 changes: 59 additions & 28 deletions panther_analysis_tool/backend/public_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional
from typing import Any, Dict, Generator, List, Optional, Sequence
from urllib.parse import urlparse

from gql import Client as GraphQLClient
Expand Down Expand Up @@ -166,6 +166,10 @@ class PublicAPIClient(Client): # pylint: disable=too-many-public-methods
_requests: PublicAPIRequests
_gql_client: GraphQLClient

# backend's delete function can only handle 100 IDs at a time, due to DynamoDB restrictions
# https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/ServiceQuotas.html#limits-expression-parameters
_DELETE_BATCH_SIZE = 100

def __init__(self, opts: PublicAPIClientOptions):
self._user_id = opts.user_id
self._requests = PublicAPIRequests()
Expand Down Expand Up @@ -329,23 +333,29 @@ def transpile_filters(
def delete_saved_queries(
self, params: DeleteSavedQueriesParams
) -> BackendResponse[DeleteSavedQueriesResponse]:
query = self._requests.delete_saved_queries()
delete_params = {
"input": {
"dryRun": params.dry_run,
"includeDetections": params.include_detections,
"names": params.names,
data: Dict = {"names": [], "detectionIDs": []}
for name_batch in _batched(params.names, self._DELETE_BATCH_SIZE):
gql_params = {
"input": {
"dryRun": params.dry_run,
"includeDetections": params.include_detections,
"names": name_batch,
}
}
}
res = self._execute(query, variable_values=delete_params)
res = self._execute(self._requests.delete_saved_queries(), variable_values=gql_params)

if res.errors:
raise BackendError(res.errors)
if res.errors:
for err in res.errors:
logging.error(err.message)

if res.data is None:
raise BackendError("empty data")
raise BackendError(res.errors)

if res.data is None:
raise BackendError("empty data")

data = res.data.get("deleteSavedQueriesByName", {})
query_data = res.data.get("deleteSavedQueriesByName", {})
for field in ("names", "detectionIDs"):
data[field] += query_data.get(field) or []

return BackendResponse(
status_code=200,
Expand All @@ -358,24 +368,29 @@ def delete_saved_queries(
def delete_detections(
self, params: DeleteDetectionsParams
) -> BackendResponse[DeleteDetectionsResponse]:
gql_params = {
"input": {
"dryRun": params.dry_run,
"includeSavedQueries": params.include_saved_queries,
"ids": params.ids,
data: Dict = {"ids": [], "savedQueryNames": []}
for id_batch in _batched(params.ids, self._DELETE_BATCH_SIZE):
gql_params = {
"input": {
"dryRun": params.dry_run,
"includeSavedQueries": params.include_saved_queries,
"ids": id_batch,
}
}
}
res = self._execute(self._requests.delete_detections_query(), gql_params)
if res.errors:
for err in res.errors:
logging.error(err.message)
res = self._execute(self._requests.delete_detections_query(), gql_params)

raise BackendError(res.errors)
if res.errors:
for err in res.errors:
logging.error(err.message)

if res.data is None:
raise BackendError("empty data")
raise BackendError(res.errors)

data = res.data.get("deleteDetections", {})
if res.data is None:
raise BackendError("empty data")

query_data = res.data.get("deleteDetections", {})
for field in ("ids", "savedQueryNames"):
data[field] += query_data.get(field) or []

return BackendResponse(
status_code=200,
Expand Down Expand Up @@ -693,3 +708,19 @@ def _build_api_url(host: str) -> str:
def _get_graphql_content_filepath(name: str) -> str:
work_dir = os.path.dirname(__file__)
return os.path.join(work_dir, "graphql", f"{name}.graphql")


def _batched(iterable: Sequence, size: int = 1) -> Generator[Sequence, None, None]:
"""Batch data from 'iterable' into chunks of length 'size'. The last batch may be shorter than 'size'.
Inspired by itertools.batched in Python version 3.12+.
Args:
iterable (any iterable): a sequence or other iterable to be batched
size (int, optional): the maximum size of each batch. default=1
Yields:
out (iterable): a batch of size 'size' or smaller
"""
length = len(iterable)
for idx in range(0, length, size):
yield iterable[idx : min(idx + size, length)]
48 changes: 48 additions & 0 deletions tests/unit/panther_analysis_tool/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import panther_analysis_tool.constants
from panther_analysis_tool import util as pat_utils
from panther_analysis_tool.backend.public_api_client import _batched
from panther_analysis_tool.util import convert_unicode


Expand Down Expand Up @@ -200,3 +201,50 @@ def test_is_policy(self):
for case in test_cases:
res = pat_utils.is_policy(case["analysis_type"])
self.assertEqual(case["expected"], res)


class TestBatched(unittest.TestCase):
def test_batched_with_remainder(self):
iterable = [1] * 12
n = 5
expected_batches = 3
modulo = 2 # Size of last batch

batches = list(_batched(iterable, n))
# Ensure we recieved the expected number of batches
self.assertEqual(len(batches), expected_batches)
# Confirm all but the last batch have the same size
for batch in batches[:-1]:
self.assertEqual(len(list(batch)), n)
# Confirm the last batch has the expected number of entries
self.assertEqual(len(list(batches[-1])), modulo)

def test_batched_with_no_remainder(self):
iterable = [1] * 100
n = 10
expected_batches = 10
modulo = 10 # Size of last batch

batches = list(_batched(iterable, n))
# Ensure we recieved the expected number of batches
self.assertEqual(len(batches), expected_batches)
# Confirm all but the last batch have the same size
for batch in batches[:-1]:
self.assertEqual(len(list(batch)), n)
# Confirm the last batch has the expected number of entries
self.assertEqual(len(list(batches[-1])), modulo)

def test_batched_with_no_full_batches(self):
iterable = [1] * 3
n = 5
expected_batches = 1
modulo = 3 # Size of last batch

batches = list(_batched(iterable, n))
# Ensure we recieved the expected number of batches
self.assertEqual(len(batches), expected_batches)
# Confirm all but the last batch have the same size
for batch in batches[:-1]:
self.assertEqual(len(list(batch)), n)
# Confirm the last batch has the expected number of entries
self.assertEqual(len(list(batches[-1])), modulo)

0 comments on commit 8e27e06

Please sign in to comment.