From 38d7df103238aaec1e6eab6cb8058b7619472637 Mon Sep 17 00:00:00 2001 From: Joseph Rana Date: Fri, 5 Apr 2024 14:40:01 +0545 Subject: [PATCH] handle edge cases for duplicate bot connection --- autonomous_agent/connect-agent.py | 23 ++++++++++++------- .../app/controllers/agent_websocket.py | 19 ++++++++------- 2 files changed, 26 insertions(+), 16 deletions(-) diff --git a/autonomous_agent/connect-agent.py b/autonomous_agent/connect-agent.py index 892ab0919..4949863d3 100644 --- a/autonomous_agent/connect-agent.py +++ b/autonomous_agent/connect-agent.py @@ -1,21 +1,28 @@ import asyncio import websockets import argparse +from websockets.exceptions import ConnectionClosed -default_ping_timeout = 10 # Sends a ping every 10 second +default_ping_timeout = 10 # Sends a ping every 10 seconds async def connect_to_server(agent_id): - uri = "ws://127.0.0.1:8000/api/agent/ws" headers = {"agent_id": agent_id} - async with websockets.connect(uri, extra_headers=headers) as websocket: - while True: - await websocket.send("PING") - response = await websocket.recv() - print("Received:", response) - await asyncio.sleep(default_ping_timeout) + try: + async with websockets.connect(uri, extra_headers=headers) as websocket: + while True: + try: + await websocket.send("PING") + response = await websocket.recv() + print("Received:", response) + await asyncio.sleep(default_ping_timeout) + except ConnectionClosed: + print("Connection closed by server") + break + except ConnectionError: + print("Failed to connect to the server") async def main(agent_id): diff --git a/autonomous_agent_api/backend/app/controllers/agent_websocket.py b/autonomous_agent_api/backend/app/controllers/agent_websocket.py index fc01fc660..7684d3016 100644 --- a/autonomous_agent_api/backend/app/controllers/agent_websocket.py +++ b/autonomous_agent_api/backend/app/controllers/agent_websocket.py @@ -17,11 +17,13 @@ async def connect_websocket(self, websocket_agent_id: str, websocket: WebSocket) self.active_connections[websocket_agent_id] = websocket async def disconnect_websocket(self, websocket_agent_id): + """Critical : Dont use .close() here as this will close the new connection when dealing with multiple web socket request for the same bot. + This is due to the nature of try/except code that gets called by the previous connection.""" self.active_connections.pop(websocket_agent_id) async def send_message_to_websocket(self, websocket_agent_id: str, message: dict): # Checks if agent is active , first then sends message - agent_active = await self.check_if_agent_exists_in_active_list(websocket_agent_id) + agent_active = await self.check_if_agent_active(websocket_agent_id) if agent_active: await self.active_connections[websocket_agent_id].send_json(message) else: @@ -35,12 +37,12 @@ async def check_if_agent_active(self, websocket_agent_id: str): async def remove_previous_agent_connection_if_exists(self, websocket_agent_id: str): """ - Removes the old agent connection and creates a new one - If client requests websocket connection for an already active conenction. + If client requests websocket connection for an already active bot. + Removes the old connection and establishes a new one. """ - agent_exists = await self.check_if_agent_active(websocket_agent_id) - if agent_exists: - self.active_connections.pop(websocket_agent_id) + if await self.check_if_agent_active(websocket_agent_id): + existing_websocket = self.active_connections.pop(websocket_agent_id) + await existing_websocket.close(code=1000, reason="establishing a new connection") manager = WebSocket_Connection_Manager() @@ -56,7 +58,7 @@ async def agent_websocket_endpoint(websocket: WebSocket): raise WebSocketException(code=status.WS_1008_POLICY_VIOLATION) # Check if agent with the id exists. - agent_exists = await check_if_agent_exists(websocket_agent_id) + agent_exists = await check_if_agent_exists_in_db(websocket_agent_id) if agent_exists: await manager.remove_previous_agent_connection_if_exists(websocket_agent_id) await manager.connect_websocket(websocket_agent_id, websocket) @@ -67,11 +69,12 @@ async def agent_websocket_endpoint(websocket: WebSocket): await websocket.send_text(f"Ping recieved from {websocket_agent_id} at {datetime.now()}") except WebSocketDisconnect: await manager.disconnect_websocket(websocket_agent_id) + pass else: raise WebSocketException(code=status.WS_1008_POLICY_VIOLATION) -async def check_if_agent_exists(agent_id: str): +async def check_if_agent_exists_in_db(agent_id: str): # Query agent with the agent id from the database -> reurns a boolean async with prisma_connection: agent_exists = await prisma_connection.prisma.agent.find_first(where={"id": agent_id, "deleted_at": None})