Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[monitorlib, mock_uss] Add idempotent handler decorator #241

Merged
merged 4 commits into from
Oct 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions monitoring/mock_uss/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,3 @@ class Database(ImplicitDict):
Database(one_time_tasks=[], task_errors=[], periodic_tasks={}),
decoder=lambda b: ImplicitDict.parse(json.loads(b.decode("utf-8")), Database),
)

fulfilled_request_ids = SynchronizedValue(
[], decoder=lambda b: json.loads(b.decode("utf-8"))
)
15 changes: 2 additions & 13 deletions monitoring/mock_uss/ridsp/routes_injection.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from . import database
from .database import db
from monitoring.monitorlib import geo
from ..database import fulfilled_request_ids
from monitoring.monitorlib.idempotency import idempotent_request

require_config_value(KEY_BASE_URL)
require_config_value(KEY_RID_VERSION)
Expand All @@ -34,6 +34,7 @@ class ErrorResponse(ImplicitDict):

@webapp.route("/ridsp/injection/tests/<test_id>", methods=["PUT"])
@requires_scope([injection_api.SCOPE_RID_QUALIFIER_INJECT])
@idempotent_request()
def ridsp_create_test(test_id: str) -> Tuple[str, int]:
"""Implements test creation in RID automated testing injection API."""
logger.info(f"Create test {test_id}")
Expand All @@ -51,18 +52,6 @@ def ridsp_create_test(test_id: str) -> Tuple[str, int]:
except ValueError as e:
msg = "Create test {} unable to parse JSON: {}".format(test_id, e)
return msg, 400
if "request_id" in json:
logger.debug(f"[ridsp_create_test] Request ID {json['request_id']}")
with fulfilled_request_ids as tx:
if json["request_id"] in tx:
logger.debug(
f"[ridsp_create_test] Already processed request ID {json['request_id']}"
)
return (
f"Request ID {json['request_id']} has already been fulfilled",
400,
)
tx.append(json["request_id"])

# Create ISA in DSS
(t0, t1) = req_body.get_span()
Expand Down
26 changes: 3 additions & 23 deletions monitoring/mock_uss/scdsc/routes_injection.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
from monitoring.mock_uss import webapp, require_config_value
from monitoring.mock_uss.auth import requires_scope
from monitoring.mock_uss.config import KEY_BASE_URL
from monitoring.mock_uss.database import fulfilled_request_ids
from monitoring.mock_uss.dynamic_configuration.configuration import get_locality
from monitoring.mock_uss.scdsc import database, utm_client
from monitoring.mock_uss.scdsc.database import db
Expand All @@ -50,6 +49,7 @@
from monitoring.monitorlib.fetch import QueryError
from monitoring.monitorlib.geo import Polygon
from monitoring.monitorlib.geotemporal import Volume4D, Volume4DCollection
from monitoring.monitorlib.idempotency import idempotent_request
from monitoring.monitorlib.scd_automated_testing.scd_injection_api import (
SCOPE_SCD_QUALIFIER_INJECT,
)
Expand Down Expand Up @@ -149,6 +149,7 @@ def scd_capabilities() -> Tuple[dict, int]:

@webapp.route("/scdsc/v1/flights/<flight_id>", methods=["PUT"])
@requires_scope([SCOPE_SCD_QUALIFIER_INJECT])
@idempotent_request()
def scdsc_inject_flight(flight_id: str) -> Tuple[str, int]:
"""Implements flight injection in SCD automated testing injection API."""
logger.debug(f"[inject_flight/{os.getpid()}:{flight_id}] Starting handler")
Expand All @@ -160,20 +161,6 @@ def scdsc_inject_flight(flight_id: str) -> Tuple[str, int]:
except ValueError as e:
msg = "Create flight {} unable to parse JSON: {}".format(flight_id, e)
return msg, 400
if "request_id" in json:
logger.debug(
f"[inject_flight/{os.getpid()}:{flight_id}] Request ID {json['request_id']}"
)
with fulfilled_request_ids as tx:
if json["request_id"] in tx:
logger.debug(
f"[inject_flight/{os.getpid()}:{flight_id}] Already processed request ID {json['request_id']}"
)
return (
f"Request ID {json['request_id']} has already been fulfilled",
400,
)
tx.append(json["request_id"])
json, code = inject_flight(flight_id, req_body)
return flask.jsonify(json), code

Expand Down Expand Up @@ -492,6 +479,7 @@ def delete_flight(flight_id) -> Tuple[dict, int]:

@webapp.route("/scdsc/v1/clear_area_requests", methods=["POST"])
@requires_scope([SCOPE_SCD_QUALIFIER_INJECT])
@idempotent_request()
def scdsc_clear_area() -> Tuple[str, int]:
try:
json = flask.request.json
Expand All @@ -501,14 +489,6 @@ def scdsc_clear_area() -> Tuple[str, int]:
except ValueError as e:
msg = "Unable to parse ClearAreaRequest JSON request: {}".format(e)
return msg, 400
with fulfilled_request_ids as tx:
logger.debug(f"[scdsc_clear_area] Processing request ID {req.request_id}")
if req.request_id in tx:
logger.debug(
f"[scdsc_clear_area] Already processed request ID {req.request_id}"
)
return f"Request ID {req.request_id} has already been fulfilled", 400
tx.append(json["request_id"])
json, code = clear_area(req)
return flask.jsonify(json), code

Expand Down
157 changes: 157 additions & 0 deletions monitoring/monitorlib/idempotency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import base64
import hashlib
from functools import wraps
import json
from typing import Callable, Optional, Dict

import arrow
import flask
from loguru import logger

from implicitdict import ImplicitDict, StringBasedDateTime
from monitoring.monitorlib.multiprocessing import SynchronizedValue


_max_request_buffer_size = int(10e6)
"""Number of bytes to dedicate to caching responses"""


class Response(ImplicitDict):
"""Information about a previously-returned response.

Note that this object is never actually used (in order to maximize performance); instead it serves as documentation
of the structure of the fields within a plain JSON dict/object."""

json: Optional[dict]
body: Optional[str]
code: int
timestamp: StringBasedDateTime


def _get_responses(raw: bytes) -> Dict[str, Response]:
return json.loads(raw.decode("utf-8"))


def _set_responses(responses: Dict[str, Response]) -> bytes:
while True:
s = json.dumps(responses)
if len(s) <= _max_request_buffer_size:
break

# Remove oldest cached response
oldest_id = None
oldest_timestamp = None
for request_id, response in responses.items():
t = arrow.get(response["timestamp"])
if oldest_timestamp is None or t < oldest_timestamp:
oldest_id = request_id
oldest_timestamp = t

del responses[oldest_id]
return s.encode("utf-8")


_fulfilled_requests = SynchronizedValue(
{},
decoder=_get_responses,
encoder=_set_responses,
capacity_bytes=_max_request_buffer_size,
)


def get_hashed_request_id() -> Optional[str]:
"""Retrieves an identifier for the request by hashing key characteristics of the request."""
characteristics = flask.request.method + flask.request.url
if flask.request.json:
characteristics += json.dumps(flask.request.json)
else:
characteristics += flask.request.data.decode("utf-8")
return base64.b64encode(
hashlib.sha512(characteristics.encode("utf-8")).digest()
).decode("utf-8")


def idempotent_request(get_request_id: Optional[Callable[[], Optional[str]]] = None):
"""Decorator for idempotent Flask view handlers.

When subsequent requests are received with the same request identifier, this decorator will use a recent cached
response instead of invoking the underlying handler when possible. Note that there is no verification that the rest
of the request (apart from the request ID) is identical, so a request with different content but the same request ID
will receive the cached response from the first request. A developer could compute a request ID based on a hash of
important request characteristics to control this behavior.

Note that cached response characteristics are limited and the full original response is not produced verbatim.
"""
if get_request_id is None:
get_request_id = get_hashed_request_id

def outer_wrapper(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
request_id = get_request_id()

cached_requests = _fulfilled_requests.value
if request_id in cached_requests:
endpoint = (
flask.request.url_rule.rule
if flask.request.url_rule is not None
else "unknown endpoint"
)
logger.warning(
"Fulfilling {} {} with cached response for request {}",
flask.request.method,
endpoint,
request_id,
)
response = cached_requests[request_id]
if response["body"] is not None:
return response["body"], response["code"]
else:
return flask.jsonify(response["json"]), response["code"]

result = fn(*args, **kwargs)

response = {
"timestamp": arrow.utcnow().isoformat(),
"code": 200,
"body": None,
"json": None,
}
keep_code = False
if isinstance(result, tuple):
if len(result) == 2:
if not isinstance(result[1], int):
raise NotImplementedError(
f"Unable to cache Flask view handler result where the second 2-tuple element is a '{type(result[1]).__name__}'"
)
response["code"] = result[1]
keep_code = True
result = result[0]
else:
raise NotImplementedError(
f"Unable to cache Flask view handler result which is a tuple of ({', '.join(type(v).__name__ for v in result)})"
)

if isinstance(result, str):
response["body"] = result
response["json"] = None
elif isinstance(result, flask.Response):
try:
response["json"] = result.get_json()
except ValueError:
response["body"] = result.get_data(as_text=True)
if not keep_code:
response["code"] = result.status_code
else:
raise NotImplementedError(
f"Unable to cache Flask view handler result of type '{type(result).__name__}'"
)

with _fulfilled_requests as cached_requests:
cached_requests[request_id] = response

return result

return wrapper

return outer_wrapper
31 changes: 21 additions & 10 deletions monitoring/monitorlib/multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ class SynchronizedValue(object):
> {"foo":"baz"}
"""

SIZE_BYTES = 4
"""Number of bytes at the beginning of the memory buffer dedicated to defining the size of the content."""

_lock: multiprocessing.RLock
_shared_memory: multiprocessing.shared_memory.SharedMemory
_encoder: Callable[[Any], bytes]
Expand All @@ -43,7 +46,7 @@ def __init__(
"""
self._lock = multiprocessing.RLock()
self._shared_memory = multiprocessing.shared_memory.SharedMemory(
create=True, size=capacity_bytes
create=True, size=capacity_bytes + self.SIZE_BYTES
)
self._encoder = (
encoder
Expand All @@ -57,27 +60,35 @@ def __init__(
self._set_value(initial_value)

def _get_value(self):
content_len = int.from_bytes(bytes(self._shared_memory.buf[0:4]), "big")
if content_len + 4 > self._shared_memory.size:
content_len = int.from_bytes(
bytes(self._shared_memory.buf[0 : self.SIZE_BYTES]), "big"
)
if content_len + self.SIZE_BYTES > self._shared_memory.size:
raise RuntimeError(
"Shared memory claims to have {} bytes of content when buffer size is only {}".format(
content_len, self._shared_memory.size
"Shared memory claims to have {} bytes of content when buffer size only allows {}".format(
content_len, self._shared_memory.size - self.SIZE_BYTES
)
)
content = bytes(self._shared_memory.buf[4 : content_len + 4])
content = bytes(
self._shared_memory.buf[self.SIZE_BYTES : content_len + self.SIZE_BYTES]
)
return self._decoder(content)

def _set_value(self, value):
content = self._encoder(value)
content_len = len(content)
if content_len + 4 > self._shared_memory.size:
if content_len + self.SIZE_BYTES > self._shared_memory.size:
raise RuntimeError(
"Tried to write {} bytes into a SynchronizedValue with only {} bytes of capacity".format(
content_len, self._shared_memory.size
content_len, self._shared_memory.size - self.SIZE_BYTES
)
)
self._shared_memory.buf[0:4] = content_len.to_bytes(4, "big")
self._shared_memory.buf[4 : content_len + 4] = content
self._shared_memory.buf[0 : self.SIZE_BYTES] = content_len.to_bytes(
self.SIZE_BYTES, "big"
)
self._shared_memory.buf[
self.SIZE_BYTES : content_len + self.SIZE_BYTES
] = content

@property
def value(self):
Expand Down
Loading