Skip to content

Commit

Permalink
use ruff for formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
the-praxs committed Nov 21, 2024
1 parent b1fb154 commit c8f8585
Show file tree
Hide file tree
Showing 54 changed files with 596 additions and 926 deletions.
4 changes: 1 addition & 3 deletions agentops/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@ def main():
parser = argparse.ArgumentParser(description="AgentOps CLI")
subparsers = parser.add_subparsers(dest="command")

timetravel_parser = subparsers.add_parser(
"timetravel", help="Time Travel Debugging commands", aliases=["tt"]
)
timetravel_parser = subparsers.add_parser("timetravel", help="Time Travel Debugging commands", aliases=["tt"])
timetravel_parser.add_argument(
"branch_name",
type=str,
Expand Down
53 changes: 12 additions & 41 deletions agentops/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,7 @@ def __init__(self):
api_key=os.environ.get("AGENTOPS_API_KEY"),
parent_key=os.environ.get("AGENTOPS_PARENT_KEY"),
endpoint=os.environ.get("AGENTOPS_API_ENDPOINT"),
env_data_opt_out=os.environ.get(
"AGENTOPS_ENV_DATA_OPT_OUT", "False"
).lower()
== "true",
env_data_opt_out=os.environ.get("AGENTOPS_ENV_DATA_OPT_OUT", "False").lower() == "true",
)

def configure(
Expand Down Expand Up @@ -106,9 +103,7 @@ def initialize(self) -> Union[Session, None]:

if session:
for agent_args in self._pre_init_queue["agents"]:
session.create_agent(
name=agent_args["name"], agent_id=agent_args["agent_id"]
)
session.create_agent(name=agent_args["name"], agent_id=agent_args["agent_id"])
self._pre_init_queue["agents"] = []

return session
Expand Down Expand Up @@ -141,9 +136,7 @@ def add_tags(self, tags: List[str]) -> None:

session = self._safe_get_session()
if session is None:
return logger.warning(
"Could not add tags. Start a session by calling agentops.start_session()."
)
return logger.warning("Could not add tags. Start a session by calling agentops.start_session().")

session.add_tags(tags=tags)

Expand All @@ -162,9 +155,7 @@ def set_tags(self, tags: List[str]) -> None:
session = self._safe_get_session()

if session is None:
return logger.warning(
"Could not set tags. Start a session by calling agentops.start_session()."
)
return logger.warning("Could not set tags. Start a session by calling agentops.start_session().")
else:
session.set_tags(tags=tags)

Expand Down Expand Up @@ -198,9 +189,7 @@ def record(self, event: Union[Event, ErrorEvent]) -> None:

session = self._safe_get_session()
if session is None:
return logger.error(
"Could not record event. Start a session by calling agentops.start_session()."
)
return logger.error("Could not record event. Start a session by calling agentops.start_session().")
session.record(event)

def start_session(
Expand Down Expand Up @@ -244,9 +233,7 @@ def start_session(

if self._pre_init_queue["agents"] and len(self._pre_init_queue["agents"]) > 0:
for agent_args in self._pre_init_queue["agents"]:
session.create_agent(
name=agent_args["name"], agent_id=agent_args["agent_id"]
)
session.create_agent(name=agent_args["name"], agent_id=agent_args["agent_id"])
self._pre_init_queue["agents"] = []

self._sessions.append(session)
Expand Down Expand Up @@ -277,9 +264,7 @@ def end_session(
if is_auto_end and self._config.skip_auto_end_session:
return

token_cost = session.end_session(
end_state=end_state, end_state_reason=end_state_reason, video=video
)
token_cost = session.end_session(end_state=end_state, end_state_reason=end_state_reason, video=video)

return token_cost

Expand All @@ -299,9 +284,7 @@ def create_agent(
# if no session passed, assume single session
session = self._safe_get_session()
if session is None:
self._pre_init_queue["agents"].append(
{"name": name, "agent_id": agent_id}
)
self._pre_init_queue["agents"].append({"name": name, "agent_id": agent_id})
else:
session.create_agent(name=name, agent_id=agent_id)

Expand All @@ -326,9 +309,7 @@ def signal_handler(signum, frame):
"""
signal_name = "SIGINT" if signum == signal.SIGINT else "SIGTERM"
logger.info("%s detected. Ending session...", signal_name)
self.end_session(
end_state="Fail", end_state_reason=f"Signal {signal_name} detected"
)
self.end_session(end_state="Fail", end_state_reason=f"Signal {signal_name} detected")
sys.exit(0)

def handle_exception(exc_type, exc_value, exc_traceback):
Expand All @@ -341,9 +322,7 @@ def handle_exception(exc_type, exc_value, exc_traceback):
exc_traceback (TracebackType): A traceback object encapsulating the call stack at the
point where the exception originally occurred.
"""
formatted_traceback = "".join(
traceback.format_exception(exc_type, exc_value, exc_traceback)
)
formatted_traceback = "".join(traceback.format_exception(exc_type, exc_value, exc_traceback))

for session in self._sessions:
session.end_session(
Expand Down Expand Up @@ -376,13 +355,7 @@ def add_pre_init_warning(self, message: str):
# replaces the session currently stored with a specific session_id, with a new session
def _update_session(self, session: Session):
self._sessions[
self._sessions.index(
[
sess
for sess in self._sessions
if sess.session_id == session.session_id
][0]
)
self._sessions.index([sess for sess in self._sessions if sess.session_id == session.session_id][0])
] = session

def _safe_get_session(self) -> Optional[Session]:
Expand All @@ -392,9 +365,7 @@ def _safe_get_session(self) -> Optional[Session]:
return self._sessions[0]

if len(self._sessions) > 1:
calling_function = inspect.stack()[
2
].function # Using index 2 because we have a wrapper at index 1
calling_function = inspect.stack()[2].function # Using index 2 because we have a wrapper at index 1
return logger.warning(
f"Multiple sessions detected. You must use session.{calling_function}(). More info: https://docs.agentops.ai/v1/concepts/core-concepts#session-management"
)
Expand Down
28 changes: 7 additions & 21 deletions agentops/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,7 @@ async def async_wrapper(*args, session: Optional[Session] = None, **kwargs):
arg_names = list(func_args.keys())
# Get default values
arg_values = {
name: func_args[name].default
for name in arg_names
if func_args[name].default is not inspect._empty
name: func_args[name].default for name in arg_names if func_args[name].default is not inspect._empty
}
# Update with positional arguments
arg_values.update(dict(zip(arg_names, args)))
Expand Down Expand Up @@ -112,9 +110,7 @@ def sync_wrapper(*args, session: Optional[Session] = None, **kwargs):
arg_names = list(func_args.keys())
# Get default values
arg_values = {
name: func_args[name].default
for name in arg_names
if func_args[name].default is not inspect._empty
name: func_args[name].default for name in arg_names if func_args[name].default is not inspect._empty
}
# Update with positional arguments
arg_values.update(dict(zip(arg_names, args)))
Expand Down Expand Up @@ -192,9 +188,7 @@ async def async_wrapper(*args, session: Optional[Session] = None, **kwargs):
arg_names = list(func_args.keys())
# Get default values
arg_values = {
name: func_args[name].default
for name in arg_names
if func_args[name].default is not inspect._empty
name: func_args[name].default for name in arg_names if func_args[name].default is not inspect._empty
}
# Update with positional arguments
arg_values.update(dict(zip(arg_names, args)))
Expand Down Expand Up @@ -258,9 +252,7 @@ def sync_wrapper(*args, session: Optional[Session] = None, **kwargs):
arg_names = list(func_args.keys())
# Get default values
arg_values = {
name: func_args[name].default
for name in arg_names
if func_args[name].default is not inspect._empty
name: func_args[name].default for name in arg_names if func_args[name].default is not inspect._empty
}
# Update with positional arguments
arg_values.update(dict(zip(arg_names, args)))
Expand Down Expand Up @@ -353,21 +345,15 @@ def new_init(self, *args, **kwargs):

except AttributeError as ex:
logger.debug(ex)
Client().add_pre_init_warning(
f"Failed to track an agent {name} with the @track_agent decorator."
)
logger.warning(
"Failed to track an agent with the @track_agent decorator."
)
Client().add_pre_init_warning(f"Failed to track an agent {name} with the @track_agent decorator.")
logger.warning("Failed to track an agent with the @track_agent decorator.")

obj.__init__ = new_init

elif inspect.isfunction(obj):
obj.agentops_agent_id = str(uuid4())
obj.agentops_agent_name = name
Client().create_agent(
name=obj.agentops_agent_name, agent_id=obj.agentops_agent_id
)
Client().create_agent(name=obj.agentops_agent_name, agent_id=obj.agentops_agent_id)

else:
raise Exception("Invalid input, 'obj' must be a class or a function")
Expand Down
4 changes: 1 addition & 3 deletions agentops/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,7 @@ def stack_lookup() -> Union[UUID, None]:
var_type = type(var)

# Get all class attributes
class_attrs = {
name: getattr(var_type, name, None) for name in dir(var_type)
}
class_attrs = {name: getattr(var_type, name, None) for name in dir(var_type)}

agent_id_desc = class_attrs.get("agentops_agent_id")

Expand Down
12 changes: 7 additions & 5 deletions agentops/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ def filter_dict(obj):
k: (
filter_dict(v)
if isinstance(v, (dict, list)) or is_jsonable(v)
else str(v) if isinstance(v, UUID) else ""
else str(v)
if isinstance(v, UUID)
else ""
)
for k, v in obj.items()
}
Expand All @@ -48,7 +50,9 @@ def filter_dict(obj):
(
filter_dict(x)
if isinstance(x, (dict, list)) or is_jsonable(x)
else str(x) if isinstance(x, UUID) else ""
else str(x)
if isinstance(x, UUID)
else ""
)
for x in obj
]
Expand Down Expand Up @@ -86,9 +90,7 @@ def remove_unwanted_items(value):
"""Recursively remove self key and None/... values from dictionaries so they aren't serialized"""
if isinstance(value, dict):
return {
k: remove_unwanted_items(v)
for k, v in value.items()
if v is not None and v is not ... and k != "self"
k: remove_unwanted_items(v) for k, v in value.items() if v is not None and v is not ... and k != "self"
}
elif isinstance(value, list):
return [remove_unwanted_items(item) for item in value]
Expand Down
7 changes: 2 additions & 5 deletions agentops/host_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ def get_installed_packages():
return {
# TODO: add to opt out
"Installed Packages": {
dist.metadata.get("Name"): dist.metadata.get("Version")
for dist in importlib.metadata.distributions()
dist.metadata.get("Name"): dist.metadata.get("Version") for dist in importlib.metadata.distributions()
}
}
except:
Expand Down Expand Up @@ -125,9 +124,7 @@ def get_disk_details():
}
except OSError as inaccessible:
# Skip inaccessible partitions, such as removable drives with no media
logger.debug(
"Mountpoint %s inaccessible: %s", partition.mountpoint, inaccessible
)
logger.debug("Mountpoint %s inaccessible: %s", partition.mountpoint, inaccessible)

return disk_info

Expand Down
16 changes: 4 additions & 12 deletions agentops/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@ class HttpStatus(Enum):


class Response:
def __init__(
self, status: HttpStatus = HttpStatus.UNKNOWN, body: Optional[dict] = None
):
def __init__(self, status: HttpStatus = HttpStatus.UNKNOWN, body: Optional[dict] = None):
self.status: HttpStatus = status
self.code: int = status.value
self.body = body if body else {}
Expand Down Expand Up @@ -80,17 +78,13 @@ def post(
if jwt is not None:
JSON_HEADER["Authorization"] = f"Bearer {jwt}"

res = request_session.post(
url, data=payload, headers=JSON_HEADER, timeout=20
)
res = request_session.post(url, data=payload, headers=JSON_HEADER, timeout=20)

result.parse(res)
except requests.exceptions.Timeout:
result.code = 408
result.status = HttpStatus.TIMEOUT
raise ApiServerException(
"Could not reach API server - connection timed out"
)
raise ApiServerException("Could not reach API server - connection timed out")
except requests.exceptions.HTTPError as e:
try:
result.parse(e.response)
Expand Down Expand Up @@ -143,9 +137,7 @@ def get(
except requests.exceptions.Timeout:
result.code = 408
result.status = HttpStatus.TIMEOUT
raise ApiServerException(
"Could not reach API server - connection timed out"
)
raise ApiServerException("Could not reach API server - connection timed out")
except requests.exceptions.HTTPError as e:
try:
result.parse(e.response)
Expand Down
Loading

0 comments on commit c8f8585

Please sign in to comment.