From d39ceb4ca415e6129a8e20a59e85c6c53c396550 Mon Sep 17 00:00:00 2001 From: Anish Shah <93145909+ash0ts@users.noreply.github.com> Date: Tue, 12 Nov 2024 11:55:18 -0500 Subject: [PATCH] azure-ft-example --- colabs/azure/azure_gpt_medical_notes.ipynb | 376 +++++++++++++++++++++ 1 file changed, 376 insertions(+) create mode 100644 colabs/azure/azure_gpt_medical_notes.ipynb diff --git a/colabs/azure/azure_gpt_medical_notes.ipynb b/colabs/azure/azure_gpt_medical_notes.ipynb new file mode 100644 index 00000000..ebdf038c --- /dev/null +++ b/colabs/azure/azure_gpt_medical_notes.ipynb @@ -0,0 +1,376 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Dict, List, Literal, Optional, Tuple\n", + "\n", + "import instructor\n", + "import openai\n", + "import pandas as pd\n", + "import weave\n", + "from pydantic import BaseModel, Field\n", + "from set_env import set_env\n", + "import json\n", + "import asyncio" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "set_env(\"OPENAI_API_KEY\")\n", + "set_env(\"WANDB_API_KEY\")\n", + "set_env(\"AZURE_OPENAI_ENDPOINT\")\n", + "set_env(\"AZURE_OPENAI_API_KEY\")\n", + "print(\"Env set\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "from utils.config import ENTITY, WEAVE_PROJECT" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "weave.init(f\"{ENTITY}/{WEAVE_PROJECT}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "N_SAMPLES = 67" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "client = openai.OpenAI()" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "def load_medical_data(url: str, num_samples: int = N_SAMPLES) -> Tuple[pd.DataFrame, pd.DataFrame]:\n", + " \"\"\"\n", + " Load medical data and split into train and test sets\n", + " \n", + " Args:\n", + " url: URL of the CSV file\n", + " num_samples: Number of samples to load\n", + " \n", + " Returns:\n", + " Tuple of (train_df, test_df)\n", + " \"\"\"\n", + " df = pd.read_csv(url)\n", + " df = df.sample(n=num_samples, random_state=42) # Sample and shuffle data\n", + " \n", + " # Split into 80% train, 20% test\n", + " train_size = int(0.8 * len(df))\n", + " train_df = df[:train_size]\n", + " test_df = df[train_size:]\n", + " \n", + " return train_df, test_df" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "medical_dataset_url = \"https://raw.githubusercontent.com/wyim/aci-bench/main/data/challenge_data/train.csv\"" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "train_df, test_df = load_medical_data(medical_dataset_url)\n", + "train_samples = train_df.to_dict(\"records\")\n", + "test_samples = test_df.to_dict(\"records\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_samples[0]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "test_samples[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "def convert_to_jsonl(df: pd.DataFrame, output_file: str = \"medical_conversations.jsonl\"):\n", + " \"\"\"\n", + " Convert medical dataset to JSONL format with conversation structure\n", + " \n", + " Args:\n", + " df: DataFrame to convert\n", + " output_file: Output JSONL filename\n", + " \"\"\"\n", + " \n", + " with open(output_file, 'w', encoding='utf-8') as f:\n", + " for _, row in df.iterrows():\n", + " # Create the conversation structure\n", + " conversation = {\n", + " \"messages\": [\n", + " {\n", + " \"role\": \"system\",\n", + " \"content\": \"You are a medical scribe assistant. Your task is to accurately document medical conversations between doctors and patients, creating detailed medical notes that capture all relevant clinical information.\"\n", + " },\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": row['dialogue']\n", + " },\n", + " {\n", + " \"role\": \"assistant\",\n", + " \"content\": row['note']\n", + " }\n", + " ]\n", + " }\n", + " \n", + " # Write as JSON line\n", + " json_line = json.dumps(conversation, ensure_ascii=False)\n", + " f.write(json_line + '\\n')\n", + " \n", + " print(f\"Converted {len(df)} records to {output_file}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "convert_to_jsonl(train_df, \"medical_conversations_train.jsonl\")\n", + "convert_to_jsonl(test_df, \"medical_conversations_test.jsonl\")" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "from utils.prompts import medical_task, medical_system_prompt" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "def format_dialogue(dialogue: str):\n", + " dialogue = dialogue.replace(\"\\n\", \" \")\n", + " transcript = f\"Dialogue: {dialogue}\"\n", + " return transcript\n", + "\n", + "\n", + "@weave.op()\n", + "def process_medical_record(dialogue: str) -> Dict:\n", + " transcript = format_dialogue(dialogue)\n", + " prompt = medical_task.format(transcript=transcript)\n", + "\n", + " response = client.chat.completions.create(\n", + " model=\"gpt-3.5-turbo\",\n", + " messages=[\n", + " {\"role\": \"system\", \"content\": medical_system_prompt},\n", + " {\"role\": \"user\", \"content\": prompt},\n", + " ],\n", + " )\n", + "\n", + " extracted_info = response.choices[0].message.content\n", + "\n", + " return {\n", + " \"input\": transcript,\n", + " \"output\": extracted_info,\n", + " }" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "# Define the LLM scoring function\n", + "@weave.op()\n", + "async def medical_note_accuracy(note: str, output: dict) -> dict:\n", + " scoring_prompt = \"\"\"Compare the generated medical note with the ground truth note and evaluate accuracy.\n", + " Score as 1 if the generated note captures the key medical information accurately, 0 if not.\n", + " Output in valid JSON format with just a \"score\" field.\n", + " \n", + " Ground Truth Note:\n", + " {ground_truth}\n", + " \n", + " Generated Note:\n", + " {generated}\"\"\"\n", + " \n", + " prompt = scoring_prompt.format(\n", + " ground_truth=note,\n", + " generated=output['output']\n", + " )\n", + " \n", + " response = client.chat.completions.create(\n", + " model=\"gpt-4o\",\n", + " messages=[{\"role\": \"user\", \"content\": prompt}],\n", + " response_format={ \"type\": \"json_object\" }\n", + " )\n", + " return json.loads(response.choices[0].message.content)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "# Create evaluation for test samples\n", + "test_evaluation = weave.Evaluation(\n", + " name='medical_record_extraction_test',\n", + " dataset=test_samples,\n", + " scorers=[medical_note_accuracy]\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "try:\n", + " in_jupyter = True\n", + "except ImportError:\n", + " in_jupyter = False\n", + "if in_jupyter:\n", + " import nest_asyncio\n", + "\n", + " nest_asyncio.apply()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "test_results = asyncio.run(test_evaluation.evaluate(process_medical_record))\n", + "print(f\"Completed test evaluation\")" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from openai import AzureOpenAI\n", + "\n", + "# Initialize Azure client\n", + "azure_client = AzureOpenAI(\n", + " azure_endpoint = os.getenv(\"AZURE_OPENAI_ENDPOINT\"), \n", + " api_key=os.getenv(\"AZURE_OPENAI_API_KEY\"), \n", + " api_version=\"2024-02-01\"\n", + ")\n", + "\n", + "@weave.op()\n", + "def process_medical_record_azure(dialogue: str) -> Dict:\n", + "\n", + " response = azure_client.chat.completions.create(\n", + " model=\"gpt-35-turbo-0125-ft-d30b3aee14864c29acd9ac54eb92457f\",\n", + " messages=[\n", + " {\"role\": \"system\", \"content\": \"You are a medical scribe assistant. Your task is to accurately document medical conversations between doctors and patients, creating detailed medical notes that capture all relevant clinical information.\"},\n", + " {\"role\": \"user\", \"content\": dialogue},\n", + " ],\n", + " )\n", + "\n", + " extracted_info = response.choices[0].message.content\n", + "\n", + " return {\n", + " \"input\": dialogue,\n", + " \"output\": extracted_info,\n", + " }" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "test_results_azure = asyncio.run(test_evaluation.evaluate(process_medical_record_azure))\n", + "print(f\"Completed test evaluation\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}