Skip to content

Commit

Permalink
refresh token when needed
Browse files Browse the repository at this point in the history
  • Loading branch information
bboynton97 committed Sep 21, 2024
1 parent 89aa440 commit 1a7cba1
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 20 deletions.
2 changes: 2 additions & 0 deletions agentops/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
from uuid import UUID

from .log_config import logger
from .singleton import singleton


@singleton
class Configuration:
def __init__(self):
self.api_key: Optional[str] = None
Expand Down
73 changes: 57 additions & 16 deletions agentops/http_client.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
from datetime import datetime
from enum import Enum
from typing import Optional, List
from typing import Optional, List, Union

import jwt
from requests.adapters import Retry, HTTPAdapter
import requests
from agentops.log_config import logger
from .config import Configuration

from .exceptions import ApiServerException
from dotenv import load_dotenv
import os

from .helpers import ensure_dead_letter_queue
from .helpers import ensure_dead_letter_queue, filter_unjsonable, safe_serialize
import json

load_dotenv()
Expand Down Expand Up @@ -47,7 +51,7 @@ def read_queue(self):
def write_queue(self):
if not self.is_testing:
with open(self.file_path, "w") as f:
json.dump({"messages": self.queue}, f)
json.dump({"messages": safe_serialize(self.queue)}, f)

def add(self, request_data: dict):
if not self.is_testing:
Expand Down Expand Up @@ -112,7 +116,7 @@ def post(
payload: bytes,
api_key: Optional[str] = None,
parent_key: Optional[str] = None,
jwt: Optional[str] = None,
token: Optional[str] = None,
) -> Response:
result = Response()
try:
Expand All @@ -126,8 +130,23 @@ def post(
if parent_key is not None:
JSON_HEADER["X-Agentops-Parent-Key"] = parent_key

if jwt is not None:
JSON_HEADER["Authorization"] = f"Bearer {jwt}"
if token is not None:
decoded_jwt = jwt.decode(
token,
algorithms=["HS256"],
options={"verify_signature": False},
)

# if token is expired, reauth
if datetime.fromtimestamp(decoded_jwt["exp"]) < datetime.now():
new_jwt = reauthorize_jwt(
token,
api_key,
decoded_jwt["session_id"],
)
token = new_jwt

JSON_HEADER["Authorization"] = f"Bearer {token}"

res = request_session.post(
url, data=payload, headers=JSON_HEADER, timeout=20
Expand All @@ -140,16 +159,18 @@ def post(

except (requests.exceptions.Timeout, requests.exceptions.HTTPError) as e:
HttpClient._handle_failed_request(
url, payload, api_key, parent_key, jwt, type(e).__name__
url, payload, api_key, parent_key, token, type(e).__name__
)
raise ApiServerException(f"{type(e).__name__}: {e}")
except requests.exceptions.RequestException as e:
HttpClient._handle_failed_request(
url, payload, api_key, parent_key, jwt, "RequestException"
url, payload, api_key, parent_key, token, "RequestException"
)
raise ApiServerException(f"RequestException: {e}")

if result.code == 401:
if result.body.get("message") == "Expired Token":
raise ApiServerException(f"API server: jwt token expired.")
raise ApiServerException(
f"API server: invalid API key: {api_key}. Find your API key at https://app.agentops.ai/settings/projects"
)
Expand All @@ -160,7 +181,7 @@ def post(
raise ApiServerException(f"API server: {result.body}")
if result.code == 500:
HttpClient._handle_failed_request(
url, payload, api_key, parent_key, jwt, "ServerError"
url, payload, api_key, parent_key, token, "ServerError"
)
raise ApiServerException("API server: - internal server error")

Expand All @@ -170,7 +191,7 @@ def post(
def get(
url: str,
api_key: Optional[str] = None,
jwt: Optional[str] = None,
token: Optional[str] = None,
) -> Response:
result = Response()
try:
Expand All @@ -181,8 +202,8 @@ def get(
if api_key is not None:
JSON_HEADER["X-Agentops-Api-Key"] = api_key

if jwt is not None:
JSON_HEADER["Authorization"] = f"Bearer {jwt}"
if token is not None:
JSON_HEADER["Authorization"] = f"Bearer {token}"

res = request_session.get(url, headers=JSON_HEADER, timeout=20)

Expand All @@ -193,12 +214,12 @@ def get(

except (requests.exceptions.Timeout, requests.exceptions.HTTPError) as e:
HttpClient._handle_failed_request(
url, None, api_key, None, jwt, type(e).__name__
url, None, api_key, None, token, type(e).__name__
)
raise ApiServerException(f"{type(e).__name__}: {e}")
except requests.exceptions.RequestException as e:
HttpClient._handle_failed_request(
url, None, api_key, None, jwt, "RequestException"
url, None, api_key, None, token, "RequestException"
)
raise ApiServerException(f"RequestException: {e}")

Expand Down Expand Up @@ -229,7 +250,9 @@ def _retry_dlq_requests():
# Retry POST request from DLQ
HttpClient.post(
failed_request["url"],
failed_request["payload"],
json.dumps(filter_unjsonable(failed_request["payload"])).encode(
"utf-8"
),
failed_request["api_key"],
failed_request["parent_key"],
failed_request["jwt"],
Expand All @@ -241,7 +264,7 @@ def _retry_dlq_requests():
failed_request["api_key"],
failed_request["jwt"],
)
except ApiServerException:
except ApiServerException as e:
dead_letter_queue.add(failed_request)
# If it still fails, keep it in the DLQ
except Exception as e:
Expand Down Expand Up @@ -270,3 +293,21 @@ def _handle_failed_request(
logger.warning(
f"An error occurred while communicating with the server: {error_type}"
)


def reauthorize_jwt(old_jwt: str, api_key: str, session_id: str) -> Union[str, None]:
payload = {"jwt": old_jwt, "session_id": session_id}
serialized_payload = json.dumps(filter_unjsonable(payload)).encode("utf-8")
config = Configuration()
res = HttpClient.post(
f"{config.endpoint}/v2/reauthorize_jwt",
serialized_payload,
api_key,
)

logger.debug(res.body)

if res.code != 200:
return None

return res.body.get("jwt", None)
8 changes: 4 additions & 4 deletions agentops/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def format_duration(start_time, end_time):
res = HttpClient.post(
f"{self.config.endpoint}/v2/update_session",
json.dumps(filter_unjsonable(payload)).encode("utf-8"),
jwt=self.jwt,
token=self.jwt,
)
except ApiServerException as e:
return logger.error(f"Could not end session - {e}")
Expand Down Expand Up @@ -289,7 +289,7 @@ def _update_session(self) -> None:
res = HttpClient.post(
f"{self.config.endpoint}/v2/update_session",
json.dumps(filter_unjsonable(payload)).encode("utf-8"),
jwt=self.jwt,
token=self.jwt,
)
except ApiServerException as e:
return logger.error(f"Could not update session - {e}")
Expand All @@ -311,7 +311,7 @@ def _flush_queue(self) -> None:
HttpClient.post(
f"{self.config.endpoint}/v2/create_events",
serialized_payload,
jwt=self.jwt,
token=self.jwt,
)
except ApiServerException as e:
return logger.error(f"Could not post events - {e}")
Expand Down Expand Up @@ -360,7 +360,7 @@ def create_agent(self, name, agent_id):
HttpClient.post(
f"{self.config.endpoint}/v2/create_agent",
serialized_payload,
jwt=self.jwt,
token=self.jwt,
)
except ApiServerException as e:
return logger.error(f"Could not create agent - {e}")
Expand Down
57 changes: 57 additions & 0 deletions tests/core_manual_tests/http_client/dead_letter_queue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
### Purpose
# test an edge case where a request is retried after the jwt has expired
import time
from datetime import datetime

### SETUP
# Run the API server locally
# In utils.py -> generate_jwt -> set the jwt expiration to 0.001
# Run this script

### Plan
# The first request should succeed and return a JWT
# We'll manually add a failed request to the DLQ with the expired JWT
# When reattempting, the http_client should identify the expired jwt and reauthorize it before sending again

import agentops
from agentops import ActionEvent
from agentops.helpers import safe_serialize, get_ISO_time
from agentops.http_client import dead_letter_queue, HttpClient

api_key = "492f0ee6-0b7d-40a6-af86-22d89c7c5eea"
agentops.init(
endpoint="http://localhost:8000",
api_key=api_key,
auto_start_session=False,
default_tags=["dead-letter-queue-test"],
)

# create session
session = agentops.start_session()

# add failed request to DLQ
event = ActionEvent()
event.end_timestamp = get_ISO_time()

failed_request = {
"url": "http://localhost:8000/v2/create_events",
"payload": {"events": [event.__dict__]},
"api_key": str(api_key),
"parent_key": None,
"jwt": session.jwt,
"error_type": "Timeout",
}
# failed_request = safe_serialize(failed_request).encode("utf-8")

dead_letter_queue.add(failed_request)
assert len(dead_letter_queue.get_all()) == 1

# wait for the JWT to expire
time.sleep(3)

# retry
HttpClient()._retry_dlq_requests()
session.end_session(end_state="Success")

# check if the failed request is still in the DLQ
assert dead_letter_queue.get_all() == []

0 comments on commit 1a7cba1

Please sign in to comment.