From f32f83765d47041bceb8c673c05c101e58969374 Mon Sep 17 00:00:00 2001 From: youssef Date: Sun, 22 Sep 2024 20:31:50 +0100 Subject: [PATCH] Add Hugging API --- data_folder/config.yaml | 3 ++ requirements.txt | 1 + src/llm/llm_manager.py | 75 +++++++++++++++++++++++++++++------------ 3 files changed, 58 insertions(+), 21 deletions(-) diff --git a/data_folder/config.yaml b/data_folder/config.yaml index 762597d7..375a8a19 100644 --- a/data_folder/config.yaml +++ b/data_folder/config.yaml @@ -49,4 +49,7 @@ job_applicants_threshold: llm_model_type: openai llm_model: gpt-4o-mini + +#llm_model_type: huggingface +#llm_model: 'tiiuae/falcon-7b-instruct' # llm_api_url: https://api.pawan.krd/cosmosrp/v1 this field is optional diff --git a/requirements.txt b/requirements.txt index f3b0eee5..0fe8debd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,6 +6,7 @@ jsonschema==4.23.0 jsonschema-specifications==2023.12.1 langchain==0.2.11 langchain-anthropic +langchain-huggingface langchain-community==0.2.10 langchain-core===0.2.36 langchain-google-genai==1.0.10 diff --git a/src/llm/llm_manager.py b/src/llm/llm_manager.py index 5c48c557..5f9fd250 100644 --- a/src/llm/llm_manager.py +++ b/src/llm/llm_manager.py @@ -90,6 +90,18 @@ def invoke(self, prompt: str) -> BaseMessage: response = self.model.invoke(prompt) return response +class HuggingFaceModel(AIModel): + def __init__(self, api_key: str, llm_model: str): + from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace + self.model = HuggingFaceEndpoint(repo_id=llm_model, huggingfacehub_api_token=api_key, + temperature=0.4) + self.chatmodel=ChatHuggingFace(llm=self.model) + + def invoke(self, prompt: str) -> BaseMessage: + response = self.chatmodel.invoke(prompt) + logger.debug("Invoking Model from Hugging Face API") + print(response,type(response)) + return response class AIAdapter: def __init__(self, config: dict, api_key: str): @@ -111,6 +123,8 @@ def _create_model(self, config: dict, api_key: str) -> AIModel: return OllamaModel(llm_model, llm_api_url) elif llm_model_type == "gemini": return GeminiModel(api_key, llm_model) + elif llm_model_type == "huggingface": + return HuggingFaceModel(api_key, llm_model) else: raise ValueError(f"Unsupported model type: {llm_model_type}") @@ -286,27 +300,46 @@ def parse_llmresult(self, llmresult: AIMessage) -> Dict[str, Dict]: logger.debug(f"Parsing LLM result: {llmresult}") try: - content = llmresult.content - response_metadata = llmresult.response_metadata - id_ = llmresult.id - usage_metadata = llmresult.usage_metadata - - parsed_result = { - "content": content, - "response_metadata": { - "model_name": response_metadata.get("model_name", ""), - "system_fingerprint": response_metadata.get("system_fingerprint", ""), - "finish_reason": response_metadata.get("finish_reason", ""), - "logprobs": response_metadata.get("logprobs", None), - }, - "id": id_, - "usage_metadata": { - "input_tokens": usage_metadata.get("input_tokens", 0), - "output_tokens": usage_metadata.get("output_tokens", 0), - "total_tokens": usage_metadata.get("total_tokens", 0), - }, - } - + if hasattr(llmresult, 'usage_metadata '): + content = llmresult.content + response_metadata = llmresult.response_metadata + id_ = llmresult.id + usage_metadata = llmresult.usage_metadata + + parsed_result = { + "content": content, + "response_metadata": { + "model_name": response_metadata.get("model_name", ""), + "system_fingerprint": response_metadata.get("system_fingerprint", ""), + "finish_reason": response_metadata.get("finish_reason", ""), + "logprobs": response_metadata.get("logprobs", None), + }, + "id": id_, + "usage_metadata": { + "input_tokens": usage_metadata.get("input_tokens", 0), + "output_tokens": usage_metadata.get("output_tokens", 0), + "total_tokens": usage_metadata.get("total_tokens", 0), + }, + } + else : + content = llmresult.content + response_metadata = llmresult.response_metadata + id_ = llmresult.id + token_usage = response_metadata['token_usage'] + + parsed_result = { + "content": content, + "response_metadata": { + "model_name": response_metadata.get("model", ""), + "finish_reason": response_metadata.get("finish_reason", ""), + }, + "id": id_, + "usage_metadata": { + "input_tokens": token_usage.prompt_tokens, + "output_tokens": token_usage.completion_tokens, + "total_tokens": token_usage.total_tokens, + }, + } logger.debug(f"Parsed LLM result successfully: {parsed_result}") return parsed_result