Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement fastapi for http server #82

Merged
merged 21 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 39 additions & 2 deletions docker-compose.example.yml
Original file line number Diff line number Diff line change
@@ -1,13 +1,50 @@
---

services:
server:
api:
build:
context: .
target: prod
command: python3 server.py
command: uvicorn main:app --host 0.0.0.0 --port 8000 --reload
ports:
- 8000:8000
environment:
- RABBIT_USER=root
- RABBIT_PASSWORD=pass
- RABBIT_HOST=rabbitmq
- RABBIT_PORT=5672
- REDIS_PASSWORD=pass
- REDIS_HOST=redis
- REDIS_PORT=6379
volumes:
- ./:/project/
amindadgar marked this conversation as resolved.
Show resolved Hide resolved
worker:
build:
context: .
target: prod
dockerfile: Dockerfile
environment:
- RABBIT_USER=root
- RABBIT_PASSWORD=pass
- RABBIT_HOST=rabbitmq
- RABBIT_PORT=5672
- REDIS_PASSWORD=pass
- REDIS_HOST=redis
- REDIS_PORT=6379
rabbitmq:
image: "rabbitmq:3-management-alpine"
environment:
- RABBITMQ_DEFAULT_USER=root
- RABBITMQ_DEFAULT_PASS=pass
healthcheck:
test: rabbitmq-diagnostics -q ping
interval: 30s
timeout: 30s
retries: 2
start_period: 40s
amindadgar marked this conversation as resolved.
Show resolved Hide resolved
ports:
- 5672:5672
redis:
image: bitnami/redis
environment:
- REDIS_PASSWORD=pass
8 changes: 8 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from fastapi import FastAPI
from routers.http import router as httpRouter
from routers.amqp import router as amqpRouter

app = FastAPI()

app.include_router(httpRouter)
app.include_router(amqpRouter)
4 changes: 4 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@ tc-hivemind-backend==1.2.2
llama-index-question-gen-guidance==0.1.2
llama-index-vector-stores-postgres==0.1.2
celery>=5.3.6, <6.0.0
celery[redis]>=5.3.6, <6.0.0
amindadgar marked this conversation as resolved.
Show resolved Hide resolved
guidance==0.1.14
tc-messageBroker==1.6.6
traceloop-sdk==0.14.1
backoff==2.2.1
fastapi[standard]==0.114.1
faststream==0.5.23
aio_pika==9.4.0
48 changes: 48 additions & 0 deletions routers/amqp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from datetime import datetime

from pydantic import BaseModel
from faststream.rabbit.fastapi import RabbitRouter, Logger # type: ignore
from faststream.rabbit.schemas.queue import RabbitQueue
from utils.credentials import load_rabbitmq_credentials
from tc_messageBroker.rabbit_mq.queue import Queue
from tc_messageBroker.rabbit_mq.event import Event
from worker.tasks import query_data_sources

rabbitmq_creds = load_rabbitmq_credentials()

router = RabbitRouter(rabbitmq_creds["url"])


class Content(BaseModel):
question: str
community_id: str


class Payload(BaseModel):
event: str
date: datetime | str
content: Content | dict
amindadgar marked this conversation as resolved.
Show resolved Hide resolved


@router.subscriber(queue=RabbitQueue(name=Queue.HIVEMIND, durable=True))
@router.publisher(queue=RabbitQueue(Queue.DISCORD_BOT, durable=True))
async def ask(payload: Payload, logger: Logger):
amindadgar marked this conversation as resolved.
Show resolved Hide resolved
if payload.event == Event.HIVEMIND.INTERACTION_CREATED:
question = payload.content.question
community_id = payload.content.community_id

logger.info(f"COMMUNITY_ID: {community_id} Received job")
response = query_data_sources(community_id=community_id, query=question)
logger.info(f"COMMUNITY_ID: {community_id} Job finished")

response_payload = Payload(
event=Event.DISCORD_BOT.INTERACTION_RESPONSE.EDIT,
date=str(datetime.now()),
content={"response": response},
)
else:
raise NotImplementedError(
f"No more event available for {Queue.HIVEMIND} queue! "
f"Received event: `{payload.event}`"
)
return response_payload
amindadgar marked this conversation as resolved.
Show resolved Hide resolved
26 changes: 26 additions & 0 deletions routers/http.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from worker.tasks import ask_question_auto_search
from celery.result import AsyncResult
from pydantic import BaseModel
from fastapi import APIRouter
from typing import Any
amindadgar marked this conversation as resolved.
Show resolved Hide resolved


class Payload(BaseModel):
query: str
community_id: str
# bot_given_info: dict[str, Any]


router = APIRouter()


@router.post("/ask")
async def ask(payload: Payload):
task = ask_question_auto_search.delay(**payload.model_dump())
return {"id": task.id}


@router.get("/status")
async def status(task_id: str):
task = AsyncResult(task_id)
return {"id": task.id, "status": task.status, "result": task.result}
4 changes: 2 additions & 2 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from tc_messageBroker.rabbit_mq.queue import Queue
from utils.credentials import load_rabbitmq_credentials
from utils.fetch_community_id import fetch_community_id_by_guild_id
from worker.tasks import ask_question_auto_search
from worker.tasks import ask_question_auto_search_discord_interaction


def query_llm(recieved_data: dict[str, Any]):
Expand All @@ -32,7 +32,7 @@ def query_llm(recieved_data: dict[str, Any]):

community_id = fetch_community_id_by_guild_id(guild_id=recieved_input.guild_id)
logging.info(f"COMMUNITY_ID: {community_id} | Sending job to Celery!")
ask_question_auto_search.delay(
ask_question_auto_search_discord_interaction.delay(
question=user_input,
community_id=community_id,
bot_given_info=recieved_data,
Expand Down
28 changes: 28 additions & 0 deletions test_message.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from tc_messageBroker.rabbit_mq.event import Event
import asyncio
import aio_pika
import json


async def main() -> None:
connection = await aio_pika.connect_robust(
"amqp://root:[email protected]:5672/",
)

async with connection:
routing_key = Event.HIVEMIND.INTERACTION_CREATED

channel = await connection.channel()

payload = {"question": "what is AI?", "community_id": "9999999999999"}

body = json.dumps(payload).encode("utf-8")

await channel.default_exchange.publish(
aio_pika.Message(body=body),
routing_key=routing_key,
)


if __name__ == "__main__":
asyncio.run(main())
48 changes: 44 additions & 4 deletions utils/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,22 @@ def load_rabbitmq_credentials() -> dict[str, str]:
`password` : str
`host` : str
`port` : int
`url` : str
"""
load_dotenv()

rabbitmq_creds = {}

rabbitmq_creds["user"] = os.getenv("RABBIT_USER", "")
rabbitmq_creds["password"] = os.getenv("RABBIT_PASSWORD", "")
rabbitmq_creds["host"] = os.getenv("RABBIT_HOST", "")
rabbitmq_creds["port"] = os.getenv("RABBIT_PORT", "")
user = os.getenv("RABBIT_USER", "")
password = os.getenv("RABBIT_PASSWORD", "")
host = os.getenv("RABBIT_HOST", "")
port = os.getenv("RABBIT_PORT", "")

rabbitmq_creds["user"] = user
rabbitmq_creds["password"] = password
rabbitmq_creds["host"] = host
rabbitmq_creds["port"] = port
rabbitmq_creds["url"] = f"amqp://{user}:{password}@{host}:{port}"

return rabbitmq_creds

Expand Down Expand Up @@ -79,3 +86,36 @@ def load_mongo_credentials() -> dict[str, str]:
mongo_creds["port"] = os.getenv("MONGODB_PORT", "")

return mongo_creds


def load_redis_credentials() -> dict[str, str]:
"""
load redis db credentials from .env

Returns:
---------
redis_creds : dict[str, Any]
redis credentials
a dictionary representative of
`user`: str
`password` : str
`host` : str
`port` : int
`url` : str
"""
load_dotenv()

redis_creds = {}

user = os.getenv("REDIS_USER", "")
password = os.getenv("REDIS_PASSWORD", "")
host = os.getenv("REDIS_HOST", "")
port = os.getenv("REDIS_PORT", "")

redis_creds["user"] = user
redis_creds["password"] = password
redis_creds["host"] = host
redis_creds["port"] = port
redis_creds["url"] = f"redis://{user}:{password}@{host}:{port}"

return redis_creds
4 changes: 4 additions & 0 deletions worker/celery.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
from celery import Celery
from utils.credentials import load_rabbitmq_credentials
from utils.credentials import load_redis_credentials

rabbit_creds = load_rabbitmq_credentials()
user = rabbit_creds["user"]
password = rabbit_creds["password"]
host = rabbit_creds["host"]
port = rabbit_creds["port"]

redis_creds = load_redis_credentials()

app = Celery(
"tasks",
broker=f"pyamqp://{user}:{password}@{host}:{port}//",
backend=redis_creds["url"],
include=["worker.tasks"],
)

Expand Down
52 changes: 44 additions & 8 deletions worker/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,14 @@


@app.task
def ask_question_auto_search(
def ask_question_auto_search_discord_interaction(
question: str,
community_id: str,
bot_given_info: dict[str, Any],
) -> None:
"""
this task is for the case that the user asks a question
and use the discord interaction schema
it would first retrieve the search metadata from summaries
then perform a query on the filetred raw data to find answer

Expand Down Expand Up @@ -69,13 +70,7 @@ def ask_question_auto_search(
# )
logging.info(f"{prefix}Querying the data sources!")
# for now we have just the discord platform
selector = DataSourceSelector()
data_sources = selector.select_data_source(community_id)
response, _ = query_multiple_source(
query=question,
community_id=community_id,
**data_sources,
)
response = query_data_sources(community_id=community_id, query=question)

# source_nodes_dict: list[dict[str, Any]] = []
# for node in source_nodes:
Expand Down Expand Up @@ -119,6 +114,15 @@ def ask_question_auto_search(
)


@app.task
def ask_question_auto_search(
community_id: str,
query: str,
) -> str:
response = query_data_sources(community_id=community_id, query=query)
return response


@task_prerun.connect
def task_prerun_handler(sender=None, **kwargs):
# Initialize Traceloop for LLM
Expand All @@ -129,3 +133,35 @@ def task_prerun_handler(sender=None, **kwargs):
def task_postrun_handler(sender=None, **kwargs):
# Trigger garbage collection after each task
gc.collect()


def query_data_sources(
community_id: str,
query: str,
) -> str:
"""
ask questions with auto select platforms

Parameters
-------------
community_id : str
the community id data to use for answering
query : str
the user query to ask llm

Returns
---------
response : str
the LLM's response
"""
logging.info(f"COMMUNITY_ID: {community_id} Finding data sources to query to!")
selector = DataSourceSelector()
data_sources = selector.select_data_source(community_id)
logging.info(f"Quering data sources: {data_sources}!")
response, _ = query_multiple_source(
query=query,
community_id=community_id,
**data_sources,
)

return response
Loading