Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

General linting and setup update #55

Merged
merged 1 commit into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 3 additions & 10 deletions chatify/cache.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,8 @@
from gptcache.adapter.langchain_models import LangChainLLMs
from gptcache import Cache
from gptcache.processor.pre import get_prompt


from gptcache.manager import get_data_manager, CacheBase, VectorBase


from gptcache.adapter.langchain_models import LangChainLLMs
from gptcache.embedding import Onnx
from gptcache.embedding.string import to_embeddings as string_embedding


from gptcache.manager import CacheBase, VectorBase, get_data_manager
from gptcache.processor.pre import get_prompt
from gptcache.similarity_evaluation.distance import SearchDistanceEvaluation
from gptcache.similarity_evaluation.exact_match import ExactMatchEvaluation

Expand Down
52 changes: 23 additions & 29 deletions chatify/chains.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,24 @@
from typing import Any, Dict, List, Optional

import requests


from langchain.prompts import PromptTemplate
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains import LLMChain, LLMMathChain
from langchain.chains.base import Chain
from langchain.prompts import PromptTemplate


from typing import Any, Dict, Optional
from langchain.callbacks.manager import CallbackManagerForChainRun

from .llm_models import ModelsFactory
from .cache import LLMCacher

from .llm_models import ModelsFactory
from .utils import compress_code


class RequestChain(Chain):
llm_chain: LLMChain = None
prompt: Optional[Dict[str, Any]]
headers: Optional[Dict[str, str]] = {
'accept': 'application/json',
'Content-Type': 'application/json',
"accept": "application/json",
"Content-Type": "application/json",
}
input_key: str = 'text'
input_key: str = "text"
url: str = "url" #: :meta private:
output_key: str = "text" #: :meta private:

Expand All @@ -50,14 +44,14 @@ def _call(
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
# Prepare data
if self.url != '/':
self.url += '/'
combined_url = self.url + self.prompt['prompt_id'] + '/response'
data = {'user_text': inputs[self.input_key]}
if self.url != "/":
self.url += "/"
combined_url = self.url + self.prompt["prompt_id"] + "/response"
data = {"user_text": inputs[self.input_key]}

# Send the request
response = requests.post(url=combined_url, headers=self.headers, json=data)
output = eval(response.content.decode('utf-8'))
output = eval(response.content.decode("utf-8"))

return {self.output_key: output}

Expand All @@ -77,16 +71,16 @@ def __init__(self, config):
None
"""
self.config = config
self.chain_config = config['chain_config']
self.chain_config = config["chain_config"]

self.llm_model = None
self.llm_models_factory = ModelsFactory()

self.cache = config['cache_config']['cache']
self.cache = config["cache_config"]["cache"]
self.cacher = LLMCacher(config)

# Setup model and chain factory
self._setup_llm_model(config['model_config'])
self._setup_llm_model(config["model_config"])
self._setup_chain_factory()

return None
Expand All @@ -112,9 +106,9 @@ def _setup_chain_factory(self):
None
"""
self.chain_factory = {
'math': LLMMathChain,
'default': LLMChain,
'proxy': RequestChain,
"math": LLMMathChain,
"default": LLMChain,
"proxy": RequestChain,
}

def create_prompt(self, prompt):
Expand All @@ -129,7 +123,7 @@ def create_prompt(self, prompt):
PROMPT (PromptTemplate): Prompt template object.
"""
PROMPT = PromptTemplate(
template=prompt['content'], input_variables=prompt['input_variables']
template=prompt["content"], input_variables=prompt["input_variables"]
)
return PROMPT

Expand All @@ -145,15 +139,15 @@ def create_chain(self, model_config=None, prompt_template=None):
-------
chain (LLMChain): LLM chain object.
"""
if self.config['chain_config']['chain_type'] == 'proxy':
if self.config["chain_config"]["chain_type"] == "proxy":
chain = RequestChain(
url=self.config['model_config']['proxy_url'], prompt=prompt_template
url=self.config["model_config"]["proxy_url"], prompt=prompt_template
)
else:
try:
chain_type = self.chain_config['chain_type']
chain_type = self.chain_config["chain_type"]
except KeyError:
chain_type = 'default'
chain_type = "default"

chain = self.chain_factory[chain_type](
llm=self.llm_model, prompt=self.create_prompt(prompt_template)
Expand All @@ -179,6 +173,6 @@ def execute(self, chain, inputs, *args, **kwargs):
output = chain.llm(inputs, cache_obj=self.cacher.llm_cache)
self.cacher.llm_cache.flush()
else:
output = chain(inputs)['text']
output = chain(inputs)["text"]

return output
83 changes: 41 additions & 42 deletions chatify/llm_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@

with warnings.catch_warnings(): # catch warnings about accelerate library
warnings.simplefilter("ignore")
from langchain.llms import OpenAI, HuggingFacePipeline, LlamaCpp
from langchain.llms.base import LLM
from langchain.chat_models import ChatOpenAI
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.chat_models import ChatOpenAI
from langchain.llms import HuggingFacePipeline, LlamaCpp, OpenAI

try:
from huggingface_hub import hf_hub_download
Expand Down Expand Up @@ -54,23 +53,23 @@ def get_model(self, model_config):
RuntimeError
If the specified model is not supported.
"""
model_ = model_config['model']
model_ = model_config["model"]

# Collect all the models
models = {
'open_ai_model': OpenAIModel,
'open_ai_chat_model': OpenAIChatModel,
'fake_model': FakeLLMModel,
'cached_model': CachedLLMModel,
'huggingface_model': HuggingFaceModel,
'llama_model': LlamaModel,
'proxy': ProxyModel,
"open_ai_model": OpenAIModel,
"open_ai_chat_model": OpenAIChatModel,
"fake_model": FakeLLMModel,
"cached_model": CachedLLMModel,
"huggingface_model": HuggingFaceModel,
"llama_model": LlamaModel,
"proxy": ProxyModel,
}

if model_ in models.keys():
with warnings.catch_warnings():
warnings.simplefilter("ignore")
if type(models[model_]) == str:
if isinstance(models[model_], str):
return models[model_]
else:
return models[model_](model_config).init_model()
Expand Down Expand Up @@ -139,17 +138,17 @@ def init_model(self):
llm_model : ChatOpenAI
Initialized OpenAI Chat Model.
"""
if self.model_config['open_ai_key'] is None:
raise ValueError(f'openai_api_key value cannot be None')
if self.model_config["open_ai_key"] is None:
raise ValueError("openai_api_key value cannot be None")

os.environ["OPENAI_API_KEY"] = self.model_config['open_ai_key']
os.environ["OPENAI_API_KEY"] = self.model_config["open_ai_key"]

llm_model = OpenAI(
temperature=0.85,
openai_api_key=self.model_config['open_ai_key'],
model_name=self.model_config['model_name'],
openai_api_key=self.model_config["open_ai_key"],
model_name=self.model_config["model_name"],
presence_penalty=0.1,
max_tokens=self.model_config['max_tokens'],
max_tokens=self.model_config["max_tokens"],
)
return llm_model

Expand Down Expand Up @@ -179,15 +178,15 @@ def init_model(self):
llm_model : ChatOpenAI
Initialized OpenAI Chat Model.
"""
if self.model_config['open_ai_key'] is None:
raise ValueError(f'openai_api_key value cannot be None')
if self.model_config["open_ai_key"] is None:
raise ValueError("openai_api_key value cannot be None")

llm_model = ChatOpenAI(
temperature=0.85,
openai_api_key=self.model_config['open_ai_key'],
model_name=self.model_config['model_name'],
openai_api_key=self.model_config["open_ai_key"],
model_name=self.model_config["model_name"],
presence_penalty=0.1,
max_tokens=self.model_config['max_tokens'],
max_tokens=self.model_config["max_tokens"],
)
return llm_model

Expand Down Expand Up @@ -216,7 +215,7 @@ def init_model(self):
Initialized Fake Chat Model.
"""
responses = [
'The explanation you requested has not been included in Chatify\'s cache. You\'ll need to enable interactive mode to generate a response. Please see the [Chatify GitHub repository](https://github.com/ContextLab/chatify) for instructions. Note that generating responses to uncached content will require an [OpenAI API Key](https://platform.openai.com/account/api-keys).'
"The explanation you requested has not been included in Chatify's cache. You'll need to enable interactive mode to generate a response. Please see the [Chatify GitHub repository](https://github.com/ContextLab/chatify) for instructions. Note that generating responses to uncached content will require an [OpenAI API Key](https://platform.openai.com/account/api-keys)."
]
llm_model = FakeListLLM(responses=responses)
return llm_model
Expand Down Expand Up @@ -247,7 +246,7 @@ def init_model(self):
"""
llm_model = FakeListLLM(
responses=[
f'The explanation you requested has not been included in Chatify\'s cache. You\'ll need to enable interactive mode to generate a response. Please see the [Chatify GitHub repository](https://github.com/ContextLab/chatify) for instructions. Note that generating responses to uncached content will require an [OpenAI API Key](https://platform.openai.com/account/api-keys).'
"The explanation you requested has not been included in Chatify's cache. You'll need to enable interactive mode to generate a response. Please see the [Chatify GitHub repository](https://github.com/ContextLab/chatify) for instructions. Note that generating responses to uncached content will require an [OpenAI API Key](https://platform.openai.com/account/api-keys)."
]
)
return llm_model
Expand Down Expand Up @@ -276,27 +275,27 @@ def init_model(self):
llm_model : HuggingFaceModel
Initialized Hugging Face Chat Model.
"""
self.proxy = self.model_config['proxy']
self.proxy_port = self.model_config['proxy_port']
self.proxy = self.model_config["proxy"]
self.proxy_port = self.model_config["proxy_port"]

with warnings.catch_warnings():
warnings.simplefilter("ignore")

try:
llm = HuggingFacePipeline.from_model_id(
model_id=self.model_config['model_name'],
task='text-generation',
model_id=self.model_config["model_name"],
task="text-generation",
device=0,
model_kwargs={'max_length': self.model_config['max_tokens']},
model_kwargs={"max_length": self.model_config["max_tokens"]},
)
except:
llm = HuggingFacePipeline.from_model_id(
model_id=self.model_config['model_name'],
task='text-generation',
model_id=self.model_config["model_name"],
task="text-generation",
model_kwargs={
'max_length': self.model_config['max_tokens'],
'temperature': 0.85,
'presence_penalty': 0.1,
"max_length": self.model_config["max_tokens"],
"temperature": 0.85,
"presence_penalty": 0.1,
},
)
return llm
Expand Down Expand Up @@ -326,8 +325,8 @@ def init_model(self):
Initialized Hugging Face Chat Model.
"""
self.model_path = hf_hub_download(
repo_id=self.model_config['model_name'],
filename=self.model_config['weights_fname'],
repo_id=self.model_config["model_name"],
filename=self.model_config["weights_fname"],
)

with warnings.catch_warnings():
Expand All @@ -337,17 +336,17 @@ def init_model(self):
try:
llm = LlamaCpp(
model_path=self.model_path,
max_tokens=self.model_config['max_tokens'],
n_gpu_layers=self.model_config['n_gpu_layers'],
n_batch=self.model_config['n_batch'],
max_tokens=self.model_config["max_tokens"],
n_gpu_layers=self.model_config["n_gpu_layers"],
n_batch=self.model_config["n_batch"],
callback_manager=callback_manager,
verbose=True,
)
except:
llm = LlamaCpp(
model_path=self.model_path,
max_tokens=self.model_config['max_tokens'],
n_batch=self.model_config['n_batch'],
max_tokens=self.model_config["max_tokens"],
n_batch=self.model_config["n_batch"],
callback_manager=callback_manager,
verbose=True,
)
Expand Down
13 changes: 5 additions & 8 deletions chatify/main.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
import yaml

import pathlib
import requests

from IPython.display import display
from IPython.core.magic import Magics, magics_class, cell_magic

import ipywidgets as widgets
import yaml
from IPython.core.magic import Magics, cell_magic, magics_class
from IPython.display import display

from .chains import CreateLLMChain
from .widgets import option_widget, button_widget, text_widget, thumbs, loading_widget

from .utils import check_dev_config, get_html
from .widgets import (button_widget, loading_widget, option_widget,
text_widget, thumbs)


@magics_class
Expand Down Expand Up @@ -138,7 +135,7 @@
output : str
The GPT model output in markdown format.
"""
# TODO: Should we create the chain every time? Only prompt is chainging not the model

Check failure on line 138 in chatify/main.py

View workflow job for this annotation

GitHub Actions / Check for spelling errors

chainging ==> changing, chaining
chain = self.llm_chain.create_chain(
self.cfg["model_config"], prompt_template=prompt
)
Expand Down
8 changes: 3 additions & 5 deletions chatify/utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
from typing import Any, List, Mapping, Optional

import random
import urllib
from typing import Any, List, Mapping, Optional

from langchain.llms.base import LLM
from markdown_it import MarkdownIt
from pygments import highlight
from pygments.formatters import HtmlFormatter
from pygments.lexers import get_lexer_by_name
from pygments import highlight

from langchain.llms.base import LLM


def highlight_code(code, name, attrs):
Expand Down
4 changes: 2 additions & 2 deletions chatify/widgets.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import ipywidgets as widgets

import pathlib

import ipywidgets as widgets


def option_widget(config):
"""Create an options dropdown widget based on the given configuration.
Expand Down
Loading
Loading