Skip to content

Commit

Permalink
handle edge cases for duplicate bot connection
Browse files Browse the repository at this point in the history
  • Loading branch information
JosephRana11 committed Apr 5, 2024
1 parent c0e5217 commit 38d7df1
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 16 deletions.
23 changes: 15 additions & 8 deletions autonomous_agent/connect-agent.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,28 @@
import asyncio
import websockets
import argparse
from websockets.exceptions import ConnectionClosed

default_ping_timeout = 10 # Sends a ping every 10 second
default_ping_timeout = 10 # Sends a ping every 10 seconds


async def connect_to_server(agent_id):

uri = "ws://127.0.0.1:8000/api/agent/ws"
headers = {"agent_id": agent_id}

async with websockets.connect(uri, extra_headers=headers) as websocket:
while True:
await websocket.send("PING")
response = await websocket.recv()
print("Received:", response)
await asyncio.sleep(default_ping_timeout)
try:
async with websockets.connect(uri, extra_headers=headers) as websocket:
while True:
try:
await websocket.send("PING")
response = await websocket.recv()
print("Received:", response)
await asyncio.sleep(default_ping_timeout)
except ConnectionClosed:
print("Connection closed by server")
break
except ConnectionError:
print("Failed to connect to the server")


async def main(agent_id):
Expand Down
19 changes: 11 additions & 8 deletions autonomous_agent_api/backend/app/controllers/agent_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@ async def connect_websocket(self, websocket_agent_id: str, websocket: WebSocket)
self.active_connections[websocket_agent_id] = websocket

async def disconnect_websocket(self, websocket_agent_id):
"""Critical : Dont use .close() here as this will close the new connection when dealing with multiple web socket request for the same bot.
This is due to the nature of try/except code that gets called by the previous connection."""
self.active_connections.pop(websocket_agent_id)

async def send_message_to_websocket(self, websocket_agent_id: str, message: dict):
# Checks if agent is active , first then sends message
agent_active = await self.check_if_agent_exists_in_active_list(websocket_agent_id)
agent_active = await self.check_if_agent_active(websocket_agent_id)
if agent_active:
await self.active_connections[websocket_agent_id].send_json(message)
else:
Expand All @@ -35,12 +37,12 @@ async def check_if_agent_active(self, websocket_agent_id: str):

async def remove_previous_agent_connection_if_exists(self, websocket_agent_id: str):
"""
Removes the old agent connection and creates a new one
If client requests websocket connection for an already active conenction.
If client requests websocket connection for an already active bot.
Removes the old connection and establishes a new one.
"""
agent_exists = await self.check_if_agent_active(websocket_agent_id)
if agent_exists:
self.active_connections.pop(websocket_agent_id)
if await self.check_if_agent_active(websocket_agent_id):
existing_websocket = self.active_connections.pop(websocket_agent_id)
await existing_websocket.close(code=1000, reason="establishing a new connection")


manager = WebSocket_Connection_Manager()
Expand All @@ -56,7 +58,7 @@ async def agent_websocket_endpoint(websocket: WebSocket):
raise WebSocketException(code=status.WS_1008_POLICY_VIOLATION)

# Check if agent with the id exists.
agent_exists = await check_if_agent_exists(websocket_agent_id)
agent_exists = await check_if_agent_exists_in_db(websocket_agent_id)
if agent_exists:
await manager.remove_previous_agent_connection_if_exists(websocket_agent_id)
await manager.connect_websocket(websocket_agent_id, websocket)
Expand All @@ -67,11 +69,12 @@ async def agent_websocket_endpoint(websocket: WebSocket):
await websocket.send_text(f"Ping recieved from {websocket_agent_id} at {datetime.now()}")
except WebSocketDisconnect:
await manager.disconnect_websocket(websocket_agent_id)
pass
else:
raise WebSocketException(code=status.WS_1008_POLICY_VIOLATION)


async def check_if_agent_exists(agent_id: str):
async def check_if_agent_exists_in_db(agent_id: str):
# Query agent with the agent id from the database -> reurns a boolean
async with prisma_connection:
agent_exists = await prisma_connection.prisma.agent.find_first(where={"id": agent_id, "deleted_at": None})
Expand Down

0 comments on commit 38d7df1

Please sign in to comment.