Skip to content

Commit

Permalink
Support API authentication
Browse files Browse the repository at this point in the history
  • Loading branch information
pelletier committed Feb 1, 2024
1 parent d09ce78 commit f17246b
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 10 deletions.
19 changes: 10 additions & 9 deletions src/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,23 +122,24 @@ def __init__(
Raises:
ValueError: if the API key is missing.
"""
self._api_key = api_key or os.environ.get("DISPATCH_API_KEY")
if not self._api_key:
if not api_key:
api_key = os.environ.get("DISPATCH_API_KEY")
if not api_key:
raise ValueError("api_key is required")

# TODO: actually use the API key when we have defined the authentication
# mechanism.

result = urlparse(api_url)
match result.scheme:
case "http":
port = result.port if result.port else 80
channel = grpc.insecure_channel(f"{result.hostname}:{port}")
creds = grpc.local_channel_credentials()
case "https":
port = result.port if result.port else 443
channel = grpc.insecure_channel(f"{result.hostname}:{port}")
creds = grpc.ssl_channel_credentials()
case _:
raise ValueError(f"Invalid API scheme: '{result.scheme}'")

call_creds = grpc.access_token_call_credentials(api_key)
creds = grpc.composite_channel_credentials(creds, call_creds)
channel = grpc.secure_channel(result.netloc, creds)

self._stub = service_grpc.ServiceStub(channel)

def create_tasks(self, tasks: Iterable[TaskDef]) -> Iterable[TaskID]:
Expand Down
21 changes: 20 additions & 1 deletion tests/task_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from dispatch import Client, TaskInput, TaskID


_test_auth_token = "THIS_IS_A_TEST_AUTH_TOKEN"


class FakeRing(service_grpc.ServiceServicer):
def __init__(self):
super().__init__()
Expand All @@ -21,7 +24,21 @@ def __init__(self):

self.pending_tasks = []

def _validate_authentication(self, context: grpc.ServicerContext):
expected = f"Bearer {_test_auth_token}"
for key, value in context.invocation_metadata():
if key == "authorization":
if value == expected:
return
context.abort(
grpc.StatusCode.UNAUTHENTICATED,
f"Invalid authorization header. Expected '{expected}', got '{value!r}'",
)
context.abort(grpc.StatusCode.UNAUTHENTICATED, "Missing authorization header")

def CreateTasks(self, request: service_pb.CreateTasksRequest, context):
self._validate_authentication(context)

resp = service_pb.CreateTasksResponse()

for t in request.tasks:
Expand Down Expand Up @@ -75,7 +92,9 @@ def __init__(self):
service_grpc.add_ServiceServicer_to_server(self.servicer, self.server)
self.server.start()

self.client = Client(api_key="test", api_url=f"http://127.0.0.1:{port}")
self.client = Client(
api_key=_test_auth_token, api_url=f"http://127.0.0.1:{port}"
)

def stop(self):
self.server.stop(0)
Expand Down
2 changes: 2 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ def setUp(self):
def tearDown(self):
self.server.stop()

def test_authentication(self): ...

def test_create_one_task_pickle(self):
results = self.client.create_tasks(
[TaskInput(coroutine_uri="my-cool-coroutine", input=42)]
Expand Down

0 comments on commit f17246b

Please sign in to comment.