diff --git a/colabs/openai/OpenAI_Finetuning_on_Gorilla_with_wandb.ipynb b/colabs/openai/OpenAI_Finetuning_on_Gorilla_with_wandb.ipynb
new file mode 100644
index 00000000..0915d874
--- /dev/null
+++ b/colabs/openai/OpenAI_Finetuning_on_Gorilla_with_wandb.ipynb
@@ -0,0 +1,635 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# ChatGPT-3.5 Fine-tuning - Gorrilla api\n",
+ "\n",
+ "Fine-tuning ChatGPT-3.5 on the Gorilla api dataset to try and improve its performance\n",
+ "- [Gorilla project](https://shishirpatil.github.io/gorilla/)\n",
+ "- [Gorilla paper](https://arxiv.org/abs/2305.15334)\n",
+ "- [Gorilla code](https://github.com/ShishirPatil/gorilla)\n",
+ "\n",
+ "OpenAI ChatGPT-3.5 fine-tuning docs [are here](https://platform.openai.com/docs/guides/fine-tuning)\n",
+ "\n",
+ "**Warning!**\n",
+ "\n",
+ "This fine-tuning script will train 7.2 million tokens on OpenAI, check if you're willing to pay that before proceeding :)\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!pip install openai tiktoken wandb -qqq"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import re\n",
+ "import os\n",
+ "import json\n",
+ "import wandb\n",
+ "import openai\n",
+ "from pprint import pprint"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "openai_api_key = \"OPENAI API KEY\"\n",
+ "openai.api_key = openai_api_key"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Data"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Download the Gorrilla huggingface api training data, you can find all the [Gorilla training data here](https://github.com/ShishirPatil/gorilla/tree/main/data/apibench)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!wget https://raw.githubusercontent.com/ShishirPatil/gorilla/cab053ba7fdf4a3286c0e75aa2bf7abc4053812f/data/apibench/huggingface_train.json\n",
+ "!wget https://raw.githubusercontent.com/ShishirPatil/gorilla/main/data/apibench/tensorflow_train.json\n",
+ "!wget https://raw.githubusercontent.com/ShishirPatil/gorilla/main/data/apibench/torchhub_train.json"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Load the data"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "data = []\n",
+ "data_files = [\n",
+ " \"huggingface_train.json\",\n",
+ " \"tensorflow_train.json\",\n",
+ " \"torchhub_train.json\",\n",
+ "]\n",
+ "\n",
+ "for file in data_files:\n",
+ " with open(file, \"r\") as file:\n",
+ " # data = json.load(file)\n",
+ " for line in file:\n",
+ " item = json.loads(line.strip())\n",
+ " data.append(item)\n",
+ "\n",
+ "# This is the data relevant to training\n",
+ "data[0][\"code\"]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Data Parsing"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Parse the training data instructions"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def parse_instructions_and_outputs(code_section):\n",
+ "\n",
+ " sections = code_section.split('###')\n",
+ " instruction = \"\"\n",
+ " for section in sections:\n",
+ " if \"Instruction:\" in section:\n",
+ " instruction = section.lower().split(\"instruction:\", 1)[1].strip()\n",
+ " break\n",
+ "\n",
+ " # domain = re.search(r'<<>>(.*?)\\n', code_section, re.IGNORECASE).group(1).lstrip(': ')\n",
+ " if \"<<>>\" in code_section:\n",
+ " domain = re.search(r'<<>>(.*?)<<<', d[\"code\"], re.IGNORECASE | re.DOTALL).group(1).lstrip(': ')\n",
+ " else:\n",
+ " domain = \"\"\n",
+ "\n",
+ " api_call = re.search(r'<<>>(.*?)<<<', code_section, re.IGNORECASE | re.DOTALL).group(1).lstrip(': ')\n",
+ " # api_provider = re.search(r'<<>>(.*?)\\n', code_section, re.IGNORECASE).group(1).lstrip(': ')\n",
+ " if \"<<>>\" in code_section:\n",
+ " api_provider = re.search(r'<<>>(.*?)<<<', code_section, re.IGNORECASE | re.DOTALL).group(1).lstrip(': ')\n",
+ " else:\n",
+ " api_provider = \"\"\n",
+ "\n",
+ " if \"<<>>\" in code_section:\n",
+ " explanation_pattern = r'<<>>(.*?)(?:\\n<<>>|```|$)'\n",
+ " explanation = re.search(explanation_pattern, code_section, re.IGNORECASE | re.DOTALL).group(1).lstrip(': ')\n",
+ " else:\n",
+ " explanation = None\n",
+ "\n",
+ " # Extract code snippet considering both cases\n",
+ " code_pattern = r'(?:<<>>|```) (.*)' # Matches either <<>> or ```\n",
+ " code_snippet_match = re.search(code_pattern, code_section, re.IGNORECASE | re.DOTALL)\n",
+ " code_snippet = code_snippet_match.group(1).lstrip(': ') if code_snippet_match else None\n",
+ "\n",
+ " return instruction, domain, api_call, api_provider, explanation, code_snippet"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def encode_train_sample(data, api_name):\n",
+ " \"\"\"Encode multiple prompt instructions into a single string.\"\"\"\n",
+ " code_section = data['code']\n",
+ "\n",
+ " if \"<<>>\" in code_section:\n",
+ " instruction, domain, api_call, api_provider, explanation, code = parse_instructions_and_outputs(code_section)\n",
+ "\n",
+ " prompts = []\n",
+ "\n",
+ " #prompt = instruction + \"\\nWrite a python program in 1 to 2 lines to call API in \" + api_name + \".\\n\\nThe answer should follow the format: <<>> $DOMAIN, <<>>: $API_CALL, <<>>: $API_PROVIDER, <<>>: $EXPLANATION, <<>>: $CODE}. Here are the requirements:\\n\" + domains + \"\\n2. The $API_CALL should have only 1 line of code that calls api.\\n3. The $API_PROVIDER should be the programming framework used.\\n4. $EXPLANATION should be a step-by-step explanation.\\n5. The $CODE is the python code.\\n6. Do not repeat the format in your answer.\"\n",
+ "\n",
+ " prompts.append({\"role\": \"system\", \"content\": \"You are a helpful API writer who can write APIs based on requirements.\"})\n",
+ " prompts.append({\"role\": \"user\", \"content\": instruction})\n",
+ " prompts.append({\"role\": \"assistant\", \"content\": f\"<<>> {domain},\\\n",
+ "<<>>: {api_call}, <<>>: {api_provider}, <<>>: {explanation}, <<>>: {code}\"})\n",
+ " return prompts\n",
+ " else:\n",
+ " return None"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Format the training samples with the correct format to mirror the Gorilla paper"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "encoded_data = []\n",
+ "none_count = 0\n",
+ "for d in data:\n",
+ " res = encode_train_sample(d, \"huggingface\")\n",
+ " if res is not None:\n",
+ " encoded_data.append({\"messages\":res})\n",
+ " else:\n",
+ " none_count += 1\n",
+ "\n",
+ "print(f\"{none_count} samples out of {len(data)} ignored\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Print a sample of what will get passed to OpenAI for fine-tuning"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "encoded_data[333]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Save the training data"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "encoded_file_path = 'all_encoded_data.jsonl'\n",
+ "\n",
+ "with open(encoded_file_path, 'w') as file:\n",
+ " for item in encoded_data:\n",
+ " line = json.dumps(item)\n",
+ " file.write(line + '\\n')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Start a Weights & Biases run to save our data and results\n",
+ "wandb.init(project=\"gorilla-api\")\n",
+ "wandb.log_artifact(encoded_file_path, \"hf_tf_th_gorilla_train.jsonl\", type=\"train_data\")\n",
+ "wandb.finish()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## OpenAI data verification script"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# We start by importing the required packages\n",
+ "\n",
+ "import json\n",
+ "import os\n",
+ "import tiktoken\n",
+ "import numpy as np\n",
+ "from collections import defaultdict\n",
+ "\n",
+ "# Next, we specify the data path and open the JSONL file\n",
+ "\n",
+ "data_path = encoded_file_path\n",
+ "\n",
+ "# Load dataset\n",
+ "with open(data_path) as f:\n",
+ " dataset = [json.loads(line) for line in f]\n",
+ "\n",
+ "# We can inspect the data quickly by checking the number of examples and the first item\n",
+ "\n",
+ "# Initial dataset stats\n",
+ "print(\"Num examples:\", len(dataset))\n",
+ "print(\"First example:\")\n",
+ "for message in dataset[0][\"messages\"]:\n",
+ " print(message)\n",
+ "\n",
+ "# Now that we have a sense of the data, we need to go through all the different examples and check to make sure the formatting is correct and matches the Chat completions message structure\n",
+ "\n",
+ "# Format error checks\n",
+ "format_errors = defaultdict(int)\n",
+ "\n",
+ "for ex in dataset:\n",
+ " if not isinstance(ex, dict):\n",
+ " format_errors[\"data_type\"] += 1\n",
+ " continue\n",
+ "\n",
+ " messages = ex.get(\"messages\", None)\n",
+ " if not messages:\n",
+ " format_errors[\"missing_messages_list\"] += 1\n",
+ " continue\n",
+ "\n",
+ " for message in messages:\n",
+ " if \"role\" not in message or \"content\" not in message:\n",
+ " format_errors[\"message_missing_key\"] += 1\n",
+ "\n",
+ " if any(k not in (\"role\", \"content\", \"name\") for k in message):\n",
+ " format_errors[\"message_unrecognized_key\"] += 1\n",
+ "\n",
+ " if message.get(\"role\", None) not in (\"system\", \"user\", \"assistant\"):\n",
+ " format_errors[\"unrecognized_role\"] += 1\n",
+ "\n",
+ " content = message.get(\"content\", None)\n",
+ " if not content or not isinstance(content, str):\n",
+ " format_errors[\"missing_content\"] += 1\n",
+ "\n",
+ " if not any(message.get(\"role\", None) == \"assistant\" for message in messages):\n",
+ " format_errors[\"example_missing_assistant_message\"] += 1\n",
+ "\n",
+ "if format_errors:\n",
+ " print(\"Found errors:\")\n",
+ " for k, v in format_errors.items():\n",
+ " print(f\"{k}: {v}\")\n",
+ "else:\n",
+ " print(\"No errors found\")\n",
+ "\n",
+ "# Beyond the structure of the message, we also need to ensure that the length does not exceed the 4096 token limit.\n",
+ "\n",
+ "# Token counting functions\n",
+ "encoding = tiktoken.get_encoding(\"cl100k_base\")\n",
+ "\n",
+ "# not exact!\n",
+ "# simplified from https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb\n",
+ "def num_tokens_from_messages(messages, tokens_per_message=3, tokens_per_name=1):\n",
+ " num_tokens = 0\n",
+ " for message in messages:\n",
+ " num_tokens += tokens_per_message\n",
+ " for key, value in message.items():\n",
+ " num_tokens += len(encoding.encode(value))\n",
+ " if key == \"name\":\n",
+ " num_tokens += tokens_per_name\n",
+ " num_tokens += 3\n",
+ " return num_tokens\n",
+ "\n",
+ "def num_assistant_tokens_from_messages(messages):\n",
+ " num_tokens = 0\n",
+ " for message in messages:\n",
+ " if message[\"role\"] == \"assistant\":\n",
+ " num_tokens += len(encoding.encode(message[\"content\"]))\n",
+ " return num_tokens\n",
+ "\n",
+ "def print_distribution(values, name):\n",
+ " print(f\"\\n#### Distribution of {name}:\")\n",
+ " print(f\"min / max: {min(values)}, {max(values)}\")\n",
+ " print(f\"mean / median: {np.mean(values)}, {np.median(values)}\")\n",
+ " print(f\"p5 / p95: {np.quantile(values, 0.1)}, {np.quantile(values, 0.9)}\")\n",
+ "\n",
+ "# Last, we can look at the results of the different formatting operations before proceeding with creating a fine-tuning job:\n",
+ "\n",
+ "# Warnings and tokens counts\n",
+ "n_missing_system = 0\n",
+ "n_missing_user = 0\n",
+ "n_messages = []\n",
+ "convo_lens = []\n",
+ "assistant_message_lens = []\n",
+ "\n",
+ "for ex in dataset:\n",
+ " messages = ex[\"messages\"]\n",
+ " if not any(message[\"role\"] == \"system\" for message in messages):\n",
+ " n_missing_system += 1\n",
+ " if not any(message[\"role\"] == \"user\" for message in messages):\n",
+ " n_missing_user += 1\n",
+ " n_messages.append(len(messages))\n",
+ " convo_lens.append(num_tokens_from_messages(messages))\n",
+ " assistant_message_lens.append(num_assistant_tokens_from_messages(messages))\n",
+ "\n",
+ "print(\"Num examples missing system message:\", n_missing_system)\n",
+ "print(\"Num examples missing user message:\", n_missing_user)\n",
+ "print_distribution(n_messages, \"num_messages_per_example\")\n",
+ "print_distribution(convo_lens, \"num_total_tokens_per_example\")\n",
+ "print_distribution(assistant_message_lens, \"num_assistant_tokens_per_example\")\n",
+ "n_too_long = sum(l > 4096 for l in convo_lens)\n",
+ "print(f\"\\n{n_too_long} examples may be over the 4096 token limit, they will be truncated during fine-tuning\")\n",
+ "\n",
+ "# Pricing and default n_epochs estimate\n",
+ "MAX_TOKENS_PER_EXAMPLE = 4096\n",
+ "\n",
+ "MIN_TARGET_EXAMPLES = 100\n",
+ "MAX_TARGET_EXAMPLES = 25000\n",
+ "TARGET_EPOCHS = 3\n",
+ "MIN_EPOCHS = 1\n",
+ "MAX_EPOCHS = 25\n",
+ "\n",
+ "n_epochs = TARGET_EPOCHS\n",
+ "n_train_examples = len(dataset)\n",
+ "if n_train_examples * TARGET_EPOCHS < MIN_TARGET_EXAMPLES:\n",
+ " n_epochs = min(MAX_EPOCHS, MIN_TARGET_EXAMPLES // n_train_examples)\n",
+ "elif n_train_examples * TARGET_EPOCHS > MAX_TARGET_EXAMPLES:\n",
+ " n_epochs = max(MIN_EPOCHS, MAX_TARGET_EXAMPLES // n_train_examples)\n",
+ "\n",
+ "n_billing_tokens_in_dataset = sum(min(MAX_TOKENS_PER_EXAMPLE, length) for length in convo_lens)\n",
+ "print(f\"Dataset has ~{n_billing_tokens_in_dataset} tokens that will be charged for during training\")\n",
+ "print(f\"By default, you'll train for {n_epochs} epochs on this dataset\")\n",
+ "print(f\"By default, you'll be charged for ~{n_epochs * n_billing_tokens_in_dataset} tokens\")\n",
+ "print(\"See pricing page to estimate total costs\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "wandb.summary[\"num_samples\"] = len(dataset)\n",
+ "wandb.summary[\"n_billing_tokens_in_dataset\"] = n_billing_tokens_in_dataset"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Start Fine-tuning ChatGPT-3.5"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Create an OpenAI training file"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "openai.File.create(\n",
+ " file=open(encoded_file_path, \"rb\"),\n",
+ " purpose='fine-tune'\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Create your fine-tuning job"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "openai.api_key = openai_api_key\n",
+ "openai.FineTuningJob.create(\n",
+ " training_file=\"file-N9M4sC8GfXgTNw0WAwgiLHNR\", #\"file-OrxAP7HcvoSUmu9MtAbWo5s4\",\n",
+ " model=\"gpt-3.5-turbo\",\n",
+ " hyperparameters={\"n_epochs\": 3}\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "openai.FineTuningJob.list_events(id=\"ftjob-ShHWEMHa2U7gRNVTpjOYEZEP\", limit=5)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Log the results to Weights & Biases when the model is finished training\n",
+ "\n",
+ "(temporarily install openai from a fork until this PR to update the wandb logger is merged in openai: https://github.com/openai/openai-python/pull/590)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!pip uninstall -y openai -qq && pip install git+https://github.com/morganmcg1/openai-python.git@update_wandb_logger -qqq"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Run `openai wandb sync` to sync your openai results to W&B"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!OPENAI_API_KEY={openai_api_key} openai wandb sync --entity prompt-eng --project gorilla-api --id ftjob-mNSsI2UcxCvpV767GmnYoSzR"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Other useful openai commands"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "List 10 fine-tuning jobs"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "openai.FineTuningJob.list(limit=10)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Retrieve the state of a fine-tune"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "state = openai.FineTuningJob.retrieve(\"ftjob-qhg4yswil15TCqD4SNHn0V1D\")\n",
+ "state[\"status\"], state[\"trained_tokens\"], state[\"finished_at\"]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "List up to 10 events from a fine-tuning job"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "openai.FineTuningJob.list_events(id=\"ftjob-qhg4yswil15TCqD4SNHn0V1D\", limit=10)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Use the Model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "openai.api_key = openai_api_key\n",
+ "\n",
+ "completion = openai.ChatCompletion.create(\n",
+ " model=\"ft:gpt-3.5-turbo:my-org:custom_suffix:id\",\n",
+ " messages=[\n",
+ " {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n",
+ " {\"role\": \"user\", \"content\": \"How can i load a NER model?\"}\n",
+ " ]\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "pprint(completion.choices[0].message)\n",
+ "pprint(completion.choices[0].message[\"content\"])"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "include_colab_link": true,
+ "provenance": [],
+ "toc_visible": true
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}