Skip to content

Commit

Permalink
adjust for the reorg
Browse files Browse the repository at this point in the history
  • Loading branch information
Alleria1809 committed Jun 23, 2024
1 parent 7049856 commit f1cab05
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 42 deletions.
File renamed without changes.
41 changes: 1 addition & 40 deletions lightrag/components/model_client/transformers_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,25 +222,6 @@ def __call__(self, **kwargs):
else:
raise ValueError(f"model {model_name} is not supported")

class ZephyrClient: # local
API_URL = "https://api-inference.huggingface.co/models/HuggingFaceH4/zephyr-7b-beta"

def __init__(self, payload: str, api_key: Optional[str] = None, API_URL: str = API_URL):

super().__init__()
self._api_key = api_key
self.API_URL = API_URL
self.payload = payload
self.headers = self.init_sync_client()

def init_sync_client(self):
api_key = self._api_key or os.getenv("HUGGING_FACE_API_KEY")
if not api_key:
raise ValueError("Environment variable HUGGING_FACE_API_KEY must be set")
headers = {"Authorization": api_key}
# response = requests.post(self.API_URL, headers=headers)
return headers


class TransformersClient(ModelClient):
__doc__ = r"""LightRAG API client for transformers.
Expand All @@ -255,9 +236,6 @@ class TransformersClient(ModelClient):
"BAAI/bge-reranker-base": {
"type": ModelType.RERANKER,
},
"HuggingFaceH4/zephyr-7b-beta": {
"type": ModelType.LLM,
},
}

def __init__(self, model_name: Optional[str] = None) -> None:
Expand All @@ -271,8 +249,6 @@ def __init__(self, model_name: Optional[str] = None) -> None:
self.sync_client = self.init_sync_client()
elif self._model_name == "BAAI/bge-reranker-base":
self.reranker_client = self.init_reranker_client()
elif self._model_name =="HuggingFaceH4/zephyr-7b-beta":
self.sync_client = self.init_llm_client()
self.async_client = None

def init_sync_client(self):
Expand All @@ -281,9 +257,6 @@ def init_sync_client(self):
def init_reranker_client(self):
return TransformerReranker()

def init_llm_client(self):
return ZephyrClient()

def parse_embedding_response(self, response: Any) -> EmbedderOutput:
embeddings: List[Embedding] = []
for idx, emb in enumerate(response):
Expand Down Expand Up @@ -316,15 +289,6 @@ def call(self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINE
scores, api_kwargs["top_k"]
)
return top_k_indices, top_k_scores
elif ( # llm
model_type == ModelType.LLM
and "model" in api_kwargs
and api_kwargs["model"] == "HuggingFaceH4/zephyr-7b-beta"
):
headers = self.sync_client.headers
input = api_kwargs["input"]
response = requests.post(self.API_URL, headers=headers, json=input)
return response.json()

def convert_inputs_to_api_kwargs(
self,
Expand All @@ -342,8 +306,5 @@ def convert_inputs_to_api_kwargs(
assert "top_k" in final_model_kwargs, "top_k must be specified"
final_model_kwargs["query"] = input
return final_model_kwargs
elif model_type == ModelType.LLM:
final_model_kwargs["input"] = input
return final_model_kwargs
else:
raise ValueError(f"model_type {model_type} is not supported")
raise ValueError(f"model_type {model_type} is not supported")
4 changes: 2 additions & 2 deletions lightrag/tests/test_gt_text_splitter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import unittest
from lightrag.core.types import Document
from lightrag.core.text_splitter import TextSplitter # Import your TextSplitter
from lightrag.core.document_splitter import DocumentSplitter # Import the ground truth splitter
from lightrag.components.data_process.text_splitter import TextSplitter # Import your TextSplitter
from lightrag.components.data_process.document_splitter import DocumentSplitter # Import the ground truth splitter


class TestTextSplitterComparison(unittest.TestCase):
Expand Down

0 comments on commit f1cab05

Please sign in to comment.