-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
DST-198: add llm selection tool and base transcript (#22)
Resolves DST-198: base code for processing transcripts: Added prompt for transcript and selection for openhermes, dolphin, gemini, gpt 3.5 and gpt 4
- Loading branch information
Showing
4 changed files
with
151 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
.env | ||
__pycache__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import dotenv | ||
|
||
from langchain_community.llms.ollama import Ollama | ||
import google.generativeai as genai | ||
from openai import OpenAI | ||
import os | ||
|
||
dotenv.load_dotenv() | ||
|
||
|
||
def get_transcript(file_path="./transcript.txt"): | ||
file = open(file_path) | ||
content = file.read() | ||
return content | ||
|
||
|
||
def ollama_client( | ||
model_name=None, | ||
prompt=None, | ||
callbacks=None, | ||
settings=None, | ||
): | ||
if not settings: | ||
settings = { | ||
# "temperature": 0.1, | ||
# "system": "", | ||
# "template": "", | ||
# See https://github.com/langchain-ai/langchain/blob/master/libs/community/langchain_community/llms/ollama.py | ||
"stop": None | ||
} | ||
|
||
print("LLM settings:", model_name, settings) | ||
# To connect via another URL: Ollama(base_url='http://localhost:11434', ...) | ||
return Ollama(model=model_name, callbacks=callbacks, **settings).invoke(prompt) | ||
|
||
|
||
def google_gemini_client( | ||
model_name="gemini-pro", | ||
prompt=None, | ||
settings=None, | ||
): | ||
# Get a Google API key by following the steps after clicking on Get an API key button | ||
# at https://ai.google.dev/tutorials/setup | ||
GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY") | ||
|
||
print("LLM settings:", model_name, settings) | ||
|
||
genai.configure(api_key=GOOGLE_API_KEY) | ||
if settings: | ||
genai.GenerationConfig(**settings) | ||
model = genai.GenerativeModel(model_name) | ||
return model.generate_content(prompt) | ||
|
||
|
||
def gpt3_5(prompt, model="gpt-3.5-turbo"): | ||
OPEN_AI_API_KEY = os.environ.get("OPEN_AI_API_KEY") | ||
openai_client = OpenAI(api_key=OPEN_AI_API_KEY) # Uses OPENAI_API_KEY | ||
return ( | ||
openai_client.chat.completions.create( | ||
model=model, messages=[{"role": "user", "content": prompt}] | ||
) | ||
.choices[0] | ||
.message.content | ||
) | ||
|
||
|
||
def gpt_4_turbo(prompt): | ||
return gpt3_5(prompt, model="gpt-4-turbo") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
google-generativeai | ||
tokenizers | ||
langchain | ||
langchain_community | ||
openai |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
from llm import google_gemini_client, ollama_client, gpt3_5, gpt_4_turbo | ||
from langchain_core.prompts import PromptTemplate | ||
|
||
|
||
# download transcripts from https://drive.google.com/drive/folders/19r6x3Zep4N9Rl_x4n4H6RpWkXviwbxyw?usp=sharing | ||
def get_transcript(file_path="./transcript.txt"): | ||
file = open(file_path) | ||
content = file.read() | ||
file.close() | ||
return content | ||
|
||
|
||
prompt = """ | ||
You are a helpful AI assistant who will summarize this transcript {transcript}, using the following template: | ||
--------------------- | ||
Caller Information (Name, contact information, availability, household information) | ||
Reason/Type of Call (Applying for benefits, Follow-Ups) | ||
Previous Benefits History (Applied for, Receives, Denied, etc) | ||
Put # in front of the benefit discussed (i.e. #SNAP, LIHEAP) | ||
Discussion Points (Key information points) | ||
Documents Needed (Income, Housing, etc) | ||
Next Steps for Client | ||
Next Steps for Agent | ||
--------------------- | ||
""" | ||
|
||
print(""" | ||
Select an llm | ||
1. openhermes (default) | ||
2. dolphin | ||
3. gemini | ||
4. gpt 3.5 | ||
5. gpt 4 | ||
""") | ||
|
||
|
||
transcript = get_transcript() | ||
llm = input() or "1" | ||
prompt_template = PromptTemplate.from_template(prompt) | ||
formatted_prompt = prompt_template.format(transcript=transcript) | ||
|
||
if llm == "2": | ||
test = ollama_client(model_name="dolphin-mistral", prompt=formatted_prompt) | ||
print("""---------- | ||
Dolphin | ||
""") | ||
elif llm == "3": | ||
test = google_gemini_client(prompt=formatted_prompt).text | ||
print("""---------- | ||
Gemini | ||
""") | ||
elif llm == "4": | ||
print("""---------- | ||
GPT 3.5 | ||
""") | ||
test = gpt3_5(prompt=formatted_prompt) | ||
elif llm == "5": | ||
print("""---------- | ||
GPT 4 | ||
""") | ||
test = gpt_4_turbo(prompt=formatted_prompt) | ||
else: | ||
test = ollama_client(model_name="openhermes", prompt=formatted_prompt) | ||
print(""" | ||
Openhermes | ||
""") | ||
|
||
print(test) |