diff --git a/colabs/openai/OpenAI_Finetuning_on_Gorilla_with_wandb.ipynb b/colabs/openai/OpenAI_Finetuning_on_Gorilla_with_wandb.ipynb new file mode 100644 index 00000000..0915d874 --- /dev/null +++ b/colabs/openai/OpenAI_Finetuning_on_Gorilla_with_wandb.ipynb @@ -0,0 +1,635 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\"Open\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# ChatGPT-3.5 Fine-tuning - Gorrilla api\n", + "\n", + "Fine-tuning ChatGPT-3.5 on the Gorilla api dataset to try and improve its performance\n", + "- [Gorilla project](https://shishirpatil.github.io/gorilla/)\n", + "- [Gorilla paper](https://arxiv.org/abs/2305.15334)\n", + "- [Gorilla code](https://github.com/ShishirPatil/gorilla)\n", + "\n", + "OpenAI ChatGPT-3.5 fine-tuning docs [are here](https://platform.openai.com/docs/guides/fine-tuning)\n", + "\n", + "**Warning!**\n", + "\n", + "This fine-tuning script will train 7.2 million tokens on OpenAI, check if you're willing to pay that before proceeding :)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install openai tiktoken wandb -qqq" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import re\n", + "import os\n", + "import json\n", + "import wandb\n", + "import openai\n", + "from pprint import pprint" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "openai_api_key = \"OPENAI API KEY\"\n", + "openai.api_key = openai_api_key" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Data" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Download the Gorrilla huggingface api training data, you can find all the [Gorilla training data here](https://github.com/ShishirPatil/gorilla/tree/main/data/apibench)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!wget https://raw.githubusercontent.com/ShishirPatil/gorilla/cab053ba7fdf4a3286c0e75aa2bf7abc4053812f/data/apibench/huggingface_train.json\n", + "!wget https://raw.githubusercontent.com/ShishirPatil/gorilla/main/data/apibench/tensorflow_train.json\n", + "!wget https://raw.githubusercontent.com/ShishirPatil/gorilla/main/data/apibench/torchhub_train.json" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load the data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "data = []\n", + "data_files = [\n", + " \"huggingface_train.json\",\n", + " \"tensorflow_train.json\",\n", + " \"torchhub_train.json\",\n", + "]\n", + "\n", + "for file in data_files:\n", + " with open(file, \"r\") as file:\n", + " # data = json.load(file)\n", + " for line in file:\n", + " item = json.loads(line.strip())\n", + " data.append(item)\n", + "\n", + "# This is the data relevant to training\n", + "data[0][\"code\"]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Data Parsing" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Parse the training data instructions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def parse_instructions_and_outputs(code_section):\n", + "\n", + " sections = code_section.split('###')\n", + " instruction = \"\"\n", + " for section in sections:\n", + " if \"Instruction:\" in section:\n", + " instruction = section.lower().split(\"instruction:\", 1)[1].strip()\n", + " break\n", + "\n", + " # domain = re.search(r'<<>>(.*?)\\n', code_section, re.IGNORECASE).group(1).lstrip(': ')\n", + " if \"<<>>\" in code_section:\n", + " domain = re.search(r'<<>>(.*?)<<<', d[\"code\"], re.IGNORECASE | re.DOTALL).group(1).lstrip(': ')\n", + " else:\n", + " domain = \"\"\n", + "\n", + " api_call = re.search(r'<<>>(.*?)<<<', code_section, re.IGNORECASE | re.DOTALL).group(1).lstrip(': ')\n", + " # api_provider = re.search(r'<<>>(.*?)\\n', code_section, re.IGNORECASE).group(1).lstrip(': ')\n", + " if \"<<>>\" in code_section:\n", + " api_provider = re.search(r'<<>>(.*?)<<<', code_section, re.IGNORECASE | re.DOTALL).group(1).lstrip(': ')\n", + " else:\n", + " api_provider = \"\"\n", + "\n", + " if \"<<>>\" in code_section:\n", + " explanation_pattern = r'<<>>(.*?)(?:\\n<<>>|```|$)'\n", + " explanation = re.search(explanation_pattern, code_section, re.IGNORECASE | re.DOTALL).group(1).lstrip(': ')\n", + " else:\n", + " explanation = None\n", + "\n", + " # Extract code snippet considering both cases\n", + " code_pattern = r'(?:<<>>|```) (.*)' # Matches either <<>> or ```\n", + " code_snippet_match = re.search(code_pattern, code_section, re.IGNORECASE | re.DOTALL)\n", + " code_snippet = code_snippet_match.group(1).lstrip(': ') if code_snippet_match else None\n", + "\n", + " return instruction, domain, api_call, api_provider, explanation, code_snippet" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def encode_train_sample(data, api_name):\n", + " \"\"\"Encode multiple prompt instructions into a single string.\"\"\"\n", + " code_section = data['code']\n", + "\n", + " if \"<<>>\" in code_section:\n", + " instruction, domain, api_call, api_provider, explanation, code = parse_instructions_and_outputs(code_section)\n", + "\n", + " prompts = []\n", + "\n", + " #prompt = instruction + \"\\nWrite a python program in 1 to 2 lines to call API in \" + api_name + \".\\n\\nThe answer should follow the format: <<>> $DOMAIN, <<>>: $API_CALL, <<>>: $API_PROVIDER, <<>>: $EXPLANATION, <<>>: $CODE}. Here are the requirements:\\n\" + domains + \"\\n2. The $API_CALL should have only 1 line of code that calls api.\\n3. The $API_PROVIDER should be the programming framework used.\\n4. $EXPLANATION should be a step-by-step explanation.\\n5. The $CODE is the python code.\\n6. Do not repeat the format in your answer.\"\n", + "\n", + " prompts.append({\"role\": \"system\", \"content\": \"You are a helpful API writer who can write APIs based on requirements.\"})\n", + " prompts.append({\"role\": \"user\", \"content\": instruction})\n", + " prompts.append({\"role\": \"assistant\", \"content\": f\"<<>> {domain},\\\n", + "<<>>: {api_call}, <<>>: {api_provider}, <<>>: {explanation}, <<>>: {code}\"})\n", + " return prompts\n", + " else:\n", + " return None" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Format the training samples with the correct format to mirror the Gorilla paper" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "encoded_data = []\n", + "none_count = 0\n", + "for d in data:\n", + " res = encode_train_sample(d, \"huggingface\")\n", + " if res is not None:\n", + " encoded_data.append({\"messages\":res})\n", + " else:\n", + " none_count += 1\n", + "\n", + "print(f\"{none_count} samples out of {len(data)} ignored\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Print a sample of what will get passed to OpenAI for fine-tuning" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "encoded_data[333]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Save the training data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "encoded_file_path = 'all_encoded_data.jsonl'\n", + "\n", + "with open(encoded_file_path, 'w') as file:\n", + " for item in encoded_data:\n", + " line = json.dumps(item)\n", + " file.write(line + '\\n')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Start a Weights & Biases run to save our data and results\n", + "wandb.init(project=\"gorilla-api\")\n", + "wandb.log_artifact(encoded_file_path, \"hf_tf_th_gorilla_train.jsonl\", type=\"train_data\")\n", + "wandb.finish()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## OpenAI data verification script" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# We start by importing the required packages\n", + "\n", + "import json\n", + "import os\n", + "import tiktoken\n", + "import numpy as np\n", + "from collections import defaultdict\n", + "\n", + "# Next, we specify the data path and open the JSONL file\n", + "\n", + "data_path = encoded_file_path\n", + "\n", + "# Load dataset\n", + "with open(data_path) as f:\n", + " dataset = [json.loads(line) for line in f]\n", + "\n", + "# We can inspect the data quickly by checking the number of examples and the first item\n", + "\n", + "# Initial dataset stats\n", + "print(\"Num examples:\", len(dataset))\n", + "print(\"First example:\")\n", + "for message in dataset[0][\"messages\"]:\n", + " print(message)\n", + "\n", + "# Now that we have a sense of the data, we need to go through all the different examples and check to make sure the formatting is correct and matches the Chat completions message structure\n", + "\n", + "# Format error checks\n", + "format_errors = defaultdict(int)\n", + "\n", + "for ex in dataset:\n", + " if not isinstance(ex, dict):\n", + " format_errors[\"data_type\"] += 1\n", + " continue\n", + "\n", + " messages = ex.get(\"messages\", None)\n", + " if not messages:\n", + " format_errors[\"missing_messages_list\"] += 1\n", + " continue\n", + "\n", + " for message in messages:\n", + " if \"role\" not in message or \"content\" not in message:\n", + " format_errors[\"message_missing_key\"] += 1\n", + "\n", + " if any(k not in (\"role\", \"content\", \"name\") for k in message):\n", + " format_errors[\"message_unrecognized_key\"] += 1\n", + "\n", + " if message.get(\"role\", None) not in (\"system\", \"user\", \"assistant\"):\n", + " format_errors[\"unrecognized_role\"] += 1\n", + "\n", + " content = message.get(\"content\", None)\n", + " if not content or not isinstance(content, str):\n", + " format_errors[\"missing_content\"] += 1\n", + "\n", + " if not any(message.get(\"role\", None) == \"assistant\" for message in messages):\n", + " format_errors[\"example_missing_assistant_message\"] += 1\n", + "\n", + "if format_errors:\n", + " print(\"Found errors:\")\n", + " for k, v in format_errors.items():\n", + " print(f\"{k}: {v}\")\n", + "else:\n", + " print(\"No errors found\")\n", + "\n", + "# Beyond the structure of the message, we also need to ensure that the length does not exceed the 4096 token limit.\n", + "\n", + "# Token counting functions\n", + "encoding = tiktoken.get_encoding(\"cl100k_base\")\n", + "\n", + "# not exact!\n", + "# simplified from https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb\n", + "def num_tokens_from_messages(messages, tokens_per_message=3, tokens_per_name=1):\n", + " num_tokens = 0\n", + " for message in messages:\n", + " num_tokens += tokens_per_message\n", + " for key, value in message.items():\n", + " num_tokens += len(encoding.encode(value))\n", + " if key == \"name\":\n", + " num_tokens += tokens_per_name\n", + " num_tokens += 3\n", + " return num_tokens\n", + "\n", + "def num_assistant_tokens_from_messages(messages):\n", + " num_tokens = 0\n", + " for message in messages:\n", + " if message[\"role\"] == \"assistant\":\n", + " num_tokens += len(encoding.encode(message[\"content\"]))\n", + " return num_tokens\n", + "\n", + "def print_distribution(values, name):\n", + " print(f\"\\n#### Distribution of {name}:\")\n", + " print(f\"min / max: {min(values)}, {max(values)}\")\n", + " print(f\"mean / median: {np.mean(values)}, {np.median(values)}\")\n", + " print(f\"p5 / p95: {np.quantile(values, 0.1)}, {np.quantile(values, 0.9)}\")\n", + "\n", + "# Last, we can look at the results of the different formatting operations before proceeding with creating a fine-tuning job:\n", + "\n", + "# Warnings and tokens counts\n", + "n_missing_system = 0\n", + "n_missing_user = 0\n", + "n_messages = []\n", + "convo_lens = []\n", + "assistant_message_lens = []\n", + "\n", + "for ex in dataset:\n", + " messages = ex[\"messages\"]\n", + " if not any(message[\"role\"] == \"system\" for message in messages):\n", + " n_missing_system += 1\n", + " if not any(message[\"role\"] == \"user\" for message in messages):\n", + " n_missing_user += 1\n", + " n_messages.append(len(messages))\n", + " convo_lens.append(num_tokens_from_messages(messages))\n", + " assistant_message_lens.append(num_assistant_tokens_from_messages(messages))\n", + "\n", + "print(\"Num examples missing system message:\", n_missing_system)\n", + "print(\"Num examples missing user message:\", n_missing_user)\n", + "print_distribution(n_messages, \"num_messages_per_example\")\n", + "print_distribution(convo_lens, \"num_total_tokens_per_example\")\n", + "print_distribution(assistant_message_lens, \"num_assistant_tokens_per_example\")\n", + "n_too_long = sum(l > 4096 for l in convo_lens)\n", + "print(f\"\\n{n_too_long} examples may be over the 4096 token limit, they will be truncated during fine-tuning\")\n", + "\n", + "# Pricing and default n_epochs estimate\n", + "MAX_TOKENS_PER_EXAMPLE = 4096\n", + "\n", + "MIN_TARGET_EXAMPLES = 100\n", + "MAX_TARGET_EXAMPLES = 25000\n", + "TARGET_EPOCHS = 3\n", + "MIN_EPOCHS = 1\n", + "MAX_EPOCHS = 25\n", + "\n", + "n_epochs = TARGET_EPOCHS\n", + "n_train_examples = len(dataset)\n", + "if n_train_examples * TARGET_EPOCHS < MIN_TARGET_EXAMPLES:\n", + " n_epochs = min(MAX_EPOCHS, MIN_TARGET_EXAMPLES // n_train_examples)\n", + "elif n_train_examples * TARGET_EPOCHS > MAX_TARGET_EXAMPLES:\n", + " n_epochs = max(MIN_EPOCHS, MAX_TARGET_EXAMPLES // n_train_examples)\n", + "\n", + "n_billing_tokens_in_dataset = sum(min(MAX_TOKENS_PER_EXAMPLE, length) for length in convo_lens)\n", + "print(f\"Dataset has ~{n_billing_tokens_in_dataset} tokens that will be charged for during training\")\n", + "print(f\"By default, you'll train for {n_epochs} epochs on this dataset\")\n", + "print(f\"By default, you'll be charged for ~{n_epochs * n_billing_tokens_in_dataset} tokens\")\n", + "print(\"See pricing page to estimate total costs\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "wandb.summary[\"num_samples\"] = len(dataset)\n", + "wandb.summary[\"n_billing_tokens_in_dataset\"] = n_billing_tokens_in_dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Start Fine-tuning ChatGPT-3.5" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Create an OpenAI training file" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "openai.File.create(\n", + " file=open(encoded_file_path, \"rb\"),\n", + " purpose='fine-tune'\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Create your fine-tuning job" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "openai.api_key = openai_api_key\n", + "openai.FineTuningJob.create(\n", + " training_file=\"file-N9M4sC8GfXgTNw0WAwgiLHNR\", #\"file-OrxAP7HcvoSUmu9MtAbWo5s4\",\n", + " model=\"gpt-3.5-turbo\",\n", + " hyperparameters={\"n_epochs\": 3}\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "openai.FineTuningJob.list_events(id=\"ftjob-ShHWEMHa2U7gRNVTpjOYEZEP\", limit=5)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Log the results to Weights & Biases when the model is finished training\n", + "\n", + "(temporarily install openai from a fork until this PR to update the wandb logger is merged in openai: https://github.com/openai/openai-python/pull/590)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip uninstall -y openai -qq && pip install git+https://github.com/morganmcg1/openai-python.git@update_wandb_logger -qqq" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Run `openai wandb sync` to sync your openai results to W&B" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!OPENAI_API_KEY={openai_api_key} openai wandb sync --entity prompt-eng --project gorilla-api --id ftjob-mNSsI2UcxCvpV767GmnYoSzR" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Other useful openai commands" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "List 10 fine-tuning jobs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "openai.FineTuningJob.list(limit=10)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Retrieve the state of a fine-tune" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "state = openai.FineTuningJob.retrieve(\"ftjob-qhg4yswil15TCqD4SNHn0V1D\")\n", + "state[\"status\"], state[\"trained_tokens\"], state[\"finished_at\"]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "List up to 10 events from a fine-tuning job" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "openai.FineTuningJob.list_events(id=\"ftjob-qhg4yswil15TCqD4SNHn0V1D\", limit=10)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Use the Model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "openai.api_key = openai_api_key\n", + "\n", + "completion = openai.ChatCompletion.create(\n", + " model=\"ft:gpt-3.5-turbo:my-org:custom_suffix:id\",\n", + " messages=[\n", + " {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n", + " {\"role\": \"user\", \"content\": \"How can i load a NER model?\"}\n", + " ]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pprint(completion.choices[0].message)\n", + "pprint(completion.choices[0].message[\"content\"])" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "include_colab_link": true, + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +}