Skip to content

Commit

Permalink
fix issues caused by refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
goldmermaid committed Sep 13, 2023
1 parent 5b183e8 commit 2331aa5
Show file tree
Hide file tree
Showing 6 changed files with 12 additions and 3,799 deletions.
4 changes: 4 additions & 0 deletions example/flex/flex_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@

import os
import argparse
from dotenv import load_dotenv
from pykoi import Application
from pykoi.chat import QuestionAnswerDatabase
from pykoi.retrieval import RetrievalFactory
from pykoi.retrieval import VectorDbFactory
from pykoi.component import Chatbot, Dashboard, RetrievalQA


load_dotenv()


def main(**kargs):
os.environ["DOC_PATH"] = os.path.join(os.getcwd(), "temp/docs")
os.environ["VECTORDB_PATH"] = os.path.join(os.getcwd(), "temp/vectordb")
Expand Down
3,792 changes: 0 additions & 3,792 deletions poetry.lock

This file was deleted.

2 changes: 2 additions & 0 deletions pykoi/retrieval/llm/embedding_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@ def create_embedding(model_source: Union[str, ModelSource], **kwargs) -> Embeddi
try:
model_source = ModelSource(model_source)
if model_source == ModelSource.OPENAI:
from langchain.embeddings import OpenAIEmbeddings
return OpenAIEmbeddings()
elif model_source == ModelSource.HUGGINGFACE:
from langchain.embeddings import HuggingFaceEmbeddings
return HuggingFaceEmbeddings(
model_name=kwargs.get("model_name"),
)
Expand Down
2 changes: 0 additions & 2 deletions pykoi/retrieval/llm/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,10 @@

from langchain.chains import RetrievalQA
from langchain.llms import OpenAI
from dotenv import load_dotenv

from pykoi.retrieval.llm.abs_llm import AbsLlm
from pykoi.retrieval.vectordb.abs_vectordb import AbsVectorDb

load_dotenv()

OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
MIN_DOCS = 2
Expand Down
4 changes: 2 additions & 2 deletions pykoi/retrieval/llm/retrieval_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
from typing import Union

from pykoi.retrieval.llm.abs_llm import AbsLlm
from pykoi.retrieval.llm.openai import OpenAIModel
from pykoi.retrieval.llm.huggingface import HuggingFaceModel
from pykoi.retrieval.llm.constants import ModelSource
from pykoi.retrieval.vectordb.abs_vectordb import AbsVectorDb

Expand All @@ -29,8 +27,10 @@ def create(
try:
model_source = ModelSource(model_source)
if model_source == ModelSource.OPENAI:
from pykoi.retrieval.llm.openai import OpenAIModel
return OpenAIModel(vector_db)
if model_source == ModelSource.HUGGINGFACE:
from pykoi.retrieval.llm.huggingface import HuggingFaceModel
return HuggingFaceModel(vector_db, **kwargs)
except Exception as ex:
raise Exception(f"Unknown model: {model_source}") from ex
7 changes: 4 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ bcrypt = "4.0.1"
posthog = "3.0.1"
pynvml = "11.5.0"
pandas = "2.0.3"
python-dotenv = "1.0.0"

transformers = { version = "4.31.0", optional = true }
einops = { version = "0.6.1", optional = true }
Expand All @@ -47,7 +48,8 @@ huggingface = [
"transformers",
"einops",
"accelerate",
"bitsandbytes"
"bitsandbytes",
"sentence-transformers"
]
rag = [
"langchain",
Expand All @@ -57,8 +59,7 @@ rag = [
"pdfminer-six",
"docx2txt",
"python-multipart",
"tiktoken",
"sentence-transformers"
"tiktoken"
]
rlhf = [
"transformers",
Expand Down

0 comments on commit 2331aa5

Please sign in to comment.