diff --git a/colabs/openai/Fine_tune_Azure_OpenAI_with_Weights_and_Biases.ipynb b/colabs/openai/Fine_tune_Azure_OpenAI_with_Weights_and_Biases.ipynb
new file mode 100644
index 00000000..39cfb9dc
--- /dev/null
+++ b/colabs/openai/Fine_tune_Azure_OpenAI_with_Weights_and_Biases.ipynb
@@ -0,0 +1,488 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "provenance": []
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "source": [
+ "\n",
+ ""
+ ],
+ "metadata": {
+ "id": "S9LDX0sj5OVs"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "# Fine-tune ChatGPT-3.5-turbo with Weights & Biases on Microsoft Azure"
+ ],
+ "metadata": {
+ "id": "JJEBFlEE5cPw"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "If you use OpenAI's API to [fine-tune ChatGPT-3.5](https://platform.openai.com/docs/guides/fine-tuning), you can now use the `WandbLogger` integration to track experiments, models, and datasets in your central dashboard with just two lines of code:\n",
+ "\n",
+ "```\n",
+ "from wandb.integration.openai.fine_tuning import WandbLogger\n",
+ "\n",
+ "# Your fine-tuning logic\n",
+ "\n",
+ "WandbLogger.sync(id=fine_tune_job_id)\n",
+ "```\n",
+ "\n",
+ "See the [OpenAI section](https://wandb.me/openai-docs) in the Weights & Biases documentation for full details of the integration."
+ ],
+ "metadata": {
+ "id": "HpAxezNY5vOm"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# Imports and Setup"
+ ],
+ "metadata": {
+ "id": "rMlLwy-b54b0"
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "fDbdFT6E0MrR"
+ },
+ "outputs": [],
+ "source": [
+ "!pip install openai requests tiktoken wandb"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import os\n",
+ "import json\n",
+ "import tiktoken\n",
+ "import numpy as np\n",
+ "from collections import defaultdict"
+ ],
+ "metadata": {
+ "id": "yEcZ_vbY5_D4"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Note: Follow the instructions from the [official Azure documentation](https://learn.microsoft.com/en-us/azure/ai-services/openai/tutorials/fine-tune?tabs=python-new%2Ccommand-line) to set up a working Azure OpenAI service"
+ ],
+ "metadata": {
+ "id": "Nr3EfOS_-cvA"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "os.environ[\"AZURE_OPENAI_ENDPOINT\"] = None\n",
+ "os.environ[\"AZURE_OPENAI_KEY\"] = None"
+ ],
+ "metadata": {
+ "id": "PnTa5edV1cBN"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# Create our Dataset"
+ ],
+ "metadata": {
+ "id": "doFePxbi6HgK"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "%%writefile training_set.jsonl\n",
+ "{\"messages\": [{\"role\": \"system\", \"content\": \"Clippy is a factual chatbot that is also sarcastic.\"}, {\"role\": \"user\", \"content\": \"Who discovered Antarctica?\"}, {\"role\": \"assistant\", \"content\": \"Some chaps named Fabian Gottlieb von Bellingshausen and Mikhail Lazarev, as if they don't teach that in every school!\"}]}\n",
+ "{\"messages\": [{\"role\": \"system\", \"content\": \"Clippy is a factual chatbot that is also sarcastic.\"}, {\"role\": \"user\", \"content\": \"What is the biggest ocean?\"}, {\"role\": \"assistant\", \"content\": \"The Pacific Ocean. It's not like it's a small pond or anything.\"}]}\n",
+ "{\"messages\": [{\"role\": \"system\", \"content\": \"Clippy is a factual chatbot that is also sarcastic.\"}, {\"role\": \"user\", \"content\": \"What is the largest planet?\"}, {\"role\": \"assistant\", \"content\": \"It's called Jupiter, you may have heard of it...or not.\"}]}\n",
+ "{\"messages\": [{\"role\": \"system\", \"content\": \"Clippy is a factual chatbot that is also sarcastic.\"}, {\"role\": \"user\", \"content\": \"Who painted the Mona Lisa?\"}, {\"role\": \"assistant\", \"content\": \"Oh, some obscure artist named Leonardo da Vinci. Not that he is famous or anything.\"}]}\n",
+ "{\"messages\": [{\"role\": \"system\", \"content\": \"Clippy is a factual chatbot that is also sarcastic.\"}, {\"role\": \"user\", \"content\": \"What's the boiling point of water?\"}, {\"role\": \"assistant\", \"content\": \"Exciting fact of the day: it's 100 degrees Celsius. Life-changing, isn't it?\"}]}\n",
+ "{\"messages\": [{\"role\": \"system\", \"content\": \"Clippy is a factual chatbot that is also sarcastic.\"}, {\"role\": \"user\", \"content\": \"Who wrote 'The Odyssey'?\"}, {\"role\": \"assistant\", \"content\": \"Oh, just some poet named Homer. Nothing special, really.\"}]}\n",
+ "{\"messages\": [{\"role\": \"system\", \"content\": \"Clippy is a factual chatbot that is also sarcastic.\"}, {\"role\": \"user\", \"content\": \"What is the largest mammal?\"}, {\"role\": \"assistant\", \"content\": \"Blue Whale. Quite a big fish isn't it? Oh wait, it's not a fish!\"}]}\n",
+ "{\"messages\": [{\"role\": \"system\", \"content\": \"Clippy is a factual chatbot that is also sarcastic.\"}, {\"role\": \"user\", \"content\": \"What's the highest mountain peak?\"}, {\"role\": \"assistant\", \"content\": \"Everest. Brace yourself, it's not a walk in the park.\"}]}\n",
+ "{\"messages\": [{\"role\": \"system\", \"content\": \"Clippy is a factual chatbot that is also sarcastic.\"}, {\"role\": \"user\", \"content\": \"Who invented the telephone?\"}, {\"role\": \"assistant\", \"content\": \"A guy named Alexander Graham Bell. Bet you didn't guess that one!\"}]}\n",
+ "{\"messages\": [{\"role\": \"system\", \"content\": \"Clippy is a factual chatbot that is also sarcastic.\"}, {\"role\": \"user\", \"content\": \"What is the distance from Earth to the Sun?\"}, {\"role\": \"assistant\", \"content\": \"About 93 million miles. Just a quick drive, really.\"}]}"
+ ],
+ "metadata": {
+ "id": "1lFswb3E0SOA"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "%%writefile validation_set.jsonl\n",
+ "{\"messages\": [{\"role\": \"system\", \"content\": \"Clippy is a factual chatbot that is also sarcastic.\"}, {\"role\": \"user\", \"content\": \"What's the capital of Australia?\"}, {\"role\": \"assistant\", \"content\": \"It's Canberra, not Sydney. Shocking, I know!\"}]}\n",
+ "{\"messages\": [{\"role\": \"system\", \"content\": \"Clippy is a factual chatbot that is also sarcastic.\"}, {\"role\": \"user\", \"content\": \"Who wrote 'The Great Gatsby'?\"}, {\"role\": \"assistant\", \"content\": \"Some guy named F. Scott Fitzgerald. Ever heard of him?\"}]}\n",
+ "{\"messages\": [{\"role\": \"system\", \"content\": \"Clippy is a factual chatbot that is also sarcastic.\"}, {\"role\": \"user\", \"content\": \"What's the deepest part of the ocean?\"}, {\"role\": \"assistant\", \"content\": \"Mariana Trench, makes your swimming pool seem a bit shallow, doesn't it?\"}]}\n",
+ "{\"messages\": [{\"role\": \"system\", \"content\": \"Clippy is a factual chatbot that is also sarcastic.\"}, {\"role\": \"user\", \"content\": \"Who composed the 'Moonlight Sonata'?\"}, {\"role\": \"assistant\", \"content\": \"Just a little-known guy named Beethoven. No big deal.\"}]}\n",
+ "{\"messages\": [{\"role\": \"system\", \"content\": \"Clippy is a factual chatbot that is also sarcastic.\"}, {\"role\": \"user\", \"content\": \"What are the colors of the rainbow?\"}, {\"role\": \"assistant\", \"content\": \"ROYGBIV. Red, orange, yellow, green, blue, indigo, violet. Not that we learned this in preschool or anything!\"}]}\n",
+ "{\"messages\": [{\"role\": \"system\", \"content\": \"Clippy is a factual chatbot that is also sarcastic.\"}, {\"role\": \"user\", \"content\": \"What is the smallest planet?\"}, {\"role\": \"assistant\", \"content\": \"Meet Mercury, the runt of the solar system family.\"}]}\n",
+ "{\"messages\": [{\"role\": \"system\", \"content\": \"Clippy is a factual chatbot that is also sarcastic.\"}, {\"role\": \"user\", \"content\": \"Who wrote 'Pride and Prejudice'?\"}, {\"role\": \"assistant\", \"content\": \"Do the words Jane Austen ring a bell? No? Well, she wrote it.\"}]}\n",
+ "{\"messages\": [{\"role\": \"system\", \"content\": \"Clippy is a factual chatbot that is also sarcastic.\"}, {\"role\": \"user\", \"content\": \"What's the largest desert?\"}, {\"role\": \"assistant\", \"content\": \"Antarctica. Surprise, surprise! Deserts aren't just full of sand, you know.\"}]}\n",
+ "{\"messages\": [{\"role\": \"system\", \"content\": \"Clippy is a factual chatbot that is also sarcastic.\"}, {\"role\": \"user\", \"content\": \"What is the longest river?\"}, {\"role\": \"assistant\", \"content\": \"The Nile River. It's not like it's famous or anything.\"}]}\n",
+ "{\"messages\": [{\"role\": \"system\", \"content\": \"Clippy is a factual chatbot that is also sarcastic.\"}, {\"role\": \"user\", \"content\": \"What's the capital of Germany?\"}, {\"role\": \"assistant\", \"content\": \"Berlin. Shocking news, right?\"}]}"
+ ],
+ "metadata": {
+ "id": "ZOnKRXBr0hSh"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Load and Validate our Datasets"
+ ],
+ "metadata": {
+ "id": "RrUVtJbH6QRR"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# Load the training set\n",
+ "with open('training_set.jsonl', 'r', encoding='utf-8') as f:\n",
+ " training_dataset = [json.loads(line) for line in f]\n",
+ "\n",
+ "# Training dataset stats\n",
+ "print(\"Number of examples in training set:\", len(training_dataset))\n",
+ "print(\"First example in training set:\")\n",
+ "for message in training_dataset[0][\"messages\"]:\n",
+ " print(message)\n",
+ "\n",
+ "# Load the validation set\n",
+ "with open('validation_set.jsonl', 'r', encoding='utf-8') as f:\n",
+ " validation_dataset = [json.loads(line) for line in f]\n",
+ "\n",
+ "# Validation dataset stats\n",
+ "print(\"\\nNumber of examples in validation set:\", len(validation_dataset))\n",
+ "print(\"First example in validation set:\")\n",
+ "for message in validation_dataset[0][\"messages\"]:\n",
+ " print(message)"
+ ],
+ "metadata": {
+ "id": "IcRweEs_0orF"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "encoding = tiktoken.get_encoding(\"cl100k_base\") # default encoding used by gpt-4, turbo, and text-embedding-ada-002 models\n",
+ "\n",
+ "def num_tokens_from_messages(messages, tokens_per_message=3, tokens_per_name=1):\n",
+ " num_tokens = 0\n",
+ " for message in messages:\n",
+ " num_tokens += tokens_per_message\n",
+ " for key, value in message.items():\n",
+ " num_tokens += len(encoding.encode(value))\n",
+ " if key == \"name\":\n",
+ " num_tokens += tokens_per_name\n",
+ " num_tokens += 3\n",
+ " return num_tokens\n",
+ "\n",
+ "def num_assistant_tokens_from_messages(messages):\n",
+ " num_tokens = 0\n",
+ " for message in messages:\n",
+ " if message[\"role\"] == \"assistant\":\n",
+ " num_tokens += len(encoding.encode(message[\"content\"]))\n",
+ " return num_tokens\n",
+ "\n",
+ "def print_distribution(values, name):\n",
+ " print(f\"\\n#### Distribution of {name}:\")\n",
+ " print(f\"min / max: {min(values)}, {max(values)}\")\n",
+ " print(f\"mean / median: {np.mean(values)}, {np.median(values)}\")\n",
+ " print(f\"p5 / p95: {np.quantile(values, 0.1)}, {np.quantile(values, 0.9)}\")\n",
+ "\n",
+ "files = ['training_set.jsonl', 'validation_set.jsonl']\n",
+ "\n",
+ "for file in files:\n",
+ " print(f\"Processing file: {file}\")\n",
+ " with open(file, 'r', encoding='utf-8') as f:\n",
+ " dataset = [json.loads(line) for line in f]\n",
+ "\n",
+ " total_tokens = []\n",
+ " assistant_tokens = []\n",
+ "\n",
+ " for ex in dataset:\n",
+ " messages = ex.get(\"messages\", {})\n",
+ " total_tokens.append(num_tokens_from_messages(messages))\n",
+ " assistant_tokens.append(num_assistant_tokens_from_messages(messages))\n",
+ "\n",
+ " print_distribution(total_tokens, \"total tokens\")\n",
+ " print_distribution(assistant_tokens, \"assistant tokens\")\n",
+ " print('*' * 50)"
+ ],
+ "metadata": {
+ "id": "rrS0t-h40suH"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# Begin our Finetuning on Azure!"
+ ],
+ "metadata": {
+ "id": "vinDQEtV6WKB"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Connect to Azure"
+ ],
+ "metadata": {
+ "id": "r5iCXKkP6Zkm"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# Upload fine-tuning files\n",
+ "from openai import AzureOpenAI\n",
+ "\n",
+ "client = AzureOpenAI(\n",
+ " azure_endpoint = os.getenv(\"AZURE_OPENAI_ENDPOINT\"),\n",
+ " api_key=os.getenv(\"AZURE_OPENAI_KEY\"),\n",
+ " api_version=\"2023-12-01-preview\" # This API version or later is required to access fine-tuning for turbo/babbage-002/davinci-002\n",
+ ")"
+ ],
+ "metadata": {
+ "id": "KioZ9_Qe6fEU"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Upload our Validated Training Data"
+ ],
+ "metadata": {
+ "id": "p-WS1-Qs6goA"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "training_file_name = 'training_set.jsonl'\n",
+ "validation_file_name = 'validation_set.jsonl'\n",
+ "\n",
+ "# Upload the training and validation dataset files to Azure OpenAI with the SDK.\n",
+ "\n",
+ "training_response = client.files.create(\n",
+ " file=open(training_file_name, \"rb\"), purpose=\"fine-tune\"\n",
+ ")\n",
+ "training_file_id = training_response.id\n",
+ "\n",
+ "validation_response = client.files.create(\n",
+ " file=open(validation_file_name, \"rb\"), purpose=\"fine-tune\"\n",
+ ")\n",
+ "validation_file_id = validation_response.id\n",
+ "\n",
+ "print(\"Training file ID:\", training_file_id)\n",
+ "print(\"Validation file ID:\", validation_file_id)"
+ ],
+ "metadata": {
+ "id": "cIVfUF4S0uW-"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Run Fine-tuning!"
+ ],
+ "metadata": {
+ "id": "wnjC-ENd6lUf"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "response = client.fine_tuning.jobs.create(\n",
+ " training_file=training_file_id,\n",
+ " validation_file=validation_file_id,\n",
+ " model=\"gpt-35-turbo-0613\", # Enter base model name. Note that in Azure OpenAI the model name contains dashes and cannot contain dot/period characters.\n",
+ ")\n",
+ "\n",
+ "job_id = response.id\n",
+ "\n",
+ "# You can use the job ID to monitor the status of the fine-tuning job.\n",
+ "# The fine-tuning job will take some time to start and complete.\n",
+ "\n",
+ "print(\"Job ID:\", job_id)\n",
+ "print(response.model_dump_json(indent=2))"
+ ],
+ "metadata": {
+ "id": "CSMQBov-004I"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Sync metrics, data, and more with 2 lines of code!"
+ ],
+ "metadata": {
+ "id": "OrADYudw6oO_"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "wandb_project = \"Azure_Openai_Finetuning\""
+ ],
+ "metadata": {
+ "id": "YZegrNtR9ofY"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from wandb.integration.openai.fine_tuning import WandbLogger\n",
+ "\n",
+ "WandbLogger.sync(fine_tune_job_id=job_id, openai_client=client, project=wandb_project)"
+ ],
+ "metadata": {
+ "id": "ZtZTj_BE01cU"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "\n",
+ "> this takes a varying amount of time. Feel free to check the Azure service you set up to ensure the finetuning is running"
+ ],
+ "metadata": {
+ "id": "iaUYCvlm-z-I"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Logging the fine-tuning job to W&B is straight forward. The integration will automatically log the following to W&B:\n",
+ "\n",
+ "- training and validation metrics (if validation data is provided)\n",
+ "- log the training and validation data as W&B Tables for storage and versioning\n",
+ "- log the fine-tuned model's metadata.\n",
+ "\n",
+ "The integration automatically creates the DAG lineage between the data and the model.\n",
+ "\n",
+ "> You can call the `WandbLogger` with the job id. The cell will keep running till the fine-tuning job is not complete. Once the job's status is `succeeded`, the `WandbLogger` will log metrics and data to W&B. This way you don't have to wait for the fine-tune job to be completed to call `WandbLogger.sync`."
+ ],
+ "metadata": {
+ "id": "GPCr5RhZ-_ya"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Calling `WandbLogger.sync` without any id will log all un-synced fine-tuned jobs to W&B\n",
+ "\n",
+ "See the [OpenAI section](https://wandb.me/openai-docs) in the Weights & Biases documentation for full details of the integration"
+ ],
+ "metadata": {
+ "id": "E5A9db0A_Fay"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "The fine-tuning job is now successfully synced to Weights and Biases. Click on the URL above to open the [W&B run page](https://docs.wandb.ai/guides/app/pages/run-page). The following will be logged to W&B:\n",
+ "\n",
+ "#### Training and validation metrics\n",
+ "\n",
+ "![image.png](assets/metrics.png)\n",
+ "\n",
+ "#### Training and validation data as W&B Tables\n",
+ "\n",
+ "![image.png](assets/datatable.png)\n",
+ "\n",
+ "#### The data and model artifacts for version control (go to the overview tab)\n",
+ "\n",
+ "![image.png](assets/artifacts.png)\n",
+ "\n",
+ "#### The configuration and hyperparameters (go to the overview tab)\n",
+ "\n",
+ "![image.png](assets/configs.png)\n",
+ "\n",
+ "#### The data and model DAG\n",
+ "\n",
+ "![image.png](assets/dag.png)"
+ ],
+ "metadata": {
+ "id": "fNAzboPD_KP5"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Load the trained model for inference"
+ ],
+ "metadata": {
+ "id": "oSFBU9DB6tkJ"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "#Retrieve fine_tuned_model name\n",
+ "\n",
+ "response = client.fine_tuning.jobs.retrieve(job_id)\n",
+ "\n",
+ "print(response.model_dump_json(indent=2))\n",
+ "fine_tuned_model = response.fine_tuned_model"
+ ],
+ "metadata": {
+ "id": "g2z05Koa05Qx"
+ },
+ "execution_count": null,
+ "outputs": []
+ }
+ ]
+}
\ No newline at end of file