Skip to content

Commit

Permalink
Merge pull request #22 from OpenGenenerativeAI/some-code-improvements
Browse files Browse the repository at this point in the history
Some Fix & code quality improvements
  • Loading branch information
StanGirard authored Jul 31, 2023
2 parents 023be34 + 5c8ef4d commit 5c33854
Show file tree
Hide file tree
Showing 12 changed files with 97 additions and 145 deletions.
2 changes: 2 additions & 0 deletions demo/widgets/genoss_backend_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ def display_message_if_failing_to_access_genoss() -> None:
)


# Fix bug where the custom endpoint will be added at every rerun
@st.cache_data(experimental_allow_widgets=True)
def add_custom_hf_endpoint_if_available_or_display_warning() -> None:
if SETTINGS.custom_hf_endpoint_url is None:
st.warning(
Expand Down
5 changes: 2 additions & 3 deletions genoss/api/completions_routes.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from typing import Any

from fastapi import APIRouter, Body, HTTPException
from fastapi.params import Depends
from pydantic import BaseModel

from genoss.auth.auth_handler import AuthHandler
from genoss.entities.chat.chat_completion import ChatCompletion
from genoss.entities.chat.message import Message
from genoss.services.model_factory import ModelFactory
from logger import get_logger
Expand All @@ -27,7 +26,7 @@ async def post_chat_completions(
api_key: str = Depends( # type: ignore[assignment] # noqa: B008
AuthHandler.check_auth_header, use_cache=False
),
) -> dict[str, Any]:
) -> ChatCompletion:
model = ModelFactory.get_model_from_name(body.model, api_key) # pyright: ignore

if model is None:
Expand Down
99 changes: 46 additions & 53 deletions genoss/entities/chat/chat_completion.py
Original file line number Diff line number Diff line change
@@ -1,59 +1,52 @@
import time
import uuid
from typing import Any
from typing import Self

from pydantic import BaseModel, Field

from genoss.entities.chat.message import Message


# TODO: why is this nested classes ?
# TODO: why don't we use a pydantic ?
class ChatCompletion:
class Choice:
def __init__(
self, message: Message, finish_reason: str = "stop", index: int = 0
):
self.message = message
self.finish_reason = finish_reason
self.index = index

def to_dict(self) -> dict[str, Any]:
return {
"message": self.message.to_dict(),
"finish_reason": self.finish_reason,
"index": self.index,
}

class Usage:
def __init__(
self, prompt_tokens: int, completion_tokens: int, total_tokens: int
):
self.prompt_tokens = prompt_tokens
self.completion_tokens = completion_tokens
self.total_tokens = total_tokens

def to_dict(self) -> dict[str, Any]:
return {
"prompt_tokens": self.prompt_tokens,
"completion_tokens": self.completion_tokens,
"total_tokens": self.total_tokens,
}

def __init__(self, model: str, question: str, answer: str):
self.id = str(uuid.uuid4())
self.object = "chat.completion"
self.created = int(time.time())
self.model = model
self.usage = self.Usage(len(question), len(answer), len(question) + len(answer))
self.choices = [
self.Choice(Message(role="assistant", content=answer), "stop", 0)
]

def to_dict(self) -> dict[str, Any]:
return {
"id": self.id,
"object": self.object,
"created": self.created,
"model": self.model,
"usage": self.usage.to_dict(),
"choices": [choice.to_dict() for choice in self.choices],
}
class Choice(BaseModel):
message: Message
finish_reason: str = "stop"
index: int = 0

@classmethod
def from_model_answer(cls, answer: str) -> Self:
return cls(
message=Message(role="assistant", content=answer),
finish_reason="stop",
index=0,
)


class Usage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int

@classmethod
def from_question_and_answer(cls, question: str, answer: str) -> Self:
return cls(
prompt_tokens=len(question),
completion_tokens=len(answer),
total_tokens=len(question) + len(answer),
)


class ChatCompletion(BaseModel):
id: uuid.UUID = Field(default_factory=uuid.uuid4)
object: str = "chat.completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
usage: Usage
choices: list[Choice]

@classmethod
def from_model_question_answer(cls, model: str, question: str, answer: str) -> Self:
return cls(
model=model,
usage=Usage.from_question_and_answer(question, answer),
choices=[Choice.from_model_answer(answer)],
)
9 changes: 4 additions & 5 deletions genoss/entities/chat/message.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
from typing import Any
from typing import Literal

from pydantic import BaseModel, Field

MessageRole = Literal["system", "user", "assistant", "function"]


class Message(BaseModel):
role: str = Field(
role: MessageRole = Field(
...,
description="The role of the messages author. One of system, user, assistant, or function.",
)
content: str = Field(
...,
description="The contents of the message. content is required for all messages, and may be null for assistant messages with function calls.",
)

def to_dict(self) -> dict[str, Any]:
return {"role": self.role, "content": self.content}
20 changes: 18 additions & 2 deletions genoss/llm/base_genoss.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,35 @@
from abc import abstractmethod
from typing import Any

from langchain import LLMChain
from pydantic import BaseModel

from genoss.entities.chat.chat_completion import ChatCompletion
from genoss.entities.chat.message import Message
from genoss.prompts.prompt_template import prompt_template


class BaseGenossLLM(BaseModel):
name: str
description: str

@abstractmethod
def generate_answer(self, messages: list[Message]) -> dict[str, Any]:
def generate_answer(self, messages: list[Message]) -> ChatCompletion:
pass

def _chat_completion_from_langchain_llm(
self, llm: BaseModel, messages: list[Message]
) -> ChatCompletion:
llm_chain = LLMChain(prompt=prompt_template, llm=llm)

question = messages[-1].content
response_text = llm_chain(question)

answer = response_text["text"]

return ChatCompletion.from_model_question_answer(
model=self.name, answer=answer, question=question
)

@abstractmethod
def generate_embedding(self, text: str) -> list[float]:
pass
20 changes: 4 additions & 16 deletions genoss/llm/fake_llm.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING

from langchain import LLMChain
from langchain.embeddings import FakeEmbeddings
from langchain.llms import FakeListLLM

from genoss.entities.chat.chat_completion import ChatCompletion
from genoss.llm.base_genoss import BaseGenossLLM
from genoss.prompts.prompt_template import prompt_template

if TYPE_CHECKING:
from genoss.entities.chat.chat_completion import ChatCompletion
from genoss.entities.chat.message import Message

FAKE_LLM_NAME = "fake"
Expand All @@ -20,19 +18,9 @@ class FakeLLM(BaseGenossLLM):
name: str = FAKE_LLM_NAME
description: str = "Fake LLM for testing purpose"

def generate_answer(self, messages: list[Message]) -> dict[str, Any]:
def generate_answer(self, messages: list[Message]) -> ChatCompletion:
llm = FakeListLLM(responses=["Hello from FakeLLM!"])

llm_chain = LLMChain(llm=llm, prompt=prompt_template)
question = messages[-1].content
response_text = llm_chain(question)

answer = response_text["text"]
chat_completion = ChatCompletion(
model=self.name, answer=answer, question=question
)

return chat_completion.to_dict()
return self._chat_completion_from_langchain_llm(llm=llm, messages=messages)

def generate_embedding(self, text: str) -> list[float]:
model = FakeEmbeddings(size=128)
Expand Down
19 changes: 3 additions & 16 deletions genoss/llm/hf_hub/base_hf_hub.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from abc import ABC
from typing import Any

from langchain import HuggingFaceHub, LLMChain
from langchain import HuggingFaceHub

from genoss.entities.chat.chat_completion import ChatCompletion
from genoss.entities.chat.message import Message
from genoss.llm.base_genoss import BaseGenossLLM
from genoss.prompts.prompt_template import prompt_template


class BaseHuggingFaceHubLLM(BaseGenossLLM, ABC):
Expand All @@ -16,23 +14,12 @@ class BaseHuggingFaceHubLLM(BaseGenossLLM, ABC):
api_key: str | None = None
repo_id: str

def generate_answer(self, messages: list[Message]) -> dict[str, Any]:
def generate_answer(self, messages: list[Message]) -> ChatCompletion:
"""Generate answer from prompt."""
llm = HuggingFaceHub(
repo_id=self.repo_id, huggingfacehub_api_token=self.api_key
)
llm_chain = LLMChain(prompt=prompt_template, llm=llm)

question = messages[-1].content
response_text = llm_chain(question)

answer = response_text["text"]

chat_completion = ChatCompletion(
model=self.name, question=question, answer=answer
)

return chat_completion.to_dict()
return self._chat_completion_from_langchain_llm(llm=llm, messages=messages)

def generate_embedding(self, text: str) -> list[float]:
"""Dummy method to satisfy base class requirement."""
Expand Down
23 changes: 4 additions & 19 deletions genoss/llm/hf_inference_endpoint/hf_inference_endpoint.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from abc import ABC
from typing import Any, Literal
from unittest import mock
from typing import Literal

from langchain import LLMChain
from langchain.llms import HuggingFaceEndpoint

from genoss.entities.chat.chat_completion import ChatCompletion
from genoss.entities.chat.message import Message
from genoss.llm.base_genoss import BaseGenossLLM
from genoss.prompts.prompt_template import prompt_template


class HuggingFaceInferenceEndpointLLM(BaseGenossLLM, ABC):
Expand All @@ -22,27 +20,14 @@ class HuggingFaceInferenceEndpointLLM(BaseGenossLLM, ABC):
"text-generation", "text-generation", "summarization"
] = "text-generation"

@mock.patch(
"huggingface_hub.inference_api.INFERENCE_ENDPOINT", "http://0.0.0.0:8080"
)
def generate_answer(self, question: str) -> dict[str, Any]:
def generate_answer(self, messages: list[Message]) -> ChatCompletion:
"""Generate answer from prompt."""
llm = HuggingFaceEndpoint(
endpoint_url=self.endpoint_url,
huggingfacehub_api_token=self.api_key,
task=self.task,
)
llm_chain = LLMChain(prompt=prompt_template, llm=llm)

response_text = llm_chain(question)

answer = response_text["text"]

chat_completion = ChatCompletion(
model=self.name, question=question, answer=answer
)

return chat_completion.to_dict()
return self._chat_completion_from_langchain_llm(llm=llm, messages=messages)

def generate_embedding(self, text: str) -> list[float]:
"""Dummy method to satisfy base class requirement."""
Expand Down
22 changes: 4 additions & 18 deletions genoss/llm/local/gpt4all.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING

from langchain import LLMChain
from langchain.embeddings import GPT4AllEmbeddings
from langchain.llms import GPT4All

from genoss.entities.chat.chat_completion import ChatCompletion
from genoss.llm.local.base_local import BaseLocalLLM
from genoss.prompts.prompt_template import prompt_template

if TYPE_CHECKING:
from genoss.entities.chat.chat_completion import ChatCompletion
from genoss.entities.chat.message import Message


Expand All @@ -19,23 +17,11 @@ class Gpt4AllLLM(BaseLocalLLM):
description: str = "GPT-4"
model_path: str = "./local_models/ggml-gpt4all-j-v1.3-groovy.bin"

def generate_answer(self, messages: list[Message]) -> dict[str, Any]:
def generate_answer(self, messages: list[Message]) -> ChatCompletion:
llm = GPT4All(
model=self.model_path, # pyright: ignore reportPrivateUsage=none
)

llm_chain = LLMChain(llm=llm, prompt=prompt_template)

question = messages[-1].content
response_text = llm_chain(question)

answer = response_text["text"]

chat_completion = ChatCompletion(
model=self.name, question=question, answer=answer
)

return chat_completion.to_dict()
return self._chat_completion_from_langchain_llm(llm=llm, messages=messages)

def generate_embedding(self, embedding: str | list[str]) -> list[float]:
gpt4all_embd = GPT4AllEmbeddings() # pyright: ignore reportPrivateUsage=none
Expand Down
Loading

0 comments on commit 5c33854

Please sign in to comment.