Skip to content

Commit

Permalink
Fetch conversation for seed data tasks, minor model fixes (#485)
Browse files Browse the repository at this point in the history
* Fetch conversation for seed data, fix models, remove redundant payload type checks
  • Loading branch information
andreaskoepf authored Jan 7, 2023
1 parent eaefa68 commit 96d6717
Show file tree
Hide file tree
Showing 9 changed files with 61 additions and 54 deletions.
15 changes: 10 additions & 5 deletions backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,12 +132,17 @@ class DummyMessage(BaseModel):
parent_message = pr.fetch_message_by_frontend_message_id(
msg.parent_message_id, fail_if_missing=True
)
task = pr.store_task(
protocol_schema.AssistantReplyTask(
conversation=protocol_schema.Conversation(
messages=[protocol_schema.ConversationMessage(text="dummy", is_assistant=False)]
conversation_messages = pr.fetch_message_conversation(parent_message)
conversation = protocol_schema.Conversation(
messages=[
protocol_schema.ConversationMessage(
text=msg.text, is_assistant=msg.role == "assistant"
)
),
for msg in conversation_messages
]
)
task = pr.store_task(
protocol_schema.AssistantReplyTask(conversation=conversation),
message_tree_id=parent_message.message_tree_id,
parent_message_id=parent_message.id,
)
Expand Down
7 changes: 0 additions & 7 deletions backend/oasst_backend/api/v1/frontend_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@
from oasst_backend.api import deps
from oasst_backend.api.v1 import utils
from oasst_backend.models import ApiClient
from oasst_backend.models.db_payload import MessagePayload
from oasst_backend.prompt_repository import PromptRepository
from oasst_shared.exceptions import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol
from sqlmodel import Session

Expand All @@ -20,11 +18,6 @@ def get_message_by_frontend_id(
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_message_by_frontend_message_id(message_id)

if not isinstance(message.payload.payload, MessagePayload):
# Unexpected message payload
raise OasstError("Invalid message", OasstErrorCode.INVALID_MESSAGE)

return utils.prepare_message(message)


Expand Down
6 changes: 0 additions & 6 deletions backend/oasst_backend/api/v1/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
from oasst_backend.api import deps
from oasst_backend.api.v1 import utils
from oasst_backend.models import ApiClient
from oasst_backend.models.db_payload import MessagePayload
from oasst_backend.prompt_repository import PromptRepository
from oasst_shared.exceptions import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol
from sqlmodel import Session
from starlette.status import HTTP_204_NO_CONTENT
Expand Down Expand Up @@ -55,10 +53,6 @@ def get_message(
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_message(message_id)
if not isinstance(message.payload.payload, MessagePayload):
# Unexptcted message payload
raise OasstError("Invalid message", OasstErrorCode.INVALID_MESSAGE)

return utils.prepare_message(message)


Expand Down
24 changes: 10 additions & 14 deletions backend/oasst_backend/api/v1/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,7 @@ def generate_task(
logger.info("Generating a PrompterReplyTask.")
messages = pr.fetch_random_conversation("assistant")
task_messages = [
protocol_schema.ConversationMessage(
text=msg.payload.payload.text, is_assistant=(msg.role == "assistant")
)
protocol_schema.ConversationMessage(text=msg.text, is_assistant=(msg.role == "assistant"))
for msg in messages
]

Expand All @@ -70,9 +68,7 @@ def generate_task(
logger.info("Generating a AssistantReplyTask.")
messages = pr.fetch_random_conversation("prompter")
task_messages = [
protocol_schema.ConversationMessage(
text=msg.payload.payload.text, is_assistant=(msg.role == "assistant")
)
protocol_schema.ConversationMessage(text=msg.text, is_assistant=(msg.role == "assistant"))
for msg in messages
]

Expand All @@ -83,19 +79,19 @@ def generate_task(
logger.info("Generating a RankInitialPromptsTask.")

messages = pr.fetch_random_initial_prompts()
task = protocol_schema.RankInitialPromptsTask(prompts=[msg.payload.payload.text for msg in messages])
task = protocol_schema.RankInitialPromptsTask(prompts=[msg.text for msg in messages])
case protocol_schema.TaskRequestType.rank_prompter_replies:
logger.info("Generating a RankPrompterRepliesTask.")
conversation, replies = pr.fetch_multiple_random_replies(message_role="assistant")

task_messages = [
protocol_schema.ConversationMessage(
text=p.payload.payload.text,
text=p.text,
is_assistant=(p.role == "assistant"),
)
for p in conversation
]
replies = [p.payload.payload.text for p in replies]
replies = [p.text for p in replies]
task = protocol_schema.RankPrompterRepliesTask(
conversation=protocol_schema.Conversation(
messages=task_messages,
Expand All @@ -109,12 +105,12 @@ def generate_task(

task_messages = [
protocol_schema.ConversationMessage(
text=p.payload.payload.text,
text=p.text,
is_assistant=(p.role == "assistant"),
)
for p in conversation
]
replies = [p.payload.payload.text for p in replies]
replies = [p.text for p in replies]
task = protocol_schema.RankAssistantRepliesTask(
conversation=protocol_schema.Conversation(messages=task_messages),
replies=replies,
Expand All @@ -125,14 +121,14 @@ def generate_task(
message = pr.fetch_random_initial_prompts(1)[0]
task = protocol_schema.LabelInitialPromptTask(
message_id=message.id,
prompt=message.payload.payload.text,
prompt=message.text,
valid_labels=list(map(lambda x: x.value, protocol_schema.TextLabel)),
)

case protocol_schema.TaskRequestType.label_prompter_reply:
logger.info("Generating a LabelPrompterReplyTask.")
conversation, messages = pr.fetch_multiple_random_replies(max_size=1, message_role="assistant")
message = messages[0].payload.payload.text
message = messages[0].text
task = protocol_schema.LabelPrompterReplyTask(
message_id=message.id,
conversation=conversation,
Expand All @@ -143,7 +139,7 @@ def generate_task(
case protocol_schema.TaskRequestType.label_assistant_reply:
logger.info("Generating a LabelAssistantReplyTask.")
conversation, messages = pr.fetch_multiple_random_replies(max_size=1, message_role="prompter")
message = messages[0].payload.payload.text
message = messages[0].text
task = protocol_schema.LabelAssistantReplyTask(
message_id=message.id,
conversation=conversation,
Expand Down
13 changes: 2 additions & 11 deletions backend/oasst_backend/api/v1/utils.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,14 @@
from http import HTTPStatus
from uuid import UUID

from oasst_backend.models import Message
from oasst_backend.models.db_payload import MessagePayload
from oasst_shared.exceptions import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol


def prepare_message(m: Message) -> protocol.Message:
if not isinstance(m.payload.payload, MessagePayload):
raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR)
return protocol.Message(
id=m.id,
parent_id=m.parent_id,
text=m.payload.payload.text,
text=m.text,
is_assistant=(m.role == "assistant"),
created_date=m.created_date,
)
Expand All @@ -26,10 +21,8 @@ def prepare_message_list(messages: list[Message]) -> list[protocol.Message]:
def prepare_conversation(messages: list[Message]) -> protocol.Conversation:
conv_messages = []
for message in messages:
if not isinstance(message.payload.payload, MessagePayload):
raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR)
conv_messages.append(
protocol.ConversationMessage(text=message.payload.payload.text, is_assistant=(message.role == "assistant"))
protocol.ConversationMessage(text=message.text, is_assistant=(message.role == "assistant"))
)

return protocol.Conversation(messages=conv_messages)
Expand All @@ -38,8 +31,6 @@ def prepare_conversation(messages: list[Message]) -> protocol.Conversation:
def prepare_tree(tree: list[Message], tree_id: UUID) -> protocol.MessageTree:
tree_messages = []
for message in tree:
if not isinstance(message.payload.payload, MessagePayload):
raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR)
tree_messages.append(prepare_message(message))

return protocol.MessageTree(id=tree_id, messages=tree_messages)
10 changes: 5 additions & 5 deletions backend/oasst_backend/models/journal.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class Journal(SQLModel, table=True):
created_date: Optional[datetime] = Field(
sa_column=sa.Column(sa.DateTime(timezone=True), nullable=False, server_default=sa.func.current_timestamp())
)
user_id: UUID = Field(nullable=True, foreign_key="user.id", index=True)
user_id: Optional[UUID] = Field(nullable=True, foreign_key="user.id", index=True)
message_id: Optional[UUID] = Field(foreign_key="message.id", nullable=True)
api_client_id: UUID = Field(foreign_key="api_client.id")

Expand All @@ -49,7 +49,7 @@ class JournalIntegration(SQLModel, table=True):
),
)
description: str = Field(max_length=512, primary_key=True)
last_journal_id: UUID = Field(foreign_key="journal.id", nullable=True)
last_run: datetime = Field(sa_column=sa.Column(sa.DateTime(), nullable=True))
last_error: str = Field(nullable=True)
next_run: datetime = Field(nullable=True)
last_journal_id: Optional[UUID] = Field(foreign_key="journal.id", nullable=True)
last_run: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(), nullable=True))
last_error: Optional[str] = Field(nullable=True)
next_run: Optional[datetime] = Field(nullable=True)
24 changes: 19 additions & 5 deletions backend/oasst_backend/models/message.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from datetime import datetime
from http import HTTPStatus
from typing import Optional
from uuid import UUID, uuid4

import sqlalchemy as sa
import sqlalchemy.dialects.postgresql as pg
from oasst_backend.models.db_payload import MessagePayload
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
from sqlalchemy import false
from sqlmodel import Field, Index, SQLModel

Expand All @@ -19,19 +22,30 @@ class Message(SQLModel, table=True):
pg.UUID(as_uuid=True), primary_key=True, default=uuid4, server_default=sa.text("gen_random_uuid()")
),
)
parent_id: UUID = Field(nullable=True)
parent_id: Optional[UUID] = Field(nullable=True)
message_tree_id: UUID = Field(nullable=False, index=True)
task_id: UUID = Field(nullable=True, index=True)
user_id: UUID = Field(nullable=True, foreign_key="user.id", index=True)
role: str = Field(nullable=False, max_length=128) # valid: "prompter" | "assistant"
task_id: Optional[UUID] = Field(nullable=True, index=True)
user_id: Optional[UUID] = Field(nullable=True, foreign_key="user.id", index=True)
role: str = Field(nullable=False, max_length=128, regex="^prompter|assistant$")
api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id")
frontend_message_id: str = Field(max_length=200, nullable=False)
created_date: Optional[datetime] = Field(
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp())
)
payload_type: str = Field(nullable=False, max_length=200)
payload: PayloadContainer = Field(sa_column=sa.Column(payload_column_type(PayloadContainer), nullable=True))
payload: Optional[PayloadContainer] = Field(
sa_column=sa.Column(payload_column_type(PayloadContainer), nullable=True)
)
lang: str = Field(nullable=False, max_length=200, default="en-US")
depth: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False))
children_count: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False))
deleted: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=false()))

def ensure_is_message(self) -> None:
if not self.payload or not isinstance(self.payload.payload, MessagePayload):
raise OasstError("Invalid message", OasstErrorCode.INVALID_MESSAGE, HTTPStatus.INTERNAL_SERVER_ERROR)

@property
def text(self) -> str:
self.ensure_is_message()
return self.payload.payload.text
2 changes: 1 addition & 1 deletion backend/oasst_backend/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class Task(SQLModel, table=True):
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()),
)
expiry_date: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(), nullable=True))
user_id: UUID = Field(nullable=True, foreign_key="user.id", index=True)
user_id: Optional[UUID] = Field(nullable=True, foreign_key="user.id", index=True)
payload_type: str = Field(nullable=False, max_length=200)
payload: PayloadContainer = Field(sa_column=sa.Column(payload_column_type(PayloadContainer), nullable=False))
api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id")
Expand Down
14 changes: 14 additions & 0 deletions backend/test_data/generic/test_generic_data.json
Original file line number Diff line number Diff line change
Expand Up @@ -54,5 +54,19 @@
"parent_message_id": "cec432cf",
"text": "I'm unsure how to interpret this. Is it a riddle?",
"role": "assistant"
},
{
"task_message_id": "b8e98ed6",
"user_message_id": "89384709",
"parent_message_id": "0e276b98",
"text": "No, I just wanted to see how you reply when I type random characters. Can you tell me who invented Wikipedia?",
"role": "prompter"
},
{
"task_message_id": "9a0e7683",
"user_message_id": "6d452c57",
"parent_message_id": "0e276b98",
"text": "Sorry, my cat sat on my keyboard. Can you print a cat in ASCII art?",
"role": "prompter"
}
]

0 comments on commit 96d6717

Please sign in to comment.