Skip to content

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
amh4r committed Oct 24, 2024
1 parent 8dc6d8e commit c038e37
Show file tree
Hide file tree
Showing 8 changed files with 108 additions and 60 deletions.
2 changes: 1 addition & 1 deletion inngest/experimental/remote_state_middleware/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
- Avoiding step output size limits.
NOT STABLE! This is an experimental feature and may change in the future. If
you'd like to use it, we recommend copying this file into your source code.
you'd like to use it, we recommend copying this package into your source code.
"""

from .in_memory_driver import InMemoryDriver
Expand Down
72 changes: 48 additions & 24 deletions inngest/experimental/remote_state_middleware/s3_driver.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,27 @@
from __future__ import annotations

import json
import secrets
import string
import typing

import boto3
import pydantic
import typing_extensions

import inngest

from .middleware import StateDriver

if typing.TYPE_CHECKING:
from mypy_boto3_s3 import S3Client


class _StateSurrogate(pydantic.BaseModel):
"""
Replaces step output sent back to Inngest. Its data is sufficient to
retrieve the actual state.
"""

class _StatePlaceholder(pydantic.BaseModel):
bucket: str
key: str

Expand All @@ -30,45 +40,55 @@ class S3Driver(StateDriver):

_strategy_identifier: typing.Final = "inngest/s3"

def __init__( # noqa: D107
def __init__(
self,
*,
bucket: str,
endpoint_url: typing.Optional[str] = None,
region_name: str,
client: S3Client,
) -> None:
"""
Args:
----
bucket: Bucket name to store remote state.
client: Boto3 S3 client.
"""

self._bucket = bucket
self._client = boto3.client(
"s3",
endpoint_url=endpoint_url,
region_name=region_name,
)
self._client = client

def _create_key(self) -> str:
chars = string.ascii_letters + string.digits
return "".join(secrets.choice(chars) for _ in range(32))

def _is_remote(
self, data: object
) -> typing_extensions.TypeGuard[dict[str, object]]:
return (
isinstance(data, dict)
and self._marker in data
and self._strategy_marker in data
and data[self._strategy_marker] == self._strategy_identifier
)

def load_steps(self, steps: inngest.StepMemos) -> None:
"""
Hydrate steps with remote state if necessary.
Args:
----
steps: Steps that may need hydration.
"""

for step in steps.values():
if not isinstance(step.data, dict):
continue
if self._marker not in step.data:
continue
if self._strategy_marker not in step.data:
continue
if step.data[self._strategy_marker] != self._strategy_identifier:
if not self._is_remote(step.data):
continue

placeholder = _StatePlaceholder.model_validate(step.data)
surrogate = _StateSurrogate.model_validate(step.data)

step.data = json.loads(
self._client.get_object(
Bucket=placeholder.bucket,
Key=placeholder.key,
Bucket=surrogate.bucket,
Key=surrogate.key,
)["Body"]
.read()
.decode()
Expand All @@ -81,20 +101,24 @@ def save_step(
) -> dict[str, object]:
"""
Save a step's output to the remote store and return a placeholder.
Args:
----
run_id: Run ID.
value: Step output.
"""

key = f"inngest/remote_state/{run_id}/{self._create_key()}"
self._client.create_bucket(Bucket=self._bucket)
self._client.put_object(
Body=json.dumps(value),
Bucket=self._bucket,
Key=key,
)

placeholder: dict[str, object] = {
surrogate = {
self._marker: True,
self._strategy_marker: self._strategy_identifier,
**_StatePlaceholder(bucket=self._bucket, key=key).model_dump(),
**_StateSurrogate(bucket=self._bucket, key=key).model_dump(),
}

return placeholder
return surrogate
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ def pytest_configure(config: pytest.Config) -> None:


def pytest_unconfigure(config: pytest.Config) -> None:
print("pytest_unconfigure")
dev_server.server.stop()
43 changes: 29 additions & 14 deletions tests/net.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,36 @@
import contextlib
import random
import socket
import time
import typing

HOST: typing.Final = "0.0.0.0"
_min_port: typing.Final = 9000
_max_port: typing.Final = 9999

_used_ports: set[int] = set()


def get_available_port() -> int:
for port in range(_min_port, _max_port + 1):
with contextlib.closing(
socket.socket(socket.AF_INET, socket.SOCK_STREAM)
) as sock:
try:
sock.bind((HOST, port))
return port
except OSError:
continue

raise Exception("failed to find available port")
start_time = time.time()

while True:
if time.time() - start_time > 5:
raise Exception("timeout finding available port")

port = random.randint(9000, 9999)

if port in _used_ports:
continue

if not _is_port_available(port):
continue

_used_ports.add(port)
return port


def _is_port_available(port: int) -> bool:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
try:
s.bind((HOST, port))
return True
except OSError:
return False
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from . import cases

_framework = server_lib.Framework.FLASK
_app_id = f"{_framework.value}-functions"
_app_id = f"{_framework.value}-encryption-middleware"

_client = inngest.Inngest(
api_base_url=dev_server.server.origin,
Expand All @@ -32,7 +32,7 @@
_fns.append(case.fn)


class TestFunctions(unittest.IsolatedAsyncioTestCase):
class TestEncryptionMiddleware(unittest.IsolatedAsyncioTestCase):
app: flask.testing.FlaskClient
client: inngest.Inngest
dev_server_port: int
Expand Down Expand Up @@ -78,7 +78,7 @@ def on_proxy_request(

for case in _cases:
test_name = f"test_{case.name}"
setattr(TestFunctions, test_name, case.run_test)
setattr(TestEncryptionMiddleware, test_name, case.run_test)


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import inngest
from inngest._internal import server_lib

from . import base, step_failed, step_output_in_memory, step_output_s3
from . import base, step_failed, step_output_aws, step_output_in_memory

_modules = (
step_failed,
step_output_in_memory,
step_output_s3,
step_output_aws,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ class _State(base.BaseState):
events: list[inngest.Event]


@moto.mock_aws
def create(
client: inngest.Inngest,
framework: server_lib.Framework,
Expand All @@ -33,20 +32,26 @@ def create(
fn_id = base.create_fn_id(test_name)
state = _State()

aws_server = moto.server.ThreadedMotoServer(port=net.get_available_port())
aws_server.start()
aws_host, aws_port = aws_server.get_host_and_port()
aws_port = net.get_available_port()
aws_url = f"http://localhost:{aws_port}"
aws_access_key_id = "test"
aws_secret_access_key = "test"
aws_region = "us-east-1"
s3_bucket = "inngest"

s3_client = boto3.client(
"s3",
endpoint_url=aws_url,
region_name=aws_region,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
)

conn = boto3.resource("s3", region_name="us-east-1")
conn.create_bucket(Bucket="inngest")
driver = remote_state_middleware.S3Driver(
bucket="inngest",
endpoint_url=f"http://{aws_host}:{aws_port}",
region_name="us-east-1",
bucket=s3_bucket,
client=s3_client,
)

driver.save_step("run_id", "value")

@client.create_function(
fn_id=fn_id,
middleware=[
Expand Down Expand Up @@ -104,15 +109,18 @@ def _step_2() -> list[inngest.JSON]:
return "function output"

async def run_test(self: base.TestClass) -> None:
self.client.send_sync(inngest.Event(name=event_name))
aws_server = moto.server.ThreadedMotoServer(port=aws_port)
aws_server.start()

s3_client.create_bucket(Bucket=s3_bucket)

self.client.send_sync(inngest.Event(name=event_name))
run_id = state.wait_for_run_id()
run = tests.helper.client.wait_for_run_status(
run_id,
tests.helper.RunStatus.COMPLETED,
)

# Ensure that step_1 output is encrypted and its value is correct
output = json.loads(
tests.helper.client.get_step_output(
run_id=run_id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from . import cases

_framework = server_lib.Framework.FLASK
_app_id = f"{_framework.value}-functions"
_app_id = f"{_framework.value}-remote-state-middleware"

_client = inngest.Inngest(
api_base_url=dev_server.server.origin,
Expand All @@ -32,7 +32,7 @@
_fns.append(case.fn)


class TestFunctions(unittest.IsolatedAsyncioTestCase):
class TestRemoteStateMiddleware(unittest.IsolatedAsyncioTestCase):
app: flask.testing.FlaskClient
client: inngest.Inngest
dev_server_port: int
Expand Down Expand Up @@ -78,7 +78,7 @@ def on_proxy_request(

for case in _cases:
test_name = f"test_{case.name}"
setattr(TestFunctions, test_name, case.run_test)
setattr(TestRemoteStateMiddleware, test_name, case.run_test)


if __name__ == "__main__":
Expand Down

0 comments on commit c038e37

Please sign in to comment.