Skip to content

Commit

Permalink
Merge pull request #7 from elastic/wave2
Browse files Browse the repository at this point in the history
Wave2
  • Loading branch information
derickson authored Mar 20, 2024
2 parents 7181024 + 193b245 commit 68ee649
Showing 1 changed file with 133 additions and 99 deletions.
232 changes: 133 additions & 99 deletions notebooks/genai_colab_lab1and2.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,17 @@
"source": [
"FILE=\"GenAI Lab 1 and 2\"\n",
"\n",
"## suppress some warnings\n",
"import warnings, os\n",
"os.environ['PIP_ROOT_USER_ACTION'] = 'ignore'\n",
"warnings.filterwarnings(\"ignore\", category=UserWarning, module='huggingface_hub.utils._token')\n",
"warnings.filterwarnings(\"ignore\", message=\"torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\")\n",
"warnings.filterwarnings(\"ignore\", category=DeprecationWarning, message=\"on_submit is deprecated.*\")\n",
"\n",
"# workshop environment - this is where you'll enter a key\n",
"# ! pip install -qqq git+https://github.com/elastic/notebook-workshop-loader.git@main\n",
"from notebookworkshoploader import loader\n",
"import os\n",
"from dotenv import load_dotenv\n",
"\n",
"if os.path.isfile(\"../env\"):\n",
" load_dotenv(\"../env\", override=True)\n",
" print('Successfully loaded environment variables from local env file')\n",
Expand Down Expand Up @@ -143,6 +149,9 @@
"def wrap_text(text, width):\n",
" wrapped_text = textwrap.wrap(text, width)\n",
" return '\\n'.join(wrapped_text)\n",
"\n",
"def print_light_blue(text):\n",
" print(f'\\033[94m{text}\\033[0m')\n",
"\n"
]
},
Expand Down Expand Up @@ -266,8 +275,7 @@
"\n",
"Let's start with the Hello World of generative AI examples: completing a sentence. For this we'll install a fine tuned Flan-T5 variant model. ([LaMini-T5 ](https://huggingface.co/MBZUAI/LaMini-T5-738M))\n",
"\n",
"Note, while this is a smaller checkpoint of the model, it is still a 3GB download. We'll cache the files in the same folder.\n",
"\n"
"Note, while this is a smaller checkpoint of the model, it is still a 900 MB download. We'll cache the files in the same folder.\n"
]
},
{
Expand All @@ -279,14 +287,15 @@
"outputs": [],
"source": [
"## Let's play with something a little bigger that can do a text completion\n",
"## This is a 3 GB download and takes some RAM to run, but it works CPU only\n",
"## This is a 900 MB download and takes some RAM to run, but it works CPU only\n",
"\n",
"from transformers import pipeline\n",
"from transformers import AutoTokenizer, AutoModelForSeq2SeqLM\n",
"\n",
"# model_name = \"MBZUAI/LaMini-Flan-T5-77M\"\n",
"# model_name = \"MBZUAI/LaMini-T5-223M\"\n",
"model_name = \"MBZUAI/LaMini-T5-738M\"\n",
"model_name = \"MBZUAI/LaMini-T5-223M\" ## Trying this 900 MB version of the LLM\n",
"# model_name = \"MBZUAI/LaMini-T5-738M\" ## 3GB version requires more RAM than we have in this local environment\n",
"\n",
"\n",
"llm_tokenizer = AutoTokenizer.from_pretrained(model_name,\n",
" cache_dir=cache_directory)\n",
Expand Down Expand Up @@ -430,7 +439,51 @@
"\n",
"openai.api_key = os.environ['OPENAI_API_KEY']\n",
"openai.api_base = os.environ['OPENAI_API_BASE']\n",
"openai.default_model = os.environ['OPENAI_API_ENGINE']"
"openai.default_model = os.environ['OPENAI_API_ENGINE']\n",
"\n",
"import ipywidgets as widgets\n",
"from IPython.display import display\n",
"\n",
"class NotebookChatExperience:\n",
" def __init__(self, ai_response_function, ai_name = \"AI\"):\n",
" self.ai_name = ai_name\n",
" self.ai_response_function = ai_response_function\n",
" self.chat_history = widgets.Textarea(\n",
" value='',\n",
" placeholder='Chat history will appear here...',\n",
" description='Chat:',\n",
" disabled=True,\n",
" layout=widgets.Layout(width='700px', height='300px') # Adjust the size as needed\n",
" )\n",
" self.user_input = widgets.Text(\n",
" value='',\n",
" placeholder='Type your message here...',\n",
" description='You:',\n",
" disabled=False,\n",
" layout=widgets.Layout(width='700px') # Adjust the size as needed\n",
" )\n",
" self.user_input.on_submit(self.on_submit)\n",
" display(self.chat_history, self.user_input)\n",
"\n",
" def on_submit(self, event):\n",
" user_message = self.user_input.value\n",
" ai_name = self.ai_name\n",
" self.chat_history.value += f\"\\nYou: {user_message}\"\n",
" ai_message = self.ai_response_function(user_message)\n",
" self.chat_history.value += f\"\\n{ai_name}: {ai_message}\"\n",
" self.user_input.value = '' # Clear input for next message\n",
"\n",
" def clear_chat(self):\n",
" self.chat_history.value = '' # Clear the chat history\n",
"\n",
"## ********** Example usage:\n",
"\n",
"## ********** Define a simple AI response function\n",
"# def simple_ai_response(user_message):\n",
" # return f\"AI > Echo: {user_message}\"\n",
"\n",
"## ********** Create an instance of the chat interface\n",
"#chat_instance = NotebookChatExperience(simple_ai_response)"
]
},
{
Expand All @@ -451,29 +504,26 @@
"outputs": [],
"source": [
"# Call the OpenAI ChatCompletion API\n",
"def chatCompletion(messages):\n",
"def chatCompletion(messages, max_tokens=100):\n",
" client = OpenAI(api_key=openai.api_key, base_url=openai.api_base)\n",
" completion = client.chat.completions.create(\n",
" model=openai.default_model,\n",
" max_tokens=100,\n",
" max_tokens=max_tokens,\n",
" messages=messages\n",
" )\n",
" return completion\n",
"\n",
"def chatWithGPT(prompt, print_full_json=False):\n",
" completion = chatCompletion([{\"role\": \"user\", \"content\": prompt}])\n",
" response_text = completion.choices[0].message.content\n",
"prompt=\"Hello, is ChatGPT online and working?\"\n",
"\n",
" if print_full_json:\n",
" print(completion.json())\n",
"messages = [{\"role\": \"user\", \"content\": prompt}]\n",
"\n",
" return wrap_text(response_text,70)\n",
"completion = chatCompletion(messages)\n",
"\n",
"## call it with the json debug output enabled\n",
"response = chatWithGPT(\"Hello, is ChatGPT online and working?\", print_full_json=True)\n",
"response_text = completion.choices[0].message.content\n",
"\n",
"print(\"\\n\")\n",
"print(response)"
"print(wrap_text(completion.json(),70))\n",
"\n",
"print(\"\\n\", wrap_text(response_text,70))"
]
},
{
Expand All @@ -483,7 +533,7 @@
},
"source": [
"\n",
"## Step 3: A conversation loop - ❗ type \"exit\" to end the chat ❗\n",
"## Step 3: Using OpenAI in a simple loop\n",
"Feeding user input in for single questions is easy"
]
},
Expand All @@ -495,19 +545,13 @@
},
"outputs": [],
"source": [
"def hold_a_conversation(ai_conversation_function = chatWithGPT):\n",
" print(\" -- Have a conversation with an AI: \")\n",
" print(\" -- type 'exit' when done\")\n",
"\n",
" user_input = input(\"> \")\n",
" while not user_input.lower().startswith(\"exit\"):\n",
" print(ai_conversation_function(user_input, False))\n",
" print(\" -- type 'exit' when done\")\n",
" user_input = input(\"> \")\n",
" print(\"\\n -- end conversation --\")\n",
"\n",
"## we are passing the previously defined function as a parameter\n",
"hold_a_conversation(chatWithGPT)\n"
"def openai_ai_response(user_message):\n",
" messages = [{\"role\": \"user\", \"content\": user_message}]\n",
" completion = chatCompletion(messages)\n",
" response_text = completion.choices[0].message.content\n",
" return response_text\n",
"\n",
"chat_instance = NotebookChatExperience(openai_ai_response)"
]
},
{
Expand All @@ -530,23 +574,20 @@
},
"outputs": [],
"source": [
"def pirateGPT(prompt, print_full_json=False):\n",
" system_prompt = \"\"\"\n",
"def pirate_ai_response(user_message):\n",
" system_prompt = \"\"\"\n",
"You are an unhelpful AI named Captain LLM_Beard that talks like a pirate in short responses.\n",
"You acknowledge the user's question but redirect all conversations towards your love of treasure.\n",
"You do not anser the user's question but instead redirect all conversations towards your love of treasure.\n",
"\"\"\"\n",
" completion = chatCompletion([\n",
" {\"role\": \"system\", \"content\": system_prompt},\n",
" {\"role\": \"user\", \"content\": prompt}\n",
" ])\n",
" response_text = completion.choices[0].message.content\n",
" if print_full_json:\n",
" print(completion.json())\n",
"\n",
" return wrap_text(response_text,70)\n",
" completion = chatCompletion([\n",
" {\"role\": \"system\", \"content\": system_prompt},\n",
" {\"role\": \"user\", \"content\": user_message}\n",
" ])\n",
"\n",
" response_text = completion.choices[0].message.content\n",
" return response_text\n",
"\n",
"hold_a_conversation(pirateGPT)"
"pirate_chat_instance = NotebookChatExperience(pirate_ai_response, ai_name=\"LLM_Beard\")"
]
},
{
Expand Down Expand Up @@ -618,74 +659,67 @@
" def peek(self):\n",
" return list(self.buffer)\n",
"\n",
"# enough conversation memory for 2 call and response history\n",
"memory_buffer = QueueBuffer(4)\n",
"\n",
"system_prompt = {\n",
"class MemoryNotebookChatExperience(NotebookChatExperience):\n",
" def __init__(self, ai_response_function, ai_name=\"AI\", memory_size = 4):\n",
" # Initialize the superclass\n",
" self.memory_buffer = QueueBuffer(memory_size)\n",
" self.current_memory_dump = \"\"\n",
" super().__init__(ai_response_function, ai_name)\n",
"\n",
" ## now with memory\n",
" def memory_gpt_response(self, prompt):\n",
" ## the API call will use the system prompt + the memory buffer\n",
" ## which ends with the user prompt\n",
" user_message = {\"role\": \"user\", \"content\": prompt}\n",
" self.memory_buffer.enqueue(user_message)\n",
"\n",
" ## debug print the current AI memory\n",
" self.current_memory_dump = \"Current memory\\n\"\n",
" for m in self.memory_buffer.peek():\n",
" role = m.get(\"role\").strip()\n",
" content = m.get(\"content\").strip()\n",
" self.current_memory_dump += f\"{role} | {content}\\n\"\n",
"\n",
" system_prompt = {\n",
" \"role\": \"system\",\n",
" \"content\": \"\"\"\n",
"You are an AI named Cher Horowitz that speaks\n",
"in 1990's valley girl dialect of English.\n",
"You are a helpful AI that answers questions consicely.\n",
"You talk to the human and use the past conversation to inform your answers.\"\"\"\n",
" }\n",
"\n",
"## utility function to print in a different color for debug output\n",
"def print_light_blue(text):\n",
" print(f'\\033[94m{text}\\033[0m')\n",
" ## when calling the AI we put the system prompt at the start\n",
" concatenated_message = [system_prompt] + self.memory_buffer.peek()\n",
"\n",
"## now with memory\n",
"def cluelessGPT(prompt, print_full_json=False):\n",
" ## here is the request to the AI\n",
"\n",
" ## the API call will use the system prompt + the memory buffer\n",
" ## which ends with the user prompt\n",
" user_message = {\"role\": \"user\", \"content\": prompt}\n",
" memory_buffer.enqueue(user_message)\n",
" completion = chatCompletion(concatenated_message)\n",
" response_text = completion.choices[0].message.content\n",
"\n",
" ## debug print the current AI memory\n",
" print_light_blue(\"Current memory\")\n",
" for m in memory_buffer.peek():\n",
" role = m.get(\"role\").strip()\n",
" content = m.get(\"content\").strip()\n",
" print_light_blue( f\" {role} | {content}\")\n",
"\n",
" ## when calling the AI we put the system prompt at the start\n",
" concatenated_message = [system_prompt] + memory_buffer.peek()\n",
" ## don't forget to add the repsonse to the conversation memory\n",
" self.memory_buffer.enqueue({\"role\":\"assistant\", \"content\":response_text})\n",
"\n",
" ## here is the request to the AI\n",
" return response_text\n",
"\n",
" completion = chatCompletion(concatenated_message)\n",
" response_text = completion.choices[0].message.content\n",
" if print_full_json:\n",
" print(completion.json())\n",
" def on_submit(self, event):\n",
" user_message = self.user_input.value\n",
" self.chat_history.value += f\"\\nYou: {user_message}\"\n",
" # Attempting to add styled text, but it will appear as plain text\n",
"\n",
" ai_message = self.memory_gpt_response(user_message)\n",
"\n",
" ## don't forget to add the repsonse to the conversation memory\n",
" memory_buffer.enqueue({\"role\":\"assistant\", \"content\":response_text})\n",
" ## deubg lines to show memory buffer in chat\n",
" for i, line in enumerate(self.current_memory_dump.split(\"\\n\")):\n",
" self.chat_history.value += f\"\\n---- {i} {line}\"\n",
" self.chat_history.value += \"\\n\"\n",
"\n",
" if print_full_json:\n",
" json_pretty(completion)\n",
" self.chat_history.value += f\"\\n{self.ai_name}: {ai_message}\"\n",
" self.user_input.value = '' # Clear input for next message\n",
"\n",
" return wrap_text(response_text,70)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NTHELTKTNOF8"
},
"source": [
"#### Step 5: Let's chat with a not so clueless chatbot"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "FVztTDN6NOF8"
},
"outputs": [],
"source": [
"hold_a_conversation(cluelessGPT)"
"\n",
"# Create an instance of the enhanced chat experience class with a simple AI response function\n",
"not_so_clueless_chat = MemoryNotebookChatExperience(None)"
]
},
{
Expand Down

0 comments on commit 68ee649

Please sign in to comment.