Skip to content

Commit

Permalink
new tokens database + chainlog APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
yashbonde committed Mar 3, 2024
1 parent 5aaaa15 commit 16b08b8
Show file tree
Hide file tree
Showing 9 changed files with 224 additions and 69 deletions.
4 changes: 2 additions & 2 deletions client/src/redux/services/auth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ export const authApi = createApi({
}
>({
query: ({ score, prompt_id }) => ({
url: `${BASE_URL}/api/v1/prompts/${prompt_id}/feedback`,
url: `${BASE_URL}/api/v1/prompts/${prompt_id}/feedback/`,
method: 'PUT',
body: {
score
Expand Down Expand Up @@ -136,7 +136,7 @@ export const authApi = createApi({
}
>({
query: ({ score, prompt_id, chatbot_id }) => ({
url: `${BASE_URL}/api/prompts/${prompt_id}/feedback`,
url: `${BASE_URL}/api/prompts/${prompt_id}/feedback/`,
method: 'PUT',
body: {
score
Expand Down
13 changes: 6 additions & 7 deletions server/chainfury_server/api/chains.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def create_chain(
)

# DB call
dag = chatbot_data.dag.dict() if chatbot_data.dag else {}
dag = chatbot_data.dag.model_dump() if chatbot_data.dag else {}
chatbot = DB.ChatBot(
name=chatbot_data.name,
created_by=user.id,
Expand All @@ -51,8 +51,7 @@ def create_chain(
db.refresh(chatbot)

# return
response = T.ApiChain(**chatbot.to_dict())
return response
return chatbot.to_ApiChain()


def get_chain(
Expand All @@ -74,13 +73,13 @@ def get_chain(
]
if tag_id:
filters.append(DB.ChatBot.tag_id == tag_id)
chatbot = db.query(DB.ChatBot).filter(*filters).first() # type: ignore
chatbot: DB.ChatBot = db.query(DB.ChatBot).filter(*filters).first() # type: ignore
if not chatbot:
resp.status_code = 404
return T.ApiResponse(message="ChatBot not found")

# return
return T.ApiChain(**chatbot.to_dict())
return chatbot.to_ApiChain()


def update_chain(
Expand Down Expand Up @@ -130,7 +129,7 @@ def update_chain(
db.refresh(chatbot)

# return
return T.ApiChain(**chatbot.to_dict())
return chatbot.to_ApiChain()


def delete_chain(
Expand Down Expand Up @@ -186,7 +185,7 @@ def list_chains(

# return
return T.ApiListChainsResponse(
chatbots=[T.ApiChain(**chatbot.to_dict()) for chatbot in chatbots],
chatbots=[chatbot.to_ApiChain() for chatbot in chatbots],
)


Expand Down
43 changes: 33 additions & 10 deletions server/chainfury_server/api/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from fastapi import Depends, Header, HTTPException
from fastapi.requests import Request
from fastapi.responses import Response
from typing import Annotated
from typing import Annotated, List
from sqlalchemy.orm import Session

import chainfury_server.database as DB
Expand All @@ -17,30 +17,30 @@ def list_prompts(
limit: int = 100,
offset: int = 0,
db: Session = Depends(DB.fastapi_db_session),
):
) -> T.ApiListPromptsResponse:
# validate user
user = DB.get_user_from_jwt(token=token, db=db)

# get prompts
if limit < 1 or limit > 100:
limit = 100
offset = offset if offset > 0 else 0
prompts = (
prompts: List[DB.Prompt] = (
db.query(DB.Prompt) # type: ignore
.filter(DB.Prompt.chatbot_id == chain_id)
.order_by(DB.Prompt.created_at.desc()) # type: ignore
.limit(limit)
.offset(offset)
.all()
)
return {"prompts": [p.to_dict() for p in prompts]}
return T.ApiListPromptsResponse(prompts=[p.to_ApiPrompt() for p in prompts])


def get_prompt(
prompt_id: int,
token: Annotated[str, Header()],
db: Session = Depends(DB.fastapi_db_session),
):
) -> T.ApiPrompt:
# validate user
user = DB.get_user_from_jwt(token=token, db=db)

Expand All @@ -49,14 +49,15 @@ def get_prompt(
if not prompt:
raise HTTPException(status_code=404, detail="Prompt not found")

return {"prompt": prompt.to_dict()}
# return {"prompt": prompt.to_dict()} # before
return prompt.to_ApiPrompt()


def delete_prompt(
prompt_id: int,
token: Annotated[str, Header()],
db: Session = Depends(DB.fastapi_db_session),
):
) -> T.ApiResponse:
# validate user
user = DB.get_user_from_jwt(token=token, db=db)

Expand All @@ -67,15 +68,15 @@ def delete_prompt(
db.delete(prompt)

db.commit()
return {"msg": f"Prompt: '{prompt_id}' deleted"}
return T.ApiResponse(message=f"Prompt '{prompt.id}' deleted")


def prompt_feedback(
token: Annotated[str, Header()],
inputs: T.ApiPromptFeedback,
prompt_id: int,
db: Session = Depends(DB.fastapi_db_session),
):
) -> T.ApiPromptFeedbackResponse:
# validate user
user = DB.get_user_from_jwt(token=token, db=db)

Expand All @@ -94,4 +95,26 @@ def prompt_feedback(
status_code=404,
detail=f"Unable to find the prompt",
)
return {"rating": prompt.user_rating}
return T.ApiPromptFeedbackResponse(rating=prompt.user_rating) # type: ignore


def get_chain_logs(
token: Annotated[str, Header()],
prompt_id: int,
limit: int = 100,
offset: int = 0,
db: Session = Depends(DB.fastapi_db_session),
) -> T.ApiListChainLogsResponse:
# validate user
user = DB.get_user_from_jwt(token=token, db=db)

# query the DB
chainlogs: List[DB.ChainLog] = (
db.query(DB.ChainLog) # type: ignore
.filter(DB.ChainLog.prompt_id == prompt_id)
.order_by(DB.ChainLog.created_at.desc()) # type: ignore
.limit(limit)
.offset(offset)
.all()
)
return T.ApiListChainLogsResponse(logs=[c.to_ApiChainLog() for c in chainlogs])
77 changes: 67 additions & 10 deletions server/chainfury_server/api/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,30 @@
import chainfury.types as T


def login(auth: T.ApiAuth, db: Session = Depends(DB.fastapi_db_session)):
def login(
req: Request,
resp: Response,
auth: T.ApiAuthRequest,
db: Session = Depends(DB.fastapi_db_session),
) -> T.ApiLoginResponse:
user: DB.User = db.query(DB.User).filter(DB.User.username == auth.username).first() # type: ignore
if user is not None and sha256_crypt.verify(auth.password, user.password): # type: ignore
token = jwt.encode(
payload=DB.JWTPayload(username=auth.username, user_id=user.id).to_dict(),
key=Env.JWT_SECRET(),
)
response = {"msg": "success", "token": token}
return T.ApiLoginResponse(message="success", token=token)
else:
response = {"msg": "failed"}
return response
resp.status_code = 401
return T.ApiLoginResponse(message="failed")


def sign_up(auth: T.ApiSignUp, db: Session = Depends(DB.fastapi_db_session)):
def sign_up(
req: Request,
resp: Response,
auth: T.ApiSignUpRequest,
db: Session = Depends(DB.fastapi_db_session),
) -> T.ApiLoginResponse:
user_exists = False
email_exists = False
user: DB.User = db.query(DB.User).filter(DB.User.username == auth.username).first() # type: ignore
Expand All @@ -36,7 +46,8 @@ def sign_up(auth: T.ApiSignUp, db: Session = Depends(DB.fastapi_db_session)):
email_exists = True
if user_exists and email_exists:
raise HTTPException(
status_code=400, detail="Username and email already registered"
status_code=400,
detail="Username and email already registered",
)
elif user_exists:
raise HTTPException(status_code=400, detail="Username is taken")
Expand All @@ -54,17 +65,17 @@ def sign_up(auth: T.ApiSignUp, db: Session = Depends(DB.fastapi_db_session)):
payload=DB.JWTPayload(username=auth.username, user_id=user.id).to_dict(),
key=Env.JWT_SECRET(),
)
response = {"msg": "success", "token": token}
return T.ApiLoginResponse(message="success", token=token)
else:
response = {"msg": "failed"}
return response
resp.status_code = 400
return T.ApiLoginResponse(message="failed")


def change_password(
req: Request,
resp: Response,
token: Annotated[str, Header()],
inputs: T.ApiChangePassword,
inputs: T.ApiChangePasswordRequest,
db: Session = Depends(DB.fastapi_db_session),
) -> T.ApiResponse:
# validate user
Expand All @@ -78,3 +89,49 @@ def change_password(
else:
resp.status_code = 400
return T.ApiResponse(message="password incorrect")


# TODO: @tunekoro - Implement the following functions


def create_token(
req: Request,
resp: Response,
token: Annotated[str, Header()],
inputs: T.ApiSaveTokenRequest,
db: Session = Depends(DB.fastapi_db_session),
) -> T.ApiResponse:
resp.status_code = 501 #
return T.ApiResponse(message="not implemented")


def get_token(
req: Request,
resp: Response,
key: str,
token: Annotated[str, Header()],
db: Session = Depends(DB.fastapi_db_session),
) -> T.ApiResponse:
resp.status_code = 501 #
return T.ApiResponse(message="not implemented")


def list_tokens(
req: Request,
resp: Response,
token: Annotated[str, Header()],
db: Session = Depends(DB.fastapi_db_session),
) -> T.ApiResponse:
resp.status_code = 501 #
return T.ApiResponse(message="not implemented")


def delete_token(
req: Request,
resp: Response,
key: str,
token: Annotated[str, Header()],
db: Session = Depends(DB.fastapi_db_session),
) -> T.ApiResponse:
resp.status_code = 501 #
return T.ApiResponse(message="not implemented")
36 changes: 21 additions & 15 deletions server/chainfury_server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
description="""
chainfury server is a way to deploy and run chainfury engine over APIs. `chainfury` is [Tune AI](tunehq.ai)'s FOSS project
released under [Apache-2 License](https://choosealicense.com/licenses/apache-2.0/) so you can use this for your commercial
projects. A version `chainfury` is used in production in [Tune.Chat](chat.tune.app) and serves thousands of users daily.
projects. A version `chainfury` is used in production in [Tune.Chat](chat.tune.app), serves and solves thousands of user
queries daily.
""".strip(),
version=__version__,
docs_url="" if Env.CFS_DISABLE_DOCS() else "/docs",
Expand All @@ -42,24 +43,29 @@
app.add_api_route("/api/v1/chatbot/{id}/prompt", api_chains.run_chain, methods=["POST"], tags=["deprecated"], response_model=None) # type: ignore

# user
app.add_api_route("/user/login/", api_user.login, methods=["POST"], tags=["user"]) # type: ignore
app.add_api_route("/user/signup/", api_user.sign_up, methods=["POST"], tags=["user"]) # type: ignore
app.add_api_route("/user/change_password/", api_user.change_password, methods=["POST"], tags=["user"]) # type: ignore
app.add_api_route(methods=["POST"], path="/user/login/", endpoint=api_user.login, tags=["user"]) # type: ignore
app.add_api_route(methods=["POST"], path="/user/signup/", endpoint=api_user.sign_up, tags=["user"]) # type: ignore
app.add_api_route(methods=["POST"], path="/user/change_password/", endpoint=api_user.change_password, tags=["user"]) # type: ignore
app.add_api_route(methods=["PUT"], path="/user/token/", endpoint=api_user.create_token, tags=["user"]) # type: ignore
app.add_api_route(methods=["GET"], path="/user/token/", endpoint=api_user.get_token, tags=["user"]) # type: ignore
app.add_api_route(methods=["DELETE"], path="/user/token/", endpoint=api_user.delete_token, tags=["user"]) # type: ignore
app.add_api_route(methods=["GET"], path="/user/tokens/list/", endpoint=api_user.list_tokens, tags=["user"]) # type: ignore

# chains
app.add_api_route("/api/chains/", api_chains.list_chains, methods=["GET"], tags=["chains"]) # type: ignore
app.add_api_route("/api/chains/", api_chains.create_chain, methods=["PUT"], tags=["chains"]) # type: ignore
app.add_api_route("/api/chains/{id}/", api_chains.get_chain, methods=["GET"], tags=["chains"]) # type: ignore
app.add_api_route("/api/chains/{id}/", api_chains.delete_chain, methods=["DELETE"], tags=["chains"]) # type: ignore
app.add_api_route("/api/chains/{id}/", api_chains.update_chain, methods=["PATCH"], tags=["chains"]) # type: ignore
app.add_api_route("/api/chains/{id}/", api_chains.run_chain, methods=["POST"], tags=["chains"], response_model=None) # type: ignore
app.add_api_route("/api/chains/{id}/metrics/", api_chains.get_chain_metrics, methods=["GET"], tags=["chains"]) # type: ignore
app.add_api_route(methods=["GET"], path="/api/chains/", endpoint=api_chains.list_chains, tags=["chains"]) # type: ignore
app.add_api_route(methods=["PUT"], path="/api/chains/", endpoint=api_chains.create_chain, tags=["chains"]) # type: ignore
app.add_api_route(methods=["GET"], path="/api/chains/{id}/", endpoint=api_chains.get_chain, tags=["chains"]) # type: ignore
app.add_api_route(methods=["DELETE"], path="/api/chains/{id}/", endpoint=api_chains.delete_chain, tags=["chains"]) # type: ignore
app.add_api_route(methods=["PATCH"], path="/api/chains/{id}/", endpoint=api_chains.update_chain, tags=["chains"]) # type: ignore
app.add_api_route(methods=["POST"], path="/api/chains/{id}/", endpoint=api_chains.run_chain, tags=["chains"], response_model=None) # type: ignore
app.add_api_route(methods=["GET"], path="/api/chains/{id}/metrics/", endpoint=api_chains.get_chain_metrics, tags=["chains"]) # type: ignore

# prompts
app.add_api_route("/api/prompts/", api_prompts.list_prompts, methods=["GET"], tags=["prompts"]) # type: ignore
app.add_api_route("/api/prompts/{prompt_id}/", api_prompts.get_prompt, methods=["GET"], tags=["prompts"]) # type: ignore
app.add_api_route("/api/prompts/{prompt_id}/", api_prompts.delete_prompt, methods=["DELETE"], tags=["prompts"]) # type: ignore
app.add_api_route("/api/prompts/{prompt_id}/feedback", api_prompts.prompt_feedback, methods=["PUT"], tags=["prompts"]) # type: ignore
app.add_api_route(methods=["GET"], path="/api/prompts/", endpoint=api_prompts.list_prompts, tags=["prompts"]) # type: ignore
app.add_api_route(methods=["GET"], path="/api/prompts/{prompt_id}/", endpoint=api_prompts.get_prompt, tags=["prompts"]) # type: ignore
app.add_api_route(methods=["DELETE"], path="/api/prompts/{prompt_id}/", endpoint=api_prompts.delete_prompt, tags=["prompts"]) # type: ignore
app.add_api_route(methods=["PUT"], path="/api/prompts/{prompt_id}/feedback/", endpoint=api_prompts.prompt_feedback, tags=["prompts"]) # type: ignore
app.add_api_route(methods=["GET"], path="/api/prompts/{prompt_id}/logs/", endpoint=api_prompts.get_chain_logs, tags=["prompts"]) # type: ignore


# UI files
Expand Down
Loading

0 comments on commit 16b08b8

Please sign in to comment.