Skip to content

Commit

Permalink
try
Browse files Browse the repository at this point in the history
  • Loading branch information
amh4r committed Oct 24, 2024
1 parent 4a5f577 commit 5a3f9d3
Show file tree
Hide file tree
Showing 8 changed files with 423 additions and 180 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ install: check-venv
@pip install -e '.[extra]' -c constraints.txt

itest: check-venv
@pytest -n 4 -v tests
@pytest -svv tests/test_experimental/test_remote_state_middleware -k memory

pre-commit: format-check lint type-check utest

Expand Down
15 changes: 7 additions & 8 deletions inngest/experimental/remote_state_middleware/s3_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import typing

import boto3
import mypy_boto3_s3
import pydantic

import inngest
Expand Down Expand Up @@ -33,16 +34,15 @@ class S3Driver(StateDriver):
def __init__( # noqa: D107
self,
*,
# aws_access_key_id: typing.Optional[str] = None,
# aws_secret_access_key: typing.Optional[str] = None,
bucket: str,
endpoint_url: typing.Optional[str] = None,
region_name: str,
client: mypy_boto3_s3.S3Client,
# endpoint_url: typing.Optional[str] = None,
# region_name: str,
) -> None:
self._bucket = bucket
self._client = boto3.client(
"s3",
endpoint_url=endpoint_url,
region_name=region_name,
)
self._client = client

def _create_key(self) -> str:
chars = string.ascii_letters + string.digits
Expand Down Expand Up @@ -84,7 +84,6 @@ def save_step(
"""

key = f"inngest/remote_state/{run_id}/{self._create_key()}"
self._client.create_bucket(Bucket=self._bucket)
self._client.put_object(
Body=json.dumps(value),
Bucket=self._bucket,
Expand Down
43 changes: 29 additions & 14 deletions tests/net.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,36 @@
import contextlib
import random
import socket
import time
import typing

HOST: typing.Final = "0.0.0.0"
_min_port: typing.Final = 9000
_max_port: typing.Final = 9999

_used_ports: set[int] = set()


def get_available_port() -> int:
for port in range(_min_port, _max_port + 1):
with contextlib.closing(
socket.socket(socket.AF_INET, socket.SOCK_STREAM)
) as sock:
try:
sock.bind((HOST, port))
return port
except OSError:
continue

raise Exception("failed to find available port")
start_time = time.time()

while True:
if time.time() - start_time > 5:
raise Exception("timeout finding available port")

port = random.randint(9000, 9999)

if port in _used_ports:
continue

if not _is_port_available(port):
continue

_used_ports.add(port)
return port


def _is_port_available(port: int) -> bool:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
try:
s.bind((HOST, port))
return True
except OSError:
return False
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import inngest
from inngest._internal import server_lib

from . import base, step_failed, step_output_in_memory, step_output_s3
from . import base, step_failed, step_output_aws, step_output_in_memory

_modules = (
step_failed,
step_output_in_memory,
step_output_s3,
step_output_aws,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
class TestClass(typing.Protocol):
client: inngest.Inngest

def addCleanup(self, func: typing.Callable) -> None:
...


@dataclasses.dataclass
class Case:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
"""
Ensure step and function output is encrypted and decrypted correctly
"""

import json

import inngest
import tests.helper
from inngest._internal import server_lib

# from inngest.experimental import remote_state_middleware
from tests import net

from . import base

# import boto3
# import moto
# import moto.server


class _State(base.BaseState):
event: inngest.Event
events: list[inngest.Event]


def create(
client: inngest.Inngest,
framework: server_lib.Framework,
is_sync: bool,
) -> base.Case:
print("Creating case")
test_name = base.create_test_name(__file__)
event_name = base.create_event_name(framework, test_name)
fn_id = base.create_fn_id(test_name)
state = _State()

aws_port = net.get_available_port()

# Start mock AWS server.
# aws_server = moto.server.ThreadedMotoServer(port=net.get_available_port())
# aws_server.start()
# aws_host, aws_port = aws_server.get_host_and_port()
aws_url = f"http://localhost:{aws_port}"
aws_access_key_id = "test"
aws_secret_access_key = "test"
aws_region = "us-east-1"
s3_bucket = "inngest"

# s3_client = boto3.client(
# "s3",
# endpoint_url=aws_url,
# region_name=aws_region,
# aws_access_key_id=aws_access_key_id,
# aws_secret_access_key=aws_secret_access_key,
# )

# Create S3 driver.
# driver = remote_state_middleware.S3Driver(
# # aws_access_key_id=aws_access_key_id,
# # aws_secret_access_key=aws_secret_access_key,
# bucket=s3_bucket,
# client=s3_client,
# # endpoint_url=aws_url,
# # region_name=aws_region,
# )

@client.create_function(
fn_id=fn_id,
middleware=[
# remote_state_middleware.RemoteStateMiddleware.factory(driver)
],
retries=0,
trigger=inngest.TriggerEvent(event=event_name),
)
def fn_sync(
ctx: inngest.Context,
step: inngest.StepSync,
) -> str:
state.run_id = ctx.run_id

def _step_1() -> str:
return "test string"

step_1_output = step.run("step_1", _step_1)
assert step_1_output == "test string"

def _step_2() -> list[inngest.JSON]:
return [{"a": {"b": 1}}]

step_2_output = step.run("step_2", _step_2)
assert step_2_output == [{"a": {"b": 1}}]

return "function output"

@client.create_function(
fn_id=fn_id,
middleware=[
# remote_state_middleware.RemoteStateMiddleware.factory(driver)
],
retries=0,
trigger=inngest.TriggerEvent(event=event_name),
)
async def fn_async(
ctx: inngest.Context,
step: inngest.Step,
) -> str:
state.run_id = ctx.run_id

def _step_1() -> str:
return "test string"

step_1_output = await step.run("step_1", _step_1)
assert step_1_output == "test string"

def _step_2() -> list[inngest.JSON]:
return [{"a": {"b": 1}}]

step_2_output = await step.run("step_2", _step_2)
assert step_2_output == [{"a": {"b": 1}}]

return "function output"

async def run_test(self: base.TestClass) -> None:
# aws_server = moto.server.ThreadedMotoServer(port=aws_port)
# aws_server.start()
# self.addCleanup(aws_server.stop)
# self.addCleanup(s3_client.close)

# # Create bucket.
# print("Creating bucket")

# s3_client.create_bucket(Bucket=s3_bucket)
# # client.close()

# print("Running test")
# self.client.send_sync(inngest.Event(name=event_name))

# print("Waiting for run ID")
# run_id = state.wait_for_run_id()
# print("Waiting for run status")
# run = tests.helper.client.wait_for_run_status(
# run_id,
# tests.helper.RunStatus.COMPLETED,
# )

# print("Getting step output")

# # Ensure that step_1 output is encrypted and its value is correct
# output = json.loads(
# tests.helper.client.get_step_output(
# run_id=run_id,
# step_id="step_1",
# )
# )

# assert isinstance(output, dict)
# data = output.get("data")
# assert isinstance(data, dict)

# # Ensure the step output is remotely stored.
# assert driver._marker in data

# print("Getting step output")
# output = json.loads(
# tests.helper.client.get_step_output(
# run_id=run_id,
# step_id="step_2",
# )
# )
# assert isinstance(output, dict)
# data = output.get("data")
# assert isinstance(data, dict)

# # Ensure the step output is remotely stored.
# assert driver._marker in data

# assert run.output is not None
# assert json.loads(run.output) == "function output"
print("done")
# aws_server.stop()

if is_sync:
fn = fn_sync
else:
fn = fn_async

return base.Case(
fn=fn,
run_test=run_test,
name=test_name,
)
Loading

0 comments on commit 5a3f9d3

Please sign in to comment.