Skip to content

Commit

Permalink
using database
Browse files Browse the repository at this point in the history
  • Loading branch information
siwonpada committed May 27, 2024
1 parent 111f423 commit 075c039
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 12 deletions.
11 changes: 10 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
import os
import dotenv
from utils.prisma import prisma
from fastapi import FastAPI
from fastapi.concurrency import asynccontextmanager

import routers

# loading fron .env file
dotenv.load_dotenv()

app = FastAPI(root_path=os.environ.get('BASE_URL', ''))
@asynccontextmanager
async def lifespan(app: FastAPI):
await prisma.connect()
yield
await prisma.disconnect()


app = FastAPI(lifespan=lifespan,root_path=os.environ.get('BASE_URL', ''))

# connecting router
app.include_router(routers.home.router)
Expand Down
24 changes: 20 additions & 4 deletions prisma/schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,23 @@ datasource db {
url = env("DATABASE_URL")
}

model Flower {
id Int @id @default(autoincrement())
name String
}
enum ChatRole {
USER
AI
}

model Log {
id Int @id @default(autoincrement())
model String
chat Chat[]
}

model Chat {
id Int @id @default(autoincrement())
role ChatRole
text String
log Log @relation(fields: [logId], references: [id])
logId Int
}
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ gunicorn = "^22.0.0"
python-dotenv = "^1.0.1"
langchain-openai = "^0.1.7"
pydantic = "^2.7.1"
prisma = "^0.13.1"


[build-system]
Expand Down
70 changes: 63 additions & 7 deletions routers/chat.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from fastapi import APIRouter

from prisma import enums
from utils.prisma import prisma

from llm.chat import build
from llm.model import ChatGPTModel

Expand All @@ -12,12 +15,18 @@ class InputModel(BaseModel):
description='필요한 꽃을 유추하기 위한 문장'
)

context: str = Field(
alias='context',
description='대화의 문맥'
id: int | None = Field(
alias='id',
description='사용자 식별자, 처음 대화에서는 자동으로 생성되어 반환된다.',
default=None
)

class OutputModel(BaseModel):
id: int = Field(
alias='id',
description='사용자 식별자, input과 동일하게 반환된다.'
)

output: str = Field(
alias='output',
description='생성된 문장'
Expand All @@ -32,12 +41,59 @@ class OutputModel(BaseModel):
@router.post('/chat')
async def chat(input: InputModel) -> OutputModel:
chain = build(model.build())
if input.id == None:
log = await prisma.log.create(
data={
'model': 'gpt-4o'
}, include={
'chat': True
}
)
else:
log = await prisma.log.find_unique(where = {
'id': input.id
}, include={
'chat': True
})

return OutputModel(
output=chain.invoke({
context = ''
for chat in log.chat:
if chat.role == 'USER':
context += f'사용자: {chat.text}\n'
else:
context += f'봇: {chat.text}\n'

output = chain.invoke({
'input_context': f'''
* 현재까지의 대화: {input.context}
* 현재까지의 대화: {context}
* 사용자 입력: {input.chat}
'''
}),
})

await prisma.chat.create(
data={
'role': enums.ChatRole.USER,
'text': input.chat,
'log': {
'connect': {
'id': log.id
}
}
}
)

await prisma.chat.create(
data={
'role': enums.ChatRole.AI,
'text': output,
'log': {
'connect': {
'id': log.id
}
}
}
)

return OutputModel(
output=output, id=log.id
)
3 changes: 3 additions & 0 deletions utils/prisma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from prisma import Prisma

prisma = Prisma()

0 comments on commit 075c039

Please sign in to comment.