Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
the-praxs committed Nov 21, 2024
1 parent 9120be9 commit 7d31cbb
Show file tree
Hide file tree
Showing 40 changed files with 565 additions and 280 deletions.
4 changes: 3 additions & 1 deletion agentops/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ def main():
parser = argparse.ArgumentParser(description="AgentOps CLI")
subparsers = parser.add_subparsers(dest="command")

timetravel_parser = subparsers.add_parser("timetravel", help="Time Travel Debugging commands", aliases=["tt"])
timetravel_parser = subparsers.add_parser(
"timetravel", help="Time Travel Debugging commands", aliases=["tt"]
)
timetravel_parser.add_argument(
"branch_name",
type=str,
Expand Down
53 changes: 41 additions & 12 deletions agentops/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ def __init__(self):
api_key=os.environ.get("AGENTOPS_API_KEY"),
parent_key=os.environ.get("AGENTOPS_PARENT_KEY"),
endpoint=os.environ.get("AGENTOPS_API_ENDPOINT"),
env_data_opt_out=os.environ.get("AGENTOPS_ENV_DATA_OPT_OUT", "False").lower() == "true",
env_data_opt_out=os.environ.get(
"AGENTOPS_ENV_DATA_OPT_OUT", "False"
).lower()
== "true",
)

def configure(
Expand Down Expand Up @@ -103,7 +106,9 @@ def initialize(self) -> Union[Session, None]:

if session:
for agent_args in self._pre_init_queue["agents"]:
session.create_agent(name=agent_args["name"], agent_id=agent_args["agent_id"])
session.create_agent(
name=agent_args["name"], agent_id=agent_args["agent_id"]
)
self._pre_init_queue["agents"] = []

return session
Expand Down Expand Up @@ -136,7 +141,9 @@ def add_tags(self, tags: List[str]) -> None:

session = self._safe_get_session()
if session is None:
return logger.warning("Could not add tags. Start a session by calling agentops.start_session().")
return logger.warning(
"Could not add tags. Start a session by calling agentops.start_session()."
)

session.add_tags(tags=tags)

Expand All @@ -155,7 +162,9 @@ def set_tags(self, tags: List[str]) -> None:
session = self._safe_get_session()

if session is None:
return logger.warning("Could not set tags. Start a session by calling agentops.start_session().")
return logger.warning(
"Could not set tags. Start a session by calling agentops.start_session()."
)
else:
session.set_tags(tags=tags)

Expand Down Expand Up @@ -189,7 +198,9 @@ def record(self, event: Union[Event, ErrorEvent]) -> None:

session = self._safe_get_session()
if session is None:
return logger.error("Could not record event. Start a session by calling agentops.start_session().")
return logger.error(
"Could not record event. Start a session by calling agentops.start_session()."
)
session.record(event)

def start_session(
Expand Down Expand Up @@ -233,7 +244,9 @@ def start_session(

if self._pre_init_queue["agents"] and len(self._pre_init_queue["agents"]) > 0:
for agent_args in self._pre_init_queue["agents"]:
session.create_agent(name=agent_args["name"], agent_id=agent_args["agent_id"])
session.create_agent(
name=agent_args["name"], agent_id=agent_args["agent_id"]
)
self._pre_init_queue["agents"] = []

self._sessions.append(session)
Expand Down Expand Up @@ -264,7 +277,9 @@ def end_session(
if is_auto_end and self._config.skip_auto_end_session:
return

token_cost = session.end_session(end_state=end_state, end_state_reason=end_state_reason, video=video)
token_cost = session.end_session(
end_state=end_state, end_state_reason=end_state_reason, video=video
)

return token_cost

Expand All @@ -284,7 +299,9 @@ def create_agent(
# if no session passed, assume single session
session = self._safe_get_session()
if session is None:
self._pre_init_queue["agents"].append({"name": name, "agent_id": agent_id})
self._pre_init_queue["agents"].append(
{"name": name, "agent_id": agent_id}
)
else:
session.create_agent(name=name, agent_id=agent_id)

Expand All @@ -309,7 +326,9 @@ def signal_handler(signum, frame):
"""
signal_name = "SIGINT" if signum == signal.SIGINT else "SIGTERM"
logger.info("%s detected. Ending session...", signal_name)
self.end_session(end_state="Fail", end_state_reason=f"Signal {signal_name} detected")
self.end_session(
end_state="Fail", end_state_reason=f"Signal {signal_name} detected"
)
sys.exit(0)

def handle_exception(exc_type, exc_value, exc_traceback):
Expand All @@ -322,7 +341,9 @@ def handle_exception(exc_type, exc_value, exc_traceback):
exc_traceback (TracebackType): A traceback object encapsulating the call stack at the
point where the exception originally occurred.
"""
formatted_traceback = "".join(traceback.format_exception(exc_type, exc_value, exc_traceback))
formatted_traceback = "".join(
traceback.format_exception(exc_type, exc_value, exc_traceback)
)

for session in self._sessions:
session.end_session(
Expand Down Expand Up @@ -355,7 +376,13 @@ def add_pre_init_warning(self, message: str):
# replaces the session currently stored with a specific session_id, with a new session
def _update_session(self, session: Session):
self._sessions[
self._sessions.index([sess for sess in self._sessions if sess.session_id == session.session_id][0])
self._sessions.index(
[
sess
for sess in self._sessions
if sess.session_id == session.session_id
][0]
)
] = session

def _safe_get_session(self) -> Optional[Session]:
Expand All @@ -365,7 +392,9 @@ def _safe_get_session(self) -> Optional[Session]:
return self._sessions[0]

if len(self._sessions) > 1:
calling_function = inspect.stack()[2].function # Using index 2 because we have a wrapper at index 1
calling_function = inspect.stack()[
2
].function # Using index 2 because we have a wrapper at index 1
return logger.warning(
f"Multiple sessions detected. You must use session.{calling_function}(). More info: https://docs.agentops.ai/v1/concepts/core-concepts#session-management"
)
Expand Down
28 changes: 21 additions & 7 deletions agentops/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ async def async_wrapper(*args, session: Optional[Session] = None, **kwargs):
arg_names = list(func_args.keys())
# Get default values
arg_values = {
name: func_args[name].default for name in arg_names if func_args[name].default is not inspect._empty
name: func_args[name].default
for name in arg_names
if func_args[name].default is not inspect._empty
}
# Update with positional arguments
arg_values.update(dict(zip(arg_names, args)))
Expand Down Expand Up @@ -110,7 +112,9 @@ def sync_wrapper(*args, session: Optional[Session] = None, **kwargs):
arg_names = list(func_args.keys())
# Get default values
arg_values = {
name: func_args[name].default for name in arg_names if func_args[name].default is not inspect._empty
name: func_args[name].default
for name in arg_names
if func_args[name].default is not inspect._empty
}
# Update with positional arguments
arg_values.update(dict(zip(arg_names, args)))
Expand Down Expand Up @@ -188,7 +192,9 @@ async def async_wrapper(*args, session: Optional[Session] = None, **kwargs):
arg_names = list(func_args.keys())
# Get default values
arg_values = {
name: func_args[name].default for name in arg_names if func_args[name].default is not inspect._empty
name: func_args[name].default
for name in arg_names
if func_args[name].default is not inspect._empty
}
# Update with positional arguments
arg_values.update(dict(zip(arg_names, args)))
Expand Down Expand Up @@ -252,7 +258,9 @@ def sync_wrapper(*args, session: Optional[Session] = None, **kwargs):
arg_names = list(func_args.keys())
# Get default values
arg_values = {
name: func_args[name].default for name in arg_names if func_args[name].default is not inspect._empty
name: func_args[name].default
for name in arg_names
if func_args[name].default is not inspect._empty
}
# Update with positional arguments
arg_values.update(dict(zip(arg_names, args)))
Expand Down Expand Up @@ -345,15 +353,21 @@ def new_init(self, *args, **kwargs):

except AttributeError as ex:
logger.debug(ex)
Client().add_pre_init_warning(f"Failed to track an agent {name} with the @track_agent decorator.")
logger.warning("Failed to track an agent with the @track_agent decorator.")
Client().add_pre_init_warning(
f"Failed to track an agent {name} with the @track_agent decorator."
)
logger.warning(
"Failed to track an agent with the @track_agent decorator."
)

obj.__init__ = new_init

elif inspect.isfunction(obj):
obj.agentops_agent_id = str(uuid4())
obj.agentops_agent_name = name
Client().create_agent(name=obj.agentops_agent_name, agent_id=obj.agentops_agent_id)
Client().create_agent(
name=obj.agentops_agent_name, agent_id=obj.agentops_agent_id
)

else:
raise Exception("Invalid input, 'obj' must be a class or a function")
Expand Down
4 changes: 3 additions & 1 deletion agentops/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,9 @@ def stack_lookup() -> Union[UUID, None]:
var_type = type(var)

# Get all class attributes
class_attrs = {name: getattr(var_type, name, None) for name in dir(var_type)}
class_attrs = {
name: getattr(var_type, name, None) for name in dir(var_type)
}

agent_id_desc = class_attrs.get("agentops_agent_id")

Expand Down
12 changes: 5 additions & 7 deletions agentops/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,7 @@ def filter_dict(obj):
k: (
filter_dict(v)
if isinstance(v, (dict, list)) or is_jsonable(v)
else str(v)
if isinstance(v, UUID)
else ""
else str(v) if isinstance(v, UUID) else ""
)
for k, v in obj.items()
}
Expand All @@ -50,9 +48,7 @@ def filter_dict(obj):
(
filter_dict(x)
if isinstance(x, (dict, list)) or is_jsonable(x)
else str(x)
if isinstance(x, UUID)
else ""
else str(x) if isinstance(x, UUID) else ""
)
for x in obj
]
Expand Down Expand Up @@ -90,7 +86,9 @@ def remove_unwanted_items(value):
"""Recursively remove self key and None/... values from dictionaries so they aren't serialized"""
if isinstance(value, dict):
return {
k: remove_unwanted_items(v) for k, v in value.items() if v is not None and v is not ... and k != "self"
k: remove_unwanted_items(v)
for k, v in value.items()
if v is not None and v is not ... and k != "self"
}
elif isinstance(value, list):
return [remove_unwanted_items(item) for item in value]
Expand Down
7 changes: 5 additions & 2 deletions agentops/host_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ def get_installed_packages():
return {
# TODO: add to opt out
"Installed Packages": {
dist.metadata.get("Name"): dist.metadata.get("Version") for dist in importlib.metadata.distributions()
dist.metadata.get("Name"): dist.metadata.get("Version")
for dist in importlib.metadata.distributions()
}
}
except:
Expand Down Expand Up @@ -124,7 +125,9 @@ def get_disk_details():
}
except OSError as inaccessible:
# Skip inaccessible partitions, such as removable drives with no media
logger.debug("Mountpoint %s inaccessible: %s", partition.mountpoint, inaccessible)
logger.debug(
"Mountpoint %s inaccessible: %s", partition.mountpoint, inaccessible
)

return disk_info

Expand Down
16 changes: 12 additions & 4 deletions agentops/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ class HttpStatus(Enum):


class Response:
def __init__(self, status: HttpStatus = HttpStatus.UNKNOWN, body: Optional[dict] = None):
def __init__(
self, status: HttpStatus = HttpStatus.UNKNOWN, body: Optional[dict] = None
):
self.status: HttpStatus = status
self.code: int = status.value
self.body = body if body else {}
Expand Down Expand Up @@ -78,13 +80,17 @@ def post(
if jwt is not None:
JSON_HEADER["Authorization"] = f"Bearer {jwt}"

res = request_session.post(url, data=payload, headers=JSON_HEADER, timeout=20)
res = request_session.post(
url, data=payload, headers=JSON_HEADER, timeout=20
)

result.parse(res)
except requests.exceptions.Timeout:
result.code = 408
result.status = HttpStatus.TIMEOUT
raise ApiServerException("Could not reach API server - connection timed out")
raise ApiServerException(
"Could not reach API server - connection timed out"
)
except requests.exceptions.HTTPError as e:
try:
result.parse(e.response)
Expand Down Expand Up @@ -137,7 +143,9 @@ def get(
except requests.exceptions.Timeout:
result.code = 408
result.status = HttpStatus.TIMEOUT
raise ApiServerException("Could not reach API server - connection timed out")
raise ApiServerException(
"Could not reach API server - connection timed out"
)
except requests.exceptions.HTTPError as e:
try:
result.parse(e.response)
Expand Down
Loading

0 comments on commit 7d31cbb

Please sign in to comment.