diff --git a/.envrc b/.envrc index 5c5cd28..9cb0c9f 100644 --- a/.envrc +++ b/.envrc @@ -8,6 +8,7 @@ export RERANKER_PORT=80 export VECTORDB_HOST=192.168.0.207 export VECTORDB_PORT=8000 export ENABLE_RERANKER="1" +export TOOLS_BASE_URL="http://192.168.0.209" # External services export HUGGINGFACEHUB_API_TOKEN="$(cat ~/.hf_token)" #Replace with your own Hugging Face API token @@ -19,9 +20,9 @@ export PORTKEY_CUSTOM_HOST="llm_provider_host_ip_and_port" #Only if LLM is local export USE_PORTKEY="0" # Model specific options -export MODEL_ID="qwen/Qwen2-7B-Instruct" -export STOP_TOKEN="<|endoftext|>" +export MAX_TOKENS=1024 # Streamlit configurations export AUTH_CONFIG_FILE_PATH=".streamlit/config.yaml" export STREAMLIT_CLIENT_SHOW_ERROR_DETAILS=False +export STREAMLIT_SERVER_HEADLESS=True \ No newline at end of file diff --git a/.gitignore b/.gitignore index 581fb9f..7016c33 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ __pycache__/** insf_venv/** +*.pyc diff --git a/app.py b/app.py index 8f17f50..e9e6a40 100644 --- a/app.py +++ b/app.py @@ -2,15 +2,11 @@ import uuid import datasets from langchain_huggingface import HuggingFaceEndpointEmbeddings -from langchain_community.chat_models import ChatHuggingFace -from langchain_community.llms import HuggingFaceEndpoint +from langchain_openai import ChatOpenAI from langchain_core.output_parsers import StrOutputParser from langchain_core.runnables import RunnablePassthrough from langchain.agents import create_react_agent, AgentExecutor from langchain.tools.retriever import create_retriever_tool -from langchain_community.tools.tavily_search import TavilySearchResults -from langchain_community.utilities import StackExchangeAPIWrapper -from langchain_community.tools.stackexchange.tool import StackExchangeTool from langchain_core.messages import SystemMessage from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain_core.documents import Document @@ -19,19 +15,24 @@ import chromadb from chromadb.config import Settings from chromadb.utils.embedding_functions import HuggingFaceEmbeddingServer - +from tenacity import retry, stop_after_attempt, wait_fixed, retry_if_exception_type +from urllib3.exceptions import ProtocolError from langchain.retrievers import ContextualCompressionRetriever -from tei_rerank import TEIRerank from transformers import AutoTokenizer +from tools import get_tools +from tei_rerank import TEIRerank + import streamlit as st import streamlit_authenticator as stauth import yaml from yaml.loader import SafeLoader -from tenacity import retry, stop_after_attempt, wait_fixed, retry_if_exception_type -from urllib3.exceptions import ProtocolError +from langchain.globals import set_verbose, set_debug + +set_verbose(True) +set_debug(True) st.set_page_config(layout="wide", page_title="InSightful") @@ -80,24 +81,15 @@ def hf_embedding_server(): # Set up HuggingFaceEndpoint model @st.cache_resource -def setup_huggingface_endpoint(model_id): - llm = HuggingFaceEndpoint( - endpoint_url="http://{host}:{port}".format( +def setup_chat_endpoint(): + model = ChatOpenAI( + base_url="http://{host}:{port}/v1".format( host=os.getenv("TGI_HOST", "localhost"), port=os.getenv("TGI_PORT", "8080") ), - temperature=0.3, - task="conversational", - stop_sequences=[ - "<|im_end|>", - "<|eot_id|>", - "{your_token}".format( - your_token=os.getenv("STOP_TOKEN", "<|end_of_text|>") - ), - ], + max_tokens=os.getenv("MAX_TOKENS", 1024), + temperature=0.7, + api_key="dummy", ) - - model = ChatHuggingFace(llm=llm, model_id=model_id) - return model @@ -159,8 +151,6 @@ def load_prompt_and_system_ins( class RAG: def __init__(self, collection_name, db_client): - # self.llm = llm - # self.embedding_svc = embedding_svc self.collection_name = collection_name self.db_client = db_client @@ -208,17 +198,9 @@ def insert_embeddings(self, chunks, chroma_embedding_function, batch_size=32): documents = [chunk.page_content for chunk in batch] collection.add(ids=chunk_ids, metadatas=metadatas, documents=documents) - # db = Chroma( - # embedding_function=embedder, - # collection_name=self.collection_name, - # client=self.db_client, - # ) print("Embeddings inserted\n") - # return db - def query_docs( - self, model, question, vector_store, prompt, chat_history, use_reranker=False - ): + def get_retriever(self, vector_store, use_reranker=False): retriever = vector_store.as_retriever( search_type="similarity", search_kwargs={"k": 10} ) @@ -234,7 +216,12 @@ def query_docs( retriever = ContextualCompressionRetriever( base_compressor=compressor, base_retriever=retriever ) + return retriever + def query_docs( + self, model, question, vector_store, prompt, chat_history, use_reranker=False + ): + retriever = self.get_retriever(vector_store, use_reranker) pass_question = lambda input: input["question"] rag_chain = ( RunnablePassthrough.assign(context=pass_question | retriever | format_docs) @@ -245,6 +232,7 @@ def query_docs( return rag_chain.stream({"question": question, "chat_history": chat_history}) + def format_docs(docs): return "\n\n".join(doc.page_content for doc in docs) @@ -262,70 +250,22 @@ def create_retriever( collection_name=collection_name, client=client, ) - if reranker: - compressor = TEIRerank( - url="http://{host}:{port}".format( - host=os.getenv("RERANKER_HOST", "localhost"), - port=os.getenv("RERANKER_PORT", "8082"), - ), - top_n=10, - batch_size=16, - ) - - retriever = vector_store.as_retriever( - search_type="similarity", search_kwargs={"k": 100} - ) - compression_retriever = ContextualCompressionRetriever( - base_compressor=compressor, base_retriever=retriever - ) - info_retriever = create_retriever_tool(compression_retriever, name, description) - else: - retriever = vector_store.as_retriever( - search_type="similarity", search_kwargs={"k": 10} - ) - info_retriever = create_retriever_tool(retriever, name, description) - - return info_retriever - - -def setup_tools(_model, _client, _chroma_embedding_function, _embedder): - tools = [] - if ( - os.getenv("STACK_OVERFLOW_API_KEY") - and os.getenv("STACK_OVERFLOW_API_KEY").strip() - ): - stackexchange_wrapper = StackExchangeAPIWrapper(max_results=3) - stackexchange_tool = StackExchangeTool(api_wrapper=stackexchange_wrapper) - tools.append(stackexchange_tool) - - if os.getenv("TAVILY_API_KEY") and os.getenv("TAVILY_API_KEY").strip(): - web_search_tool = TavilySearchResults(max_results=10, handle_tool_error=True) - tools.append(web_search_tool) + retriever = rag.get_retriever(vector_store, use_reranker=reranker) - use_reranker = os.getenv("USE_RERANKER", "False") == "True" - retriever = create_retriever( - "slack_conversations_retriever", - "Useful for when you need to answer from Slack conversations.", - _client, - _chroma_embedding_function, - _embedder, - reranker=use_reranker, + retriever = vector_store.as_retriever( + search_type="similarity", search_kwargs={"k": 10} ) - tools.append(retriever) - - return tools - + return create_retriever_tool(retriever, name, description) @st.cache_resource -def setup_agent(_model, _prompt, _client, _chroma_embedding_function, _embedder): - tools = setup_tools(_model, _client, _chroma_embedding_function, _embedder) +def setup_agent(_model, _prompt, _tools): agent = create_react_agent( llm=_model, prompt=_prompt, - tools=tools, + tools=_tools, ) agent_executor = AgentExecutor( - agent=agent, verbose=True, tools=tools, handle_parsing_errors=True + agent=agent, verbose=True, tools=_tools, handle_parsing_errors=True ) return agent_executor @@ -337,12 +277,22 @@ def main(): if os.getenv("ENABLE_PORTKEY", "False") == "True": model = setup_portkey_integrated_model() else: - model = setup_huggingface_endpoint(model_id=os.getenv("MODEL_ID")) + model = setup_chat_endpoint() embedder = setup_huggingface_embeddings() + use_reranker = os.getenv("USE_RERANKER", "False") == "True" - agent_executor = setup_agent( - model, prompt, client, chroma_embedding_function, embedder + retriever_tool = create_retriever( + "slack_conversations_retriever", + "Useful for when you need to answer from Slack conversations.", + client, + chroma_embedding_function, + embedder, + reranker=use_reranker, ) + _tools = get_tools() + _tools.append(retriever_tool) + + agent_executor = setup_agent(model, prompt, _tools) st.title("InSightful: Your AI Assistant for community questions") st.text("Made with ❤️ by InfraCloud Technologies") diff --git a/k8s-manifests/env.yaml b/k8s-manifests/env.yaml index ce9f1bd..6eeaf5a 100644 --- a/k8s-manifests/env.yaml +++ b/k8s-manifests/env.yaml @@ -12,12 +12,10 @@ data: RERANKER_PORT: "80" VECTORDB_HOST: "ai-stack-vectordb" VECTORDB_PORT: "8000" - STOP_TOKEN: "<|endoftext|>" + TOOLS_BASE_URL: "http://192.168.0.209" PORTKEY_PROVIDER: "llm_provider_name" PORTKEY_CUSTOM_HOST: "llm_provider_host_ip_and_port" USE_PORTKEY: "0" USE_RERANKER: "1" AUTH_CONFIG_FILE_PATH: "/opt/auth-config/config.yaml" - MODEL_ID: "meta-llama/Meta-Llama-3.1-8B-Instruct" - STOP_TOKEN: "<|endoftext|>" STREAMLIT_CLIENT_SHOW_ERROR_DETAILS: False diff --git a/multi_tenant_rag.py b/multi_tenant_rag.py index 43f6267..98fe398 100644 --- a/multi_tenant_rag.py +++ b/multi_tenant_rag.py @@ -9,18 +9,21 @@ from langchain_community.document_loaders import PyPDFLoader from langchain_community.vectorstores.chroma import Chroma from unstructured.cleaners.core import clean_extra_whitespace, group_broken_paragraphs +from tools import get_tools from app import ( setup_chroma_client, hf_embedding_server, load_prompt_and_system_ins, setup_huggingface_embeddings, - setup_huggingface_endpoint, + setup_chat_endpoint, RAG, + setup_agent, ) logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.INFO) + def configure_authenticator(): auth_config = os.getenv("AUTH_CONFIG_FILE_PATH", default=".streamlit/config.yaml") print(f"auth_config: {auth_config}") @@ -81,7 +84,12 @@ def load_documents(self, doc): def main(): - llm = setup_huggingface_endpoint(model_id=os.getenv("MODEL_ID")) + use_reranker = st.sidebar.toggle("Use reranker", False) + use_tools = st.sidebar.toggle("Use tools", False) + uploaded_file = st.sidebar.file_uploader("Upload a document", type=["pdf"]) + question = st.chat_input("Chat with your docs or apis") + + llm = setup_chat_endpoint() embedding_svc = setup_huggingface_embeddings() @@ -97,8 +105,12 @@ def main(): Be concise and always provide accurate, specific, and relevant information. """ + template_file_path = "templates/multi_tenant_rag_prompt_template.tmpl" + if use_tools: + template_file_path = "templates/multi_tenant_rag_prompt_template_tools.tmpl" + prompt, system_instructions = load_prompt_and_system_ins( - template_file_path="templates/multi_tenant_rag_prompt_template.tmpl", + template_file_path=template_file_path, template=template, ) @@ -118,15 +130,16 @@ def main(): f"user-collection-{user_id}", embedding_function=chroma_embeddings ) - use_reranker = st.sidebar.toggle("Use reranker", False) - use_tools = st.sidebar.toggle("Use tools", False) - uploaded_file = st.sidebar.file_uploader("Upload a document", type=["pdf"]) - question = st.chat_input("Chat with your doc") - logger = logging.getLogger(__name__) - logger.info(f"user_id: {user_id} use_reranker: {use_reranker} use_tools: {use_tools} question: {question}") + logger.info( + f"user_id: {user_id} use_reranker: {use_reranker} use_tools: {use_tools} question: {question}" + ) rag = MultiTenantRAG(user_id, collection.name, client) + if use_tools: + tools = get_tools() + agent_executor = setup_agent(llm, prompt, tools) + # prompt = hub.pull("rlm/rag-prompt") vectorstore = Chroma( @@ -147,17 +160,28 @@ def main(): if question: st.chat_message("user").markdown(question) with st.spinner(): - answer = rag.query_docs( - model=llm, - question=question, - vector_store=vectorstore, - prompt=prompt, - chat_history=chat_history, - use_reranker=use_reranker, - ) - with st.chat_message("assistant"): - answer = st.write_stream(answer) - logger.info(f"answer: {answer}") + if use_tools: + answer = agent_executor.invoke( + { + "question": question, + "chat_history": chat_history, + } + )["output"] + with st.chat_message("assistant"): + answer = st.write(answer) + logger.info(f"answer: {answer}") + else: + answer = rag.query_docs( + model=llm, + question=question, + vector_store=vectorstore, + prompt=prompt, + chat_history=chat_history, + use_reranker=use_reranker, + ) + with st.chat_message("assistant"): + answer = st.write_stream(answer) + logger.info(f"answer: {answer}") chat_history.append({"role": "user", "content": question}) chat_history.append({"role": "assistant", "content": answer}) diff --git a/templates/multi_tenant_rag_prompt_template_tools.tmpl b/templates/multi_tenant_rag_prompt_template_tools.tmpl new file mode 100644 index 0000000..14c1314 --- /dev/null +++ b/templates/multi_tenant_rag_prompt_template_tools.tmpl @@ -0,0 +1,40 @@ +InSightful is a bot developed by InfraCloud Technologies. + +InSightful is used to assist users analyze & get insights in the uploaded pdf files. + +InSightful is designed to be able to assist with a wide range of tasks, from answering simple questions to providing in-depth explanations and discussions on the pdf files uploaded and the data within. + +TOOLS: +------ + +InSightful has access to the following tools: + +{tools} + +To use a tool, please use the following format: + +``` +Thought: Do I need to use a tool? Yes\n +Action: the action to take, should be one of [{tool_names}]\n +Action Input: the input parameters to the action in specified format. It MUST be as a json string. It MUST not contain anything other than the json string.\n +Observation: the result of the action\n +``` + +When you have a response to say to the Human, or if you do not need to use a tool, you MUST use the format: + +``` +Thought: Do I need to use a tool? No +Final Answer: [your response here] +``` + +You MUST not mention the tools used for the response in the final answer. + +Use the following pieces of retrieved context to answer the question. +Begin! + +Previous conversation history: +{chat_history} + +Question: {question} +{agent_scratchpad} +Answer: \ No newline at end of file diff --git a/tools/__init__.py b/tools/__init__.py new file mode 100644 index 0000000..2d7494d --- /dev/null +++ b/tools/__init__.py @@ -0,0 +1,36 @@ +import os + +from .annual_still_births_by_state import AnnualStillBirthsByStateAPIClient, AnnualStillBirthsByStateTool +from .annual_still_births import AnnualStillBirthsAPIClient, AnnualStillBirthsTool +from .annual_live_births_by_state import AnnualLiveBirthsByStateAPIClient, AnnualLiveBirthsByStateTool +from .annual_live_births import AnnualLiveBirthsAPIClient, AnnualLiveBirthsTool +from .daily_live_births import DailyLiveBirthsAPIClient, DailyLiveBirthsTool +from .car_registrations import CarRegistrationTool, CarRegistrationAPIClient + +def get_tools(): + car_registration_tool = CarRegistrationTool(api_wrapper=CarRegistrationAPIClient()) + annual_live_births_tool = AnnualLiveBirthsTool(api_wrapper=AnnualLiveBirthsAPIClient()) + annual_live_births_by_state_tool = AnnualLiveBirthsByStateTool(api_wrapper=AnnualLiveBirthsByStateAPIClient()) + annual_still_births_tool = AnnualStillBirthsByStateTool(api_wrapper=AnnualStillBirthsByStateAPIClient()) + annual_still_births_by_state_tool = AnnualStillBirthsTool(api_wrapper=AnnualStillBirthsAPIClient()) + daily_live_births_tool = DailyLiveBirthsTool(api_wrapper=DailyLiveBirthsAPIClient()) + # append car registration tool to tools + tools = [ + car_registration_tool, + annual_live_births_tool, + annual_live_births_by_state_tool, + annual_still_births_tool, + annual_still_births_by_state_tool, + daily_live_births_tool + ] + if ( + os.getenv("STACK_OVERFLOW_API_KEY") + and os.getenv("STACK_OVERFLOW_API_KEY").strip() + ): + from .stackexchange import tool as stackexchange_tool + tools.append(stackexchange_tool()) + + if os.getenv("TAVILY_API_KEY") and os.getenv("TAVILY_API_KEY").strip(): + from .tavily import tool as tavily_tool + tools.append(tavily_tool()) + return tools \ No newline at end of file diff --git a/tools/annual_live_births.py b/tools/annual_live_births.py new file mode 100644 index 0000000..e7a59a6 --- /dev/null +++ b/tools/annual_live_births.py @@ -0,0 +1,93 @@ +from langchain.pydantic_v1 import BaseModel, Field, root_validator +from typing import Optional, Dict +from urllib.parse import urlencode, urljoin +from langchain.tools import BaseTool +import logging +import os +import json +from tools.utils import cleanupInputAndGetDictFromStr +import requests +import re +from langchain_core.callbacks import CallbackManagerForToolRun + +logger = logging.getLogger(__name__) + +class AnnualLiveBirthsAPIClient(BaseModel): + base_url: str = os.getenv("TOOLS_BASE_URL", default="http://localhost:5000") + + @root_validator(pre=True, allow_reuse=True) + def validate_api_client(cls, values: Dict) -> Dict: + logging.info("validation done AnnualLiveBirthsAPIClient") + return values + + + def get_annual_live_births_info(self, params): + base_url = f"{self.base_url}/annual_live_births" + if params: + query_string = urlencode(params) + url = urljoin(base_url, f"?{query_string}") + else: + url = base_url + response = requests.get(url) + logging.info(f"Annual Live Births Response code: {response.status_code}, response json: {response.json()}") + + if response.status_code == 200: + return response.json() + else: + response.raise_for_status() + +class AnnualLiveBirthsTool(BaseTool): + name = "annual_live_births_data_lookup" + description = """ A wrapper around API to lookup annual total live births in the whole country of Malaysia. + Useful to look up annual children born alive information from 1st Jan, 2000 to 1st Jan, 2022. + The data contains just one row per year, with the number, i.e. absolute, abs in short, of children born alive in that year. + The api can look up annual live births data by executing an HTTP GET endpoint. + Optionally the GET endpoint can accept query parameters to filter the results. + + You MUST use this tool ONLY to fetch country-wise live births data. To fetch state-wise data, use the annual_live_births_by_state tool. + + Args: + A JSON string with the following fields: + - dateFrom: The start date of the registration period. Date format: YYYY-MM-DD + - dateTo: The end date of the registration period. Date format: YYYY-MM-DD + - abs: The absolute number of children born alive. + - absMin: The minimum absolute number of children born alive. + - absMax: The maximum absolute number of children born alive. + - rate: The rate of children born alive. + - rateMin: The minimum rate of children born alive. + - rateMax: The maximum rate of children born alive. + - count: set to true, if you just need the count instead of actual data. + + Returns: + A JSON string with the following fields: + - date: The date of birth record in the format YYYY-MM-DD. + - abs: The absolute number of children born alive. + - rate: The rate of children born alive. + """ + api_wrapper: AnnualLiveBirthsAPIClient + + + def _run( + self, + input: Optional[str]=None, + run_manager: Optional[CallbackManagerForToolRun] = None + ) -> str: + logging.info(f"\nEntering AnnualLiveBirthsTool run function\nOriginal input: {input}\n") + params = {} + if input: + matches = re.findall(r'(\w+)=([\w\-]+)', input) + params = dict(matches) + matches = re.findall(r'(\w+): ([\w\d\-]+)', input) + params.update(dict(matches)) + if not params: + params = cleanupInputAndGetDictFromStr(input) + + if str(params.get("count")) != "None": + params.pop("count") + + logging.info(f"\nInvoking annual live births data look up tool with params: {params}") + return self.api_wrapper.get_annual_live_births_info(params) + + def _arun(self, params): + raise NotImplementedError("Asynchronous operation is not supported for this tool.") + diff --git a/tools/annual_live_births_by_state.py b/tools/annual_live_births_by_state.py new file mode 100644 index 0000000..43808fe --- /dev/null +++ b/tools/annual_live_births_by_state.py @@ -0,0 +1,95 @@ +from langchain.pydantic_v1 import BaseModel, Field, root_validator +from typing import Optional, Dict +from urllib.parse import urlencode, urljoin +from langchain.tools import BaseTool +import logging +import os +import json +from tools.utils import cleanupInputAndGetDictFromStr +import requests +import re +from langchain_core.callbacks import CallbackManagerForToolRun + +logger = logging.getLogger(__name__) + +class AnnualLiveBirthsByStateAPIClient(BaseModel): + base_url: str = os.getenv("TOOLS_BASE_URL", default="http://localhost:5000") + + @root_validator(pre=True, allow_reuse=True) + def validate_api_client(cls, values: Dict) -> Dict: + logging.info("validation done AnnualLiveBirthsByStateAPIClient") + return values + + + def get_annual_live_births_by_state_info(self, params): + base_url = f"{self.base_url}/annual_live_births_by_state" + if params: + query_string = urlencode(params) + url = urljoin(base_url, f"?{query_string}") + else: + url = base_url + response = requests.get(url) + logging.info(f"Annual Live Births By State Response code: {response.status_code}, response json: {response.json()}") + + if response.status_code == 200: + return response.json() + else: + response.raise_for_status() + +class AnnualLiveBirthsByStateTool(BaseTool): + name = "annual_live_births_by_state_data_lookup" + description = """ A wrapper around annual live births, i.e. children born alive, by state data lookup APIs, for the country of Malaysia. + Useful to look up annual children born alive per state information from 1st Jan, 2000 to 1st Jan, 2022. + The data contains just one row or less, per year per state, with the number, i.e. absolute, abs in short, of children born alive in that year. + The api can look up annual live births data per state by executing an HTTP GET endpoint. + Optionally the GET endpoint can accept query parameters to filter the results. + + You MUST use this tool ONLY to fetch state-wise live births data. For fetching data for the entire country, use the annual_live_births tool. + + Args: + A JSON string with the following fields: + - state: The state where the children are born. Skip this field to get data for all states. + - dateFrom: The start date of the registration period. Date format: YYYY-MM-DD + - dateTo: The end date of the registration period. Date format: YYYY-MM-DD + - abs: The absolute number of children born alive. + - absMin: The minimum absolute number of children born alive. + - absMax: The maximum absolute number of children born alive. + - rate: The rate of children born alive. + - rateMin: The minimum rate of children born alive. + - rateMax: The maximum rate of children born alive. + + Returns: + A JSON string with the following fields: + - state: The state where the children are born. + - date: The date of birth record in the format YYYY-MM-DD. + - abs: The absolute number of children born alive. + - rate: The rate of children born alive. + + """ + api_wrapper: AnnualLiveBirthsByStateAPIClient + + + def _run( + self, + input: Optional[str]=None, + run_manager: Optional[CallbackManagerForToolRun] = None + ) -> str: + logging.info(f"\nEntering AnnualLiveBirthsByStateTool run function\nOriginal input: {input}\n") + params = {} + if input: + matches = re.findall(r'(\w+)=([\w\-]+)', input) + params = dict(matches) + matches = re.findall(r'(\w+): ([\w\d\-]+)', input) + params.update(dict(matches)) + if not params: + params = cleanupInputAndGetDictFromStr(input) + + if str(params.get("count")) != "None": + params.pop("count") + + logging.info(f"\nInvoking annual live births by state data look up tool with params: {params}") + return self.api_wrapper.get_annual_live_births_by_state_info(params) + + def _arun(self, params): + raise NotImplementedError("Asynchronous operation is not supported for this tool.") + diff --git a/tools/annual_still_births.py b/tools/annual_still_births.py new file mode 100644 index 0000000..a0ec2d6 --- /dev/null +++ b/tools/annual_still_births.py @@ -0,0 +1,92 @@ +from langchain.pydantic_v1 import BaseModel, Field, root_validator +from typing import Optional, Dict +from urllib.parse import urlencode, urljoin +from langchain.tools import BaseTool +import logging +import os +import json +from tools.utils import cleanupInputAndGetDictFromStr +import requests +import re +from langchain_core.callbacks import CallbackManagerForToolRun + +logger = logging.getLogger(__name__) + +class AnnualStillBirthsAPIClient(BaseModel): + base_url: str = os.getenv("TOOLS_BASE_URL", default="http://localhost:5000") + + @root_validator(pre=True, allow_reuse=True) + def validate_api_client(cls, values: Dict) -> Dict: + logging.info("validation done AnnualStillBirthsAPIClient") + return values + + + def get_annual_still_births_info(self, params): + base_url = f"{self.base_url}/annual_still_births" + if params: + query_string = urlencode(params) + url = urljoin(base_url, f"?{query_string}") + else: + url = base_url + response = requests.get(url) + logging.info(f"Annual Still Births Response code: {response.status_code}, response json: {response.json()}") + + if response.status_code == 200: + return response.json() + else: + response.raise_for_status() + +class AnnualStillBirthsTool(BaseTool): + name = "annual_still_births_data_lookup" + description = """ A wrapper around api to lookup annual still births in the whole country of Malaysia. + Useful to look up annual stillborn children information from 1st Jan, 2000 to 1st Jan, 2022. + The data contains just one row per year, with the number, i.e. absolute, abs in short, of stillborn children in that year. + The api can look up annual still births data by executing an HTTP GET endpoint. + Optionally the GET endpoint can accept query parameters to filter the results. + + You MUST use this tool ONLY to fetch country-wise live births data. To fetch state-wise data, use the annual_still_births_by_state tool. + + Args: + A JSON string with the following fields: + - dateFrom: The start date of the registration period. Date format: YYYY-MM-DD + - dateTo: The end date of the registration period. Date format: YYYY-MM-DD + - abs: The absolute number of stillborn children. + - absMin: The minimum absolute number of stillborn children. + - absMax: The maximum absolute number of stillborn children. + - rate: The rate of stillborn children. + - rateMin: The minimum rate of stillborn children. + - rateMax: The maximum rate of stillborn children. + - count: set to true, if you just need the count instead of actual data. + + Returns: + A JSON string with the following fields: + - date: The date of birth record in the format YYYY-MM-DD. + - abs: The absolute number of stillborn children. + - rate: The rate of stillborn children. + """ + api_wrapper: AnnualStillBirthsAPIClient + + + def _run( + self, + input: Optional[str]=None, + run_manager: Optional[CallbackManagerForToolRun] = None + ) -> str: + logging.info(f"\nEntering AnnualStillBirthsTool run function\nOriginal input: {input}\n") + params = {} + if input: + matches = re.findall(r'(\w+)=([\w\-]+)', input) + params = dict(matches) + matches = re.findall(r'(\w+): ([\w\d\-]+)', input) + params.update(dict(matches)) + if not params: + params = cleanupInputAndGetDictFromStr(input) + + if str(params.get("count")) != "None": + params.pop("count") + + logging.info(f"Invoking annual still births data look up tool with params: {params}") + return self.api_wrapper.get_annual_still_births_info(params) + + def _arun(self, params): + raise NotImplementedError("Asynchronous operation is not supported for this tool.") diff --git a/tools/annual_still_births_by_state.py b/tools/annual_still_births_by_state.py new file mode 100644 index 0000000..f05d11c --- /dev/null +++ b/tools/annual_still_births_by_state.py @@ -0,0 +1,95 @@ +from langchain.pydantic_v1 import BaseModel, Field, root_validator +from typing import Optional, Dict +from urllib.parse import urlencode, urljoin +from langchain.tools import BaseTool +import logging +import os +import json +from tools.utils import cleanupInputAndGetDictFromStr +import requests +import re +from langchain_core.callbacks import CallbackManagerForToolRun + +logger = logging.getLogger(__name__) + +class AnnualStillBirthsByStateAPIClient(BaseModel): + base_url: str = os.getenv("TOOLS_BASE_URL", default="http://localhost:5000") + + @root_validator(pre=True, allow_reuse=True) + def validate_api_client(cls, values: Dict) -> Dict: + logging.info("validation done AnnualStillBirthsByStateAPIClient") + return values + + + def get_annual_still_births_by_state_info(self, params): + base_url = f"{self.base_url}/annual_still_births_by_state" + if params: + query_string = urlencode(params) + url = urljoin(base_url, f"?{query_string}") + else: + url = base_url + response = requests.get(url) + logging.info(f"Annual Still Births By State Response code: {response.status_code}, response json: {response.json()}") + + if response.status_code == 200: + return response.json() + else: + response.raise_for_status() + +class AnnualStillBirthsByStateTool(BaseTool): + name = "annual_still_births_by_state_data_lookup" + description = """ A wrapper around annual still births, i.e. children born still, by state data lookup APIs, for the country of Malaysia. + Useful to look up annual stillborn children per state information from 1st Jan, 2000 to 1st Jan, 2022. + The data contains just one row or less, per year per state, with the number, i.e. absolute, abs in short, of stillborn children in that year. + The api can look up annual still births data per state by executing an HTTP GET endpoint. + Optionally the GET endpoint can accept query parameters to filter the results. + + You MUST use this tool ONLY to fetch state-wise still births data. For fetching data for the entire country, use the annual_still_births tool. + + Args: + A JSON string with the following fields: + - state: The state where the children are born. Skip this field to get data for all states. + - dateFrom: The start date of the registration period. Date format: YYYY-MM-DD + - dateTo: The end date of the registration period. Date format: YYYY-MM-DD + - abs: The absolute number of stillborn children. + - absMin: The minimum absolute number of stillborn children. + - absMax: The maximum absolute number of stillborn children. + - rate: The rate of stillborn children. + - rateMin: The minimum rate of stillborn children. + - rateMax: The maximum rate of stillborn children. + - count: set to true, if you just need the count instead of actual data. + + Returns: + A JSON string with the following fields: + - state: The state where the children are born. + - date: The date of birth record in the format YYYY-MM-DD. + - abs: The absolute number of stillborn children. + - rate: The rate of stillborn children. + """ + api_wrapper: AnnualStillBirthsByStateAPIClient + + + def _run( + self, + input: Optional[str]=None, + run_manager: Optional[CallbackManagerForToolRun] = None + ) -> str: + logging.info(f"\nEntering AnnualStillBirthsByStateTool run function\nOriginal input: {input}\n") + params = {} + if input: + matches = re.findall(r'(\w+)=([\w\-]+)', input) + params = dict(matches) + matches = re.findall(r'(\w+): ([\w\d\-]+)', input) + params.update(dict(matches)) + if not params: + params = cleanupInputAndGetDictFromStr(input) + + if str(params.get("count")) != "None": + params.pop("count") + + logging.info(f"Invoking annual still births by state data look up tool with params: {params}") + return self.api_wrapper.get_annual_still_births_by_state_info(params) + + def _arun(self, params): + raise NotImplementedError("Asynchronous operation is not supported for this tool.") + diff --git a/tools/car_registrations.py b/tools/car_registrations.py new file mode 100644 index 0000000..f90437e --- /dev/null +++ b/tools/car_registrations.py @@ -0,0 +1,68 @@ +from langchain.pydantic_v1 import BaseModel, Field, root_validator +from typing import Optional, Dict +from urllib.parse import urlencode, urljoin +from langchain.tools import BaseTool +import requests +import re +from langchain_core.callbacks import CallbackManagerForToolRun + +class CarRegistrationAPIClient(BaseModel): + base_url: str = "http://192.168.0.209" + + @root_validator(pre=True, allow_reuse=True) + def validate_api_client(cls, values: Dict) -> Dict: + print("validation done") + return values + + + def get_registration_info(self, params): + base_url = f"{self.base_url}/registrations" + if params: + query_string = urlencode(params) + url = urljoin(base_url, f"?{query_string}") + else: + url = base_url + response = requests.get(url) + print("Get Car Registrations Response: ", response.status_code, response.json()) + + if response.status_code == 200: + return response.json() + else: + response.raise_for_status() + +class CarRegistrationTool(BaseTool): + name = "car_registration_data_lookup" + description = """ A wrapper around car registration data lookup APIs. + Useful to look up car registration information from 1st Jan, 2024 to 31st Jul, 2024. + The api can look up car registration data by executing an HTTP GET endpoint. + Optionally the GET endpoint can accept query parameters to filter the results. + The acceptable query parameters are: + - dateFrom: The start date of the registration period. Date format: YYYY-MM-DD + - dateTo: The end date of the registration period. Date format: YYYY-MM-DD + - type: The type of the car. (e.g. jip, pick_up, motokar, etc.) + - make: The make of the car. + - model: The model of the car. + - color: The color of the car. + - fuel: The fuel type of the car. + - state: The state where the car is registered. + - count: set to true, if you just need the count instead of actual data. + """ + api_wrapper: CarRegistrationAPIClient + + + def _run( + self, + input: Optional[str]=None, + run_manager: Optional[CallbackManagerForToolRun] = None + ) -> str: + print("Entering CarRegistrationTool run function") + params = {} + if input: + matches = re.findall(r'(\w+)=([\w\-]+)', input) + params = dict(matches) + print("Invoking car registration data look up tool with params: ", params) + return self.api_wrapper.get_registration_info(params) + + def _arun(self, params): + raise NotImplementedError("Asynchronous operation is not supported for this tool.") + diff --git a/tools/daily_live_births.py b/tools/daily_live_births.py new file mode 100644 index 0000000..b03e9c8 --- /dev/null +++ b/tools/daily_live_births.py @@ -0,0 +1,88 @@ +from langchain.pydantic_v1 import BaseModel, Field, root_validator +from typing import Optional, Dict +from urllib.parse import urlencode, urljoin +from langchain.tools import BaseTool +import logging +import os +import json +from tools.utils import cleanupInputAndGetDictFromStr +import requests +import re +from langchain_core.callbacks import CallbackManagerForToolRun + +logger = logging.getLogger(__name__) + +class DailyLiveBirthsAPIClient(BaseModel): + base_url: str = os.getenv("TOOLS_BASE_URL", default="http://localhost:5000") + + @root_validator(pre=True, allow_reuse=True) + def validate_api_client(cls, values: Dict) -> Dict: + logging.info("validation done DailyLiveBirthsAPIClient") + return values + + + def get_daily_live_births_info(self, params): + base_url = f"{self.base_url}/daily_live_births" + if params: + query_string = urlencode(params) + url = urljoin(base_url, f"?{query_string}") + else: + url = base_url + response = requests.get(url) + logging.info(f"Daily Live Births Response code: {response.status_code}, response json: {response.json()}") + + if response.status_code == 200: + return response.json() + else: + response.raise_for_status() + +class DailyLiveBirthsTool(BaseTool): + name = "daily_live_births_data_lookup" + description = """ A wrapper around daily live births data lookup APIs, for the entire country of Malaysia. + Useful to look up daily children born alive information from 1st Jan, 1920 to 31st Jul, 2023. + The data contains just one row per day, with the number, i.e. births, of children born alive in that day. + The api can look up daily live births data by executing an HTTP GET endpoint. + Optionally the GET endpoint can accept query parameters to filter the results. + + Args: + A JSON string with the following fields: + - dateFrom: The start date of the registration period. Date format: YYYY-MM-DD + - dateTo: The end date of the registration period. Date format: YYYY-MM-DD + - births: The number of children born alive on a day. + - birthsMin: The minimum number of children born alive on a day. + - birthsMax: The maximum number of children born alive on a day. + - count: set to true, if you just need the count instead of actual data. + + Returns: + A JSON string with the following fields: + - date: The date of the birth in the format YYYY-MM-DD. + - births: The number of children born alive on that day. + - state: The state where the children are born. + """ + api_wrapper: DailyLiveBirthsAPIClient + + + def _run( + self, + input: Optional[str]=None, + run_manager: Optional[CallbackManagerForToolRun] = None + ) -> str: + logging.info(f"\nEntering DailyLiveBirthsTool run function\nOriginal input: {input}\n") + params = {} + if input: + matches = re.findall(r'(\w+)=([\w\-]+)', input) + params = dict(matches) + matches = re.findall(r'(\w+): ([\w\d\-]+)', input) + params.update(dict(matches)) + if not params: + params = cleanupInputAndGetDictFromStr(input) + + if str(params.get("count")) != "None": + params.pop("count") + + logging.info(f"\nInvoking daily live births data look up tool with params: {params}") + return self.api_wrapper.get_daily_live_births_info(params) + + def _arun(self, params): + raise NotImplementedError("Asynchronous operation is not supported for this tool.") + diff --git a/tools/stackexchange.py b/tools/stackexchange.py new file mode 100644 index 0000000..ef5c759 --- /dev/null +++ b/tools/stackexchange.py @@ -0,0 +1,7 @@ +from langchain_community.utilities import StackExchangeAPIWrapper +from langchain_community.tools.stackexchange.tool import StackExchangeTool + + +def tool(): + stackexchange_wrapper = StackExchangeAPIWrapper(max_results=3) + return StackExchangeTool(api_wrapper=stackexchange_wrapper) diff --git a/tools/tavily.py b/tools/tavily.py new file mode 100644 index 0000000..d005231 --- /dev/null +++ b/tools/tavily.py @@ -0,0 +1,5 @@ +from langchain_community.tools.tavily_search import TavilySearchResults + + +def tool(): + return TavilySearchResults(max_results=10, handle_tool_error=True) diff --git a/tools/utils.py b/tools/utils.py new file mode 100644 index 0000000..e371f63 --- /dev/null +++ b/tools/utils.py @@ -0,0 +1,24 @@ +import json +import re + +def cleanupInputAndGetDictFromStr(input: str) -> dict: + input = input.replace("\n", "") + match = re.search(r'(\{.*\}|\[.*\])', input) + cleaned_json_string = input + + if match: + cleaned_json_string = match.group(0) + + formattedInput = cleaned_json_string + if "True" in formattedInput and '"True"' not in formattedInput: + formattedInput = formattedInput.replace("True", "\"True\"") + if "False" in formattedInput and '"False"' not in formattedInput: + formattedInput = formattedInput.replace("False", "\"False\"") + + try: + params = json.loads(formattedInput) + cleaned_params = {k: v for k, v in params.items() if v not in (0, None, "", "all", [], {})} + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON input: {e}") + + return cleaned_params \ No newline at end of file