diff --git a/.github/nb_scripts/fix_login.ipynb b/.github/nb_scripts/fix_login.ipynb
new file mode 100644
index 00000000..2559088f
--- /dev/null
+++ b/.github/nb_scripts/fix_login.ipynb
@@ -0,0 +1,247 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import nbformat\n",
+ "from pathlib import Path\n",
+ "from nb_helpers.utils import find_nbs"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "152"
+ ]
+ },
+ "execution_count": null,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "nbs_paths = find_nbs(\"../../\")\n",
+ "len(nbs_paths)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Path('/Users/tcapelle/work/examples/colabs/boosting/Credit_Scorecards_with_XGBoost_and_W&B.ipynb')"
+ ]
+ },
+ "execution_count": null,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "nb_path = nbs_paths[3]\n",
+ "nb_path\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "wandb.login()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "nb = nbformat.read(nb_path, 4)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "18"
+ ]
+ },
+ "execution_count": null,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "login_idx = idx_login_cell(nb, delete_line=False)\n",
+ "login_idx"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "login_code = \"wandb.login()\" # or !wandb login\n",
+ "\n",
+ "def insert_login_cell(nb, idx, code=login_code, replace=True):\n",
+ " n_cells = len(nb[\"cells\"])\n",
+ " login_cell = nbformat.v4.new_code_cell(source=code)\n",
+ " nb[\"cells\"].insert(idx+1, login_cell)\n",
+ " return nb"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "nb = insert_login_cell(nb, login_idx)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'cell_type': 'code',\n",
+ " 'execution_count': None,\n",
+ " 'metadata': {},\n",
+ " 'outputs': [],\n",
+ " 'source': '!wandb login'}"
+ ]
+ },
+ "execution_count": null,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "nb[\"cells\"][login_idx] "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'id': 'ff6708eb',\n",
+ " 'cell_type': 'code',\n",
+ " 'metadata': {},\n",
+ " 'execution_count': None,\n",
+ " 'source': '!wandb login',\n",
+ " 'outputs': []}"
+ ]
+ },
+ "execution_count": null,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "nb[\"cells\"][login_idx+1] "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'attachments': {},\n",
+ " 'cell_type': 'markdown',\n",
+ " 'metadata': {},\n",
+ " 'source': '## Vehicle Loan Dataset\\n\\nWe will be using a simplified version of the [Vehicle Loan Default Prediction dataset](https://www.kaggle.com/sneharshinde/ltfs-av-data) from L&T which has been stored in W&B Artifacts. '}"
+ ]
+ },
+ "execution_count": null,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "nb[\"cells\"][login_idx+2] "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def remove_properties(nb, props=[\"id\", \"idx_\", \"path_\"]):\n",
+ " for cell in nb[\"cells\"]:\n",
+ " for prop in props:\n",
+ " cell.pop(prop, None)\n",
+ " return nb\n",
+ "\n",
+ "def remove_empty_cells(nb):\n",
+ " nb[\"cells\"] = [cell for cell in nb[\"cells\"] if cell[\"source\"] != \"\"]\n",
+ " return nb"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "nb = remove_properties(nb)\n",
+ "nb = remove_empty_cells(nb)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "for nb_path in nbs_paths:\n",
+ " nb = nbformat.read(nb_path, 4)\n",
+ " login_idx = idx_login_cell(nb)\n",
+ " if login_idx is not None:\n",
+ " nb = insert_login_cell(nb, login_idx)\n",
+ " # delete path_ attribute\n",
+ " nb.pop(\"path_\", None)\n",
+ " nb = remove_properties(nb)\n",
+ " nb = remove_empty_cells(nb)\n",
+ " nbformat.validate(nb)\n",
+ " nbformat.write(nb, nb_path)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "nbdev2",
+ "language": "python",
+ "name": "python3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/W&B_Prompts_with_Custom_Columns.ipynb b/W&B_Prompts_with_Custom_Columns.ipynb
deleted file mode 100644
index f1252f0c..00000000
--- a/W&B_Prompts_with_Custom_Columns.ipynb
+++ /dev/null
@@ -1,632 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "e-ZYaV5KGVmA"
- },
- "source": [
- "\n",
- ""
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "gJSVEAGWGVmA"
- },
- "source": [
- "\n",
- ""
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "9f7yMKLwGVmA"
- },
- "source": [
- "**[Weights & Biases Prompts](https://docs.wandb.ai/guides/prompts?utm_source=code&utm_medium=colab&utm_campaign=prompts)** is a suite of LLMOps tools built for the development of LLM-powered applications.\n",
- "\n",
- "Use W&B Prompts to visualize and inspect the execution flow of your LLMs, analyze the inputs and outputs of your LLMs, view the intermediate results and securely store and manage your prompts and LLM chain configurations.\n",
- "\n",
- "#### [🪄 View Prompts In Action](https://wandb.ai/timssweeney/prompts-demo/)\n",
- "\n",
- "**In this notebook we will demostrate W&B Prompts:**\n",
- "\n",
- "- Using our 1-line LangChain integration\n",
- "- Using our Trace class when building your own LLM Pipelines\n",
- "\n",
- "See here for the full [W&B Prompts documentation](https://docs.wandb.ai/guides/prompts)\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "A4wI3b_8GVmB"
- },
- "source": [
- "## Installation"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {
- "id": "nDoIqQ8_GVmB"
- },
- "outputs": [],
- "source": [
- "!pip install \"wandb>=0.15.4\" -qqq\n",
- "!pip install \"langchain>=0.0.218\" openai -qqq"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {
- "id": "PcGiSWBSGVmB"
- },
- "outputs": [],
- "source": [
- "import langchain\n",
- "assert langchain.__version__ >= \"0.0.218\", \"Please ensure you are using LangChain v0.0.188 or higher\""
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "pbmQIsjJGVmB"
- },
- "source": [
- "## Setup\n",
- "\n",
- "This demo requires that you have an [OpenAI key](https://platform.openai.com)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {
- "id": "ZH4g2B0lGVmB",
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "outputId": "22295db6-5369-474d-a8ea-fb45c4c92085"
- },
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "Paste your OpenAI key from: https://platform.openai.com/account/api-keys\n",
- "··········\n",
- "OpenAI API key configured\n"
- ]
- }
- ],
- "source": [
- "import os\n",
- "from getpass import getpass\n",
- "\n",
- "if os.getenv(\"OPENAI_API_KEY\") is None:\n",
- " os.environ[\"OPENAI_API_KEY\"] = getpass(\"Paste your OpenAI key from: https://platform.openai.com/account/api-keys\\n\")\n",
- "assert os.getenv(\"OPENAI_API_KEY\", \"\").startswith(\"sk-\"), \"This doesn't look like a valid OpenAI API key\"\n",
- "print(\"OpenAI API key configured\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "79KOB2EhGVmB"
- },
- "source": [
- "# W&B Prompts\n",
- "\n",
- "W&B Prompts consists of three main components:\n",
- "\n",
- "**Trace table**: Overview of the inputs and outputs of a chain.\n",
- "\n",
- "**Trace timeline**: Displays the execution flow of the chain and is color-coded according to component types.\n",
- "\n",
- "**Model architecture**: View details about the structure of the chain and the parameters used to initialize each component of the chain.\n",
- "\n",
- "After running this section, you will see a new panel automatically created in your workspace, showing each execution, the trace, and the model architecture"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "5kxmdm3zGVmC"
- },
- "source": [
- ""
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "9u97K5vVGVmC"
- },
- "source": [
- "## Maths with LangChain"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "oneRFmv6GVmC"
- },
- "source": [
- "Set the `LANGCHAIN_WANDB_TRACING` environment variable as well as any other relevant [W&B environment variables](https://docs.wandb.ai/guides/track/environment-variables). This could includes a W&B project name, team name, and more. See [wandb.init](https://docs.wandb.ai/ref/python/init) for a full list of arguments."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {
- "id": "ACl-rMtAGVmC"
- },
- "outputs": [],
- "source": [
- "os.environ[\"LANGCHAIN_WANDB_TRACING\"] = \"true\"\n",
- "os.environ[\"WANDB_PROJECT\"] = \"langchain-testing\""
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {
- "id": "csp3MXG4GVmC"
- },
- "outputs": [],
- "source": [
- "from langchain.chat_models import ChatOpenAI\n",
- "from langchain.agents import load_tools, initialize_agent, AgentType"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "2hWU2GcAGVmC"
- },
- "source": [
- "Create a standard math Agent using LangChain"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {
- "id": "l_JkVMlRGVmC"
- },
- "outputs": [],
- "source": [
- "llm = ChatOpenAI(temperature=0)\n",
- "tools = load_tools([\"llm-math\"], llm=llm)\n",
- "math_agent = initialize_agent(tools,\n",
- " llm,\n",
- " agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "9FFviwCPGVmC"
- },
- "source": [
- "Use LangChain as normal by calling your Agent.\n",
- "\n",
- " You will see a Weights & Biases run start and you will be asked for your [Weights & Biases API key](wwww.wandb.ai/authorize). Once your enter your API key, the inputs and outputs of your Agent calls will start to be streamed to the Weights & Biases App."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {
- "id": "y-RHjVN4GVmC",
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 178
- },
- "outputId": "5ccd5f32-6137-46c3-9abd-d458dbdbacca"
- },
- "outputs": [
- {
- "output_type": "stream",
- "name": "stderr",
- "text": [
- "\u001b[34m\u001b[1mwandb\u001b[0m: Streaming LangChain activity to W&B at https://wandb.ai/carey/langchain-testing/runs/lcznj5lg\n",
- "\u001b[34m\u001b[1mwandb\u001b[0m: `WandbTracer` is currently in beta.\n",
- "\u001b[34m\u001b[1mwandb\u001b[0m: Please report any issues to https://github.com/wandb/wandb/issues with the tag `langchain`.\n"
- ]
- },
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "LLMMathChain._evaluate(\"\n",
- "import math\n",
- "math.sqrt(5.4)\n",
- "\") raised error: invalid syntax (, line 1). Please try again with a valid numerical expression\n",
- "0.005720801417544866\n",
- "0.15096209512635608\n"
- ]
- }
- ],
- "source": [
- "# some sample maths questions\n",
- "questions = [\n",
- " \"Find the square root of 5.4.\",\n",
- " \"What is 3 divided by 7.34 raised to the power of pi?\",\n",
- " \"What is the sin of 0.47 radians, divided by the cube root of 27?\"\n",
- "]\n",
- "\n",
- "for question in questions:\n",
- " try:\n",
- " # call your Agent as normal\n",
- " answer = math_agent.run(question)\n",
- " print(answer)\n",
- " except Exception as e:\n",
- " # any errors will be also logged to Weights & Biases\n",
- " print(e)\n",
- " pass"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "SNYFSaUrGVmC"
- },
- "source": [
- "Once each Agent execution completes, all calls in your LangChain object will be logged to Weights & Biases"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "m0bL1xpkGVmC"
- },
- "source": [
- "### LangChain Context Manager\n",
- "Depending on your use case, you might instead prefer to use a context manager to manage your logging to W&B.\n",
- "\n",
- "**✨ New: Custom columns** can be logged directly to W&B to display in the same Trace Table with this snippet:\n",
- "```python\n",
- "import wandb\n",
- "wandb.log(custom_metrics_dict, commit=False})\n",
- "```\n",
- "Use `commit=False` to make sure that metadata is logged to the same row of the Trace Table as the LangChain output."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 10,
- "metadata": {
- "id": "7i9Pj1NKGVmC",
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 35
- },
- "outputId": "b44f3ae7-fd49-437f-af7b-fb8f82056bd0"
- },
- "outputs": [
- {
- "output_type": "execute_result",
- "data": {
- "text/plain": [
- "'1.0891804557407723'"
- ],
- "application/vnd.google.colaboratory.intrinsic+json": {
- "type": "string"
- }
- },
- "metadata": {},
- "execution_count": 10
- }
- ],
- "source": [
- "from langchain.callbacks import wandb_tracing_enabled\n",
- "import wandb # To enable custom column logging with wandb.run.log()\n",
- "\n",
- "# unset the environment variable and use a context manager instead\n",
- "if \"LANGCHAIN_WANDB_TRACING\" in os.environ:\n",
- " del os.environ[\"LANGCHAIN_WANDB_TRACING\"]\n",
- "\n",
- "# enable tracing using a context manager\n",
- "with wandb_tracing_enabled():\n",
- " for i in range (10):\n",
- " # Log any custom columns you'd like to add to the Trace Table\n",
- " wandb.log({\"custom_column\": i}, commit=False)\n",
- " try:\n",
- " math_agent.run(f\"What is {i} raised to .123243 power?\") # this should be traced\n",
- " except:\n",
- " pass\n",
- "\n",
- "math_agent.run(\"What is 2 raised to .123243 power?\") # this should not be traced"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "JDLzoorhGVmC"
- },
- "source": [
- "# Non-Lang Chain Implementation\n",
- "\n",
- "\n",
- "A W&B Trace is created by logging 1 or more \"spans\". A root span is expected, which can accept nested child spans, which can in turn accept their own child spans. A Span represents a unit of work, Spans can have type `AGENT`, `TOOL`, `LLM` or `CHAIN`\n",
- "\n",
- "When logging with Trace, a single W&B run can have multiple calls to a LLM, Tool, Chain or Agent logged to it, there is no need to start a new W&B run after each generation from your model or pipeline, instead each call will be appended to the Trace Table.\n",
- "\n",
- "In this quickstart, we will how to log a single call to an OpenAI model to W&B Trace as a single span. Then we will show how to log a more complex series of nested spans.\n",
- "\n",
- "## Logging with W&B Trace\n",
- "A high-level Trace api is available from the [`wandb-addon`](https://github.com/soumik12345/wandb-addons) community library from [@soumik12345](https://github.com/soumik12345). This will be replaced by a wandb-native integration shortly."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "FO3Kf2ngGVmC"
- },
- "outputs": [],
- "source": [
- "# Install wandb-addons\n",
- "!git clone https://github.com/soumik12345/wandb-addons.git\n",
- "!pip install ./wandb-addons[prompts] openai wandb -qqq"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "7z98yfoqGVmD"
- },
- "source": [
- "Call wandb.init to start a W&B run. Here you can pass a W&B project name as well as an entity name (if logging to a W&B Team), as well as a config and more. See wandb.init for the full list of arguments.\n",
- "\n",
- "You will see a Weights & Biases run start and be asked for your [Weights & Biases API key](wwww.wandb.ai/authorize). Once your enter your API key, the inputs and outputs of your Agent calls will start to be streamed to the Weights & Biases App.\n",
- "\n",
- "**Note:** A W&B run supports logging as many traces you needed to a single run, i.e. you can make multiple calls of `run.log` without the need to create a new run each time"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "ZcvgzZ55GVmD"
- },
- "outputs": [],
- "source": [
- "import wandb\n",
- "\n",
- "# start a wandb run to log to\n",
- "wandb.init(project=\"trace-example\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "4_3Wrg2YGVmD"
- },
- "source": [
- "You can also set the entity argument in wandb.init if logging to a W&B Team.\n",
- "\n",
- "### Logging a single Span\n",
- "Now we will query OpenAI times and log the results to a W&B Trace. We will log the inputs and outputs, start and end times, whether the OpenAI call was successful, the token usage, and additional metadata.\n",
- "\n",
- "You can see the full description of the arguments to the Trace class [here](https://soumik12345.github.io/wandb-addons/prompts/tracer/)."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "q2pkMhpMGVmD"
- },
- "outputs": [],
- "source": [
- "import openai\n",
- "import datetime\n",
- "from wandb_addons.prompts import Trace\n",
- "\n",
- "openai.api_key = os.environ[\"OPENAI_API_KEY\"]\n",
- "\n",
- "# define your conifg\n",
- "model_name = \"gpt-3.5-turbo\"\n",
- "temperature = 0.7\n",
- "system_message = \"You are a helpful assistant that always replies in 3 concise bullet points using markdown.\"\n",
- "\n",
- "queries_ls = [\n",
- " \"What is the capital of France?\",\n",
- " \"How do I boil an egg?\" * 10000, # deliberately trigger an openai error\n",
- " \"What to do if the aliens arrive?\"\n",
- "]\n",
- "\n",
- "for query in queries_ls:\n",
- " messages=[\n",
- " {\"role\": \"system\", \"content\": system_message},\n",
- " {\"role\": \"user\", \"content\": query}\n",
- " ]\n",
- "\n",
- " start_time_ms = datetime.datetime.now().timestamp() * 1000\n",
- " try:\n",
- " response = openai.ChatCompletion.create(model=model_name,\n",
- " messages=messages,\n",
- " temperature=temperature\n",
- " )\n",
- "\n",
- " end_time_ms = round(datetime.datetime.now().timestamp() * 1000) # logged in milliseconds\n",
- " status=\"success\"\n",
- " status_message=None,\n",
- " response_text = response[\"choices\"][0][\"message\"][\"content\"]\n",
- " token_usage = response[\"usage\"].to_dict()\n",
- "\n",
- "\n",
- " except Exception as e:\n",
- " end_time_ms = round(datetime.datetime.now().timestamp() * 1000) # logged in milliseconds\n",
- " status=\"error\"\n",
- " status_message=str(e)\n",
- " response_text = \"\"\n",
- " token_usage = {}\n",
- "\n",
- " # create a span in wandb\n",
- " root_span = Trace(\n",
- " name=\"root_span\",\n",
- " kind=\"llm\", # kind can be \"llm\", \"chain\", \"agent\" or \"tool\"\n",
- " status_code=status,\n",
- " status_message=status_message,\n",
- " metadata={\"temperature\": temperature,\n",
- " \"token_usage\": token_usage,\n",
- " \"model_name\": model_name},\n",
- " start_time_ms=start_time_ms,\n",
- " end_time_ms=end_time_ms,\n",
- " inputs={\"system_prompt\": system_message, \"query\": query},\n",
- " outputs={\"response\": response_text},\n",
- " )\n",
- "\n",
- " # log the span to wandb\n",
- " root_span.log(name=\"openai_trace\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "XFcwFgaDGVmD"
- },
- "source": [
- "### Logging a LLM pipeline using nested Spans\n",
- "\n",
- "In this example we will simulate an Agent being called, which then calls a LLM Chain, which calls an OpenAI LLM and then the Agent \"calls\" a Calculator tool.\n",
- "\n",
- "The inputs, outputs and metadata for each step in the execution of our \"Agent\" is logged in its own span. Spans can have child"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "ACMaGuYUGVmD"
- },
- "outputs": [],
- "source": [
- "import time\n",
- "\n",
- "openai.api_key = os.environ[\"OPENAI_API_KEY\"]\n",
- "\n",
- "# The query our agent has to answer\n",
- "query = \"How many days until the next US election?\"\n",
- "\n",
- "# part 1 - an Agent is started...\n",
- "start_time_ms = round(datetime.datetime.now().timestamp() * 1000)\n",
- "\n",
- "root_span = Trace(\n",
- " name=\"MyAgent\",\n",
- " kind=\"agent\",\n",
- " start_time_ms=start_time_ms,\n",
- " metadata={\"user\": \"optimus_12\"})\n",
- "\n",
- "\n",
- "# part 2 - The Agent calls into a LLMChain..\n",
- "chain_span = Trace(\n",
- " name=\"LLMChain\",\n",
- " kind=\"chain\",\n",
- " start_time_ms=start_time_ms)\n",
- "\n",
- "# add the Chain span as a child of the root\n",
- "root_span.add_child(chain_span)\n",
- "\n",
- "\n",
- "# part 3 - the LLMChain calls an OpenAI LLM...\n",
- "messages=[\n",
- " {\"role\": \"system\", \"content\": system_message},\n",
- " {\"role\": \"user\", \"content\": query}\n",
- "]\n",
- "\n",
- "response = openai.ChatCompletion.create(model=model_name,\n",
- " messages=messages,\n",
- " temperature=temperature)\n",
- "\n",
- "llm_end_time_ms = round(datetime.datetime.now().timestamp() * 1000)\n",
- "response_text = response[\"choices\"][0][\"message\"][\"content\"]\n",
- "token_usage = response[\"usage\"].to_dict()\n",
- "\n",
- "llm_span = Trace(\n",
- " name=\"OpenAI\",\n",
- " kind=\"llm\",\n",
- " status_code=\"success\",\n",
- " metadata={\"temperature\":temperature,\n",
- " \"token_usage\": token_usage,\n",
- " \"model_name\":model_name},\n",
- " start_time_ms=start_time_ms,\n",
- " end_time_ms=llm_end_time_ms,\n",
- " inputs={\"system_prompt\":system_message, \"query\":query},\n",
- " outputs={\"response\": response_text},\n",
- " )\n",
- "\n",
- "# add the LLM span as a child of the Chain span...\n",
- "chain_span.add_child(llm_span)\n",
- "\n",
- "# update the end time of the Chain span\n",
- "chain_span.add_inputs_and_outputs(\n",
- " inputs={\"query\":query},\n",
- " outputs={\"response\": response_text})\n",
- "\n",
- "# update the Chain span's end time\n",
- "chain_span._span.end_time_ms = llm_end_time_ms\n",
- "\n",
- "\n",
- "# part 4 - the Agent then calls a Tool...\n",
- "time.sleep(3)\n",
- "days_to_election = 117\n",
- "tool_end_time_ms = round(datetime.datetime.now().timestamp() * 1000)\n",
- "\n",
- "# create a Tool span\n",
- "tool_span = Trace(\n",
- " name=\"Calculator\",\n",
- " kind=\"tool\",\n",
- " status_code=\"success\",\n",
- " start_time_ms=llm_end_time_ms,\n",
- " end_time_ms=tool_end_time_ms,\n",
- " inputs={\"input\": response_text},\n",
- " outputs={\"result\": days_to_election})\n",
- "\n",
- "# add the TOOL span as a child of the root\n",
- "root_span.add_child(tool_span)\n",
- "\n",
- "\n",
- "# part 5 - the final results from the tool are added\n",
- "root_span.add_inputs_and_outputs(inputs={\"query\": query},\n",
- " outputs={\"result\": days_to_election})\n",
- "root_span._span.end_time_ms = tool_end_time_ms\n",
- "\n",
- "\n",
- "# part 6 - log all spans to W&B by logging the root span\n",
- "root_span.log(name=\"openai_trace\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "nBFVwawPGVmD"
- },
- "source": [
- "Once each Agent execution completes, all calls in your LangChain object will be logged to Weights & Biases"
- ]
- }
- ],
- "metadata": {
- "accelerator": "GPU",
- "colab": {
- "provenance": [],
- "include_colab_link": true
- },
- "kernelspec": {
- "display_name": "Python 3",
- "name": "python3"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 0
-}
diff --git a/colabs/README.md b/colabs/README.md
index 7913dc56..d6144c25 100644
--- a/colabs/README.md
+++ b/colabs/README.md
@@ -26,7 +26,10 @@
| Kaolin-Wisp | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](http://wandb.me/vqad-colab) |
| Super Gradients | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](http://wandb.me/yolo-nas-colab) |
| 🎸 Generating music with AudioCraft | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/wandb/examples/blob/master/colabs/audiocraft/AudioCraft.ipynb) |
-
+| 🦄 Fine-tune a Torchvision Model with KerasCore | [![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/wandb/examples/blob/master/colabs/keras/keras_core/torchvision_keras.ipynb) |
+| 🦄 Fine-tune a Timm Model with KerasCore | [![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/wandb/examples/blob/master/colabs/keras/keras_core/timm_keras.ipynb) |
+| 🦄 Medical Image Classification Tutorial using MonAI and KerasCore | [![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/wandb/examples/blob/master/colabs/keras/keras_core/monai_medmnist_keras.ipynb) |
+| 🩻 Brain tumor 3D segmentation with MONAI and Weights & Biases | [![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/wandb/examples/blob/main/colabs/monai/3d_brain_tumor_segmentation.ipynb) |
# 🏋🏽♂️ W&B Features
diff --git a/colabs/audiocraft/AudioCraft.ipynb b/colabs/audiocraft/AudioCraft.ipynb
index e18c8832..af385a28 100644
--- a/colabs/audiocraft/AudioCraft.ipynb
+++ b/colabs/audiocraft/AudioCraft.ipynb
@@ -1,431 +1,392 @@
{
- "cells": [
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {
- "id": "W-26KlXuiXul"
- },
- "source": [
- "\n",
- "\n",
- "\n",
- "# 🎸 Generating Music using [Audiocraft](https://github.com/facebookresearch/audiocraft) and W&B 🐝\n",
- "\n",
- "\n",
- "\n",
- "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/wandb/examples/blob/master/colabs/audiocraft/AudioCraft.ipynb)\n",
- "\n",
- "In this notebook we demonstrate how you can generate music and other types of audio from text prompts or generate new music from existing music using SoTA models such as [MusicGen](https://github.com/facebookresearch/audiocraft/blob/main/docs/MUSICGEN.md) and [AudioGen](https://github.com/facebookresearch/audiocraft/blob/main/docs/AUDIOGEN.md) from [Audiocraft](https://github.com/facebookresearch/audiocraft) and play and visualize them using [Weights & Biases](https://wandb.ai/site).\n",
- "\n",
- "If you want to know more about the underlying architectures for MusicGen and AudioGen and explore some cool audio samples generated by these models, you can check out [this W&B report](http://wandb.me/audiocraft_2mp)."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "cellView": "form",
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 1000
- },
- "id": "EZU3hg4B1om6",
- "outputId": "3311a45d-35c3-49e8-cbd5-4618386fa2a1"
- },
- "outputs": [],
- "source": [
- "# @title Install AudioCraft + WandB\n",
- "!pip install -U git+https://git@github.com/facebookresearch/audiocraft#egg=audiocraft\n",
- "!pip install -qq -U wandb"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "cellView": "form",
- "id": "RerQaiZt14r8"
- },
- "outputs": [],
- "source": [
- "# @title\n",
- "import os\n",
- "import random\n",
- "from tempfile import TemporaryDirectory\n",
- "\n",
- "from scipy import signal\n",
- "from scipy.io import wavfile\n",
- "\n",
- "import torchaudio\n",
- "from audiocraft.models import AudioGen, MusicGen, MultiBandDiffusion\n",
- "from audiocraft.data.audio import audio_write\n",
- "\n",
- "import wandb\n",
- "import numpy as np\n",
- "from tqdm.auto import tqdm\n",
- "from google.colab import files\n",
- "import matplotlib.pyplot as plt"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "cellView": "form",
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 162
- },
- "id": "3MTX8GoE7AzN",
- "outputId": "ebe733d0-3a21-47e5-d217-89a622cafc62"
- },
- "outputs": [],
- "source": [
- "# @title ## Audio Generation Configs\n",
- "\n",
- "# @markdown In this section, you can interact with the user interface to chose the models you want to use to generate audio, prompts and other configs. Once you execute this cell, it initializes a [wandb run](https://docs.wandb.ai/guides/runs) which will be used to automatically log all the generated audio along with all the prompts and configs, to ensure your AI-generated music is never lost and your experiments are always reproducible and easy to share. \n",
- "\n",
- "# @markdown **Note:** If you have provided prompts, you will be prompted to provide an audio file in addition to the prompts to condition the model. If you don't want to provide a file as an additional condition to the model, just press on the `cancel` button.\n",
- "\n",
- "# @markdown ---\n",
- "# @markdown WandB Project Name\n",
- "project_name = \"audiocraft\" # @param {type:\"string\"}\n",
- "\n",
- "wandb.init(project=project_name, job_type=\"musicgen/inference\")\n",
- "\n",
- "config = wandb.config\n",
- "\n",
- "# @markdown Select the Model for audio generation supported by [AudioCraft](https://github.com/facebookresearch/audiocraft). You can select either the MusicGen model variants (great for generating music) or the AudioGen model variants (great for generating non-musical audio). Also note that you can run all variants of MusicGen except the `large` one on the free-tier Colab GPU.\n",
- "model_name = \"musicgen-small\" # @param [\"musicgen-small\", \"musicgen-medium\", \"musicgen-large\", \"musicgen-melody\", \"audiogen-medium\"]\n",
- "config.model_name = \"facebook/\" + model_name if model_name == \"audiogen-medium\" else model_name\n",
- "\n",
- "# @markdown Whether to enable [MultiBand Diffusion](https://github.com/facebookresearch/audiocraft/blob/main/docs/MBD.md) or not. MultiBand diffusion is a collection of 4 models that can decode tokens from EnCodec tokenizer into waveform audio. Note that enabling this increases the time required to generate the audio.\n",
- "enable_multi_band_diffusion = True # @param {type:\"boolean\"}\n",
- "# config.enable_multi_band_diffusion = enable_multi_band_diffusion\n",
- "\n",
- "if \"musicgen\" not in model_name:\n",
- " wandb.termwarn(\"Multi-band Diffusion is only available for Musicgen\")\n",
- " config.enable_multi_band_diffusion = False\n",
- "else:\n",
- " config.enable_multi_band_diffusion = enable_multi_band_diffusion\n",
- "\n",
- "# @markdown ---\n",
- "# @markdown ## Conditional Generation Configs\n",
- "\n",
- "# @markdown The prompt for generating audio. You can give multiple prompts separated by `|` in the input. You can also leave it blank for unconditional generation.\n",
- "config.prompts = \"happy rock | energetic EDM | sad jazz\" # @param {type:\"string\"}\n",
- "\n",
- "descriptions = [prompt.strip() for prompt in config.prompts.split(\"|\")]\n",
- "config.is_unconditional = config.prompts.strip() == \"\"\n",
- "\n",
- "input_audio, input_sampling_rate, wandb_input_audio = None, None, None\n",
- "if not config.is_unconditional:\n",
- " input_audio_file = files.upload()\n",
- " if input_audio_file != {}:\n",
- " if config.model_name == \"facebook/audiogen-medium\":\n",
- " error = f\"{config.model_name} does not support audio-based conditioning\"\n",
- " raise ValueError(error)\n",
- " wandb_input_audio = wandb.Audio(list(input_audio_file.keys())[0])\n",
- " input_audio, input_sampling_rate = torchaudio.load(\n",
- " list(input_audio_file.keys())[0]\n",
- " )\n",
- " config.input_audio_available = True\n",
- " else:\n",
- " config.input_audio_available = False\n",
- "else:\n",
- " if config.model_name == \"facebook/audiogen-medium\":\n",
- " error = f\"{config.model_name} does not support unconditional generration\"\n",
- " raise ValueError(error)\n",
- "\n",
- "# @markdown Number of audio samples generated, this is relevant only for unconditional generation, i.e, if `config.prompts` is left blank.\n",
- "config.num_samples = 4 # @param {type:\"slider\", min:1, max:10, step:1}\n",
- "\n",
- "# @markdown Specify the random seed\n",
- "seed = None # @param {type:\"raw\"}\n",
- "\n",
- "max_seed = int(1024 * 1024 * 1024)\n",
- "if not isinstance(seed, int):\n",
- " seed = random.randint(1, max_seed)\n",
- "if seed < 0:\n",
- " seed = - seed\n",
- "seed = seed % max_seed\n",
- "config.seed = seed\n",
- "\n",
- "# @markdown ---\n",
- "# @markdown ## Generation Parameters\n",
- "# @markdown Use sampling if True, else do argmax decoding\n",
- "config.use_sampling = True # @param {type:\"boolean\"}\n",
- "\n",
- "# @markdown `top_k` used for sampling; limits us to `k` number of of the top tokens to consider.\n",
- "config.top_k = 250 # @param {type:\"slider\", min:0, max:1000, step:1}\n",
- "\n",
- "# @markdown `top_p` used for sampling; limits us to the top tokens within a probability mass `p`\n",
- "config.top_p = 0.0 # @param {type:\"slider\", min:0, max:1.0, step:0.01}\n",
- "\n",
- "# @markdown Softmax temperature parameter\n",
- "config.temperature = 1.0 # @param {type:\"slider\", min:0, max:1.0, step:0.01}\n",
- "\n",
- "# @markdown Duration of the generated waveform\n",
- "config.duration = 10 # @param {type:\"slider\", min:1, max:30, step:1}\n",
- "\n",
- "# @markdown Coefficient used for classifier free guidance\n",
- "config.cfg_coef = 3 # @param {type:\"slider\", min:1, max:100, step:1}\n",
- "\n",
- "# @markdown Whether to perform 2 forward for Classifier Free Guidance instead of batching together the two. This has some impact on how things are padded but seems to have little impact in practice.\n",
- "config.two_step_cfg = False # @param {type:\"boolean\"}\n",
- "\n",
- "# @markdown When doing extended generation (i.e. more than 30 seconds), by how much should we extend the audio each time. Larger values will mean less context is preserved, and shorter value will require extra computations.\n",
- "config.extend_stride = 0 # @param {type:\"slider\", min:0, max:30, step:1}"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "cellView": "form",
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "SfM8rhVX7ES9",
- "outputId": "a935173b-382a-4514-97fc-eec12e188379"
- },
- "outputs": [],
- "source": [
- "# @title Generate Audio using MusicGen\n",
- "\n",
- "# @markdown In this section, the audio is generated using the configs, specified in the aforementioned section. If you wish to peek behind the curtain and checkout the code, click on the `Show Code` button. In order to know about the different APIs for audio generation, visit the [official audiocraft documentations](https://facebookresearch.github.io/audiocraft/api_docs/audiocraft/index.html).\n",
- "\n",
- "model = None\n",
- "if config.model_name == \"facebook/audiogen-medium\":\n",
- " model = AudioGen.get_pretrained(config.model_name)\n",
- "elif \"musicgen\" in config.model_name:\n",
- " model = MusicGen.get_pretrained(config.model_name.split(\"-\")[-1])\n",
- "\n",
- "multi_band_diffusion = None\n",
- "if config.enable_multi_band_diffusion:\n",
- " multi_band_diffusion = MultiBandDiffusion.get_mbd_musicgen()\n",
- "\n",
- "model.set_generation_params(\n",
- " use_sampling=config.use_sampling,\n",
- " top_k=config.top_k,\n",
- " top_p=config.top_p,\n",
- " temperature=config.temperature,\n",
- " duration=config.duration,\n",
- " cfg_coef=config.cfg_coef,\n",
- " two_step_cfg=config.two_step_cfg,\n",
- " extend_stride=config.extend_stride\n",
- ")\n",
- "\n",
- "generated_wav, tokens = None, None\n",
- "if config.is_unconditional:\n",
- " if input_audio is None:\n",
- " if \"musicgen\" in config.model_name:\n",
- " generated_wav, tokens = model.generate_unconditional(\n",
- " num_samples=config.num_samples,\n",
- " progress=True,\n",
- " return_tokens=True\n",
- " )\n",
- " else:\n",
- " generated_wav = model.generate_unconditional(\n",
- " num_samples=config.num_samples,\n",
- " progress=True,\n",
- " )\n",
- " else:\n",
- " if \"musicgen\" in config.model_name:\n",
- " generated_wav, tokens = model.generate_with_chroma(\n",
- " descriptions,\n",
- " input_audio[None].expand(3, -1, -1),\n",
- " input_sampling_rate,\n",
- " return_tokens=True\n",
- " )\n",
- " else:\n",
- " generated_wav = model.generate_with_chroma(\n",
- " descriptions,\n",
- " input_audio[None].expand(3, -1, -1),\n",
- " input_sampling_rate,\n",
- " )\n",
- "else:\n",
- " if \"musicgen\" in config.model_name:\n",
- " generated_wav, tokens = model.generate(\n",
- " descriptions,\n",
- " progress=True,\n",
- " return_tokens=True\n",
- " )\n",
- " else:\n",
- " generated_wav = model.generate(\n",
- " descriptions,\n",
- " progress=True,\n",
- " )\n",
- "\n",
- "generated_wav_diffusion = None\n",
- "if config.enable_multi_band_diffusion:\n",
- " generated_wav_diffusion = multi_band_diffusion.tokens_to_wav(tokens)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "cellView": "form",
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 104
- },
- "id": "_n-1RthFVPYN",
- "outputId": "ac7bbc70-8114-4ef7-ef64-ecae9ba898cf"
- },
- "outputs": [],
- "source": [
- "# @title Log Audio to Weights & Biases Dashboard\n",
- "\n",
- "# @markdown In this section, we log the generated audio to Weights & Biases where you can listen and visualize them using an interactive audio player and waveform visualizer. Also, shoutout to [Atanu Sarkar](https://github.com/mratanusarkar) for building the spectrogram viusalization function which lets you visualize the spectrogram of the generated audio inside a [`wandb.Table`](https://docs.wandb.ai/guides/tables/tables-walkthrough).\n",
- "\n",
- "def get_spectrogram(audio_file, output_file):\n",
- " sample_rate, samples = wavfile.read(audio_file)\n",
- " frequencies, times, Sxx = signal.spectrogram(samples, sample_rate)\n",
- "\n",
- " log_Sxx = 10 * np.log10(Sxx + 1e-10)\n",
- " vmin = np.percentile(log_Sxx, 5)\n",
- " vmax = np.percentile(log_Sxx, 95)\n",
- "\n",
- " mean_spectrum = np.mean(log_Sxx, axis=1)\n",
- " threshold_low = np.percentile(mean_spectrum, 5)\n",
- " threshold_high = np.percentile(mean_spectrum, 95)\n",
- "\n",
- " freq_indices = np.where(mean_spectrum > threshold_low)\n",
- " freq_min = 20\n",
- " freq_max = frequencies[freq_indices].max()\n",
- "\n",
- " fig, ax = plt.subplots()\n",
- " cmap = plt.get_cmap('magma')\n",
- "\n",
- " ax.pcolormesh(\n",
- " times,\n",
- " frequencies,\n",
- " log_Sxx,\n",
- " shading='gouraud',\n",
- " cmap=cmap,\n",
- " vmin=vmin,\n",
- " vmax=vmax\n",
- " )\n",
- " ax.axis('off')\n",
- " ax.set_ylim([freq_min, freq_max])\n",
- "\n",
- " plt.subplots_adjust(left=0, right=1, top=1, bottom=0)\n",
- " plt.savefig(\n",
- " output_file, format='png', bbox_inches='tight', pad_inches=0\n",
- " )\n",
- " plt.close()\n",
- "\n",
- " return wandb.Image(output_file)\n",
- "\n",
- "\n",
- "temp_dir = TemporaryDirectory()\n",
- "columns = [\"Model\", \"Prompt\", \"Generated-Audio\", \"Spectrogram\", \"Seed\"]\n",
- "if input_audio is not None:\n",
- " columns.insert(2, \"Input-Audio\")\n",
- "if config.enable_multi_band_diffusion:\n",
- " columns.insert(4, \"Generated-Audio-Diffusion\")\n",
- " columns.insert(5, \"Spectrogram-Diffusion\")\n",
- "wandb_table = wandb.Table(columns=columns)\n",
- "\n",
- "for idx, wav in enumerate(generated_wav):\n",
- "\n",
- " file_name = os.path.join(temp_dir.name, str(idx))\n",
- " audio_write(\n",
- " file_name,\n",
- " wav.cpu(),\n",
- " model.sample_rate,\n",
- " strategy=\"loudness\",\n",
- " loudness_compressor=True,\n",
- " )\n",
- " wandb_audio = wandb.Audio(file_name + \".wav\")\n",
- " wandb.log({\"Generated-Audio\": wandb_audio}, commit=False)\n",
- "\n",
- " file_name_diffusion, wandb_diffusion_audio = None, None\n",
- " if config.enable_multi_band_diffusion:\n",
- " file_name_diffusion = os.path.join(\n",
- " temp_dir.name, str(idx) + \"_diffusion\"\n",
- " )\n",
- " audio_write(\n",
- " file_name_diffusion,\n",
- " generated_wav_diffusion[idx].cpu(),\n",
- " model.sample_rate,\n",
- " strategy=\"loudness\",\n",
- " loudness_compressor=True,\n",
- " )\n",
- " wandb_diffusion_audio = wandb.Audio(file_name_diffusion + \".wav\")\n",
- " wandb.log(\n",
- " {\"Generated-Audio-Diffusion\": wandb_diffusion_audio},\n",
- " commit=False\n",
- " )\n",
- "\n",
- " wandb.log({}, commit=True)\n",
- "\n",
- " desc = descriptions[idx] if len(descriptions) > 1 else config.prompts\n",
- " wandb_table_row = [\n",
- " model_name,\n",
- " desc,\n",
- " wandb_audio,\n",
- " get_spectrogram(\n",
- " audio_file=file_name + \".wav\",\n",
- " output_file=os.path.join(temp_dir.name, str(idx) + \".png\")\n",
- " ),\n",
- " config.seed\n",
- " ]\n",
- " if input_audio is not None:\n",
- " wandb_table_row.insert(2, wandb_input_audio)\n",
- " if config.enable_multi_band_diffusion:\n",
- " wandb_table_row.insert(4, wandb_diffusion_audio)\n",
- " wandb_table_row.insert(\n",
- " 5,\n",
- " get_spectrogram(\n",
- " audio_file=file_name_diffusion + \".wav\",\n",
- " output_file=os.path.join(\n",
- " temp_dir.name, str(idx) + \"_diffusion.png\"\n",
- " )\n",
- " )\n",
- " )\n",
- " wandb_table.add_data(*wandb_table_row)\n",
- "\n",
- "wandb.log({\"Generated-Audio-Table\": wandb_table})\n",
- "\n",
- "wandb.finish()\n",
- "temp_dir.cleanup()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "This is how the W&B Table looks like with the interactive audio player, waveform visualizer and spectrogram visualization along with the prompts and other configs. Note that the notebook automatically sets the seed if you leave it blank, so your experiments are always reproducible.\n",
- "\n",
- "![](https://github.com/wandb/examples/blob/example/audiocraft/colabs/audiocraft/assets/music_gen.png?raw=1)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "If you want to know more about the underlying architectures for MusicGen and AudioGen and explore some cool audio samples generated by these models, you can check out [this W&B report](http://wandb.me/audiocraft_2mp)."
- ]
- }
- ],
- "metadata": {
- "accelerator": "GPU",
- "colab": {
- "gpuType": "V100",
- "provenance": []
- },
- "kernelspec": {
- "display_name": "Python 3",
- "name": "python3"
- },
- "language_info": {
- "name": "python"
- }
+ "cells": [
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n",
+ "\n",
+ "\n",
+ "# 🎸 Generating Music using [Audiocraft](https://github.com/facebookresearch/audiocraft) and W&B 🐝\n",
+ "\n",
+ "\n",
+ "\n",
+ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/wandb/examples/blob/master/colabs/audiocraft/AudioCraft.ipynb)\n",
+ "\n",
+ "In this notebook we demonstrate how you can generate music and other types of audio from text prompts or generate new music from existing music using SoTA models such as [MusicGen](https://github.com/facebookresearch/audiocraft/blob/main/docs/MUSICGEN.md) and [AudioGen](https://github.com/facebookresearch/audiocraft/blob/main/docs/AUDIOGEN.md) from [Audiocraft](https://github.com/facebookresearch/audiocraft) and play and visualize them using [Weights & Biases](https://wandb.ai/site).\n",
+ "\n",
+ "If you want to know more about the underlying architectures for MusicGen and AudioGen and explore some cool audio samples generated by these models, you can check out [this W&B report](http://wandb.me/audiocraft_2mp)."
+ ]
},
- "nbformat": 4,
- "nbformat_minor": 0
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# @title Install AudioCraft + WandB\n",
+ "!pip install -U git+https://git@github.com/facebookresearch/audiocraft#egg=audiocraft\n",
+ "!pip install -qq -U wandb"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# @title\n",
+ "import os\n",
+ "import random\n",
+ "from tempfile import TemporaryDirectory\n",
+ "\n",
+ "from scipy import signal\n",
+ "from scipy.io import wavfile\n",
+ "\n",
+ "import torchaudio\n",
+ "from audiocraft.models import AudioGen, MusicGen, MultiBandDiffusion\n",
+ "from audiocraft.data.audio import audio_write\n",
+ "\n",
+ "import wandb\n",
+ "import numpy as np\n",
+ "from tqdm.auto import tqdm\n",
+ "from google.colab import files\n",
+ "import matplotlib.pyplot as plt"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# @title ## Audio Generation Configs\n",
+ "\n",
+ "# @markdown In this section, you can interact with the user interface to chose the models you want to use to generate audio, prompts and other configs. Once you execute this cell, it initializes a [wandb run](https://docs.wandb.ai/guides/runs) which will be used to automatically log all the generated audio along with all the prompts and configs, to ensure your AI-generated music is never lost and your experiments are always reproducible and easy to share. \n",
+ "\n",
+ "# @markdown **Note:** If you have provided prompts, you will be prompted to provide an audio file in addition to the prompts to condition the model. If you don't want to provide a file as an additional condition to the model, just press on the `cancel` button.\n",
+ "\n",
+ "# @markdown ---\n",
+ "# @markdown WandB Project Name\n",
+ "project_name = \"audiocraft\" # @param {type:\"string\"}\n",
+ "\n",
+ "wandb.init(project=project_name, job_type=\"musicgen/inference\")\n",
+ "\n",
+ "config = wandb.config\n",
+ "\n",
+ "# @markdown Select the Model for audio generation supported by [AudioCraft](https://github.com/facebookresearch/audiocraft). You can select either the MusicGen model variants (great for generating music) or the AudioGen model variants (great for generating non-musical audio). Also note that you can run all variants of MusicGen except the `large` one on the free-tier Colab GPU.\n",
+ "model_name = \"musicgen-small\" # @param [\"musicgen-small\", \"musicgen-medium\", \"musicgen-large\", \"musicgen-melody\", \"audiogen-medium\"]\n",
+ "config.model_name = \"facebook/\" + model_name if model_name == \"audiogen-medium\" else model_name\n",
+ "\n",
+ "# @markdown Whether to enable [MultiBand Diffusion](https://github.com/facebookresearch/audiocraft/blob/main/docs/MBD.md) or not. MultiBand diffusion is a collection of 4 models that can decode tokens from EnCodec tokenizer into waveform audio. Note that enabling this increases the time required to generate the audio.\n",
+ "enable_multi_band_diffusion = True # @param {type:\"boolean\"}\n",
+ "# config.enable_multi_band_diffusion = enable_multi_band_diffusion\n",
+ "\n",
+ "if \"musicgen\" not in model_name:\n",
+ " wandb.termwarn(\"Multi-band Diffusion is only available for Musicgen\")\n",
+ " config.enable_multi_band_diffusion = False\n",
+ "else:\n",
+ " config.enable_multi_band_diffusion = enable_multi_band_diffusion\n",
+ "\n",
+ "# @markdown ---\n",
+ "# @markdown ## Conditional Generation Configs\n",
+ "\n",
+ "# @markdown The prompt for generating audio. You can give multiple prompts separated by `|` in the input. You can also leave it blank for unconditional generation.\n",
+ "config.prompts = \"happy rock | energetic EDM | sad jazz\" # @param {type:\"string\"}\n",
+ "\n",
+ "descriptions = [prompt.strip() for prompt in config.prompts.split(\"|\")]\n",
+ "config.is_unconditional = config.prompts.strip() == \"\"\n",
+ "\n",
+ "input_audio, input_sampling_rate, wandb_input_audio = None, None, None\n",
+ "if not config.is_unconditional:\n",
+ " input_audio_file = files.upload()\n",
+ " if input_audio_file != {}:\n",
+ " if config.model_name == \"facebook/audiogen-medium\":\n",
+ " error = f\"{config.model_name} does not support audio-based conditioning\"\n",
+ " raise ValueError(error)\n",
+ " wandb_input_audio = wandb.Audio(list(input_audio_file.keys())[0])\n",
+ " input_audio, input_sampling_rate = torchaudio.load(\n",
+ " list(input_audio_file.keys())[0]\n",
+ " )\n",
+ " config.input_audio_available = True\n",
+ " else:\n",
+ " config.input_audio_available = False\n",
+ "else:\n",
+ " if config.model_name == \"facebook/audiogen-medium\":\n",
+ " error = f\"{config.model_name} does not support unconditional generration\"\n",
+ " raise ValueError(error)\n",
+ "\n",
+ "# @markdown Number of audio samples generated, this is relevant only for unconditional generation, i.e, if `config.prompts` is left blank.\n",
+ "config.num_samples = 4 # @param {type:\"slider\", min:1, max:10, step:1}\n",
+ "\n",
+ "# @markdown Specify the random seed\n",
+ "seed = None # @param {type:\"raw\"}\n",
+ "\n",
+ "max_seed = int(1024 * 1024 * 1024)\n",
+ "if not isinstance(seed, int):\n",
+ " seed = random.randint(1, max_seed)\n",
+ "if seed < 0:\n",
+ " seed = - seed\n",
+ "seed = seed % max_seed\n",
+ "config.seed = seed\n",
+ "\n",
+ "# @markdown ---\n",
+ "# @markdown ## Generation Parameters\n",
+ "# @markdown Use sampling if True, else do argmax decoding\n",
+ "config.use_sampling = True # @param {type:\"boolean\"}\n",
+ "\n",
+ "# @markdown `top_k` used for sampling; limits us to `k` number of of the top tokens to consider.\n",
+ "config.top_k = 250 # @param {type:\"slider\", min:0, max:1000, step:1}\n",
+ "\n",
+ "# @markdown `top_p` used for sampling; limits us to the top tokens within a probability mass `p`\n",
+ "config.top_p = 0.0 # @param {type:\"slider\", min:0, max:1.0, step:0.01}\n",
+ "\n",
+ "# @markdown Softmax temperature parameter\n",
+ "config.temperature = 1.0 # @param {type:\"slider\", min:0, max:1.0, step:0.01}\n",
+ "\n",
+ "# @markdown Duration of the generated waveform\n",
+ "config.duration = 10 # @param {type:\"slider\", min:1, max:30, step:1}\n",
+ "\n",
+ "# @markdown Coefficient used for classifier free guidance\n",
+ "config.cfg_coef = 3 # @param {type:\"slider\", min:1, max:100, step:1}\n",
+ "\n",
+ "# @markdown Whether to perform 2 forward for Classifier Free Guidance instead of batching together the two. This has some impact on how things are padded but seems to have little impact in practice.\n",
+ "config.two_step_cfg = False # @param {type:\"boolean\"}\n",
+ "\n",
+ "# @markdown When doing extended generation (i.e. more than 30 seconds), by how much should we extend the audio each time. Larger values will mean less context is preserved, and shorter value will require extra computations.\n",
+ "config.extend_stride = 0 # @param {type:\"slider\", min:0, max:30, step:1}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# @title Generate Audio using MusicGen\n",
+ "\n",
+ "# @markdown In this section, the audio is generated using the configs, specified in the aforementioned section. If you wish to peek behind the curtain and checkout the code, click on the `Show Code` button. In order to know about the different APIs for audio generation, visit the [official audiocraft documentations](https://facebookresearch.github.io/audiocraft/api_docs/audiocraft/index.html).\n",
+ "\n",
+ "model = None\n",
+ "if config.model_name == \"facebook/audiogen-medium\":\n",
+ " model = AudioGen.get_pretrained(config.model_name)\n",
+ "elif \"musicgen\" in config.model_name:\n",
+ " model = MusicGen.get_pretrained(config.model_name.split(\"-\")[-1])\n",
+ "\n",
+ "multi_band_diffusion = None\n",
+ "if config.enable_multi_band_diffusion:\n",
+ " multi_band_diffusion = MultiBandDiffusion.get_mbd_musicgen()\n",
+ "\n",
+ "model.set_generation_params(\n",
+ " use_sampling=config.use_sampling,\n",
+ " top_k=config.top_k,\n",
+ " top_p=config.top_p,\n",
+ " temperature=config.temperature,\n",
+ " duration=config.duration,\n",
+ " cfg_coef=config.cfg_coef,\n",
+ " two_step_cfg=config.two_step_cfg,\n",
+ " extend_stride=config.extend_stride\n",
+ ")\n",
+ "\n",
+ "generated_wav, tokens = None, None\n",
+ "if config.is_unconditional:\n",
+ " if input_audio is None:\n",
+ " if \"musicgen\" in config.model_name:\n",
+ " generated_wav, tokens = model.generate_unconditional(\n",
+ " num_samples=config.num_samples,\n",
+ " progress=True,\n",
+ " return_tokens=True\n",
+ " )\n",
+ " else:\n",
+ " generated_wav = model.generate_unconditional(\n",
+ " num_samples=config.num_samples,\n",
+ " progress=True,\n",
+ " )\n",
+ " else:\n",
+ " if \"musicgen\" in config.model_name:\n",
+ " generated_wav, tokens = model.generate_with_chroma(\n",
+ " descriptions,\n",
+ " input_audio[None].expand(3, -1, -1),\n",
+ " input_sampling_rate,\n",
+ " return_tokens=True\n",
+ " )\n",
+ " else:\n",
+ " generated_wav = model.generate_with_chroma(\n",
+ " descriptions,\n",
+ " input_audio[None].expand(3, -1, -1),\n",
+ " input_sampling_rate,\n",
+ " )\n",
+ "else:\n",
+ " if \"musicgen\" in config.model_name:\n",
+ " generated_wav, tokens = model.generate(\n",
+ " descriptions,\n",
+ " progress=True,\n",
+ " return_tokens=True\n",
+ " )\n",
+ " else:\n",
+ " generated_wav = model.generate(\n",
+ " descriptions,\n",
+ " progress=True,\n",
+ " )\n",
+ "\n",
+ "generated_wav_diffusion = None\n",
+ "if config.enable_multi_band_diffusion:\n",
+ " generated_wav_diffusion = multi_band_diffusion.tokens_to_wav(tokens)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# @title Log Audio to Weights & Biases Dashboard\n",
+ "\n",
+ "# @markdown In this section, we log the generated audio to Weights & Biases where you can listen and visualize them using an interactive audio player and waveform visualizer. Also, shoutout to [Atanu Sarkar](https://github.com/mratanusarkar) for building the spectrogram viusalization function which lets you visualize the spectrogram of the generated audio inside a [`wandb.Table`](https://docs.wandb.ai/guides/tables/tables-walkthrough).\n",
+ "\n",
+ "def get_spectrogram(audio_file, output_file):\n",
+ " sample_rate, samples = wavfile.read(audio_file)\n",
+ " frequencies, times, Sxx = signal.spectrogram(samples, sample_rate)\n",
+ "\n",
+ " log_Sxx = 10 * np.log10(Sxx + 1e-10)\n",
+ " vmin = np.percentile(log_Sxx, 5)\n",
+ " vmax = np.percentile(log_Sxx, 95)\n",
+ "\n",
+ " mean_spectrum = np.mean(log_Sxx, axis=1)\n",
+ " threshold_low = np.percentile(mean_spectrum, 5)\n",
+ " threshold_high = np.percentile(mean_spectrum, 95)\n",
+ "\n",
+ " freq_indices = np.where(mean_spectrum > threshold_low)\n",
+ " freq_min = 20\n",
+ " freq_max = frequencies[freq_indices].max()\n",
+ "\n",
+ " fig, ax = plt.subplots()\n",
+ " cmap = plt.get_cmap('magma')\n",
+ "\n",
+ " ax.pcolormesh(\n",
+ " times,\n",
+ " frequencies,\n",
+ " log_Sxx,\n",
+ " shading='gouraud',\n",
+ " cmap=cmap,\n",
+ " vmin=vmin,\n",
+ " vmax=vmax\n",
+ " )\n",
+ " ax.axis('off')\n",
+ " ax.set_ylim([freq_min, freq_max])\n",
+ "\n",
+ " plt.subplots_adjust(left=0, right=1, top=1, bottom=0)\n",
+ " plt.savefig(\n",
+ " output_file, format='png', bbox_inches='tight', pad_inches=0\n",
+ " )\n",
+ " plt.close()\n",
+ "\n",
+ " return wandb.Image(output_file)\n",
+ "\n",
+ "\n",
+ "temp_dir = TemporaryDirectory()\n",
+ "columns = [\"Model\", \"Prompt\", \"Generated-Audio\", \"Spectrogram\", \"Seed\"]\n",
+ "if input_audio is not None:\n",
+ " columns.insert(2, \"Input-Audio\")\n",
+ "if config.enable_multi_band_diffusion:\n",
+ " columns.insert(4, \"Generated-Audio-Diffusion\")\n",
+ " columns.insert(5, \"Spectrogram-Diffusion\")\n",
+ "wandb_table = wandb.Table(columns=columns)\n",
+ "\n",
+ "for idx, wav in enumerate(generated_wav):\n",
+ "\n",
+ " file_name = os.path.join(temp_dir.name, str(idx))\n",
+ " audio_write(\n",
+ " file_name,\n",
+ " wav.cpu(),\n",
+ " model.sample_rate,\n",
+ " strategy=\"loudness\",\n",
+ " loudness_compressor=True,\n",
+ " )\n",
+ " wandb_audio = wandb.Audio(file_name + \".wav\")\n",
+ " wandb.log({\"Generated-Audio\": wandb_audio}, commit=False)\n",
+ "\n",
+ " file_name_diffusion, wandb_diffusion_audio = None, None\n",
+ " if config.enable_multi_band_diffusion:\n",
+ " file_name_diffusion = os.path.join(\n",
+ " temp_dir.name, str(idx) + \"_diffusion\"\n",
+ " )\n",
+ " audio_write(\n",
+ " file_name_diffusion,\n",
+ " generated_wav_diffusion[idx].cpu(),\n",
+ " model.sample_rate,\n",
+ " strategy=\"loudness\",\n",
+ " loudness_compressor=True,\n",
+ " )\n",
+ " wandb_diffusion_audio = wandb.Audio(file_name_diffusion + \".wav\")\n",
+ " wandb.log(\n",
+ " {\"Generated-Audio-Diffusion\": wandb_diffusion_audio},\n",
+ " commit=False\n",
+ " )\n",
+ "\n",
+ " wandb.log({}, commit=True)\n",
+ "\n",
+ " desc = descriptions[idx] if len(descriptions) > 1 else config.prompts\n",
+ " wandb_table_row = [\n",
+ " model_name,\n",
+ " desc,\n",
+ " wandb_audio,\n",
+ " get_spectrogram(\n",
+ " audio_file=file_name + \".wav\",\n",
+ " output_file=os.path.join(temp_dir.name, str(idx) + \".png\")\n",
+ " ),\n",
+ " config.seed\n",
+ " ]\n",
+ " if input_audio is not None:\n",
+ " wandb_table_row.insert(2, wandb_input_audio)\n",
+ " if config.enable_multi_band_diffusion:\n",
+ " wandb_table_row.insert(4, wandb_diffusion_audio)\n",
+ " wandb_table_row.insert(\n",
+ " 5,\n",
+ " get_spectrogram(\n",
+ " audio_file=file_name_diffusion + \".wav\",\n",
+ " output_file=os.path.join(\n",
+ " temp_dir.name, str(idx) + \"_diffusion.png\"\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " wandb_table.add_data(*wandb_table_row)\n",
+ "\n",
+ "wandb.log({\"Generated-Audio-Table\": wandb_table})\n",
+ "\n",
+ "wandb.finish()\n",
+ "temp_dir.cleanup()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "This is how the W&B Table looks like with the interactive audio player, waveform visualizer and spectrogram visualization along with the prompts and other configs. Note that the notebook automatically sets the seed if you leave it blank, so your experiments are always reproducible.\n",
+ "\n",
+ "![](https://github.com/wandb/examples/blob/example/audiocraft/colabs/audiocraft/assets/music_gen.png?raw=1)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "If you want to know more about the underlying architectures for MusicGen and AudioGen and explore some cool audio samples generated by these models, you can check out [this W&B report](http://wandb.me/audiocraft_2mp)."
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "gpuType": "V100",
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
}
diff --git a/colabs/boosting/Credit_Scorecards_with_XGBoost_and_W&B.ipynb b/colabs/boosting/Credit_Scorecards_with_XGBoost_and_W&B.ipynb
index 6bd674f2..57b392ef 100644
--- a/colabs/boosting/Credit_Scorecards_with_XGBoost_and_W&B.ipynb
+++ b/colabs/boosting/Credit_Scorecards_with_XGBoost_and_W&B.ipynb
@@ -211,11 +211,19 @@
"outputs": [],
"source": [
"import wandb\n",
- "wandb.login()\n",
"\n",
"WANDB_PROJECT ='vehicle_loan_default'"
]
},
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "wandb.login()"
+ ]
+ },
{
"attachments": {},
"cell_type": "markdown",
diff --git a/colabs/boosting/Simple_LightGBM_Integration.ipynb b/colabs/boosting/Simple_LightGBM_Integration.ipynb
index 35ba65ec..263e64df 100644
--- a/colabs/boosting/Simple_LightGBM_Integration.ipynb
+++ b/colabs/boosting/Simple_LightGBM_Integration.ipynb
@@ -111,8 +111,15 @@
"outputs": [],
"source": [
"import wandb\n",
- "from wandb.lightgbm import wandb_callback, log_summary\n",
- "\n",
+ "from wandb.lightgbm import wandb_callback, log_summary\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"wandb.login()"
]
},
diff --git a/colabs/boosting/Using_W&B_Sweeps_with_XGBoost.ipynb b/colabs/boosting/Using_W&B_Sweeps_with_XGBoost.ipynb
index e2649961..5e1a2a8d 100644
--- a/colabs/boosting/Using_W&B_Sweeps_with_XGBoost.ipynb
+++ b/colabs/boosting/Using_W&B_Sweeps_with_XGBoost.ipynb
@@ -97,7 +97,15 @@
"outputs": [],
"source": [
"\n",
- "import wandb\n",
+ "import wandb"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"wandb.login()"
]
},
diff --git a/colabs/datasets-predictions/Image_Classification_with_Tables.ipynb b/colabs/datasets-predictions/Image_Classification_with_Tables.ipynb
index f8252a03..8def8a28 100644
--- a/colabs/datasets-predictions/Image_Classification_with_Tables.ipynb
+++ b/colabs/datasets-predictions/Image_Classification_with_Tables.ipynb
@@ -140,7 +140,15 @@
"outputs": [],
"source": [
"\n",
- "import wandb\n",
+ "import wandb"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"wandb.login()"
]
},
@@ -613,20 +621,6 @@
"3. [Articles](https://www.wandb.com/articles) - blog posts and tutorials\n",
"4. [Community](wandb.me/slack) - join our Slack community forum"
]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
}
],
"metadata": {
diff --git a/colabs/datasets-predictions/Logging_Timbre_Transfer_with_W&B.ipynb b/colabs/datasets-predictions/Logging_Timbre_Transfer_with_W&B.ipynb
index 5c761177..85d83e12 100644
--- a/colabs/datasets-predictions/Logging_Timbre_Transfer_with_W&B.ipynb
+++ b/colabs/datasets-predictions/Logging_Timbre_Transfer_with_W&B.ipynb
@@ -193,7 +193,15 @@
"metadata": {},
"outputs": [],
"source": [
- "WANDB_PROJECT = \"timbre_demo\"\n",
+ "WANDB_PROJECT = \"timbre_demo\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"wandb.login()"
]
},
diff --git a/colabs/datasets-predictions/W&B_Dataset_Visualization.ipynb b/colabs/datasets-predictions/W&B_Dataset_Visualization.ipynb
index 244dce8d..a9818a49 100644
--- a/colabs/datasets-predictions/W&B_Dataset_Visualization.ipynb
+++ b/colabs/datasets-predictions/W&B_Dataset_Visualization.ipynb
@@ -93,7 +93,15 @@
"source": [
"# default project name where results will be logged\n",
"WANDB_PROJECT = \"dsviz-demo-colab\"\n",
- "NUM_EXAMPLES = 50\n",
+ "NUM_EXAMPLES = 50"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"wandb.login()"
]
},
diff --git a/colabs/datasets-predictions/W&B_Tables_Quickstart.ipynb b/colabs/datasets-predictions/W&B_Tables_Quickstart.ipynb
index 592878ae..8b0cd943 100644
--- a/colabs/datasets-predictions/W&B_Tables_Quickstart.ipynb
+++ b/colabs/datasets-predictions/W&B_Tables_Quickstart.ipynb
@@ -82,11 +82,19 @@
"source": [
"\n",
"import wandb\n",
- "wandb.login()\n",
"\n",
"WANDB_PROJECT = \"mnist-viz\""
]
},
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "wandb.login()"
+ ]
+ },
{
"cell_type": "markdown",
"metadata": {},
diff --git a/colabs/deepchem/W&B_x_DeepChem.ipynb b/colabs/deepchem/W&B_x_DeepChem.ipynb
index 81687e46..84303756 100644
--- a/colabs/deepchem/W&B_x_DeepChem.ipynb
+++ b/colabs/deepchem/W&B_x_DeepChem.ipynb
@@ -58,12 +58,20 @@
"outputs": [],
"source": [
"import wandb\n",
- "wandb.login()\n",
"\n",
"import warnings\n",
"warnings.filterwarnings('ignore')"
]
},
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "wandb.login()"
+ ]
+ },
{
"cell_type": "markdown",
"metadata": {},
diff --git a/colabs/diffusers/assets/diffusers-autolog-1.gif b/colabs/diffusers/assets/diffusers-autolog-1.gif
new file mode 100644
index 00000000..55a20e31
Binary files /dev/null and b/colabs/diffusers/assets/diffusers-autolog-1.gif differ
diff --git a/colabs/diffusers/assets/diffusers-autolog-4.gif b/colabs/diffusers/assets/diffusers-autolog-4.gif
new file mode 100644
index 00000000..a52f63d6
Binary files /dev/null and b/colabs/diffusers/assets/diffusers-autolog-4.gif differ
diff --git a/colabs/diffusers/assets/diffusers-autolog-5.gif b/colabs/diffusers/assets/diffusers-autolog-5.gif
new file mode 100644
index 00000000..a3673e06
Binary files /dev/null and b/colabs/diffusers/assets/diffusers-autolog-5.gif differ
diff --git a/colabs/diffusers/diffusers-image-generation.ipynb b/colabs/diffusers/diffusers-image-generation.ipynb
index 00ac78c0..f3c603c5 100644
--- a/colabs/diffusers/diffusers-image-generation.ipynb
+++ b/colabs/diffusers/diffusers-image-generation.ipynb
@@ -512,13 +512,6 @@
"source": [
"notebook_launcher(training_loop, num_processes=config.num_processes)"
]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
}
],
"metadata": {
diff --git a/colabs/diffusers/lcm-diffusers.ipynb b/colabs/diffusers/lcm-diffusers.ipynb
new file mode 100644
index 00000000..710aa903
--- /dev/null
+++ b/colabs/diffusers/lcm-diffusers.ipynb
@@ -0,0 +1,118 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Image Generation with Consistency Models using 🤗 Diffusers\n",
+ "\n",
+ "\n",
+ "\n",
+ "This notebook demonstrates the following:\n",
+ "- Performing text-conditional image-generations with the [Consistency Models](https://huggingface.co/docs/diffusers/api/pipelines/consistency_models) using [🤗 Diffusers](https://huggingface.co/docs/diffusers).\n",
+ "- Manage image generation experiments using [Weights & Biases](http://wandb.ai/site).\n",
+ "- Log the prompts, generated images and experiment configs to [Weigts & Biases](http://wandb.ai/site) for visalization.\n",
+ "\n",
+ "![](./assets/diffusers-autolog-4.gif)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!pip install diffusers transformers accelerate wandb > install.log"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import random\n",
+ "\n",
+ "import torch\n",
+ "from diffusers import DiffusionPipeline\n",
+ "\n",
+ "import wandb\n",
+ "from wandb.integration.diffusers import autolog"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Initialize the diffusion pipeline for latent consistency model\n",
+ "pipeline = DiffusionPipeline.from_pretrained(\"SimianLuo/LCM_Dreamshaper_v7\")\n",
+ "pipeline = pipeline.to(torch_device=\"cuda\", torch_dtype=torch.float32)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Define the prompts, negative prompts, and seed.\n",
+ "prompt = [\n",
+ " \"a photograph of an astronaut riding a horse\",\n",
+ " \"a photograph of a dragon\"\n",
+ "]\n",
+ "\n",
+ "# Make the experiment reproducible by controlling randomness.\n",
+ "# The seed would be automatically logged to WandB.\n",
+ "generator = torch.Generator(device=\"cpu\").manual_seed(10)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Call WandB Autolog for Diffusers. This would automatically log\n",
+ "# the prompts, generated images, pipeline architecture and all\n",
+ "# associated experiment configs to Weights & Biases, thus making your\n",
+ "# image generation experiments easy to reproduce, share and analyze.\n",
+ "autolog(init=dict(project=\"diffusers_logging\"))\n",
+ "\n",
+ "# call the pipeline to generate the images\n",
+ "images = pipeline(\n",
+ " prompt,\n",
+ " num_images_per_prompt=2,\n",
+ " generator=generator,\n",
+ " num_inference_steps=10,\n",
+ ")\n",
+ "\n",
+ "# End the experiment\n",
+ "wandb.finish()"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "include_colab_link": true,
+ "provenance": [],
+ "toc_visible": true
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/colabs/diffusers/pixart-alpha-diffusers.ipynb b/colabs/diffusers/pixart-alpha-diffusers.ipynb
new file mode 100644
index 00000000..443e5704
--- /dev/null
+++ b/colabs/diffusers/pixart-alpha-diffusers.ipynb
@@ -0,0 +1,7145 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Image Generation with Pixart-α using 🤗 Diffusers\n",
+ "\n",
+ "[![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/wandb/examples/blob/master/colabs/diffusers/pixart-alpha-diffusers.ipynb)\n",
+ "\n",
+ "This notebook demonstrates the following:\n",
+ "- Performing text-conditional image-generations with the [Pixart-α model](https://huggingface.co/docs/diffusers/v0.23.1/en/api/pipelines/pixart) using [🤗 Diffusers](https://huggingface.co/docs/diffusers).\n",
+ "- Manage image generation experiments using [Weights & Biases](http://wandb.ai/site).\n",
+ "- Log the prompts, generated images and experiment configs to [Weigts & Biases](http://wandb.ai/site) for visalization.\n",
+ "\n",
+ "![](./assets/diffusers-autolog-1.gif)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "0_F8JWIgNfqy",
+ "outputId": "bd6cdb65-e29c-4f93-b7a9-406993a94d62"
+ },
+ "outputs": [],
+ "source": [
+ "!pip install diffusers transformers accelerate sentencepiece ftfy wandb > install.log"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 87,
+ "referenced_widgets": [
+ "95616ad88fc84e80a99f7aa0987a4372",
+ "2d4b6035778d4bac8a2159504533272e",
+ "35675cfe3b5d4c4f83fddd70c0cf8c50",
+ "00e83d39844348e59e7f256898e08709",
+ "f842eb80fed54b2ba865b75a05dae0a1",
+ "a3d6530e68c2427c893fcebcbae020f3",
+ "be51de41fd0c4ef89130862d33d57011",
+ "716aea1c27df41b79cf2642bb93d5df3",
+ "987d902ad45e4304a15dbd01fe755587",
+ "010dd14615ae4e58abb84cd466446b55",
+ "895a39fb613c406db808d1ad01dac37f"
+ ]
+ },
+ "id": "_6msOo17Q0Bq",
+ "outputId": "049c0ad5-eb5f-4464-bc76-c0e11e1818d1"
+ },
+ "outputs": [],
+ "source": [
+ "import random\n",
+ "\n",
+ "import torch\n",
+ "from diffusers import PixArtAlphaPipeline\n",
+ "\n",
+ "import wandb\n",
+ "from wandb.integration.diffusers import autolog"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 579,
+ "referenced_widgets": [
+ "950697f13cfb477cb09def01968a8be0",
+ "436ae734138c4274b8a0157e76c8d2e3",
+ "c54e4159f65549a4a4d877c224aa89ed",
+ "048a3dc56e8c4df39b1179d15029eb00",
+ "cc9cb09be76e43b591285a5d52c6ed76",
+ "fb7fffb9281f42b693c1b7f48325f96f",
+ "1ce438295bd145ad8064c9d7a9f56469",
+ "ab13c1cef2ff4f41b64c4ad103144b31",
+ "920fb69e19504d3b860f0cf55900d25d",
+ "006d201a04f342e89ee728d8d2538f0f",
+ "429b18e0a7874f6c8438cc797ba0123a",
+ "893b076e6f86408db6c1e542b3c730f6",
+ "546cbe5b663a4b9fabaff56fecc6e914",
+ "22802a6ccb80404e96f9912950d73252",
+ "6b1c8bebfaf3464c97151f0fa03d1204",
+ "44cbd72d7bf74a10b357ea7625b5c84b",
+ "314154a15a8d49da8eb158b9b64df8bc",
+ "da149b5bd2024a53b99282849349ea8a",
+ "79426ffc354b4b9daff5c64f6bea931e",
+ "282fd54083174b31a38efdc3f757cedb",
+ "5a724e34fb1b455389349963550b9da8",
+ "b302595b335b458899e30ff9eb0fd79a",
+ "b069c4b00c864783963c1bf299c5c372",
+ "fd5b8ec96c834a519fedba254849ec04",
+ "d9f9127720b541ddbbe5990e322ad089",
+ "3a04722dff694815a30d8c7c61f42e16",
+ "d30a20c93a404c02adaa4fcb66443cb2",
+ "dce47f13fc15450787a63016b61e1c1a",
+ "c4073b2d2c1d4642a1c99ce39c8fa39b",
+ "6ffbe6f493204f4e884692597058ad41",
+ "a517bf731b2748b0ae70c137beb50108",
+ "faa63481cc9a4ac5adbdeee025a7de19",
+ "2650cccd00ae4086955bed6ca909e126",
+ "94876605c0b14c55957f1ad227953d06",
+ "c4bce1fe40ef402cb8a86767f518fc70",
+ "c3398473b8d54655bad516100d0d0fbb",
+ "cc80937e599b45bc974d21f6bd5af90f",
+ "e6803d58f47f4b1fb4e76bca06d6b35c",
+ "dfa1e1cbc0044ba9823f307edbfbec1a",
+ "ce51bb48d8184438968b0f06b87e16d3",
+ "cb6eb4fc3ea44449a53d362d9819dd8e",
+ "062d186b597a47a08746353bf35148b1",
+ "0d435039170340569cf92c0255d7fd5e",
+ "9fc87495a9364e27a1da550d91b85550",
+ "4140e9203f1d4f078581ecfe3038d94b",
+ "fac39df9084640bbbee32339dc4613ad",
+ "97c1c321a42f450e939bb0baf1e42785",
+ "98f22a1dbac54d398191f61ed9764c6c",
+ "cef1102de0a14a46a5e918afc52f03bc",
+ "99fd909b2eb14415a79409646dd59fd2",
+ "41c8d0eff808446abd70ce585da26978",
+ "6ad951ba967b430cbff93ae75a6b1597",
+ "ba1c265343204064b4644f9bc9fbc0e9",
+ "e09197134928414c974211651e17b385",
+ "42901129e32343dcb8d5edb960519a53",
+ "13069f251f51467b83ab5119b2317337",
+ "32a404ef0450473eaccdd5594058b49d",
+ "67b4635164a04f96adca0efb40eae566",
+ "b61eb332e206436d93af97221feb9b4c",
+ "67f3fe40a0914321926c53e2fec176a5",
+ "49edcb2e0c8a4e9aad0398653046cf9e",
+ "4143365a9b69494eab538971b1602c35",
+ "9dbf25daca3840319f5686ff275d842f",
+ "08f75e1bf9864812a916f14eb29a12f9",
+ "6c42c3efc54f436996b154c38fb32209",
+ "acc056c201ef4cb8beac8d981b43ba3e",
+ "260f8ec597f9490d9d355664bd687dc7",
+ "f9730c19c2bc45a0ac95964fd7950032",
+ "803604e1509d4cb88fec21e96689a57b",
+ "10f589c69e884531b0c1ee967f523aa5",
+ "a86366ba25244c6da47189cf59dcbec6",
+ "6509119a1279497b91d29c7f922e8f3f",
+ "2c886cdef1d04e45a735cebaf8e05b00",
+ "7dc25f4f13ce4d7f94356ef9edeb2eb9",
+ "07e2d238d83e45bf9037fa851a5e5648",
+ "2e0869cc3ac24f12bf9d173cb432dc4c",
+ "27f2e688349d444a860dacb562634150",
+ "d213aeaea9f949658d1bd9728ee86fef",
+ "1cde89e63bd7437e9fac7446fc33aea3",
+ "442da82cec1241b8a5249d183820673d",
+ "f65994d833e04d9f96d2613a82651cdd",
+ "eb118c92cb514e999f7b02ebfbbc7d0b",
+ "a9c174e4fa204c8eb01990bcd02a43cb",
+ "ce7f64ac3c164f04835d7869f7ebfb48",
+ "4af899d9d9ba4023b4985e61e3feeda6",
+ "b692c6053fc04651833ee87015df0daa",
+ "0a4330eafca245378354d0abe07f881f",
+ "05c9d7dd8bca4023b8e9d67d8edcf84b",
+ "b9cb18bf70f843e2827a5908d86d490d",
+ "b7f856516aef474aac5d1755163a4775",
+ "973a485540b04967a277eaf9bff09fd9",
+ "66c72c7c741a4288a04e29284c7dab38",
+ "f5e977aeb3fe4edbacc09d6b9dec6df0",
+ "6fafd02828024e37beccd9086fb745f6",
+ "65bcdf864db84dff899b411bdb51d0c1",
+ "41d1bf639c5c4ee4afc327a64b8357b7",
+ "e6ee1b5afae24cf6bd731b64e53a682e",
+ "d10632b92b72482bb3892ec471f0e756",
+ "2e6f4510f5e6496f890e91f827d9cb99",
+ "516574616ff540778c024528a9fafec9",
+ "10468ffcd7c14cc79bc0fa023da9888a",
+ "2b2351a243144932bc0a8fbc0b542402",
+ "7c3adc19d763420787a6a55bb33f2d43",
+ "6610639505394df9b93cd39fdd0acd5c",
+ "fba986e08f5242afaa9ffc193ef9b3b1",
+ "5fc4719ad0fe4b16b0a9428ed5f32b3c",
+ "e9b0e81e633940e4983ceaf14687fa3b",
+ "1eab4186a6a24bfe98d764cfa74b342f",
+ "28f2f2f574f54fd6a16f244f4be5d08e",
+ "fb4e2edf849047ae9790f0522e13c11e",
+ "5e279ec61581429f84f0329ea5e62fee",
+ "61edf08284b942aaa7306d0a26a59344",
+ "c7e02eb0a33349629f58b9cbd7d3a9ad",
+ "0f497591d8334c47a7e01aeb3ceaef81",
+ "01be5b3663f348e7b47747f3371c1fbd",
+ "724772294d614997af26729ae0f845f9",
+ "ff86218d2aca4ee7b3764e998b8c9b9a",
+ "266d230ffacd4d76acb96a2466104289",
+ "8f7235cc17134152a6fc37a2a8ef85b0",
+ "e89ac0f3c4c848bb9505a357ca31919a",
+ "848aefc8f7c8452c8ac8871051166b3e",
+ "1200f91ecf914c6abf38c08365a9cc31",
+ "fe163c82f9b04a15ac6b923ed3b01873",
+ "729f0475028446e8bb00078a2bfbfbce",
+ "50f16033e1e44d05b7e9add4979a9b54",
+ "e28d59c860df45c9ad7318461e4ed25f",
+ "f650b81e97cf47b78b6512ff4a0794c6",
+ "d7a4b722cbc74d548edd1d0469d258a7",
+ "f54a32cfafb5462985b2e1f03f3e7b55",
+ "93141ddff63a44aba010559c6d209596",
+ "01d6bd13c6a743c0b82dcb51bf3e7623",
+ "247a4fda6dc947cc8e7336bcffcd2781",
+ "904ce266ca25413398b736c5f6760a5a",
+ "5cb62c1e29b6421d8f2b98d60543e045",
+ "dc0b1e41de4c4d99b6668e07e2b061a4",
+ "69ab97b1174b4742847ef885c7b2f4ba",
+ "65c618901769460bbe35f1f68d04920c",
+ "21f0ae78db2f4cd09c0b64f4ae5f7b8a",
+ "bf028fd02c0e4220a3077c4f41253ed7",
+ "e29ec1b24ede4619881a3e011699c434",
+ "75343e1ba0604632abf96e86ad7fe8b3",
+ "9ff029954ed64721a4f00faf34f701c5",
+ "f91b715ea2ae4fe0970c142ddc59d76e",
+ "32f4a8a631074da49f3361ea74b483fd",
+ "0a427178ccca414c8f58b892b54cb00d",
+ "591c006b4de647d2b95a8d3a6103a48b",
+ "aa04c220c98d4d6f880a58f6aa9bda75",
+ "04e94e56da49440abf2b6eaf6e675d98",
+ "bdeeb0f9825147efbb70a5eaf226a8ff",
+ "f7dabfcad74b4838b3ad63d62457f140",
+ "a6801b2c8f41424888dbafde4252c7da",
+ "5e8efc02014a455dbfbe356c965abd02",
+ "b9573401b606405da3b1d8815865b104",
+ "cf75087ea8c4401faebff1fe50cc8703",
+ "d96719a9e06845aaae085bd256148333",
+ "4497b0d5517a4acabba98e60af7c5677",
+ "fe96da4d0e484f79ab3a82820ddea5ea",
+ "d7b8ca6b0cf549af8f3f8b0e15f0b8af",
+ "2cb5315f2ad748c4a8605208dc9492d1",
+ "d3406fb70a4c470fa4cd9476d60a9370",
+ "63dc113923d145a88de3c2156b73a1ad",
+ "9690ca75724a4547a505b457efa18cf5",
+ "b9e34a78891a4be6b293733ad79fb409",
+ "cb7c327783124585a21b6062b9a4bec7",
+ "5f13dc4e83c140c1a020bebcd08943d5",
+ "2631936da4064c2b8763bb9bd081b264",
+ "3df3134ee2c14677b1ed6eb49ece2f46",
+ "93a02724ca12427998e17f8e355b0012",
+ "1ec067c782824e329661b9a903748caa",
+ "4166654783344449b71d946124b02703",
+ "9ee2b2a0e18d420ea09a202ce2804660",
+ "b0f80fbce966413496bf67440107cc5b",
+ "ee7cb7091ca5458bba7ab8cef3bb4b5e",
+ "56e272c0af234570a5371800c336ec4d",
+ "266723871e2c4f80bebc0e758b58374d",
+ "3c9728a91b5d485f81ef4d7814c6bb22",
+ "2845309eb6cd4e3bb91b1f3f0cfb6f43",
+ "251d69e9aac04d0da25fa6cb18ce907c",
+ "2466f004e7654d578536b3bdd9e47db2",
+ "ed3e92da5ade4986b381ef86e0d57ed4",
+ "98271fb9547940c69ff8958320056ca5",
+ "1e4fba8ea112482cab2df3e917f6407c",
+ "c001ad33ee0e4df1a314228e292529b5",
+ "532599e42fc64bfdb57957454312ed00",
+ "6c2bbdfd6e6f4a8faa2d47a81a0feebd",
+ "2f428cbd696a4ba4906c7fd8f74a57d4",
+ "f1bea04743fd4a63aaa1b03acdac859d"
+ ]
+ },
+ "id": "O4WktKfiQ0z7",
+ "outputId": "0be1b5e0-cfb2-4b1d-f9e4-18d97b92377c"
+ },
+ "outputs": [],
+ "source": [
+ "# Load the pre-trained checkpoints from HuggingFace Hub to the PixArtAlphaPipeline\n",
+ "pipe = PixArtAlphaPipeline.from_pretrained(\n",
+ " \"PixArt-alpha/PixArt-XL-2-1024-MS\", torch_dtype=torch.float16\n",
+ ")\n",
+ "\n",
+ "# Enable offloading the weights to the CPU and only loading them on the GPU when\n",
+ "# performing the forward pass can also save memory.\n",
+ "pipe.enable_model_cpu_offload()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "MQIZo4F_Q1qG"
+ },
+ "outputs": [],
+ "source": [
+ "wandb_project = \"pixart-alpha\" # @param {type:\"string\"}\n",
+ "\n",
+ "prompt = \"a traveler navigating via a boat in countless mountains, Chinese ink painting\" # @param {type:\"string\"}\n",
+ "negative_prompt = \"\" # @param {type:\"string\"}\n",
+ "num_inference_steps = 25 # @param {type:\"slider\", min:10, max:50, step:1}\n",
+ "guidance_scale = 4.5 # @param {type:\"slider\", min:0, max:10, step:0.1}\n",
+ "num_images_per_prompt = 1 # @param {type:\"slider\", min:0, max:10, step:0.1}\n",
+ "height = 1024 # @param {type:\"slider\", min:512, max:2560, step:32}\n",
+ "width = 1024 # @param {type:\"slider\", min:512, max:2560, step:32}\n",
+ "seed = None # @param {type:\"raw\"}\n",
+ "\n",
+ "\n",
+ "def autogenerate_seed():\n",
+ " max_seed = int(1024 * 1024 * 1024)\n",
+ " seed = random.randint(1, max_seed)\n",
+ " seed = -seed if seed < 0 else seed\n",
+ " seed = seed % max_seed\n",
+ " return seed\n",
+ "\n",
+ "\n",
+ "seed = autogenerate_seed() if seed is None else seed\n",
+ "\n",
+ "# Make the experiment reproducible by controlling randomness.\n",
+ "# The seed would be automatically logged to WandB.\n",
+ "generator = torch.Generator(device=\"cuda\").manual_seed(seed)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 206,
+ "referenced_widgets": [
+ "68239504cbdf4aa0a80096171156cffd",
+ "85ab89a71bb248e996fb8d0148a3d64c",
+ "ceeef5ea2da54e3b8f6fcda7f33b4e32",
+ "5498cc4574d74bf894f35c96d74dce80",
+ "529b32124a884391ba1176d340d480f9",
+ "95caf26b330740a593c8cc86a2fd1ee1",
+ "76d11a364eae4686a5879cddb98c4ddd",
+ "8ece46fc480c4d399e7f2b7a92e96cec",
+ "73afbd0ff1404c63b12909c609446cf9",
+ "601ce4737c2c46b493a191c7f3f6b950",
+ "2079c259a28943a0a80ddabc44026a39",
+ "0356e4a80e04409ca07d1581d735086f",
+ "609dff85b2104270a8e9263ee8e1e958",
+ "a25969fd35744afdb66e9fd59612df6e",
+ "4964cf137c504555a7df44112e65cc0b",
+ "52a240218ab84c0896c7bdc4bfe5a0ef",
+ "bd8e4b6f1e3f49a586fc2ca9baca58e3",
+ "ca2e80da3e194102a9ff2cf0880e1e99",
+ "8dae7b8cdf004d748be7a6395c45bd1a"
+ ]
+ },
+ "id": "QfopfTm-Q4S3",
+ "outputId": "ca4cd35d-eac4-44e7-c7fb-b019e40a00e9"
+ },
+ "outputs": [],
+ "source": [
+ "# Call WandB Autolog for Diffusers. This would automatically log\n",
+ "# the prompts, generated images, pipeline architecture and all\n",
+ "# associated experiment configs to Weights & Biases, thus making your\n",
+ "# image generation experiments easy to reproduce, share and analyze.\n",
+ "autolog(init=dict(project=wandb_project))\n",
+ "\n",
+ "# Generate the images by calling the PixArtAlphaPipeline\n",
+ "image = pipe(\n",
+ " prompt=prompt,\n",
+ " negative_prompt=negative_prompt,\n",
+ " num_inference_steps=num_inference_steps,\n",
+ " guidance_scale=guidance_scale,\n",
+ " num_images_per_prompt=num_images_per_prompt,\n",
+ " height=height,\n",
+ " width=width,\n",
+ " generator=generator,\n",
+ ").images[0]\n",
+ "\n",
+ "# End the experiment\n",
+ "wandb.finish()"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "gpuType": "T4",
+ "machine_shape": "hm",
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python"
+ },
+ "widgets": {
+ "application/vnd.jupyter.widget-state+json": {
+ "006d201a04f342e89ee728d8d2538f0f": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "00e83d39844348e59e7f256898e08709": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_010dd14615ae4e58abb84cd466446b55",
+ "placeholder": "",
+ "style": "IPY_MODEL_895a39fb613c406db808d1ad01dac37f",
+ "value": " 0/0 [00:00<?, ?it/s]"
+ }
+ },
+ "010dd14615ae4e58abb84cd466446b55": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "01be5b3663f348e7b47747f3371c1fbd": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "01d6bd13c6a743c0b82dcb51bf3e7623": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "0356e4a80e04409ca07d1581d735086f": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "VBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "VBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "VBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_609dff85b2104270a8e9263ee8e1e958",
+ "IPY_MODEL_a25969fd35744afdb66e9fd59612df6e"
+ ],
+ "layout": "IPY_MODEL_4964cf137c504555a7df44112e65cc0b"
+ }
+ },
+ "048a3dc56e8c4df39b1179d15029eb00": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_006d201a04f342e89ee728d8d2538f0f",
+ "placeholder": "",
+ "style": "IPY_MODEL_429b18e0a7874f6c8438cc797ba0123a",
+ "value": " 400/400 [00:00<00:00, 27.8kB/s]"
+ }
+ },
+ "04e94e56da49440abf2b6eaf6e675d98": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "05c9d7dd8bca4023b8e9d67d8edcf84b": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "062d186b597a47a08746353bf35148b1": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "07e2d238d83e45bf9037fa851a5e5648": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "08f75e1bf9864812a916f14eb29a12f9": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "0a427178ccca414c8f58b892b54cb00d": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_bdeeb0f9825147efbb70a5eaf226a8ff",
+ "placeholder": "",
+ "style": "IPY_MODEL_f7dabfcad74b4838b3ad63d62457f140",
+ "value": "diffusion_pytorch_model.safetensors: 100%"
+ }
+ },
+ "0a4330eafca245378354d0abe07f881f": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "0d435039170340569cf92c0255d7fd5e": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "0f497591d8334c47a7e01aeb3ceaef81": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_e89ac0f3c4c848bb9505a357ca31919a",
+ "placeholder": "",
+ "style": "IPY_MODEL_848aefc8f7c8452c8ac8871051166b3e",
+ "value": " 9.99G/9.99G [01:18<00:00, 159MB/s]"
+ }
+ },
+ "10468ffcd7c14cc79bc0fa023da9888a": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_fba986e08f5242afaa9ffc193ef9b3b1",
+ "placeholder": "",
+ "style": "IPY_MODEL_5fc4719ad0fe4b16b0a9428ed5f32b3c",
+ "value": "vae/config.json: 100%"
+ }
+ },
+ "10f589c69e884531b0c1ee967f523aa5": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_2e0869cc3ac24f12bf9d173cb432dc4c",
+ "placeholder": "",
+ "style": "IPY_MODEL_27f2e688349d444a860dacb562634150",
+ "value": " 2.63k/2.63k [00:00<00:00, 56.9kB/s]"
+ }
+ },
+ "1200f91ecf914c6abf38c08365a9cc31": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_fe163c82f9b04a15ac6b923ed3b01873",
+ "IPY_MODEL_729f0475028446e8bb00078a2bfbfbce",
+ "IPY_MODEL_50f16033e1e44d05b7e9add4979a9b54"
+ ],
+ "layout": "IPY_MODEL_e28d59c860df45c9ad7318461e4ed25f"
+ }
+ },
+ "13069f251f51467b83ab5119b2317337": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_32a404ef0450473eaccdd5594058b49d",
+ "IPY_MODEL_67b4635164a04f96adca0efb40eae566",
+ "IPY_MODEL_b61eb332e206436d93af97221feb9b4c"
+ ],
+ "layout": "IPY_MODEL_67f3fe40a0914321926c53e2fec176a5"
+ }
+ },
+ "1cde89e63bd7437e9fac7446fc33aea3": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_a9c174e4fa204c8eb01990bcd02a43cb",
+ "placeholder": "",
+ "style": "IPY_MODEL_ce7f64ac3c164f04835d7869f7ebfb48",
+ "value": "tokenizer/tokenizer_config.json: 100%"
+ }
+ },
+ "1ce438295bd145ad8064c9d7a9f56469": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "1e4fba8ea112482cab2df3e917f6407c": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "1eab4186a6a24bfe98d764cfa74b342f": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "1ec067c782824e329661b9a903748caa": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_266723871e2c4f80bebc0e758b58374d",
+ "placeholder": "",
+ "style": "IPY_MODEL_3c9728a91b5d485f81ef4d7814c6bb22",
+ "value": " 5/5 [00:11<00:00, 2.08s/it]"
+ }
+ },
+ "2079c259a28943a0a80ddabc44026a39": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "21f0ae78db2f4cd09c0b64f4ae5f7b8a": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "22802a6ccb80404e96f9912950d73252": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_79426ffc354b4b9daff5c64f6bea931e",
+ "max": 14,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_282fd54083174b31a38efdc3f757cedb",
+ "value": 14
+ }
+ },
+ "2466f004e7654d578536b3bdd9e47db2": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_532599e42fc64bfdb57957454312ed00",
+ "max": 2,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_6c2bbdfd6e6f4a8faa2d47a81a0feebd",
+ "value": 2
+ }
+ },
+ "247a4fda6dc947cc8e7336bcffcd2781": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "251d69e9aac04d0da25fa6cb18ce907c": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_1e4fba8ea112482cab2df3e917f6407c",
+ "placeholder": "",
+ "style": "IPY_MODEL_c001ad33ee0e4df1a314228e292529b5",
+ "value": "Loading checkpoint shards: 100%"
+ }
+ },
+ "260f8ec597f9490d9d355664bd687dc7": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_f9730c19c2bc45a0ac95964fd7950032",
+ "IPY_MODEL_803604e1509d4cb88fec21e96689a57b",
+ "IPY_MODEL_10f589c69e884531b0c1ee967f523aa5"
+ ],
+ "layout": "IPY_MODEL_a86366ba25244c6da47189cf59dcbec6"
+ }
+ },
+ "2631936da4064c2b8763bb9bd081b264": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_3df3134ee2c14677b1ed6eb49ece2f46",
+ "IPY_MODEL_93a02724ca12427998e17f8e355b0012",
+ "IPY_MODEL_1ec067c782824e329661b9a903748caa"
+ ],
+ "layout": "IPY_MODEL_4166654783344449b71d946124b02703"
+ }
+ },
+ "2650cccd00ae4086955bed6ca909e126": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "266723871e2c4f80bebc0e758b58374d": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "266d230ffacd4d76acb96a2466104289": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "27f2e688349d444a860dacb562634150": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "282fd54083174b31a38efdc3f757cedb": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "2845309eb6cd4e3bb91b1f3f0cfb6f43": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_251d69e9aac04d0da25fa6cb18ce907c",
+ "IPY_MODEL_2466f004e7654d578536b3bdd9e47db2",
+ "IPY_MODEL_ed3e92da5ade4986b381ef86e0d57ed4"
+ ],
+ "layout": "IPY_MODEL_98271fb9547940c69ff8958320056ca5"
+ }
+ },
+ "28f2f2f574f54fd6a16f244f4be5d08e": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "2b2351a243144932bc0a8fbc0b542402": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_e9b0e81e633940e4983ceaf14687fa3b",
+ "max": 654,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_1eab4186a6a24bfe98d764cfa74b342f",
+ "value": 654
+ }
+ },
+ "2c886cdef1d04e45a735cebaf8e05b00": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "2cb5315f2ad748c4a8605208dc9492d1": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "2d4b6035778d4bac8a2159504533272e": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_a3d6530e68c2427c893fcebcbae020f3",
+ "placeholder": "",
+ "style": "IPY_MODEL_be51de41fd0c4ef89130862d33d57011",
+ "value": ""
+ }
+ },
+ "2e0869cc3ac24f12bf9d173cb432dc4c": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "2e6f4510f5e6496f890e91f827d9cb99": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "2f428cbd696a4ba4906c7fd8f74a57d4": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "314154a15a8d49da8eb158b9b64df8bc": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "32a404ef0450473eaccdd5594058b49d": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_49edcb2e0c8a4e9aad0398653046cf9e",
+ "placeholder": "",
+ "style": "IPY_MODEL_4143365a9b69494eab538971b1602c35",
+ "value": "(…)ext_encoder/model.safetensors.index.json: 100%"
+ }
+ },
+ "32f4a8a631074da49f3361ea74b483fd": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_0a427178ccca414c8f58b892b54cb00d",
+ "IPY_MODEL_591c006b4de647d2b95a8d3a6103a48b",
+ "IPY_MODEL_aa04c220c98d4d6f880a58f6aa9bda75"
+ ],
+ "layout": "IPY_MODEL_04e94e56da49440abf2b6eaf6e675d98"
+ }
+ },
+ "35675cfe3b5d4c4f83fddd70c0cf8c50": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_716aea1c27df41b79cf2642bb93d5df3",
+ "max": 1,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_987d902ad45e4304a15dbd01fe755587",
+ "value": 0
+ }
+ },
+ "3a04722dff694815a30d8c7c61f42e16": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_faa63481cc9a4ac5adbdeee025a7de19",
+ "placeholder": "",
+ "style": "IPY_MODEL_2650cccd00ae4086955bed6ca909e126",
+ "value": " 655/655 [00:00<00:00, 48.5kB/s]"
+ }
+ },
+ "3c9728a91b5d485f81ef4d7814c6bb22": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "3df3134ee2c14677b1ed6eb49ece2f46": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_9ee2b2a0e18d420ea09a202ce2804660",
+ "placeholder": "",
+ "style": "IPY_MODEL_b0f80fbce966413496bf67440107cc5b",
+ "value": "Loading pipeline components...: 100%"
+ }
+ },
+ "4140e9203f1d4f078581ecfe3038d94b": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_fac39df9084640bbbee32339dc4613ad",
+ "IPY_MODEL_97c1c321a42f450e939bb0baf1e42785",
+ "IPY_MODEL_98f22a1dbac54d398191f61ed9764c6c"
+ ],
+ "layout": "IPY_MODEL_cef1102de0a14a46a5e918afc52f03bc"
+ }
+ },
+ "4143365a9b69494eab538971b1602c35": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "4166654783344449b71d946124b02703": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "41c8d0eff808446abd70ce585da26978": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "41d1bf639c5c4ee4afc327a64b8357b7": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "42901129e32343dcb8d5edb960519a53": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "429b18e0a7874f6c8438cc797ba0123a": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "436ae734138c4274b8a0157e76c8d2e3": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_fb7fffb9281f42b693c1b7f48325f96f",
+ "placeholder": "",
+ "style": "IPY_MODEL_1ce438295bd145ad8064c9d7a9f56469",
+ "value": "model_index.json: 100%"
+ }
+ },
+ "442da82cec1241b8a5249d183820673d": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_4af899d9d9ba4023b4985e61e3feeda6",
+ "max": 20509,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_b692c6053fc04651833ee87015df0daa",
+ "value": 20509
+ }
+ },
+ "4497b0d5517a4acabba98e60af7c5677": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_d3406fb70a4c470fa4cd9476d60a9370",
+ "placeholder": "",
+ "style": "IPY_MODEL_63dc113923d145a88de3c2156b73a1ad",
+ "value": "diffusion_pytorch_model.safetensors: 100%"
+ }
+ },
+ "44cbd72d7bf74a10b357ea7625b5c84b": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "4964cf137c504555a7df44112e65cc0b": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "49edcb2e0c8a4e9aad0398653046cf9e": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "4af899d9d9ba4023b4985e61e3feeda6": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "50f16033e1e44d05b7e9add4979a9b54": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_01d6bd13c6a743c0b82dcb51bf3e7623",
+ "placeholder": "",
+ "style": "IPY_MODEL_247a4fda6dc947cc8e7336bcffcd2781",
+ "value": " 9.06G/9.06G [01:16<00:00, 150MB/s]"
+ }
+ },
+ "516574616ff540778c024528a9fafec9": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_10468ffcd7c14cc79bc0fa023da9888a",
+ "IPY_MODEL_2b2351a243144932bc0a8fbc0b542402",
+ "IPY_MODEL_7c3adc19d763420787a6a55bb33f2d43"
+ ],
+ "layout": "IPY_MODEL_6610639505394df9b93cd39fdd0acd5c"
+ }
+ },
+ "529b32124a884391ba1176d340d480f9": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "52a240218ab84c0896c7bdc4bfe5a0ef": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "532599e42fc64bfdb57957454312ed00": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "546cbe5b663a4b9fabaff56fecc6e914": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_314154a15a8d49da8eb158b9b64df8bc",
+ "placeholder": "",
+ "style": "IPY_MODEL_da149b5bd2024a53b99282849349ea8a",
+ "value": "Fetching 14 files: 100%"
+ }
+ },
+ "5498cc4574d74bf894f35c96d74dce80": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_601ce4737c2c46b493a191c7f3f6b950",
+ "placeholder": "",
+ "style": "IPY_MODEL_2079c259a28943a0a80ddabc44026a39",
+ "value": " 25/25 [00:32<00:00, 1.07s/it]"
+ }
+ },
+ "56e272c0af234570a5371800c336ec4d": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "591c006b4de647d2b95a8d3a6103a48b": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_a6801b2c8f41424888dbafde4252c7da",
+ "max": 2447431856,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_5e8efc02014a455dbfbe356c965abd02",
+ "value": 2447431856
+ }
+ },
+ "5a724e34fb1b455389349963550b9da8": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "5cb62c1e29b6421d8f2b98d60543e045": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_21f0ae78db2f4cd09c0b64f4ae5f7b8a",
+ "placeholder": "",
+ "style": "IPY_MODEL_bf028fd02c0e4220a3077c4f41253ed7",
+ "value": "spiece.model: 100%"
+ }
+ },
+ "5e279ec61581429f84f0329ea5e62fee": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_61edf08284b942aaa7306d0a26a59344",
+ "IPY_MODEL_c7e02eb0a33349629f58b9cbd7d3a9ad",
+ "IPY_MODEL_0f497591d8334c47a7e01aeb3ceaef81"
+ ],
+ "layout": "IPY_MODEL_01be5b3663f348e7b47747f3371c1fbd"
+ }
+ },
+ "5e8efc02014a455dbfbe356c965abd02": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "5f13dc4e83c140c1a020bebcd08943d5": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "5fc4719ad0fe4b16b0a9428ed5f32b3c": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "601ce4737c2c46b493a191c7f3f6b950": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "609dff85b2104270a8e9263ee8e1e958": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "LabelModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "LabelModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "LabelView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_52a240218ab84c0896c7bdc4bfe5a0ef",
+ "placeholder": "",
+ "style": "IPY_MODEL_bd8e4b6f1e3f49a586fc2ca9baca58e3",
+ "value": "3.327 MB of 3.327 MB uploaded (0.008 MB deduped)\r"
+ }
+ },
+ "61edf08284b942aaa7306d0a26a59344": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_724772294d614997af26729ae0f845f9",
+ "placeholder": "",
+ "style": "IPY_MODEL_ff86218d2aca4ee7b3764e998b8c9b9a",
+ "value": "model-00001-of-00002.safetensors: 100%"
+ }
+ },
+ "63dc113923d145a88de3c2156b73a1ad": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "6509119a1279497b91d29c7f922e8f3f": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "65bcdf864db84dff899b411bdb51d0c1": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "65c618901769460bbe35f1f68d04920c": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "6610639505394df9b93cd39fdd0acd5c": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "66c72c7c741a4288a04e29284c7dab38": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_d10632b92b72482bb3892ec471f0e756",
+ "placeholder": "",
+ "style": "IPY_MODEL_2e6f4510f5e6496f890e91f827d9cb99",
+ "value": " 719/719 [00:00<00:00, 25.4kB/s]"
+ }
+ },
+ "67b4635164a04f96adca0efb40eae566": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_9dbf25daca3840319f5686ff275d842f",
+ "max": 19886,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_08f75e1bf9864812a916f14eb29a12f9",
+ "value": 19886
+ }
+ },
+ "67f3fe40a0914321926c53e2fec176a5": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "68239504cbdf4aa0a80096171156cffd": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_85ab89a71bb248e996fb8d0148a3d64c",
+ "IPY_MODEL_ceeef5ea2da54e3b8f6fcda7f33b4e32",
+ "IPY_MODEL_5498cc4574d74bf894f35c96d74dce80"
+ ],
+ "layout": "IPY_MODEL_529b32124a884391ba1176d340d480f9"
+ }
+ },
+ "69ab97b1174b4742847ef885c7b2f4ba": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_9ff029954ed64721a4f00faf34f701c5",
+ "placeholder": "",
+ "style": "IPY_MODEL_f91b715ea2ae4fe0970c142ddc59d76e",
+ "value": " 792k/792k [00:00<00:00, 4.76MB/s]"
+ }
+ },
+ "6ad951ba967b430cbff93ae75a6b1597": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "6b1c8bebfaf3464c97151f0fa03d1204": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_5a724e34fb1b455389349963550b9da8",
+ "placeholder": "",
+ "style": "IPY_MODEL_b302595b335b458899e30ff9eb0fd79a",
+ "value": " 14/14 [01:20<00:00, 23.51s/it]"
+ }
+ },
+ "6c2bbdfd6e6f4a8faa2d47a81a0feebd": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "6c42c3efc54f436996b154c38fb32209": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "6fafd02828024e37beccd9086fb745f6": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "6ffbe6f493204f4e884692597058ad41": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "716aea1c27df41b79cf2642bb93d5df3": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": "20px"
+ }
+ },
+ "724772294d614997af26729ae0f845f9": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "729f0475028446e8bb00078a2bfbfbce": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_f54a32cfafb5462985b2e1f03f3e7b55",
+ "max": 9060119392,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_93141ddff63a44aba010559c6d209596",
+ "value": 9060119392
+ }
+ },
+ "73afbd0ff1404c63b12909c609446cf9": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "75343e1ba0604632abf96e86ad7fe8b3": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "76d11a364eae4686a5879cddb98c4ddd": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "79426ffc354b4b9daff5c64f6bea931e": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "7c3adc19d763420787a6a55bb33f2d43": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_28f2f2f574f54fd6a16f244f4be5d08e",
+ "placeholder": "",
+ "style": "IPY_MODEL_fb4e2edf849047ae9790f0522e13c11e",
+ "value": " 654/654 [00:00<00:00, 28.9kB/s]"
+ }
+ },
+ "7dc25f4f13ce4d7f94356ef9edeb2eb9": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "803604e1509d4cb88fec21e96689a57b": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_7dc25f4f13ce4d7f94356ef9edeb2eb9",
+ "max": 2634,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_07e2d238d83e45bf9037fa851a5e5648",
+ "value": 2634
+ }
+ },
+ "848aefc8f7c8452c8ac8871051166b3e": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "85ab89a71bb248e996fb8d0148a3d64c": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_95caf26b330740a593c8cc86a2fd1ee1",
+ "placeholder": "",
+ "style": "IPY_MODEL_76d11a364eae4686a5879cddb98c4ddd",
+ "value": "100%"
+ }
+ },
+ "893b076e6f86408db6c1e542b3c730f6": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_546cbe5b663a4b9fabaff56fecc6e914",
+ "IPY_MODEL_22802a6ccb80404e96f9912950d73252",
+ "IPY_MODEL_6b1c8bebfaf3464c97151f0fa03d1204"
+ ],
+ "layout": "IPY_MODEL_44cbd72d7bf74a10b357ea7625b5c84b"
+ }
+ },
+ "895a39fb613c406db808d1ad01dac37f": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "8dae7b8cdf004d748be7a6395c45bd1a": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "8ece46fc480c4d399e7f2b7a92e96cec": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "8f7235cc17134152a6fc37a2a8ef85b0": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "904ce266ca25413398b736c5f6760a5a": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_5cb62c1e29b6421d8f2b98d60543e045",
+ "IPY_MODEL_dc0b1e41de4c4d99b6668e07e2b061a4",
+ "IPY_MODEL_69ab97b1174b4742847ef885c7b2f4ba"
+ ],
+ "layout": "IPY_MODEL_65c618901769460bbe35f1f68d04920c"
+ }
+ },
+ "920fb69e19504d3b860f0cf55900d25d": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "93141ddff63a44aba010559c6d209596": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "93a02724ca12427998e17f8e355b0012": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_ee7cb7091ca5458bba7ab8cef3bb4b5e",
+ "max": 5,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_56e272c0af234570a5371800c336ec4d",
+ "value": 5
+ }
+ },
+ "94876605c0b14c55957f1ad227953d06": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_c4bce1fe40ef402cb8a86767f518fc70",
+ "IPY_MODEL_c3398473b8d54655bad516100d0d0fbb",
+ "IPY_MODEL_cc80937e599b45bc974d21f6bd5af90f"
+ ],
+ "layout": "IPY_MODEL_e6803d58f47f4b1fb4e76bca06d6b35c"
+ }
+ },
+ "950697f13cfb477cb09def01968a8be0": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_436ae734138c4274b8a0157e76c8d2e3",
+ "IPY_MODEL_c54e4159f65549a4a4d877c224aa89ed",
+ "IPY_MODEL_048a3dc56e8c4df39b1179d15029eb00"
+ ],
+ "layout": "IPY_MODEL_cc9cb09be76e43b591285a5d52c6ed76"
+ }
+ },
+ "95616ad88fc84e80a99f7aa0987a4372": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_2d4b6035778d4bac8a2159504533272e",
+ "IPY_MODEL_35675cfe3b5d4c4f83fddd70c0cf8c50",
+ "IPY_MODEL_00e83d39844348e59e7f256898e08709"
+ ],
+ "layout": "IPY_MODEL_f842eb80fed54b2ba865b75a05dae0a1"
+ }
+ },
+ "95caf26b330740a593c8cc86a2fd1ee1": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "9690ca75724a4547a505b457efa18cf5": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "973a485540b04967a277eaf9bff09fd9": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_41d1bf639c5c4ee4afc327a64b8357b7",
+ "max": 719,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_e6ee1b5afae24cf6bd731b64e53a682e",
+ "value": 719
+ }
+ },
+ "97c1c321a42f450e939bb0baf1e42785": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_6ad951ba967b430cbff93ae75a6b1597",
+ "max": 788,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_ba1c265343204064b4644f9bc9fbc0e9",
+ "value": 788
+ }
+ },
+ "98271fb9547940c69ff8958320056ca5": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "987d902ad45e4304a15dbd01fe755587": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "98f22a1dbac54d398191f61ed9764c6c": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_e09197134928414c974211651e17b385",
+ "placeholder": "",
+ "style": "IPY_MODEL_42901129e32343dcb8d5edb960519a53",
+ "value": " 788/788 [00:00<00:00, 15.2kB/s]"
+ }
+ },
+ "99fd909b2eb14415a79409646dd59fd2": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "9dbf25daca3840319f5686ff275d842f": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "9ee2b2a0e18d420ea09a202ce2804660": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "9fc87495a9364e27a1da550d91b85550": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "9ff029954ed64721a4f00faf34f701c5": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "a25969fd35744afdb66e9fd59612df6e": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_ca2e80da3e194102a9ff2cf0880e1e99",
+ "max": 1,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_8dae7b8cdf004d748be7a6395c45bd1a",
+ "value": 1
+ }
+ },
+ "a3d6530e68c2427c893fcebcbae020f3": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "a517bf731b2748b0ae70c137beb50108": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "a6801b2c8f41424888dbafde4252c7da": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "a86366ba25244c6da47189cf59dcbec6": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "a9c174e4fa204c8eb01990bcd02a43cb": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "aa04c220c98d4d6f880a58f6aa9bda75": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_b9573401b606405da3b1d8815865b104",
+ "placeholder": "",
+ "style": "IPY_MODEL_cf75087ea8c4401faebff1fe50cc8703",
+ "value": " 2.45G/2.45G [00:31<00:00, 130MB/s]"
+ }
+ },
+ "ab13c1cef2ff4f41b64c4ad103144b31": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "acc056c201ef4cb8beac8d981b43ba3e": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "b069c4b00c864783963c1bf299c5c372": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_fd5b8ec96c834a519fedba254849ec04",
+ "IPY_MODEL_d9f9127720b541ddbbe5990e322ad089",
+ "IPY_MODEL_3a04722dff694815a30d8c7c61f42e16"
+ ],
+ "layout": "IPY_MODEL_d30a20c93a404c02adaa4fcb66443cb2"
+ }
+ },
+ "b0f80fbce966413496bf67440107cc5b": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "b302595b335b458899e30ff9eb0fd79a": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "b61eb332e206436d93af97221feb9b4c": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_6c42c3efc54f436996b154c38fb32209",
+ "placeholder": "",
+ "style": "IPY_MODEL_acc056c201ef4cb8beac8d981b43ba3e",
+ "value": " 19.9k/19.9k [00:00<00:00, 425kB/s]"
+ }
+ },
+ "b692c6053fc04651833ee87015df0daa": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "b7f856516aef474aac5d1755163a4775": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_6fafd02828024e37beccd9086fb745f6",
+ "placeholder": "",
+ "style": "IPY_MODEL_65bcdf864db84dff899b411bdb51d0c1",
+ "value": "transformer/config.json: 100%"
+ }
+ },
+ "b9573401b606405da3b1d8815865b104": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "b9cb18bf70f843e2827a5908d86d490d": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_b7f856516aef474aac5d1755163a4775",
+ "IPY_MODEL_973a485540b04967a277eaf9bff09fd9",
+ "IPY_MODEL_66c72c7c741a4288a04e29284c7dab38"
+ ],
+ "layout": "IPY_MODEL_f5e977aeb3fe4edbacc09d6b9dec6df0"
+ }
+ },
+ "b9e34a78891a4be6b293733ad79fb409": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "ba1c265343204064b4644f9bc9fbc0e9": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "bd8e4b6f1e3f49a586fc2ca9baca58e3": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "bdeeb0f9825147efbb70a5eaf226a8ff": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "be51de41fd0c4ef89130862d33d57011": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "bf028fd02c0e4220a3077c4f41253ed7": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "c001ad33ee0e4df1a314228e292529b5": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "c3398473b8d54655bad516100d0d0fbb": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_cb6eb4fc3ea44449a53d362d9819dd8e",
+ "max": 2201,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_062d186b597a47a08746353bf35148b1",
+ "value": 2201
+ }
+ },
+ "c4073b2d2c1d4642a1c99ce39c8fa39b": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "c4bce1fe40ef402cb8a86767f518fc70": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_dfa1e1cbc0044ba9823f307edbfbec1a",
+ "placeholder": "",
+ "style": "IPY_MODEL_ce51bb48d8184438968b0f06b87e16d3",
+ "value": "tokenizer/special_tokens_map.json: 100%"
+ }
+ },
+ "c54e4159f65549a4a4d877c224aa89ed": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_ab13c1cef2ff4f41b64c4ad103144b31",
+ "max": 400,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_920fb69e19504d3b860f0cf55900d25d",
+ "value": 400
+ }
+ },
+ "c7e02eb0a33349629f58b9cbd7d3a9ad": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_266d230ffacd4d76acb96a2466104289",
+ "max": 9989150328,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_8f7235cc17134152a6fc37a2a8ef85b0",
+ "value": 9989150328
+ }
+ },
+ "ca2e80da3e194102a9ff2cf0880e1e99": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "cb6eb4fc3ea44449a53d362d9819dd8e": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "cb7c327783124585a21b6062b9a4bec7": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "cc80937e599b45bc974d21f6bd5af90f": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_0d435039170340569cf92c0255d7fd5e",
+ "placeholder": "",
+ "style": "IPY_MODEL_9fc87495a9364e27a1da550d91b85550",
+ "value": " 2.20k/2.20k [00:00<00:00, 60.1kB/s]"
+ }
+ },
+ "cc9cb09be76e43b591285a5d52c6ed76": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "ce51bb48d8184438968b0f06b87e16d3": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "ce7f64ac3c164f04835d7869f7ebfb48": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "ceeef5ea2da54e3b8f6fcda7f33b4e32": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_8ece46fc480c4d399e7f2b7a92e96cec",
+ "max": 25,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_73afbd0ff1404c63b12909c609446cf9",
+ "value": 25
+ }
+ },
+ "cef1102de0a14a46a5e918afc52f03bc": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "cf75087ea8c4401faebff1fe50cc8703": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "d10632b92b72482bb3892ec471f0e756": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "d213aeaea9f949658d1bd9728ee86fef": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_1cde89e63bd7437e9fac7446fc33aea3",
+ "IPY_MODEL_442da82cec1241b8a5249d183820673d",
+ "IPY_MODEL_f65994d833e04d9f96d2613a82651cdd"
+ ],
+ "layout": "IPY_MODEL_eb118c92cb514e999f7b02ebfbbc7d0b"
+ }
+ },
+ "d30a20c93a404c02adaa4fcb66443cb2": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "d3406fb70a4c470fa4cd9476d60a9370": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "d7a4b722cbc74d548edd1d0469d258a7": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "d7b8ca6b0cf549af8f3f8b0e15f0b8af": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_cb7c327783124585a21b6062b9a4bec7",
+ "placeholder": "",
+ "style": "IPY_MODEL_5f13dc4e83c140c1a020bebcd08943d5",
+ "value": " 335M/335M [00:03<00:00, 92.4MB/s]"
+ }
+ },
+ "d96719a9e06845aaae085bd256148333": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_4497b0d5517a4acabba98e60af7c5677",
+ "IPY_MODEL_fe96da4d0e484f79ab3a82820ddea5ea",
+ "IPY_MODEL_d7b8ca6b0cf549af8f3f8b0e15f0b8af"
+ ],
+ "layout": "IPY_MODEL_2cb5315f2ad748c4a8605208dc9492d1"
+ }
+ },
+ "d9f9127720b541ddbbe5990e322ad089": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_6ffbe6f493204f4e884692597058ad41",
+ "max": 655,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_a517bf731b2748b0ae70c137beb50108",
+ "value": 655
+ }
+ },
+ "da149b5bd2024a53b99282849349ea8a": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "dc0b1e41de4c4d99b6668e07e2b061a4": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_e29ec1b24ede4619881a3e011699c434",
+ "max": 791656,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_75343e1ba0604632abf96e86ad7fe8b3",
+ "value": 791656
+ }
+ },
+ "dce47f13fc15450787a63016b61e1c1a": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "dfa1e1cbc0044ba9823f307edbfbec1a": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "e09197134928414c974211651e17b385": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "e28d59c860df45c9ad7318461e4ed25f": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "e29ec1b24ede4619881a3e011699c434": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "e6803d58f47f4b1fb4e76bca06d6b35c": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "e6ee1b5afae24cf6bd731b64e53a682e": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "e89ac0f3c4c848bb9505a357ca31919a": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "e9b0e81e633940e4983ceaf14687fa3b": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "eb118c92cb514e999f7b02ebfbbc7d0b": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "ed3e92da5ade4986b381ef86e0d57ed4": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_2f428cbd696a4ba4906c7fd8f74a57d4",
+ "placeholder": "",
+ "style": "IPY_MODEL_f1bea04743fd4a63aaa1b03acdac859d",
+ "value": " 2/2 [00:02<00:00, 1.17s/it]"
+ }
+ },
+ "ee7cb7091ca5458bba7ab8cef3bb4b5e": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "f1bea04743fd4a63aaa1b03acdac859d": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "f54a32cfafb5462985b2e1f03f3e7b55": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "f5e977aeb3fe4edbacc09d6b9dec6df0": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "f650b81e97cf47b78b6512ff4a0794c6": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "f65994d833e04d9f96d2613a82651cdd": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_0a4330eafca245378354d0abe07f881f",
+ "placeholder": "",
+ "style": "IPY_MODEL_05c9d7dd8bca4023b8e9d67d8edcf84b",
+ "value": " 20.5k/20.5k [00:00<00:00, 1.28MB/s]"
+ }
+ },
+ "f7dabfcad74b4838b3ad63d62457f140": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "f842eb80fed54b2ba865b75a05dae0a1": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "f91b715ea2ae4fe0970c142ddc59d76e": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "f9730c19c2bc45a0ac95964fd7950032": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_6509119a1279497b91d29c7f922e8f3f",
+ "placeholder": "",
+ "style": "IPY_MODEL_2c886cdef1d04e45a735cebaf8e05b00",
+ "value": "tokenizer/added_tokens.json: 100%"
+ }
+ },
+ "faa63481cc9a4ac5adbdeee025a7de19": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "fac39df9084640bbbee32339dc4613ad": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_99fd909b2eb14415a79409646dd59fd2",
+ "placeholder": "",
+ "style": "IPY_MODEL_41c8d0eff808446abd70ce585da26978",
+ "value": "text_encoder/config.json: 100%"
+ }
+ },
+ "fb4e2edf849047ae9790f0522e13c11e": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "fb7fffb9281f42b693c1b7f48325f96f": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "fba986e08f5242afaa9ffc193ef9b3b1": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "fd5b8ec96c834a519fedba254849ec04": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_dce47f13fc15450787a63016b61e1c1a",
+ "placeholder": "",
+ "style": "IPY_MODEL_c4073b2d2c1d4642a1c99ce39c8fa39b",
+ "value": "scheduler/scheduler_config.json: 100%"
+ }
+ },
+ "fe163c82f9b04a15ac6b923ed3b01873": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_f650b81e97cf47b78b6512ff4a0794c6",
+ "placeholder": "",
+ "style": "IPY_MODEL_d7a4b722cbc74d548edd1d0469d258a7",
+ "value": "model-00002-of-00002.safetensors: 100%"
+ }
+ },
+ "fe96da4d0e484f79ab3a82820ddea5ea": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_9690ca75724a4547a505b457efa18cf5",
+ "max": 334643268,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_b9e34a78891a4be6b293733ad79fb409",
+ "value": 334643268
+ }
+ },
+ "ff86218d2aca4ee7b3764e998b8c9b9a": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ }
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/colabs/diffusers/sdxl-diffusers.ipynb b/colabs/diffusers/sdxl-diffusers.ipynb
new file mode 100644
index 00000000..302184b8
--- /dev/null
+++ b/colabs/diffusers/sdxl-diffusers.ipynb
@@ -0,0 +1,153 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Image Generation with Stable Diffusion XL using 🤗 Diffusers\n",
+ "\n",
+ "\n",
+ "\n",
+ "[![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/wandb/examples/blob/master/colabs/diffusers/sdxl-diffusers.ipynb)\n",
+ "\n",
+ "This notebook demonstrates the following:\n",
+ "- Performing text-conditional image-generations with the [Stable Diffusion XL](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl) using [🤗 Diffusers](https://huggingface.co/docs/diffusers).\n",
+ "- Manage image generation experiments using [Weights & Biases](http://wandb.ai/site).\n",
+ "- Log the prompts, generated images and experiment configs to [Weigts & Biases](http://wandb.ai/site) for visalization.\n",
+ "\n",
+ "![](./assets/diffusers-autolog-5.gif)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!pip install diffusers transformers accelerate wandb > install.log"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import random\n",
+ "\n",
+ "import torch\n",
+ "from diffusers import StableDiffusionXLImg2ImgPipeline, StableDiffusionXLPipeline\n",
+ "\n",
+ "import wandb\n",
+ "from wandb.integration.diffusers import autolog"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "base_model_id = \"stabilityai/stable-diffusion-xl-base-1.0\" # @param [\"stabilityai/stable-diffusion-xl-base-1.0\", \"segmind/SSD-1B\", \"stabilityai/sdxl-turbo\"]\n",
+ "\n",
+ "base_pipeline = StableDiffusionXLPipeline.from_pretrained(\n",
+ " \"stabilityai/stable-diffusion-xl-base-1.0\",\n",
+ " torch_dtype=torch.float16,\n",
+ " variant=\"fp16\",\n",
+ " use_safetensors=True,\n",
+ ")\n",
+ "\n",
+ "base_pipeline.enable_model_cpu_offload()\n",
+ "\n",
+ "refiner_pipeline = StableDiffusionXLImg2ImgPipeline.from_pretrained(\n",
+ " \"stabilityai/stable-diffusion-xl-refiner-1.0\",\n",
+ " text_encoder_2=base_pipeline.text_encoder_2,\n",
+ " vae=base_pipeline.vae,\n",
+ " torch_dtype=torch.float16,\n",
+ " use_safetensors=True,\n",
+ " variant=\"fp16\",\n",
+ ")\n",
+ "refiner_pipeline.enable_model_cpu_offload()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "wandb_project = \"pixart-alpha\" # @param {type:\"string\"}\n",
+ "\n",
+ "prompt_1 = \"a photograph of an evil and vile looking demon in Bengali attire eating fish. The demon has large and bloody teeth. The demon is sitting on the branches of a giant Banyan tree, dimly lit, bluish and dark color palette, realistic, 8k\" # @param {type:\"string\"}\n",
+ "prompt_2 = \"\" # @param {type:\"string\"}\n",
+ "negative_prompt_1 = \"static, frame, painting, illustration, sd character, low quality, low resolution, greyscale, monochrome, nose, cropped, lowres, jpeg artifacts, deformed iris, deformed pupils, bad eyes, semi-realistic worst quality, bad lips, deformed mouth, deformed face, deformed fingers, deformed toes standing still, posing\" # @param {type:\"string\"}\n",
+ "negative_prompt_2 = \"static, frame, painting, illustration, sd character, low quality, low resolution, greyscale, monochrome, nose, cropped, lowres, jpeg artifacts, deformed iris, deformed pupils, bad eyes, semi-realistic worst quality, bad lips, deformed mouth, deformed face, deformed fingers, deformed toes standing still, posing\" # @param {type:\"string\"}\n",
+ "num_inference_steps = 50 # @param {type:\"slider\", min:10, max:100, step:1}\n",
+ "guidance_scale = 5.0 # @param {type:\"slider\", min:0, max:10, step:0.1}\n",
+ "height = 1024 # @param {type:\"slider\", min:512, max:2560, step:32}\n",
+ "width = 1024 # @param {type:\"slider\", min:512, max:2560, step:32}\n",
+ "seed = None # @param {type:\"raw\"}\n",
+ "\n",
+ "\n",
+ "def autogenerate_seed():\n",
+ " max_seed = int(1024 * 1024 * 1024)\n",
+ " seed = random.randint(1, max_seed)\n",
+ " seed = -seed if seed < 0 else seed\n",
+ " seed = seed % max_seed\n",
+ " return seed\n",
+ "\n",
+ "\n",
+ "seed = autogenerate_seed() if seed is None else seed\n",
+ "\n",
+ "# Make the experiment reproducible by controlling randomness.\n",
+ "# The seed would be automatically logged to WandB.\n",
+ "generator_base = torch.Generator(device=\"cuda\").manual_seed(seed)\n",
+ "generator_refiner = torch.Generator(device=\"cuda\").manual_seed(seed)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Call WandB Autolog for Diffusers. This would automatically log\n",
+ "# the prompts, generated images, pipeline architecture and all\n",
+ "# associated experiment configs to Weights & Biases, thus making your\n",
+ "# image generation experiments easy to reproduce, share and analyze.\n",
+ "autolog(init=dict(project=wandb_project))\n",
+ "\n",
+ "image = base_pipeline(\n",
+ " prompt=prompt_1,\n",
+ " prompt_2=prompt_2,\n",
+ " negative_prompt=negative_prompt_1,\n",
+ " negative_prompt_2=negative_prompt_2,\n",
+ " num_inference_steps=num_inference_steps,\n",
+ " output_type=\"latent\",\n",
+ " generator=generator_base,\n",
+ " guidance_scale=guidance_scale,\n",
+ ").images[0]\n",
+ "\n",
+ "image = refiner_pipeline(\n",
+ " prompt=prompt_1,\n",
+ " prompt_2=prompt_2,\n",
+ " negative_prompt=negative_prompt_1,\n",
+ " negative_prompt_2=negative_prompt_2,\n",
+ " image=image[None, :],\n",
+ " num_inference_steps=num_inference_steps,\n",
+ " guidance_scale=guidance_scale,\n",
+ " generator=generator_refiner,\n",
+ ").images[0]\n",
+ "\n",
+ "wandb.finish()"
+ ]
+ }
+ ],
+ "metadata": {
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/colabs/diffusers/sdxl-text-to-image.ipynb b/colabs/diffusers/sdxl-text-to-image.ipynb
index 65663461..3050b661 100644
--- a/colabs/diffusers/sdxl-text-to-image.ipynb
+++ b/colabs/diffusers/sdxl-text-to-image.ipynb
@@ -148,7 +148,7 @@
"\n",
"1. We define the base diffusion pipeline using `diffusers.DiffusionPipeline` and load the pre-trained weights for SDXL 1.0 by calling the `from_pretrained` function on it. We also pass the scheduler as `diffusers.EulerDiscreteScheduler` in this step.\n",
"\n",
- "2. In case we don't have a GPU with large enough GPU, it's recommended to enable CPU offloading. Otherwise, we load the model on the GPU. In case you're curious how HiggingFace manages CPU offloading in the most optimized manner, we recommend you read this port by [Sylvain Gugger](https://huggingface.co/sgugger): [How 🤗 Accelerate runs very large models thanks to PyTorch](https://huggingface.co/blog/accelerate-large-models).\n",
+ "2. In case we don't have a GPU with large enough GPU, it's recommended to enable CPU offloading. Otherwise, we load the model on the GPU. In case you're curious how HuggingFace manages CPU offloading in the most optimized manner, we recommend you read this port by [Sylvain Gugger](https://huggingface.co/sgugger): [How 🤗 Accelerate runs very large models thanks to PyTorch](https://huggingface.co/blog/accelerate-large-models).\n",
"\n",
"3. We can compile model using `torch.compile`, this might give a significant speedup.\n",
"\n",
diff --git a/colabs/dsviz/W&B_Dataset_and_Predictions_Viz_Demo.ipynb b/colabs/dsviz/W&B_Dataset_and_Predictions_Viz_Demo.ipynb
index b7447302..4c713086 100644
--- a/colabs/dsviz/W&B_Dataset_and_Predictions_Viz_Demo.ipynb
+++ b/colabs/dsviz/W&B_Dataset_and_Predictions_Viz_Demo.ipynb
@@ -112,7 +112,15 @@
"outputs": [],
"source": [
"# default project name where results will be logged\n",
- "WANDB_PROJECT = \"dsviz-demo\"\n",
+ "WANDB_PROJECT = \"dsviz-demo\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"wandb.login()"
]
},
diff --git a/colabs/huggingface/Custom_Progress_Callback.ipynb b/colabs/huggingface/Custom_Progress_Callback.ipynb
new file mode 100644
index 00000000..252e18cc
--- /dev/null
+++ b/colabs/huggingface/Custom_Progress_Callback.ipynb
@@ -0,0 +1,460 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n",
+ ""
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Visualize LLM training progress with Wandb Tables\n",
+ "\n",
+ "In this example we will see how to instrument a custom [callback](https://huggingface.co/docs/transformers/main_classes/callback) for the huggingface [Trainer](https://huggingface.co/docs/transformers/main_classes/trainer) to periodically visualize model predictions using Weight & Biases [Tables](https://wandb.ai/site/tables)\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "## 🤔 Why should I use W&B?\n",
+ "\n",
+ "\n",
+ "\n",
+ "- **Unified dashboard**: Central repository for all your model metrics and predictions\n",
+ "- **Lightweight**: No code changes required to integrate with Hugging Face\n",
+ "- **Accessible**: Free for individuals and academic teams\n",
+ "- **Secure**: All projects are private by default\n",
+ "- **Trusted**: Used by machine learning teams at OpenAI, Toyota, Lyft and more\n",
+ "\n",
+ "Think of W&B like GitHub for machine learning models— save machine learning experiments to your private, hosted dashboard. Experiment quickly with the confidence that all the versions of your models are saved for you, no matter where you're running your scripts.\n",
+ "\n",
+ "W&B lightweight integrations works with any Python script, and all you need to do is sign up for a free W&B account to start tracking and visualizing your models."
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# 🚀 Install, Import, and Log in\n",
+ "\n",
+ "\n",
+ "To get started with this example you will need to Install the Transformers, Weights & Biases, datasets libraries.\n",
+ "- [Hugging Face Transformers](https://github.com/huggingface/transformers)\n",
+ "- [Weights & Biases](https://docs.wandb.com/)\n",
+ "- [Huggingface Datasets](https://github.com/huggingface/datasets)\n",
+ "\n",
+ "**Uncomment the following cell install the libraries.**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# ! pip install -qqq datasets \"transformers[torch]\" wandb"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 🖊️ [Sign up for a free account →](https://app.wandb.ai/login?signup=true)\n",
+ "\n",
+ "## 🔑 Put in your API key\n",
+ "\n",
+ "Once you've signed up, run the next cell and click on the link to get your API key and authenticate this notebook."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import wandb\n",
+ "wandb.login()"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Train a language model\n",
+ "\n",
+ "In this notebook, we'll see how to train a [🤗 Transformers](https://github.com/huggingface/transformers) model on the Causal language modeling task i.e. the model has to predict the next token in the sentence (so the labels are the same as the inputs shifted to the right). To make sure the model does not cheat, it gets an attention mask that will prevent it to access the tokens after token i when trying to predict the token i+1 in the sentence.\n",
+ "\n",
+ "We will see how to load and preprocess the dataset for the task and train a model on it using the `Trainer` API.\n",
+ "We will be building a custom trainer callback to visualize the model predictions using Weights & Biases Tables by priodically logging the predictions to the table."
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Prepare the dataset"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "we will use the [Wikitext 2](https://paperswithcode.com/dataset/wikitext-2) dataset as an example for this task."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from datasets import load_dataset\n",
+ "datasets = load_dataset('wikitext', 'wikitext-2-raw-v1')\n",
+ "\n",
+ "# look at a sample from the train dataset\n",
+ "datasets[\"train\"][10]"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n",
+ "For causal language modeling (CLM) task we are going to take all the texts in our dataset and concatenate them after they are tokenized.\n",
+ "Then we split them in examples of a certain sequence length. This way the model will receive chunks of contiguous text that may look like:\n",
+ "```\n",
+ "part of text 1\n",
+ "```\n",
+ "or\n",
+ "```\n",
+ "end of text 1 [BOS_TOKEN] beginning of text 2\n",
+ "```\n",
+ "depending on whether they span over several of the original texts in the dataset or not. The labels will be the same as the inputs, shifted to the left.\n",
+ "\n",
+ "We will use the [`gpt2`](https://huggingface.co/gpt2) architecture for this example. You can pick any of the checkpoints listed [here](https://huggingface.co/models?filter=causal-lm) instead."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "MODEL_NAME =\"gpt2\"\n",
+ "\n",
+ "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
+ "\n",
+ "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We can now call the tokenizer on all our texts using the [`map`](https://huggingface.co/docs/datasets/process#map) method from the Datasets library."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def tokenize_function(examples):\n",
+ " return tokenizer(examples[\"text\"])\n",
+ "\n",
+ "tokenized_datasets = datasets.map(tokenize_function, batched=True, num_proc=4, remove_columns=[\"text\"])"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Next, we need to concatenate all our texts together then split the result in small chunks of a certain `block_size`. Here we will use a block_size of `128`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "BLOCK_SIZE = 128\n",
+ "\n",
+ "def group_texts(examples):\n",
+ " # Concatenate all texts.\n",
+ " concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}\n",
+ " total_length = len(concatenated_examples[list(examples.keys())[0]])\n",
+ " # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can\n",
+ " # customize this part to your needs.\n",
+ " total_length = (total_length // BLOCK_SIZE) * BLOCK_SIZE\n",
+ " # Split by chunks of max_len.\n",
+ " result = {\n",
+ " k: [t[i : i + BLOCK_SIZE] for i in range(0, total_length, BLOCK_SIZE)]\n",
+ " for k, t in concatenated_examples.items()\n",
+ " }\n",
+ " result[\"labels\"] = result[\"input_ids\"].copy()\n",
+ " return result\n",
+ "\n",
+ "\n",
+ "lm_datasets = tokenized_datasets.map(\n",
+ " group_texts,\n",
+ " batched=True,\n",
+ " batch_size=1000,\n",
+ " num_proc=4,\n",
+ ")\n",
+ "\n",
+ "# look at a sample from the preprocessed dataset\n",
+ "tokenizer.decode(lm_datasets[\"train\"][1][\"input_ids\"])"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Model & TrainingArguments\n",
+ "\n",
+ "Now that we have our dataset prepared we are readt to instantiate the model and the Training Arguments.\n",
+ "For simpilicity we will train the model for 3 epochs and log the model predictions and metrics after each epoch."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from transformers import AutoConfig, AutoModelForCausalLM\n",
+ "from transformers import Trainer, TrainingArguments\n",
+ "\n",
+ "model = AutoModelForCausalLM.from_pretrained(MODEL_NAME,)\n",
+ "\n",
+ "training_args = TrainingArguments(\n",
+ " f\"{MODEL_NAME}-wikitext2\",\n",
+ " evaluation_strategy=\"epoch\",\n",
+ " num_train_epochs=3,\n",
+ " learning_rate=2e-5,\n",
+ " do_train=True,\n",
+ " do_eval=True,\n",
+ " weight_decay=0.01,\n",
+ " logging_strategy=\"epoch\",\n",
+ " fp16=True,\n",
+ " dataloader_num_workers=4,\n",
+ "\n",
+ ")"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## WandbPredictionProgressCallback\n",
+ "\n",
+ "To periodically visualize the results we will subclass the [`WandbCallback`](https://huggingface.co/docs/transformers/main_classes/callback#transformers.integrations.WandbCallback) from the transformers library. This callback already instrumented to log model metrics, checkpoints and system metrics to Weights & Biases.\n",
+ "\n",
+ "Here we will customize the callback to periodically log model predictions and labels to a `wandb.Table` so that we can visualize the model predictions as the training progresses. To do this, we will also need to pass the trainier and tokenizer to our callback in order to predict over the validation dataset."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from transformers.integrations import WandbCallback\n",
+ "import pandas as pd\n",
+ "import os\n",
+ "\n",
+ "os.environ[\"WANDB_LOG_MODEL\"] = \"checkpoint\"\n",
+ "\n",
+ "def decode_predictions(tokenizer, predictions):\n",
+ " labels = tokenizer.batch_decode(predictions.label_ids)\n",
+ " prediction_text = tokenizer.batch_decode(predictions.predictions.argmax(axis=-1))\n",
+ " return {\"labels\": labels, \"predictions\": prediction_text}\n",
+ "\n",
+ "\n",
+ "class WandbPredictionProgressCallback(WandbCallback):\n",
+ " \"\"\"Custom WandbCallback to log model predictions during training.\n",
+ "\n",
+ " This callback logs model predictions and labels to a wandb.Table at each logging step during training.\n",
+ " It allows to visualize the model predictions as the training progresses.\n",
+ "\n",
+ " Attributes:\n",
+ " trainer (Trainer): The Hugging Face Trainer instance.\n",
+ " tokenizer (AutoTokenizer): The tokenizer associated with the model.\n",
+ " sample_dataset (Dataset): A subset of the validation dataset for generating predictions.\n",
+ " num_samples (int, optional): Number of samples to select from the validation dataset for generating predictions. Defaults to 100.\n",
+ " \"\"\"\n",
+ "\n",
+ " def __init__(self, trainer, tokenizer, val_dataset, num_samples=100, freq=2):\n",
+ " \"\"\"Initializes the WandbPredictionProgressCallback instance.\n",
+ "\n",
+ " Args:\n",
+ " trainer (Trainer): The Hugging Face Trainer instance.\n",
+ " tokenizer (AutoTokenizer): The tokenizer associated with the model.\n",
+ " val_dataset (Dataset): The validation dataset.\n",
+ " num_samples (int, optional): Number of samples to select from the validation dataset for generating predictions. Defaults to 100.\n",
+ " freq (int, optional): Control the frequency of logging. Defaults to 2.\n",
+ " \"\"\"\n",
+ " super().__init__()\n",
+ " self.trainer = trainer\n",
+ " self.tokenizer = tokenizer\n",
+ " self.sample_dataset = val_dataset.select(range(num_samples))\n",
+ " self.freq = freq\n",
+ "\n",
+ "\n",
+ " def on_evaluate(self, args, state, control, **kwargs):\n",
+ " super().on_evaluate(args, state, control, **kwargs)\n",
+ " # control the frequency of logging by logging the predictions every `freq` epochs\n",
+ " if state.epoch % self.freq == 0:\n",
+ " # generate predictions\n",
+ " predictions = self.trainer.predict(self.sample_dataset)\n",
+ " # decode predictions and labels\n",
+ " predictions = decode_predictions(self.tokenizer, predictions)\n",
+ " # add predictions to a wandb.Table\n",
+ " predictions_df = pd.DataFrame(predictions)\n",
+ " predictions_df[\"epoch\"] = state.epoch\n",
+ " records_table = self._wandb.Table(dataframe=predictions_df)\n",
+ " # log the table to wandb\n",
+ " self._wandb.log({\"sample_predictions\": records_table})\n"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Trainer\n",
+ "\n",
+ "First we inistantiate the [`Trainer`] class with the model, training arguments, and the train, eval datasets. Since our callback needs to store predictions after each evaluation we will be passing the trainer to callback and then adding the callback to the trainer.\n",
+ "\n",
+ "**Note**: Here we donot add the callback while inistatiating the `Trainer` but instead we will use the `add_callback` method to include the callback in the trainer after instantiation."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "trainer = Trainer(\n",
+ " model=model,\n",
+ " args=training_args,\n",
+ " train_dataset=lm_datasets[\"train\"],\n",
+ " eval_dataset=lm_datasets[\"validation\"],\n",
+ "\n",
+ ")\n",
+ "\n",
+ "\n",
+ "progress_callback = WandbPredictionProgressCallback(trainer, tokenizer, lm_datasets[\"validation\"], 10)\n",
+ "trainer.add_callback(progress_callback)"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Training\n",
+ "\n",
+ "And that's it, we are ready to train the model and visualize the predictions."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "trainer.train()"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# 👀 Visualize results in dashboard\n",
+ "\n",
+ "Click the link printed out above, or go to [wandb.ai](https://app.wandb.ai) to see your results stream in live. The link to see your run in the browser will appear after all the dependencies are loaded — look for the following output:\n",
+ "\n",
+ "```\n",
+ "Tracking run with wandb version \n",
+ "Run data is saved locally in \n",
+ "Syncing run to Weights & Biases (docs)\n",
+ "View project at \n",
+ "View run at \n",
+ "```\n",
+ "\n",
+ "Click on the to visualize the sample model predictions epoch. You should see a table similar to the one shown in the screenshot below."
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "![]()"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n",
+ "## 📈 Track key information effortlessly by default\n",
+ "Weights & Biases saves a new run for each experiment. Here's the information that gets saved by default:\n",
+ "- **Hyperparameters**: Settings for your model are saved in Config\n",
+ "- **Model Metrics**: Time series data of metrics streaming in are saved in Log\n",
+ "- **Terminal Logs**: Command line outputs are saved and available in a tab\n",
+ "- **System Metrics**: GPU and CPU utilization, memory, temperature etc.\n"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n",
+ "\n",
+ "## 🤓 Learn more!\n",
+ "- [Documentation](https://docs.wandb.com/huggingface): docs on the Weights & Biases and Hugging Face integration\n",
+ "- [Videos](http://wandb.me/youtube): tutorials, interviews with practitioners, and more on our YouTube channel\n",
+ "- Contact: Message us at contact@wandb.com with questions "
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "include_colab_link": true,
+ "provenance": [],
+ "toc_visible": true
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/colabs/huggingface/Huggingface_wandb.ipynb b/colabs/huggingface/Huggingface_wandb.ipynb
index 4ff85657..bdf21058 100644
--- a/colabs/huggingface/Huggingface_wandb.ipynb
+++ b/colabs/huggingface/Huggingface_wandb.ipynb
@@ -126,7 +126,15 @@
"metadata": {},
"outputs": [],
"source": [
- "import wandb\n",
+ "import wandb"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"wandb.login()"
]
},
diff --git a/colabs/huggingface/Optimize_Hugging_Face_models_with_Weights_&_Biases.ipynb b/colabs/huggingface/Optimize_Hugging_Face_models_with_Weights_&_Biases.ipynb
index e950b3d4..4b186151 100644
--- a/colabs/huggingface/Optimize_Hugging_Face_models_with_Weights_&_Biases.ipynb
+++ b/colabs/huggingface/Optimize_Hugging_Face_models_with_Weights_&_Biases.ipynb
@@ -109,8 +109,15 @@
"metadata": {},
"outputs": [],
"source": [
- "import wandb\n",
- "\n",
+ "import wandb\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"wandb.login()"
]
},
diff --git a/colabs/huggingface/Simple_accelerate_integration_wandb.ipynb b/colabs/huggingface/Simple_accelerate_integration_wandb.ipynb
index b371a500..5898e68c 100644
--- a/colabs/huggingface/Simple_accelerate_integration_wandb.ipynb
+++ b/colabs/huggingface/Simple_accelerate_integration_wandb.ipynb
@@ -2,6 +2,7 @@
"cells": [
{
"cell_type": "markdown",
+ "id": "205e453e",
"metadata": {},
"source": [
"\n",
@@ -11,7 +12,7 @@
{
"attachments": {},
"cell_type": "markdown",
- "id": "cadda531-5f7b-4793-876b-dceedec26cc5",
+ "id": "fe7c4702",
"metadata": {},
"source": [
"# Using Huggingface Accelerate with Weights and Biases\n",
@@ -20,7 +21,7 @@
},
{
"cell_type": "markdown",
- "id": "c5ccfcf6-5d99-474c-a58e-77783906efb4",
+ "id": "04b2cd3c",
"metadata": {},
"source": [
"[Accelerate](https://github.com/huggingface/accelerate) is this amazing little framework that simplifies your PyTorch training scripts enabling you to train with all the tricks out there!\n",
@@ -68,7 +69,7 @@
},
{
"cell_type": "markdown",
- "id": "02aa793a-f2d3-468c-823e-492d2bebe4b7",
+ "id": "179d31c1",
"metadata": {},
"source": [
"## Training and Image Classifier"
@@ -77,7 +78,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "6c83e24e-ad59-4038-8a4b-67f6dcce94cb",
+ "id": "1a080fa1",
"metadata": {},
"outputs": [],
"source": [
@@ -87,7 +88,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "90dc4bba-0181-4a8b-bfc5-6ae87a3c07a4",
+ "id": "a62d276f",
"metadata": {},
"outputs": [],
"source": [
@@ -111,7 +112,7 @@
},
{
"cell_type": "markdown",
- "id": "8760c536-636b-4cd9-b244-3bb61254ebd6",
+ "id": "6f2923f8",
"metadata": {},
"source": [
"Store your configuration parameters"
@@ -120,7 +121,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "0afbcd78-f3f8-4cfc-a89c-5e3767d16e94",
+ "id": "f2e5118b",
"metadata": {},
"outputs": [],
"source": [
@@ -137,7 +138,7 @@
},
{
"cell_type": "markdown",
- "id": "0370911f-7e69-43d3-93ec-50ed271c9ae6",
+ "id": "6e038818",
"metadata": {},
"source": [
"setup transforms"
@@ -146,7 +147,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "8a68787b-43a7-459c-97b8-6b483291a821",
+ "id": "8424eb9a",
"metadata": {},
"outputs": [],
"source": [
@@ -159,7 +160,7 @@
},
{
"cell_type": "markdown",
- "id": "fed9579d-244c-47ee-9131-51d9ef7aa069",
+ "id": "43a359aa",
"metadata": {},
"source": [
"Create a simple CNN"
@@ -168,7 +169,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "aead4006-6a4e-4009-b9b8-776d1ca1bf62",
+ "id": "3e903bb8",
"metadata": {},
"outputs": [],
"source": [
@@ -188,7 +189,7 @@
},
{
"cell_type": "markdown",
- "id": "ac8f89dd-4e18-4644-9d6a-b7ae7a5447a9",
+ "id": "d07a7cf2",
"metadata": {},
"source": [
"Wrap everything into a training functions (this is necessary to run on multiple GPUS, if it is only one, you can skip the wrapping)"
@@ -197,7 +198,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "26cbaf18-655c-4a35-8060-8c25eb7f27fc",
+ "id": "78346075",
"metadata": {},
"outputs": [],
"source": [
@@ -250,7 +251,7 @@
},
{
"cell_type": "markdown",
- "id": "e6227c06-2365-4386-9250-e1c22ac84bec",
+ "id": "2e162db1",
"metadata": {},
"source": [
"Let's train on 2 GPUs! This is really nice, as accelerate will take care of only calling `log` on the main process, so only one run get's created, so no need to manually check the rank of the process when using multiple GPUs."
@@ -259,7 +260,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "028f7217-9a0f-4881-9219-38f1a65dbf53",
+ "id": "027b2d42",
"metadata": {},
"outputs": [],
"source": [
diff --git a/colabs/huggingface/Visualize_your_Hugging_Face_data_with_Weights_&_Biases.ipynb b/colabs/huggingface/Visualize_your_Hugging_Face_data_with_Weights_&_Biases.ipynb
deleted file mode 100644
index 6be6becf..00000000
--- a/colabs/huggingface/Visualize_your_Hugging_Face_data_with_Weights_&_Biases.ipynb
+++ /dev/null
@@ -1,556 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "\n",
- ""
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "\n",
- "\n",
- "\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# Visualize your 🤗 Hugging Face data with 🏋♀️ Weights & Biases "
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "[Hugging Face](https://huggingface.co/) provides tools to quickly train neural networks for NLP (Natural Language Processing) on any task (classification, translation, question answering, etc) and any dataset with PyTorch and TensorFlow 2.0.\n",
- "\n",
- "Coupled with [Weights & Biases integration](https://docs.wandb.ai/integrations/huggingface), you can quickly train and monitor models for full traceability and reproducibility without any extra line of code! You just need to install the library, sign in, and your experiments will automatically be logged:\n",
- "\n",
- "```bash\n",
- "pip install wandb\n",
- "wandb login\n",
- "```\n",
- "\n",
- "**Note**: To enable logging to W&B, set `report_to` to `wandb` in your `TrainingArguments` or script."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 🛠️ Installation and set-up"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "!pip install -qU transformers datasets wandb"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "import wandb\n",
- "wandb.login()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# 🛫 Data and model preparation"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### 🏷️ Loading a dataset"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "We'll download the\n",
- "[Yahoo! Answers Dataset](https://paperswithcode.com/sota/text-classification-on-yahoo-answers),\n",
- "which includes questions and answers from the now-defunct\n",
- "[Yahoo! Answers forum](https://en.wikipedia.org/wiki/Yahoo!_Answers)."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from datasets import load_dataset\n",
- "dataset = load_dataset(\"yahoo_answers_topics\")\n",
- "dataset"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Data includes question titles, question bodies and best answers (for some questions), and a question topic from one of 10 choices."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "dataset['train'][0]"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "For our text classification task, we use `question_title` as input and try to predict `topic`."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "label_list = dataset['train'].unique('topic')\n",
- "label_list.sort()\n",
- "label_list"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "num_labels = len(label_list)\n",
- "num_labels"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "The \"topic\" class needs to be renamed to \"labels\" for the `Trainer` to find it."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "dataset = dataset.rename_column('topic', 'labels')"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "For the purpose of this demo, we'll decrease the size of the dataset for faster training."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "dataset['train'] = dataset['train'].select(range(dataset['train'].num_rows // 10))\n",
- "dataset['test'] = dataset['test'].select(range(dataset['test'].num_rows // 10))"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### ⚙️ Tokenizing the dataset"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "In a typical NLP workflow, we must first tokenize our dataset -- converting the stream of characters in the text into a stream of defined \"tokens\", which can be anything from a smaller set of characters to words from a vocabulary.\n",
- "\n",
- "We will use a pretrained model for our , so we inherit its tokenization scheme.\n",
- "Hugging Face makes this easy with the `AutoTokenizer`:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from transformers import AutoTokenizer\n",
- "\n",
- "tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "sample_input = dataset['train'][0]['question_title']\n",
- "print(sample_input, \"\\n\", tokenizer(sample_input))"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "We then map the tokenizer over our dataset:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "dataset = dataset.map(lambda x: tokenizer(x['question_title'], truncation=True), batched=True)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### ✨ Loading a model"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "In this case, we are loading a pre-trained network to which a custom head has been added for sequence classification.\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from transformers import AutoModelForSequenceClassification\n",
- "\n",
- "model = AutoModelForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=num_labels)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "\n",
- "Let's make a function to return the topic prediction from a sample question."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "import torch\n",
- "\n",
- "def get_topic(sentence, tokenize=tokenizer, model=model):\n",
- " # tokenize the input\n",
- " inputs = tokenizer(sentence, return_tensors='pt')\n",
- " # ensure model and inputs are on the same device (GPU)\n",
- " inputs = {name: tensor.cuda() for name, tensor in inputs.items()}\n",
- " model = model.cuda()\n",
- " # get prediction - 10 classes \"probabilities\" (not really true because they still need to be normalized)\n",
- " with torch.no_grad():\n",
- " predictions = model(**inputs)[0].cpu().numpy()\n",
- " # get the top prediction class and convert it to its associated label\n",
- " top_prediction = predictions.argmax().item()\n",
- " return dataset['train'].features['labels'].int2str(top_prediction)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "We can test our prediction pipeline on a sample sentence."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "get_topic('Why is cheese so much better with wine?')"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "When we ran our model, we got the answer `Politics & Government`,\n",
- "which doesn't seem quite right for a question about cheese and wine.\n",
- "\n",
- "That's because the model has not been trained yet so the outputs are still random. But at least we have a working pipeline!"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# 📊 Log your data for better visualization\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Before we train our model, let's set up some better logging during training.\n",
- "\n",
- "Without the ability to inspect model behavior, it can be hard to debug or understand models.\n",
- "So we'll log a table of information about the model's behavior on the validation set --\n",
- "not just the loss or accuracy, but the inputs and outputs as well.\n",
- "\n",
- "Our data is already in a pandas `DataFrame`, so there's not much we have to do besides\n",
- "slightly reformat them\n",
- "and then use them to define a `ValidationDataLogger` instance."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from wandb.sdk.integration_utils.data_logging import ValidationDataLogger"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "validation_inputs = dataset['test'].remove_columns(['labels', 'attention_mask', 'input_ids'])\n",
- "validation_targets = [dataset['test'].features['labels'].int2str(x) for x in dataset['test']['labels']]\n",
- "\n",
- "validation_inputs[0], validation_targets[0]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "validation_logger = ValidationDataLogger(\n",
- " inputs = validation_inputs[:],\n",
- " targets = validation_targets\n",
- ")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "We can now log our predictions for visualization with `validation_logger.log_predictions(prediction_labels)`."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# 👟 Training the model and logging to W&B\n",
- "\n",
- "We are now ready to fine-tune the model to solve our task.\n",
- "\n",
- "The Hugging Face [`Trainer` class](https://huggingface.co/transformers/main_classes/trainer.html)\n",
- "lets us easily train a model and is very flexible.\n",
- "\n",
- "**Note:** set `report_to` to `wandb` in `TrainingArguments` to enable logging through W&B."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from transformers import Trainer, TrainingArguments\n",
- "\n",
- "args = TrainingArguments(\n",
- " report_to='wandb', # enable logging to W&B\n",
- " output_dir='topic_classification', # set output directory\n",
- " overwrite_output_dir=True,\n",
- " evaluation_strategy='steps', # check evaluation metrics on a given # of steps\n",
- " learning_rate=5e-5, # we can customize learning rate\n",
- " max_steps=1000,\n",
- " logging_steps=100, # we will log every 100 steps\n",
- " eval_steps=500, # we will perform evaluation every 1000 steps\n",
- " eval_accumulation_steps=1, # report evaluation results after each step\n",
- " load_best_model_at_end=True,\n",
- " metric_for_best_model='accuracy',\n",
- " run_name='my_training_run' # name of the W&B run\n",
- ")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# automatically log model to W&B at the end\n",
- "%env WANDB_LOG_MODEL=true"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "For more customization, refer to [`TrainingArguments` documentation](https://huggingface.co/transformers/main_classes/trainer.html#transformers.TrainingArguments)."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "We can optionally define metrics to calculate in addition to the loss through the `compute_metrics` function.\n",
- "\n",
- "Several [metrics](https://huggingface.co/metrics) are readily available from the datasets library to monitor model performance.\n",
- "\n",
- "We'll also use it to log all of our predictions at each evaluation loop,\n",
- "using the `validation_logger`."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from datasets import load_metric\n",
- "import numpy as np\n",
- "\n",
- "accuracy_metric = load_metric(\"accuracy\")\n",
- "\n",
- "def compute_metrics(eval_pred):\n",
- " predictions, labels = eval_pred\n",
- " predictions = np.argmax(predictions, axis=1)\n",
- "\n",
- " # convert predictions from class (0, 1, 2…) to label (Health, Science…)\n",
- " prediction_labels = [dataset['test'].features['labels'].int2str(x.item())\n",
- " for x in predictions]\n",
- " \n",
- " # log predictions\n",
- " validation_logger.log_predictions(prediction_labels)\n",
- "\n",
- " # metrics from the datasets library have a compute method\n",
- " return accuracy_metric.compute(predictions=predictions, references=labels)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "The `Trainer` handles all the training & evaluation logic."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "trainer = Trainer(\n",
- " model=model, # model to be trained\n",
- " args=args, # training args\n",
- " train_dataset=dataset['train'],\n",
- " eval_dataset=dataset['test'],\n",
- " tokenizer=tokenizer, # for padding batched data\n",
- " compute_metrics=compute_metrics # for custom metrics\n",
- ")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "trainer.train()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "We can now easily observe your predictions live in your run dashboard!\n",
- "\n",
- "![image.png]()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "We can monitor losses, metrics, gradients and parameters as the model trains.\n",
- "\n",
- "![image.png]()\n",
- "\n",
- "When training is complete, our model is logged and versioned along with its performance as metadata.\n",
- "\n",
- "![image.png]()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "When we want to close our W&B run, we can call `wandb.finish()` (mainly useful in notebooks, called automatically in scripts)."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "wandb.finish()"
- ]
- }
- ],
- "metadata": {
- "accelerator": "GPU",
- "colab": {
- "include_colab_link": true,
- "provenance": [],
- "toc_visible": true
- },
- "kernelspec": {
- "display_name": "Python 3",
- "name": "python3"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 0
-}
diff --git a/colabs/huggingface/wandb_hf_example.ipynb b/colabs/huggingface/wandb_hf_example.ipynb
index 2c94bdce..b1bcf54f 100644
--- a/colabs/huggingface/wandb_hf_example.ipynb
+++ b/colabs/huggingface/wandb_hf_example.ipynb
@@ -89,7 +89,15 @@
"outputs": [],
"source": [
"# Login and authenticate Weights & Biases\n",
- "import wandb\n",
+ "import wandb"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"wandb.login()"
]
},
diff --git a/colabs/intro/3_in_1_Intro_to_Weights_&_Biases_CV,_NLP_and_RL.ipynb b/colabs/intro/3_in_1_Intro_to_Weights_&_Biases_CV,_NLP_and_RL.ipynb
index c1fd9d6b..75becb55 100644
--- a/colabs/intro/3_in_1_Intro_to_Weights_&_Biases_CV,_NLP_and_RL.ipynb
+++ b/colabs/intro/3_in_1_Intro_to_Weights_&_Biases_CV,_NLP_and_RL.ipynb
@@ -126,7 +126,15 @@
"metadata": {},
"outputs": [],
"source": [
- "import wandb\n",
+ "import wandb"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"wandb.login()"
]
},
diff --git a/colabs/intro/Intro_to_Weights_&_Biases.ipynb b/colabs/intro/Intro_to_Weights_&_Biases.ipynb
index 23aa4b2f..e2705b2c 100644
--- a/colabs/intro/Intro_to_Weights_&_Biases.ipynb
+++ b/colabs/intro/Intro_to_Weights_&_Biases.ipynb
@@ -76,7 +76,15 @@
"outputs": [],
"source": [
"# Log in to your W&B account\n",
- "import wandb\n",
+ "import wandb"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"wandb.login()"
]
},
diff --git a/colabs/intro/Intro_to_Weights_&_Biases_keras.ipynb b/colabs/intro/Intro_to_Weights_&_Biases_keras.ipynb
index df3f2a22..7c80dc77 100644
--- a/colabs/intro/Intro_to_Weights_&_Biases_keras.ipynb
+++ b/colabs/intro/Intro_to_Weights_&_Biases_keras.ipynb
@@ -73,7 +73,15 @@
"outputs": [],
"source": [
"# Log in to your W&B account\n",
- "import wandb\n",
+ "import wandb"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"wandb.login()"
]
},
diff --git a/colabs/intro/Report_API_Quickstart.ipynb b/colabs/intro/Report_API_Quickstart.ipynb
index f5b21445..5945ccba 100644
--- a/colabs/intro/Report_API_Quickstart.ipynb
+++ b/colabs/intro/Report_API_Quickstart.ipynb
@@ -38,11 +38,11 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 61,
"metadata": {},
"outputs": [],
"source": [
- "!pip install wandb"
+ "!pip install wandb -qqq"
]
},
{
@@ -521,7 +521,7 @@
"source": [
"report = wr.Report(\n",
" PROJECT,\n",
- " title=\"Resizing panels\"\n",
+ " title=\"Resizing panels\",\n",
" description=\"Look at this wide parallel coordinates plot!\",\n",
" blocks=[\n",
" wr.PanelGrid(\n",
@@ -563,7 +563,7 @@
"report = wr.Report(\n",
" PROJECT,\n",
" title='W&B Block Gallery',\n",
- " description=\"Check out all of the blocks available in W&B\"\n",
+ " description=\"Check out all of the blocks available in W&B\",\n",
" blocks=[\n",
" wr.H1(text=\"Heading 1\"),\n",
" wr.P(text=\"Normal paragraph\"),\n",
@@ -621,7 +621,7 @@
"report = wr.Report(\n",
" project=PROJECT,\n",
" title='W&B Panel Gallery',\n",
- " description=\"Check out all of the panels available in W&B\"\n",
+ " description=\"Check out all of the panels available in W&B\",\n",
" width='fluid',\n",
" blocks=[\n",
" wr.PanelGrid(\n",
@@ -703,7 +703,7 @@
" wr.CustomChart(\n",
" query={'summary': ['val_loss', 'val_acc']},\n",
" chart_name='wandb/scatter/v0',\n",
- " user_fields={'x': 'val_loss', 'y': 'val_acc'}\n",
+ " chart_fields={'x': 'val_loss', 'y': 'val_acc'}\n",
" ),\n",
" wr.WeavePanelSummaryTable(\"my-table\"),\n",
" wr.WeavePanelArtifact('model-1', 'lineage', layout={'w': 24, 'h': 12}),\n",
@@ -812,7 +812,7 @@
"source": [
"report = wr.Report(PROJECT,\n",
" title=\"Report with links\",\n",
- " description=\"Use `wr.Link(text, url)` to add links inside normal text, or use normal markdown syntax in a MarkdownBlock\"\n",
+ " description=\"Use `wr.Link(text, url)` to add links inside normal text, or use normal markdown syntax in a MarkdownBlock\",\n",
" blocks=[\n",
" wr.H1(\"This is a normal heading\"),\n",
" wr.P(\"And here is some normal text\"),\n",
@@ -835,7 +835,7 @@
"report3 = wr.Report(\n",
" PROJECT,\n",
" title=\"Combined blocks report\",\n",
- " description=\"This report combines blocks from both Report 1 and Report 2\"\n",
+ " description=\"This report combines blocks from both Report 1 and Report 2\",\n",
" blocks=[*report1.blocks, *report2.blocks]\n",
")\n",
"report3.save()"
@@ -857,7 +857,7 @@
"report4 = wr.Report(\n",
" PROJECT,\n",
" title=\"Referenced reports via Gallery\",\n",
- " description=\"This report has gallery links to Report1 and Report 2\"\n",
+ " description=\"This report has gallery links to Report1 and Report 2\",\n",
" blocks=[wr.Gallery(ids=[report1.id, report2.id])]\n",
")\n",
"report4.save()"
@@ -939,7 +939,7 @@
"report = wr.Report(\n",
" PROJECT,\n",
" title=\"Report with markdown\",\n",
- " description=\"See what's possible with MarkdownBlock and MarkdownPanel\"\n",
+ " description=\"See what's possible with MarkdownBlock and MarkdownPanel\",\n",
" blocks=[\n",
" wr.MarkdownBlock(markdown),\n",
" wr.PanelGrid(\n",
@@ -1082,7 +1082,7 @@
"report = wr.Report(\n",
" project=PROJECT,\n",
" title=\"Parallel Coordinates Example (all-in-one)\",\n",
- " description=\"Using the pytorch sweeps demo (same as the other one but written in one expression)\"\n",
+ " description=\"Using the pytorch sweeps demo (same as the other one but written in one expression)\",\n",
" blocks=[\n",
" wr.PanelGrid(\n",
" runsets=[wr.Runset(project=\"pytorch-sweeps-demo\")],\n",
@@ -1259,7 +1259,7 @@
"metadata": {},
"outputs": [],
"source": [
- "report = wr.Report(project=PROJECT, title='Adding artifact lineage to reports', \"via WeaveBlockArtifact, WeaveBlockArtifactVersionedFile, or their panel equivalents\")"
+ "report = wr.Report(project=PROJECT, title='Adding artifact lineage to reports', description=\"via WeaveBlockArtifact, WeaveBlockArtifactVersionedFile, or their panel equivalents\")"
]
},
{
@@ -1597,6 +1597,18 @@
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.10"
}
},
"nbformat": 4,
diff --git a/colabs/intro/boom.txt b/colabs/intro/boom.txt
new file mode 100644
index 00000000..89656223
--- /dev/null
+++ b/colabs/intro/boom.txt
@@ -0,0 +1 @@
+tcnuqua1t6tw
\ No newline at end of file
diff --git a/colabs/jax/training_with_tfrecords_in_jax_imagenette.ipynb b/colabs/jax/training_with_tfrecords_in_jax_imagenette.ipynb
index b4191be8..04d7418c 100644
--- a/colabs/jax/training_with_tfrecords_in_jax_imagenette.ipynb
+++ b/colabs/jax/training_with_tfrecords_in_jax_imagenette.ipynb
@@ -504,13 +504,6 @@
"source": [
"wandb.finish()"
]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
}
],
"metadata": {
diff --git a/colabs/jupyter/Interactive_W&B_Charts_Inside_Jupyter.ipynb b/colabs/jupyter/Interactive_W&B_Charts_Inside_Jupyter.ipynb
index 3918bb68..c7b29721 100644
--- a/colabs/jupyter/Interactive_W&B_Charts_Inside_Jupyter.ipynb
+++ b/colabs/jupyter/Interactive_W&B_Charts_Inside_Jupyter.ipynb
@@ -95,7 +95,15 @@
"metadata": {},
"outputs": [],
"source": [
- "import wandb\n",
+ "import wandb"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"wandb.login()"
]
},
diff --git a/colabs/keras/Fine_tune_Vision_Transformer_using_KerasCV.ipynb b/colabs/keras/Fine_tune_Vision_Transformer_using_KerasCV.ipynb
index 8554cbef..68db1c28 100644
--- a/colabs/keras/Fine_tune_Vision_Transformer_using_KerasCV.ipynb
+++ b/colabs/keras/Fine_tune_Vision_Transformer_using_KerasCV.ipynb
@@ -47,8 +47,15 @@
"\n",
"import wandb\n",
"from wandb.keras import WandbMetricsLogger\n",
- "from wandb.keras import WandbEvalCallback\n",
- "\n",
+ "from wandb.keras import WandbEvalCallback\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"wandb.login()"
]
},
@@ -321,13 +328,6 @@
"# Close the W&B run\n",
"wandb.finish()"
]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
}
],
"metadata": {
diff --git a/colabs/keras/Image_Segmentation_with_Keras.ipynb b/colabs/keras/Image_Segmentation_with_Keras.ipynb
index 85cf9cae..7caba275 100644
--- a/colabs/keras/Image_Segmentation_with_Keras.ipynb
+++ b/colabs/keras/Image_Segmentation_with_Keras.ipynb
@@ -34,8 +34,15 @@
"source": [
"import wandb\n",
"from wandb.keras import WandbMetricsLogger\n",
- "from wandb.keras import WandbEvalCallback\n",
- "\n",
+ "from wandb.keras import WandbEvalCallback\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"wandb.login()"
]
},
diff --git a/colabs/keras/Keras_param_opti_using_sweeps.ipynb b/colabs/keras/Keras_param_opti_using_sweeps.ipynb
index b6175ce6..a8b755f3 100644
--- a/colabs/keras/Keras_param_opti_using_sweeps.ipynb
+++ b/colabs/keras/Keras_param_opti_using_sweeps.ipynb
@@ -102,8 +102,15 @@
"outputs": [],
"source": [
"import wandb\n",
- "from wandb.keras import WandbMetricsLogger, WandbModelCheckpoint\n",
- "\n",
+ "from wandb.keras import WandbMetricsLogger, WandbModelCheckpoint\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"wandb.login()"
]
},
diff --git a/colabs/keras/Legacy_Keras_pipeline_with_Weights_and_Biases.ipynb b/colabs/keras/Legacy_Keras_pipeline_with_Weights_and_Biases.ipynb
index 3008b3f8..97836a06 100644
--- a/colabs/keras/Legacy_Keras_pipeline_with_Weights_and_Biases.ipynb
+++ b/colabs/keras/Legacy_Keras_pipeline_with_Weights_and_Biases.ipynb
@@ -112,7 +112,15 @@
"metadata": {},
"outputs": [],
"source": [
- "# Login to W&B\n",
+ "# Login to W&B"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"wandb.login()"
]
},
diff --git a/colabs/keras/Simple_Keras_Integration.ipynb b/colabs/keras/Simple_Keras_Integration.ipynb
index f5e51f51..bf5d8739 100644
--- a/colabs/keras/Simple_Keras_Integration.ipynb
+++ b/colabs/keras/Simple_Keras_Integration.ipynb
@@ -120,8 +120,15 @@
"outputs": [],
"source": [
"import wandb\n",
- "from wandb.keras import WandbMetricsLogger, WandbModelCheckpoint, WandbEvalCallback\n",
- "\n",
+ "from wandb.keras import WandbMetricsLogger, WandbModelCheckpoint, WandbEvalCallback\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"wandb.login()"
]
},
diff --git a/colabs/keras/cosine_decay_using_keras.ipynb b/colabs/keras/cosine_decay_using_keras.ipynb
index 9128a2c2..17e2d329 100644
--- a/colabs/keras/cosine_decay_using_keras.ipynb
+++ b/colabs/keras/cosine_decay_using_keras.ipynb
@@ -40,8 +40,15 @@
"\n",
"# Weights and Biases related imports\n",
"import wandb\n",
- "from wandb.keras import WandbMetricsLogger\n",
- "\n",
+ "from wandb.keras import WandbMetricsLogger\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"wandb.login()"
]
},
diff --git a/colabs/keras/keras_core/monai_medmnist_keras.ipynb b/colabs/keras/keras_core/monai_medmnist_keras.ipynb
new file mode 100644
index 00000000..747baa6f
--- /dev/null
+++ b/colabs/keras/keras_core/monai_medmnist_keras.ipynb
@@ -0,0 +1,477 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "dMq6cFtJl2vR"
+ },
+ "source": [
+ "\n",
+ "\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "2eqSE_8rl6yo"
+ },
+ "source": [
+ "# 🩺 Medical Image Classification Tutorial using MonAI and Keras\n",
+ "\n",
+ "[![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/wandb/examples/blob/master/colabs/keras/keras_core/monai_medmnist_keras.ipynb)\n",
+ "\n",
+ "This notebook demonstrates\n",
+ "- an end-to-end training using [MonAI](https://github.com/Project-MONAI/MONAI) and [KerasCore](https://github.com/keras-team/keras-core).\n",
+ "- how we can use the backend-agnostic Keras callbacks for [Weights & Biases](https://wandb.ai/site) to manage and track our experiment.\n",
+ "\n",
+ "Original Notebook: https://github.com/Project-MONAI/tutorials/blob/main/2d_classification/mednist_tutorial.ipynb"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Ny_HOlvymX6O"
+ },
+ "source": [
+ "## Installing and Importing the Dependencies"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "6SvDmUxVmatW"
+ },
+ "source": [
+ "- We install the `main` branch of [KerasCore](https://github.com/keras-team/keras-core), this lets us use the latest feature merged in KerasCore.\n",
+ "- We install [monai](https://github.com/Project-MONAI/MONAI), a PyTorch-based, open-source framework for deep learning in healthcare imaging, part of the PyTorch Ecosystem.\n",
+ "- We also install [wandb-addons](https://github.com/soumik12345/wandb-addons), a library that hosts the backend-agnostic callbacks compatible with KerasCore"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "ZO1xjFtvVdqQ"
+ },
+ "outputs": [],
+ "source": [
+ "# install the `main` branch of KerasCore\n",
+ "!pip install -qq namex\n",
+ "!apt install python3.10-venv\n",
+ "!git clone https://github.com/soumik12345/keras-core.git && cd keras-core && python pip_build.py --install\n",
+ "\n",
+ "# install monai and wandb-addons\n",
+ "!pip install -qq git+https://github.com/soumik12345/wandb-addons\n",
+ "!pip install -q \"monai-weekly[pillow, tqdm]\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "c6vd3NZ-mhxs"
+ },
+ "source": [
+ "We specify the Keras backend to be using `torch` by explicitly specifying the environment variable `KERAS_BACKEND`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "d5ZiQmMkW-h3"
+ },
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "os.environ[\"KERAS_BACKEND\"] = \"torch\"\n",
+ "\n",
+ "import shutil\n",
+ "import tempfile\n",
+ "import matplotlib.pyplot as plt\n",
+ "import PIL\n",
+ "import torch\n",
+ "import numpy as np\n",
+ "from sklearn.metrics import classification_report\n",
+ "\n",
+ "import keras_core as keras\n",
+ "from keras_core.utils import TorchModuleWrapper\n",
+ "\n",
+ "from monai.apps import download_and_extract\n",
+ "from monai.config import print_config\n",
+ "from monai.data import decollate_batch, DataLoader\n",
+ "from monai.metrics import ROCAUCMetric\n",
+ "from monai.networks.nets import DenseNet121\n",
+ "from monai.transforms import (\n",
+ " Activations,\n",
+ " EnsureChannelFirst,\n",
+ " AsDiscrete,\n",
+ " Compose,\n",
+ " LoadImage,\n",
+ " RandFlip,\n",
+ " RandRotate,\n",
+ " RandZoom,\n",
+ " ScaleIntensity,\n",
+ ")\n",
+ "from monai.utils import set_determinism\n",
+ "\n",
+ "import wandb\n",
+ "from wandb_addons.keras import WandbMetricsLogger, WandbModelCheckpoint"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "1BvtLvxKmkSp"
+ },
+ "source": [
+ "We initialize a [wandb run](https://docs.wandb.ai/guides/runs) and set the configs for the experiment."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "WUt7PAsado4j"
+ },
+ "outputs": [],
+ "source": [
+ "wandb.init(project=\"keras-torch\")\n",
+ "\n",
+ "config = wandb.config\n",
+ "config.batch_size = 128\n",
+ "config.num_epochs = 1"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "NE938skxmoMI"
+ },
+ "source": [
+ "## Setup data directory\n",
+ "\n",
+ "You can specify a directory with the `MONAI_DATA_DIRECTORY` environment variable.\n",
+ "This allows you to save results and reuse downloads.\n",
+ "If not specified a temporary directory will be used."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "82KX1Sj6XXY1"
+ },
+ "outputs": [],
+ "source": [
+ "directory = os.environ.get(\"MONAI_DATA_DIRECTORY\")\n",
+ "root_dir = tempfile.mkdtemp() if directory is None else directory\n",
+ "print(root_dir)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "nkVQ-tojmzwf"
+ },
+ "source": [
+ "## Download dataset\n",
+ "\n",
+ "The MedNIST dataset was gathered from several sets from [TCIA](https://wiki.cancerimagingarchive.net/display/Public/Data+Usage+Policies+and+Restrictions),\n",
+ "[the RSNA Bone Age Challenge](http://rsnachallenges.cloudapp.net/competitions/4),\n",
+ "and [the NIH Chest X-ray dataset](https://cloud.google.com/healthcare/docs/resources/public-datasets/nih-chest).\n",
+ "\n",
+ "The dataset is kindly made available by [Dr. Bradley J. Erickson M.D., Ph.D.](https://www.mayo.edu/research/labs/radiology-informatics/overview) (Department of Radiology, Mayo Clinic)\n",
+ "under the Creative Commons [CC BY-SA 4.0 license](https://creativecommons.org/licenses/by-sa/4.0/).\n",
+ "\n",
+ "If you use the MedNIST dataset, please acknowledge the source."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "DAHybyvdXZoH"
+ },
+ "outputs": [],
+ "source": [
+ "resource = \"https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/MedNIST.tar.gz\"\n",
+ "md5 = \"0bc7306e7427e00ad1c5526a6677552d\"\n",
+ "\n",
+ "compressed_file = os.path.join(root_dir, \"MedNIST.tar.gz\")\n",
+ "data_dir = os.path.join(root_dir, \"MedNIST\")\n",
+ "if not os.path.exists(data_dir):\n",
+ " download_and_extract(resource, compressed_file, root_dir, md5)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "u6m2Uas1nMi8"
+ },
+ "source": [
+ "## Read image filenames from the dataset folders\n",
+ "\n",
+ "First of all, check the dataset files and show some statistics. \n",
+ "There are 6 folders in the dataset: Hand, AbdomenCT, CXR, ChestCT, BreastMRI, HeadCT, \n",
+ "which should be used as the labels to train our classification model."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "qtLr1T0gXydq"
+ },
+ "outputs": [],
+ "source": [
+ "class_names = sorted(x for x in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, x)))\n",
+ "num_class = len(class_names)\n",
+ "image_files = [\n",
+ " [os.path.join(data_dir, class_names[i], x) for x in os.listdir(os.path.join(data_dir, class_names[i]))]\n",
+ " for i in range(num_class)\n",
+ "]\n",
+ "num_each = [len(image_files[i]) for i in range(num_class)]\n",
+ "image_files_list = []\n",
+ "image_class = []\n",
+ "for i in range(num_class):\n",
+ " image_files_list.extend(image_files[i])\n",
+ " image_class.extend([i] * num_each[i])\n",
+ "num_total = len(image_class)\n",
+ "image_width, image_height = PIL.Image.open(image_files_list[0]).size\n",
+ "\n",
+ "print(f\"Total image count: {num_total}\")\n",
+ "print(f\"Image dimensions: {image_width} x {image_height}\")\n",
+ "print(f\"Label names: {class_names}\")\n",
+ "print(f\"Label counts: {num_each}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "k4WOUb4IX6KK"
+ },
+ "outputs": [],
+ "source": [
+ "plt.subplots(3, 3, figsize=(8, 8))\n",
+ "for i, k in enumerate(np.random.randint(num_total, size=9)):\n",
+ " im = PIL.Image.open(image_files_list[k])\n",
+ " arr = np.array(im)\n",
+ " plt.subplot(3, 3, i + 1)\n",
+ " plt.xlabel(class_names[image_class[k]])\n",
+ " plt.imshow(arr, cmap=\"gray\", vmin=0, vmax=255)\n",
+ "plt.tight_layout()\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "wOGja_mEnQ5m"
+ },
+ "source": [
+ "## Prepare training, validation and test data lists\n",
+ "\n",
+ "Randomly select 10% of the dataset as validation and 10% as test."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "BmHhKta1X8ZM"
+ },
+ "outputs": [],
+ "source": [
+ "val_frac = 0.1\n",
+ "test_frac = 0.1\n",
+ "length = len(image_files_list)\n",
+ "indices = np.arange(length)\n",
+ "np.random.shuffle(indices)\n",
+ "\n",
+ "test_split = int(test_frac * length)\n",
+ "val_split = int(val_frac * length) + test_split\n",
+ "test_indices = indices[:test_split]\n",
+ "val_indices = indices[test_split:val_split]\n",
+ "train_indices = indices[val_split:]\n",
+ "\n",
+ "train_x = [image_files_list[i] for i in train_indices]\n",
+ "train_y = [image_class[i] for i in train_indices]\n",
+ "val_x = [image_files_list[i] for i in val_indices]\n",
+ "val_y = [image_class[i] for i in val_indices]\n",
+ "test_x = [image_files_list[i] for i in test_indices]\n",
+ "test_y = [image_class[i] for i in test_indices]\n",
+ "\n",
+ "print(f\"Training count: {len(train_x)}, Validation count: \" f\"{len(val_x)}, Test count: {len(test_x)}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Lkhtr3p8nT1G"
+ },
+ "source": [
+ "## Define MONAI transforms, Dataset and Dataloader to pre-process data"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "2b5H2WuUYLsc"
+ },
+ "outputs": [],
+ "source": [
+ "train_transforms = Compose(\n",
+ " [\n",
+ " LoadImage(image_only=True),\n",
+ " EnsureChannelFirst(),\n",
+ " ScaleIntensity(),\n",
+ " RandRotate(range_x=np.pi / 12, prob=0.5, keep_size=True),\n",
+ " RandFlip(spatial_axis=0, prob=0.5),\n",
+ " RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5),\n",
+ " ]\n",
+ ")\n",
+ "\n",
+ "val_transforms = Compose([LoadImage(image_only=True), EnsureChannelFirst(), ScaleIntensity()])\n",
+ "\n",
+ "y_pred_trans = Compose([Activations(softmax=True)])\n",
+ "y_trans = Compose([AsDiscrete(to_onehot=num_class)])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "QdFnaFUDYOaq"
+ },
+ "outputs": [],
+ "source": [
+ "class MedNISTDataset(torch.utils.data.Dataset):\n",
+ " def __init__(self, image_files, labels, transforms):\n",
+ " self.image_files = image_files\n",
+ " self.labels = labels\n",
+ " self.transforms = transforms\n",
+ "\n",
+ " def __len__(self):\n",
+ " return len(self.image_files)\n",
+ "\n",
+ " def __getitem__(self, index):\n",
+ " return self.transforms(self.image_files[index]), self.labels[index]\n",
+ "\n",
+ "\n",
+ "train_ds = MedNISTDataset(train_x, train_y, train_transforms)\n",
+ "train_loader = DataLoader(train_ds, batch_size=config.batch_size, shuffle=True, num_workers=2)\n",
+ "\n",
+ "val_ds = MedNISTDataset(val_x, val_y, val_transforms)\n",
+ "val_loader = DataLoader(val_ds, batch_size=config.batch_size, num_workers=2)\n",
+ "\n",
+ "test_ds = MedNISTDataset(test_x, test_y, val_transforms)\n",
+ "test_loader = DataLoader(test_ds, batch_size=config.batch_size, num_workers=2)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "FIhKqMrVnZV9"
+ },
+ "source": [
+ "We typically define a model in PyTorch using [`torch.nn.Module`s](https://pytorch.org/docs/stable/notes/modules.html) which act as the building blocks of stateful computation. Even though Keras supports PyTorch as a backend, it does not mean that we can nest torch modules inside a [`keras_core.Model`](https://keras.io/keras_core/api/models/), because trainable variables inside a Keras Model is tracked exclusively via [Keras Layers](https://keras.io/keras_core/api/layers/).\n",
+ "\n",
+ "KerasCore provides us with a feature called `TorchModuleWrapper` which enables us to do exactly this. The `TorchModuleWrapper` is a Keras Layer that accepts a torch module and tracks its trainable variables, essentially converting the torch module into a Keras Layer. This enables us to put any torch modules inside a Keras Model and train them with a single `model.fit()`!\n",
+ "\n",
+ "The idea of the `TorchModuleWrapper` was proposed by Keras' creator [François Chollet](https://github.com/fchollet) on [this issue thread](https://github.com/keras-team/keras-core/issues/604)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "7zbinqy4ZsEy"
+ },
+ "outputs": [],
+ "source": [
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
+ "\n",
+ "inputs = keras.Input(shape=(1, 64, 64))\n",
+ "outputs = TorchModuleWrapper(\n",
+ " DenseNet121(\n",
+ " spatial_dims=2, in_channels=1, out_channels=num_class\n",
+ " )\n",
+ ")(inputs)\n",
+ "model = keras.Model(inputs, outputs)\n",
+ "\n",
+ "# model = MedMnistModel()\n",
+ "model(next(iter(train_loader))[0].to(device)).shape"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "opjBI87nneYU"
+ },
+ "source": [
+ "**Note:** It is actually possible to use torch modules inside a Keras Model without having to explicitly have them wrapped with the `TorchModuleWrapper` as evident by [this tweet](https://twitter.com/fchollet/status/1697381832164290754) from François Chollet. However, this doesn't seem to work at the point of time this example was created, as reported in [this issue](https://github.com/keras-team/keras-core/issues/834)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "y3H-cUKraRDh"
+ },
+ "outputs": [],
+ "source": [
+ "# Compile the model\n",
+ "model.compile(\n",
+ " loss=\"sparse_categorical_crossentropy\",\n",
+ " optimizer=keras.optimizers.Adam(1e-5),\n",
+ " metrics=[\"accuracy\"],\n",
+ ")\n",
+ "\n",
+ "# Define the backend-agnostic WandB callbacks for KerasCore\n",
+ "callbacks = [\n",
+ " # Track experiment metrics\n",
+ " WandbMetricsLogger(log_freq=\"batch\")\n",
+ "]\n",
+ "\n",
+ "# Train the model by calling model.fit\n",
+ "model.fit(\n",
+ " train_loader,\n",
+ " validation_data=val_loader,\n",
+ " epochs=config.num_epochs,\n",
+ " callbacks=callbacks,\n",
+ ")\n",
+ "\n",
+ "wandb.finish()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "Q0E-kiGBeCnZ"
+ },
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "gpuType": "V100",
+ "private_outputs": true,
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/colabs/keras/keras_core/timm_keras.ipynb b/colabs/keras/keras_core/timm_keras.ipynb
new file mode 100644
index 00000000..1cea4769
--- /dev/null
+++ b/colabs/keras/keras_core/timm_keras.ipynb
@@ -0,0 +1,377 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n",
+ "\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# 🔥 Fine-tune a [Timm](https://huggingface.co/docs/timm/index) Model with Keras and WandB 🦄\n",
+ "\n",
+ "[![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/wandb/examples/blob/master/colabs/keras/keras_core/timm_keras.ipynb)\n",
+ "\n",
+ "This notebook demonstrates\n",
+ "- how we can fine-tune a pre-trained model from timm using [KerasCore](https://github.com/keras-team/keras-core).\n",
+ "- how we can use the backend-agnostic Keras callbacks for [Weights & Biases](https://wandb.ai/site) to manage and track our experiment."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Installing and Importing the Dependencies"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "- We install the `main` branch of [KerasCore](https://github.com/keras-team/keras-core), this lets us use the latest feature merged in KerasCore.\n",
+ "- We install [timm](https://huggingface.co/docs/timm/index), a library containing SOTA computer vision models, layers, utilities, optimizers, schedulers, data-loaders, augmentations, and training/evaluation scripts.\n",
+ "- We also install [wandb-addons](https://github.com/soumik12345/wandb-addons), a library that hosts the backend-agnostic callbacks compatible with KerasCore"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# install the `main` branch of KerasCore\n",
+ "!pip install -qq namex\n",
+ "!apt install python3.10-venv\n",
+ "!git clone https://github.com/soumik12345/keras-core.git && cd keras-core && python pip_build.py --install\n",
+ "\n",
+ "# install timm and wandb-addons\n",
+ "!pip install -qq git+https://github.com/soumik12345/wandb-addons"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We specify the Keras backend to be using `torch` by explicitly specifying the environment variable `KERAS_BACKEND`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "os.environ[\"KERAS_BACKEND\"] = \"torch\"\n",
+ "\n",
+ "import numpy as np\n",
+ "import matplotlib.pyplot as plt\n",
+ "\n",
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "import torch.nn.functional as F\n",
+ "\n",
+ "import timm\n",
+ "from timm.data import resolve_data_config\n",
+ "\n",
+ "import torchvision\n",
+ "from torchvision import datasets, models, transforms\n",
+ "from torchvision.transforms.functional import InterpolationMode\n",
+ "\n",
+ "import wandb\n",
+ "from wandb_addons.keras import WandbMetricsLogger, WandbModelCheckpoint"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We initialize a [wandb run](https://docs.wandb.ai/guides/runs) and set the configs for the experiment."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "wandb.init(project=\"keras-torch\")\n",
+ "\n",
+ "config = wandb.config\n",
+ "config.model_name = \"xception41\"\n",
+ "config.freeze_backbone = False\n",
+ "config.preprocess_config = resolve_data_config({}, model=config.model_name)\n",
+ "config.dropout_rate = 0.5\n",
+ "config.batch_size = 4\n",
+ "config.num_epochs = 25"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## A PyTorch-based Input Pipeline"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We will be using the [ImageNette](https://github.com/fastai/imagenette) dataset for this experiment. Imagenette is a subset of 10 easily classified classes from [Imagenet](https://www.image-net.org/) (tench, English springer, cassette player, chain saw, church, French horn, garbage truck, gas pump, golf ball, parachute).\n",
+ "\n",
+ "First, let's download this dataset."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-320.tgz -P imagenette\n",
+ "!tar zxf imagenette/imagenette2-320.tgz -C imagenette\n",
+ "!gzip -d imagenette/imagenette2-320.tgz"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Now, we create our standard torch-based data loading pipeline."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Define pre-processing and augmentation transforms for the train and validation sets\n",
+ "data_transforms = {\n",
+ " 'train': transforms.Compose([\n",
+ " transforms.RandomResizedCrop(\n",
+ " size=config.preprocess_config[\"input_size\"][1],\n",
+ " interpolation=InterpolationMode.BICUBIC,\n",
+ " ),\n",
+ " transforms.RandomHorizontalFlip(),\n",
+ " transforms.ToTensor(),\n",
+ " transforms.Normalize(\n",
+ " config.preprocess_config[\"mean\"],\n",
+ " config.preprocess_config[\"std\"]\n",
+ " )\n",
+ " ]),\n",
+ " 'val': transforms.Compose([\n",
+ " transforms.Resize(256),\n",
+ " transforms.CenterCrop(config.preprocess_config[\"input_size\"][1]),\n",
+ " transforms.ToTensor(),\n",
+ " transforms.Normalize(\n",
+ " config.preprocess_config[\"mean\"],\n",
+ " config.preprocess_config[\"std\"]\n",
+ " )\n",
+ " ]),\n",
+ "}\n",
+ "\n",
+ "# Define the train and validation datasets\n",
+ "data_dir = 'imagenette/imagenette2-320'\n",
+ "image_datasets = {\n",
+ " x: datasets.ImageFolder(\n",
+ " os.path.join(data_dir, x), data_transforms[x]\n",
+ " )\n",
+ " for x in ['train', 'val']\n",
+ "}\n",
+ "\n",
+ "# Define the torch dataloaders corresponding to the train and validation dataset\n",
+ "dataloaders = {\n",
+ " x: torch.utils.data.DataLoader(\n",
+ " image_datasets[x],\n",
+ " batch_size=config.batch_size,\n",
+ " shuffle=True,\n",
+ " num_workers=4\n",
+ " )\n",
+ " for x in ['train', 'val']\n",
+ "}\n",
+ "dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}\n",
+ "class_names = image_datasets['train'].classes\n",
+ "\n",
+ "# Specify the global device\n",
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Let's take a look at a few of the samples."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def imshow(inp, title=None):\n",
+ " \"\"\"Display image for Tensor.\"\"\"\n",
+ " inp = inp.numpy().transpose((1, 2, 0))\n",
+ " mean = np.array(config.preprocess_config[\"mean\"])\n",
+ " std = np.array(config.preprocess_config[\"std\"])\n",
+ " inp = std * inp + mean\n",
+ " inp = np.clip(inp, 0, 1)\n",
+ " plt.imshow(inp)\n",
+ " if title is not None:\n",
+ " plt.title(title)\n",
+ " plt.pause(0.001)\n",
+ "\n",
+ "\n",
+ "# Get a batch of training data\n",
+ "inputs, classes = next(iter(dataloaders['train']))\n",
+ "print(inputs.shape, classes.shape)\n",
+ "\n",
+ "# Make a grid from batch\n",
+ "out = torchvision.utils.make_grid(inputs)\n",
+ "\n",
+ "imshow(out, title=[class_names[x] for x in classes])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Creating and Training our Classifier"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We typically define a model in PyTorch using [`torch.nn.Module`s](https://pytorch.org/docs/stable/notes/modules.html) which act as the building blocks of stateful computation. Even though Keras supports PyTorch as a backend, it does not mean that we can nest torch modules inside a [`keras_core.Model`](https://keras.io/keras_core/api/models/), because trainable variables inside a Keras Model is tracked exclusively via [Keras Layers](https://keras.io/keras_core/api/layers/).\n",
+ "\n",
+ "KerasCore provides us with a feature called `TorchModuleWrapper` which enables us to do exactly this. The `TorchModuleWrapper` is a Keras Layer that accepts a torch module and tracks its trainable variables, essentially converting the torch module into a Keras Layer. This enables us to put any torch modules inside a Keras Model and train them with a single `model.fit()`!\n",
+ "\n",
+ "The idea of the `TorchModuleWrapper` was proposed by Keras' creator [François Chollet](https://github.com/fchollet) on [this issue thread](https://github.com/keras-team/keras-core/issues/604)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import keras_core as keras\n",
+ "from keras_core.utils import TorchModuleWrapper\n",
+ "\n",
+ "\n",
+ "class TimmClassifier(keras.Model):\n",
+ "\n",
+ " def __init__(self, model_name, freeze_backbone, dropout_rate, num_classes, *args, **kwargs):\n",
+ " super().__init__(*args, **kwargs)\n",
+ " \n",
+ " # Define the pre-trained module from timm\n",
+ " self.backbone = TorchModuleWrapper(\n",
+ " timm.create_model(model_name, pretrained=True)\n",
+ " )\n",
+ " self.backbone.trainable = not freeze_backbone\n",
+ " \n",
+ " # Build the classification head using keras layers\n",
+ " self.global_average_pooling = keras.layers.GlobalAveragePooling2D()\n",
+ " self.dropout = keras.layers.Dropout(dropout_rate)\n",
+ " self.classification_head = keras.layers.Dense(num_classes)\n",
+ "\n",
+ " def call(self, inputs):\n",
+ " # We get the unpooled features from the timm backbone by calling `forward_features`\n",
+ " # on the torch module corresponding to the backbone.\n",
+ " x = self.backbone.module.forward_features(inputs)\n",
+ " x = self.global_average_pooling(x)\n",
+ " x = self.dropout(x)\n",
+ " x = self.classification_head(x)\n",
+ " return keras.activations.softmax(x, axis=1)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "**Note:** It is actually possible to use torch modules inside a Keras Model without having to explicitly have them wrapped with the `TorchModuleWrapper` as evident by [this tweet](https://twitter.com/fchollet/status/1697381832164290754) from François Chollet. However, this doesn't seem to work at the point of time this example was created, as reported in [this issue](https://github.com/keras-team/keras-core/issues/834)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Now, we define the model and pass a random tensor to check the output shape\n",
+ "model = TimmClassifier(\n",
+ " model_name=config.model_name,\n",
+ " freeze_backbone=config.freeze_backbone,\n",
+ " dropout_rate=config.dropout_rate,\n",
+ " num_classes=len(class_names)\n",
+ ")\n",
+ "model(torch.ones(1, *config.preprocess_config[\"input_size\"]).to(device)).shape"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Now, in standard Keras fashion, all we need to do is compile the model and call `model.fit()`!"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Create exponential decay learning rate scheduler\n",
+ "decay_steps = config.num_epochs * len(dataloaders[\"train\"]) // config.batch_size\n",
+ "lr_scheduler = keras.optimizers.schedules.ExponentialDecay(\n",
+ " initial_learning_rate=1e-3, decay_steps=decay_steps, decay_rate=0.1,\n",
+ ")\n",
+ "\n",
+ "# Compile the model\n",
+ "model.compile(\n",
+ " loss=\"sparse_categorical_crossentropy\",\n",
+ " optimizer=keras.optimizers.Adam(lr_scheduler),\n",
+ " metrics=[\"accuracy\"],\n",
+ ")\n",
+ "\n",
+ "# Define the backend-agnostic WandB callbacks for KerasCore\n",
+ "callbacks = [\n",
+ " # Track experiment metrics\n",
+ " WandbMetricsLogger(log_freq=\"batch\"),\n",
+ " # Track and version model checkpoints\n",
+ " WandbModelCheckpoint(\"model.keras\")\n",
+ "]\n",
+ "\n",
+ "# Train the model by calling model.fit\n",
+ "model.fit(\n",
+ " dataloaders[\"train\"],\n",
+ " validation_data=dataloaders[\"val\"],\n",
+ " epochs=config.num_epochs,\n",
+ " callbacks=callbacks,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "In order to know more about the backend-agnostic Keras callbacks for Weights & Biases, check out the [docs for wandb-addons](https://geekyrakshit.dev/wandb-addons/keras/keras_core/)."
+ ]
+ }
+ ],
+ "metadata": {
+ "language_info": {
+ "name": "python"
+ },
+ "orig_nbformat": 4
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/colabs/keras/keras_core/torchvision-keras.ipynb b/colabs/keras/keras_core/torchvision-keras.ipynb
new file mode 100644
index 00000000..30783184
--- /dev/null
+++ b/colabs/keras/keras_core/torchvision-keras.ipynb
@@ -0,0 +1,922 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "jLRorPqA1-h4"
+ },
+ "source": [
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "# 🔥 Fine-tune a TorchVision Model with Keras and WandB 🦄\n",
+ "\n",
+ "[![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/wandb/examples/blob/master/colabs/keras/keras_core/torchvision_keras.ipynb)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "aEl-j2hq2w25"
+ },
+ "source": [
+ "## Introduction\n",
+ "\n",
+ "[TorchVision](https://pytorch.org/vision/stable/index.html) is a library part of the [PyTorch](http://pytorch.org/) project that consists of popular datasets, model architectures, and common image transformations for computer vision. This example demonstrates how we can perform transfer learning for image classification using a pre-trained backbone model from TorchVision on the [Imagenette dataset](https://github.com/fastai/imagenette) using KerasCore. We will also demonstrate the compatibility of KerasCore with an input system consisting of [Torch Datasets and Dataloaders](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html).\n",
+ "\n",
+ "### References:\n",
+ "\n",
+ "- [Customizing what happens in `fit()` with PyTorch](https://keras.io/keras_core/guides/custom_train_step_in_torch/)\n",
+ "- [PyTorch Datasets and Dataloaders](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html)\n",
+ "- [Transfer learning for Computer Vision using PyTorch](https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html)\n",
+ "\n",
+ "## Setup\n",
+ "\n",
+ "- We install the `main` branch of [KerasCore](https://github.com/keras-team/keras-core), this lets us use the latest feature merged in KerasCore.\n",
+ "- We also install [wandb-addons](https://github.com/soumik12345/wandb-addons), a library that hosts the backend-agnostic callbacks compatible with KerasCore"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "r4rfNRPgiy9v",
+ "outputId": "ce5ba027-567a-4577-a638-ca8802ee1f84"
+ },
+ "outputs": [],
+ "source": [
+ "# install the `main` branch of KerasCore\n",
+ "!pip install -qq namex\n",
+ "!apt install python3.10-venv\n",
+ "!git clone https://github.com/soumik12345/keras-core.git && cd keras-core && python pip_build.py --install\n",
+ "\n",
+ "# install wandb-addons\n",
+ "!pip install -qq git+https://github.com/soumik12345/wandb-addons"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "7nudAUt8jHRB"
+ },
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "os.environ[\"KERAS_BACKEND\"] = \"torch\"\n",
+ "\n",
+ "import numpy as np\n",
+ "from tqdm.auto import tqdm\n",
+ "import matplotlib.pyplot as plt\n",
+ "\n",
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "import torch.nn.functional as F\n",
+ "\n",
+ "import torchvision\n",
+ "from torchvision import datasets, models, transforms\n",
+ "\n",
+ "import keras_core as keras\n",
+ "from keras_core.utils import TorchModuleWrapper\n",
+ "\n",
+ "import wandb\n",
+ "from wandb_addons.keras import WandbMetricsLogger, WandbModelCheckpoint"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "pS1c-ySo7nty"
+ },
+ "source": [
+ "## Define the Hyperparameters"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "2ovtXSUA7ksk",
+ "outputId": "725bef1a-0e68-473e-8c24-1f1ffd28506c"
+ },
+ "outputs": [],
+ "source": [
+ "wandb.init(project=\"keras-torch\", entity=\"ml-colabs\", job_type=\"torchvision/train\")\n",
+ "\n",
+ "config = wandb.config\n",
+ "config.batch_size = 32\n",
+ "config.image_size = 224\n",
+ "config.freeze_backbone = True\n",
+ "config.initial_learning_rate = 1e-3\n",
+ "config.num_epochs = 5"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "b5uU_Q_H74GO"
+ },
+ "source": [
+ "## Creating the Torch Datasets and Dataloaders\n",
+ "\n",
+ "In this example, we would train an image classification model on the [Imagenette dataset](https://github.com/fastai/imagenette). Imagenette is a subset of 10 easily classified classes from [Imagenet](https://www.image-net.org/) (tench, English springer, cassette player, chain saw, church, French horn, garbage truck, gas pump, golf ball, parachute)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "gWV2hNGo8vW7",
+ "outputId": "cc70b598-fc66-480a-a98c-7077fa634a22"
+ },
+ "outputs": [],
+ "source": [
+ "# Fetch the imagenette dataset\n",
+ "data_dir = keras.utils.get_file(\n",
+ " fname=\"imagenette2-320.tgz\",\n",
+ " origin=\"https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-320.tgz\",\n",
+ " extract=True,\n",
+ ")\n",
+ "data_dir = data_dir.replace(\".tgz\", \"\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Ka7TDUMn9IcG"
+ },
+ "source": [
+ "Next, we define pre-processing and augmentation transforms from TorchVision for the train and validation sets."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "rFZKJafF9H6y"
+ },
+ "outputs": [],
+ "source": [
+ "data_transforms = {\n",
+ " 'train': transforms.Compose([\n",
+ " transforms.RandomResizedCrop(config.image_size),\n",
+ " transforms.RandomHorizontalFlip(),\n",
+ " transforms.ToTensor(),\n",
+ " transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
+ " ]),\n",
+ " 'val': transforms.Compose([\n",
+ " transforms.Resize(256),\n",
+ " transforms.CenterCrop(config.image_size),\n",
+ " transforms.ToTensor(),\n",
+ " transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
+ " ]),\n",
+ "}"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "aE3VFQHm9srl"
+ },
+ "source": [
+ "Finally, we will use TorchVision and the [`torch.utils.data`](https://pytorch.org/docs/stable/data.html) packages for creating the dataloaders for trainig and validation."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "0N81UNjtjMZ4",
+ "outputId": "bfb2af1a-977c-4408-cb1c-8da78d03d13d"
+ },
+ "outputs": [],
+ "source": [
+ "# Define the train and validation datasets\n",
+ "image_datasets = {\n",
+ " x: datasets.ImageFolder(\n",
+ " os.path.join(data_dir, x), data_transforms[x]\n",
+ " )\n",
+ " for x in ['train', 'val']\n",
+ "}\n",
+ "\n",
+ "# Define the torch dataloaders corresponding to the\n",
+ "# train and validation dataset\n",
+ "dataloaders = {\n",
+ " x: torch.utils.data.DataLoader(\n",
+ " image_datasets[x],\n",
+ " batch_size=config.batch_size,\n",
+ " shuffle=True,\n",
+ " num_workers=4\n",
+ " )\n",
+ " for x in ['train', 'val']\n",
+ "}\n",
+ "dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}\n",
+ "class_names = image_datasets['train'].classes"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "AY6kOIL--EdP"
+ },
+ "source": [
+ "Let us visualize a few samples from the training dataloader."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 829
+ },
+ "id": "yffdD4LxjOQG",
+ "outputId": "38e6f182-11b1-4830-e5ec-a5588c7bdbf9"
+ },
+ "outputs": [],
+ "source": [
+ "plt.figure(figsize=(10, 10))\n",
+ "sample_images, sample_labels = next(iter(dataloaders['train']))\n",
+ "sample_images = sample_images.numpy()\n",
+ "sample_labels = sample_labels.numpy()\n",
+ "for idx in range(9):\n",
+ " ax = plt.subplot(3, 3, idx + 1)\n",
+ " image = sample_images[idx].transpose((1, 2, 0))\n",
+ " mean = np.array([0.485, 0.456, 0.406])\n",
+ " std = np.array([0.229, 0.224, 0.225])\n",
+ " image = std * image + mean\n",
+ " image = np.clip(image, 0, 1)\n",
+ " plt.imshow(image)\n",
+ " plt.title(\"Ground Truth Label: \" + class_names[int(sample_labels[idx])])\n",
+ " plt.axis(\"off\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "0zvfzD04-ce9"
+ },
+ "source": [
+ "## The Image Classification Model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "ZBLzXVwk-mLP"
+ },
+ "source": [
+ "We typically define a model in PyTorch using [`torch.nn.Module`s](https://pytorch.org/docs/stable/notes/modules.html) which act as the building blocks of stateful computation. Let us define the ResNet18 model from the TorchVision package as a `torch.nn.Module` pre-trained on the [Imagenet1K dataset](https://huggingface.co/datasets/imagenet-1k)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "tOGUiI9K_BRk",
+ "outputId": "97aa387c-7b3d-41cf-959b-de60eb572912"
+ },
+ "outputs": [],
+ "source": [
+ "# Define the pre-trained resnet18 module from TorchVision\n",
+ "resnet_18 = models.resnet18(weights='IMAGENET1K_V1')\n",
+ "\n",
+ "# We set the classification head of the pre-trained ResNet18\n",
+ "# module to an identity module\n",
+ "resnet_18.fc = nn.Identity()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "afxbMwcYF-Yz"
+ },
+ "source": [
+ "ven though Keras supports PyTorch as a backend, it does not mean that we can nest torch modules inside a [`keras_core.Model`](https://keras.io/keras_core/api/models/), because trainable variables inside a Keras Model is tracked exclusively via [Keras Layers](https://keras.io/keras_core/api/layers/).\n",
+ "\n",
+ "KerasCore provides us with a feature called `TorchModuleWrapper` which enables us to do exactly this. The `TorchModuleWrapper` is a Keras Layer that accepts a torch module and tracks its trainable variables, essentially converting the torch module into a Keras Layer. This enables us to put any torch modules inside a Keras Model and train them with a single `model.fit()`!"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "JLuCIAy5F6L1"
+ },
+ "outputs": [],
+ "source": [
+ "# We set the trainable ResNet18 backbone to be a Keras Layer\n",
+ "# using `TorchModuleWrapper`\n",
+ "backbone = TorchModuleWrapper(resnet_18)\n",
+ "\n",
+ "# We set this to `False` if you want to freeze the backbone\n",
+ "backbone.trainable = config.freeze_backbone"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "7y28txoVHHk8"
+ },
+ "source": [
+ "Now, we will build a Keras functional model with the backbone layer."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 314
+ },
+ "id": "l2rxqA8vjR3W",
+ "outputId": "206ad321-5fb5-41e1-a391-37cef3b10edb"
+ },
+ "outputs": [],
+ "source": [
+ "inputs = keras.Input(shape=(3, config.image_size, config.image_size))\n",
+ "x = backbone(inputs)\n",
+ "x = keras.layers.Dropout(0.5)(x)\n",
+ "x = keras.layers.Dense(len(class_names))(x)\n",
+ "outputs = keras.activations.softmax(x, axis=1)\n",
+ "model = keras.Model(inputs, outputs, name=\"ResNet18_Classifier\")\n",
+ "\n",
+ "model.summary()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "BXbvYzDnjyDQ",
+ "outputId": "9d7e51b2-85e7-4717-993b-0f6efac2999d"
+ },
+ "outputs": [],
+ "source": [
+ "# Create exponential decay learning rate scheduler\n",
+ "decay_steps = config.num_epochs * len(dataloaders[\"train\"]) // config.batch_size\n",
+ "lr_scheduler = keras.optimizers.schedules.ExponentialDecay(\n",
+ " initial_learning_rate=config.initial_learning_rate,\n",
+ " decay_steps=decay_steps,\n",
+ " decay_rate=0.1,\n",
+ ")\n",
+ "\n",
+ "# Compile the model\n",
+ "model.compile(\n",
+ " loss=\"sparse_categorical_crossentropy\",\n",
+ " optimizer=keras.optimizers.Adam(lr_scheduler),\n",
+ " metrics=[\"accuracy\"],\n",
+ ")\n",
+ "\n",
+ "# Define the backend-agnostic WandB callbacks for KerasCore\n",
+ "callbacks = [\n",
+ " # Track experiment metrics with WandB\n",
+ " WandbMetricsLogger(log_freq=\"batch\"),\n",
+ " # Save best model checkpoints to WandB\n",
+ " WandbModelCheckpoint(\n",
+ " filepath=\"model.weights.h5\",\n",
+ " monitor=\"val_loss\",\n",
+ " save_best_only=True,\n",
+ " save_weights_only=True,\n",
+ " )\n",
+ "]\n",
+ "\n",
+ "# Train the model by calling model.fit\n",
+ "history = model.fit(\n",
+ " dataloaders[\"train\"],\n",
+ " validation_data=dataloaders[\"val\"],\n",
+ " epochs=config.num_epochs,\n",
+ " callbacks=callbacks,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "upJf2M92JBlD",
+ "outputId": "6993ffbe-2b69-4ff8-9a2f-9c502ca1414d"
+ },
+ "outputs": [],
+ "source": [
+ "wandb.finish()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "I58qvNnzJmHD"
+ },
+ "source": [
+ "## Evaluation and Inference\n",
+ "\n",
+ "Now, we let us load the best model weights checkpoint and evaluate the model."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "cNpfuELqnDI_",
+ "outputId": "18bfe7bb-adea-4ac2-807b-924dce6fbcc8"
+ },
+ "outputs": [],
+ "source": [
+ "wandb.init(\n",
+ " project=\"keras-torch\", entity=\"ml-colabs\", job_type=\"torchvision/eval\"\n",
+ ")\n",
+ "artifact = wandb.use_artifact(\n",
+ " 'ml-colabs/keras-torch/run_hiceci7f_model:latest', type='model'\n",
+ ")\n",
+ "artifact_dir = artifact.download()\n",
+ "\n",
+ "model.load_weights(os.path.join(artifact_dir, \"model.weights.h5\"))\n",
+ "\n",
+ "_, val_accuracy = model.evaluate(dataloaders[\"val\"])\n",
+ "wandb.log({\"Validation-Accuracy\": val_accuracy})"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "vE2vAvBNKAI9"
+ },
+ "source": [
+ "Finally, let us visualize the some predictions of the model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 49,
+ "referenced_widgets": [
+ "1b82953b5f134926ae8a11c4fedca385",
+ "639fec0083134ab18b6a22999203536e",
+ "5d85942dc6e44640aed593b7a8494493",
+ "757ca7f811ec4db5a4dd517c5aec2bb8",
+ "122edd2a0300448cac1fcc8645fd8708",
+ "b1df464c3a2d47dbb49b36fa5d9912bb",
+ "dd95b1ba48ec4bc5b494f319ad41eedf",
+ "84dfa1f6d6f04c50b87e8683a0abdf55",
+ "634e163cb1354824a764dc49f7d9f2fa",
+ "5b1e8c0d14d84e0f923a997727794c89",
+ "2e912ebb924c4269a07fb7c1204eb923"
+ ]
+ },
+ "id": "ugrP307SpxMj",
+ "outputId": "ead193f0-bfd9-4bbc-c675-d741d64fd70f"
+ },
+ "outputs": [],
+ "source": [
+ "table = wandb.Table(\n",
+ " columns=[\n",
+ " \"Image\", \"Ground-Truth\", \"Prediction\"\n",
+ " ] + [\"Confidence-\" + cls for cls in class_names]\n",
+ ")\n",
+ "\n",
+ "sample_images, sample_labels = next(iter(dataloaders['train']))\n",
+ "\n",
+ "# We perform inference and detach the predicted probabilities from the Torch\n",
+ "# computation graph with a tensor that does not require gradient computation.\n",
+ "sample_pred_probas = model(sample_images.to(\"cuda\")).detach()\n",
+ "sample_pred_logits = keras.ops.argmax(sample_pred_probas, axis=1)\n",
+ "sample_pred_logits = sample_pred_logits.to(\"cpu\").numpy()\n",
+ "sample_pred_probas = sample_pred_probas.to(\"cpu\").numpy()\n",
+ "\n",
+ "sample_images = sample_images.numpy()\n",
+ "sample_labels = sample_labels.numpy()\n",
+ "\n",
+ "for idx in tqdm(range(sample_images.shape[0])):\n",
+ " image = sample_images[idx].transpose((1, 2, 0))\n",
+ " mean = np.array([0.485, 0.456, 0.406])\n",
+ " std = np.array([0.229, 0.224, 0.225])\n",
+ " image = std * image + mean\n",
+ " image = np.clip(image, 0, 1)\n",
+ " table.add_data(\n",
+ " wandb.Image(image),\n",
+ " class_names[int(sample_labels[idx])],\n",
+ " class_names[int(sample_pred_logits[idx])],\n",
+ " *sample_pred_probas[idx].tolist(),\n",
+ " )\n",
+ "\n",
+ "wandb.log({\"Evaluation-Table\": table})"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "a90XmoRR65SJ",
+ "outputId": "88e5b0b1-a3db-4366-ba71-f6a21a877676"
+ },
+ "outputs": [],
+ "source": [
+ "wandb.finish()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "QA6ytgUaSxsS"
+ },
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "gpuType": "T4",
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python"
+ },
+ "widgets": {
+ "application/vnd.jupyter.widget-state+json": {
+ "122edd2a0300448cac1fcc8645fd8708": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "1b82953b5f134926ae8a11c4fedca385": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_639fec0083134ab18b6a22999203536e",
+ "IPY_MODEL_5d85942dc6e44640aed593b7a8494493",
+ "IPY_MODEL_757ca7f811ec4db5a4dd517c5aec2bb8"
+ ],
+ "layout": "IPY_MODEL_122edd2a0300448cac1fcc8645fd8708"
+ }
+ },
+ "2e912ebb924c4269a07fb7c1204eb923": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "5b1e8c0d14d84e0f923a997727794c89": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "5d85942dc6e44640aed593b7a8494493": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_84dfa1f6d6f04c50b87e8683a0abdf55",
+ "max": 32,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_634e163cb1354824a764dc49f7d9f2fa",
+ "value": 32
+ }
+ },
+ "634e163cb1354824a764dc49f7d9f2fa": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "639fec0083134ab18b6a22999203536e": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_b1df464c3a2d47dbb49b36fa5d9912bb",
+ "placeholder": "",
+ "style": "IPY_MODEL_dd95b1ba48ec4bc5b494f319ad41eedf",
+ "value": "100%"
+ }
+ },
+ "757ca7f811ec4db5a4dd517c5aec2bb8": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_5b1e8c0d14d84e0f923a997727794c89",
+ "placeholder": "",
+ "style": "IPY_MODEL_2e912ebb924c4269a07fb7c1204eb923",
+ "value": " 32/32 [00:01<00:00, 17.84it/s]"
+ }
+ },
+ "84dfa1f6d6f04c50b87e8683a0abdf55": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "b1df464c3a2d47dbb49b36fa5d9912bb": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "dd95b1ba48ec4bc5b494f319ad41eedf": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ }
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/colabs/keras/keras_nsynth_instrument_prediction.ipynb b/colabs/keras/keras_nsynth_instrument_prediction.ipynb
index a7bcc152..6a4ec1f8 100644
--- a/colabs/keras/keras_nsynth_instrument_prediction.ipynb
+++ b/colabs/keras/keras_nsynth_instrument_prediction.ipynb
@@ -649,13 +649,6 @@
"source": [
"run.finish()"
]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
}
],
"metadata": {
diff --git a/colabs/monai/3d_brain_tumor_segmentation.ipynb b/colabs/monai/3d_brain_tumor_segmentation.ipynb
new file mode 100644
index 00000000..7eba44b6
--- /dev/null
+++ b/colabs/monai/3d_brain_tumor_segmentation.ipynb
@@ -0,0 +1,986 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Brain tumor 3D segmentation with MONAI and Weights & Biases\n",
+ "\n",
+ "[![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/wandb/examples/blob/main/colabs/monai/3d_brain_tumor_segmentation.ipynb)\n",
+ "\n",
+ "This tutorial shows how to construct a training workflow of multi-labels 3D brain tumor segmentation task using [MONAI](https://github.com/Project-MONAI/MONAI) and use experiment tracking and data visualization features of [Weights & Biases](https://wandb.ai/site). The tutorial contains the following features:\n",
+ "\n",
+ "1. Initialize a Weights & Biases run and synchrozize all configs associated with the run for reproducibility.\n",
+ "2. MONAI transform API:\n",
+ " 1. MONAI Transforms for dictionary format data.\n",
+ " 2. How to define a new transform according to MONAI `transforms` API.\n",
+ " 3. How to randomly adjust intensity for data augmentation.\n",
+ "3. Data Loading and Visualization:\n",
+ " 1. Load Nifti image with metadata, load a list of images and stack them.\n",
+ " 2. Cache IO and transforms to accelerate training and validation.\n",
+ " 3. Visualize the data using `wandb.Table` and interactive segmentation overlay on Weights & Biases.\n",
+ "4. Training a 3D `SegResNet` model\n",
+ " 1. Using the `networks`, `losses`, and `metrics` APIs from MONAI.\n",
+ " 2. Training the 3D `SegResNet` model using a PyTorch training loop.\n",
+ " 3. Track the training experiment using Weights & Biases.\n",
+ " 4. Log and version model checkpoints as model artifacts on Weights & Biases.\n",
+ "5. Visualize and compare the predictions on the validation dataset using `wandb.Table` and interactive segmentation overlay on Weights & Biases."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 🌴 Setup and Installation\n",
+ "\n",
+ "First, let us install the latest version of both MONAI and Weights and Biases."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!python -c \"import monai\" || pip install -q -U \"monai[nibabel, tqdm]\"\n",
+ "!python -c \"import wandb\" || pip install -q -U wandb"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "\n",
+ "import numpy as np\n",
+ "from tqdm.auto import tqdm\n",
+ "import wandb\n",
+ "\n",
+ "from monai.apps import DecathlonDataset\n",
+ "from monai.data import DataLoader, decollate_batch\n",
+ "from monai.losses import DiceLoss\n",
+ "from monai.config import print_config\n",
+ "from monai.inferers import sliding_window_inference\n",
+ "from monai.metrics import DiceMetric\n",
+ "from monai.networks.nets import SegResNet\n",
+ "from monai.transforms import (\n",
+ " Activations,\n",
+ " AsDiscrete,\n",
+ " Compose,\n",
+ " LoadImaged,\n",
+ " MapTransform,\n",
+ " NormalizeIntensityd,\n",
+ " Orientationd,\n",
+ " RandFlipd,\n",
+ " RandScaleIntensityd,\n",
+ " RandShiftIntensityd,\n",
+ " RandSpatialCropd,\n",
+ " Spacingd,\n",
+ " EnsureTyped,\n",
+ " EnsureChannelFirstd,\n",
+ ")\n",
+ "from monai.utils import set_determinism\n",
+ "\n",
+ "import torch\n",
+ "\n",
+ "print_config()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We will then authenticate this colab instance to use W&B."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "wandb.login()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 🌳 Initialize a W&B Run\n",
+ "\n",
+ "We will start a new W&B run to start tracking our experiment."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "wandb.init(project=\"monai-brain-tumor-segmentation\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Use of proper config system is a recommended best practice for reproducible machine learning. We can track the hyperparameters for every experiment using W&B."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "config = wandb.config\n",
+ "config.seed = 0\n",
+ "config.roi_size = [224, 224, 144]\n",
+ "config.batch_size = 1\n",
+ "config.num_workers = 4\n",
+ "config.max_train_images_visualized = 20\n",
+ "config.max_val_images_visualized = 20\n",
+ "config.dice_loss_smoothen_numerator = 0\n",
+ "config.dice_loss_smoothen_denominator = 1e-5\n",
+ "config.dice_loss_squared_prediction = True\n",
+ "config.dice_loss_target_onehot = False\n",
+ "config.dice_loss_apply_sigmoid = True\n",
+ "config.initial_learning_rate = 1e-4\n",
+ "config.weight_decay = 1e-5\n",
+ "config.max_train_epochs = 50\n",
+ "config.validation_intervals = 1\n",
+ "config.dataset_dir = \"./dataset/\"\n",
+ "config.checkpoint_dir = \"./checkpoints\"\n",
+ "config.inference_roi_size = (128, 128, 64)\n",
+ "config.max_prediction_images_visualized = 20"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We would also need to set the random seed for modules to enable or disable deterministic training."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "set_determinism(seed=config.seed)\n",
+ "\n",
+ "# Create directories\n",
+ "os.makedirs(config.dataset_dir, exist_ok=True)\n",
+ "os.makedirs(config.checkpoint_dir, exist_ok=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 💿 Data Loading and Transformation"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Here we use the `monai.transforms` API to create a custom transform that converts the multi-classes labels into multi-labels segmentation task in one-hot format."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform):\n",
+ " \"\"\"\n",
+ " Convert labels to multi channels based on brats classes:\n",
+ " label 1 is the peritumoral edema\n",
+ " label 2 is the GD-enhancing tumor\n",
+ " label 3 is the necrotic and non-enhancing tumor core\n",
+ " The possible classes are TC (Tumor core), WT (Whole tumor)\n",
+ " and ET (Enhancing tumor).\n",
+ "\n",
+ " Reference: https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/brats_segmentation_3d.ipynb\n",
+ "\n",
+ " \"\"\"\n",
+ "\n",
+ " def __call__(self, data):\n",
+ " d = dict(data)\n",
+ " for key in self.keys:\n",
+ " result = []\n",
+ " # merge label 2 and label 3 to construct TC\n",
+ " result.append(torch.logical_or(d[key] == 2, d[key] == 3))\n",
+ " # merge labels 1, 2 and 3 to construct WT\n",
+ " result.append(\n",
+ " torch.logical_or(\n",
+ " torch.logical_or(d[key] == 2, d[key] == 3), d[key] == 1\n",
+ " )\n",
+ " )\n",
+ " # label 2 is ET\n",
+ " result.append(d[key] == 2)\n",
+ " d[key] = torch.stack(result, axis=0).float()\n",
+ " return d"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Next, we set up transforms for training and validation datasets respectively."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "train_transform = Compose(\n",
+ " [\n",
+ " # load 4 Nifti images and stack them together\n",
+ " LoadImaged(keys=[\"image\", \"label\"]),\n",
+ " EnsureChannelFirstd(keys=\"image\"),\n",
+ " EnsureTyped(keys=[\"image\", \"label\"]),\n",
+ " ConvertToMultiChannelBasedOnBratsClassesd(keys=\"label\"),\n",
+ " Orientationd(keys=[\"image\", \"label\"], axcodes=\"RAS\"),\n",
+ " Spacingd(\n",
+ " keys=[\"image\", \"label\"],\n",
+ " pixdim=(1.0, 1.0, 1.0),\n",
+ " mode=(\"bilinear\", \"nearest\"),\n",
+ " ),\n",
+ " RandSpatialCropd(\n",
+ " keys=[\"image\", \"label\"], roi_size=config.roi_size, random_size=False\n",
+ " ),\n",
+ " RandFlipd(keys=[\"image\", \"label\"], prob=0.5, spatial_axis=0),\n",
+ " RandFlipd(keys=[\"image\", \"label\"], prob=0.5, spatial_axis=1),\n",
+ " RandFlipd(keys=[\"image\", \"label\"], prob=0.5, spatial_axis=2),\n",
+ " NormalizeIntensityd(keys=\"image\", nonzero=True, channel_wise=True),\n",
+ " RandScaleIntensityd(keys=\"image\", factors=0.1, prob=1.0),\n",
+ " RandShiftIntensityd(keys=\"image\", offsets=0.1, prob=1.0),\n",
+ " ]\n",
+ ")\n",
+ "val_transform = Compose(\n",
+ " [\n",
+ " LoadImaged(keys=[\"image\", \"label\"]),\n",
+ " EnsureChannelFirstd(keys=\"image\"),\n",
+ " EnsureTyped(keys=[\"image\", \"label\"]),\n",
+ " ConvertToMultiChannelBasedOnBratsClassesd(keys=\"label\"),\n",
+ " Orientationd(keys=[\"image\", \"label\"], axcodes=\"RAS\"),\n",
+ " Spacingd(\n",
+ " keys=[\"image\", \"label\"],\n",
+ " pixdim=(1.0, 1.0, 1.0),\n",
+ " mode=(\"bilinear\", \"nearest\"),\n",
+ " ),\n",
+ " NormalizeIntensityd(keys=\"image\", nonzero=True, channel_wise=True),\n",
+ " ]\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 🍁 The Dataset\n",
+ "\n",
+ "The dataset that we will use for this experiment comes from http://medicaldecathlon.com/. We will use Multimodal multisite MRI data (FLAIR, T1w, T1gd, T2w) to segment Gliomas, necrotic/active tumour, and oedema. The dataset consists of 750 4D volumes (484 Training + 266 Testing).\n",
+ "\n",
+ "We will use the `DecathlonDataset` to automatically download and extract the dataset. It inherits MONAI `CacheDataset` which enables us to set `cache_num=N` to cache `N` items for training and use the default args to cache all the items for validation, depending on your memory size."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "train_dataset = DecathlonDataset(\n",
+ " root_dir=config.dataset_dir,\n",
+ " task=\"Task01_BrainTumour\",\n",
+ " transform=val_transform,\n",
+ " section=\"training\",\n",
+ " download=True,\n",
+ " cache_rate=0.0,\n",
+ " num_workers=4,\n",
+ ")\n",
+ "val_dataset = DecathlonDataset(\n",
+ " root_dir=config.dataset_dir,\n",
+ " task=\"Task01_BrainTumour\",\n",
+ " transform=val_transform,\n",
+ " section=\"validation\",\n",
+ " download=False,\n",
+ " cache_rate=0.0,\n",
+ " num_workers=4,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "**Note:** Instead of applying the `train_transform` to the `train_dataset`, we have applied `val_transform` to both the training and validation datasets. This is because, before training, we would be visualizing samples from both the splits of the dataset."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 📸 Visualizing the Dataset"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Weights & Biases supports images, video, audio, and more. Log rich media to explore our results and visually compare our runs, models, and datasets. We would be using the [segmentation mask overlay system](https://docs.wandb.ai/guides/track/log/media#image-overlays-in-tables) to visualize our data volumes. To log segmentation masks in [tables](https://docs.wandb.ai/guides/tables), we will need to provide a `wandb.Image`` object for each row in the table.\n",
+ "\n",
+ "An example is provided in the Code snippet below:\n",
+ "\n",
+ "```python\n",
+ "table = wandb.Table(columns=[\"ID\", \"Image\"])\n",
+ "\n",
+ "for id, img, label in zip(ids, images, labels):\n",
+ " mask_img = wandb.Image(\n",
+ " img,\n",
+ " masks={\n",
+ " \"prediction\": {\"mask_data\": label, \"class_labels\": class_labels}\n",
+ " # ...\n",
+ " },\n",
+ " )\n",
+ "\n",
+ " table.add_data(id, img)\n",
+ "\n",
+ "wandb.log({\"Table\": table})\n",
+ "```\n",
+ "\n",
+ "Let us now write a simple utility function that takes a sample image, label, `wandb.Table` object and some associated metadata and populate the rows of a table that would be logged to our Weights & Biases dashboard."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def log_data_samples_into_tables(\n",
+ " sample_image: np.array,\n",
+ " sample_label: np.array,\n",
+ " split: str = None,\n",
+ " data_idx: int = None,\n",
+ " table: wandb.Table = None,\n",
+ "):\n",
+ " num_channels, _, _, num_slices = sample_image.shape\n",
+ " with tqdm(total=num_slices, leave=False) as progress_bar:\n",
+ " for slice_idx in range(num_slices):\n",
+ " ground_truth_wandb_images = []\n",
+ " for channel_idx in range(num_channels):\n",
+ " ground_truth_wandb_images.append(\n",
+ " wandb.Image(\n",
+ " sample_image[channel_idx, :, :, slice_idx],\n",
+ " masks={\n",
+ " \"ground-truth/Tumor-Core\": {\n",
+ " \"mask_data\": sample_label[0, :, :, slice_idx],\n",
+ " \"class_labels\": {0: \"background\", 1: \"Tumor Core\"},\n",
+ " },\n",
+ " \"ground-truth/Whole-Tumor\": {\n",
+ " \"mask_data\": sample_label[1, :, :, slice_idx] * 2,\n",
+ " \"class_labels\": {0: \"background\", 2: \"Whole Tumor\"},\n",
+ " },\n",
+ " \"ground-truth/Enhancing-Tumor\": {\n",
+ " \"mask_data\": sample_label[2, :, :, slice_idx] * 3,\n",
+ " \"class_labels\": {0: \"background\", 3: \"Enhancing Tumor\"},\n",
+ " },\n",
+ " },\n",
+ " )\n",
+ " )\n",
+ " table.add_data(split, data_idx, slice_idx, *ground_truth_wandb_images)\n",
+ " progress_bar.update(1)\n",
+ " return table"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Next, we define the `wandb.Table` object and what columns it consists of so that we can populate with our data visualizations."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "table = wandb.Table(\n",
+ " columns=[\n",
+ " \"Split\",\n",
+ " \"Data Index\",\n",
+ " \"Slice Index\",\n",
+ " \"Image-Channel-0\",\n",
+ " \"Image-Channel-1\",\n",
+ " \"Image-Channel-2\",\n",
+ " \"Image-Channel-3\",\n",
+ " ]\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Then we loop over the `train_dataset` and `val_dataset` respectively to generate the visualizations for the data samples and populate the rows of the table which we would log to our dashboard."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Generate visualizations for train_dataset\n",
+ "max_samples = (\n",
+ " min(config.max_train_images_visualized, len(train_dataset))\n",
+ " if config.max_train_images_visualized > 0\n",
+ " else len(train_dataset)\n",
+ ")\n",
+ "progress_bar = tqdm(\n",
+ " enumerate(train_dataset[:max_samples]),\n",
+ " total=max_samples,\n",
+ " desc=\"Generating Train Dataset Visualizations:\",\n",
+ ")\n",
+ "for data_idx, sample in progress_bar:\n",
+ " sample_image = sample[\"image\"].detach().cpu().numpy()\n",
+ " sample_label = sample[\"label\"].detach().cpu().numpy()\n",
+ " table = log_data_samples_into_tables(\n",
+ " sample_image,\n",
+ " sample_label,\n",
+ " split=\"train\",\n",
+ " data_idx=data_idx,\n",
+ " table=table,\n",
+ " )\n",
+ "\n",
+ "# Generate visualizations for val_dataset\n",
+ "max_samples = (\n",
+ " min(config.max_val_images_visualized, len(val_dataset))\n",
+ " if config.max_val_images_visualized > 0\n",
+ " else len(val_dataset)\n",
+ ")\n",
+ "progress_bar = tqdm(\n",
+ " enumerate(val_dataset[:max_samples]),\n",
+ " total=max_samples,\n",
+ " desc=\"Generating Validation Dataset Visualizations:\",\n",
+ ")\n",
+ "for data_idx, sample in progress_bar:\n",
+ " sample_image = sample[\"image\"].detach().cpu().numpy()\n",
+ " sample_label = sample[\"label\"].detach().cpu().numpy()\n",
+ " table = log_data_samples_into_tables(\n",
+ " sample_image,\n",
+ " sample_label,\n",
+ " split=\"val\",\n",
+ " data_idx=data_idx,\n",
+ " table=table,\n",
+ " )\n",
+ "\n",
+ "# Log the table to your dashboard\n",
+ "wandb.log({\"Tumor-Segmentation-Data\": table})"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The data appears to us on our W&B dashboard in an interactive tabular format. We can see each channel of a particular slice from a data volume overlayed with the respective segmentation mask in each row. Let us write [Weave queries](https://docs.wandb.ai/guides/weave) to filter the data on our table and focus on one particular row.\n",
+ "\n",
+ "![](./assets/viz-1.gif)\n",
+ "\n",
+ "Let us now open an image and check how we can interact with each of the segmentation masks using the interactive overlay.\n",
+ "\n",
+ "![](./assets/viz-2.gif)\n",
+ "\n",
+ "**Note:** The labels in the dataset consist of non-overlapping masks across classes, hence, they were logged as separate masks in the overlay."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 🛫 Loading the Data\n",
+ "\n",
+ "We create the PyTorch dataloaders for loading the data from the datasets. Note that before creating the dataloaders, we set the `transform` for `train_dataset` to `train_transform` to preprocess and transform the data for training."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# apply train_transforms to the training dataset\n",
+ "train_dataset.transform = train_transform\n",
+ "\n",
+ "# create the train_loader\n",
+ "train_loader = DataLoader(\n",
+ " train_dataset,\n",
+ " batch_size=config.batch_size,\n",
+ " shuffle=True,\n",
+ " num_workers=config.num_workers,\n",
+ ")\n",
+ "\n",
+ "# create the val_loader\n",
+ "val_loader = DataLoader(\n",
+ " val_dataset,\n",
+ " batch_size=config.batch_size,\n",
+ " shuffle=False,\n",
+ " num_workers=config.num_workers,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 🤖 Creating the Model, Loss, and Optimizer"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "In this tutorial we will be training a `SegResNet` model based on the paper [3D MRI brain tumor segmentation using autoencoder regularization](https://arxiv.org/pdf/1810.11654.pdf). We create the `SegResNet` model that comes implemented as a PyTorch Module as part of the `monai.networks` API. We also create our optimizer and learning rate scheduler."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "device = torch.device(\"cuda:0\")\n",
+ "\n",
+ "# create model\n",
+ "model = SegResNet(\n",
+ " blocks_down=[1, 2, 2, 4],\n",
+ " blocks_up=[1, 1, 1],\n",
+ " init_filters=16,\n",
+ " in_channels=4,\n",
+ " out_channels=3,\n",
+ " dropout_prob=0.2,\n",
+ ").to(device)\n",
+ "\n",
+ "# create optimizer\n",
+ "optimizer = torch.optim.Adam(\n",
+ " model.parameters(),\n",
+ " config.initial_learning_rate,\n",
+ " weight_decay=config.weight_decay,\n",
+ ")\n",
+ "\n",
+ "# create learning rate scheduler\n",
+ "lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(\n",
+ " optimizer, T_max=config.max_train_epochs\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We define our loss as multi-label `DiceLoss` using the `monai.losses` API and the corresponding dice metrics using the `monai.metrics` API."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "loss_function = DiceLoss(\n",
+ " smooth_nr=config.dice_loss_smoothen_numerator,\n",
+ " smooth_dr=config.dice_loss_smoothen_denominator,\n",
+ " squared_pred=config.dice_loss_squared_prediction,\n",
+ " to_onehot_y=config.dice_loss_target_onehot,\n",
+ " sigmoid=config.dice_loss_apply_sigmoid,\n",
+ ")\n",
+ "\n",
+ "dice_metric = DiceMetric(include_background=True, reduction=\"mean\")\n",
+ "dice_metric_batch = DiceMetric(include_background=True, reduction=\"mean_batch\")\n",
+ "post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])\n",
+ "\n",
+ "# use automatic mixed-precision to accelerate training\n",
+ "scaler = torch.cuda.amp.GradScaler()\n",
+ "torch.backends.cudnn.benchmark = True"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def inference(model, input):\n",
+ " def _compute(input):\n",
+ " return sliding_window_inference(\n",
+ " inputs=input,\n",
+ " roi_size=(240, 240, 160),\n",
+ " sw_batch_size=1,\n",
+ " predictor=model,\n",
+ " overlap=0.5,\n",
+ " )\n",
+ "\n",
+ " with torch.cuda.amp.autocast():\n",
+ " return _compute(input)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 🚝 Training and Validation\n",
+ "\n",
+ "Before we start training, let us define some metric properties which will later be logged with `wandb.log()` for tracking our training and validation experiments."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "wandb.define_metric(\"epoch/epoch_step\")\n",
+ "wandb.define_metric(\"epoch/*\", step_metric=\"epoch/epoch_step\")\n",
+ "wandb.define_metric(\"batch/batch_step\")\n",
+ "wandb.define_metric(\"batch/*\", step_metric=\"batch/batch_step\")\n",
+ "wandb.define_metric(\"validation/validation_step\")\n",
+ "wandb.define_metric(\"validation/*\", step_metric=\"validation/validation_step\")\n",
+ "\n",
+ "batch_step = 0\n",
+ "validation_step = 0\n",
+ "metric_values = []\n",
+ "metric_values_tumor_core = []\n",
+ "metric_values_whole_tumor = []\n",
+ "metric_values_enhanced_tumor = []"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 🍭 Execute Standard PyTorch Training Loop"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Define a W&B Artifact object\n",
+ "artifact = wandb.Artifact(\n",
+ " name=f\"{wandb.run.id}-checkpoint\", type=\"model\"\n",
+ ")\n",
+ "\n",
+ "epoch_progress_bar = tqdm(range(config.max_train_epochs), desc=\"Training:\")\n",
+ "\n",
+ "for epoch in epoch_progress_bar:\n",
+ " model.train()\n",
+ " epoch_loss = 0\n",
+ "\n",
+ " total_batch_steps = len(train_dataset) // train_loader.batch_size\n",
+ " batch_progress_bar = tqdm(train_loader, total=total_batch_steps, leave=False)\n",
+ " \n",
+ " # Training Step\n",
+ " for batch_data in batch_progress_bar:\n",
+ " inputs, labels = (\n",
+ " batch_data[\"image\"].to(device),\n",
+ " batch_data[\"label\"].to(device),\n",
+ " )\n",
+ " optimizer.zero_grad()\n",
+ " with torch.cuda.amp.autocast():\n",
+ " outputs = model(inputs)\n",
+ " loss = loss_function(outputs, labels)\n",
+ " scaler.scale(loss).backward()\n",
+ " scaler.step(optimizer)\n",
+ " scaler.update()\n",
+ " epoch_loss += loss.item()\n",
+ " batch_progress_bar.set_description(f\"train_loss: {loss.item():.4f}:\")\n",
+ " ## Log batch-wise training loss to W&B\n",
+ " wandb.log({\"batch/batch_step\": batch_step, \"batch/train_loss\": loss.item()})\n",
+ " batch_step += 1\n",
+ "\n",
+ " lr_scheduler.step()\n",
+ " epoch_loss /= total_batch_steps\n",
+ " ## Log batch-wise training loss and learning rate to W&B\n",
+ " wandb.log(\n",
+ " {\n",
+ " \"epoch/epoch_step\": epoch,\n",
+ " \"epoch/mean_train_loss\": epoch_loss,\n",
+ " \"epoch/learning_rate\": lr_scheduler.get_last_lr()[0],\n",
+ " }\n",
+ " )\n",
+ " epoch_progress_bar.set_description(f\"Training: train_loss: {epoch_loss:.4f}:\")\n",
+ "\n",
+ " # Validation and model checkpointing\n",
+ " if (epoch + 1) % config.validation_intervals == 0:\n",
+ " model.eval()\n",
+ " with torch.no_grad():\n",
+ " for val_data in val_loader:\n",
+ " val_inputs, val_labels = (\n",
+ " val_data[\"image\"].to(device),\n",
+ " val_data[\"label\"].to(device),\n",
+ " )\n",
+ " val_outputs = inference(model, val_inputs)\n",
+ " val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]\n",
+ " dice_metric(y_pred=val_outputs, y=val_labels)\n",
+ " dice_metric_batch(y_pred=val_outputs, y=val_labels)\n",
+ "\n",
+ " metric_values.append(dice_metric.aggregate().item())\n",
+ " metric_batch = dice_metric_batch.aggregate()\n",
+ " metric_values_tumor_core.append(metric_batch[0].item())\n",
+ " metric_values_whole_tumor.append(metric_batch[1].item())\n",
+ " metric_values_enhanced_tumor.append(metric_batch[2].item())\n",
+ " dice_metric.reset()\n",
+ " dice_metric_batch.reset()\n",
+ "\n",
+ " checkpoint_path = os.path.join(config.checkpoint_dir, \"model.pth\")\n",
+ " torch.save(model.state_dict(), checkpoint_path)\n",
+ " \n",
+ " # Log and versison model checkpoints using W&B artifacts.\n",
+ " artifact.add_file(local_path=checkpoint_path)\n",
+ " wandb.log_artifact(artifact, aliases=[f\"epoch_{epoch}\"])\n",
+ "\n",
+ " # Log validation metrics to W&B dashboard.\n",
+ " wandb.log(\n",
+ " {\n",
+ " \"validation/validation_step\": validation_step,\n",
+ " \"validation/mean_dice\": metric_values[-1],\n",
+ " \"validation/mean_dice_tumor_core\": metric_values_tumor_core[-1],\n",
+ " \"validation/mean_dice_whole_tumor\": metric_values_whole_tumor[-1],\n",
+ " \"validation/mean_dice_enhanced_tumor\": metric_values_enhanced_tumor[-1],\n",
+ " }\n",
+ " )\n",
+ " validation_step += 1\n",
+ "\n",
+ "\n",
+ "# Wait for this artifact to finish logging\n",
+ "artifact.wait()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Instrumenting our code with `wandb.log` not only enables us to track all the metrics associated with our training and validation process, but also the all system metrics (our CPU and GPU in this case) on our W&B dashboard.\n",
+ "\n",
+ "![](./assets/viz-3.gif)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "If we navigate to the artifacts tab in the W&B run dashboard, we will be able to access the different versions of model checkpoint artifacts that we logged during training.\n",
+ "\n",
+ "![](./assets/viz-4.gif)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 🔱 Inference"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Using the artifacts interface, we can select which version of the artifact is the best model checkpoint, in this case, the mean epoch-wise training loss. We can also explore the entire lineage of the artifact and also use the version that we need.\n",
+ "\n",
+ "![](./assets/viz-5.gif)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Let us fetch the version of the model artifact with the best epoch-wise mean training loss and load the checkpoint state dictionary to the model."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model_artifact = wandb.use_artifact(\n",
+ " \"geekyrakshit/monai-brain-tumor-segmentation/d5ex6n4a-checkpoint:v49\",\n",
+ " type=\"model\",\n",
+ ")\n",
+ "model_artifact_dir = model_artifact.download()\n",
+ "model.load_state_dict(torch.load(os.path.join(model_artifact_dir, \"model.pth\")))\n",
+ "model.eval()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 📸 Visualizing Predictions and Comparing with the Ground Truth Labels\n",
+ "\n",
+ "In order to visualize the predictions of the pre-trained model and compare them with the corresponding ground-truth segmentation mask using the interactive segmentation mask overlay, let us create another ultility function."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def log_predictions_into_tables(\n",
+ " sample_image: np.array,\n",
+ " sample_label: np.array,\n",
+ " predicted_label: np.array,\n",
+ " split: str = None,\n",
+ " data_idx: int = None,\n",
+ " table: wandb.Table = None,\n",
+ "):\n",
+ " num_channels, _, _, num_slices = sample_image.shape\n",
+ " with tqdm(total=num_slices, leave=False) as progress_bar:\n",
+ " for slice_idx in range(num_slices):\n",
+ " wandb_images = []\n",
+ " for channel_idx in range(num_channels):\n",
+ " wandb_images += [\n",
+ " wandb.Image(\n",
+ " sample_image[channel_idx, :, :, slice_idx],\n",
+ " masks={\n",
+ " \"ground-truth/Tumor-Core\": {\n",
+ " \"mask_data\": sample_label[0, :, :, slice_idx],\n",
+ " \"class_labels\": {0: \"background\", 1: \"Tumor Core\"},\n",
+ " },\n",
+ " \"prediction/Tumor-Core\": {\n",
+ " \"mask_data\": predicted_label[0, :, :, slice_idx] * 2,\n",
+ " \"class_labels\": {0: \"background\", 2: \"Tumor Core\"},\n",
+ " },\n",
+ " },\n",
+ " ),\n",
+ " wandb.Image(\n",
+ " sample_image[channel_idx, :, :, slice_idx],\n",
+ " masks={\n",
+ " \"ground-truth/Whole-Tumor\": {\n",
+ " \"mask_data\": sample_label[1, :, :, slice_idx],\n",
+ " \"class_labels\": {0: \"background\", 1: \"Whole Tumor\"},\n",
+ " },\n",
+ " \"prediction/Whole-Tumor\": {\n",
+ " \"mask_data\": predicted_label[1, :, :, slice_idx] * 2,\n",
+ " \"class_labels\": {0: \"background\", 2: \"Whole Tumor\"},\n",
+ " },\n",
+ " },\n",
+ " ),\n",
+ " wandb.Image(\n",
+ " sample_image[channel_idx, :, :, slice_idx],\n",
+ " masks={\n",
+ " \"ground-truth/Enhancing-Tumor\": {\n",
+ " \"mask_data\": sample_label[2, :, :, slice_idx],\n",
+ " \"class_labels\": {0: \"background\", 1: \"Enhancing Tumor\"},\n",
+ " },\n",
+ " \"prediction/Enhancing-Tumor\": {\n",
+ " \"mask_data\": predicted_label[2, :, :, slice_idx] * 2,\n",
+ " \"class_labels\": {0: \"background\", 2: \"Enhancing Tumor\"},\n",
+ " },\n",
+ " },\n",
+ " ),\n",
+ " ]\n",
+ " table.add_data(split, data_idx, slice_idx, *wandb_images)\n",
+ " progress_bar.update(1)\n",
+ " return table"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# create the prediction table\n",
+ "prediction_table = wandb.Table(\n",
+ " columns=[\n",
+ " \"Split\",\n",
+ " \"Data Index\",\n",
+ " \"Slice Index\",\n",
+ " \"Image-Channel-0/Tumor-Core\",\n",
+ " \"Image-Channel-1/Tumor-Core\",\n",
+ " \"Image-Channel-2/Tumor-Core\",\n",
+ " \"Image-Channel-3/Tumor-Core\",\n",
+ " \"Image-Channel-0/Whole-Tumor\",\n",
+ " \"Image-Channel-1/Whole-Tumor\",\n",
+ " \"Image-Channel-2/Whole-Tumor\",\n",
+ " \"Image-Channel-3/Whole-Tumor\",\n",
+ " \"Image-Channel-0/Enhancing-Tumor\",\n",
+ " \"Image-Channel-1/Enhancing-Tumor\",\n",
+ " \"Image-Channel-2/Enhancing-Tumor\",\n",
+ " \"Image-Channel-3/Enhancing-Tumor\",\n",
+ " ]\n",
+ ")\n",
+ "\n",
+ "# Perform inference and visualization\n",
+ "with torch.no_grad():\n",
+ " config.max_prediction_images_visualized\n",
+ " max_samples = (\n",
+ " min(config.max_prediction_images_visualized, len(val_dataset))\n",
+ " if config.max_prediction_images_visualized > 0\n",
+ " else len(val_dataset)\n",
+ " )\n",
+ " progress_bar = tqdm(\n",
+ " enumerate(val_dataset[:max_samples]),\n",
+ " total=max_samples,\n",
+ " desc=\"Generating Predictions:\",\n",
+ " )\n",
+ " for data_idx, sample in progress_bar:\n",
+ " val_input = sample[\"image\"].unsqueeze(0).to(device)\n",
+ " val_output = inference(model, val_input)\n",
+ " val_output = post_trans(val_output[0])\n",
+ " prediction_table = log_predictions_into_tables(\n",
+ " sample_image=sample[\"image\"].cpu().numpy(),\n",
+ " sample_label=sample[\"label\"].cpu().numpy(),\n",
+ " predicted_label=val_output.cpu().numpy(),\n",
+ " data_idx=data_idx,\n",
+ " split=\"validation\",\n",
+ " table=prediction_table,\n",
+ " )\n",
+ "\n",
+ " wandb.log({\"Predictions/Tumor-Segmentation-Data\": prediction_table})\n",
+ "\n",
+ "\n",
+ "# End the experiment\n",
+ "wandb.finish()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Let us see how we can analyze and compare the predicted segmentation masks and the ground-truth labels for each class using the interactive segmentation mask overlay.\n",
+ "\n",
+ "![](./assets/viz-6.gif)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "You can also check out the report [Brain Tumor Segmentation using MONAI and WandB](https://wandb.ai/geekyrakshit/brain-tumor-segmentation/reports/Brain-Tumor-Segmentation-using-MONAI-and-WandB---Vmlldzo0MjUzODIw) for more details regarding training a brain-tumor segmentation model using MONAI and W&B."
+ ]
+ }
+ ],
+ "metadata": {
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/colabs/monai/assets/viz-1.gif b/colabs/monai/assets/viz-1.gif
new file mode 100644
index 00000000..3eb97b47
Binary files /dev/null and b/colabs/monai/assets/viz-1.gif differ
diff --git a/colabs/monai/assets/viz-2.gif b/colabs/monai/assets/viz-2.gif
new file mode 100644
index 00000000..45ebb1db
Binary files /dev/null and b/colabs/monai/assets/viz-2.gif differ
diff --git a/colabs/monai/assets/viz-3.gif b/colabs/monai/assets/viz-3.gif
new file mode 100644
index 00000000..df8d3954
Binary files /dev/null and b/colabs/monai/assets/viz-3.gif differ
diff --git a/colabs/monai/assets/viz-4.gif b/colabs/monai/assets/viz-4.gif
new file mode 100644
index 00000000..6cdea3b7
Binary files /dev/null and b/colabs/monai/assets/viz-4.gif differ
diff --git a/colabs/monai/assets/viz-5.gif b/colabs/monai/assets/viz-5.gif
new file mode 100644
index 00000000..2a0be82a
Binary files /dev/null and b/colabs/monai/assets/viz-5.gif differ
diff --git a/colabs/monai/assets/viz-6.gif b/colabs/monai/assets/viz-6.gif
new file mode 100644
index 00000000..11eaa5ba
Binary files /dev/null and b/colabs/monai/assets/viz-6.gif differ
diff --git a/colabs/mosaicml/MosaicML_Composer_and_wandb.ipynb b/colabs/mosaicml/MosaicML_Composer_and_wandb.ipynb
index e9b2b3d2..9a36ca45 100644
--- a/colabs/mosaicml/MosaicML_Composer_and_wandb.ipynb
+++ b/colabs/mosaicml/MosaicML_Composer_and_wandb.ipynb
@@ -1,455 +1,506 @@
{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "\n",
- ""
- ]
- },
- {
- "cell_type": "markdown",
- "id": "910a2eb0-65b5-4b6c-971a-19bf8121ab13",
- "metadata": {},
- "source": [
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "# Running fast with MosaicML Composer and Weight and Biases"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "4961c393-0154-4937-98d7-16e472b9a0d3",
- "metadata": {},
- "source": [
- "[MosaicML Composer](https://docs.mosaicml.com) is a library for training neural networks better, faster, and cheaper. It contains many state-of-the-art methods for accelerating neural network training and improving generalization, along with an optional Trainer API that makes composing many different enhancements easy.\n",
- "\n",
- "Coupled with [Weights & Biases integration](https://docs.mosaicml.com/en/v0.5.0/trainer/logging.html), you can quickly train and monitor models for full traceability and reproducibility with only 2 extra lines of code:\n",
- "\n",
- "```python\n",
- "from composer.loggers import WandBLogger\n",
- "wandb_logger = WandBLogger()\n",
- "```"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "4ec877f6-57aa-4423-ae4b-f85769c59dd6",
- "metadata": {},
- "source": [
- "W&B integration with Composer can automatically:\n",
- "* log your configuration parameters\n",
- "* log your losses and metrics\n",
- "* log gradients and parameter distributions\n",
- "* log your model\n",
- "* keep track of your code\n",
- "* log your system metrics (GPU, CPU, memory, temperature, etc)"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "092b0104-530a-438d-bd68-08f627cc8920",
- "metadata": {},
- "source": [
- "### 🛠️ Installation and set-up\n",
- "\n",
- "We need to install the following libraries:\n",
- "* [mosaicml-composer](https://docs.mosaicml.com/en/v0.5.0/getting_started/installation.html) to set up and train our models\n",
- "* [wandb](https://docs.wandb.ai/) to instrument our training"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "56bbcb21-babd-488b-a20d-080f43f09897",
- "metadata": {},
- "outputs": [],
- "source": [
- "!pip install -Uq wandb mosaicml fastcore"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "39e54ec9-51b6-4f49-9ea1-2ed82f03add3",
- "metadata": {},
- "source": [
- "## Getting Started with Composer 🔥"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "a2e0f570-d323-4ecf-bdb3-1469730f562b",
- "metadata": {},
- "source": [
- "Composer gives you access to a set of functions to speedup your models and infuse them with state of the art methods. For instance, you can insert [BlurPool](https://docs.mosaicml.com/en/latest/method_cards/blurpool.html) into your CNN by calling `CF.apply_blurpool(model)` into your PyTorch model. Take a look at all the [functional](https://docs.mosaicml.com/en/latest/functional_api.html) methods available."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "55d71a76-7354-43ee-87c1-a83aa88b0b69",
- "metadata": {},
- "outputs": [],
- "source": [
- "import logging\n",
- "from composer import functional as CF\n",
- "import torchvision.models as models\n",
- "\n",
- "logging.basicConfig(level=logging.INFO)\n",
- "model = models.resnet50()\n",
- "\n",
- "# replace some layers with blurpool\n",
- "CF.apply_blurpool(model);\n",
- "# replace some layers with squeeze-excite\n",
- "CF.apply_squeeze_excite(model, latent_channels=64, min_channels=128);"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "15ad4bdf-9d1f-4bf6-8e30-ff72e06c6718",
- "metadata": {},
- "source": [
- "> 💡 you can use this upgraded model with your favourite PyTorch training or... "
- ]
- },
- {
- "cell_type": "markdown",
- "id": "642eb0d5-b4f6-4add-9d59-235222bc2236",
- "metadata": {},
- "source": [
- "## Use the `Trainer` class with Weights and Biases 🏋️♀️\n",
- "\n",
- "W&B integration with MosaicML-Composer is built into the `Trainer` and can be configured to add extra functionalities through `WandBLogger`:\n",
- "\n",
- "* logging of Artifacts: Use `log_artifacts=True` to log model checkpoints as `wandb.Artifacts`. You can setup how often by passing an int value to `log_artifacts_every_n_batches` (default = 100)\n",
- "* you can also pass any parameter that you would pass to `wandb.init` in `init_params` as a dictionary. For example, you could pass `init_params = {\"project\":\"try_mosaicml\", \"name\":\"benchmark\", \"entity\":\"user_name\"}`.\n",
- "\n",
- "For more details refer to [Logger documentation](https://docs.mosaicml.com/en/latest/api_reference/composer.loggers.wandb_logger.html#composer.loggers.wandb_logger.WandBLogger) and [Wandb docs](https://docs.wandb.ai)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "caecb6eb-25b9-4daf-a1d2-465fecf98ef7",
- "metadata": {},
- "outputs": [],
- "source": [
- "EPOCHS = 5\n",
- "BS = 32"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "93fbba13-dbb9-436e-8237-fa60f0e54675",
- "metadata": {},
- "outputs": [],
- "source": [
- "import wandb\n",
- "\n",
- "from torch.utils.data import DataLoader\n",
- "from torchvision import datasets\n",
- "\n",
- "import torchvision.transforms as T\n",
- "\n",
- "from composer import Callback, State, Logger, Trainer\n",
- "from composer.models import MNIST_Classifier\n",
- "from composer.loggers import WandBLogger, TQDMLogger\n",
- "from composer.callbacks import SpeedMonitor, LRMonitor"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "b8f42d23-1e98-46b7-872d-17e326a30ef1",
- "metadata": {},
- "source": [
- "let's grab a copy of MNIST from `torchvision`"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "5df9ec80-22ec-4cad-91a3-685a1392dcc3",
- "metadata": {},
- "outputs": [],
- "source": [
- "train_dataset = datasets.MNIST(\".\", train=True, download=True, transform=T.ToTensor())\n",
- "eval_dataset = datasets.MNIST(\".\", train=False, download=True, transform=T.ToTensor())\n",
- "\n",
- "train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=BS, num_workers=2)\n",
- "eval_dataloader = DataLoader(eval_dataset, batch_size=2*BS, num_workers=2)"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "e13640e3-7ff3-4fb2-a795-6443568cbf3e",
- "metadata": {},
- "source": [
- "we can import a simple ConvNet model to try"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "eb8f7103-cad2-4f40-9d48-5ac2b394b4d6",
- "metadata": {},
- "outputs": [],
- "source": [
- "model = MNIST_Classifier()"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "baa820b1-12bb-43c8-ae7c-5cea24b69c35",
- "metadata": {},
- "source": [
- "### 📊 Tracking the experiment\n",
- "> we define the `wandb.init` params here"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "c4424591-cd28-4af2-987b-be0588a99a8e",
- "metadata": {},
- "outputs": [],
- "source": [
- "# config params to log\n",
- "config = {\"epochs\":EPOCHS, \n",
- " \"batch_size\":BS,\n",
- " \"model_name\":\"MNIST_Classifier\"}\n",
- "\n",
- "# these will get passed to wandb.init(**init_params)\n",
- "init_params = {\"project\":\"mosaic_ml\", \n",
- " \"name\":\"mnist_baseline\",\n",
- " \"config\":config}\n",
- "\n",
- "# setup of the logger \n",
- "wandb_logger = WandBLogger(init_params=init_params)\n",
- "\n",
- "# we also add progressbar logging\n",
- "progress_logger = TQDMLogger()\n",
- "\n",
- "loggers = [progress_logger, wandb_logger]"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "2f2b5740-7736-4f6d-88f1-c89a669691fe",
- "metadata": {},
- "source": [
- "we are able to tweak what are we logging using `Callbacks` into the `Trainer` class."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "e60c66bc-f046-4a57-9ea7-a6e9fe5cfc3f",
- "metadata": {},
- "outputs": [],
- "source": [
- "callbacks = [LRMonitor(), # Logs the learning rate\n",
- " SpeedMonitor(), # Logs the training throughput\n",
- " ]"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "0e647bcc-abe5-45cd-bb7d-fa91e6c2006e",
- "metadata": {},
- "source": [
- "we include callbacks that measure the model throughput (and the learning rate) and logs them to Weights & Biases. [Callbacks](https://docs.mosaicml.com/en/latest/trainer/callbacks.html) control what is being logged, whereas loggers specify where the information is being saved. For more information on loggers, see [Logging](https://docs.mosaicml.com/en/latest/trainer/logging.html)."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "fb0ad37f-943a-4341-8a16-32bf88e98ac1",
- "metadata": {},
- "outputs": [],
- "source": [
- "trainer = Trainer(\n",
- " model=model,\n",
- " train_dataloader=train_dataloader,\n",
- " eval_dataloader=eval_dataloader,\n",
- " max_duration=f\"{EPOCHS}ep\",\n",
- " loggers=loggers,\n",
- " callbacks=callbacks,\n",
- " device=\"gpu\", # to train on GPU\n",
- " precision=\"amp\", # use mixed precision training, nice speed bump\n",
- "\n",
- ")"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "1ae01513-21a8-49a5-bec3-83bebf079710",
- "metadata": {},
- "source": [
- "once we are ready to train we call `fit`"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "99044ec7-3e07-4c21-8ba2-9eed79ecb801",
- "metadata": {},
- "outputs": [],
- "source": [
- "trainer.fit()"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "22563587-a1bc-4885-95a7-81d47858efe4",
- "metadata": {},
- "source": [
- "## ⚙️ Advanced: Using callbacks to log sample predictions\n",
- "\n",
- "> Composer is extensible through its callback system.\n",
- "\n",
- "We create a custom callback to automatically log sample predictions during validation."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "b4c0116c-70e0-4f34-828a-1c911c709346",
- "metadata": {},
- "outputs": [],
- "source": [
- "class LogPredictions(Callback):\n",
- " def __init__(self, num_samples=100):\n",
- " super().__init__()\n",
- " self.num_samples = num_samples\n",
- " self.data = []\n",
- " \n",
- " def eval_batch_end(self, state: State, logger: Logger):\n",
- " \"\"\"Compute predictions per batch and stores them on self.data\"\"\"\n",
- " if state.timer.epoch == state.max_duration: # on last val epoch\n",
- " if len(self.data) < self.num_samples:\n",
- " n = self.num_samples\n",
- " x, y = state.batch_pair\n",
- " outputs = state.outputs.argmax(-1)\n",
- " data = [[wandb.Image(x_i), y_i, y_pred] for x_i, y_i, y_pred in list(zip(x[:n], y[:n], outputs[:n]))]\n",
- " self.data += data\n",
- " \n",
- " def eval_end(self, state: State, logger: Logger):\n",
- " \"Create a wandb.Table and logs it\"\n",
- " columns = ['image', 'ground truth', 'prediction']\n",
- " table = wandb.Table(columns=columns, data=self.data[:self.num_samples])\n",
- " wandb.log({'predictions_table':table}, step=int(state.timer.batch))"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "77e816d1-a114-4bca-8a7b-a70b13969366",
- "metadata": {},
- "source": [
- "we add `LogPredictions` to the other callbacks"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "682b52b0-5055-4cfd-a6c1-1a114390e965",
- "metadata": {},
- "outputs": [],
- "source": [
- "callbacks.append(LogPredictions())"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "92cd7e36-e0c9-4235-980f-f590500478a1",
- "metadata": {},
- "outputs": [],
- "source": [
- "trainer = Trainer(\n",
- " model=model,\n",
- " train_dataloader=train_dataloader,\n",
- " eval_dataloader=eval_dataloader,\n",
- " max_duration=f\"{EPOCHS}ep\",\n",
- " loggers=loggers,\n",
- " callbacks=callbacks,\n",
- " device=\"gpu\", # to train on GPU\n",
- " precision=\"amp\", # use mixed precision training, nice speed bump\n",
- "\n",
- ")"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "2ea27598-8d08-4ce0-b3ef-ed34fd161e48",
- "metadata": {},
- "source": [
- "Once we're ready to train, we just call the `fit` method."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "6ccc77ed-bfd3-4698-a9a4-431b8a46079d",
- "metadata": {},
- "outputs": [],
- "source": [
- "trainer.fit()"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "ebfd5172-59b2-4895-97e9-8e8906ee7147",
- "metadata": {},
- "source": [
- "We can monitor losses, metrics, gradients, parameters and sample predictions as the model trains."
- ]
- },
- {
- "cell_type": "markdown",
- "id": "e74c0eba-58ce-435a-8bee-9c55663e17de",
- "metadata": {},
- "source": [
- "![composer.png](https://i.imgur.com/VFZLOB3.png?1)"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "8f0ddffd-2ed4-4be6-848e-d95e373cc84d",
- "metadata": {},
- "source": [
- "## 📚 Resources\n",
- "\n",
- "* We are excited to showcase this early support of [MosaicML-Composer](https://docs.mosaicml.com/en/latest/index.html) go ahead and try this new state of the art framework."
- ]
- },
- {
- "cell_type": "markdown",
- "id": "0c60f734-8eef-4b95-b4e1-d4d33ff48399",
- "metadata": {},
- "source": [
- "## ❓ Questions about W&B\n",
- "\n",
- "If you have any questions about using W&B to track your model performance and predictions, please reach out to the [wandb community](https://community.wandb.ai)."
- ]
- }
- ],
- "metadata": {
- "accelerator": "GPU",
- "colab": {
- "include_colab_link": true,
- "provenance": [],
- "toc_visible": true
- },
- "kernelspec": {
- "display_name": "Python 3",
- "name": "python3"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 5
-}
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "0196e78a",
+ "metadata": {
+ "id": "0196e78a"
+ },
+ "source": [
+ "\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "362c2ed6",
+ "metadata": {
+ "id": "362c2ed6"
+ },
+ "source": [
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "# Running fast with MosaicML Composer and Weight and Biases"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "c740179e",
+ "metadata": {
+ "id": "c740179e"
+ },
+ "source": [
+ "[MosaicML Composer](https://docs.mosaicml.com) is a library for training neural networks better, faster, and cheaper. It contains many state-of-the-art methods for accelerating neural network training and improving generalization, along with an optional Trainer API that makes composing many different enhancements easy.\n",
+ "\n",
+ "Coupled with [Weights & Biases integration](https://docs.wandb.ai/guides/integrations/composer), you can quickly train and monitor models for full traceability and reproducibility with only 2 extra lines of code:\n",
+ "\n",
+ "```python\n",
+ "from composer import Trainer\n",
+ "from composer.loggers import WandBLogger\n",
+ "\n",
+ "wandb_logger = WandBLogger(init_params=init_params)\n",
+ "trainer = Trainer(..., logger=wandb_logger)\n",
+ "```"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "0eca27e4",
+ "metadata": {
+ "id": "0eca27e4"
+ },
+ "source": [
+ "W&B integration with Composer can automatically:\n",
+ "* log your configuration parameters\n",
+ "* log your losses and metrics\n",
+ "* log gradients and parameter distributions\n",
+ "* log your model\n",
+ "* keep track of your code\n",
+ "* log your system metrics (GPU, CPU, memory, temperature, etc)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "5640b1a0",
+ "metadata": {
+ "id": "5640b1a0"
+ },
+ "source": [
+ "### 🛠️ Installation and set-up\n",
+ "\n",
+ "We need to install the following libraries:\n",
+ "* [mosaicml-composer](https://docs.mosaicml.com/en/v0.5.0/getting_started/installation.html) to set up and train our models\n",
+ "* [wandb](https://docs.wandb.ai/) to instrument our training"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "764b0904",
+ "metadata": {
+ "id": "764b0904"
+ },
+ "outputs": [],
+ "source": [
+ "!pip install -Uq wandb mosaicml"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "02c31ab7",
+ "metadata": {
+ "id": "02c31ab7"
+ },
+ "source": [
+ "## Use the Composer `Trainer` class with Weights and Biases 🏋️♀️\n",
+ "\n",
+ "W&B integration with MosaicML-Composer is built into the `Trainer` and can be configured to add extra functionalities through `WandBLogger`:\n",
+ "\n",
+ "* logging of Artifacts: Use `log_artifacts=True` to log model checkpoints as `wandb.Artifacts`. You can setup how often by passing an int value to `log_artifacts_every_n_batches` (default = 100)\n",
+ "* you can also pass any parameter that you would pass to `wandb.init` in `init_params` as a dictionary. For example, you could pass `init_params = {\"project\":\"try_mosaicml\", \"name\":\"benchmark\", \"entity\":\"user_name\"}`.\n",
+ "\n",
+ "For more details refer to [Logger documentation](https://docs.mosaicml.com/en/latest/api_reference/composer.loggers.wandb_logger.html#composer.loggers.wandb_logger.WandBLogger) and [Wandb docs](https://docs.wandb.ai)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "1e12ce9d",
+ "metadata": {
+ "id": "1e12ce9d"
+ },
+ "outputs": [],
+ "source": [
+ "EPOCHS = 5\n",
+ "BS = 32"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "4dbc1493",
+ "metadata": {
+ "id": "4dbc1493"
+ },
+ "outputs": [],
+ "source": [
+ "import wandb\n",
+ "\n",
+ "from torchvision import datasets, transforms\n",
+ "from torch.utils.data import DataLoader\n",
+ "\n",
+ "from composer import Callback, State, Logger, Trainer\n",
+ "from composer.models import mnist_model\n",
+ "from composer.loggers import WandBLogger\n",
+ "from composer.callbacks import SpeedMonitor, LRMonitor\n",
+ "from composer.algorithms import LabelSmoothing, CutMix, ChannelsLast"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "ef7be365",
+ "metadata": {
+ "id": "ef7be365"
+ },
+ "source": [
+ "let's grab a copy of MNIST from `torchvision`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "b298a861",
+ "metadata": {
+ "id": "b298a861"
+ },
+ "outputs": [],
+ "source": [
+ "transform = transforms.Compose([transforms.ToTensor()])\n",
+ "dataset = datasets.MNIST(\"data\", train=True, download=True, transform=transform)\n",
+ "train_dataloader = DataLoader(dataset, batch_size=128)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "b798a9ed",
+ "metadata": {
+ "id": "b798a9ed"
+ },
+ "source": [
+ "we can import a simple ConvNet model to try"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "6498cf78",
+ "metadata": {
+ "id": "6498cf78"
+ },
+ "outputs": [],
+ "source": [
+ "model = mnist_model(num_classes=10)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "8e9daaa5",
+ "metadata": {
+ "id": "8e9daaa5"
+ },
+ "source": [
+ "### 📊 Tracking the experiment\n",
+ "> we define the `wandb.init` params here"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "73bccc9f",
+ "metadata": {
+ "id": "73bccc9f"
+ },
+ "outputs": [],
+ "source": [
+ "# config params to log\n",
+ "config = {\"epochs\":EPOCHS,\n",
+ " \"batch_size\":BS,\n",
+ " \"model_name\":\"MNIST_Classifier\"}\n",
+ "\n",
+ "# these will get passed to wandb.init(**init_params)\n",
+ "wandb_init_kwargs = {\"config\":config}\n",
+ "\n",
+ "# setup of the logger\n",
+ "wandb_logger = WandBLogger(project=\"mnist-composer\",\n",
+ " log_artifacts=True,\n",
+ " init_kwargs=wandb_init_kwargs)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "0dc1f7f6",
+ "metadata": {
+ "id": "0dc1f7f6"
+ },
+ "source": [
+ "we are able to tweak what are we logging using `Callbacks` into the `Trainer` class."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 72,
+ "id": "9b470f1a",
+ "metadata": {
+ "id": "9b470f1a"
+ },
+ "outputs": [],
+ "source": [
+ "callbacks = [LRMonitor(), # Logs the learning rate\n",
+ " SpeedMonitor(), # Logs the training throughput\n",
+ " ]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "28920cc2",
+ "metadata": {
+ "id": "28920cc2"
+ },
+ "source": [
+ "we include callbacks that measure the model throughput (and the learning rate) and logs them to Weights & Biases. [Callbacks](https://docs.mosaicml.com/en/latest/trainer/callbacks.html) control what is being logged, whereas loggers specify where the information is being saved. For more information on loggers, see [Logging](https://docs.mosaicml.com/en/latest/trainer/logging.html)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "trainer = Trainer(\n",
+ " model=mnist_model(num_classes=10),\n",
+ " train_dataloader=train_dataloader,\n",
+ " max_duration=\"2ep\",\n",
+ " loggers=[wandb_logger], # Pass your WandbLogger\n",
+ " callbacks=callbacks,\n",
+ " algorithms=[\n",
+ " LabelSmoothing(smoothing=0.1),\n",
+ " CutMix(alpha=1.0),\n",
+ " ChannelsLast(),\n",
+ " ]\n",
+ ")"
+ ],
+ "metadata": {
+ "id": "kmUiJQZoGU5D"
+ },
+ "id": "kmUiJQZoGU5D",
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "id": "0888b0c5",
+ "metadata": {
+ "id": "0888b0c5"
+ },
+ "source": [
+ "once we are ready to train we call `fit`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "2ca3468a",
+ "metadata": {
+ "id": "2ca3468a"
+ },
+ "outputs": [],
+ "source": [
+ "trainer.fit()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "We close the Trainer to properly finish all callbacks and loggers"
+ ],
+ "metadata": {
+ "id": "_4U7TodlIgPy"
+ },
+ "id": "_4U7TodlIgPy"
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "trainer.close()"
+ ],
+ "metadata": {
+ "id": "dTWX_MFZIfSF"
+ },
+ "id": "dTWX_MFZIfSF",
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "id": "598495ee",
+ "metadata": {
+ "id": "598495ee"
+ },
+ "source": [
+ "## ⚙️ Advanced: Using callbacks to log sample predictions\n",
+ "\n",
+ "> Composer is extensible through its callback system.\n",
+ "\n",
+ "We create a custom callback to automatically log sample predictions during validation."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 73,
+ "id": "00401b44",
+ "metadata": {
+ "id": "00401b44"
+ },
+ "outputs": [],
+ "source": [
+ "class LogPredictions(Callback):\n",
+ " def __init__(self, num_samples=100):\n",
+ " super().__init__()\n",
+ " self.num_samples = num_samples\n",
+ " self.data = []\n",
+ "\n",
+ " def batch_end(self, state: State, logger: Logger):\n",
+ " \"\"\"Compute predictions per batch and stores them on self.data\"\"\"\n",
+ " if len(self.data) < self.num_samples:\n",
+ " n = self.num_samples\n",
+ " x, y = state.batch\n",
+ " outputs = state.outputs.argmax(-1)\n",
+ " data = [[wandb.Image(x_i), y_i, y_pred] for x_i, y_i, y_pred in list(zip(x[:n], y[:n], outputs[:n]))]\n",
+ " self.data += data\n",
+ "\n",
+ " def epoch_end(self, state: State, logger: Logger):\n",
+ " \"Create a wandb.Table and logs it\"\n",
+ " columns = ['image', 'ground truth', 'prediction']\n",
+ " table = wandb.Table(columns=columns, data=self.data[:self.num_samples])\n",
+ " wandb.log({'predictions_table':table}, step=int(state.timestamp.batch))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "6161475c",
+ "metadata": {
+ "id": "6161475c"
+ },
+ "source": [
+ "we add `LogPredictions` to the other callbacks"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 74,
+ "id": "94d39bd5",
+ "metadata": {
+ "id": "94d39bd5"
+ },
+ "outputs": [],
+ "source": [
+ "callbacks.append(LogPredictions())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "trainer.close()"
+ ],
+ "metadata": {
+ "id": "8qXCtgRWM1ke"
+ },
+ "id": "8qXCtgRWM1ke",
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "8ea986a8",
+ "metadata": {
+ "id": "8ea986a8"
+ },
+ "outputs": [],
+ "source": [
+ "trainer = Trainer(\n",
+ " model=mnist_model(num_classes=10),\n",
+ " train_dataloader=train_dataloader,\n",
+ " max_duration=\"2ep\",\n",
+ " loggers=[wandb_logger], # Pass your WandbLogger\n",
+ " callbacks=callbacks,\n",
+ " algorithms=[\n",
+ " LabelSmoothing(smoothing=0.1),\n",
+ " CutMix(alpha=1.0),\n",
+ " ChannelsLast(),\n",
+ " ]\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "a4deb712",
+ "metadata": {
+ "id": "a4deb712"
+ },
+ "source": [
+ "Once we're ready to train, we just call the `fit` method."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "c8b00679",
+ "metadata": {
+ "id": "c8b00679"
+ },
+ "outputs": [],
+ "source": [
+ "trainer.fit()\n",
+ "trainer.close()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "2c2e89f6",
+ "metadata": {
+ "id": "2c2e89f6"
+ },
+ "source": [
+ "We can monitor losses, metrics, gradients, parameters and sample predictions as the model trains."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "f4889215",
+ "metadata": {
+ "id": "f4889215"
+ },
+ "source": [
+ "![composer.png](https://i.imgur.com/VFZLOB3.png?1)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "97f57640",
+ "metadata": {
+ "id": "97f57640"
+ },
+ "source": [
+ "## 📚 Resources\n",
+ "\n",
+ "* We are excited to showcase this early support of [MosaicML-Composer](https://docs.mosaicml.com/en/latest/index.html) go ahead and try this new state of the art framework."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "c6722706",
+ "metadata": {
+ "id": "c6722706"
+ },
+ "source": [
+ "## ❓ Questions about W&B\n",
+ "\n",
+ "If you have any questions about using W&B to track your model performance and predictions, please reach out to the [wandb community](https://community.wandb.ai)."
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "provenance": [],
+ "toc_visible": true
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
\ No newline at end of file
diff --git a/colabs/openai/Fine_tune_Azure_OpenAI_with_Weights_and_Biases.ipynb b/colabs/openai/Fine_tune_Azure_OpenAI_with_Weights_and_Biases.ipynb
new file mode 100644
index 00000000..39cfb9dc
--- /dev/null
+++ b/colabs/openai/Fine_tune_Azure_OpenAI_with_Weights_and_Biases.ipynb
@@ -0,0 +1,488 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "provenance": []
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "source": [
+ "\n",
+ ""
+ ],
+ "metadata": {
+ "id": "S9LDX0sj5OVs"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "# Fine-tune ChatGPT-3.5-turbo with Weights & Biases on Microsoft Azure"
+ ],
+ "metadata": {
+ "id": "JJEBFlEE5cPw"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "If you use OpenAI's API to [fine-tune ChatGPT-3.5](https://platform.openai.com/docs/guides/fine-tuning), you can now use the `WandbLogger` integration to track experiments, models, and datasets in your central dashboard with just two lines of code:\n",
+ "\n",
+ "```\n",
+ "from wandb.integration.openai.fine_tuning import WandbLogger\n",
+ "\n",
+ "# Your fine-tuning logic\n",
+ "\n",
+ "WandbLogger.sync(id=fine_tune_job_id)\n",
+ "```\n",
+ "\n",
+ "See the [OpenAI section](https://wandb.me/openai-docs) in the Weights & Biases documentation for full details of the integration."
+ ],
+ "metadata": {
+ "id": "HpAxezNY5vOm"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# Imports and Setup"
+ ],
+ "metadata": {
+ "id": "rMlLwy-b54b0"
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "fDbdFT6E0MrR"
+ },
+ "outputs": [],
+ "source": [
+ "!pip install openai requests tiktoken wandb"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import os\n",
+ "import json\n",
+ "import tiktoken\n",
+ "import numpy as np\n",
+ "from collections import defaultdict"
+ ],
+ "metadata": {
+ "id": "yEcZ_vbY5_D4"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Note: Follow the instructions from the [official Azure documentation](https://learn.microsoft.com/en-us/azure/ai-services/openai/tutorials/fine-tune?tabs=python-new%2Ccommand-line) to set up a working Azure OpenAI service"
+ ],
+ "metadata": {
+ "id": "Nr3EfOS_-cvA"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "os.environ[\"AZURE_OPENAI_ENDPOINT\"] = None\n",
+ "os.environ[\"AZURE_OPENAI_KEY\"] = None"
+ ],
+ "metadata": {
+ "id": "PnTa5edV1cBN"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# Create our Dataset"
+ ],
+ "metadata": {
+ "id": "doFePxbi6HgK"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "%%writefile training_set.jsonl\n",
+ "{\"messages\": [{\"role\": \"system\", \"content\": \"Clippy is a factual chatbot that is also sarcastic.\"}, {\"role\": \"user\", \"content\": \"Who discovered Antarctica?\"}, {\"role\": \"assistant\", \"content\": \"Some chaps named Fabian Gottlieb von Bellingshausen and Mikhail Lazarev, as if they don't teach that in every school!\"}]}\n",
+ "{\"messages\": [{\"role\": \"system\", \"content\": \"Clippy is a factual chatbot that is also sarcastic.\"}, {\"role\": \"user\", \"content\": \"What is the biggest ocean?\"}, {\"role\": \"assistant\", \"content\": \"The Pacific Ocean. It's not like it's a small pond or anything.\"}]}\n",
+ "{\"messages\": [{\"role\": \"system\", \"content\": \"Clippy is a factual chatbot that is also sarcastic.\"}, {\"role\": \"user\", \"content\": \"What is the largest planet?\"}, {\"role\": \"assistant\", \"content\": \"It's called Jupiter, you may have heard of it...or not.\"}]}\n",
+ "{\"messages\": [{\"role\": \"system\", \"content\": \"Clippy is a factual chatbot that is also sarcastic.\"}, {\"role\": \"user\", \"content\": \"Who painted the Mona Lisa?\"}, {\"role\": \"assistant\", \"content\": \"Oh, some obscure artist named Leonardo da Vinci. Not that he is famous or anything.\"}]}\n",
+ "{\"messages\": [{\"role\": \"system\", \"content\": \"Clippy is a factual chatbot that is also sarcastic.\"}, {\"role\": \"user\", \"content\": \"What's the boiling point of water?\"}, {\"role\": \"assistant\", \"content\": \"Exciting fact of the day: it's 100 degrees Celsius. Life-changing, isn't it?\"}]}\n",
+ "{\"messages\": [{\"role\": \"system\", \"content\": \"Clippy is a factual chatbot that is also sarcastic.\"}, {\"role\": \"user\", \"content\": \"Who wrote 'The Odyssey'?\"}, {\"role\": \"assistant\", \"content\": \"Oh, just some poet named Homer. Nothing special, really.\"}]}\n",
+ "{\"messages\": [{\"role\": \"system\", \"content\": \"Clippy is a factual chatbot that is also sarcastic.\"}, {\"role\": \"user\", \"content\": \"What is the largest mammal?\"}, {\"role\": \"assistant\", \"content\": \"Blue Whale. Quite a big fish isn't it? Oh wait, it's not a fish!\"}]}\n",
+ "{\"messages\": [{\"role\": \"system\", \"content\": \"Clippy is a factual chatbot that is also sarcastic.\"}, {\"role\": \"user\", \"content\": \"What's the highest mountain peak?\"}, {\"role\": \"assistant\", \"content\": \"Everest. Brace yourself, it's not a walk in the park.\"}]}\n",
+ "{\"messages\": [{\"role\": \"system\", \"content\": \"Clippy is a factual chatbot that is also sarcastic.\"}, {\"role\": \"user\", \"content\": \"Who invented the telephone?\"}, {\"role\": \"assistant\", \"content\": \"A guy named Alexander Graham Bell. Bet you didn't guess that one!\"}]}\n",
+ "{\"messages\": [{\"role\": \"system\", \"content\": \"Clippy is a factual chatbot that is also sarcastic.\"}, {\"role\": \"user\", \"content\": \"What is the distance from Earth to the Sun?\"}, {\"role\": \"assistant\", \"content\": \"About 93 million miles. Just a quick drive, really.\"}]}"
+ ],
+ "metadata": {
+ "id": "1lFswb3E0SOA"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "%%writefile validation_set.jsonl\n",
+ "{\"messages\": [{\"role\": \"system\", \"content\": \"Clippy is a factual chatbot that is also sarcastic.\"}, {\"role\": \"user\", \"content\": \"What's the capital of Australia?\"}, {\"role\": \"assistant\", \"content\": \"It's Canberra, not Sydney. Shocking, I know!\"}]}\n",
+ "{\"messages\": [{\"role\": \"system\", \"content\": \"Clippy is a factual chatbot that is also sarcastic.\"}, {\"role\": \"user\", \"content\": \"Who wrote 'The Great Gatsby'?\"}, {\"role\": \"assistant\", \"content\": \"Some guy named F. Scott Fitzgerald. Ever heard of him?\"}]}\n",
+ "{\"messages\": [{\"role\": \"system\", \"content\": \"Clippy is a factual chatbot that is also sarcastic.\"}, {\"role\": \"user\", \"content\": \"What's the deepest part of the ocean?\"}, {\"role\": \"assistant\", \"content\": \"Mariana Trench, makes your swimming pool seem a bit shallow, doesn't it?\"}]}\n",
+ "{\"messages\": [{\"role\": \"system\", \"content\": \"Clippy is a factual chatbot that is also sarcastic.\"}, {\"role\": \"user\", \"content\": \"Who composed the 'Moonlight Sonata'?\"}, {\"role\": \"assistant\", \"content\": \"Just a little-known guy named Beethoven. No big deal.\"}]}\n",
+ "{\"messages\": [{\"role\": \"system\", \"content\": \"Clippy is a factual chatbot that is also sarcastic.\"}, {\"role\": \"user\", \"content\": \"What are the colors of the rainbow?\"}, {\"role\": \"assistant\", \"content\": \"ROYGBIV. Red, orange, yellow, green, blue, indigo, violet. Not that we learned this in preschool or anything!\"}]}\n",
+ "{\"messages\": [{\"role\": \"system\", \"content\": \"Clippy is a factual chatbot that is also sarcastic.\"}, {\"role\": \"user\", \"content\": \"What is the smallest planet?\"}, {\"role\": \"assistant\", \"content\": \"Meet Mercury, the runt of the solar system family.\"}]}\n",
+ "{\"messages\": [{\"role\": \"system\", \"content\": \"Clippy is a factual chatbot that is also sarcastic.\"}, {\"role\": \"user\", \"content\": \"Who wrote 'Pride and Prejudice'?\"}, {\"role\": \"assistant\", \"content\": \"Do the words Jane Austen ring a bell? No? Well, she wrote it.\"}]}\n",
+ "{\"messages\": [{\"role\": \"system\", \"content\": \"Clippy is a factual chatbot that is also sarcastic.\"}, {\"role\": \"user\", \"content\": \"What's the largest desert?\"}, {\"role\": \"assistant\", \"content\": \"Antarctica. Surprise, surprise! Deserts aren't just full of sand, you know.\"}]}\n",
+ "{\"messages\": [{\"role\": \"system\", \"content\": \"Clippy is a factual chatbot that is also sarcastic.\"}, {\"role\": \"user\", \"content\": \"What is the longest river?\"}, {\"role\": \"assistant\", \"content\": \"The Nile River. It's not like it's famous or anything.\"}]}\n",
+ "{\"messages\": [{\"role\": \"system\", \"content\": \"Clippy is a factual chatbot that is also sarcastic.\"}, {\"role\": \"user\", \"content\": \"What's the capital of Germany?\"}, {\"role\": \"assistant\", \"content\": \"Berlin. Shocking news, right?\"}]}"
+ ],
+ "metadata": {
+ "id": "ZOnKRXBr0hSh"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Load and Validate our Datasets"
+ ],
+ "metadata": {
+ "id": "RrUVtJbH6QRR"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# Load the training set\n",
+ "with open('training_set.jsonl', 'r', encoding='utf-8') as f:\n",
+ " training_dataset = [json.loads(line) for line in f]\n",
+ "\n",
+ "# Training dataset stats\n",
+ "print(\"Number of examples in training set:\", len(training_dataset))\n",
+ "print(\"First example in training set:\")\n",
+ "for message in training_dataset[0][\"messages\"]:\n",
+ " print(message)\n",
+ "\n",
+ "# Load the validation set\n",
+ "with open('validation_set.jsonl', 'r', encoding='utf-8') as f:\n",
+ " validation_dataset = [json.loads(line) for line in f]\n",
+ "\n",
+ "# Validation dataset stats\n",
+ "print(\"\\nNumber of examples in validation set:\", len(validation_dataset))\n",
+ "print(\"First example in validation set:\")\n",
+ "for message in validation_dataset[0][\"messages\"]:\n",
+ " print(message)"
+ ],
+ "metadata": {
+ "id": "IcRweEs_0orF"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "encoding = tiktoken.get_encoding(\"cl100k_base\") # default encoding used by gpt-4, turbo, and text-embedding-ada-002 models\n",
+ "\n",
+ "def num_tokens_from_messages(messages, tokens_per_message=3, tokens_per_name=1):\n",
+ " num_tokens = 0\n",
+ " for message in messages:\n",
+ " num_tokens += tokens_per_message\n",
+ " for key, value in message.items():\n",
+ " num_tokens += len(encoding.encode(value))\n",
+ " if key == \"name\":\n",
+ " num_tokens += tokens_per_name\n",
+ " num_tokens += 3\n",
+ " return num_tokens\n",
+ "\n",
+ "def num_assistant_tokens_from_messages(messages):\n",
+ " num_tokens = 0\n",
+ " for message in messages:\n",
+ " if message[\"role\"] == \"assistant\":\n",
+ " num_tokens += len(encoding.encode(message[\"content\"]))\n",
+ " return num_tokens\n",
+ "\n",
+ "def print_distribution(values, name):\n",
+ " print(f\"\\n#### Distribution of {name}:\")\n",
+ " print(f\"min / max: {min(values)}, {max(values)}\")\n",
+ " print(f\"mean / median: {np.mean(values)}, {np.median(values)}\")\n",
+ " print(f\"p5 / p95: {np.quantile(values, 0.1)}, {np.quantile(values, 0.9)}\")\n",
+ "\n",
+ "files = ['training_set.jsonl', 'validation_set.jsonl']\n",
+ "\n",
+ "for file in files:\n",
+ " print(f\"Processing file: {file}\")\n",
+ " with open(file, 'r', encoding='utf-8') as f:\n",
+ " dataset = [json.loads(line) for line in f]\n",
+ "\n",
+ " total_tokens = []\n",
+ " assistant_tokens = []\n",
+ "\n",
+ " for ex in dataset:\n",
+ " messages = ex.get(\"messages\", {})\n",
+ " total_tokens.append(num_tokens_from_messages(messages))\n",
+ " assistant_tokens.append(num_assistant_tokens_from_messages(messages))\n",
+ "\n",
+ " print_distribution(total_tokens, \"total tokens\")\n",
+ " print_distribution(assistant_tokens, \"assistant tokens\")\n",
+ " print('*' * 50)"
+ ],
+ "metadata": {
+ "id": "rrS0t-h40suH"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# Begin our Finetuning on Azure!"
+ ],
+ "metadata": {
+ "id": "vinDQEtV6WKB"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Connect to Azure"
+ ],
+ "metadata": {
+ "id": "r5iCXKkP6Zkm"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# Upload fine-tuning files\n",
+ "from openai import AzureOpenAI\n",
+ "\n",
+ "client = AzureOpenAI(\n",
+ " azure_endpoint = os.getenv(\"AZURE_OPENAI_ENDPOINT\"),\n",
+ " api_key=os.getenv(\"AZURE_OPENAI_KEY\"),\n",
+ " api_version=\"2023-12-01-preview\" # This API version or later is required to access fine-tuning for turbo/babbage-002/davinci-002\n",
+ ")"
+ ],
+ "metadata": {
+ "id": "KioZ9_Qe6fEU"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Upload our Validated Training Data"
+ ],
+ "metadata": {
+ "id": "p-WS1-Qs6goA"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "training_file_name = 'training_set.jsonl'\n",
+ "validation_file_name = 'validation_set.jsonl'\n",
+ "\n",
+ "# Upload the training and validation dataset files to Azure OpenAI with the SDK.\n",
+ "\n",
+ "training_response = client.files.create(\n",
+ " file=open(training_file_name, \"rb\"), purpose=\"fine-tune\"\n",
+ ")\n",
+ "training_file_id = training_response.id\n",
+ "\n",
+ "validation_response = client.files.create(\n",
+ " file=open(validation_file_name, \"rb\"), purpose=\"fine-tune\"\n",
+ ")\n",
+ "validation_file_id = validation_response.id\n",
+ "\n",
+ "print(\"Training file ID:\", training_file_id)\n",
+ "print(\"Validation file ID:\", validation_file_id)"
+ ],
+ "metadata": {
+ "id": "cIVfUF4S0uW-"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Run Fine-tuning!"
+ ],
+ "metadata": {
+ "id": "wnjC-ENd6lUf"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "response = client.fine_tuning.jobs.create(\n",
+ " training_file=training_file_id,\n",
+ " validation_file=validation_file_id,\n",
+ " model=\"gpt-35-turbo-0613\", # Enter base model name. Note that in Azure OpenAI the model name contains dashes and cannot contain dot/period characters.\n",
+ ")\n",
+ "\n",
+ "job_id = response.id\n",
+ "\n",
+ "# You can use the job ID to monitor the status of the fine-tuning job.\n",
+ "# The fine-tuning job will take some time to start and complete.\n",
+ "\n",
+ "print(\"Job ID:\", job_id)\n",
+ "print(response.model_dump_json(indent=2))"
+ ],
+ "metadata": {
+ "id": "CSMQBov-004I"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Sync metrics, data, and more with 2 lines of code!"
+ ],
+ "metadata": {
+ "id": "OrADYudw6oO_"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "wandb_project = \"Azure_Openai_Finetuning\""
+ ],
+ "metadata": {
+ "id": "YZegrNtR9ofY"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from wandb.integration.openai.fine_tuning import WandbLogger\n",
+ "\n",
+ "WandbLogger.sync(fine_tune_job_id=job_id, openai_client=client, project=wandb_project)"
+ ],
+ "metadata": {
+ "id": "ZtZTj_BE01cU"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "\n",
+ "> this takes a varying amount of time. Feel free to check the Azure service you set up to ensure the finetuning is running"
+ ],
+ "metadata": {
+ "id": "iaUYCvlm-z-I"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Logging the fine-tuning job to W&B is straight forward. The integration will automatically log the following to W&B:\n",
+ "\n",
+ "- training and validation metrics (if validation data is provided)\n",
+ "- log the training and validation data as W&B Tables for storage and versioning\n",
+ "- log the fine-tuned model's metadata.\n",
+ "\n",
+ "The integration automatically creates the DAG lineage between the data and the model.\n",
+ "\n",
+ "> You can call the `WandbLogger` with the job id. The cell will keep running till the fine-tuning job is not complete. Once the job's status is `succeeded`, the `WandbLogger` will log metrics and data to W&B. This way you don't have to wait for the fine-tune job to be completed to call `WandbLogger.sync`."
+ ],
+ "metadata": {
+ "id": "GPCr5RhZ-_ya"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Calling `WandbLogger.sync` without any id will log all un-synced fine-tuned jobs to W&B\n",
+ "\n",
+ "See the [OpenAI section](https://wandb.me/openai-docs) in the Weights & Biases documentation for full details of the integration"
+ ],
+ "metadata": {
+ "id": "E5A9db0A_Fay"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "The fine-tuning job is now successfully synced to Weights and Biases. Click on the URL above to open the [W&B run page](https://docs.wandb.ai/guides/app/pages/run-page). The following will be logged to W&B:\n",
+ "\n",
+ "#### Training and validation metrics\n",
+ "\n",
+ "![image.png](assets/metrics.png)\n",
+ "\n",
+ "#### Training and validation data as W&B Tables\n",
+ "\n",
+ "![image.png](assets/datatable.png)\n",
+ "\n",
+ "#### The data and model artifacts for version control (go to the overview tab)\n",
+ "\n",
+ "![image.png](assets/artifacts.png)\n",
+ "\n",
+ "#### The configuration and hyperparameters (go to the overview tab)\n",
+ "\n",
+ "![image.png](assets/configs.png)\n",
+ "\n",
+ "#### The data and model DAG\n",
+ "\n",
+ "![image.png](assets/dag.png)"
+ ],
+ "metadata": {
+ "id": "fNAzboPD_KP5"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Load the trained model for inference"
+ ],
+ "metadata": {
+ "id": "oSFBU9DB6tkJ"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "#Retrieve fine_tuned_model name\n",
+ "\n",
+ "response = client.fine_tuning.jobs.retrieve(job_id)\n",
+ "\n",
+ "print(response.model_dump_json(indent=2))\n",
+ "fine_tuned_model = response.fine_tuned_model"
+ ],
+ "metadata": {
+ "id": "g2z05Koa05Qx"
+ },
+ "execution_count": null,
+ "outputs": []
+ }
+ ]
+}
\ No newline at end of file
diff --git a/colabs/openai/Fine_tune_OpenAI_with_Weights_and_Biases.ipynb b/colabs/openai/Fine_tune_OpenAI_with_Weights_and_Biases.ipynb
index a2ec1217..d99d40c8 100644
--- a/colabs/openai/Fine_tune_OpenAI_with_Weights_and_Biases.ipynb
+++ b/colabs/openai/Fine_tune_OpenAI_with_Weights_and_Biases.ipynb
@@ -4,7 +4,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "\n",
+ "\n",
""
]
},
@@ -23,11 +23,17 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "If you use OpenAI's API to [fine-tune ChatGPT-3.5](https://platform.openai.com/docs/guides/fine-tuning), you can now use the W&B integration to track experiments, models, and datasets in your central dashboard.\n",
+ "If you use OpenAI's API to [fine-tune ChatGPT-3.5](https://platform.openai.com/docs/guides/fine-tuning), you can now use the `WandbLogger` integration to track experiments, models, and datasets in your central dashboard with just two lines of code:\n",
"\n",
- "All it takes is one line: `openai wandb sync`\n",
+ "```\n",
+ "from wandb.integration.openai.fine_tuning import WandbLogger\n",
"\n",
- "See the [OpenAI section](https://wandb.me/openai-docs) in the Weights & Biases documentation for full details of the integration"
+ "# Your fine-tuning logic\n",
+ "\n",
+ "WandbLogger.sync(id=fine_tune_job_id)\n",
+ "```\n",
+ "\n",
+ "See the [OpenAI section](https://wandb.me/openai-docs) in the Weights & Biases documentation for full details of the integration."
]
},
{
@@ -36,28 +42,24 @@
"metadata": {},
"outputs": [],
"source": [
- "!pip install -Uq openai tiktoken datasets tenacity wandb"
+ "!pip install -Uq openai tiktoken datasets tenacity"
]
},
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
- "# Remove once openai 0.28.1 release is made\n",
- "!pip uninstall -y openai -qq && pip install git+https://github.com/openai/openai-python -qqq"
+ "# Remove once this PR: https://github.com/wandb/wandb/pull/6498 is merged\n",
+ "!pip install git+https://github.com/wandb/wandb.git"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "# Optional: Fine-tune ChatGPT-3.5\n",
- "\n",
- "It's always more fun to experiment with your own projects so if you have already used the openai API to fine-tune an OpenAI model, just skip this section.\n",
- "\n",
- "Otherwise let's fine-tune ChatGPT-3.5 on a legal dataset!"
+ "In this colab notebook, we will be finetuning GPT 3.5 model on the [LegalBench](https://hazyresearch.stanford.edu/legalbench/) dataset. The notebook will show how to prepare and validate the dataset, upload it to OpenAI and setup a fine-tune job. Finally, the notebook shows how to use the `WandbLogger`."
]
},
{
@@ -73,10 +75,11 @@
"metadata": {},
"outputs": [],
"source": [
- "import openai\n",
+ "from openai import OpenAI\n",
"import wandb\n",
"\n",
"import os\n",
+ "import glob\n",
"import json\n",
"import random\n",
"import tiktoken\n",
@@ -92,7 +95,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "Start your Weigths & Biases run. If you don't have an account you can sign up for one for free at www.wandb.ai"
+ "Initialize the OpenAI client\n",
+ "\n",
+ "You can add the api key to your environment variable by doing `os.environ['OPENAI_API_KEY'] = \"sk-....\"`."
]
},
{
@@ -101,14 +106,16 @@
"metadata": {},
"outputs": [],
"source": [
- "WANDB_PROJECT = \"OpenAI-Fine-Tune\""
+ "# Uncomment the line below and set your OpenAI API Key.\n",
+ "# os.environ['OPENAI_API_KEY'] = \"sk-....\" \n",
+ "client = OpenAI()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "### Set up your API key"
+ "Import the `WandbLogger`"
]
},
{
@@ -117,10 +124,9 @@
"metadata": {},
"outputs": [],
"source": [
- "# Enter credentials\n",
- "openai_key = \"YOUR_API_KEY\"\n",
+ "from wandb.integration.openai.fine_tuning import WandbLogger\n",
"\n",
- "openai.api_key = openai_key"
+ "WANDB_PROJECT = \"OpenAI-Fine-Tune\""
]
},
{
@@ -430,8 +436,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "### Log Data to Weights & Biases Artifacts\n",
- "Log our data to Weigths & Biases Artifacts for storage and versioning"
+ "Validate test data"
]
},
{
@@ -440,40 +445,31 @@
"metadata": {},
"outputs": [],
"source": [
- "wandb.init(\n",
- " project=WANDB_PROJECT,\n",
- " job_type=\"log-data\",\n",
- " config = {'n_train': n_train,\n",
- " 'n_valid': n_test})\n",
- "\n",
- "wandb.log_artifact(train_file_path,\n",
- " \"legalbench-contract_nli_explicit_identification-train\",\n",
- " type=\"train-data\")\n",
- "\n",
- "wandb.log_artifact(test_file_path,\n",
- " \"legalbench-contract_nli_explicit_identification-test\",\n",
- " type=\"test-data\")\n",
- "\n",
- "# keep entity (typically your wandb username) for reference of artifact later in this demo\n",
- "entity = wandb.run.entity\n",
- "\n",
- "wandb.finish()"
+ "openai_validate_data(test_file_path)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "## Create a fine-tuned model"
+ "## Upload the training and validation data to OpenAI\n",
+ "\n",
+ "We will first upload the data to OpenAI. This might take a few minutes depending on the size of your dataset."
]
},
{
- "cell_type": "markdown",
+ "cell_type": "code",
+ "execution_count": null,
"metadata": {},
+ "outputs": [],
"source": [
- "We'll now use OpenAI API to fine-tune ChatGPT-3.5\n",
+ "openai_train_file_info = client.files.create(\n",
+ " file=open(train_file_path, \"rb\"), purpose=\"fine-tune\"\n",
+ ")\n",
"\n",
- "Let's first download our training & validation files and save them to a folder called `my_data`. We will retrieve the `latest` version of the artifact, but it could also be `v0`, `v1` or any alias we associated with it"
+ "openai_valid_file_info = client.files.create(\n",
+ " file=open(test_file_path, \"rb\"), purpose=\"fine-tune\"\n",
+ ")"
]
},
{
@@ -482,43 +478,30 @@
"metadata": {},
"outputs": [],
"source": [
- "wandb.init(project=WANDB_PROJECT, job_type=\"finetune\")\n",
- "\n",
- "artifact_train = wandb.use_artifact(\n",
- " f'{entity}/{WANDB_PROJECT}/legalbench-contract_nli_explicit_identification-train:latest',\n",
- " type='train-data')\n",
- "train_file = artifact_train.get_path(train_file_path).download(\"my_data\")\n",
- "\n",
- "train_file"
+ "openai_train_file_info"
]
},
{
- "cell_type": "markdown",
+ "cell_type": "code",
+ "execution_count": null,
"metadata": {},
+ "outputs": [],
"source": [
- "Then we upload the training data to OpenAI. OpenAi has to process the data, so this will take a few minutes depending on the size of your dataset."
+ "openai_valid_file_info"
]
},
{
- "cell_type": "code",
- "execution_count": null,
+ "cell_type": "markdown",
"metadata": {},
- "outputs": [],
"source": [
- "openai_train_file_info = openai.File.create(\n",
- " file=open(train_file, \"rb\"),\n",
- " purpose='fine-tune'\n",
- ")\n",
- "\n",
- "# you may need to wait a couple of minutes for OpenAI to process the file\n",
- "openai_train_file_info"
+ "> Notice the unique ids for both training and validation data. OpenAI uses these ids to access the uploaded data to fine-tune GPT 3.5 on."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "### Time to train the model!"
+ "# Train the model and log to Weights & Biases"
]
},
{
@@ -544,15 +527,17 @@
"metadata": {},
"outputs": [],
"source": [
- "openai_ft_job_info = openai.FineTuningJob.create(\n",
- " training_file=openai_train_file_info[\"id\"],\n",
+ "openai_ft_job_info = client.fine_tuning.jobs.create(\n",
+ " training_file=openai_train_file_info.id,\n",
" model=model,\n",
- " hyperparameters={\"n_epochs\": n_epochs}\n",
+ " hyperparameters={\"n_epochs\": n_epochs},\n",
+ " validation_file=openai_valid_file_info.id\n",
")\n",
"\n",
- "ft_job_id = openai_ft_job_info[\"id\"]\n",
+ "ft_job_id = openai_ft_job_info.id\n",
"\n",
- "openai_ft_job_info"
+ "# Log to Weights and Biases\n",
+ "WandbLogger.sync(fine_tune_job_id=ft_job_id, project=WANDB_PROJECT)"
]
},
{
@@ -560,7 +545,7 @@
"metadata": {},
"source": [
"\n",
- "> this takes around 5 minutes to train, and you get an email from OpenAI when finished."
+ "> this takes around 5 minutes to train."
]
},
{
@@ -569,54 +554,84 @@
"source": [
"**Thats it!**\n",
"\n",
- "Now your model is training on OpenAI's machines. To get the current state of your fine-tuning job, run:"
+ "Now your model is training on OpenAI's machines."
]
},
{
- "cell_type": "code",
- "execution_count": null,
+ "cell_type": "markdown",
"metadata": {},
- "outputs": [],
"source": [
- "state = openai.FineTuningJob.retrieve(ft_job_id)\n",
- "state[\"status\"], state[\"trained_tokens\"], state[\"finished_at\"], state[\"fine_tuned_model\"]"
+ "Logging the fine-tuning job to W&B is straight forward. The integration will automatically log the following to W&B:\n",
+ "\n",
+ "- training and validation metrics (if validation data is provided)\n",
+ "- log the training and validation data as W&B Tables for storage and versioning\n",
+ "- log the fine-tuned model's metadata.\n",
+ "\n",
+ "The integration automatically creates the DAG lineage between the data and the model.\n",
+ "\n",
+ "> You can call the `WandbLogger` with the job id. The cell will keep running till the fine-tuning job is not complete. Once the job's status is `succeeded`, the `WandbLogger` will log metrics and data to W&B. This way you don't have to wait for the fine-tune job to be completed to call `WandbLogger.sync`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "Show recent events for our fine-tuning job"
+ "Calling `WandbLogger.sync` without any id will log all un-synced fine-tuned jobs to W&B\n",
+ "\n",
+ "See the [OpenAI section](https://wandb.me/openai-docs) in the Weights & Biases documentation for full details of the integration"
]
},
{
- "cell_type": "code",
- "execution_count": null,
+ "attachments": {},
+ "cell_type": "markdown",
"metadata": {},
- "outputs": [],
"source": [
- "openai.FineTuningJob.list_events(id=ft_job_id, limit=5)"
+ "The fine-tuning job is now successfully synced to Weights and Biases. Click on the URL above to open the [W&B run page](https://docs.wandb.ai/guides/app/pages/run-page). The following will be logged to W&B:\n",
+ "\n",
+ "#### Training and validation metrics\n",
+ "\n",
+ "![image.png](assets/metrics.png)\n",
+ "\n",
+ "#### Training and validation data as W&B Tables\n",
+ "\n",
+ "![image.png](assets/datatable.png)\n",
+ "\n",
+ "#### The data and model artifacts for version control (go to the overview tab)\n",
+ "\n",
+ "![image.png](assets/artifacts.png)\n",
+ "\n",
+ "#### The configuration and hyperparameters (go to the overview tab)\n",
+ "\n",
+ "![image.png](assets/configs.png)\n",
+ "\n",
+ "#### The data and model DAG\n",
+ "\n",
+ "![image.png](assets/dag.png)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "We can run a few different fine-tunes with different parameters or even with different datasets."
+ "## Run evalution and log the results"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "# Log OpenAI fine-tune jobs to Weights & Biases"
+ "The best way to evaluate a generative model is to explore sample predictions from your evaluation set.\n",
+ "\n",
+ "Let's generate a few inference samples and log them to W&B and see how the performance compares to a baseline ChatGPT-3.5 model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "We can log our fine-tunes with a simple command."
+ "We will be evaluating using the validation dataset. In the overview tab of the run page, find the \"validation_files\" in the Artifact Inputs section. Clicking on it will take you to the artifacts page. Copy the artifact URI (full name) as shown in the image below.\n",
+ "\n",
+ "![image](assets/select_artifact_uri.png)"
]
},
{
@@ -625,30 +640,24 @@
"metadata": {},
"outputs": [],
"source": [
- "!openai wandb sync --help"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Calling `openai wandb sync` will log all un-synced fine-tuned jobs to W&B\n",
+ "run = wandb.init(\n",
+ " project=WANDB_PROJECT,\n",
+ " job_type='eval'\n",
+ ")\n",
"\n",
- "Below we are just logging 1 job, passing:\n",
- "- our OpenAI key as an environment variable\n",
- "- the id of the fine-tune job we'd like to log\n",
- "- the W&B project of where to log it to\n",
+ "VALIDATION_FILE_ARTIFACT_URI = '//valid-file-*' # REPLACE THIS WITH YOUR OWN ARTIFACT URI\n",
"\n",
- "See the [OpenAI section](https://wandb.me/openai-docs) in the Weights & Biases documentation for full details of the integration"
+ "artifact_valid = run.use_artifact(\n",
+ " VALIDATION_FILE_ARTIFACT_URI,\n",
+ " type='validation_files'\n",
+ ")"
]
},
{
- "cell_type": "code",
- "execution_count": null,
+ "cell_type": "markdown",
"metadata": {},
- "outputs": [],
"source": [
- "!OPENAI_API_KEY={openai_key} openai wandb sync --id {ft_job_id} --project {WANDB_PROJECT}"
+ "The code snippet below, download the logged validation data and prepare a pandas dataframe from it."
]
},
{
@@ -657,34 +666,32 @@
"metadata": {},
"outputs": [],
"source": [
- "wandb.finish()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Our fine-tunes are now successfully synced to Weights & Biases.\n",
+ "artifact_valid_path = artifact_valid.download()\n",
+ "print(\"Downloaded the validation data at: \", artifact_valid_path)\n",
"\n",
- "![image.png]()\n",
+ "validation_file = glob.glob(f\"{artifact_valid_path}/*.table.json\")[0]\n",
+ "with open(validation_file, 'r') as file:\n",
+ " data = json.load(file)\n",
"\n",
- "Anytime we have new fine-tunes, we can just call `openai wandb sync` to add them to our dashboard."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Run evalution and log the results"
+ "validation_df = pd.DataFrame(columns=data[\"columns\"], data=data[\"data\"])\n",
+ "\n",
+ "print(f\"There are {len(validation_df)} validation examples\")\n",
+ "run.config.update({\"num_validation_samples\":len(validation_df)})\n",
+ "\n",
+ "validation_df.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "The best way to evaluate a generative model is to explore sample predictions from your evaluation set.\n",
+ "We will need to package the data in the dataframe in the format acceptable by GPT 3.5. The format is:\n",
"\n",
- "Let's generate a few inference samples and log them to W&B and see how the performance compares to a baseline ChatGPT-3.5 model"
+ "```\n",
+ "{\"messages\": [{\"role\": \"system\", \"content\": \"some system prompt\"}, {\"role\": \"user\", \"content\": \"some user prompt\"}, {\"role\": \"assistant\", \"content\": \"completion text\"}]}\n",
+ "```\n",
+ "\n",
+ "For evaluation we don't need to pack the `{\"role\": \"assistant\", \"content\": \"completition text\"}` in `messages` as this is meant to be generated by GPT 3.5."
]
},
{
@@ -693,19 +700,17 @@
"metadata": {},
"outputs": [],
"source": [
- "wandb.init(project=WANDB_PROJECT,\n",
- " job_type='eval')\n",
+ "def eval_data_format(row):\n",
+ " role_system_content = row[\"role: system\"]\n",
+ " role_system_dict = {\"role\": \"system\", \"content\": role_system_content}\n",
"\n",
- "artifact_valid = wandb.use_artifact(\n",
- " f'{entity}/{WANDB_PROJECT}/legalbench-contract_nli_explicit_identification-test:latest',\n",
- " type='test-data')\n",
- "test_file = artifact_valid.get_path(test_file_path).download(\"my_data\")\n",
+ " role_user_content = row[\"role: user\"]\n",
+ " role_user_dict = {\"role\": \"user\", \"content\": role_user_content}\n",
+ " \n",
+ " return [role_system_dict, role_user_dict]\n",
"\n",
- "with open(test_file) as f:\n",
- " test_dataset = [json.loads(line) for line in f]\n",
- "\n",
- "print(f\"There are {len(test_dataset)} test examples\")\n",
- "wandb.config.update({\"num_test_samples\":len(test_dataset)})"
+ "validation_df[\"messages\"] = validation_df.apply(lambda row: eval_data_format(row), axis=1)\n",
+ "validation_df.head()"
]
},
{
@@ -713,7 +718,10 @@
"metadata": {},
"source": [
"### Run evaluation on the Fine-Tuned Model\n",
- "Set up OpenAI call with retries"
+ "\n",
+ "Next up we will get the fine-tuned model's id from the logged `model_metadata`. In the overview tab of the run page, find the \"model\" in the Artifact Outputs section. Clicking on it will take you to the artifacts page. Copy the artifact URI (full name) as shown in the image below.\n",
+ "\n",
+ "![image](assets/select_model_artifact.png)"
]
},
{
@@ -722,16 +730,28 @@
"metadata": {},
"outputs": [],
"source": [
- "@retry(stop=stop_after_attempt(3), wait=wait_fixed(60))\n",
- "def call_openai(messages=\"\", model=\"gpt-3.5-turbo\"):\n",
- " return openai.ChatCompletion.create(model=model, messages=messages, max_tokens=10)"
+ "MODEL_ARTIFACT_URI = '//model_metadata:v*' # REPLACE THIS WITH YOUR OWN ARTIFACT URI\n",
+ "\n",
+ "model_artifact = run.use_artifact(\n",
+ " MODEL_ARTIFACT_URI,\n",
+ " type='model'\n",
+ ")"
]
},
{
- "cell_type": "markdown",
+ "cell_type": "code",
+ "execution_count": null,
"metadata": {},
+ "outputs": [],
"source": [
- "Let's get our trained model id"
+ "model_metadata_path = model_artifact.download()\n",
+ "print(\"Downloaded the validation data at: \", model_metadata_path)\n",
+ "\n",
+ "model_metadata_file = glob.glob(f\"{model_metadata_path}/*.json\")[0]\n",
+ "with open(model_metadata_file, 'r') as file:\n",
+ " model_metadata = json.load(file)\n",
+ "\n",
+ "model_metadata"
]
},
{
@@ -740,9 +760,8 @@
"metadata": {},
"outputs": [],
"source": [
- "state = openai.FineTuningJob.retrieve(ft_job_id)\n",
- "ft_model_id = state[\"fine_tuned_model\"]\n",
- "ft_model_id"
+ "fine_tuned_model = model_metadata[\"fine_tuned_model\"]\n",
+ "client = OpenAI()"
]
},
{
@@ -762,16 +781,15 @@
"\n",
"eval_data = []\n",
"\n",
- "for row in tqdm(test_dataset):\n",
- " messages = row['messages'][:2]\n",
- " target = row[\"messages\"][2]\n",
+ "for idx, row in tqdm(validation_df.iterrows()):\n",
+ " messages = row.messages\n",
+ " target = row[\"role: assistant\"]\n",
"\n",
- " # res = call_openai(model=ft_model_id, messages=messages)\n",
- " res = openai.ChatCompletion.create(model=model, messages=messages, max_tokens=10)\n",
+ " res = client.chat.completions.create(model=fine_tuned_model, messages=messages, max_tokens=10)\n",
" completion = res.choices[0].message.content\n",
"\n",
" eval_data.append([messages, completion, target])\n",
- " prediction_table.add_data(messages[1]['content'], completion, target[\"content\"])\n",
+ " prediction_table.add_data(messages[1]['content'], completion, target)\n",
"\n",
"wandb.log({'predictions': prediction_table})"
]
@@ -792,7 +810,7 @@
"source": [
"correct = 0\n",
"for e in eval_data:\n",
- " if e[1].lower() == e[2][\"content\"].lower():\n",
+ " if e[1].lower() == e[2].lower():\n",
" correct+=1\n",
"\n",
"accuracy = correct / len(eval_data)\n",
@@ -817,17 +835,18 @@
"outputs": [],
"source": [
"baseline_prediction_table = wandb.Table(columns=['messages', 'completion', 'target'])\n",
+ "\n",
"baseline_eval_data = []\n",
"\n",
- "for row in tqdm(test_dataset):\n",
- " messages = row['messages'][:2]\n",
- " target = row[\"messages\"][2]\n",
+ "for idx, row in tqdm(validation_df.iterrows()):\n",
+ " messages = row.messages\n",
+ " target = row[\"role: assistant\"]\n",
"\n",
- " res = call_openai(model=\"gpt-3.5-turbo\", messages=messages)\n",
+ " res = client.chat.completions.create(model=\"gpt-3.5-turbo\", messages=messages, max_tokens=10)\n",
" completion = res.choices[0].message.content\n",
"\n",
" baseline_eval_data.append([messages, completion, target])\n",
- " baseline_prediction_table.add_data(messages[1]['content'], completion, target[\"content\"])\n",
+ " baseline_prediction_table.add_data(messages[1]['content'], completion, target)\n",
"\n",
"wandb.log({'baseline_predictions': baseline_prediction_table})"
]
@@ -847,7 +866,7 @@
"source": [
"baseline_correct = 0\n",
"for e in baseline_eval_data:\n",
- " if e[1].lower() == e[2][\"content\"].lower():\n",
+ " if e[1].lower() == e[2].lower():\n",
" baseline_correct+=1\n",
"\n",
"baseline_accuracy = baseline_correct / len(baseline_eval_data)\n",
@@ -881,29 +900,22 @@
"# Resources\n",
"\n",
"* [OpenAI Fine-Tuning Guide](https://platform.openai.com/docs/guides/fine-tuning)\n",
- "* [W&B Integration with OpenAI API Documentation](https://wandb.me/openai-docs)\n",
- "* [W&B Report: GPT-3 exploration & fine-tuning tips](http://wandb.me/openai-report)"
+ "* [W&B Integration with OpenAI API Documentation](https://wandb.me/openai-docs)"
]
}
],
"metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "include_colab_link": true,
+ "provenance": [],
+ "toc_visible": true
+ },
"kernelspec": {
- "display_name": "python3",
+ "display_name": "Python 3",
"name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.11.2"
}
},
"nbformat": 4,
- "nbformat_minor": 0
+ "nbformat_minor": 4
}
diff --git a/colabs/openai/OpenAI_API_Autologger_Quickstart.ipynb b/colabs/openai/OpenAI_API_Autologger_Quickstart.ipynb
index b3767dc8..2e793dee 100644
--- a/colabs/openai/OpenAI_API_Autologger_Quickstart.ipynb
+++ b/colabs/openai/OpenAI_API_Autologger_Quickstart.ipynb
@@ -47,11 +47,7 @@
{
"cell_type": "code",
"execution_count": null,
- "metadata": {
- "vscode": {
- "languageId": "python"
- }
- },
+ "metadata": {},
"outputs": [],
"source": [
"!pip install wandb openai -qU"
@@ -73,11 +69,7 @@
{
"cell_type": "code",
"execution_count": null,
- "metadata": {
- "vscode": {
- "languageId": "python"
- }
- },
+ "metadata": {},
"outputs": [],
"source": [
"import openai\n",
@@ -105,11 +97,7 @@
{
"cell_type": "code",
"execution_count": null,
- "metadata": {
- "vscode": {
- "languageId": "python"
- }
- },
+ "metadata": {},
"outputs": [],
"source": [
"# pass your OpenAI key\n",
@@ -120,11 +108,7 @@
{
"cell_type": "code",
"execution_count": null,
- "metadata": {
- "vscode": {
- "languageId": "python"
- }
- },
+ "metadata": {},
"outputs": [],
"source": [
"# make some calls to OpenAI \n",
@@ -178,11 +162,7 @@
{
"cell_type": "code",
"execution_count": null,
- "metadata": {
- "vscode": {
- "languageId": "python"
- }
- },
+ "metadata": {},
"outputs": [],
"source": [
"autolog.disable()"
diff --git a/colabs/openai/OpenAI_Finetuning_on_Gorilla_with_wandb.ipynb b/colabs/openai/OpenAI_Finetuning_on_Gorilla_with_wandb.ipynb
new file mode 100644
index 00000000..0915d874
--- /dev/null
+++ b/colabs/openai/OpenAI_Finetuning_on_Gorilla_with_wandb.ipynb
@@ -0,0 +1,635 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# ChatGPT-3.5 Fine-tuning - Gorrilla api\n",
+ "\n",
+ "Fine-tuning ChatGPT-3.5 on the Gorilla api dataset to try and improve its performance\n",
+ "- [Gorilla project](https://shishirpatil.github.io/gorilla/)\n",
+ "- [Gorilla paper](https://arxiv.org/abs/2305.15334)\n",
+ "- [Gorilla code](https://github.com/ShishirPatil/gorilla)\n",
+ "\n",
+ "OpenAI ChatGPT-3.5 fine-tuning docs [are here](https://platform.openai.com/docs/guides/fine-tuning)\n",
+ "\n",
+ "**Warning!**\n",
+ "\n",
+ "This fine-tuning script will train 7.2 million tokens on OpenAI, check if you're willing to pay that before proceeding :)\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!pip install openai tiktoken wandb -qqq"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import re\n",
+ "import os\n",
+ "import json\n",
+ "import wandb\n",
+ "import openai\n",
+ "from pprint import pprint"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "openai_api_key = \"OPENAI API KEY\"\n",
+ "openai.api_key = openai_api_key"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Data"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Download the Gorrilla huggingface api training data, you can find all the [Gorilla training data here](https://github.com/ShishirPatil/gorilla/tree/main/data/apibench)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!wget https://raw.githubusercontent.com/ShishirPatil/gorilla/cab053ba7fdf4a3286c0e75aa2bf7abc4053812f/data/apibench/huggingface_train.json\n",
+ "!wget https://raw.githubusercontent.com/ShishirPatil/gorilla/main/data/apibench/tensorflow_train.json\n",
+ "!wget https://raw.githubusercontent.com/ShishirPatil/gorilla/main/data/apibench/torchhub_train.json"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Load the data"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "data = []\n",
+ "data_files = [\n",
+ " \"huggingface_train.json\",\n",
+ " \"tensorflow_train.json\",\n",
+ " \"torchhub_train.json\",\n",
+ "]\n",
+ "\n",
+ "for file in data_files:\n",
+ " with open(file, \"r\") as file:\n",
+ " # data = json.load(file)\n",
+ " for line in file:\n",
+ " item = json.loads(line.strip())\n",
+ " data.append(item)\n",
+ "\n",
+ "# This is the data relevant to training\n",
+ "data[0][\"code\"]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Data Parsing"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Parse the training data instructions"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def parse_instructions_and_outputs(code_section):\n",
+ "\n",
+ " sections = code_section.split('###')\n",
+ " instruction = \"\"\n",
+ " for section in sections:\n",
+ " if \"Instruction:\" in section:\n",
+ " instruction = section.lower().split(\"instruction:\", 1)[1].strip()\n",
+ " break\n",
+ "\n",
+ " # domain = re.search(r'<<>>(.*?)\\n', code_section, re.IGNORECASE).group(1).lstrip(': ')\n",
+ " if \"<<>>\" in code_section:\n",
+ " domain = re.search(r'<<>>(.*?)<<<', d[\"code\"], re.IGNORECASE | re.DOTALL).group(1).lstrip(': ')\n",
+ " else:\n",
+ " domain = \"\"\n",
+ "\n",
+ " api_call = re.search(r'<<>>(.*?)<<<', code_section, re.IGNORECASE | re.DOTALL).group(1).lstrip(': ')\n",
+ " # api_provider = re.search(r'<<>>(.*?)\\n', code_section, re.IGNORECASE).group(1).lstrip(': ')\n",
+ " if \"<<>>\" in code_section:\n",
+ " api_provider = re.search(r'<<>>(.*?)<<<', code_section, re.IGNORECASE | re.DOTALL).group(1).lstrip(': ')\n",
+ " else:\n",
+ " api_provider = \"\"\n",
+ "\n",
+ " if \"<<>>\" in code_section:\n",
+ " explanation_pattern = r'<<>>(.*?)(?:\\n<<>>|```|$)'\n",
+ " explanation = re.search(explanation_pattern, code_section, re.IGNORECASE | re.DOTALL).group(1).lstrip(': ')\n",
+ " else:\n",
+ " explanation = None\n",
+ "\n",
+ " # Extract code snippet considering both cases\n",
+ " code_pattern = r'(?:<<>>|```) (.*)' # Matches either <<>> or ```\n",
+ " code_snippet_match = re.search(code_pattern, code_section, re.IGNORECASE | re.DOTALL)\n",
+ " code_snippet = code_snippet_match.group(1).lstrip(': ') if code_snippet_match else None\n",
+ "\n",
+ " return instruction, domain, api_call, api_provider, explanation, code_snippet"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def encode_train_sample(data, api_name):\n",
+ " \"\"\"Encode multiple prompt instructions into a single string.\"\"\"\n",
+ " code_section = data['code']\n",
+ "\n",
+ " if \"<<>>\" in code_section:\n",
+ " instruction, domain, api_call, api_provider, explanation, code = parse_instructions_and_outputs(code_section)\n",
+ "\n",
+ " prompts = []\n",
+ "\n",
+ " #prompt = instruction + \"\\nWrite a python program in 1 to 2 lines to call API in \" + api_name + \".\\n\\nThe answer should follow the format: <<>> $DOMAIN, <<>>: $API_CALL, <<>>: $API_PROVIDER, <<>>: $EXPLANATION, <<>>: $CODE}. Here are the requirements:\\n\" + domains + \"\\n2. The $API_CALL should have only 1 line of code that calls api.\\n3. The $API_PROVIDER should be the programming framework used.\\n4. $EXPLANATION should be a step-by-step explanation.\\n5. The $CODE is the python code.\\n6. Do not repeat the format in your answer.\"\n",
+ "\n",
+ " prompts.append({\"role\": \"system\", \"content\": \"You are a helpful API writer who can write APIs based on requirements.\"})\n",
+ " prompts.append({\"role\": \"user\", \"content\": instruction})\n",
+ " prompts.append({\"role\": \"assistant\", \"content\": f\"<<>> {domain},\\\n",
+ "<<>>: {api_call}, <<>>: {api_provider}, <<>>: {explanation}, <<>>: {code}\"})\n",
+ " return prompts\n",
+ " else:\n",
+ " return None"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Format the training samples with the correct format to mirror the Gorilla paper"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "encoded_data = []\n",
+ "none_count = 0\n",
+ "for d in data:\n",
+ " res = encode_train_sample(d, \"huggingface\")\n",
+ " if res is not None:\n",
+ " encoded_data.append({\"messages\":res})\n",
+ " else:\n",
+ " none_count += 1\n",
+ "\n",
+ "print(f\"{none_count} samples out of {len(data)} ignored\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Print a sample of what will get passed to OpenAI for fine-tuning"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "encoded_data[333]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Save the training data"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "encoded_file_path = 'all_encoded_data.jsonl'\n",
+ "\n",
+ "with open(encoded_file_path, 'w') as file:\n",
+ " for item in encoded_data:\n",
+ " line = json.dumps(item)\n",
+ " file.write(line + '\\n')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Start a Weights & Biases run to save our data and results\n",
+ "wandb.init(project=\"gorilla-api\")\n",
+ "wandb.log_artifact(encoded_file_path, \"hf_tf_th_gorilla_train.jsonl\", type=\"train_data\")\n",
+ "wandb.finish()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## OpenAI data verification script"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# We start by importing the required packages\n",
+ "\n",
+ "import json\n",
+ "import os\n",
+ "import tiktoken\n",
+ "import numpy as np\n",
+ "from collections import defaultdict\n",
+ "\n",
+ "# Next, we specify the data path and open the JSONL file\n",
+ "\n",
+ "data_path = encoded_file_path\n",
+ "\n",
+ "# Load dataset\n",
+ "with open(data_path) as f:\n",
+ " dataset = [json.loads(line) for line in f]\n",
+ "\n",
+ "# We can inspect the data quickly by checking the number of examples and the first item\n",
+ "\n",
+ "# Initial dataset stats\n",
+ "print(\"Num examples:\", len(dataset))\n",
+ "print(\"First example:\")\n",
+ "for message in dataset[0][\"messages\"]:\n",
+ " print(message)\n",
+ "\n",
+ "# Now that we have a sense of the data, we need to go through all the different examples and check to make sure the formatting is correct and matches the Chat completions message structure\n",
+ "\n",
+ "# Format error checks\n",
+ "format_errors = defaultdict(int)\n",
+ "\n",
+ "for ex in dataset:\n",
+ " if not isinstance(ex, dict):\n",
+ " format_errors[\"data_type\"] += 1\n",
+ " continue\n",
+ "\n",
+ " messages = ex.get(\"messages\", None)\n",
+ " if not messages:\n",
+ " format_errors[\"missing_messages_list\"] += 1\n",
+ " continue\n",
+ "\n",
+ " for message in messages:\n",
+ " if \"role\" not in message or \"content\" not in message:\n",
+ " format_errors[\"message_missing_key\"] += 1\n",
+ "\n",
+ " if any(k not in (\"role\", \"content\", \"name\") for k in message):\n",
+ " format_errors[\"message_unrecognized_key\"] += 1\n",
+ "\n",
+ " if message.get(\"role\", None) not in (\"system\", \"user\", \"assistant\"):\n",
+ " format_errors[\"unrecognized_role\"] += 1\n",
+ "\n",
+ " content = message.get(\"content\", None)\n",
+ " if not content or not isinstance(content, str):\n",
+ " format_errors[\"missing_content\"] += 1\n",
+ "\n",
+ " if not any(message.get(\"role\", None) == \"assistant\" for message in messages):\n",
+ " format_errors[\"example_missing_assistant_message\"] += 1\n",
+ "\n",
+ "if format_errors:\n",
+ " print(\"Found errors:\")\n",
+ " for k, v in format_errors.items():\n",
+ " print(f\"{k}: {v}\")\n",
+ "else:\n",
+ " print(\"No errors found\")\n",
+ "\n",
+ "# Beyond the structure of the message, we also need to ensure that the length does not exceed the 4096 token limit.\n",
+ "\n",
+ "# Token counting functions\n",
+ "encoding = tiktoken.get_encoding(\"cl100k_base\")\n",
+ "\n",
+ "# not exact!\n",
+ "# simplified from https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb\n",
+ "def num_tokens_from_messages(messages, tokens_per_message=3, tokens_per_name=1):\n",
+ " num_tokens = 0\n",
+ " for message in messages:\n",
+ " num_tokens += tokens_per_message\n",
+ " for key, value in message.items():\n",
+ " num_tokens += len(encoding.encode(value))\n",
+ " if key == \"name\":\n",
+ " num_tokens += tokens_per_name\n",
+ " num_tokens += 3\n",
+ " return num_tokens\n",
+ "\n",
+ "def num_assistant_tokens_from_messages(messages):\n",
+ " num_tokens = 0\n",
+ " for message in messages:\n",
+ " if message[\"role\"] == \"assistant\":\n",
+ " num_tokens += len(encoding.encode(message[\"content\"]))\n",
+ " return num_tokens\n",
+ "\n",
+ "def print_distribution(values, name):\n",
+ " print(f\"\\n#### Distribution of {name}:\")\n",
+ " print(f\"min / max: {min(values)}, {max(values)}\")\n",
+ " print(f\"mean / median: {np.mean(values)}, {np.median(values)}\")\n",
+ " print(f\"p5 / p95: {np.quantile(values, 0.1)}, {np.quantile(values, 0.9)}\")\n",
+ "\n",
+ "# Last, we can look at the results of the different formatting operations before proceeding with creating a fine-tuning job:\n",
+ "\n",
+ "# Warnings and tokens counts\n",
+ "n_missing_system = 0\n",
+ "n_missing_user = 0\n",
+ "n_messages = []\n",
+ "convo_lens = []\n",
+ "assistant_message_lens = []\n",
+ "\n",
+ "for ex in dataset:\n",
+ " messages = ex[\"messages\"]\n",
+ " if not any(message[\"role\"] == \"system\" for message in messages):\n",
+ " n_missing_system += 1\n",
+ " if not any(message[\"role\"] == \"user\" for message in messages):\n",
+ " n_missing_user += 1\n",
+ " n_messages.append(len(messages))\n",
+ " convo_lens.append(num_tokens_from_messages(messages))\n",
+ " assistant_message_lens.append(num_assistant_tokens_from_messages(messages))\n",
+ "\n",
+ "print(\"Num examples missing system message:\", n_missing_system)\n",
+ "print(\"Num examples missing user message:\", n_missing_user)\n",
+ "print_distribution(n_messages, \"num_messages_per_example\")\n",
+ "print_distribution(convo_lens, \"num_total_tokens_per_example\")\n",
+ "print_distribution(assistant_message_lens, \"num_assistant_tokens_per_example\")\n",
+ "n_too_long = sum(l > 4096 for l in convo_lens)\n",
+ "print(f\"\\n{n_too_long} examples may be over the 4096 token limit, they will be truncated during fine-tuning\")\n",
+ "\n",
+ "# Pricing and default n_epochs estimate\n",
+ "MAX_TOKENS_PER_EXAMPLE = 4096\n",
+ "\n",
+ "MIN_TARGET_EXAMPLES = 100\n",
+ "MAX_TARGET_EXAMPLES = 25000\n",
+ "TARGET_EPOCHS = 3\n",
+ "MIN_EPOCHS = 1\n",
+ "MAX_EPOCHS = 25\n",
+ "\n",
+ "n_epochs = TARGET_EPOCHS\n",
+ "n_train_examples = len(dataset)\n",
+ "if n_train_examples * TARGET_EPOCHS < MIN_TARGET_EXAMPLES:\n",
+ " n_epochs = min(MAX_EPOCHS, MIN_TARGET_EXAMPLES // n_train_examples)\n",
+ "elif n_train_examples * TARGET_EPOCHS > MAX_TARGET_EXAMPLES:\n",
+ " n_epochs = max(MIN_EPOCHS, MAX_TARGET_EXAMPLES // n_train_examples)\n",
+ "\n",
+ "n_billing_tokens_in_dataset = sum(min(MAX_TOKENS_PER_EXAMPLE, length) for length in convo_lens)\n",
+ "print(f\"Dataset has ~{n_billing_tokens_in_dataset} tokens that will be charged for during training\")\n",
+ "print(f\"By default, you'll train for {n_epochs} epochs on this dataset\")\n",
+ "print(f\"By default, you'll be charged for ~{n_epochs * n_billing_tokens_in_dataset} tokens\")\n",
+ "print(\"See pricing page to estimate total costs\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "wandb.summary[\"num_samples\"] = len(dataset)\n",
+ "wandb.summary[\"n_billing_tokens_in_dataset\"] = n_billing_tokens_in_dataset"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Start Fine-tuning ChatGPT-3.5"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Create an OpenAI training file"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "openai.File.create(\n",
+ " file=open(encoded_file_path, \"rb\"),\n",
+ " purpose='fine-tune'\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Create your fine-tuning job"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "openai.api_key = openai_api_key\n",
+ "openai.FineTuningJob.create(\n",
+ " training_file=\"file-N9M4sC8GfXgTNw0WAwgiLHNR\", #\"file-OrxAP7HcvoSUmu9MtAbWo5s4\",\n",
+ " model=\"gpt-3.5-turbo\",\n",
+ " hyperparameters={\"n_epochs\": 3}\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "openai.FineTuningJob.list_events(id=\"ftjob-ShHWEMHa2U7gRNVTpjOYEZEP\", limit=5)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Log the results to Weights & Biases when the model is finished training\n",
+ "\n",
+ "(temporarily install openai from a fork until this PR to update the wandb logger is merged in openai: https://github.com/openai/openai-python/pull/590)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!pip uninstall -y openai -qq && pip install git+https://github.com/morganmcg1/openai-python.git@update_wandb_logger -qqq"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Run `openai wandb sync` to sync your openai results to W&B"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!OPENAI_API_KEY={openai_api_key} openai wandb sync --entity prompt-eng --project gorilla-api --id ftjob-mNSsI2UcxCvpV767GmnYoSzR"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Other useful openai commands"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "List 10 fine-tuning jobs"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "openai.FineTuningJob.list(limit=10)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Retrieve the state of a fine-tune"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "state = openai.FineTuningJob.retrieve(\"ftjob-qhg4yswil15TCqD4SNHn0V1D\")\n",
+ "state[\"status\"], state[\"trained_tokens\"], state[\"finished_at\"]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "List up to 10 events from a fine-tuning job"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "openai.FineTuningJob.list_events(id=\"ftjob-qhg4yswil15TCqD4SNHn0V1D\", limit=10)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Use the Model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "openai.api_key = openai_api_key\n",
+ "\n",
+ "completion = openai.ChatCompletion.create(\n",
+ " model=\"ft:gpt-3.5-turbo:my-org:custom_suffix:id\",\n",
+ " messages=[\n",
+ " {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n",
+ " {\"role\": \"user\", \"content\": \"How can i load a NER model?\"}\n",
+ " ]\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "pprint(completion.choices[0].message)\n",
+ "pprint(completion.choices[0].message[\"content\"])"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "include_colab_link": true,
+ "provenance": [],
+ "toc_visible": true
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/colabs/openai/assets/artifacts.png b/colabs/openai/assets/artifacts.png
new file mode 100644
index 00000000..a7d506c0
Binary files /dev/null and b/colabs/openai/assets/artifacts.png differ
diff --git a/colabs/openai/assets/configs.png b/colabs/openai/assets/configs.png
new file mode 100644
index 00000000..c38d3572
Binary files /dev/null and b/colabs/openai/assets/configs.png differ
diff --git a/colabs/openai/assets/dag.png b/colabs/openai/assets/dag.png
new file mode 100644
index 00000000..45d37012
Binary files /dev/null and b/colabs/openai/assets/dag.png differ
diff --git a/colabs/openai/assets/datatable.png b/colabs/openai/assets/datatable.png
new file mode 100644
index 00000000..5edf7dca
Binary files /dev/null and b/colabs/openai/assets/datatable.png differ
diff --git a/colabs/openai/assets/metrics.png b/colabs/openai/assets/metrics.png
new file mode 100644
index 00000000..669f73b1
Binary files /dev/null and b/colabs/openai/assets/metrics.png differ
diff --git a/colabs/openai/assets/select_artifact_uri.png b/colabs/openai/assets/select_artifact_uri.png
new file mode 100644
index 00000000..57179ff2
Binary files /dev/null and b/colabs/openai/assets/select_artifact_uri.png differ
diff --git a/colabs/openai/assets/select_model_artifact.png b/colabs/openai/assets/select_model_artifact.png
new file mode 100644
index 00000000..29a7b9de
Binary files /dev/null and b/colabs/openai/assets/select_model_artifact.png differ
diff --git a/colabs/paddlepaddle/paddledetection/PaddleDetection_and_W&B_Your_one_stop_for_everything_object_detection.ipynb b/colabs/paddlepaddle/paddledetection/PaddleDetection_and_W&B_Your_one_stop_for_everything_object_detection.ipynb
index 05fa664d..8bb85c06 100644
--- a/colabs/paddlepaddle/paddledetection/PaddleDetection_and_W&B_Your_one_stop_for_everything_object_detection.ipynb
+++ b/colabs/paddlepaddle/paddledetection/PaddleDetection_and_W&B_Your_one_stop_for_everything_object_detection.ipynb
@@ -110,7 +110,15 @@
"outputs": [],
"source": [
"\n",
- "import wandb\n",
+ "import wandb"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"wandb.login()"
]
},
@@ -376,13 +384,6 @@
"\n",
"If you have any questions about using W&B to track your model performance and predictions, please contact support@wandb.com"
]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
}
],
"metadata": {
diff --git a/colabs/paddlepaddle/paddleocr/Train_and_Debug_Your_OCR_Models_with_PaddleOCR_and_W&B.ipynb b/colabs/paddlepaddle/paddleocr/Train_and_Debug_Your_OCR_Models_with_PaddleOCR_and_W&B.ipynb
index 12ee49fa..f3fef3e8 100644
--- a/colabs/paddlepaddle/paddleocr/Train_and_Debug_Your_OCR_Models_with_PaddleOCR_and_W&B.ipynb
+++ b/colabs/paddlepaddle/paddleocr/Train_and_Debug_Your_OCR_Models_with_PaddleOCR_and_W&B.ipynb
@@ -94,7 +94,15 @@
"outputs": [],
"source": [
"\n",
- "import wandb\n",
+ "import wandb"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"wandb.login()"
]
},
@@ -425,13 +433,6 @@
"\n",
"If you have any questions about using W&B to track your model performance and predictions, please contact support@wandb.com"
]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
}
],
"metadata": {
diff --git a/colabs/paella/Image-Variations.ipynb b/colabs/paella/Image-Variations.ipynb
index d3556d64..33c0ba5d 100644
--- a/colabs/paella/Image-Variations.ipynb
+++ b/colabs/paella/Image-Variations.ipynb
@@ -2,6 +2,7 @@
"cells": [
{
"cell_type": "markdown",
+ "id": "e6551fcc",
"metadata": {},
"source": [
"\n",
@@ -10,7 +11,7 @@
},
{
"cell_type": "markdown",
- "id": "f25c1396",
+ "id": "7df20a25",
"metadata": {},
"source": [
"# 🔥🔥 Image Variations with Paella + WandB Playground 🪄🐝\n",
@@ -23,7 +24,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "dde43958-b0ed-4bb2-99b7-08cf5a66bf53",
+ "id": "03036e33",
"metadata": {},
"outputs": [],
"source": [
@@ -52,7 +53,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "674111ae-ff8e-41b4-81b4-4961eb625f2b",
+ "id": "e10c4b4f",
"metadata": {},
"outputs": [],
"source": [
@@ -88,7 +89,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "0b15f5b6-9641-45c8-8fbe-709cb26343c2",
+ "id": "4bcf061f",
"metadata": {},
"outputs": [],
"source": [
@@ -99,7 +100,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "7e4d1f19-74dc-44c9-8977-99cb573a44ae",
+ "id": "72953d5d",
"metadata": {},
"outputs": [],
"source": [
@@ -124,7 +125,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "12e1580f-6808-4309-b8f1-39b5709e920b",
+ "id": "c1d41153",
"metadata": {},
"outputs": [],
"source": [
@@ -197,7 +198,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "10af9b57-3028-4c26-af7e-73a58df648ab",
+ "id": "2ea9c73d",
"metadata": {},
"outputs": [],
"source": [
@@ -231,7 +232,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "e6b81f38-ddba-4b0b-8eb0-1f0c49b0e3ef",
+ "id": "3fb179ef",
"metadata": {},
"outputs": [],
"source": [
@@ -261,7 +262,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "f6822644-c7bc-422b-84b3-31ea3923fce7",
+ "id": "d3d53f21",
"metadata": {},
"outputs": [],
"source": [
@@ -281,21 +282,13 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "5979feb4-2524-498a-9da6-477927d57316",
+ "id": "0e163076",
"metadata": {},
"outputs": [],
"source": [
"log_image_variations_results(original_image, sampled)\n",
"wandb.finish()"
]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "3e85d29a-caf0-4331-ab68-bb33112d4776",
- "metadata": {},
- "outputs": [],
- "source": []
}
],
"metadata": {
diff --git a/colabs/paella/Inpainting.ipynb b/colabs/paella/Inpainting.ipynb
index ad17db88..ae21ed4f 100644
--- a/colabs/paella/Inpainting.ipynb
+++ b/colabs/paella/Inpainting.ipynb
@@ -2,6 +2,7 @@
"cells": [
{
"cell_type": "markdown",
+ "id": "a51db41e",
"metadata": {},
"source": [
"\n",
@@ -11,7 +12,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "bd4a4bd6",
+ "id": "752bdb8a",
"metadata": {},
"outputs": [],
"source": [
@@ -25,7 +26,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "735c22d8-d4b4-483f-a1c5-3e57b0dd430f",
+ "id": "b128ff01",
"metadata": {},
"outputs": [],
"source": [
@@ -54,7 +55,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "2318c203-2cce-4f2d-9f51-1d094c95c5ee",
+ "id": "8ff4273f",
"metadata": {},
"outputs": [],
"source": [
@@ -89,7 +90,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "066cad54-dc91-42e1-9671-58715daf4c2d",
+ "id": "c3d88d2d",
"metadata": {},
"outputs": [],
"source": [
@@ -100,7 +101,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "db412629-cdaf-492b-b65e-184ebcd0ee62",
+ "id": "44df1809",
"metadata": {},
"outputs": [],
"source": [
@@ -139,7 +140,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "ea8632b6-fdd0-459a-bded-e987a3d14610",
+ "id": "616acee0",
"metadata": {},
"outputs": [],
"source": [
@@ -212,7 +213,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "f2ad44fc-ce5d-42ab-aea5-b985f8336379",
+ "id": "d96ec8b5",
"metadata": {},
"outputs": [],
"source": [
@@ -246,7 +247,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "4e0424fd-1561-437a-9bbd-8e648f18df47",
+ "id": "5eebe4fa",
"metadata": {},
"outputs": [],
"source": [
@@ -264,7 +265,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "3598db26-c261-4bde-a30f-5927581e0554",
+ "id": "6731402d",
"metadata": {},
"outputs": [],
"source": [
@@ -296,21 +297,13 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "643d9e25-a083-4554-9640-ef681a305a08",
+ "id": "2a99ace7",
"metadata": {},
"outputs": [],
"source": [
"log_inoutpainting_results(original_image, sampled, encoded_tokens, mask)\n",
"wandb.finish()"
]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "b17cd0e2-a354-4a8b-a68d-11f7d7d48345",
- "metadata": {},
- "outputs": [],
- "source": []
}
],
"metadata": {
diff --git a/colabs/paella/Latent-Space-Interpolation.ipynb b/colabs/paella/Latent-Space-Interpolation.ipynb
index c63c287b..b6725d1c 100644
--- a/colabs/paella/Latent-Space-Interpolation.ipynb
+++ b/colabs/paella/Latent-Space-Interpolation.ipynb
@@ -2,6 +2,7 @@
"cells": [
{
"cell_type": "markdown",
+ "id": "b9d2a9e5",
"metadata": {},
"source": [
"\n",
@@ -11,7 +12,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "f37e1062",
+ "id": "c9292760",
"metadata": {},
"outputs": [],
"source": [
@@ -25,7 +26,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "833f912f-98d0-4055-933c-9803ab1a34ed",
+ "id": "ef1d9a73",
"metadata": {},
"outputs": [],
"source": [
@@ -54,7 +55,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "1b41825a-63f3-4551-9c4d-a809297f7e0d",
+ "id": "aa426f22",
"metadata": {},
"outputs": [],
"source": [
@@ -87,7 +88,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "7b03d40b-65b6-4939-95ef-4bf2ce3b0133",
+ "id": "f201d581",
"metadata": {},
"outputs": [],
"source": [
@@ -98,7 +99,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "9ffdcff9-c5a4-4158-a1ba-52beff090801",
+ "id": "2cc157dd",
"metadata": {},
"outputs": [],
"source": [
@@ -112,7 +113,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "7eb4b944-fbe6-431a-adf8-6ef3e72a4c82",
+ "id": "53fe0dd0",
"metadata": {},
"outputs": [],
"source": [
@@ -186,7 +187,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "eeaf0fdd-95fb-4154-91bb-806f9aed6c5a",
+ "id": "6a8c2393",
"metadata": {},
"outputs": [],
"source": [
@@ -220,7 +221,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "00b740f5-0b4e-465a-bf90-a5b5860cc785",
+ "id": "a11b5b7c",
"metadata": {},
"outputs": [],
"source": [
@@ -277,21 +278,13 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "de5f608e-f9c2-419b-b9d1-3a1722d055be",
+ "id": "4d5f7b63",
"metadata": {},
"outputs": [],
"source": [
"log_interpolation_results(sampled)\n",
"wandb.finish()"
]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "ffa7516b-bd77-41a3-b9d3-95b110f57106",
- "metadata": {},
- "outputs": [],
- "source": []
}
],
"metadata": {
diff --git a/colabs/paella/Multi-Conditioning.ipynb b/colabs/paella/Multi-Conditioning.ipynb
index e42e9807..873ca63d 100644
--- a/colabs/paella/Multi-Conditioning.ipynb
+++ b/colabs/paella/Multi-Conditioning.ipynb
@@ -2,6 +2,7 @@
"cells": [
{
"cell_type": "markdown",
+ "id": "fe90fd24",
"metadata": {},
"source": [
"\n",
@@ -10,7 +11,7 @@
},
{
"cell_type": "markdown",
- "id": "c883955a",
+ "id": "318fd50a",
"metadata": {},
"source": [
"# 🔥🔥 Multi-Conditional Image Generation with Paella + WandB Playground 🪄🐝\n",
@@ -23,7 +24,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "fb35d29d-ecec-4917-b876-f65b7e8def5f",
+ "id": "d3d461f8",
"metadata": {},
"outputs": [],
"source": [
@@ -52,7 +53,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "e23328f6-12f2-498c-9322-803c67f2a00f",
+ "id": "bb52df32",
"metadata": {},
"outputs": [],
"source": [
@@ -88,7 +89,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "4693a9b6-950f-45bc-b08c-1075bf8abe62",
+ "id": "e3533942",
"metadata": {},
"outputs": [],
"source": [
@@ -99,7 +100,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "172af2c6-3440-4cbb-8dc2-88f95d0ea9a0",
+ "id": "ff2b0cc2",
"metadata": {},
"outputs": [],
"source": [
@@ -119,7 +120,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "cff68bc9-8fc2-449d-994b-9642669f8d63",
+ "id": "2aeb8cd4",
"metadata": {},
"outputs": [],
"source": [
@@ -193,7 +194,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "27874a53-1e32-4206-b64f-05e58fba45ab",
+ "id": "dcef1dac",
"metadata": {},
"outputs": [],
"source": [
@@ -227,7 +228,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "3753c2be-cbce-4292-b73f-2701a8ad99c0",
+ "id": "11a0548a",
"metadata": {},
"outputs": [],
"source": [
@@ -258,21 +259,13 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "fc9c8431-6068-4961-bc01-f128b1d88fb6",
+ "id": "9b20984f",
"metadata": {},
"outputs": [],
"source": [
"log_multi_conditioning_results(sampled)\n",
"wandb.finish()"
]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "17a90d54-7717-402f-9086-ea40c1866adf",
- "metadata": {},
- "outputs": [],
- "source": []
}
],
"metadata": {
diff --git a/colabs/paella/Orientation-Guided-Multi-Conditioning.ipynb b/colabs/paella/Orientation-Guided-Multi-Conditioning.ipynb
index fc5e35a7..1392accc 100644
--- a/colabs/paella/Orientation-Guided-Multi-Conditioning.ipynb
+++ b/colabs/paella/Orientation-Guided-Multi-Conditioning.ipynb
@@ -2,6 +2,7 @@
"cells": [
{
"cell_type": "markdown",
+ "id": "4caab47d",
"metadata": {},
"source": [
"\n",
@@ -10,7 +11,7 @@
},
{
"cell_type": "markdown",
- "id": "d10bd631",
+ "id": "c39d2d61",
"metadata": {},
"source": [
"# 🔥🔥 Orientation Guided Multi-Conditional Image Generation with Paella + WandB Playground 🪄🐝\n",
@@ -23,7 +24,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "c1019eb6-1b29-4ce5-95fa-6358ce0dc584",
+ "id": "e162b915",
"metadata": {},
"outputs": [],
"source": [
@@ -52,7 +53,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "94f4d23e-ff7e-4de0-a69c-ca13751b96a5",
+ "id": "769bbf69",
"metadata": {},
"outputs": [],
"source": [
@@ -86,7 +87,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "6ca68fca-20b2-4ef5-ba7d-fb12bc906a06",
+ "id": "e14a4831",
"metadata": {},
"outputs": [],
"source": [
@@ -97,7 +98,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "412fe0e2-032c-4154-80f6-22f4f0d53315",
+ "id": "5a75569e",
"metadata": {},
"outputs": [],
"source": [
@@ -123,7 +124,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "b30c6757-c9cf-4acc-b909-de9ce990bb29",
+ "id": "e9bb0938",
"metadata": {},
"outputs": [],
"source": [
@@ -196,7 +197,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "4784d7a7-d0c0-4314-922a-6db000642b77",
+ "id": "f0151079",
"metadata": {},
"outputs": [],
"source": [
@@ -230,7 +231,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "9574e4b9-050f-47ee-8ca3-951268539d3f",
+ "id": "d6c6d26a",
"metadata": {},
"outputs": [],
"source": [
@@ -277,7 +278,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "a8e626a6-6b86-4208-8767-101a152401cc",
+ "id": "de3c09ea",
"metadata": {},
"outputs": [],
"source": [
diff --git a/colabs/paella/Outpainting.ipynb b/colabs/paella/Outpainting.ipynb
index afda36bd..386d6906 100644
--- a/colabs/paella/Outpainting.ipynb
+++ b/colabs/paella/Outpainting.ipynb
@@ -2,6 +2,7 @@
"cells": [
{
"cell_type": "markdown",
+ "id": "e205baf0",
"metadata": {},
"source": [
"\n",
@@ -10,7 +11,7 @@
},
{
"cell_type": "markdown",
- "id": "b2903db1",
+ "id": "def788f4",
"metadata": {},
"source": [
"# 🔥🔥 Image Outpainting with Paella + WandB Playground 🪄🐝\n",
@@ -23,7 +24,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "25439e07-57df-4b4a-a095-f02fd5b82a31",
+ "id": "19414087",
"metadata": {},
"outputs": [],
"source": [
@@ -52,7 +53,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "e9f71025-a469-415e-ad7e-656080e8ef08",
+ "id": "4ef82556",
"metadata": {},
"outputs": [],
"source": [
@@ -88,7 +89,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "a38d18fc-075f-4a05-9a66-97c8a9ae9d93",
+ "id": "7274ce33",
"metadata": {},
"outputs": [],
"source": [
@@ -99,7 +100,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "2c956e5b-df8b-4d4d-93d6-fbbd23523241",
+ "id": "8cd2036d",
"metadata": {},
"outputs": [],
"source": [
@@ -138,7 +139,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "871f433c-5c86-472d-8313-cfa74eec7a0f",
+ "id": "1057acf3",
"metadata": {},
"outputs": [],
"source": [
@@ -211,7 +212,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "536ad909-0764-44d1-bf79-b8e905a760b0",
+ "id": "b6a7beac",
"metadata": {},
"outputs": [],
"source": [
@@ -245,7 +246,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "edf4cc02-c1bf-4b6b-aeef-8c121322bd9f",
+ "id": "757f8a53",
"metadata": {},
"outputs": [],
"source": [
@@ -263,7 +264,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "07446a5c-ac79-46cd-898b-fb29048d30a7",
+ "id": "2574a92e",
"metadata": {},
"outputs": [],
"source": [
@@ -300,21 +301,13 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "de73e4a4-8e9e-41de-9138-be7b3209db18",
+ "id": "afff6dd4",
"metadata": {},
"outputs": [],
"source": [
"log_inoutpainting_results(original_image, sampled, encoded_tokens, mask)\n",
"wandb.finish()"
]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "cecd9119-ad2f-4968-9395-e89626fbb032",
- "metadata": {},
- "outputs": [],
- "source": []
}
],
"metadata": {
diff --git a/colabs/paella/Structural-Morphing.ipynb b/colabs/paella/Structural-Morphing.ipynb
index 26a7c7a1..c25b6dd4 100644
--- a/colabs/paella/Structural-Morphing.ipynb
+++ b/colabs/paella/Structural-Morphing.ipynb
@@ -2,6 +2,7 @@
"cells": [
{
"cell_type": "markdown",
+ "id": "b3a265c8",
"metadata": {},
"source": [
"\n",
@@ -10,7 +11,7 @@
},
{
"cell_type": "markdown",
- "id": "aea68e9b",
+ "id": "c4165d97",
"metadata": {},
"source": [
"# 🔥🔥 Structural Morphing of Images with Paella + WandB Playground 🪄🐝\n",
@@ -23,7 +24,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "864b20f3-0e1f-4631-a3b5-b01ca229b9df",
+ "id": "a28f9cb0",
"metadata": {},
"outputs": [],
"source": [
@@ -52,7 +53,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "8441ea28-1555-4d59-b1de-56e85f48fdc5",
+ "id": "966479fa",
"metadata": {},
"outputs": [],
"source": [
@@ -87,7 +88,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "f0814312-cca0-4486-a222-4b90c6eecfd9",
+ "id": "32ceb3d7",
"metadata": {},
"outputs": [],
"source": [
@@ -98,7 +99,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "bcee6f91-56bd-4ff5-b042-95465e0949f5",
+ "id": "4bc326f1",
"metadata": {},
"outputs": [],
"source": [
@@ -123,7 +124,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "37526b5e-14d3-4057-99a3-9b12e1baab6a",
+ "id": "09aba79c",
"metadata": {},
"outputs": [],
"source": [
@@ -196,7 +197,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "15138eea-95cd-45c7-af76-d62309d266c3",
+ "id": "399194b3",
"metadata": {},
"outputs": [],
"source": [
@@ -230,7 +231,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "58743cf3-30c0-4ddc-b008-ee847742c96b",
+ "id": "97287634",
"metadata": {},
"outputs": [],
"source": [
@@ -248,7 +249,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "0ee870ef-70b3-4728-8dcd-1bfdd7117246",
+ "id": "c4b4daf8",
"metadata": {},
"outputs": [],
"source": [
@@ -277,7 +278,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "f6ccceac-8531-41b0-8b73-9953f1eff480",
+ "id": "87f06f8a",
"metadata": {},
"outputs": [],
"source": [
diff --git a/colabs/paella/Text-Conditional.ipynb b/colabs/paella/Text-Conditional.ipynb
index edf4049a..2b076f69 100644
--- a/colabs/paella/Text-Conditional.ipynb
+++ b/colabs/paella/Text-Conditional.ipynb
@@ -2,6 +2,7 @@
"cells": [
{
"cell_type": "markdown",
+ "id": "ffc4c8a0",
"metadata": {},
"source": [
"\n",
@@ -10,7 +11,7 @@
},
{
"cell_type": "markdown",
- "id": "d4a618d5",
+ "id": "33ecf3a8",
"metadata": {},
"source": [
"# 🔥🔥 Text-Conditional Image Generation with Paella + WandB Playground 🪄🐝\n",
@@ -23,7 +24,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "78f75e79-7ed1-4282-b6dd-b6dcb1d1e408",
+ "id": "9842f6c0",
"metadata": {},
"outputs": [],
"source": [
@@ -52,7 +53,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "2c3fc5a7-72ed-4a3b-8d29-ff1519ec88ba",
+ "id": "94352978",
"metadata": {},
"outputs": [],
"source": [
@@ -82,7 +83,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "4813a37e-31d4-403e-bfcb-d17ea7cfa995",
+ "id": "cd216ad2",
"metadata": {},
"outputs": [],
"source": [
@@ -93,7 +94,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "127b8965-3def-4d7e-9e7c-c9875bbe946f",
+ "id": "43772a65",
"metadata": {},
"outputs": [],
"source": [
@@ -107,7 +108,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "7d62edb2-fa7a-447c-8ecd-1f65f73adf41",
+ "id": "1d5662b1",
"metadata": {},
"outputs": [],
"source": [
@@ -181,7 +182,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "d18f9ead-88ce-4919-950e-7f99624f9344",
+ "id": "23700dc8",
"metadata": {},
"outputs": [],
"source": [
@@ -215,7 +216,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "0de7052a-ac91-4c57-8da2-ddbd2c1cb837",
+ "id": "42c83c7e",
"metadata": {},
"outputs": [],
"source": [
@@ -237,7 +238,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "625bb261-f76f-4a4c-b7e4-b3eee2d6f5d4",
+ "id": "cddd0c03",
"metadata": {},
"outputs": [],
"source": [
diff --git a/colabs/prompts/W&B_Prompts_with_Custom_Columns.ipynb b/colabs/prompts/W&B_Prompts_with_Custom_Columns.ipynb
index ebc811cf..e69de29b 100644
--- a/colabs/prompts/W&B_Prompts_with_Custom_Columns.ipynb
+++ b/colabs/prompts/W&B_Prompts_with_Custom_Columns.ipynb
@@ -1,618 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "e-ZYaV5KGVmA"
- },
- "source": [
- "\n",
- ""
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "gJSVEAGWGVmA"
- },
- "source": [
- "\n",
- ""
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "9f7yMKLwGVmA"
- },
- "source": [
- "**[Weights & Biases Prompts](https://docs.wandb.ai/guides/prompts?utm_source=code&utm_medium=colab&utm_campaign=prompts)** is a suite of LLMOps tools built for the development of LLM-powered applications.\n",
- "\n",
- "Use W&B Prompts to visualize and inspect the execution flow of your LLMs, analyze the inputs and outputs of your LLMs, view the intermediate results and securely store and manage your prompts and LLM chain configurations.\n",
- "\n",
- "#### [🪄 View Prompts In Action](https://wandb.ai/timssweeney/prompts-demo/)\n",
- "\n",
- "**In this notebook we will demostrate W&B Prompts:**\n",
- "\n",
- "- Using our 1-line LangChain integration\n",
- "- Using our Trace class when building your own LLM Pipelines\n",
- "\n",
- "See here for the full [W&B Prompts documentation](https://docs.wandb.ai/guides/prompts)\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "A4wI3b_8GVmB"
- },
- "source": [
- "## Installation"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {
- "id": "nDoIqQ8_GVmB"
- },
- "outputs": [],
- "source": [
- "!pip install \"wandb>=0.15.4\" -qqq\n",
- "!pip install \"langchain>=0.0.218\" openai -qqq"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {
- "id": "PcGiSWBSGVmB"
- },
- "outputs": [],
- "source": [
- "import langchain\n",
- "assert langchain.__version__ >= \"0.0.218\", \"Please ensure you are using LangChain v0.0.188 or higher\""
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "pbmQIsjJGVmB"
- },
- "source": [
- "## Setup\n",
- "\n",
- "This demo requires that you have an [OpenAI key](https://platform.openai.com)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "ZH4g2B0lGVmB",
- "outputId": "22295db6-5369-474d-a8ea-fb45c4c92085"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Paste your OpenAI key from: https://platform.openai.com/account/api-keys\n",
- "··········\n",
- "OpenAI API key configured\n"
- ]
- }
- ],
- "source": [
- "import os\n",
- "from getpass import getpass\n",
- "\n",
- "if os.getenv(\"OPENAI_API_KEY\") is None:\n",
- " os.environ[\"OPENAI_API_KEY\"] = getpass(\"Paste your OpenAI key from: https://platform.openai.com/account/api-keys\\n\")\n",
- "assert os.getenv(\"OPENAI_API_KEY\", \"\").startswith(\"sk-\"), \"This doesn't look like a valid OpenAI API key\"\n",
- "print(\"OpenAI API key configured\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "79KOB2EhGVmB"
- },
- "source": [
- "# W&B Prompts\n",
- "\n",
- "W&B Prompts consists of three main components:\n",
- "\n",
- "**Trace table**: Overview of the inputs and outputs of a chain.\n",
- "\n",
- "**Trace timeline**: Displays the execution flow of the chain and is color-coded according to component types.\n",
- "\n",
- "**Model architecture**: View details about the structure of the chain and the parameters used to initialize each component of the chain.\n",
- "\n",
- "After running this section, you will see a new panel automatically created in your workspace, showing each execution, the trace, and the model architecture"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "5kxmdm3zGVmC"
- },
- "source": [
- ""
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "9u97K5vVGVmC"
- },
- "source": [
- "## Maths with LangChain"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "oneRFmv6GVmC"
- },
- "source": [
- "Set the `LANGCHAIN_WANDB_TRACING` environment variable as well as any other relevant [W&B environment variables](https://docs.wandb.ai/guides/track/environment-variables). This could includes a W&B project name, team name, and more. See [wandb.init](https://docs.wandb.ai/ref/python/init) for a full list of arguments."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {
- "id": "ACl-rMtAGVmC"
- },
- "outputs": [],
- "source": [
- "os.environ[\"LANGCHAIN_WANDB_TRACING\"] = \"true\"\n",
- "os.environ[\"WANDB_PROJECT\"] = \"langchain-testing\""
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {
- "id": "csp3MXG4GVmC"
- },
- "outputs": [],
- "source": [
- "from langchain.chat_models import ChatOpenAI\n",
- "from langchain.agents import load_tools, initialize_agent, AgentType"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "2hWU2GcAGVmC"
- },
- "source": [
- "Create a standard math Agent using LangChain"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {
- "id": "l_JkVMlRGVmC"
- },
- "outputs": [],
- "source": [
- "llm = ChatOpenAI(temperature=0)\n",
- "tools = load_tools([\"llm-math\"], llm=llm)\n",
- "math_agent = initialize_agent(tools,\n",
- " llm,\n",
- " agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "9FFviwCPGVmC"
- },
- "source": [
- "Use LangChain as normal by calling your Agent.\n",
- "\n",
- " You will see a Weights & Biases run start and you will be asked for your [Weights & Biases API key](wwww.wandb.ai/authorize). Once your enter your API key, the inputs and outputs of your Agent calls will start to be streamed to the Weights & Biases App."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 178
- },
- "id": "y-RHjVN4GVmC",
- "outputId": "5ccd5f32-6137-46c3-9abd-d458dbdbacca"
- },
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "\u001b[34m\u001b[1mwandb\u001b[0m: Streaming LangChain activity to W&B at https://wandb.ai/carey/langchain-testing/runs/lcznj5lg\n",
- "\u001b[34m\u001b[1mwandb\u001b[0m: `WandbTracer` is currently in beta.\n",
- "\u001b[34m\u001b[1mwandb\u001b[0m: Please report any issues to https://github.com/wandb/wandb/issues with the tag `langchain`.\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "LLMMathChain._evaluate(\"\n",
- "import math\n",
- "math.sqrt(5.4)\n",
- "\") raised error: invalid syntax (, line 1). Please try again with a valid numerical expression\n",
- "0.005720801417544866\n",
- "0.15096209512635608\n"
- ]
- }
- ],
- "source": [
- "# some sample maths questions\n",
- "questions = [\n",
- " \"Find the square root of 5.4.\",\n",
- " \"What is 3 divided by 7.34 raised to the power of pi?\",\n",
- " \"What is the sin of 0.47 radians, divided by the cube root of 27?\"\n",
- "]\n",
- "\n",
- "for question in questions:\n",
- " try:\n",
- " # call your Agent as normal\n",
- " answer = math_agent.run(question)\n",
- " print(answer)\n",
- " except Exception as e:\n",
- " # any errors will be also logged to Weights & Biases\n",
- " print(e)\n",
- " pass"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "SNYFSaUrGVmC"
- },
- "source": [
- "Once each Agent execution completes, all calls in your LangChain object will be logged to Weights & Biases"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "m0bL1xpkGVmC"
- },
- "source": [
- "### LangChain Context Manager\n",
- "Depending on your use case, you might instead prefer to use a context manager to manage your logging to W&B.\n",
- "\n",
- "**✨ New: Custom columns** can be logged directly to W&B to display in the same Trace Table with this snippet:\n",
- "```python\n",
- "import wandb\n",
- "wandb.log(custom_metrics_dict, commit=False})\n",
- "```\n",
- "Use `commit=False` to make sure that metadata is logged to the same row of the Trace Table as the LangChain output."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 10,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 35
- },
- "id": "7i9Pj1NKGVmC",
- "outputId": "b44f3ae7-fd49-437f-af7b-fb8f82056bd0"
- },
- "outputs": [
- {
- "data": {
- "application/vnd.google.colaboratory.intrinsic+json": {
- "type": "string"
- },
- "text/plain": [
- "'1.0891804557407723'"
- ]
- },
- "execution_count": 10,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "from langchain.callbacks import wandb_tracing_enabled\n",
- "import wandb # To enable custom column logging with wandb.run.log()\n",
- "\n",
- "# unset the environment variable and use a context manager instead\n",
- "if \"LANGCHAIN_WANDB_TRACING\" in os.environ:\n",
- " del os.environ[\"LANGCHAIN_WANDB_TRACING\"]\n",
- "\n",
- "# enable tracing using a context manager\n",
- "with wandb_tracing_enabled():\n",
- " for i in range (10):\n",
- " # Log any custom columns you'd like to add to the Trace Table\n",
- " wandb.log({\"custom_column\": i}, commit=False)\n",
- " try:\n",
- " math_agent.run(f\"What is {i} raised to .123243 power?\") # this should be traced\n",
- " except:\n",
- " pass\n",
- "\n",
- "math_agent.run(\"What is 2 raised to .123243 power?\") # this should not be traced"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "JDLzoorhGVmC"
- },
- "source": [
- "# Non-Lang Chain Implementation\n",
- "\n",
- "\n",
- "A W&B Trace is created by logging 1 or more \"spans\". A root span is expected, which can accept nested child spans, which can in turn accept their own child spans. A Span represents a unit of work, Spans can have type `AGENT`, `TOOL`, `LLM` or `CHAIN`\n",
- "\n",
- "When logging with Trace, a single W&B run can have multiple calls to a LLM, Tool, Chain or Agent logged to it, there is no need to start a new W&B run after each generation from your model or pipeline, instead each call will be appended to the Trace Table.\n",
- "\n",
- "In this quickstart, we will how to log a single call to an OpenAI model to W&B Trace as a single span. Then we will show how to log a more complex series of nested spans.\n",
- "\n",
- "## Logging with W&B Trace"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "7z98yfoqGVmD"
- },
- "source": [
- "Call wandb.init to start a W&B run. Here you can pass a W&B project name as well as an entity name (if logging to a W&B Team), as well as a config and more. See wandb.init for the full list of arguments.\n",
- "\n",
- "You will see a Weights & Biases run start and be asked for your [Weights & Biases API key](wwww.wandb.ai/authorize). Once your enter your API key, the inputs and outputs of your Agent calls will start to be streamed to the Weights & Biases App.\n",
- "\n",
- "**Note:** A W&B run supports logging as many traces you needed to a single run, i.e. you can make multiple calls of `run.log` without the need to create a new run each time"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "ZcvgzZ55GVmD"
- },
- "outputs": [],
- "source": [
- "import wandb\n",
- "\n",
- "# start a wandb run to log to\n",
- "wandb.init(project=\"trace-example\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "4_3Wrg2YGVmD"
- },
- "source": [
- "You can also set the entity argument in wandb.init if logging to a W&B Team.\n",
- "\n",
- "### Logging a single Span\n",
- "Now we will query OpenAI times and log the results to a W&B Trace. We will log the inputs and outputs, start and end times, whether the OpenAI call was successful, the token usage, and additional metadata.\n",
- "\n",
- "You can see the full description of the arguments to the Trace class [here](https://soumik12345.github.io/wandb-addons/prompts/tracer/)."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "q2pkMhpMGVmD"
- },
- "outputs": [],
- "source": [
- "import openai\n",
- "import datetime\n",
- "from wandb.sdk.data_types.trace_tree import Trace\n",
- "\n",
- "openai.api_key = os.environ[\"OPENAI_API_KEY\"]\n",
- "\n",
- "# define your conifg\n",
- "model_name = \"gpt-3.5-turbo\"\n",
- "temperature = 0.7\n",
- "system_message = \"You are a helpful assistant that always replies in 3 concise bullet points using markdown.\"\n",
- "\n",
- "queries_ls = [\n",
- " \"What is the capital of France?\",\n",
- " \"How do I boil an egg?\" * 10000, # deliberately trigger an openai error\n",
- " \"What to do if the aliens arrive?\"\n",
- "]\n",
- "\n",
- "for query in queries_ls:\n",
- " messages=[\n",
- " {\"role\": \"system\", \"content\": system_message},\n",
- " {\"role\": \"user\", \"content\": query}\n",
- " ]\n",
- "\n",
- " start_time_ms = datetime.datetime.now().timestamp() * 1000\n",
- " try:\n",
- " response = openai.ChatCompletion.create(model=model_name,\n",
- " messages=messages,\n",
- " temperature=temperature\n",
- " )\n",
- "\n",
- " end_time_ms = round(datetime.datetime.now().timestamp() * 1000) # logged in milliseconds\n",
- " status=\"success\"\n",
- " status_message=None,\n",
- " response_text = response[\"choices\"][0][\"message\"][\"content\"]\n",
- " token_usage = response[\"usage\"].to_dict()\n",
- "\n",
- "\n",
- " except Exception as e:\n",
- " end_time_ms = round(datetime.datetime.now().timestamp() * 1000) # logged in milliseconds\n",
- " status=\"error\"\n",
- " status_message=str(e)\n",
- " response_text = \"\"\n",
- " token_usage = {}\n",
- "\n",
- " # create a span in wandb\n",
- " root_span = Trace(\n",
- " name=\"root_span\",\n",
- " kind=\"llm\", # kind can be \"llm\", \"chain\", \"agent\" or \"tool\"\n",
- " status_code=status,\n",
- " status_message=status_message,\n",
- " metadata={\"temperature\": temperature,\n",
- " \"token_usage\": token_usage,\n",
- " \"model_name\": model_name},\n",
- " start_time_ms=start_time_ms,\n",
- " end_time_ms=end_time_ms,\n",
- " inputs={\"system_prompt\": system_message, \"query\": query},\n",
- " outputs={\"response\": response_text},\n",
- " )\n",
- "\n",
- " # log the span to wandb\n",
- " root_span.log(name=\"openai_trace\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "XFcwFgaDGVmD"
- },
- "source": [
- "### Logging a LLM pipeline using nested Spans\n",
- "\n",
- "In this example we will simulate an Agent being called, which then calls a LLM Chain, which calls an OpenAI LLM and then the Agent \"calls\" a Calculator tool.\n",
- "\n",
- "The inputs, outputs and metadata for each step in the execution of our \"Agent\" is logged in its own span. Spans can have child"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "ACMaGuYUGVmD"
- },
- "outputs": [],
- "source": [
- "import time\n",
- "\n",
- "openai.api_key = os.environ[\"OPENAI_API_KEY\"]\n",
- "\n",
- "# The query our agent has to answer\n",
- "query = \"How many days until the next US election?\"\n",
- "\n",
- "# part 1 - an Agent is started...\n",
- "start_time_ms = round(datetime.datetime.now().timestamp() * 1000)\n",
- "\n",
- "root_span = Trace(\n",
- " name=\"MyAgent\",\n",
- " kind=\"agent\",\n",
- " start_time_ms=start_time_ms,\n",
- " metadata={\"user\": \"optimus_12\"})\n",
- "\n",
- "\n",
- "# part 2 - The Agent calls into a LLMChain..\n",
- "chain_span = Trace(\n",
- " name=\"LLMChain\",\n",
- " kind=\"chain\",\n",
- " start_time_ms=start_time_ms)\n",
- "\n",
- "# add the Chain span as a child of the root\n",
- "root_span.add_child(chain_span)\n",
- "\n",
- "\n",
- "# part 3 - the LLMChain calls an OpenAI LLM...\n",
- "messages=[\n",
- " {\"role\": \"system\", \"content\": system_message},\n",
- " {\"role\": \"user\", \"content\": query}\n",
- "]\n",
- "\n",
- "response = openai.ChatCompletion.create(model=model_name,\n",
- " messages=messages,\n",
- " temperature=temperature)\n",
- "\n",
- "llm_end_time_ms = round(datetime.datetime.now().timestamp() * 1000)\n",
- "response_text = response[\"choices\"][0][\"message\"][\"content\"]\n",
- "token_usage = response[\"usage\"].to_dict()\n",
- "\n",
- "llm_span = Trace(\n",
- " name=\"OpenAI\",\n",
- " kind=\"llm\",\n",
- " status_code=\"success\",\n",
- " metadata={\"temperature\":temperature,\n",
- " \"token_usage\": token_usage,\n",
- " \"model_name\":model_name},\n",
- " start_time_ms=start_time_ms,\n",
- " end_time_ms=llm_end_time_ms,\n",
- " inputs={\"system_prompt\":system_message, \"query\":query},\n",
- " outputs={\"response\": response_text},\n",
- " )\n",
- "\n",
- "# add the LLM span as a child of the Chain span...\n",
- "chain_span.add_child(llm_span)\n",
- "\n",
- "# update the end time of the Chain span\n",
- "chain_span.add_inputs_and_outputs(\n",
- " inputs={\"query\":query},\n",
- " outputs={\"response\": response_text})\n",
- "\n",
- "# update the Chain span's end time\n",
- "chain_span._span.end_time_ms = llm_end_time_ms\n",
- "\n",
- "\n",
- "# part 4 - the Agent then calls a Tool...\n",
- "time.sleep(3)\n",
- "days_to_election = 117\n",
- "tool_end_time_ms = round(datetime.datetime.now().timestamp() * 1000)\n",
- "\n",
- "# create a Tool span\n",
- "tool_span = Trace(\n",
- " name=\"Calculator\",\n",
- " kind=\"tool\",\n",
- " status_code=\"success\",\n",
- " start_time_ms=llm_end_time_ms,\n",
- " end_time_ms=tool_end_time_ms,\n",
- " inputs={\"input\": response_text},\n",
- " outputs={\"result\": days_to_election})\n",
- "\n",
- "# add the TOOL span as a child of the root\n",
- "root_span.add_child(tool_span)\n",
- "\n",
- "\n",
- "# part 5 - the final results from the tool are added\n",
- "root_span.add_inputs_and_outputs(inputs={\"query\": query},\n",
- " outputs={\"result\": days_to_election})\n",
- "root_span._span.end_time_ms = tool_end_time_ms\n",
- "\n",
- "\n",
- "# part 6 - log all spans to W&B by logging the root span\n",
- "root_span.log(name=\"openai_trace\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "nBFVwawPGVmD"
- },
- "source": [
- "Once each Agent execution completes, all calls in your LangChain object will be logged to Weights & Biases"
- ]
- }
- ],
- "metadata": {
- "accelerator": "GPU",
- "colab": {
- "include_colab_link": true,
- "provenance": []
- },
- "kernelspec": {
- "display_name": "Python 3",
- "name": "python3"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 0
-}
diff --git a/colabs/prompts/WandB_LLM_QA_bot.ipynb b/colabs/prompts/WandB_LLM_QA_bot.ipynb
index 6cf57f5b..efc614e8 100644
--- a/colabs/prompts/WandB_LLM_QA_bot.ipynb
+++ b/colabs/prompts/WandB_LLM_QA_bot.ipynb
@@ -114,7 +114,15 @@
"metadata": {},
"outputs": [],
"source": [
- "import wandb\n",
+ "import wandb"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"wandb.login()"
]
},
diff --git a/colabs/pycaret/Default_Credit_Prediction_Using_W&B_Pycaret_FastAPI.ipynb b/colabs/pycaret/Default_Credit_Prediction_Using_W&B_Pycaret_FastAPI.ipynb
index 2de82091..1050389b 100644
--- a/colabs/pycaret/Default_Credit_Prediction_Using_W&B_Pycaret_FastAPI.ipynb
+++ b/colabs/pycaret/Default_Credit_Prediction_Using_W&B_Pycaret_FastAPI.ipynb
@@ -117,7 +117,15 @@
"metadata": {},
"outputs": [],
"source": [
- "import wandb\n",
+ "import wandb"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"wandb.login()"
]
},
diff --git a/colabs/pyg/8_Node_Classification_(with_W&B).ipynb b/colabs/pyg/8_Node_Classification_(with_W&B).ipynb
index c05f8d88..3c386a9a 100644
--- a/colabs/pyg/8_Node_Classification_(with_W&B).ipynb
+++ b/colabs/pyg/8_Node_Classification_(with_W&B).ipynb
@@ -45,8 +45,16 @@
"source": [
"enable_wandb = True\n",
"if enable_wandb:\n",
- " import wandb\n",
- " wandb.login()"
+ " import wandb"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "wandb.login()"
]
},
{
diff --git a/colabs/pyg/Graph_Classification_with_PyG_and_W&B.ipynb b/colabs/pyg/Graph_Classification_with_PyG_and_W&B.ipynb
index 3614737c..acf4a2e5 100644
--- a/colabs/pyg/Graph_Classification_with_PyG_and_W&B.ipynb
+++ b/colabs/pyg/Graph_Classification_with_PyG_and_W&B.ipynb
@@ -580,13 +580,6 @@
"\n",
"If you have any questions about using W&B to track your model performance and predictions, please contact support@wandb.com"
]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
}
],
"metadata": {
diff --git a/colabs/pyg/point-cloud-segmentation/00_eda.ipynb b/colabs/pyg/point-cloud-segmentation/00_eda.ipynb
index f4abc616..8f4b86b7 100644
--- a/colabs/pyg/point-cloud-segmentation/00_eda.ipynb
+++ b/colabs/pyg/point-cloud-segmentation/00_eda.ipynb
@@ -2,6 +2,7 @@
"cells": [
{
"cell_type": "markdown",
+ "id": "04fdcaca",
"metadata": {},
"source": [
"\n",
@@ -11,7 +12,7 @@
{
"attachments": {},
"cell_type": "markdown",
- "id": "57500bde-8f7c-4e37-aff7-3aea89793d2f",
+ "id": "aead1b54",
"metadata": {},
"source": [
"# 🔥🔥 Explore ShapeNet Dataset using PyTorch Geometric and Weights & Biases 🪄🐝\n",
@@ -31,7 +32,7 @@
},
{
"cell_type": "markdown",
- "id": "3fffd9cb-f1c6-4b1b-9de9-dc66f7d09f1f",
+ "id": "e4c9ae80",
"metadata": {},
"source": [
"# Install Required Packages"
@@ -40,7 +41,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "f7fc4b41-2b6f-4624-bdd7-dfa2decae773",
+ "id": "73e2ae10",
"metadata": {},
"outputs": [],
"source": [
@@ -53,7 +54,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "45e8441e-3b6d-4b24-abb9-deccbcc6c64a",
+ "id": "d47b401b",
"metadata": {},
"outputs": [],
"source": [
@@ -66,7 +67,7 @@
},
{
"cell_type": "markdown",
- "id": "8040f1cd-cf19-4028-a1f6-0009c8dda7aa",
+ "id": "1a4901c4",
"metadata": {},
"source": [
"## Import Libraries"
@@ -75,7 +76,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "f110ae50-84a9-4bb0-83d5-7f45f21a2b61",
+ "id": "5fd069a9",
"metadata": {},
"outputs": [],
"source": [
@@ -99,7 +100,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "2a83ea19-e50e-410c-a539-6eccdb68ae71",
+ "id": "86448a88",
"metadata": {},
"outputs": [],
"source": [
@@ -119,7 +120,7 @@
},
{
"cell_type": "markdown",
- "id": "35c96004-0dad-420d-b5f3-e01795d563ee",
+ "id": "b21a452e",
"metadata": {},
"source": [
"## Visualize Train-Val Dataset"
@@ -128,7 +129,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "010dba04-fa80-4a2b-9176-dadeeb23ca65",
+ "id": "a23f6557",
"metadata": {},
"outputs": [],
"source": [
@@ -145,7 +146,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "f29f2964-31eb-4375-a1f4-c5bb120b8232",
+ "id": "e4d8432b",
"metadata": {},
"outputs": [],
"source": [
@@ -171,7 +172,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "33adc481-bc6f-4773-adc9-1d4fb6a82e2d",
+ "id": "983e9511",
"metadata": {},
"outputs": [],
"source": [
@@ -187,7 +188,7 @@
},
{
"cell_type": "markdown",
- "id": "2e71c385-1b0f-428c-8559-e2fac1816261",
+ "id": "385112bc",
"metadata": {},
"source": [
"## Visualize Test Dataset"
@@ -196,7 +197,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "f893cc65-779b-426b-bc9c-397c7894957a",
+ "id": "b201d63f",
"metadata": {},
"outputs": [],
"source": [
@@ -211,7 +212,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "c8001b95-57d0-43b5-9092-d874ddbca536",
+ "id": "f12b2662",
"metadata": {},
"outputs": [],
"source": [
@@ -236,7 +237,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "e8cbb460-32ca-4814-9890-11cee3f3e22a",
+ "id": "79838e77",
"metadata": {},
"outputs": [],
"source": [
@@ -253,20 +254,12 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "d592a642-4ca7-40f8-af69-0a8b3b2693a8",
+ "id": "2b9e5b69",
"metadata": {},
"outputs": [],
"source": [
"wandb.finish()"
]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "e7364297-0751-4dc5-9661-4043d243018c",
- "metadata": {},
- "outputs": [],
- "source": []
}
],
"metadata": {
diff --git a/colabs/pyg/point-cloud-segmentation/01_dgcnn_train.ipynb b/colabs/pyg/point-cloud-segmentation/01_dgcnn_train.ipynb
index fd725f09..473f8eea 100644
--- a/colabs/pyg/point-cloud-segmentation/01_dgcnn_train.ipynb
+++ b/colabs/pyg/point-cloud-segmentation/01_dgcnn_train.ipynb
@@ -2,6 +2,7 @@
"cells": [
{
"cell_type": "markdown",
+ "id": "172865ee",
"metadata": {},
"source": [
"\n",
@@ -11,7 +12,7 @@
{
"attachments": {},
"cell_type": "markdown",
- "id": "cdbbe526-23ee-4a90-af8c-9f7092f50192",
+ "id": "75b3b204",
"metadata": {},
"source": [
"# 🔥🔥 Train DGCNN Model using PyTorch Geometric and Weights & Biases 🪄🐝\n",
@@ -29,7 +30,7 @@
},
{
"cell_type": "markdown",
- "id": "ed5c4e28-f40c-4c6a-827a-e28ff6db0501",
+ "id": "90904194",
"metadata": {},
"source": [
"# Install Required Packages"
@@ -38,7 +39,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "726d3107-dd0a-440a-a164-43b708af858f",
+ "id": "0fbee2a4",
"metadata": {},
"outputs": [],
"source": [
@@ -51,7 +52,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "4ebab99e-383f-410d-93a0-c8ff29dad2de",
+ "id": "7e98bc08",
"metadata": {},
"outputs": [],
"source": [
@@ -64,7 +65,7 @@
},
{
"cell_type": "markdown",
- "id": "dec996ce-98b9-465d-b530-e0f88ac7080d",
+ "id": "937b7a90",
"metadata": {},
"source": [
"## Import Libraries"
@@ -73,7 +74,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "c7f2a4cd-24e6-457d-b7f6-7fc94912fd12",
+ "id": "ddbd6831",
"metadata": {},
"outputs": [],
"source": [
@@ -97,7 +98,7 @@
},
{
"cell_type": "markdown",
- "id": "c71deb4e-f171-4a31-9801-a0e8f9cc9c04",
+ "id": "c9b5bbe7",
"metadata": {},
"source": [
"# Initialize Weights & Biases\n",
@@ -108,7 +109,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "a1053dec-59c9-4e9d-84e2-a1297dbf715f",
+ "id": "0f940177",
"metadata": {},
"outputs": [],
"source": [
@@ -147,7 +148,7 @@
},
{
"cell_type": "markdown",
- "id": "57a72e20-ff1b-4cdf-b87f-10baff51bac7",
+ "id": "b24e29f8",
"metadata": {},
"source": [
"# Load ShapeNet Dataset using PyTorch Geometric\n",
@@ -158,7 +159,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "42955441-03d8-4e9f-b856-7d7edc708981",
+ "id": "847e8f56",
"metadata": {},
"outputs": [],
"source": [
@@ -174,7 +175,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "da27bf0f-1e0f-438d-b602-2124606e7138",
+ "id": "70388abb",
"metadata": {},
"outputs": [],
"source": [
@@ -188,7 +189,7 @@
},
{
"cell_type": "markdown",
- "id": "3e335648-933d-4865-aaaf-0f24c0358d16",
+ "id": "7c9998ca",
"metadata": {},
"source": [
"Now, we need to offset the segmentation labels"
@@ -197,7 +198,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "9e14c4af-b5af-4805-a479-f0725d6ef1de",
+ "id": "47f0e8f2",
"metadata": {},
"outputs": [],
"source": [
@@ -217,7 +218,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "10e09ff7-e64b-4912-a30b-461d5eb72cb2",
+ "id": "507ef52b",
"metadata": {},
"outputs": [],
"source": [
@@ -229,7 +230,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "8bee5029-3bdd-4fa8-9a37-4d024c0cf6b5",
+ "id": "0a04d1ed",
"metadata": {},
"outputs": [],
"source": [
@@ -249,7 +250,7 @@
},
{
"cell_type": "markdown",
- "id": "6d1bc429-7d54-411a-b477-b7bbc7bfc3d6",
+ "id": "b3764ef6",
"metadata": {},
"source": [
"# Implementing the DGCNN Model using PyTorch Geometric"
@@ -258,7 +259,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "ac1bd7a8-585f-45fc-bc59-a8feec3512db",
+ "id": "95c329fc",
"metadata": {},
"outputs": [],
"source": [
@@ -290,7 +291,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "5a677e7b-9179-4c71-b1dc-af636dcf95d5",
+ "id": "fc1beb09",
"metadata": {},
"outputs": [],
"source": [
@@ -309,7 +310,7 @@
},
{
"cell_type": "markdown",
- "id": "3d94603b-51df-4299-ae69-f43617827d1c",
+ "id": "684a9044",
"metadata": {},
"source": [
"# Training DGCNN and Logging Metrics on Weights & Biases"
@@ -318,7 +319,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "1a9fbe02-b095-4df7-ab6b-af0dd6a411bf",
+ "id": "e775d74b",
"metadata": {},
"outputs": [],
"source": [
@@ -381,7 +382,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "94d7e550-19b1-40b8-9e01-2997a694a4fa",
+ "id": "bbc46ccd",
"metadata": {},
"outputs": [],
"source": [
@@ -441,7 +442,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "da4fad76-f961-4a0a-9f37-c72c1422e885",
+ "id": "7d28329c",
"metadata": {},
"outputs": [],
"source": [
@@ -507,7 +508,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "e8c31a63-96a4-409f-bb0d-2b56b034adf0",
+ "id": "fbddd81a",
"metadata": {},
"outputs": [],
"source": [
@@ -533,7 +534,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "6f013430-fd25-4541-9dd3-3c67eef9074a",
+ "id": "aebb71ed",
"metadata": {},
"outputs": [],
"source": [
@@ -558,7 +559,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "000c8944-6821-43a3-94d2-548f8b174334",
+ "id": "079d94f7",
"metadata": {},
"outputs": [],
"source": [
@@ -567,7 +568,7 @@
},
{
"cell_type": "markdown",
- "id": "b093dee2-e902-4ad8-af3c-d7126ef985c1",
+ "id": "00ba1b79",
"metadata": {},
"source": [
"Next, you can check out the following notebook to learn how to evaluate the model on the ShapeNetCore dataset using Weights & Biases, you can check out the following notebook:\n",
diff --git a/colabs/pyg/point-cloud-segmentation/02_dgcnn_evaluate.ipynb b/colabs/pyg/point-cloud-segmentation/02_dgcnn_evaluate.ipynb
index 86c581af..d591cc44 100644
--- a/colabs/pyg/point-cloud-segmentation/02_dgcnn_evaluate.ipynb
+++ b/colabs/pyg/point-cloud-segmentation/02_dgcnn_evaluate.ipynb
@@ -2,6 +2,7 @@
"cells": [
{
"cell_type": "markdown",
+ "id": "670bb5e6",
"metadata": {},
"source": [
"\n",
@@ -11,7 +12,7 @@
{
"attachments": {},
"cell_type": "markdown",
- "id": "9385afd2-4e70-4fc3-820b-576bdf035e42",
+ "id": "f99cc608",
"metadata": {},
"source": [
"# 🔥🔥 Evaluate DGCNN Model Weights & Biases 🪄🐝\n",
@@ -27,7 +28,7 @@
},
{
"cell_type": "markdown",
- "id": "1e79b763-e197-4a1d-9f3f-f89d873c12ed",
+ "id": "f26e6a2b",
"metadata": {},
"source": [
"# Install Required Packages"
@@ -36,7 +37,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "4e46eac0-9bf8-4c3b-906a-ca56343810ea",
+ "id": "16a12686",
"metadata": {},
"outputs": [],
"source": [
@@ -49,7 +50,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "e6a87ef2-abf9-4d71-a4ae-d58e3ebba706",
+ "id": "bd471ac2",
"metadata": {},
"outputs": [],
"source": [
@@ -62,7 +63,7 @@
},
{
"cell_type": "markdown",
- "id": "267f5a6e-fedd-48ef-9861-66e8ffb9fbbd",
+ "id": "52865cb0",
"metadata": {},
"source": [
"## Import Libraries"
@@ -71,7 +72,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "1d8cf4d8-6839-47fb-b789-857997e15107",
+ "id": "464e45dc",
"metadata": {},
"outputs": [],
"source": [
@@ -95,7 +96,7 @@
},
{
"cell_type": "markdown",
- "id": "0643258d-84be-4a10-a67e-0cf0331e07aa",
+ "id": "ec10ba94",
"metadata": {},
"source": [
"# Initialize Weights & Biases\n",
@@ -106,7 +107,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "308fec6f-7397-4c9d-a83d-c7b5632053b6",
+ "id": "88df5824",
"metadata": {},
"outputs": [],
"source": [
@@ -145,7 +146,7 @@
},
{
"cell_type": "markdown",
- "id": "c6f068eb-5b69-4c73-9e21-4f0d5eab0d9e",
+ "id": "7fb79f85",
"metadata": {},
"source": [
"# Load ShapeNet Dataset using PyTorch Geometric\n",
@@ -156,7 +157,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "208dc676-c0bf-4a3d-bc5a-5b4251664de9",
+ "id": "ed125325",
"metadata": {},
"outputs": [],
"source": [
@@ -172,7 +173,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "27014b34-cfff-47be-b761-db8c484dad97",
+ "id": "ca5b0b24",
"metadata": {},
"outputs": [],
"source": [
@@ -191,7 +192,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "e8d0e1df-7522-485b-b697-3dd066d85f8c",
+ "id": "827a5a3d",
"metadata": {},
"outputs": [],
"source": [
@@ -213,7 +214,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "2c67c4fa-92b4-495a-a247-c4e8a6c56d79",
+ "id": "0845c8e6",
"metadata": {},
"outputs": [],
"source": [
@@ -226,7 +227,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "5147d1c3-9772-4be9-8932-c9238a9270ef",
+ "id": "5880ea56",
"metadata": {},
"outputs": [],
"source": [
@@ -242,7 +243,7 @@
},
{
"cell_type": "markdown",
- "id": "1ce1b72e-cd7b-452f-ac08-dc2155561077",
+ "id": "1923436a",
"metadata": {},
"source": [
"# Load Checkpoint"
@@ -251,7 +252,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "d8a02d75-be84-436a-871c-ccadb863dc95",
+ "id": "f2d53569",
"metadata": {},
"outputs": [],
"source": [
@@ -282,7 +283,7 @@
},
{
"cell_type": "markdown",
- "id": "73c1f670-2a87-4083-9ec7-ed07785a5d66",
+ "id": "6c2b36ea",
"metadata": {},
"source": [
"Since we saved the checkpoints as artifacts on our Weights & Biases workspace, we can now fetch and load them."
@@ -291,7 +292,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "22cb064c-58b4-4887-b31b-d16622405b94",
+ "id": "9dba90fc",
"metadata": {},
"outputs": [],
"source": [
@@ -312,7 +313,7 @@
},
{
"cell_type": "markdown",
- "id": "a676bbf9-50be-4a99-9222-427c927bbe57",
+ "id": "53abd015",
"metadata": {},
"source": [
"# Evaluation"
@@ -321,7 +322,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "98e79961-e71c-4a2f-9040-72c6a5b32f09",
+ "id": "563deea4",
"metadata": {},
"outputs": [],
"source": [
@@ -392,7 +393,7 @@
},
{
"cell_type": "markdown",
- "id": "e0c5f0bb-01b1-48a3-980d-c22434f2be8a",
+ "id": "59fb65f4",
"metadata": {},
"source": [
"We evaluate the results and store them in a Weights & Biases Table."
@@ -401,7 +402,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "9427209c-216a-4887-9b3b-a73c474680c4",
+ "id": "fa945139",
"metadata": {},
"outputs": [],
"source": [
@@ -413,7 +414,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "e1688f54-ef52-4f42-a6f6-b5c3364da87c",
+ "id": "e8995e37",
"metadata": {},
"outputs": [],
"source": [
@@ -423,20 +424,12 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "3dfd8610-bda8-4e7b-9939-b1e7b7b4e603",
+ "id": "d4fddcf2",
"metadata": {},
"outputs": [],
"source": [
"wandb.finish()"
]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "87c44316-de14-4031-b6d2-ba8c9621a66d",
- "metadata": {},
- "outputs": [],
- "source": []
}
],
"metadata": {
diff --git a/colabs/pyg/pointnet-classification/00_eda.ipynb b/colabs/pyg/pointnet-classification/00_eda.ipynb
index 5bf56f5e..d60acb6d 100644
--- a/colabs/pyg/pointnet-classification/00_eda.ipynb
+++ b/colabs/pyg/pointnet-classification/00_eda.ipynb
@@ -2,6 +2,7 @@
"cells": [
{
"cell_type": "markdown",
+ "id": "5d2c4f72",
"metadata": {},
"source": [
"\n",
@@ -11,7 +12,7 @@
{
"attachments": {},
"cell_type": "markdown",
- "id": "927cc500",
+ "id": "af680e31",
"metadata": {},
"source": [
"# 🔥🔥 Explore ModelNet Datasets using PyTorch Geometric and Weights & Biases 🪄🐝\n",
@@ -22,7 +23,7 @@
{
"attachments": {},
"cell_type": "markdown",
- "id": "73b8e385",
+ "id": "7a9396ad",
"metadata": {},
"source": [
"## Install Required Libraries"
@@ -31,7 +32,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "c0e582cd-7d1e-4223-8352-a0b1a05f5adf",
+ "id": "b6c91a7f",
"metadata": {},
"outputs": [],
"source": [
@@ -44,7 +45,7 @@
{
"attachments": {},
"cell_type": "markdown",
- "id": "3441841e",
+ "id": "cde43ebd",
"metadata": {},
"source": [
"We now install PyTorch Geometric according to our PyTorch Version. We also install Weights & Biases."
@@ -53,7 +54,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "cb6c2bcc",
+ "id": "9fff54b7",
"metadata": {},
"outputs": [],
"source": [
@@ -67,7 +68,7 @@
{
"attachments": {},
"cell_type": "markdown",
- "id": "8f2cdd46",
+ "id": "aba4b30e",
"metadata": {},
"source": [
"### Import Libraries"
@@ -76,7 +77,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "bac559ed-172f-4fd3-b558-b0a659764fd0",
+ "id": "6acaec36",
"metadata": {},
"outputs": [],
"source": [
@@ -105,7 +106,7 @@
{
"attachments": {},
"cell_type": "markdown",
- "id": "1bd2a644",
+ "id": "46895a21",
"metadata": {},
"source": [
"## Initialize Weights & Biases\n",
@@ -116,7 +117,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "8aa7a7d3-e763-4d9e-a05b-1ffde74ef569",
+ "id": "b2a7db44",
"metadata": {},
"outputs": [],
"source": [
@@ -146,7 +147,7 @@
{
"attachments": {},
"cell_type": "markdown",
- "id": "5ce9b4f7",
+ "id": "efd2ad10",
"metadata": {},
"source": [
"## Load ModelNet Dataset using PyTorch Geometric"
@@ -155,7 +156,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "1573fe1d-c5d5-437f-a743-33432c45a1ff",
+ "id": "3dfb1e50",
"metadata": {},
"outputs": [],
"source": [
@@ -180,7 +181,7 @@
{
"attachments": {},
"cell_type": "markdown",
- "id": "e76ffe8d",
+ "id": "b038fb22",
"metadata": {},
"source": [
"## Log Data to [`wandb.Table`](https://docs.wandb.ai/ref/python/data-types/table)\n",
@@ -191,7 +192,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "274115e0-d682-48e0-a701-f4e8d3c276cc",
+ "id": "66342479",
"metadata": {},
"outputs": [],
"source": [
@@ -220,7 +221,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "7c7b7c12-3e1e-4f7b-b917-a213a488a910",
+ "id": "21978686",
"metadata": {},
"outputs": [],
"source": [
@@ -250,7 +251,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "f6f51d44-ed23-4dc0-996e-6efc3c4f7fa6",
+ "id": "46c5fea5",
"metadata": {},
"outputs": [],
"source": [
@@ -260,19 +261,13 @@
{
"attachments": {},
"cell_type": "markdown",
- "id": "b992544b",
+ "id": "e9efd334",
"metadata": {},
"source": [
"Next, you can check out the following notebook to learn how to compare different sampling strategies in PyTorch Geometric using Weights & Biases\n",
"\n",
"[![](https://colab.research.google.com/assets/colab-badge.svg)](http://wandb.me/pyg-sampling)"
]
- },
- {
- "cell_type": "markdown",
- "id": "5ad6bd7f",
- "metadata": {},
- "source": []
}
],
"metadata": {
diff --git a/colabs/pyg/pointnet-classification/01_compare_sampling.ipynb b/colabs/pyg/pointnet-classification/01_compare_sampling.ipynb
index 866be074..8930e642 100644
--- a/colabs/pyg/pointnet-classification/01_compare_sampling.ipynb
+++ b/colabs/pyg/pointnet-classification/01_compare_sampling.ipynb
@@ -2,6 +2,7 @@
"cells": [
{
"cell_type": "markdown",
+ "id": "fb05c584",
"metadata": {},
"source": [
"\n",
@@ -11,7 +12,7 @@
{
"attachments": {},
"cell_type": "markdown",
- "id": "75ae3439",
+ "id": "5df60be8",
"metadata": {},
"source": [
"# 🔥🔥 Explore Graph Sampling Techniques using PyTorch Geometric and Weights & Biases 🪄🐝\n",
@@ -26,7 +27,7 @@
{
"attachments": {},
"cell_type": "markdown",
- "id": "24bc9d53",
+ "id": "1a0b2ded",
"metadata": {},
"source": [
"## Install Required Libraries"
@@ -35,7 +36,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "9fbbb85e",
+ "id": "91f14e02",
"metadata": {},
"outputs": [],
"source": [
@@ -48,7 +49,7 @@
{
"attachments": {},
"cell_type": "markdown",
- "id": "0387e2f3",
+ "id": "dc2c1fd0",
"metadata": {},
"source": [
"We now install PyTorch Geometric according to our PyTorch Version. We also install Weights & Biases."
@@ -57,7 +58,7 @@
{
"attachments": {},
"cell_type": "markdown",
- "id": "54e2c502",
+ "id": "67d64c57",
"metadata": {},
"source": [
"!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html\n",
@@ -70,7 +71,7 @@
{
"attachments": {},
"cell_type": "markdown",
- "id": "e5e35477",
+ "id": "94b145bd",
"metadata": {},
"source": [
"### Import Libraries"
@@ -79,7 +80,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "9899f13f-7d7b-43c2-873a-07d23b75a1f2",
+ "id": "f2d40f2f",
"metadata": {},
"outputs": [],
"source": [
@@ -106,7 +107,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "581da785-44d2-4618-b84e-aaed217f0da6",
+ "id": "c40b87cd",
"metadata": {},
"outputs": [],
"source": [
@@ -134,7 +135,7 @@
{
"attachments": {},
"cell_type": "markdown",
- "id": "e19aa6d1",
+ "id": "97c44de0",
"metadata": {},
"source": [
"We take a single point cloud from the dataset and compare the KNN-sampled subgraph and radius-sampled subgraph by visualizing the subgraphs as [`wandb.Html`](https://docs.wandb.ai/ref/python/data-types/html) on a [Weights & Biases Table](https://docs.wandb.ai/guides/data-vis)."
@@ -143,7 +144,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "e550b0c2-7b97-469b-b8da-dfb234ddfd65",
+ "id": "a59c651e",
"metadata": {},
"outputs": [],
"source": [
@@ -183,7 +184,7 @@
{
"attachments": {},
"cell_type": "markdown",
- "id": "1441daab",
+ "id": "0059f0b5",
"metadata": {},
"source": [
"Next, you can check out the following notebook to learn how to train the PointNet++ architecture using PyTorch Geometric and Weights & Biases\n",
diff --git a/colabs/pyg/pointnet-classification/02_pointnet_plus_plus.ipynb b/colabs/pyg/pointnet-classification/02_pointnet_plus_plus.ipynb
index 26b36c74..e1216fbe 100644
--- a/colabs/pyg/pointnet-classification/02_pointnet_plus_plus.ipynb
+++ b/colabs/pyg/pointnet-classification/02_pointnet_plus_plus.ipynb
@@ -2,6 +2,7 @@
"cells": [
{
"cell_type": "markdown",
+ "id": "c0f2b7bb",
"metadata": {},
"source": [
"\n",
@@ -11,7 +12,7 @@
{
"attachments": {},
"cell_type": "markdown",
- "id": "8e6f5935",
+ "id": "0a089e90",
"metadata": {},
"source": [
"# 🔥🔥 Train PointNet++ Model using PyTorch Geometric and Weights & Biases 🪄🐝\n",
@@ -28,7 +29,7 @@
{
"attachments": {},
"cell_type": "markdown",
- "id": "183d8da6",
+ "id": "74844944",
"metadata": {},
"source": [
"## Install Required Libraries"
@@ -37,7 +38,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "5349d817",
+ "id": "316dcba0",
"metadata": {},
"outputs": [],
"source": [
@@ -51,7 +52,7 @@
{
"attachments": {},
"cell_type": "markdown",
- "id": "c2d4ce37",
+ "id": "4380ba4f",
"metadata": {},
"source": [
"We now install PyTorch Geometric according to our PyTorch Version. We also install Weights & Biases."
@@ -60,7 +61,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "5ed00f24",
+ "id": "1fda5066",
"metadata": {},
"outputs": [],
"source": [
@@ -74,7 +75,7 @@
{
"attachments": {},
"cell_type": "markdown",
- "id": "3fcc8028",
+ "id": "d6616543",
"metadata": {},
"source": [
"### Import Libraries"
@@ -83,7 +84,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "c38e68ed-e9b6-452b-913a-42af29cf3b43",
+ "id": "cbc13bbc",
"metadata": {},
"outputs": [],
"source": [
@@ -105,7 +106,7 @@
{
"attachments": {},
"cell_type": "markdown",
- "id": "0393be9e",
+ "id": "46e5b154",
"metadata": {},
"source": [
"## Initialize Weights & Biases\n",
@@ -116,7 +117,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "f345aa58-8ed3-4d0b-9925-8d776d6576d6",
+ "id": "783f6a3d",
"metadata": {},
"outputs": [],
"source": [
@@ -162,7 +163,7 @@
{
"attachments": {},
"cell_type": "markdown",
- "id": "3c06b3dc",
+ "id": "524732b9",
"metadata": {},
"source": [
"## Load ModelNet Dataset using PyTorch Geometric\n",
@@ -173,7 +174,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "0a7c5bb1-30e1-4a6e-be2a-a375172d5571",
+ "id": "12a74c5f",
"metadata": {},
"outputs": [],
"source": [
@@ -224,7 +225,7 @@
{
"attachments": {},
"cell_type": "markdown",
- "id": "f8b2e6e4",
+ "id": "4960beac",
"metadata": {},
"source": [
"## Implementing the PointNet++ Model using PyTorch Geometric"
@@ -233,7 +234,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "1724081c-235a-49ff-9d1f-0ec2e06e1a4e",
+ "id": "ea4fe7e2",
"metadata": {},
"outputs": [],
"source": [
@@ -258,7 +259,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "17957287-c51c-469f-9669-302d6b9903e1",
+ "id": "e266ddad",
"metadata": {},
"outputs": [],
"source": [
@@ -278,7 +279,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "27720171-ca69-49a2-9c7f-66cbc4527e67",
+ "id": "e838fc04",
"metadata": {},
"outputs": [],
"source": [
@@ -318,7 +319,7 @@
{
"attachments": {},
"cell_type": "markdown",
- "id": "f9b7e440",
+ "id": "973729a2",
"metadata": {},
"source": [
"## Training PointNet++ and Logging Metrics on Weights & Biases"
@@ -327,7 +328,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "33cf5e53-6cdc-4037-b126-1201ed711205",
+ "id": "c524e338",
"metadata": {},
"outputs": [],
"source": [
@@ -349,7 +350,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "f86c06ce-d212-48b0-984f-bc14582cf716",
+ "id": "b753e2c0",
"metadata": {},
"outputs": [],
"source": [
@@ -363,8 +364,9 @@
" range(num_train_examples),\n",
" desc=f\"Training Epoch {epoch}/{config.epochs}\"\n",
" )\n",
+ " data_iter = iter(train_loader)\n",
" for batch_idx in progress_bar:\n",
- " data = next(iter(train_loader)).to(device)\n",
+ " data = next(data_iter).to(device)\n",
" \n",
" optimizer.zero_grad()\n",
" prediction = model(data)\n",
@@ -394,8 +396,9 @@
" range(num_val_examples),\n",
" desc=f\"Validation Epoch {epoch}/{config.epochs}\"\n",
" )\n",
+ " data_iter = iter(val_loader)\n",
" for batch_idx in progress_bar:\n",
- " data = next(iter(val_loader)).to(device)\n",
+ " data = next(data_iter).to(device)\n",
" \n",
" with torch.no_grad():\n",
" prediction = model(data)\n",
@@ -463,7 +466,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "3d7e93ee-7346-4119-91d5-9871880e2507",
+ "id": "42b37a86",
"metadata": {},
"outputs": [],
"source": [
@@ -488,7 +491,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "3b035bac-8438-4f96-a545-3f0e071490c4",
+ "id": "88bbb06a",
"metadata": {},
"outputs": [],
"source": [
@@ -498,7 +501,7 @@
{
"attachments": {},
"cell_type": "markdown",
- "id": "05441179",
+ "id": "e236382d",
"metadata": {},
"source": [
"Next, you can check out the following notebook to learn how to run a hyperparameter sweep on our PointNet++ trainig loop using Weights & Biases:\n",
diff --git a/colabs/pyg/pointnet-classification/03_sweep.ipynb b/colabs/pyg/pointnet-classification/03_sweep.ipynb
index 09845a97..691f09b6 100644
--- a/colabs/pyg/pointnet-classification/03_sweep.ipynb
+++ b/colabs/pyg/pointnet-classification/03_sweep.ipynb
@@ -2,6 +2,7 @@
"cells": [
{
"cell_type": "markdown",
+ "id": "a24744b3",
"metadata": {},
"source": [
"\n",
@@ -11,7 +12,7 @@
{
"attachments": {},
"cell_type": "markdown",
- "id": "d767398b",
+ "id": "7b2b8865",
"metadata": {},
"source": [
"# 🔥🔥 Run a Hyperparamter Sweep on PointNet++ 🪄🐝\n",
@@ -26,7 +27,7 @@
{
"attachments": {},
"cell_type": "markdown",
- "id": "9803ffe3",
+ "id": "c5d6108b",
"metadata": {},
"source": [
"## Install Required Libraries"
@@ -35,7 +36,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "f9004c21-0b42-4935-837c-9f4197ded5ec",
+ "id": "c8467ae5",
"metadata": {},
"outputs": [],
"source": [
@@ -49,7 +50,7 @@
{
"attachments": {},
"cell_type": "markdown",
- "id": "5dac741c",
+ "id": "7d43577a",
"metadata": {},
"source": [
"We now install PyTorch Geometric according to our PyTorch Version. We also install Weights & Biases."
@@ -58,7 +59,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "28080574",
+ "id": "f6d2ea50",
"metadata": {},
"outputs": [],
"source": [
@@ -72,7 +73,7 @@
{
"attachments": {},
"cell_type": "markdown",
- "id": "749b65e3",
+ "id": "3ca54e60",
"metadata": {},
"source": [
"### Import Libraries"
@@ -81,7 +82,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "ffe08f49-0e83-42b7-b404-54955ca98620",
+ "id": "e268c311",
"metadata": {},
"outputs": [],
"source": [
@@ -103,7 +104,7 @@
{
"attachments": {},
"cell_type": "markdown",
- "id": "ae6f606d",
+ "id": "452b8915",
"metadata": {},
"source": [
"## Function to Build Data Loaders"
@@ -112,7 +113,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "8eaf84b3-9588-4b6b-82b6-b7e276021ab0",
+ "id": "b941afd2",
"metadata": {},
"outputs": [],
"source": [
@@ -154,7 +155,7 @@
{
"attachments": {},
"cell_type": "markdown",
- "id": "c7527796",
+ "id": "1739fa2d",
"metadata": {},
"source": [
"## Implementing the PointNet++ Model using PyTorch Geometric"
@@ -163,7 +164,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "101cf26d-76ed-4174-8061-c06d7ea480be",
+ "id": "5c7954c1",
"metadata": {},
"outputs": [],
"source": [
@@ -225,7 +226,7 @@
{
"attachments": {},
"cell_type": "markdown",
- "id": "25e67ec4",
+ "id": "377b785a",
"metadata": {},
"source": [
"## Define a Training Function Instrumented with WandB"
@@ -234,7 +235,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "ba5b572b-5041-4fa4-af37-028e5ff2c49d",
+ "id": "127f4dcb",
"metadata": {},
"outputs": [],
"source": [
@@ -285,8 +286,9 @@
" range(num_train_examples),\n",
" desc=f\"Training Epoch {epoch}/{config.epochs}\"\n",
" )\n",
+ " data_iter = iter(train_loader)\n",
" for batch_idx in progress_bar:\n",
- " data = next(iter(train_loader)).to(device)\n",
+ " data = next(data_iter).to(device)\n",
"\n",
" optimizer.zero_grad()\n",
" prediction = model(data)\n",
@@ -314,8 +316,9 @@
" range(num_val_examples),\n",
" desc=f\"Validation Epoch {epoch}/{config.epochs}\"\n",
" )\n",
+ " data_iter = iter(val_loader)\n",
" for batch_idx in progress_bar:\n",
- " data = next(iter(val_loader)).to(device)\n",
+ " data = next(data_iter).to(device)\n",
"\n",
" with torch.no_grad():\n",
" prediction = model(data)\n",
@@ -358,7 +361,7 @@
{
"attachments": {},
"cell_type": "markdown",
- "id": "f2767675",
+ "id": "6f903aee",
"metadata": {},
"source": [
"## Start the Hyperparameter Sweep"
@@ -367,7 +370,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "cec4ce63-92bb-4fa0-a04e-81ba9b7337f2",
+ "id": "9c4c673b",
"metadata": {},
"outputs": [],
"source": [
@@ -393,14 +396,6 @@
"# Run Sweep\n",
"wandb.agent(sweep_id, function=train, count=30)"
]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "5680dad3-f42c-4b6a-b3e7-4dedeedfc88c",
- "metadata": {},
- "outputs": [],
- "source": []
}
],
"metadata": {
diff --git a/colabs/pytorch-lightning/Fine_tuning_a_Transformer_with_Pytorch_Lightning.ipynb b/colabs/pytorch-lightning/Fine_tuning_a_Transformer_with_Pytorch_Lightning.ipynb
index ac7348d7..3ddb4919 100644
--- a/colabs/pytorch-lightning/Fine_tuning_a_Transformer_with_Pytorch_Lightning.ipynb
+++ b/colabs/pytorch-lightning/Fine_tuning_a_Transformer_with_Pytorch_Lightning.ipynb
@@ -64,7 +64,7 @@
"outputs": [],
"source": [
"# Install some dependencies\n",
- "!pip install pandas torch pytorch-lightning transformers==4.1.1 -q\n",
+ "!pip install pandas torch lightning transformers\n",
"!pip install -Uq wandb"
]
},
@@ -81,7 +81,7 @@
"import transformers\n",
"import numpy as np\n",
"import pandas as pd\n",
- "import pytorch_lightning as pl"
+ "import lightning.pytorch as pl"
]
},
{
@@ -102,11 +102,17 @@
"source": [
"\"\"\"\n",
"Note that if you are using W&B local you will need to pass the url of your W&B \n",
- "deployment to wandb.login through the host keyword argument.\n",
"\n",
"For example:\n",
- "wandb.login(host=\"api.wandb.ai\")\n",
- "\"\"\"\n",
+ "\"\"\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"wandb.login()"
]
},
@@ -239,7 +245,7 @@
" \n",
" # Download the raw cola data from the 'zipfile' reference we added to the cola-raw artifact.\n",
" raw_data_artifact = run.use_artifact(\"cola-raw:latest\")\n",
- " zip_path = raw_data_artifact.get_path(\"zipfile\").download()\n",
+ " zip_path = raw_data_artifact.get_entry(\"zipfile\").download()\n",
" !unzip -o $zip_path # jupyter hack to unzip data :P\n",
" \n",
" # Read in the raw data, log it to W&B as a wandb.Table\n",
@@ -292,7 +298,7 @@
"\n",
" # Download the preprocessed data\n",
" pp_data_artifact = run.use_artifact(\"preprocessed-data:latest\")\n",
- " data_path = pp_data_artifact.get_path(\"dataset\").download()\n",
+ " data_path = pp_data_artifact.get_entry(\"dataset\").download()\n",
" dataset = torch.load(data_path)\n",
"\n",
" # Calculate the number of samples to include in each set.\n",
@@ -404,8 +410,8 @@
"\n",
" # Load the datasets from the split-dataset artifact\n",
" data = run.use_artifact(\"split-dataset:latest\")\n",
- " train_dataset = torch.load(data.get_path(\"train-data\").download())\n",
- " val_dataset = torch.load(data.get_path(\"validation-data\").download())\n",
+ " train_dataset = torch.load(data.get_entry(\"train-data\").download())\n",
+ " val_dataset = torch.load(data.get_entry(\"validation-data\").download())\n",
"\n",
" # Extract the config object associated with the run\n",
" config = run.config\n",
@@ -420,7 +426,7 @@
" gpus = -1 if torch.cuda.is_available() else 0\n",
" \n",
" # Construct a Trainer object with the W&B logger we created and epoch set by the config object\n",
- " trainer = pl.Trainer(max_epochs=config.epochs, gpus=gpus, logger=logger)\n",
+ " trainer = pl.Trainer(max_epochs=config.epochs, logger=logger)\n",
" \n",
" # Build data loaders for our datasets, using the batch_size from our config object\n",
" train_data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.batch_size)\n",
@@ -456,11 +462,6 @@
"\n"
]
},
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": []
- },
{
"cell_type": "markdown",
"metadata": {},
@@ -535,6 +536,18 @@
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.2"
}
},
"nbformat": 4,
diff --git a/colabs/pytorch-lightning/Image_Classification_using_PyTorch_Lightning.ipynb b/colabs/pytorch-lightning/Image_Classification_using_PyTorch_Lightning.ipynb
index ecd1ffa8..82c9402b 100644
--- a/colabs/pytorch-lightning/Image_Classification_using_PyTorch_Lightning.ipynb
+++ b/colabs/pytorch-lightning/Image_Classification_using_PyTorch_Lightning.ipynb
@@ -27,7 +27,7 @@
"\n",
"# Image Classification using PyTorch Lightning ⚡️\n",
"\n",
- "We will build an image classification pipeline using PyTorch Lightning. We will follow this [style guide](https://pytorch-lightning.readthedocs.io/en/stable/starter/style_guide.html) to increase the readability and reproducibility of our code. A cool explanation of this available [here](https://wandb.ai/wandb/wandb-lightning/reports/Image-Classification-using-PyTorch-Lightning--VmlldzoyODk1NzY)."
+ "We will build an image classification pipeline using PyTorch Lightning. We will follow this [style guide](https://lightning.ai/docs/pytorch/stable/starter/style_guide.html) to increase the readability and reproducibility of our code. A cool explanation of this available [here](https://wandb.ai/wandb/wandb-lightning/reports/Image-Classification-using-PyTorch-Lightning--VmlldzoyODk1NzY)."
]
},
{
@@ -46,7 +46,7 @@
"metadata": {},
"outputs": [],
"source": [
- "!pip install pytorch-lightning -q\n",
+ "!pip install lightning torchvision -q\n",
"# install weights and biases\n",
"!pip install wandb -qU"
]
@@ -65,9 +65,9 @@
"metadata": {},
"outputs": [],
"source": [
- "import pytorch_lightning as pl\n",
+ "import lightning.pytorch as pl\n",
"# your favorite machine learning tracking tool\n",
- "from pytorch_lightning.loggers import WandbLogger\n",
+ "from lightning.pytorch.loggers import WandbLogger\n",
"\n",
"import torch\n",
"from torch import nn\n",
@@ -115,7 +115,7 @@
"- Apply transforms (rotate, tokenize, etc…).\n",
"- Wrap inside a DataLoader.\n",
"\n",
- "Learn more about datamodules [here](https://pytorch-lightning.readthedocs.io/en/stable/extensions/datamodules.html). Let's build a datamodule for the Cifar-10 dataset. "
+ "Learn more about datamodules [here](https://lightning.ai/docs/pytorch/stable/data/datamodule.html). Let's build a datamodule for the Cifar-10 dataset. "
]
},
{
@@ -168,8 +168,8 @@
"source": [
"## 📱 Callbacks\n",
"\n",
- "A callback is a self-contained program that can be reused across projects. PyTorch Lightning comes with few [built-in callbacks](https://pytorch-lightning.readthedocs.io/en/latest/extensions/callbacks.html#built-in-callbacks) which are regularly used. \n",
- "Learn more about callbacks in PyTorch Lightning [here](https://pytorch-lightning.readthedocs.io/en/latest/extensions/callbacks.html)."
+ "A callback is a self-contained program that can be reused across projects. PyTorch Lightning comes with few [built-in callbacks](https://lightning.ai/docs/pytorch/latest/extensions/callbacks.html#built-in-callbacks) which are regularly used. \n",
+ "Learn more about callbacks in PyTorch Lightning [here](https://lightning.ai/docs/pytorch/latest/extensions/callbacks.html)."
]
},
{
@@ -179,7 +179,7 @@
"source": [
"### Built-in Callbacks\n",
"\n",
- "In this tutorial, we will use [Early Stopping](https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.EarlyStopping.html#pytorch_lightning.callbacks.EarlyStopping) and [Model Checkpoint](https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.ModelCheckpoint.html#pytorch_lightning.callbacks.ModelCheckpoint) built-in callbacks. They can be passed to the `Trainer`.\n"
+ "In this tutorial, we will use [Early Stopping](https://lightning.ai/docs/pytorch/latest/api/lightning.pytorch.callbacks.EarlyStopping.html#lightning.callbacks.EarlyStopping) and [Model Checkpoint](https://lightning.ai/docs/pytorch/latest/api/lightning.pytorch.callbacks.ModelCheckpoint.html#pytorch_lightning.callbacks.ModelCheckpoint) built-in callbacks. They can be passed to the `Trainer`.\n"
]
},
{
@@ -437,16 +437,9 @@
"I hope you find this report helpful. I will encourage to play with the code and train an image classifier with a dataset of your choice. \n",
"\n",
"Here are some resources to learn more about PyTorch Lightning:\n",
- "- [Step-by-step walk-through](https://pytorch-lightning.readthedocs.io/en/latest/starter/introduction.html) - This is one of the official tutorials. Their documentation is really well written and I highly encourage it as a good learning resource.\n",
+ "- [Step-by-step walk-through](https://lightning.ai/docs/pytorch/latest/starter/introduction.html) - This is one of the official tutorials. Their documentation is really well written and I highly encourage it as a good learning resource.\n",
"- [Use Pytorch Lightning with Weights & Biases](https://wandb.me/lightning) - This is a quick colab that you can run through to learn more about how to use W&B with PyTorch Lightning."
]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
}
],
"metadata": {
@@ -459,6 +452,18 @@
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.2"
}
},
"nbformat": 4,
diff --git a/colabs/pytorch-lightning/Optimize_Pytorch_Lightning_models_with_Weights_&_Biases.ipynb b/colabs/pytorch-lightning/Optimize_Pytorch_Lightning_models_with_Weights_&_Biases.ipynb
index 268e5dbb..3cf40163 100644
--- a/colabs/pytorch-lightning/Optimize_Pytorch_Lightning_models_with_Weights_&_Biases.ipynb
+++ b/colabs/pytorch-lightning/Optimize_Pytorch_Lightning_models_with_Weights_&_Biases.ipynb
@@ -28,8 +28,8 @@
"Coupled with the [Weights & Biases integration](https://docs.wandb.com/library/integrations/lightning), you can quickly train and monitor models for full traceability and reproducibility with only 2 extra lines of code:\n",
"\n",
"```python\n",
- "from pytorch_lightning.loggers import WandbLogger\n",
- "from pytorch_lightning import Trainer\n",
+ "from lightning.pytorch.loggers import WandbLogger\n",
+ "from lightning.pytorch import Trainer\n",
"\n",
"wandb_logger = WandbLogger()\n",
"trainer = Trainer(logger=wandb_logger)\n",
@@ -64,7 +64,7 @@
"metadata": {},
"outputs": [],
"source": [
- "!pip install -q pytorch-lightning wandb"
+ "!pip install -q lightning wandb torchvision"
]
},
{
@@ -80,7 +80,15 @@
"metadata": {},
"outputs": [],
"source": [
- "import wandb\n",
+ "import wandb"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"wandb.login()"
]
},
@@ -142,6 +150,15 @@
"* Call self.log in `training_step` and `validation_step` to log the metrics"
]
},
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import lightning.pytorch as pl"
+ ]
+ },
{
"cell_type": "code",
"execution_count": null,
@@ -152,9 +169,8 @@
"from torch.nn import Linear, CrossEntropyLoss, functional as F\n",
"from torch.optim import Adam\n",
"from torchmetrics.functional import accuracy\n",
- "from pytorch_lightning import LightningModule\n",
"\n",
- "class MNIST_LitModule(LightningModule):\n",
+ "class MNIST_LitModule(pl.LightningModule):\n",
"\n",
" def __init__(self, n_classes=10, n_layer_1=128, n_layer_2=256, lr=1e-3):\n",
" '''method used to define our model parameters'''\n",
@@ -265,7 +281,7 @@
"metadata": {},
"outputs": [],
"source": [
- "from pytorch_lightning.callbacks import ModelCheckpoint\n",
+ "from lightning.pytorch.callbacks import ModelCheckpoint\n",
"\n",
"checkpoint_callback = ModelCheckpoint(monitor='val_accuracy', mode='max')"
]
@@ -276,9 +292,9 @@
"source": [
"## 💡 Tracking Experiments with WandbLogger\n",
"\n",
- "PyTorch Lightning has a `WandbLogger` to easily log your experiments with Wights & Biases. Just pass it to your `Trainer` to log to W&B. See the [WandbLogger docs](https://pytorch-lightning.readthedocs.io/en/stable/extensions/generated/pytorch_lightning.loggers.WandbLogger.html#pytorch_lightning.loggers.WandbLogger) for all parameters. Note, to log the metrics to a specific W&B Team, pass your Team name to the `entity` argument in `WandbLogger`\n",
+ "PyTorch Lightning has a `WandbLogger` to easily log your experiments with Wights & Biases. Just pass it to your `Trainer` to log to W&B. See the [WandbLogger docs](https://lightning.ai/docs/pytorch/stable/extensions/generated/pytorch_lightning.loggers.WandbLogger.html#pytorch_lightning.loggers.WandbLogger) for all parameters. Note, to log the metrics to a specific W&B Team, pass your Team name to the `entity` argument in `WandbLogger`\n",
"\n",
- "#### `pytorch_lightning.loggers.WandbLogger()`\n",
+ "#### `lightning.pytorch.loggers.WandbLogger()`\n",
"\n",
"| Functionality | Argument/Function | PS |\n",
"| ------ | ------ | ------ |\n",
@@ -287,9 +303,9 @@
"| Organize runs by project | `WandbLogger(... ,project='my_project')` | |\n",
"| Log histograms of gradients and parameters | `WandbLogger.watch(model)` | `WandbLogger.watch(model, log='all')` to log parameter histograms |\n",
"| Log hyperparameters | Call `self.save_hyperparameters()` within `LightningModule.__init__()` |\n",
- "| Log custom objects (images, audio, video, molecules…) | Use `WandbLogger.log_text`, `WandbLogger.log_image` and `WandbLogger.log_table` |\n",
+ "| Log custom objects (images, audio, video, molecules…) | Use `WandbLogger.log_text`, `WandbLogger.log_image` and `WandbLogger.log_table`, etc. |\n",
"\n",
- "See the [WandbLogger docs](https://pytorch-lightning.readthedocs.io/en/stable/extensions/generated/pytorch_lightning.loggers.WandbLogger.html#pytorch_lightning.loggers.WandbLogger) here for all parameters. "
+ "See the [WandbLogger docs](https://lightning.ai/docs/pytorch/stable/extensions/generated/pytorch_lightning.loggers.WandbLogger.html#pytorch_lightning.loggers.WandbLogger) here for all parameters. "
]
},
{
@@ -298,8 +314,8 @@
"metadata": {},
"outputs": [],
"source": [
- "from pytorch_lightning.loggers import WandbLogger\n",
- "from pytorch_lightning import Trainer\n",
+ "from lightning.pytorch.loggers import WandbLogger\n",
+ "from lightning.pytorch import Trainer\n",
"\n",
"wandb_logger = WandbLogger(project='MNIST', # group runs in \"MNIST\" project\n",
" log_model='all') # log all new checkpoints during training"
@@ -326,12 +342,12 @@
"metadata": {},
"outputs": [],
"source": [
- "from pytorch_lightning.callbacks import Callback\n",
+ "from lightning.pytorch.callbacks import Callback\n",
" \n",
"class LogPredictionsCallback(Callback):\n",
" \n",
" def on_validation_batch_end(\n",
- " self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):\n",
+ " self, trainer, pl_module, outputs, batch, batch_idx):\n",
" \"\"\"Called when the validation batch ends.\"\"\"\n",
" \n",
" # `outputs` comes from `LightningModule.validation_step`\n",
diff --git a/colabs/pytorch-lightning/Profile_PyTorch_Code.ipynb b/colabs/pytorch-lightning/Profile_PyTorch_Code.ipynb
index 573e035a..35040872 100644
--- a/colabs/pytorch-lightning/Profile_PyTorch_Code.ipynb
+++ b/colabs/pytorch-lightning/Profile_PyTorch_Code.ipynb
@@ -88,7 +88,7 @@
"metadata": {},
"outputs": [],
"source": [
- "!pip install -q wandb pytorch_lightning torch_tb_profiler"
+ "!pip install -q wandb lightning torch_tb_profiler torchvision"
]
},
{
@@ -99,7 +99,7 @@
"source": [
"import glob\n",
"\n",
- "import pytorch_lightning as pl\n",
+ "import lightning.pytorch as pl\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
@@ -113,8 +113,15 @@
"# drop slow mirror from list of MNIST mirrors\n",
"torchvision.datasets.MNIST.mirrors = [mirror for mirror in torchvision.datasets.MNIST.mirrors\n",
" if not mirror.startswith(\"http://yann.lecun.com\")]\n",
- " \n",
- "# login to W&B\n",
+ " "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"wandb.login()"
]
},
@@ -350,7 +357,7 @@
" with profiler:\n",
" profiler_callback = TorchTensorboardProfilerCallback(profiler)\n",
"\n",
- " trainer = pl.Trainer(gpus=1, max_epochs=1, max_steps=total_steps,\n",
+ " trainer = pl.Trainer(max_epochs=1, max_steps=total_steps,\n",
" logger=pl.loggers.WandbLogger(log_model=True, save_code=True),\n",
" callbacks=[profiler_callback], precision=wandb.config.precision)\n",
"\n",
diff --git a/colabs/pytorch-lightning/Supercharge_your_Training_with_Pytorch_Lightning_and_Weights_and_Biases.ipynb b/colabs/pytorch-lightning/Supercharge_your_Training_with_Pytorch_Lightning_and_Weights_and_Biases.ipynb
index 36b8af7a..018a1ad6 100644
--- a/colabs/pytorch-lightning/Supercharge_your_Training_with_Pytorch_Lightning_and_Weights_and_Biases.ipynb
+++ b/colabs/pytorch-lightning/Supercharge_your_Training_with_Pytorch_Lightning_and_Weights_and_Biases.ipynb
@@ -1,855 +1,969 @@
{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "\n",
- ""
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "\n",
- "\n",
- "\n",
- "\n",
- "# ⚡ 💘 🏋️♀️ Supercharge your Training with PyTorch Lightning + Weights & Biases"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": []
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- ""
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "At Weights & Biases, we love anything\n",
- "that makes training deep learning models easier.\n",
- "That's why we worked with the folks at PyTorch Lightning to\n",
- "[integrate our experiment tracking tool](https://docs.wandb.com/library/integrations/lightning)\n",
- "directly into\n",
- "[the Lightning library](https://pytorch-lightning.readthedocs.io/en/latest/common/loggers.html#weights-and-biases).\n",
- "\n",
- "[PyTorch Lightning](https://pytorch-lightning.readthedocs.io/en/stable/) is a lightweight wrapper for organizing your PyTorch code and easily adding advanced features such as distributed training and 16-bit precision.\n",
- "It retains all the flexibility of PyTorch,\n",
- "in case you need it,\n",
- "but adds some useful abstractions\n",
- "and builds in some best practices.\n",
- "\n",
- "## What this notebook covers:\n",
- "\n",
- "1. Differences between PyTorch and PyTorch Lightning, including how to set up `LightningModules` and `LightningDataModules`\n",
- "2. How to get basic metric logging with the [`WandbLogger`](https://pytorch-lightning.readthedocs.io/en/latest/common/loggers.html#weights-and-biases)\n",
- "3. How to log media with W&B and fully customize logging with Lightning `Callbacks`\n",
- "\n",
- "## The interactive dashboard in W&B will look like this:\n",
- "\n",
- "![](https://i.imgur.com/lIbMyFR.png)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Follow along with a [video tutorial](http://wandb.me/lit-video)!"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# 🚀 Installing and importing"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "`wandb` and `pytorch-lightning` are both easily installable via [`pip`](https://pip.pypa.io/en/stable/)."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "!pip install -qqq wandb pytorch-lightning torchmetrics"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "PyTorch Lightning is built on top of PyTorch,\n",
- "so we still need to import vanilla PyTorch."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# numpy for non-GPU array math\n",
- "import numpy as np\n",
- "\n",
- "# 🍦 Vanilla PyTorch\n",
- "import torch\n",
- "from torch.nn import functional as F\n",
- "from torch import nn\n",
- "from torch.utils.data import DataLoader, random_split\n",
- "\n",
- "# 👀 Torchvision for CV\n",
- "from torchvision.datasets import MNIST\n",
- "from torchvision import transforms\n",
- "\n",
- "# remove slow mirror from list of MNIST mirrors\n",
- "MNIST.mirrors = [mirror for mirror in MNIST.mirrors\n",
- " if not mirror.startswith(\"http://yann.lecun.com\")]"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Much of Lightning is built on the [Modules](https://pytorch.org/docs/stable/generated/torch.nn.Module.html)\n",
- "API from PyTorch,\n",
- "but adds extra features\n",
- "(like data loading and logging)\n",
- "that are common to lots of PyTorch projects.\n",
- "\n",
- "Let's bring those in,\n",
- "plus W&B and the integration.\n",
- "\n",
- "Lastly, we log in to the [Weights & Biases web service](https://wandb.ai).\n",
- "If you've never used W&B,\n",
- "you'll need to sign up first.\n",
- "Accounts are free forever for academic and public projects."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# ⚡ PyTorch Lightning\n",
- "import pytorch_lightning as pl\n",
- "import torchmetrics\n",
- "pl.seed_everything(hash(\"setting random seeds\") % 2**32 - 1)\n",
- "\n",
- "# 🏋️♀️ Weights & Biases\n",
- "import wandb\n",
- "\n",
- "# ⚡ 🤝 🏋️♀️\n",
- "from pytorch_lightning.loggers import WandbLogger\n",
- "\n",
- "wandb.login()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "> _Note_: If you're executing your training in a terminal, rather than a notebook, you don't need to include `wandb.login()` in your script.\n",
- "Instead, call `wandb login` in the terminal and we'll keep you logged in for future runs."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# 🏗️ Building a Model with Lightning\n",
- "\n",
- "In PyTorch Lightning, models are built with `LightningModule` ([docs here](https://pytorch-lightning.readthedocs.io/en/latest/lightning_module.html)), which has all the functionality of a vanilla `torch.nn.Module` (🍦) but with a few delicious cherries of added functionality on top (🍨).\n",
- "These cherries are there to cut down on boilerplate and\n",
- "help separate out the ML engineering code\n",
- "from the actual machine learning.\n",
- "\n",
- "For example, the mechanics of iterating over batches\n",
- "as part of an epoch are extracted away,\n",
- "so long as you define what happens on the `training_step`.\n",
- "\n",
- "To make a working model out of a `LightningModule`,\n",
- "we need to define a new `class` and add a few methods on top.\n",
- "\n",
- "We'll demonstrate this process with `LitMLP`,\n",
- "which applies a two-layer perceptron\n",
- "(aka two fully-connected layers and\n",
- "a fully-connected softmax readout layer)\n",
- "to input `Tensors`.\n",
- "\n",
- "> _Note_: It is common in the Lightning community to shorten \"Lightning\" to \"[Lit](https://www.urbandictionary.com/define.php?term=it%27s%20lit)\".\n",
- "This sometimes it sound like\n",
- "[your code was written by Travis Scott](https://www.youtube.com/watch?v=y3FCXV8oEZU).\n",
- "We consider this a good thing."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 🍦 `__init__` and `forward`\n",
- "\n",
- "First, we need to add two methods that\n",
- "are part of any vanilla PyTorch model."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Those methods are:\n",
- "* `__init__` to do any setup, just like any Python class\n",
- "* `forward` for inference, just like a PyTorch Module\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "The `forward` pass method is standard,\n",
- "and it'll be different for every project,\n",
- "so we won't comment on it.\n",
- "\n",
- "The `__init__` method,\n",
- "which `init`ializes new instances of the class,\n",
- "is a good place to log hyperparameter information to `wandb`.\n",
- "\n",
- "This is done with the `save_hyperparameters` method,\n",
- "which captures all of the arguments to the initializer\n",
- "and adds them to a dictionary at `self.hparams` --\n",
- "that all comes for free as part of the `LightningModule`.\n",
- "\n",
- "> _Note_: `hparams` is logged to `wandb` as the `config`,\n",
- "so you'll never lose track of the arguments you used to run a model again!"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "class LitMLP(pl.LightningModule):\n",
- "\n",
- " def __init__(self, in_dims, n_classes=10,\n",
- " n_layer_1=128, n_layer_2=256, lr=1e-4):\n",
- " super().__init__()\n",
- "\n",
- " # we flatten the input Tensors and pass them through an MLP\n",
- " self.layer_1 = nn.Linear(np.prod(in_dims), n_layer_1)\n",
- " self.layer_2 = nn.Linear(n_layer_1, n_layer_2)\n",
- " self.layer_3 = nn.Linear(n_layer_2, n_classes)\n",
- "\n",
- " # log hyperparameters\n",
- " self.save_hyperparameters()\n",
- "\n",
- " # compute the accuracy -- no need to roll your own!\n",
- " self.train_acc = torchmetrics.Accuracy()\n",
- " self.valid_acc = torchmetrics.Accuracy()\n",
- " self.test_acc = torchmetrics.Accuracy()\n",
- "\n",
- " def forward(self, x):\n",
- " \"\"\"\n",
- " Defines a forward pass using the Stem-Learner-Task\n",
- " design pattern from Deep Learning Design Patterns:\n",
- " https://www.manning.com/books/deep-learning-design-patterns\n",
- " \"\"\"\n",
- " batch_size, *dims = x.size()\n",
- "\n",
- " # stem: flatten\n",
- " x = x.view(batch_size, -1)\n",
- "\n",
- " # learner: two fully-connected layers\n",
- " x = F.relu(self.layer_1(x))\n",
- " x = F.relu(self.layer_2(x))\n",
- " \n",
- " # task: compute class logits\n",
- " x = self.layer_3(x)\n",
- " x = F.log_softmax(x, dim=1)\n",
- "\n",
- " return x\n",
- "\n",
- " # convenient method to get the loss on a batch\n",
- " def loss(self, xs, ys):\n",
- " logits = self(xs) # this calls self.forward\n",
- " loss = F.nll_loss(logits, ys)\n",
- " return logits, loss"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "> _Note_: for pedagogical purposes, we're splitting out\n",
- "each stage of building the `LitMLP` into a different cell.\n",
- "In a more typical workflow,\n",
- "this would all happen in the `class` definition.\n",
- "\n",
- "> _Note_: if you're familiar with PyTorch,\n",
- "you might be surprised to see we aren't taking care with `.device`s:\n",
- "no `to_cuda` etc. PyTorch Lightning handles all that for you! 😎"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 🍨 `training_step` and `configure_optimizers`\n",
- "Now, we add some special methods so that our `LitMLP` can be trained\n",
- "using PyTorch Lightning's training API.\n",
- "\n",
- "> _Note_: if you've used Keras, this might be familiar.\n",
- "It's very similar to the `.fit` API in that library."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Those methods are\n",
- "\n",
- "* `training_step`, which takes a batch and computes the loss; backprop goes through it\n",
- "* `configure_optimizers`, which returns the `torch.optim.Optimizer` to apply after the `training_step`"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "> _Note_: `training_step` is part of a rich system of callbacks in PyTorch Lightning.\n",
- "These callbacks are methods that get called\n",
- "at specific points during training\n",
- "(e.g. when a validation epoch ends),\n",
- "and they are a major part of what makes\n",
- "PyTorch Lightning both useful and extensible."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Here's where we add some more serious logging code.\n",
- "`self.log` takes a name and value for a metric.\n",
- "Under the hood, this will get passed to `wandb.log` if you're using W&B.\n",
- "\n",
- "The logging behavior of PyTorch Lightning is both intelligent and configurable.\n",
- "For example, by passing the `on_epoch`\n",
- "keyword argument here,\n",
- "we'll get `_epoch`-wise averages\n",
- "of the metrics logged on each `_step`,\n",
- "and those metrics will be named differently\n",
- "in the W&B interface.\n",
- "When training in a distributed setting,\n",
- "these averages will be automatically computed across nodes.\n",
- "\n",
- "Read more about the `log` method [in the docs](https://pytorch-lightning.readthedocs.io/en/latest/lightning_module.html#log)."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "def training_step(self, batch, batch_idx):\n",
- " xs, ys = batch\n",
- " logits, loss = self.loss(xs, ys)\n",
- " preds = torch.argmax(logits, 1)\n",
- "\n",
- " # logging metrics we calculated by hand\n",
- " self.log('train/loss', loss, on_epoch=True)\n",
- " # logging a pl.Metric\n",
- " self.train_acc(preds, ys)\n",
- " self.log('train/acc', self.train_acc, on_epoch=True)\n",
- " \n",
- " return loss\n",
- "\n",
- "def configure_optimizers(self):\n",
- " return torch.optim.Adam(self.parameters(), lr=self.hparams[\"lr\"])\n",
- "\n",
- "LitMLP.training_step = training_step\n",
- "LitMLP.configure_optimizers = configure_optimizers"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## ➕ Optional methods for even better logging\n",
- "\n",
- "The code above will log our model's performance,\n",
- "system metrics, and more to W&B.\n",
- "\n",
- "If we want to take our logging to the next level,\n",
- "we need to make use of PyTorch Lightning's callback system.\n",
- "\n",
- "> _Note_: thanks to the clean design of PyTorch Lightning,\n",
- "the training code below will run with or without any\n",
- "of this extra logging code. Nice!"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "The other callbacks we'll make use of fall into two categories:\n",
- "* methods that trigger on each batch for a dataset: `validation_step` and `test_step`\n",
- "* methods that trigger at the end of an epoch,\n",
- "or a full pass over a given dataset: `{training, validation, test}_epoch_end`"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### 💾 `test`ing and saving the model"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "We use the test set to evaluate the performance of the final model,\n",
- "so the `test` callbacks will be called at the end of the training pipeline.\n",
- "\n",
- "For performance on the `test` and `validation` sets,\n",
- "we're typically less concerned about how\n",
- "we do on intermediate steps and more\n",
- "with how we did overall.\n",
- "That's why below, we pass in\n",
- "`on_step=False` and `on_epoch=True`\n",
- "so that we log only `epoch`-wise metrics.\n",
- "\n",
- "> _Note_: That's actually the default behavior for `.log` when it's called inside of a `validation` or a `test` loop -- but not when it's called inside a `training` loop! Check out the table of default behaviors for `.log` [in the docs](https://pytorch-lightning.readthedocs.io/en/latest/lightning_module.html#log)."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "def test_step(self, batch, batch_idx):\n",
- " xs, ys = batch\n",
- " logits, loss = self.loss(xs, ys)\n",
- " preds = torch.argmax(logits, 1)\n",
- "\n",
- " self.test_acc(preds, ys)\n",
- " self.log(\"test/loss_epoch\", loss, on_step=False, on_epoch=True)\n",
- " self.log(\"test/acc_epoch\", self.test_acc, on_step=False, on_epoch=True)\n",
- "\n",
- "LitMLP.test_step = test_step"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "We'll also take the opportunity to save the model in the\n",
- "[portable `ONNX` format](https://onnx.ai/).\n",
- "\n",
- "\n",
- "Later,\n",
- "we'll see that this allows us to use the\n",
- "[Netron model viewer](https://github.com/lutzroeder/netron) in W&B."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "def test_epoch_end(self, test_step_outputs): # args are defined as part of pl API\n",
- " dummy_input = torch.zeros(self.hparams[\"in_dims\"], device=self.device)\n",
- " model_filename = \"model_final.onnx\"\n",
- " self.to_onnx(model_filename, dummy_input, export_params=True)\n",
- " artifact = wandb.Artifact(name=\"model.ckpt\", type=\"model\")\n",
- " artifact.add_file(model_filename)\n",
- " wandb.log_artifact(artifact)\n",
- "\n",
- "LitMLP.test_epoch_end = test_epoch_end"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### 📊 Logging `Histograms`"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "For the `validation_data`,\n",
- "let's track not only the `acc`uracy and `loss`,\n",
- "but also the `logits`:\n",
- "the un-normalized class probabilities.\n",
- "That way, we can track if our network\n",
- "is becoming more or less confident over time.\n",
- "\n",
- "There's a problem though:\n",
- "`.log` wants to average,\n",
- "but we'd rather look at a distribution.\n",
- "\n",
- "So instead, on every `validation_step`,\n",
- "we'll `return` the `logits`,\n",
- "rather than `log`ging them.\n",
- "\n",
- "Then, when we reach the `end`\n",
- "of the `validation_epoch`,\n",
- "the `logits` are available as the\n",
- "`validation_step_outputs` -- a list.\n",
- "\n",
- "So to log we'll take those `logits`,\n",
- "concatenate them together,\n",
- "and turn them into a histogram with [`wandb.Histogram`](https://docs.wandb.com/library/log#histograms).\n",
- "\n",
- "Because we're no longer using Lightning's `.log` interface and are instead using `wandb`,\n",
- "we need to drop down a level and use\n",
- "`self.experiment.logger.log`."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "def validation_step(self, batch, batch_idx):\n",
- " xs, ys = batch\n",
- " logits, loss = self.loss(xs, ys)\n",
- " preds = torch.argmax(logits, 1)\n",
- " self.valid_acc(preds, ys)\n",
- "\n",
- " self.log(\"valid/loss_epoch\", loss) # default on val/test is on_epoch only\n",
- " self.log('valid/acc_epoch', self.valid_acc)\n",
- " \n",
- " return logits\n",
- "\n",
- "def validation_epoch_end(self, validation_step_outputs):\n",
- " dummy_input = torch.zeros(self.hparams[\"in_dims\"], device=self.device)\n",
- " model_filename = f\"model_{str(self.global_step).zfill(5)}.onnx\"\n",
- " torch.onnx.export(self, dummy_input, model_filename, opset_version=11)\n",
- " artifact = wandb.Artifact(name=\"model.ckpt\", type=\"model\")\n",
- " artifact.add_file(model_filename)\n",
- " self.logger.experiment.log_artifact(artifact)\n",
- "\n",
- " flattened_logits = torch.flatten(torch.cat(validation_step_outputs))\n",
- " self.logger.experiment.log(\n",
- " {\"valid/logits\": wandb.Histogram(flattened_logits.to(\"cpu\")),\n",
- " \"global_step\": self.global_step})\n",
- "\n",
- "LitMLP.validation_step = validation_step\n",
- "LitMLP.validation_epoch_end = validation_epoch_end"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Note that we're once again saving\n",
- "the model in ONNX format.\n",
- "That way, we can roll back our model to any given epoch --\n",
- "useful in case the evaluation on the test set reveals we've overfit."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### 📲 `Callback`s for extra-fancy logging"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "What we've done so far\n",
- "will tell us how well our model\n",
- "is using our system resources,\n",
- "how well our model is training and generalizing,\n",
- "and how confident it is.\n",
- "\n",
- "But DNNs often fail in pernicious and silent ways.\n",
- "Often, the only way to notice these failures\n",
- "is to look at how the model is doing\n",
- "on specific examples.\n",
- "\n",
- "So let's additionally log some detailed information on some specific examples:\n",
- "the inputs, outputs,\n",
- "and `pred`ictions.\n",
- "\n",
- "We'll do this by writing our own `Callback` --\n",
- "one that, after every `validation_epoch` ends,\n",
- "logs input images and output predictions\n",
- "using W&B's `Image` logger.\n",
- "\n",
- "> _Note_:\n",
- "For more on the W&B media toolkit, read the [docs](https://docs.wandb.com/library/log#media)\n",
- "or check out\n",
- "[this Colab](http://wandb.me/media-colab)\n",
- "to see everything it's capable of."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "class ImagePredictionLogger(pl.Callback):\n",
- " def __init__(self, val_samples, num_samples=32):\n",
- " super().__init__()\n",
- " self.val_imgs, self.val_labels = val_samples\n",
- " self.val_imgs = self.val_imgs[:num_samples]\n",
- " self.val_labels = self.val_labels[:num_samples]\n",
- " \n",
- " def on_validation_epoch_end(self, trainer, pl_module):\n",
- " val_imgs = self.val_imgs.to(device=pl_module.device)\n",
- "\n",
- " logits = pl_module(val_imgs)\n",
- " preds = torch.argmax(logits, 1)\n",
- "\n",
- " trainer.logger.experiment.log({\n",
- " \"examples\": [wandb.Image(x, caption=f\"Pred:{pred}, Label:{y}\") \n",
- " for x, pred, y in zip(val_imgs, preds, self.val_labels)],\n",
- " \"global_step\": trainer.global_step\n",
- " })"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# 🛒 Loading data\n",
- "\n",
- "Data pipelines can be created with:\n",
- "* 🍦 Vanilla Pytorch `DataLoaders`\n",
- "* ⚡ Pytorch Lightning `DataModules`\n",
- "\n",
- "`DataModules` are more structured definition, which allows for additional optimizations such as automated distribution of workload between CPU & GPU.\n",
- "Using `DataModules` is recommended whenever possible!\n",
- "\n",
- "A `DataModule` is also defined by an interface:\n",
- "* `prepare_data` (optional) which is called only once and on 1 GPU -- typically something like the data download step we have below\n",
- "* `setup`, which is called on each GPU separately and accepts `stage` to define if we are at `fit` or `test` step\n",
- "* `train_dataloader`, `val_dataloader` and `test_dataloader` to load each dataset"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "class MNISTDataModule(pl.LightningDataModule):\n",
- "\n",
- " def __init__(self, data_dir='./', batch_size=128):\n",
- " super().__init__()\n",
- " self.data_dir = data_dir\n",
- " self.batch_size = batch_size\n",
- " self.transform = transforms.Compose([\n",
- " transforms.ToTensor(),\n",
- " transforms.Normalize((0.1307,), (0.3081,))])\n",
- "\n",
- " def prepare_data(self):\n",
- " # download data, train then test\n",
- " MNIST(self.data_dir, train=True, download=True)\n",
- " MNIST(self.data_dir, train=False, download=True)\n",
- "\n",
- " def setup(self, stage=None):\n",
- "\n",
- " # we set up only relevant datasets when stage is specified\n",
- " if stage == 'fit' or stage is None:\n",
- " mnist = MNIST(self.data_dir, train=True, transform=self.transform)\n",
- " self.mnist_train, self.mnist_val = random_split(mnist, [55000, 5000])\n",
- " if stage == 'test' or stage is None:\n",
- " self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)\n",
- "\n",
- " # we define a separate DataLoader for each of train/val/test\n",
- " def train_dataloader(self):\n",
- " mnist_train = DataLoader(self.mnist_train, batch_size=self.batch_size)\n",
- " return mnist_train\n",
- "\n",
- " def val_dataloader(self):\n",
- " mnist_val = DataLoader(self.mnist_val, batch_size=10 * self.batch_size)\n",
- " return mnist_val\n",
- "\n",
- " def test_dataloader(self):\n",
- " mnist_test = DataLoader(self.mnist_test, batch_size=10 * self.batch_size)\n",
- " return mnist_test\n",
- "\n",
- "# setup data\n",
- "mnist = MNISTDataModule()\n",
- "mnist.prepare_data()\n",
- "mnist.setup()\n",
- "\n",
- "# grab samples to log predictions on\n",
- "samples = next(iter(mnist.val_dataloader()))"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# 👟 Making a `Trainer`\n",
- "\n",
- "The `DataLoader` and the `LightningModule`\n",
- "are brought together by a `Trainer`,\n",
- "which orchestrates data loading,\n",
- "gradient calculation,\n",
- "optimizer logic,\n",
- "and logging. \n",
- "\n",
- "Luckily, we don't need to sub-class the `Trainer`,\n",
- "we just need to configure it with keyword arguments."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "And that is where we'll use the `pytorch_lightning.loggers.WandbLogger` to connect our logging to W&B."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "wandb_logger = WandbLogger(project=\"lit-wandb\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "> _Note_: Check out [the documentation](https://docs.wandb.com/library/integrations/lightning) for customization options. I like `group`s and `tag`s!.\n",
- "\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "We can then set up our `Trainer` and customize several options, such as gradient accumulation, half precision training and distributed computing.\n",
- "\n",
- "We'll stick to the basics for this example,\n",
- "but half-precision training and easy scaling to distributed settings are two of the major reasons why folks like PyTorch Lightning!"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "trainer = pl.Trainer(\n",
- " logger=wandb_logger, # W&B integration\n",
- " log_every_n_steps=50, # set the logging frequency\n",
- " gpus=-1, # use all GPUs\n",
- " max_epochs=5, # number of epochs\n",
- " deterministic=True, # keep it deterministic\n",
- " callbacks=[ImagePredictionLogger(samples)] # see Callbacks section\n",
- " )"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# 🏃♀️ Running our Model"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Now, let's make it all happen:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# setup model\n",
- "model = LitMLP(in_dims=(1, 28, 28))\n",
- "\n",
- "# fit the model\n",
- "trainer.fit(model, mnist)\n",
- "\n",
- "# evaluate the model on a test set\n",
- "trainer.test(datamodule=mnist,\n",
- " ckpt_path=None) # uses last-saved model\n",
- "\n",
- "wandb.finish()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "> _Note_: In notebooks, we need to call `wandb.finish()` to indicate when we've finished our run. This isn't necessary in scripts."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Viewing the results on wandb.ai\n",
- "\n",
- "Among the outputs from W&B,\n",
- "you will have noticed a few URLs.\n",
- "One of these is the\n",
- "[run page](https://docs.wandb.ai/ref/app/pages/run-page),\n",
- "which has a dashboard with all of the information logged in this run, complete with smart default charts\n",
- "and more.\n",
- "The run page is printed both at the start and end of training, and ends with `lit-wandb/runs/{run_id}`.\n",
- "\n",
- ">_Note_: When visiting your run page, it is recommended to use `global_step` as x-axis to correctly superimpose metrics logged in different stages.\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "\n",
- "![image.png]()"
- ]
- }
- ],
- "metadata": {
- "accelerator": "GPU",
- "colab": {
- "include_colab_link": true,
- "provenance": [],
- "toc_visible": true
- },
- "kernelspec": {
- "display_name": "Python 3",
- "name": "python3"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 0
-}
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "z5w0NbyRVCKD"
+ },
+ "source": [
+ "\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "7PI6YWwrVCKF"
+ },
+ "source": [
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "# ⚡ 💘 🏋️♀️ Supercharge your Training with PyTorch Lightning + Weights & Biases"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "mPqv_lacVCKF"
+ },
+ "source": [
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "zJwmNjhlVCKF"
+ },
+ "source": [
+ "At Weights & Biases, we love anything\n",
+ "that makes training deep learning models easier.\n",
+ "That's why we worked with the folks at PyTorch Lightning to\n",
+ "[integrate our experiment tracking tool](https://docs.wandb.com/library/integrations/lightning)\n",
+ "directly into\n",
+ "[the Lightning library](https://pytorch-lightning.readthedocs.io/en/latest/common/loggers.html#weights-and-biases).\n",
+ "\n",
+ "[PyTorch Lightning](https://pytorch-lightning.readthedocs.io/en/stable/) is a lightweight wrapper for organizing your PyTorch code and easily adding advanced features such as distributed training and 16-bit precision.\n",
+ "It retains all the flexibility of PyTorch,\n",
+ "in case you need it,\n",
+ "but adds some useful abstractions\n",
+ "and builds in some best practices.\n",
+ "\n",
+ "## What this notebook covers:\n",
+ "\n",
+ "1. Differences between PyTorch and PyTorch Lightning, including how to set up `LightningModules` and `LightningDataModules`\n",
+ "2. How to get basic metric logging with the [`WandbLogger`](https://pytorch-lightning.readthedocs.io/en/latest/common/loggers.html#weights-and-biases)\n",
+ "3. How to log media with W&B and fully customize logging with Lightning `Callbacks`\n",
+ "\n",
+ "## The interactive dashboard in W&B will look like this:\n",
+ "\n",
+ "![](https://i.imgur.com/lIbMyFR.png)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "3-dKMdt4VCKG"
+ },
+ "source": [
+ "## Follow along with a [video tutorial](http://wandb.me/lit-video)!"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Y-QLq_y4VCKG"
+ },
+ "source": [
+ "# 🚀 Installing and importing"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "LgFEVQPiVCKG"
+ },
+ "source": [
+ "`wandb` and `pytorch-lightning` are both easily installable via [`pip`](https://pip.pypa.io/en/stable/)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "YKrH3BHxVCKG"
+ },
+ "outputs": [],
+ "source": [
+ "!pip install -qqq wandb lightning torchmetrics onnx"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "HyRgs61mVCKG"
+ },
+ "source": [
+ "PyTorch Lightning is built on top of PyTorch,\n",
+ "so we still need to import vanilla PyTorch."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "UksNGoWZVCKH"
+ },
+ "outputs": [],
+ "source": [
+ "# numpy for non-GPU array math\n",
+ "import numpy as np\n",
+ "\n",
+ "# 🍦 Vanilla PyTorch\n",
+ "import torch\n",
+ "from torch.nn import functional as F\n",
+ "from torch import nn\n",
+ "from torch.utils.data import DataLoader, random_split\n",
+ "\n",
+ "# 👀 Torchvision for CV\n",
+ "from torchvision.datasets import MNIST\n",
+ "from torchvision import transforms\n",
+ "\n",
+ "# remove slow mirror from list of MNIST mirrors\n",
+ "MNIST.mirrors = [mirror for mirror in MNIST.mirrors\n",
+ " if not mirror.startswith(\"http://yann.lecun.com\")]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "kOwGIep1VCKH"
+ },
+ "source": [
+ "Much of Lightning is built on the [Modules](https://pytorch.org/docs/stable/generated/torch.nn.Module.html)\n",
+ "API from PyTorch,\n",
+ "but adds extra features\n",
+ "(like data loading and logging)\n",
+ "that are common to lots of PyTorch projects.\n",
+ "\n",
+ "Let's bring those in,\n",
+ "plus W&B and the integration.\n",
+ "\n",
+ "Lastly, we log in to the [Weights & Biases web service](https://wandb.ai).\n",
+ "If you've never used W&B,\n",
+ "you'll need to sign up first.\n",
+ "Accounts are free forever for academic and public projects."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "6SpApyF_VCKH"
+ },
+ "outputs": [],
+ "source": [
+ "# ⚡ PyTorch Lightning\n",
+ "import lightning.pytorch as pl\n",
+ "import torchmetrics\n",
+ "pl.seed_everything(hash(\"setting random seeds\") % 2**32 - 1)\n",
+ "\n",
+ "# 🏋️♀️ Weights & Biases\n",
+ "import wandb\n",
+ "\n",
+ "# ⚡ 🤝 🏋️♀️\n",
+ "from lightning.pytorch.loggers import WandbLogger\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "WLJp3GnxVCKH"
+ },
+ "outputs": [],
+ "source": [
+ "wandb.login()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "bw6unCdQVCKH"
+ },
+ "source": [
+ "> _Note_: If you're executing your training in a terminal, rather than a notebook, you don't need to include `wandb.login()` in your script.\n",
+ "Instead, call `wandb login` in the terminal and we'll keep you logged in for future runs."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "VHlNYD9iVCKH"
+ },
+ "source": [
+ "# 🏗️ Building a Model with Lightning\n",
+ "\n",
+ "In PyTorch Lightning, models are built with `LightningModule` ([docs here](https://pytorch-lightning.readthedocs.io/en/latest/lightning_module.html)), which has all the functionality of a vanilla `torch.nn.Module` (🍦) but with a few delicious cherries of added functionality on top (🍨).\n",
+ "These cherries are there to cut down on boilerplate and\n",
+ "help separate out the ML engineering code\n",
+ "from the actual machine learning.\n",
+ "\n",
+ "For example, the mechanics of iterating over batches\n",
+ "as part of an epoch are extracted away,\n",
+ "so long as you define what happens on the `training_step`.\n",
+ "\n",
+ "To make a working model out of a `LightningModule`,\n",
+ "we need to define a new `class` and add a few methods on top.\n",
+ "\n",
+ "We'll demonstrate this process with `LitMLP`,\n",
+ "which applies a two-layer perceptron\n",
+ "(aka two fully-connected layers and\n",
+ "a fully-connected softmax readout layer)\n",
+ "to input `Tensors`.\n",
+ "\n",
+ "> _Note_: It is common in the Lightning community to shorten \"Lightning\" to \"[Lit](https://www.urbandictionary.com/define.php?term=it%27s%20lit)\".\n",
+ "This sometimes it sound like\n",
+ "[your code was written by Travis Scott](https://www.youtube.com/watch?v=y3FCXV8oEZU).\n",
+ "We consider this a good thing."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "RkIVeyM9VCKI"
+ },
+ "source": [
+ "## 🍦 `__init__` and `forward`\n",
+ "\n",
+ "First, we need to add two methods that\n",
+ "are part of any vanilla PyTorch model."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "coQH1peAVCKI"
+ },
+ "source": [
+ "Those methods are:\n",
+ "* `__init__` to do any setup, just like any Python class\n",
+ "* `forward` for inference, just like a PyTorch Module\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "xr46KZ1ZVCKI"
+ },
+ "source": [
+ "The `forward` pass method is standard,\n",
+ "and it'll be different for every project,\n",
+ "so we won't comment on it.\n",
+ "\n",
+ "The `__init__` method,\n",
+ "which `init`ializes new instances of the class,\n",
+ "is a good place to log hyperparameter information to `wandb`.\n",
+ "\n",
+ "This is done with the `save_hyperparameters` method,\n",
+ "which captures all of the arguments to the initializer\n",
+ "and adds them to a dictionary at `self.hparams` --\n",
+ "that all comes for free as part of the `LightningModule`.\n",
+ "\n",
+ "> _Note_: `hparams` is logged to `wandb` as the `config`,\n",
+ "so you'll never lose track of the arguments you used to run a model again!"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "-16a4P3nVCKI"
+ },
+ "outputs": [],
+ "source": [
+ "class LitMLP(pl.LightningModule):\n",
+ "\n",
+ " def __init__(self, in_dims, n_classes=10,\n",
+ " n_layer_1=128, n_layer_2=256, lr=1e-4):\n",
+ " super().__init__()\n",
+ "\n",
+ " # we flatten the input Tensors and pass them through an MLP\n",
+ " self.layer_1 = nn.Linear(np.prod(in_dims), n_layer_1)\n",
+ " self.layer_2 = nn.Linear(n_layer_1, n_layer_2)\n",
+ " self.layer_3 = nn.Linear(n_layer_2, n_classes)\n",
+ "\n",
+ " # log hyperparameters\n",
+ " self.save_hyperparameters()\n",
+ "\n",
+ " # compute the accuracy -- no need to roll your own!\n",
+ " self.train_acc = torchmetrics.Accuracy(task=\"multiclass\", num_classes=n_classes)\n",
+ " self.valid_acc = torchmetrics.Accuracy(task=\"multiclass\", num_classes=n_classes)\n",
+ " self.test_acc = torchmetrics.Accuracy(task=\"multiclass\", num_classes=n_classes)\n",
+ "\n",
+ " def forward(self, x):\n",
+ " \"\"\"\n",
+ " Defines a forward pass using the Stem-Learner-Task\n",
+ " design pattern from Deep Learning Design Patterns:\n",
+ " https://www.manning.com/books/deep-learning-design-patterns\n",
+ " \"\"\"\n",
+ " batch_size, *dims = x.size()\n",
+ "\n",
+ " # stem: flatten\n",
+ " x = x.view(batch_size, -1)\n",
+ "\n",
+ " # learner: two fully-connected layers\n",
+ " x = F.relu(self.layer_1(x))\n",
+ " x = F.relu(self.layer_2(x))\n",
+ "\n",
+ " # task: compute class logits\n",
+ " x = self.layer_3(x)\n",
+ " x = F.log_softmax(x, dim=1)\n",
+ "\n",
+ " return x\n",
+ "\n",
+ " # convenient method to get the loss on a batch\n",
+ " def loss(self, xs, ys):\n",
+ " logits = self(xs) # this calls self.forward\n",
+ " loss = F.nll_loss(logits, ys)\n",
+ " return logits, loss"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "XHIwTEqwVCKI"
+ },
+ "source": [
+ "> _Note_: for pedagogical purposes, we're splitting out\n",
+ "each stage of building the `LitMLP` into a different cell.\n",
+ "In a more typical workflow,\n",
+ "this would all happen in the `class` definition.\n",
+ "\n",
+ "> _Note_: if you're familiar with PyTorch,\n",
+ "you might be surprised to see we aren't taking care with `.device`s:\n",
+ "no `to_cuda` etc. PyTorch Lightning handles all that for you! 😎"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "q9n788xQVCKI"
+ },
+ "source": [
+ "## 🍨 `training_step` and `configure_optimizers`\n",
+ "Now, we add some special methods so that our `LitMLP` can be trained\n",
+ "using PyTorch Lightning's training API.\n",
+ "\n",
+ "> _Note_: if you've used Keras, this might be familiar.\n",
+ "It's very similar to the `.fit` API in that library."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "0PTKSV_4VCKI"
+ },
+ "source": [
+ "Those methods are\n",
+ "\n",
+ "* `training_step`, which takes a batch and computes the loss; backprop goes through it\n",
+ "* `configure_optimizers`, which returns the `torch.optim.Optimizer` to apply after the `training_step`"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "KL-Yuk-DVCKI"
+ },
+ "source": [
+ "> _Note_: `training_step` is part of a rich system of callbacks in PyTorch Lightning.\n",
+ "These callbacks are methods that get called\n",
+ "at specific points during training\n",
+ "(e.g. when a validation epoch ends),\n",
+ "and they are a major part of what makes\n",
+ "PyTorch Lightning both useful and extensible."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "S_xhIcNbVCKI"
+ },
+ "source": [
+ "Here's where we add some more serious logging code.\n",
+ "`self.log` takes a name and value for a metric.\n",
+ "Under the hood, this will get passed to `wandb.log` if you're using W&B.\n",
+ "\n",
+ "The logging behavior of PyTorch Lightning is both intelligent and configurable.\n",
+ "For example, by passing the `on_epoch`\n",
+ "keyword argument here,\n",
+ "we'll get `_epoch`-wise averages\n",
+ "of the metrics logged on each `_step`,\n",
+ "and those metrics will be named differently\n",
+ "in the W&B interface.\n",
+ "When training in a distributed setting,\n",
+ "these averages will be automatically computed across nodes.\n",
+ "\n",
+ "Read more about the `log` method [in the docs](https://pytorch-lightning.readthedocs.io/en/latest/lightning_module.html#log)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "IpbUYwBDVCKJ"
+ },
+ "outputs": [],
+ "source": [
+ "def training_step(self, batch, batch_idx):\n",
+ " xs, ys = batch\n",
+ " logits, loss = self.loss(xs, ys)\n",
+ " preds = torch.argmax(logits, 1)\n",
+ "\n",
+ " # logging metrics we calculated by hand\n",
+ " self.log('train/loss', loss, on_epoch=True)\n",
+ " # logging a pl.Metric\n",
+ " self.train_acc(preds, ys)\n",
+ " self.log('train/acc', self.train_acc, on_epoch=True)\n",
+ "\n",
+ " return loss\n",
+ "\n",
+ "def configure_optimizers(self):\n",
+ " return torch.optim.Adam(self.parameters(), lr=self.hparams[\"lr\"])\n",
+ "\n",
+ "LitMLP.training_step = training_step\n",
+ "LitMLP.configure_optimizers = configure_optimizers"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "JXquZkggVCKJ"
+ },
+ "source": [
+ "## ➕ Optional methods for even better logging\n",
+ "\n",
+ "The code above will log our model's performance,\n",
+ "system metrics, and more to W&B.\n",
+ "\n",
+ "If we want to take our logging to the next level,\n",
+ "we need to make use of PyTorch Lightning's callback system.\n",
+ "\n",
+ "> _Note_: thanks to the clean design of PyTorch Lightning,\n",
+ "the training code below will run with or without any\n",
+ "of this extra logging code. Nice!"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "TE8qYwtOVCKJ"
+ },
+ "source": [
+ "The other callbacks we'll make use of fall into two categories:\n",
+ "* methods that trigger on each batch for a dataset: `validation_step` and `test_step`\n",
+ "* methods that trigger at the end of an epoch,\n",
+ "or a full pass over a given dataset: `{training, validation, test}_epoch_end`"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Z8JylsGXVCKJ"
+ },
+ "source": [
+ "### 💾 `test`ing and saving the model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "AHy-SYLDVCKJ"
+ },
+ "source": [
+ "We use the test set to evaluate the performance of the final model,\n",
+ "so the `test` callbacks will be called at the end of the training pipeline.\n",
+ "\n",
+ "For performance on the `test` and `validation` sets,\n",
+ "we're typically less concerned about how\n",
+ "we do on intermediate steps and more\n",
+ "with how we did overall.\n",
+ "That's why below, we pass in\n",
+ "`on_step=False` and `on_epoch=True`\n",
+ "so that we log only `epoch`-wise metrics.\n",
+ "\n",
+ "> _Note_: That's actually the default behavior for `.log` when it's called inside of a `validation` or a `test` loop -- but not when it's called inside a `training` loop! Check out the table of default behaviors for `.log` [in the docs](https://pytorch-lightning.readthedocs.io/en/latest/lightning_module.html#log)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "yeVmedTTVCKJ"
+ },
+ "outputs": [],
+ "source": [
+ "def test_step(self, batch, batch_idx):\n",
+ " xs, ys = batch\n",
+ " logits, loss = self.loss(xs, ys)\n",
+ " preds = torch.argmax(logits, 1)\n",
+ "\n",
+ " self.test_acc(preds, ys)\n",
+ " self.log(\"test/loss_epoch\", loss, on_step=False, on_epoch=True)\n",
+ " self.log(\"test/acc_epoch\", self.test_acc, on_step=False, on_epoch=True)\n",
+ "\n",
+ "LitMLP.test_step = test_step"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "qk50ovviVCKJ"
+ },
+ "source": [
+ "We'll also take the opportunity to save the model in the\n",
+ "[portable `ONNX` format](https://onnx.ai/).\n",
+ "\n",
+ "\n",
+ "Later,\n",
+ "we'll see that this allows us to use the\n",
+ "[Netron model viewer](https://github.com/lutzroeder/netron) in W&B."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "QNr-9bEpVCKJ"
+ },
+ "outputs": [],
+ "source": [
+ "def on_test_epoch_end(self): # args are defined as part of pl API\n",
+ " dummy_input = torch.zeros(self.hparams[\"in_dims\"], device=self.device)\n",
+ " model_filename = \"model_final.onnx\"\n",
+ " self.to_onnx(model_filename, dummy_input, export_params=True)\n",
+ " artifact = wandb.Artifact(name=\"model.ckpt\", type=\"model\")\n",
+ " artifact.add_file(model_filename)\n",
+ " wandb.log_artifact(artifact)\n",
+ "\n",
+ "LitMLP.on_test_epoch_end = on_test_epoch_end"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "JosFHArGVCKK"
+ },
+ "source": [
+ "### 📊 Logging `Histograms`"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "FxViSjO4VCKK"
+ },
+ "source": [
+ "For the `validation_data`,\n",
+ "let's track not only the `acc`uracy and `loss`,\n",
+ "but also the `logits`:\n",
+ "the un-normalized class probabilities.\n",
+ "That way, we can track if our network\n",
+ "is becoming more or less confident over time.\n",
+ "\n",
+ "There's a problem though:\n",
+ "`.log` wants to average,\n",
+ "but we'd rather look at a distribution.\n",
+ "\n",
+ "So instead, on every `validation_step`,\n",
+ "we'll `return` the `logits`,\n",
+ "rather than `log`ging them.\n",
+ "\n",
+ "Then, when we reach the `end`\n",
+ "of the `validation_epoch`,\n",
+ "the `logits` are available as the\n",
+ "`validation_step_outputs` -- a list.\n",
+ "\n",
+ "So to log we'll take those `logits`,\n",
+ "concatenate them together,\n",
+ "and turn them into a histogram with [`wandb.Histogram`](https://docs.wandb.com/library/log#histograms).\n",
+ "\n",
+ "Because we're no longer using Lightning's `.log` interface and are instead using `wandb`,\n",
+ "we need to drop down a level and use\n",
+ "`self.experiment.logger.log`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "VrPs4tAJVCKK"
+ },
+ "outputs": [],
+ "source": [
+ "def on_validation_epoch_start(self):\n",
+ " self.validation_step_outputs = []\n",
+ "\n",
+ "def validation_step(self, batch, batch_idx):\n",
+ " xs, ys = batch\n",
+ " logits, loss = self.loss(xs, ys)\n",
+ " preds = torch.argmax(logits, 1)\n",
+ " self.valid_acc(preds, ys)\n",
+ "\n",
+ " self.log(\"valid/loss_epoch\", loss) # default on val/test is on_epoch only\n",
+ " self.log('valid/acc_epoch', self.valid_acc)\n",
+ "\n",
+ " self.validation_step_outputs.append(logits)\n",
+ "\n",
+ " return logits\n",
+ "\n",
+ "def on_validation_epoch_end(self):\n",
+ "\n",
+ " validation_step_outputs = self.validation_step_outputs\n",
+ "\n",
+ " dummy_input = torch.zeros(self.hparams[\"in_dims\"], device=self.device)\n",
+ " model_filename = f\"model_{str(self.global_step).zfill(5)}.onnx\"\n",
+ " torch.onnx.export(self, dummy_input, model_filename, opset_version=11)\n",
+ " artifact = wandb.Artifact(name=\"model.ckpt\", type=\"model\")\n",
+ " artifact.add_file(model_filename)\n",
+ " self.logger.experiment.log_artifact(artifact)\n",
+ "\n",
+ " flattened_logits = torch.flatten(torch.cat(validation_step_outputs))\n",
+ " self.logger.experiment.log(\n",
+ " {\"valid/logits\": wandb.Histogram(flattened_logits.to(\"cpu\")),\n",
+ " \"global_step\": self.global_step})\n",
+ "\n",
+ "LitMLP.on_validation_epoch_start = on_validation_epoch_start\n",
+ "LitMLP.validation_step = validation_step\n",
+ "LitMLP.on_validation_epoch_end = on_validation_epoch_end"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "617Vro8jVCKK"
+ },
+ "source": [
+ "Note that we're once again saving\n",
+ "the model in ONNX format.\n",
+ "That way, we can roll back our model to any given epoch --\n",
+ "useful in case the evaluation on the test set reveals we've overfit."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "yMe2e8FEVCKK"
+ },
+ "source": [
+ "### 📲 `Callback`s for extra-fancy logging"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "X-DfYBQrVCKL"
+ },
+ "source": [
+ "What we've done so far\n",
+ "will tell us how well our model\n",
+ "is using our system resources,\n",
+ "how well our model is training and generalizing,\n",
+ "and how confident it is.\n",
+ "\n",
+ "But DNNs often fail in pernicious and silent ways.\n",
+ "Often, the only way to notice these failures\n",
+ "is to look at how the model is doing\n",
+ "on specific examples.\n",
+ "\n",
+ "So let's additionally log some detailed information on some specific examples:\n",
+ "the inputs, outputs,\n",
+ "and `pred`ictions.\n",
+ "\n",
+ "We'll do this by writing our own `Callback` --\n",
+ "one that, after every `validation_epoch` ends,\n",
+ "logs input images and output predictions\n",
+ "using W&B's `Image` logger.\n",
+ "\n",
+ "> _Note_:\n",
+ "For more on the W&B media toolkit, read the [docs](https://docs.wandb.com/library/log#media)\n",
+ "or check out\n",
+ "[this Colab](http://wandb.me/media-colab)\n",
+ "to see everything it's capable of."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "VM9hxnNtVCKL"
+ },
+ "outputs": [],
+ "source": [
+ "class ImagePredictionLogger(pl.Callback):\n",
+ " def __init__(self, val_samples, num_samples=32):\n",
+ " super().__init__()\n",
+ " self.val_imgs, self.val_labels = val_samples\n",
+ " self.val_imgs = self.val_imgs[:num_samples]\n",
+ " self.val_labels = self.val_labels[:num_samples]\n",
+ "\n",
+ " def on_validation_epoch_end(self, trainer, pl_module):\n",
+ " val_imgs = self.val_imgs.to(device=pl_module.device)\n",
+ "\n",
+ " logits = pl_module(val_imgs)\n",
+ " preds = torch.argmax(logits, 1)\n",
+ "\n",
+ " trainer.logger.experiment.log({\n",
+ " \"examples\": [wandb.Image(x, caption=f\"Pred:{pred}, Label:{y}\")\n",
+ " for x, pred, y in zip(val_imgs, preds, self.val_labels)],\n",
+ " \"global_step\": trainer.global_step\n",
+ " })"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "6z2bcdwZVCKL"
+ },
+ "source": [
+ "# 🛒 Loading data\n",
+ "\n",
+ "Data pipelines can be created with:\n",
+ "* 🍦 Vanilla Pytorch `DataLoaders`\n",
+ "* ⚡ Pytorch Lightning `DataModules`\n",
+ "\n",
+ "`DataModules` are more structured definition, which allows for additional optimizations such as automated distribution of workload between CPU & GPU.\n",
+ "Using `DataModules` is recommended whenever possible!\n",
+ "\n",
+ "A `DataModule` is also defined by an interface:\n",
+ "* `prepare_data` (optional) which is called only once and on 1 GPU -- typically something like the data download step we have below\n",
+ "* `setup`, which is called on each GPU separately and accepts `stage` to define if we are at `fit` or `test` step\n",
+ "* `train_dataloader`, `val_dataloader` and `test_dataloader` to load each dataset"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "pO5M1XEnVCKL"
+ },
+ "outputs": [],
+ "source": [
+ "class MNISTDataModule(pl.LightningDataModule):\n",
+ "\n",
+ " def __init__(self, data_dir='./', batch_size=128):\n",
+ " super().__init__()\n",
+ " self.data_dir = data_dir\n",
+ " self.batch_size = batch_size\n",
+ " self.transform = transforms.Compose([\n",
+ " transforms.ToTensor(),\n",
+ " transforms.Normalize((0.1307,), (0.3081,))])\n",
+ "\n",
+ " def prepare_data(self):\n",
+ " # download data, train then test\n",
+ " MNIST(self.data_dir, train=True, download=True)\n",
+ " MNIST(self.data_dir, train=False, download=True)\n",
+ "\n",
+ " def setup(self, stage=None):\n",
+ "\n",
+ " # we set up only relevant datasets when stage is specified\n",
+ " if stage == 'fit' or stage is None:\n",
+ " mnist = MNIST(self.data_dir, train=True, transform=self.transform)\n",
+ " self.mnist_train, self.mnist_val = random_split(mnist, [55000, 5000])\n",
+ " if stage == 'test' or stage is None:\n",
+ " self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)\n",
+ "\n",
+ " # we define a separate DataLoader for each of train/val/test\n",
+ " def train_dataloader(self):\n",
+ " mnist_train = DataLoader(self.mnist_train, batch_size=self.batch_size)\n",
+ " return mnist_train\n",
+ "\n",
+ " def val_dataloader(self):\n",
+ " mnist_val = DataLoader(self.mnist_val, batch_size=10 * self.batch_size)\n",
+ " return mnist_val\n",
+ "\n",
+ " def test_dataloader(self):\n",
+ " mnist_test = DataLoader(self.mnist_test, batch_size=10 * self.batch_size)\n",
+ " return mnist_test\n",
+ "\n",
+ "# setup data\n",
+ "mnist = MNISTDataModule()\n",
+ "mnist.prepare_data()\n",
+ "mnist.setup()\n",
+ "\n",
+ "# grab samples to log predictions on\n",
+ "samples = next(iter(mnist.val_dataloader()))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "6Ord5Pu3VCKM"
+ },
+ "source": [
+ "# 👟 Making a `Trainer`\n",
+ "\n",
+ "The `DataLoader` and the `LightningModule`\n",
+ "are brought together by a `Trainer`,\n",
+ "which orchestrates data loading,\n",
+ "gradient calculation,\n",
+ "optimizer logic,\n",
+ "and logging.\n",
+ "\n",
+ "Luckily, we don't need to sub-class the `Trainer`,\n",
+ "we just need to configure it with keyword arguments."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "MDDjVOSEVCKM"
+ },
+ "source": [
+ "And that is where we'll use the `pytorch_lightning.loggers.WandbLogger` to connect our logging to W&B."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "3nibeocmVCKM"
+ },
+ "outputs": [],
+ "source": [
+ "wandb_logger = WandbLogger(project=\"lit-wandb\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "YDEryNsoVCKM"
+ },
+ "source": [
+ "> _Note_: Check out [the documentation](https://docs.wandb.com/library/integrations/lightning) for customization options. I like `group`s and `tag`s!.\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "NLfJc41aVCKM"
+ },
+ "source": [
+ "We can then set up our `Trainer` and customize several options, such as gradient accumulation, half precision training and distributed computing.\n",
+ "\n",
+ "We'll stick to the basics for this example,\n",
+ "but half-precision training and easy scaling to distributed settings are two of the major reasons why folks like PyTorch Lightning!"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "0dciVNKHVCKN"
+ },
+ "outputs": [],
+ "source": [
+ "trainer = pl.Trainer(\n",
+ " logger=wandb_logger, # W&B integration\n",
+ " log_every_n_steps=50, # set the logging frequency\n",
+ " max_epochs=5, # number of epochs\n",
+ " deterministic=True, # keep it deterministic\n",
+ " callbacks=[ImagePredictionLogger(samples)] # see Callbacks section\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Fs6hyyU2VCKN"
+ },
+ "source": [
+ "# 🏃♀️ Running our Model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "LzVnPc6_VCKN"
+ },
+ "source": [
+ "Now, let's make it all happen:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "q7VxL_o6VCKN"
+ },
+ "outputs": [],
+ "source": [
+ "# setup model\n",
+ "model = LitMLP(in_dims=(1, 28, 28))\n",
+ "\n",
+ "# fit the model\n",
+ "trainer.fit(model, mnist)\n",
+ "\n",
+ "# evaluate the model on a test set\n",
+ "trainer.test(datamodule=mnist,\n",
+ " ckpt_path=None) # uses last-saved model\n",
+ "\n",
+ "wandb.finish()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "cryJvmKOVCKN"
+ },
+ "source": [
+ "> _Note_: In notebooks, we need to call `wandb.finish()` to indicate when we've finished our run. This isn't necessary in scripts."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "t9jsRs-oVCKN"
+ },
+ "source": [
+ "## Viewing the results on wandb.ai\n",
+ "\n",
+ "Among the outputs from W&B,\n",
+ "you will have noticed a few URLs.\n",
+ "One of these is the\n",
+ "[run page](https://docs.wandb.ai/ref/app/pages/run-page),\n",
+ "which has a dashboard with all of the information logged in this run, complete with smart default charts\n",
+ "and more.\n",
+ "The run page is printed both at the start and end of training, and ends with `lit-wandb/runs/{run_id}`.\n",
+ "\n",
+ ">_Note_: When visiting your run page, it is recommended to use `global_step` as x-axis to correctly superimpose metrics logged in different stages.\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "fgPkQGamVCKN"
+ },
+ "source": [
+ "\n",
+ "![image.png]()"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
\ No newline at end of file
diff --git a/colabs/pytorch-lightning/Transfer_Learning_Using_PyTorch_Lightning.ipynb b/colabs/pytorch-lightning/Transfer_Learning_Using_PyTorch_Lightning.ipynb
index 3a4a744b..776bc3b1 100644
--- a/colabs/pytorch-lightning/Transfer_Learning_Using_PyTorch_Lightning.ipynb
+++ b/colabs/pytorch-lightning/Transfer_Learning_Using_PyTorch_Lightning.ipynb
@@ -38,7 +38,7 @@
"metadata": {},
"outputs": [],
"source": [
- "!pip install wandb pytorch-lightning -qqq"
+ "!pip install wandb lightning torchvision -qqq"
]
},
{
@@ -56,9 +56,9 @@
"source": [
"import os\n",
"\n",
- "import pytorch_lightning as pl\n",
+ "import lightning.pytorch as pl\n",
"# your favorite machine learning tracking tool\n",
- "from pytorch_lightning.loggers import WandbLogger\n",
+ "from lightning.pytorch.loggers import WandbLogger\n",
"\n",
"import torch\n",
"from torch import nn\n",
diff --git a/colabs/pytorch-lightning/Wandb_End_to_End_with_PyTorch_Lightning.ipynb b/colabs/pytorch-lightning/Wandb_End_to_End_with_PyTorch_Lightning.ipynb
index fdd961eb..32a5df77 100644
--- a/colabs/pytorch-lightning/Wandb_End_to_End_with_PyTorch_Lightning.ipynb
+++ b/colabs/pytorch-lightning/Wandb_End_to_End_with_PyTorch_Lightning.ipynb
@@ -1,673 +1,737 @@
{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "\n",
- ""
- ]
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Ws0nlkuGOpDy"
+ },
+ "source": [
+ "\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "aU1q92uCOpD1"
+ },
+ "source": [
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "# W&B Tutorial with Pytorch Lightning"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "mHCKzEzSOpD1"
+ },
+ "source": [
+ "## 🛠️ Install `wandb` and `pytorch-lightning`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "RnyGWvwDOpD1"
+ },
+ "outputs": [],
+ "source": [
+ "!pip install -q lightning wandb torchvision"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "G5wpAEAoOpD2"
+ },
+ "source": [
+ "## Login to W&B either through Python or CLI\n",
+ "If you are using the public W&B cloud, you don't need to specify the `WANDB_HOST`.\n",
+ "\n",
+ "You can set environment variables `WANDB_API_KEY` and `WANDB_HOST` and pass them in as:\n",
+ "```\n",
+ "import os\n",
+ "import wandb\n",
+ "\n",
+ "wandb.login(host=os.getenv(\"WANDB_HOST\"), key=os.getenv(\"WANDB_API_KEY\"))\n",
+ "```\n",
+ "You can also login via the CLI with:\n",
+ "```\n",
+ "wandb login --host \n",
+ "```"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "YjXGiBLQOpD2"
+ },
+ "outputs": [],
+ "source": [
+ "import wandb"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "r5HflvclOpD2"
+ },
+ "outputs": [],
+ "source": [
+ "wandb.login()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "zNuxn8BjOpD2"
+ },
+ "source": [
+ "## ⚱ Logging the Raw Training Data as an Artifact"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "X9mkr942OpD2"
+ },
+ "outputs": [],
+ "source": [
+ "#@title Enter your W&B project and entity\n",
+ "\n",
+ "# FORM VARIABLES\n",
+ "PROJECT_NAME = \"pytorch-lightning-e2e\" #@param {type:\"string\"}\n",
+ "ENTITY = \"wandb\"#@param {type:\"string\"}\n",
+ "\n",
+ "# set SIZE to \"TINY\", \"SMALL\", \"MEDIUM\", or \"LARGE\"\n",
+ "# to select one of these three datasets\n",
+ "# TINY dataset: 100 images, 30MB\n",
+ "# SMALL dataset: 1000 images, 312MB\n",
+ "# MEDIUM dataset: 5000 images, 1.5GB\n",
+ "# LARGE dataset: 12,000 images, 3.6GB\n",
+ "\n",
+ "SIZE = \"TINY\"\n",
+ "\n",
+ "if SIZE == \"TINY\":\n",
+ " src_url = \"https://storage.googleapis.com/wandb_datasets/nature_100.zip\"\n",
+ " src_zip = \"nature_100.zip\"\n",
+ " DATA_SRC = \"nature_100\"\n",
+ " IMAGES_PER_LABEL = 10\n",
+ " BALANCED_SPLITS = {\"train\" : 8, \"val\" : 1, \"test\": 1}\n",
+ "elif SIZE == \"SMALL\":\n",
+ " src_url = \"https://storage.googleapis.com/wandb_datasets/nature_1K.zip\"\n",
+ " src_zip = \"nature_1K.zip\"\n",
+ " DATA_SRC = \"nature_1K\"\n",
+ " IMAGES_PER_LABEL = 100\n",
+ " BALANCED_SPLITS = {\"train\" : 80, \"val\" : 10, \"test\": 10}\n",
+ "elif SIZE == \"MEDIUM\":\n",
+ " src_url = \"https://storage.googleapis.com/wandb_datasets/nature_12K.zip\"\n",
+ " src_zip = \"nature_12K.zip\"\n",
+ " DATA_SRC = \"inaturalist_12K/train\" # (technically a subset of only 10K images)\n",
+ " IMAGES_PER_LABEL = 500\n",
+ " BALANCED_SPLITS = {\"train\" : 400, \"val\" : 50, \"test\": 50}\n",
+ "elif SIZE == \"LARGE\":\n",
+ " src_url = \"https://storage.googleapis.com/wandb_datasets/nature_12K.zip\"\n",
+ " src_zip = \"nature_12K.zip\"\n",
+ " DATA_SRC = \"inaturalist_12K/train\" # (technically a subset of only 10K images)\n",
+ " IMAGES_PER_LABEL = 1000\n",
+ " BALANCED_SPLITS = {\"train\" : 800, \"val\" : 100, \"test\": 100}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "o8nApIhdOpD3"
+ },
+ "outputs": [],
+ "source": [
+ "%%capture\n",
+ "!curl -SL $src_url > $src_zip\n",
+ "!unzip $src_zip"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
+ ],
+ "metadata": {
+ "id": "ALmdQ7wISLaA"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "XQ_Kwsg9OpD3"
+ },
+ "outputs": [],
+ "source": [
+ "import wandb\n",
+ "import pandas as pd\n",
+ "import os\n",
+ "\n",
+ "with wandb.init(project=PROJECT_NAME, entity=ENTITY, job_type='log_datasets') as run:\n",
+ " img_paths = []\n",
+ " for root, dirs, files in os.walk('nature_100', topdown=False):\n",
+ " for name in files:\n",
+ " img_path = os.path.join(root, name)\n",
+ " label = img_path.split('/')[1]\n",
+ " img_paths.append([img_path, label])\n",
+ "\n",
+ " index_df = pd.DataFrame(columns=['image_path', 'label'], data=img_paths)\n",
+ " index_df.to_csv('index.csv', index=False)\n",
+ "\n",
+ " train_art = wandb.Artifact(name='Nature_100', type='raw_images', description='nature image dataset with 10 classes, 10 images per class')\n",
+ " train_art.add_dir('nature_100')\n",
+ "\n",
+ " # Also adding a csv indicating the labels of each image\n",
+ " train_art.add_file('index.csv')\n",
+ " wandb.log_artifact(train_art)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "1eJpOk_VOpD3"
+ },
+ "source": [
+ "## Using Artifacts in Pytorch Lightning `DataModule`'s and Pytorch `Dataset`'s\n",
+ "- Makes it easy to interopt your DataLoaders with new versions of datasets\n",
+ "- Just indicate the `name:alias` as an argument to your `Dataset` or `DataModule`\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "Z2g9JRrwOpD3"
+ },
+ "outputs": [],
+ "source": [
+ "from torchvision import transforms\n",
+ "import lightning.pytorch as pl\n",
+ "import torch\n",
+ "from torch.utils.data import Dataset, DataLoader, random_split\n",
+ "from skimage import io, transform\n",
+ "from torchvision import transforms, utils, models\n",
+ "import math\n",
+ "\n",
+ "class NatureDataset(Dataset):\n",
+ " def __init__(self,\n",
+ " wandb_run,\n",
+ " artifact_name_alias=\"Nature_100:latest\",\n",
+ " local_target_dir=\"Nature_100:latest\",\n",
+ " transform=None):\n",
+ " self.local_target_dir = local_target_dir\n",
+ " self.transform = transform\n",
+ "\n",
+ " # Pull down the artifact locally to load it into memory\n",
+ " art = wandb_run.use_artifact(artifact_name_alias)\n",
+ " path_at = art.download(root=self.local_target_dir)\n",
+ "\n",
+ " self.ref_df = pd.read_csv(os.path.join(self.local_target_dir, 'index.csv'))\n",
+ " self.class_names = self.ref_df.iloc[:, 1].unique().tolist()\n",
+ " self.idx_to_class = {k: v for k, v in enumerate(self.class_names)}\n",
+ " self.class_to_idx = {v: k for k, v in enumerate(self.class_names)}\n",
+ "\n",
+ " def __len__(self):\n",
+ " return len(self.ref_df)\n",
+ "\n",
+ " def __getitem__(self, idx):\n",
+ " if torch.is_tensor(idx):\n",
+ " idx = idx.tolist()\n",
+ "\n",
+ " img_path = self.ref_df.iloc[idx, 0]\n",
+ "\n",
+ " image = io.imread(img_path)\n",
+ " label = self.ref_df.iloc[idx, 1]\n",
+ " label = torch.tensor(self.class_to_idx[label], dtype=torch.long)\n",
+ "\n",
+ " if self.transform:\n",
+ " image = self.transform(image)\n",
+ "\n",
+ " return image, label\n",
+ "\n",
+ "\n",
+ "class NatureDatasetModule(pl.LightningDataModule):\n",
+ " def __init__(self,\n",
+ " wandb_run,\n",
+ " artifact_name_alias: str = \"Nature_100:latest\",\n",
+ " local_target_dir: str = \"Nature_100:latest\",\n",
+ " batch_size: int = 16,\n",
+ " input_size: int = 224,\n",
+ " seed: int = 42):\n",
+ " super().__init__()\n",
+ " self.wandb_run = wandb_run\n",
+ " self.artifact_name_alias = artifact_name_alias\n",
+ " self.local_target_dir = local_target_dir\n",
+ " self.batch_size = batch_size\n",
+ " self.input_size = input_size\n",
+ " self.seed = seed\n",
+ "\n",
+ " def setup(self, stage=None):\n",
+ " self.nature_dataset = NatureDataset(wandb_run=self.wandb_run,\n",
+ " artifact_name_alias=self.artifact_name_alias,\n",
+ " local_target_dir=self.local_target_dir,\n",
+ " transform=transforms.Compose([transforms.ToTensor(),\n",
+ " transforms.CenterCrop(self.input_size),\n",
+ " transforms.Normalize((0.485, 0.456, 0.406),\n",
+ " (0.229, 0.224, 0.225))]))\n",
+ "\n",
+ " nature_length = len(self.nature_dataset)\n",
+ " train_size = math.floor(0.8 * nature_length)\n",
+ " val_size = math.floor(0.2 * nature_length)\n",
+ " self.nature_train, self.nature_val = random_split(self.nature_dataset,\n",
+ " [train_size, val_size],\n",
+ " generator=torch.Generator().manual_seed(self.seed))\n",
+ " return self\n",
+ "\n",
+ " def train_dataloader(self):\n",
+ " return DataLoader(self.nature_train, batch_size=self.batch_size)\n",
+ "\n",
+ " def val_dataloader(self):\n",
+ " return DataLoader(self.nature_val, batch_size=self.batch_size)\n",
+ "\n",
+ " def predict_dataloader(self):\n",
+ " pass\n",
+ "\n",
+ " def teardown(self, stage: str):\n",
+ " pass"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "NLhjdvF-OpD3"
+ },
+ "source": [
+ "##How Logging in your Pytorch `LightningModule`works:\n",
+ "When you train the model using `Trainer`, ensure you have a `WandbLogger` instantiated and passed in as a `logger`.\n",
+ "\n",
+ "```\n",
+ "wandb_logger = WandbLogger(project=\"my_project\", entity=\"machine-learning\")\n",
+ "trainer = Trainer(logger=wandb_logger)\n",
+ "```\n",
+ "\n",
+ "\n",
+ "You can always use `wandb.log` as normal throughout the module. When the `WandbLogger` is used, `self.log` will also log metrics to W&B.\n",
+ "- To access the current run from within the `LightningModule`, you can access `Trainer.logger.experiment`, which is a `wandb.Run` object"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Oi1NpQs7OpD4"
+ },
+ "source": [
+ "### Some helper functions"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "HLgYji2oOpD4"
+ },
+ "outputs": [],
+ "source": [
+ "# Some helper functions\n",
+ "\n",
+ "def set_parameter_requires_grad(model, feature_extracting):\n",
+ " if feature_extracting:\n",
+ " for param in model.parameters():\n",
+ " param.requires_grad = False\n",
+ "\n",
+ "def initialize_model(model_name, num_classes, feature_extract, use_pretrained=True):\n",
+ " # Initialize these variables which will be set in this if statement. Each of these\n",
+ " # variables is model specific.\n",
+ " model_ft = None\n",
+ " input_size = 0\n",
+ "\n",
+ " if model_name == \"resnet\":\n",
+ " \"\"\" Resnet18\n",
+ " \"\"\"\n",
+ " model_ft = models.resnet18(pretrained=use_pretrained)\n",
+ " set_parameter_requires_grad(model_ft, feature_extract)\n",
+ " num_ftrs = model_ft.fc.in_features\n",
+ " model_ft.fc = torch.nn.Linear(num_ftrs, num_classes)\n",
+ " input_size = 224\n",
+ "\n",
+ " elif model_name == \"squeezenet\":\n",
+ " \"\"\" Squeezenet\n",
+ " \"\"\"\n",
+ " model_ft = models.squeezenet1_0(pretrained=use_pretrained)\n",
+ " set_parameter_requires_grad(model_ft, feature_extract)\n",
+ " model_ft.classifier[1] = torch.nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1))\n",
+ " model_ft.num_classes = num_classes\n",
+ " input_size = 224\n",
+ "\n",
+ " elif model_name == \"densenet\":\n",
+ " \"\"\" Densenet\n",
+ " \"\"\"\n",
+ " model_ft = models.densenet121(pretrained=use_pretrained)\n",
+ " set_parameter_requires_grad(model_ft, feature_extract)\n",
+ " num_ftrs = model_ft.classifier.in_features\n",
+ " model_ft.classifier = torch.nn.Linear(num_ftrs, num_classes)\n",
+ " input_size = 224\n",
+ "\n",
+ " else:\n",
+ " print(\"Invalid model name, exiting...\")\n",
+ " exit()\n",
+ "\n",
+ " return model_ft, input_size"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "LLRK0S6oOpD4"
+ },
+ "source": [
+ "### Writing the `LightningModule`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "k0tTtK5zOpD4"
+ },
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "from torch.nn import Linear, CrossEntropyLoss, functional as F\n",
+ "from torch.optim import Adam\n",
+ "from torchmetrics.functional import accuracy\n",
+ "from lightning.pytorch import LightningModule\n",
+ "from torchvision import models\n",
+ "\n",
+ "class NatureLitModule(LightningModule):\n",
+ " def __init__(self,\n",
+ " model_name,\n",
+ " num_classes=10,\n",
+ " feature_extract=True,\n",
+ " lr=0.01):\n",
+ " '''method used to define our model parameters'''\n",
+ " super().__init__()\n",
+ "\n",
+ " self.model_name = model_name\n",
+ " self.num_classes = num_classes\n",
+ " self.feature_extract = feature_extract\n",
+ " self.model, self.input_size = initialize_model(model_name=self.model_name,\n",
+ " num_classes=self.num_classes,\n",
+ " feature_extract=True)\n",
+ "\n",
+ " # loss\n",
+ " self.loss = CrossEntropyLoss()\n",
+ "\n",
+ " # optimizer parameters\n",
+ " self.lr = lr\n",
+ "\n",
+ " # save hyper-parameters to self.hparams (auto-logged by W&B)\n",
+ " self.save_hyperparameters()\n",
+ "\n",
+ " # Record the gradients of all the layers\n",
+ " wandb.watch(self.model)\n",
+ "\n",
+ " def forward(self, x):\n",
+ " '''method used for inference input -> output'''\n",
+ " x = self.model(x)\n",
+ "\n",
+ " return x\n",
+ "\n",
+ " def training_step(self, batch, batch_idx):\n",
+ " '''needs to return a loss from a single batch'''\n",
+ " preds, y, loss, acc = self._get_preds_loss_accuracy(batch)\n",
+ "\n",
+ " # Log loss and metric\n",
+ " self.log('train/loss', loss)\n",
+ " self.log('train/accuracy', acc)\n",
+ "\n",
+ " return loss\n",
+ "\n",
+ " def validation_step(self, batch, batch_idx):\n",
+ " '''used for logging metrics'''\n",
+ " preds, y, loss, acc = self._get_preds_loss_accuracy(batch)\n",
+ "\n",
+ " # Log loss and metric\n",
+ " self.log('validation/loss', loss)\n",
+ " self.log('validation/accuracy', acc)\n",
+ "\n",
+ " # Let's return preds to use it in a custom callback\n",
+ " return preds, y\n",
+ "\n",
+ " def test_step(self, batch, batch_idx):\n",
+ " '''used for logging metrics'''\n",
+ " preds, y, loss, acc = self._get_preds_loss_accuracy(batch)\n",
+ "\n",
+ " # Log loss and metric\n",
+ " self.log('test/loss', loss)\n",
+ " self.log('test/accuracy', acc)\n",
+ "\n",
+ " def configure_optimizers(self):\n",
+ " '''defines model optimizer'''\n",
+ " return Adam(self.parameters(), lr=self.lr)\n",
+ "\n",
+ "\n",
+ " def _get_preds_loss_accuracy(self, batch):\n",
+ " '''convenience function since train/valid/test steps are similar'''\n",
+ " x, y = batch\n",
+ " logits = self(x)\n",
+ " preds = torch.argmax(logits, dim=1)\n",
+ " loss = self.loss(logits, y)\n",
+ " acc = accuracy(preds, y, task=\"multiclass\", num_classes=10)\n",
+ " return preds, y, loss, acc"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "bvZHZuZzOpD4"
+ },
+ "source": [
+ "### Instrument Callbacks to log additional things at certain points in your code"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "j9pHqaw1OpD5"
+ },
+ "outputs": [],
+ "source": [
+ "from lightning.pytorch.callbacks import Callback\n",
+ "\n",
+ "class LogPredictionsCallback(Callback):\n",
+ "\n",
+ " def __init__(self):\n",
+ " super().__init__()\n",
+ "\n",
+ "\n",
+ " def on_validation_epoch_start(self, trainer, pl_module):\n",
+ " self.batch_dfs = []\n",
+ " self.image_list = []\n",
+ " self.val_table = wandb.Table(columns=['image', 'ground_truth', 'prediction'])\n",
+ "\n",
+ "\n",
+ " def on_validation_batch_end(\n",
+ " self, trainer, pl_module, outputs, batch, batch_idx):\n",
+ " \"\"\"Called when the validation batch ends.\"\"\"\n",
+ "\n",
+ " # Append validation predictions and ground truth to log in confusion matrix\n",
+ " x, y = batch\n",
+ " preds, y = outputs\n",
+ " self.batch_dfs.append(pd.DataFrame({\"Ground Truth\": y.cpu().numpy(), \"Predictions\": preds.cpu().numpy()}))\n",
+ "\n",
+ " # Add wandb.Image to a table to log at the end of validation\n",
+ " x = x.cpu().numpy().transpose(0, 2, 3, 1)\n",
+ " for x_i, y_i, y_pred in list(zip(x, y, preds)):\n",
+ " self.image_list.append(wandb.Image(x_i, caption=f'Ground Truth: {y_i} - Prediction: {y_pred}'))\n",
+ " self.val_table.add_data(wandb.Image(x_i), y_i, y_pred)\n",
+ "\n",
+ "\n",
+ " def on_validation_epoch_end(self, trainer, pl_module):\n",
+ " # Collect statistics for whole validation set and log\n",
+ " class_names = trainer.datamodule.nature_dataset.class_names\n",
+ " val_df = pd.concat(self.batch_dfs)\n",
+ " wandb.log({\"validation_table\": self.val_table,\n",
+ " \"images_over_time\": self.image_list,\n",
+ " \"validation_conf_matrix\": wandb.plot.confusion_matrix(y_true = val_df[\"Ground Truth\"].tolist(),\n",
+ " preds=val_df[\"Predictions\"].tolist(),\n",
+ " class_names=class_names)}, step=trainer.global_step)\n",
+ "\n",
+ " del self.batch_dfs\n",
+ " del self.val_table\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "jAiH-kb1OpD5"
+ },
+ "source": [
+ "## 🏋️ Main Training Loop"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "FrYTQ7uaOpD5"
+ },
+ "outputs": [],
+ "source": [
+ "from lightning.pytorch.callbacks import ModelCheckpoint\n",
+ "from lightning.pytorch.loggers import WandbLogger\n",
+ "from lightning.pytorch import Trainer\n",
+ "\n",
+ "wandb.init(project=PROJECT_NAME,\n",
+ " entity=ENTITY,\n",
+ " job_type='training',\n",
+ " config={\n",
+ " \"model_name\": \"squeezenet\",\n",
+ " \"batch_size\": 16\n",
+ " })\n",
+ "\n",
+ "wandb_logger = WandbLogger(log_model='all', checkpoint_name=f'nature-{wandb.run.id}')\n",
+ "\n",
+ "log_predictions_callback = LogPredictionsCallback()\n",
+ "checkpoint_callback = ModelCheckpoint(every_n_epochs=1)\n",
+ "\n",
+ "model = NatureLitModule(model_name=wandb.config['model_name']) # Access hyperparameters downstream to instantiate models/datasets\n",
+ "\n",
+ "nature_module = NatureDatasetModule(wandb_run = wandb_logger.experiment,\n",
+ " artifact_name_alias = \"Nature_100:latest\",\n",
+ " local_target_dir = \"Nature_100:latest\",\n",
+ " batch_size=wandb.config['batch_size'],\n",
+ " input_size=model.input_size)\n",
+ "nature_module.setup()\n",
+ "\n",
+ "trainer = Trainer(logger=wandb_logger, # W&B integration\n",
+ " callbacks=[log_predictions_callback, checkpoint_callback],\n",
+ " max_epochs=5,\n",
+ " log_every_n_steps=5)\n",
+ "trainer.fit(model, datamodule=nature_module)\n",
+ "\n",
+ "wandb.finish()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "uEII8J7UOpD5"
+ },
+ "source": [
+ "### Syncing with W&B Offline\n",
+ "If for some reason, network communication is lost during the course of training, you can always sync progress with `wandb sync`\n",
+ "\n",
+ "The W&B sdk caches all logged data in a local directory `wandb` and when you call `wandb sync`, this syncs the your local state with the web app."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "xi-EhHsxOpD5"
+ },
+ "source": [
+ "## Retrieve a model checkpoint artifact and resume training\n",
+ "- Artifacts make it easy to track state of your training remotely and then resume training from a checkpoint"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "jiE1Bk7fOpD5"
+ },
+ "outputs": [],
+ "source": [
+ "#@title Enter which checkpoint you want to resume training from:\n",
+ "\n",
+ "# FORM VARIABLES\n",
+ "ARTIFACT_NAME_ALIAS = \"nature-oyxk79m1:v4\" #@param {type:\"string\"}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "jZZXWRatOpD5"
+ },
+ "outputs": [],
+ "source": [
+ "wandb.init(project=PROJECT_NAME,\n",
+ " entity=ENTITY,\n",
+ " job_type='resume_training')\n",
+ "\n",
+ "# Retrieve model checkpoint artifact and restore previous hyperparameters\n",
+ "model_chkpt_art = wandb.use_artifact(f'{ENTITY}/{PROJECT_NAME}/{ARTIFACT_NAME_ALIAS}')\n",
+ "model_chkpt_art.download() # Can change download directory by adding `root`, defaults to \"./artifacts\"\n",
+ "logging_run = model_chkpt_art.logged_by()\n",
+ "wandb.config = logging_run.config\n",
+ "\n",
+ "# Can create a new artifact name or continue logging to the old one\n",
+ "artifact_name = ARTIFACT_NAME_ALIAS.split(\":\")[0]\n",
+ "wandb_logger = WandbLogger(log_model='all', checkpoint_name=artifact_name)\n",
+ "\n",
+ "log_predictions_callback = LogPredictionsCallback()\n",
+ "checkpoint_callback = ModelCheckpoint(every_n_epochs=1)\n",
+ "\n",
+ "model = NatureLitModule.load_from_checkpoint(f'./artifacts/{ARTIFACT_NAME_ALIAS}/model.ckpt')\n",
+ "\n",
+ "nature_module = NatureDatasetModule(wandb_run = wandb_logger.experiment,\n",
+ " artifact_name_alias = \"Nature_100:latest\",\n",
+ " local_target_dir = \"Nature_100:latest\",\n",
+ " batch_size=wandb.config['batch_size'],\n",
+ " input_size=model.input_size)\n",
+ "nature_module.setup()\n",
+ "\n",
+ "\n",
+ "\n",
+ "trainer = Trainer(logger=wandb_logger, # W&B integration\n",
+ " callbacks=[log_predictions_callback, checkpoint_callback],\n",
+ " max_epochs=10,\n",
+ " log_every_n_steps=5)\n",
+ "trainer.fit(model, datamodule=nature_module)\n",
+ "\n",
+ "wandb.finish()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "KTforR6EOpD5"
+ },
+ "source": [
+ "## Model Registry\n",
+ "After logging a bunch of checkpoints across multiple runs during experimentation, now comes time to hand-off the best checkpoint to the next stage of the workflow (e.g. testing, deployment).\n",
+ "\n",
+ "The model registry offers a centralized place to house the best checkpoints for all your model tasks. Any `model` artifact you log can be \"linked\" to a Registered Model. Here are the steps to start using the model registry for more organized model management:\n",
+ "1. Access your team's model registry by going the team page and selecting `Model Registry`\n",
+ "![model registry](https://drive.google.com/uc?export=view&id=1ZtJwBsFWPTm4Sg5w8vHhRpvDSeQPwsKw)\n",
+ "\n",
+ "2. Create a new Registered Model.\n",
+ "![model registry](https://drive.google.com/uc?export=view&id=1RuayTZHNE0LJCxt1t0l6-2zjwiV4aDXe)\n",
+ "\n",
+ "3. Go to the artifacts tab of the project that holds all your model checkpoints\n",
+ "![model registry](https://drive.google.com/uc?export=view&id=1r_jlhhtcU3as8VwQ-4oAntd8YtTwElFB)\n",
+ "\n",
+ "4. Click \"Link to Registry\" for the model artifact version you want. (Alternatively you can [link a model via api](https://docs.wandb.ai/guides/models) with `wandb.run.link_artifact`)\n",
+ "\n",
+ "**A note on linking:** The process of linking a model checkpoint is akin to \"bookmarking\" it. Each time you link a new model artifact to a Registered Model, this increments the version of the Registered Model. This helps delineate the model development side of the workflow from the model deployment/consumption side. The globally understood version/alias of a model should be unpolluted from all the experimental versions being generated in R&D and thus the versioning of a Registered Model increments according to new \"bookmarked\" models as opposed to model checkpoint logging.\n",
+ "\n",
+ "\n",
+ "### Create a Centralized Hub for all your models\n",
+ "- Add a model card, tags, slack notifactions to your Registered Model\n",
+ "- Change aliases to reflect when models move through different phases\n",
+ "- Embed the model registry in reports for model documentation and regression reports. See this report as an [example](https://api.wandb.ai/links/wandb-smle/r82bj9at)\n",
+ "![model registry](https://drive.google.com/uc?export=view&id=1lKPgaw-Ak4WK_91aBMcLvUMJL6pDQpgO)\n"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "provenance": [],
+ "toc_visible": true
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ }
},
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "\n",
- "\n",
- "\n",
- "\n",
- "# W&B Tutorial with Pytorch Lightning"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 🛠️ Install `wandb` and `pytorch-lightning`"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "!pip install -q pytorch-lightning wandb"
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Login to W&B either through Python or CLI\n",
- "If you are using the public W&B cloud, you don't need to specify the `WANDB_HOST`.\n",
- "\n",
- "You can set environment variables `WANDB_API_KEY` and `WANDB_HOST` and pass them in as:\n",
- "```\n",
- "import os\n",
- "import wandb \n",
- "\n",
- "wandb.login(host=os.getenv(\"WANDB_HOST\"), key=os.getenv(\"WANDB_API_KEY\"))\n",
- "```\n",
- "You can also login via the CLI with: \n",
- "```\n",
- "wandb login --host \n",
- "```"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "import wandb\n",
- "wandb.login()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## ⚱ Logging the Raw Training Data as an Artifact"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "#@title Enter your W&B project and entity\n",
- "\n",
- "# FORM VARIABLES\n",
- "PROJECT_NAME = \"pytorch-lightning-e2e\" #@param {type:\"string\"}\n",
- "ENTITY = \"wandb\"#@param {type:\"string\"}\n",
- "\n",
- "# set SIZE to \"TINY\", \"SMALL\", \"MEDIUM\", or \"LARGE\"\n",
- "# to select one of these three datasets\n",
- "# TINY dataset: 100 images, 30MB\n",
- "# SMALL dataset: 1000 images, 312MB\n",
- "# MEDIUM dataset: 5000 images, 1.5GB\n",
- "# LARGE dataset: 12,000 images, 3.6GB\n",
- "\n",
- "SIZE = \"TINY\"\n",
- "\n",
- "if SIZE == \"TINY\":\n",
- " src_url = \"https://storage.googleapis.com/wandb_datasets/nature_100.zip\"\n",
- " src_zip = \"nature_100.zip\"\n",
- " DATA_SRC = \"nature_100\"\n",
- " IMAGES_PER_LABEL = 10\n",
- " BALANCED_SPLITS = {\"train\" : 8, \"val\" : 1, \"test\": 1}\n",
- "elif SIZE == \"SMALL\":\n",
- " src_url = \"https://storage.googleapis.com/wandb_datasets/nature_1K.zip\"\n",
- " src_zip = \"nature_1K.zip\"\n",
- " DATA_SRC = \"nature_1K\"\n",
- " IMAGES_PER_LABEL = 100\n",
- " BALANCED_SPLITS = {\"train\" : 80, \"val\" : 10, \"test\": 10}\n",
- "elif SIZE == \"MEDIUM\":\n",
- " src_url = \"https://storage.googleapis.com/wandb_datasets/nature_12K.zip\"\n",
- " src_zip = \"nature_12K.zip\"\n",
- " DATA_SRC = \"inaturalist_12K/train\" # (technically a subset of only 10K images)\n",
- " IMAGES_PER_LABEL = 500\n",
- " BALANCED_SPLITS = {\"train\" : 400, \"val\" : 50, \"test\": 50}\n",
- "elif SIZE == \"LARGE\":\n",
- " src_url = \"https://storage.googleapis.com/wandb_datasets/nature_12K.zip\"\n",
- " src_zip = \"nature_12K.zip\"\n",
- " DATA_SRC = \"inaturalist_12K/train\" # (technically a subset of only 10K images)\n",
- " IMAGES_PER_LABEL = 1000\n",
- " BALANCED_SPLITS = {\"train\" : 800, \"val\" : 100, \"test\": 100}"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "%%capture\n",
- "!curl -SL $src_url > $src_zip\n",
- "!unzip $src_zip"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "import wandb\n",
- "import pandas as pd\n",
- "import os\n",
- "\n",
- "with wandb.init(project=PROJECT_NAME, entity=ENTITY, job_type='log_datasets') as run:\n",
- " img_paths = []\n",
- " for root, dirs, files in os.walk('nature_100', topdown=False):\n",
- " for name in files:\n",
- " img_path = os.path.join(root, name)\n",
- " label = img_path.split('/')[1]\n",
- " img_paths.append([img_path, label])\n",
- "\n",
- " index_df = pd.DataFrame(columns=['image_path', 'label'], data=img_paths)\n",
- " index_df.to_csv('index.csv', index=False)\n",
- "\n",
- " train_art = wandb.Artifact(name='Nature_100', type='raw_images', description='nature image dataset with 10 classes, 10 images per class')\n",
- " train_art.add_dir('nature_100')\n",
- "\n",
- " # Also adding a csv indicating the labels of each image\n",
- " train_art.add_file('index.csv')\n",
- " wandb.log_artifact(train_art)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Using Artifacts in Pytorch Lightning `DataModule`'s and Pytorch `Dataset`'s\n",
- "- Makes it easy to interopt your DataLoaders with new versions of datasets\n",
- "- Just indicate the `name:alias` as an argument to your `Dataset` or `DataModule`\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from torchvision import transforms\n",
- "import pytorch_lightning as pl\n",
- "import torch\n",
- "from torch.utils.data import Dataset, DataLoader, random_split\n",
- "from skimage import io, transform\n",
- "from torchvision import transforms, utils, models\n",
- "import math\n",
- "\n",
- "class NatureDataset(Dataset):\n",
- " def __init__(self, \n",
- " wandb_run, \n",
- " artifact_name_alias=\"Nature_100:latest\", \n",
- " local_target_dir=\"Nature_100:latest\", \n",
- " transform=None):\n",
- " self.local_target_dir = local_target_dir\n",
- " self.transform = transform\n",
- "\n",
- " # Pull down the artifact locally to load it into memory\n",
- " art = wandb_run.use_artifact(artifact_name_alias)\n",
- " path_at = art.download(root=self.local_target_dir)\n",
- "\n",
- " self.ref_df = pd.read_csv(os.path.join(self.local_target_dir, 'index.csv'))\n",
- " self.class_names = self.ref_df.iloc[:, 1].unique().tolist()\n",
- " self.idx_to_class = {k: v for k, v in enumerate(self.class_names)}\n",
- " self.class_to_idx = {v: k for k, v in enumerate(self.class_names)}\n",
- "\n",
- " def __len__(self):\n",
- " return len(self.ref_df)\n",
- "\n",
- " def __getitem__(self, idx):\n",
- " if torch.is_tensor(idx):\n",
- " idx = idx.tolist()\n",
- "\n",
- " img_path = self.ref_df.iloc[idx, 0]\n",
- "\n",
- " image = io.imread(img_path)\n",
- " label = self.ref_df.iloc[idx, 1]\n",
- " label = torch.tensor(self.class_to_idx[label], dtype=torch.long)\n",
- "\n",
- " if self.transform:\n",
- " image = self.transform(image)\n",
- "\n",
- " return image, label\n",
- "\n",
- "\n",
- "class NatureDatasetModule(pl.LightningDataModule):\n",
- " def __init__(self,\n",
- " wandb_run,\n",
- " artifact_name_alias: str = \"Nature_100:latest\",\n",
- " local_target_dir: str = \"Nature_100:latest\",\n",
- " batch_size: int = 16,\n",
- " input_size: int = 224,\n",
- " seed: int = 42):\n",
- " super().__init__()\n",
- " self.wandb_run = wandb_run\n",
- " self.artifact_name_alias = artifact_name_alias\n",
- " self.local_target_dir = local_target_dir\n",
- " self.batch_size = batch_size\n",
- " self.input_size = input_size\n",
- " self.seed = seed\n",
- "\n",
- " def setup(self, stage=None):\n",
- " self.nature_dataset = NatureDataset(wandb_run=self.wandb_run,\n",
- " artifact_name_alias=self.artifact_name_alias,\n",
- " local_target_dir=self.local_target_dir,\n",
- " transform=transforms.Compose([transforms.ToTensor(),\n",
- " transforms.CenterCrop(self.input_size),\n",
- " transforms.Normalize((0.485, 0.456, 0.406),\n",
- " (0.229, 0.224, 0.225))]))\n",
- "\n",
- " nature_length = len(self.nature_dataset)\n",
- " train_size = math.floor(0.8 * nature_length)\n",
- " val_size = math.floor(0.2 * nature_length)\n",
- " self.nature_train, self.nature_val = random_split(self.nature_dataset,\n",
- " [train_size, val_size],\n",
- " generator=torch.Generator().manual_seed(self.seed))\n",
- " return self\n",
- "\n",
- " def train_dataloader(self):\n",
- " return DataLoader(self.nature_train, batch_size=self.batch_size)\n",
- "\n",
- " def val_dataloader(self):\n",
- " return DataLoader(self.nature_val, batch_size=self.batch_size)\n",
- "\n",
- " def predict_dataloader(self):\n",
- " pass\n",
- "\n",
- " def teardown(self, stage: str):\n",
- " pass"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "##How Logging in your Pytorch `LightningModule`works:\n",
- "When you train the model using `Trainer`, ensure you have a `WandbLogger` instantiated and passed in as a `logger`. \n",
- " \n",
- "```\n",
- "wandb_logger = WandbLogger(project=\"my_project\", entity=\"machine-learning\") \n",
- "trainer = Trainer(logger=wandb_logger) \n",
- "```\n",
- "\n",
- "\n",
- "You can always use `wandb.log` as normal throughout the module. When the `WandbLogger` is used, `self.log` will also log metrics to W&B. \n",
- "- To access the current run from within the `LightningModule`, you can access `Trainer.logger.experiment`, which is a `wandb.Run` object"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Some helper functions"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# Some helper functions\n",
- "\n",
- "def set_parameter_requires_grad(model, feature_extracting):\n",
- " if feature_extracting:\n",
- " for param in model.parameters():\n",
- " param.requires_grad = False\n",
- "\n",
- "def initialize_model(model_name, num_classes, feature_extract, use_pretrained=True):\n",
- " # Initialize these variables which will be set in this if statement. Each of these\n",
- " # variables is model specific.\n",
- " model_ft = None\n",
- " input_size = 0\n",
- "\n",
- " if model_name == \"resnet\":\n",
- " \"\"\" Resnet18\n",
- " \"\"\"\n",
- " model_ft = models.resnet18(pretrained=use_pretrained)\n",
- " set_parameter_requires_grad(model_ft, feature_extract)\n",
- " num_ftrs = model_ft.fc.in_features\n",
- " model_ft.fc = torch.nn.Linear(num_ftrs, num_classes)\n",
- " input_size = 224\n",
- "\n",
- " elif model_name == \"squeezenet\":\n",
- " \"\"\" Squeezenet\n",
- " \"\"\"\n",
- " model_ft = models.squeezenet1_0(pretrained=use_pretrained)\n",
- " set_parameter_requires_grad(model_ft, feature_extract)\n",
- " model_ft.classifier[1] = torch.nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1))\n",
- " model_ft.num_classes = num_classes\n",
- " input_size = 224\n",
- "\n",
- " elif model_name == \"densenet\":\n",
- " \"\"\" Densenet\n",
- " \"\"\"\n",
- " model_ft = models.densenet121(pretrained=use_pretrained)\n",
- " set_parameter_requires_grad(model_ft, feature_extract)\n",
- " num_ftrs = model_ft.classifier.in_features\n",
- " model_ft.classifier = torch.nn.Linear(num_ftrs, num_classes)\n",
- " input_size = 224\n",
- "\n",
- " else:\n",
- " print(\"Invalid model name, exiting...\")\n",
- " exit()\n",
- "\n",
- " return model_ft, input_size"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Writing the `LightningModule`"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "import torch\n",
- "from torch.nn import Linear, CrossEntropyLoss, functional as F\n",
- "from torch.optim import Adam\n",
- "from torchmetrics.functional import accuracy\n",
- "from pytorch_lightning import LightningModule\n",
- "from torchvision import models\n",
- "\n",
- "class NatureLitModule(LightningModule):\n",
- " def __init__(self,\n",
- " model_name,\n",
- " num_classes=10,\n",
- " feature_extract=True,\n",
- " lr=0.01):\n",
- " '''method used to define our model parameters'''\n",
- " super().__init__()\n",
- "\n",
- " self.model_name = model_name\n",
- " self.num_classes = num_classes\n",
- " self.feature_extract = feature_extract\n",
- " self.model, self.input_size = initialize_model(model_name=self.model_name,\n",
- " num_classes=self.num_classes,\n",
- " feature_extract=True)\n",
- "\n",
- " # loss\n",
- " self.loss = CrossEntropyLoss()\n",
- "\n",
- " # optimizer parameters\n",
- " self.lr = lr\n",
- "\n",
- " # save hyper-parameters to self.hparams (auto-logged by W&B)\n",
- " self.save_hyperparameters()\n",
- "\n",
- " # Record the gradients of all the layers\n",
- " wandb.watch(self.model)\n",
- "\n",
- " def forward(self, x):\n",
- " '''method used for inference input -> output'''\n",
- " x = self.model(x)\n",
- "\n",
- " return x\n",
- "\n",
- " def training_step(self, batch, batch_idx):\n",
- " '''needs to return a loss from a single batch'''\n",
- " preds, y, loss, acc = self._get_preds_loss_accuracy(batch)\n",
- "\n",
- " # Log loss and metric\n",
- " self.log('train/loss', loss)\n",
- " self.log('train/accuracy', acc)\n",
- "\n",
- " return loss\n",
- "\n",
- " def validation_step(self, batch, batch_idx):\n",
- " '''used for logging metrics'''\n",
- " preds, y, loss, acc = self._get_preds_loss_accuracy(batch)\n",
- "\n",
- " # Log loss and metric\n",
- " self.log('validation/loss', loss)\n",
- " self.log('validation/accuracy', acc)\n",
- "\n",
- " # Let's return preds to use it in a custom callback\n",
- " return preds, y\n",
- "\n",
- " def validation_epoch_end(self, validation_step_outputs):\n",
- " \"\"\"Called when the validation ends.\"\"\"\n",
- " preds, y = validation_step_outputs\n",
- " all_preds = torch.stack(preds)\n",
- " all_y = torch.stack(y)\n",
- "\n",
- " def test_step(self, batch, batch_idx):\n",
- " '''used for logging metrics'''\n",
- " preds, y, loss, acc = self._get_preds_loss_accuracy(batch)\n",
- "\n",
- " # Log loss and metric\n",
- " self.log('test/loss', loss)\n",
- " self.log('test/accuracy', acc)\n",
- "\n",
- " def configure_optimizers(self):\n",
- " '''defines model optimizer'''\n",
- " return Adam(self.parameters(), lr=self.lr)\n",
- "\n",
- "\n",
- " def _get_preds_loss_accuracy(self, batch):\n",
- " '''convenience function since train/valid/test steps are similar'''\n",
- " x, y = batch\n",
- " logits = self(x)\n",
- " preds = torch.argmax(logits, dim=1)\n",
- " loss = self.loss(logits, y)\n",
- " acc = accuracy(preds, y, task=\"multiclass\", num_classes=10)\n",
- " return preds, y, loss, acc"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Instrument Callbacks to log additional things at certain points in your code"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from pytorch_lightning.callbacks import Callback\n",
- "\n",
- "class LogPredictionsCallback(Callback):\n",
- "\n",
- " def __init__(self):\n",
- " super().__init__()\n",
- "\n",
- " \n",
- " def on_validation_epoch_start(self, trainer, pl_module):\n",
- " self.batch_dfs = []\n",
- " self.image_list = []\n",
- " self.val_table = wandb.Table(columns=['image', 'ground_truth', 'prediction'])\n",
- "\n",
- " \n",
- " def on_validation_batch_end(\n",
- " self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):\n",
- " \"\"\"Called when the validation batch ends.\"\"\"\n",
- "\n",
- " # Append validation predictions and ground truth to log in confusion matrix\n",
- " x, y = batch\n",
- " preds, y = outputs\n",
- " self.batch_dfs.append(pd.DataFrame({\"Ground Truth\": y.numpy(), \"Predictions\": preds.numpy()}))\n",
- "\n",
- " # Add wandb.Image to a table to log at the end of validation\n",
- " x = x.numpy().transpose(0, 2, 3, 1)\n",
- " for x_i, y_i, y_pred in list(zip(x, y, preds)):\n",
- " self.image_list.append(wandb.Image(x_i, caption=f'Ground Truth: {y_i} - Prediction: {y_pred}'))\n",
- " self.val_table.add_data(wandb.Image(x_i), y_i, y_pred)\n",
- " \n",
- " \n",
- " def on_validation_epoch_end(self, trainer, pl_module):\n",
- " # Collect statistics for whole validation set and log\n",
- " class_names = trainer.datamodule.nature_dataset.class_names\n",
- " val_df = pd.concat(self.batch_dfs)\n",
- " wandb.log({\"validation_table\": self.val_table,\n",
- " \"images_over_time\": self.image_list,\n",
- " \"validation_conf_matrix\": wandb.plot.confusion_matrix(y_true = val_df[\"Ground Truth\"].tolist(), \n",
- " preds=val_df[\"Predictions\"].tolist(), \n",
- " class_names=class_names)}, step=trainer.global_step)\n",
- "\n",
- " del self.batch_dfs\n",
- " del self.val_table\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 🏋️ Main Training Loop"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from pytorch_lightning.callbacks import ModelCheckpoint\n",
- "from pytorch_lightning.loggers import WandbLogger\n",
- "from pytorch_lightning import Trainer\n",
- "\n",
- "wandb.init(project=PROJECT_NAME,\n",
- " entity=ENTITY,\n",
- " job_type='training',\n",
- " config={\n",
- " \"model_name\": \"squeezenet\",\n",
- " \"batch_size\": 16\n",
- " })\n",
- "\n",
- "wandb_logger = WandbLogger(log_model='all', checkpoint_name=f'nature-{wandb.run.id}') \n",
- "\n",
- "log_predictions_callback = LogPredictionsCallback()\n",
- "checkpoint_callback = ModelCheckpoint(every_n_epochs=1)\n",
- "\n",
- "model = NatureLitModule(model_name=wandb.config['model_name']) # Access hyperparameters downstream to instantiate models/datasets\n",
- "\n",
- "nature_module = NatureDatasetModule(wandb_run = wandb_logger.experiment,\n",
- " artifact_name_alias = \"Nature_100:latest\",\n",
- " local_target_dir = \"Nature_100:latest\",\n",
- " batch_size=wandb.config['batch_size'],\n",
- " input_size=model.input_size)\n",
- "nature_module.setup()\n",
- "\n",
- "trainer = Trainer(logger=wandb_logger, # W&B integration\n",
- " callbacks=[log_predictions_callback, checkpoint_callback],\n",
- " max_epochs=5,\n",
- " log_every_n_steps=5) \n",
- "trainer.fit(model, datamodule=nature_module)\n",
- "\n",
- "wandb.finish()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Syncing with W&B Offline\n",
- "If for some reason, network communication is lost during the course of training, you can always sync progress with `wandb sync`\n",
- "\n",
- "The W&B sdk caches all logged data in a local directory `wandb` and when you call `wandb sync`, this syncs the your local state with the web app. "
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Retrieve a model checkpoint artifact and resume training\n",
- "- Artifacts make it easy to track state of your training remotely and then resume training from a checkpoint"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "#@title Enter which checkpoint you want to resume training from:\n",
- "\n",
- "# FORM VARIABLES\n",
- "ARTIFACT_NAME_ALIAS = \"nature-zb4swpn6:v4\" #@param {type:\"string\"}"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "wandb.init(project=PROJECT_NAME,\n",
- " entity=ENTITY,\n",
- " job_type='resume_training')\n",
- "\n",
- "# Retrieve model checkpoint artifact and restore previous hyperparameters\n",
- "model_chkpt_art = wandb.use_artifact(f'{ENTITY}/{PROJECT_NAME}/{ARTIFACT_NAME_ALIAS}')\n",
- "model_chkpt_art.download() # Can change download directory by adding `root`, defaults to \"./artifacts\"\n",
- "logging_run = model_chkpt_art.logged_by()\n",
- "wandb.config = logging_run.config\n",
- "\n",
- "# Can create a new artifact name or continue logging to the old one\n",
- "artifact_name = ARTIFACT_NAME_ALIAS.split(\":\")[0]\n",
- "wandb_logger = WandbLogger(log_model='all', checkpoint_name=artifact_name) \n",
- "\n",
- "log_predictions_callback = LogPredictionsCallback()\n",
- "checkpoint_callback = ModelCheckpoint(every_n_epochs=1)\n",
- "\n",
- "model = NatureLitModule(model_name=wandb.config['model_name']) # Access hyperparameters downstream to instantiate models/datasets\n",
- "\n",
- "nature_module = NatureDatasetModule(wandb_run = wandb_logger.experiment,\n",
- " artifact_name_alias = \"Nature_100:latest\",\n",
- " local_target_dir = \"Nature_100:latest\",\n",
- " batch_size=wandb.config['batch_size'],\n",
- " input_size=model.input_size)\n",
- "nature_module.setup()\n",
- "\n",
- "\n",
- "\n",
- "trainer = Trainer(logger=wandb_logger, # W&B integration\n",
- " resume_from_checkpoint = f'./artifacts/{ARTIFACT_NAME_ALIAS}/model.ckpt',\n",
- " callbacks=[log_predictions_callback, checkpoint_callback],\n",
- " max_epochs=10,\n",
- " log_every_n_steps=5) \n",
- "trainer.fit(model, datamodule=nature_module)\n",
- "\n",
- "wandb.finish()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Model Registry \n",
- "After logging a bunch of checkpoints across multiple runs during experimentation, now comes time to hand-off the best checkpoint to the next stage of the workflow (e.g. testing, deployment).\n",
- "\n",
- "The model registry offers a centralized place to house the best checkpoints for all your model tasks. Any `model` artifact you log can be \"linked\" to a Registered Model. Here are the steps to start using the model registry for more organized model management:\n",
- "1. Access your team's model registry by going the team page and selecting `Model Registry`\n",
- "![model registry](https://drive.google.com/uc?export=view&id=1ZtJwBsFWPTm4Sg5w8vHhRpvDSeQPwsKw)\n",
- "\n",
- "2. Create a new Registered Model. \n",
- "![model registry](https://drive.google.com/uc?export=view&id=1RuayTZHNE0LJCxt1t0l6-2zjwiV4aDXe)\n",
- "\n",
- "3. Go to the artifacts tab of the project that holds all your model checkpoints\n",
- "![model registry](https://drive.google.com/uc?export=view&id=1r_jlhhtcU3as8VwQ-4oAntd8YtTwElFB)\n",
- "\n",
- "4. Click \"Link to Registry\" for the model artifact version you want. (Alternatively you can [link a model via api](https://docs.wandb.ai/guides/models) with `wandb.run.link_artifact`)\n",
- "\n",
- "**A note on linking:** The process of linking a model checkpoint is akin to \"bookmarking\" it. Each time you link a new model artifact to a Registered Model, this increments the version of the Registered Model. This helps delineate the model development side of the workflow from the model deployment/consumption side. The globally understood version/alias of a model should be unpolluted from all the experimental versions being generated in R&D and thus the versioning of a Registered Model increments according to new \"bookmarked\" models as opposed to model checkpoint logging. \n",
- "\n",
- "\n",
- "### Create a Centralized Hub for all your models\n",
- "- Add a model card, tags, slack notifactions to your Registered Model\n",
- "- Change aliases to reflect when models move through different phases\n",
- "- Embed the model registry in reports for model documentation and regression reports. See this report as an [example](https://api.wandb.ai/links/wandb-smle/r82bj9at)\n",
- "![model registry](https://drive.google.com/uc?export=view&id=1lKPgaw-Ak4WK_91aBMcLvUMJL6pDQpgO)\n"
- ]
- }
- ],
- "metadata": {
- "accelerator": "GPU",
- "colab": {
- "include_colab_link": true,
- "provenance": [],
- "toc_visible": true
- },
- "kernelspec": {
- "display_name": "Python 3",
- "name": "python3"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 0
-}
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
\ No newline at end of file
diff --git a/colabs/pytorch/How_does_adding_dropout_affect_model_performance.ipynb b/colabs/pytorch/How_does_adding_dropout_affect_model_performance.ipynb
index 41b8b3fa..be25a18c 100644
--- a/colabs/pytorch/How_does_adding_dropout_affect_model_performance.ipynb
+++ b/colabs/pytorch/How_does_adding_dropout_affect_model_performance.ipynb
@@ -48,8 +48,15 @@
"metadata": {},
"outputs": [],
"source": [
- "import wandb\n",
- "\n",
+ "import wandb\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"wandb.login()"
]
},
diff --git a/colabs/pytorch/Organizing_Hyperparameter_Sweeps_in_PyTorch_with_W&B.ipynb b/colabs/pytorch/Organizing_Hyperparameter_Sweeps_in_PyTorch_with_W&B.ipynb
index 00a27712..1ef18a83 100644
--- a/colabs/pytorch/Organizing_Hyperparameter_Sweeps_in_PyTorch_with_W&B.ipynb
+++ b/colabs/pytorch/Organizing_Hyperparameter_Sweeps_in_PyTorch_with_W&B.ipynb
@@ -98,8 +98,15 @@
"metadata": {},
"outputs": [],
"source": [
- "import wandb\n",
- "\n",
+ "import wandb\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"wandb.login()"
]
},
diff --git a/colabs/pytorch/Simple_PyTorch_Integration.ipynb b/colabs/pytorch/Simple_PyTorch_Integration.ipynb
index f261fd3a..9369972e 100644
--- a/colabs/pytorch/Simple_PyTorch_Integration.ipynb
+++ b/colabs/pytorch/Simple_PyTorch_Integration.ipynb
@@ -186,8 +186,15 @@
"metadata": {},
"outputs": [],
"source": [
- "import wandb\n",
- "\n",
+ "import wandb\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"wandb.login()"
]
},
diff --git a/colabs/raytune/RayTune_with_wandb.ipynb b/colabs/raytune/RayTune_with_wandb.ipynb
index 42b6e272..00dc6caf 100644
--- a/colabs/raytune/RayTune_with_wandb.ipynb
+++ b/colabs/raytune/RayTune_with_wandb.ipynb
@@ -1,329 +1,409 @@
{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "\n",
- ""
- ]
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "CHj4COZvzWta"
+ },
+ "source": [
+ "\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "f5SyUYlHzWtd"
+ },
+ "source": [
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Use Weights & Biases for machine learning experiment tracking, dataset versioning, and project collaboration.\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "oQYw5uVFzWtd"
+ },
+ "source": [
+ "\n",
+ "# 🌞 Ray/Tune and 🏋️♀️ Weights & Biases\n",
+ "\n",
+ "Both Weights and Biases and Ray/Tune are built for scale and handle millions of models every month for teams doing some of the most cutting-edge deep learning research.\n",
+ "\n",
+ "[W&B](https://wandb.com) is a toolkit with everything you need to track, reproduce, and gain insights from your models easily; [Ray/Tune](https://docs.ray.io/en/latest/tune/) provides a simple interface for scaling and running distributed experiments.\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "SFI8z2gWzWte"
+ },
+ "source": [
+ "\n",
+ "### 🤝 They're a natural match! 🤝\n",
+ "\n",
+ "Here's just a few reasons why our community likes Ray/Tune –\n",
+ "\n",
+ "* **Simple distributed execution**: Ray/Tune makes it easy to scale all the way from a single node on a laptop, through to multiple GPUs, and up to multiple nodes on multiple machines.\n",
+ "* **State-of-the-art algorithms**: Ray/Tune has tested implementations of a huge number of potent scheduling algorithms including\n",
+ "[Population-Based Training](https://docs.ray.io/en/latest/tune/tutorials/tune-advanced-tutorial.html),\n",
+ "[ASHA](https://docs.ray.io/en/master/tune/tutorials/tune-tutorial.html#early-stopping-with-asha),\n",
+ "and\n",
+ "[HyperBand](https://docs.ray.io/en/latest/tune/api_docs/schedulers.html#hyperband-tune-schedulers-hyperbandscheduler)\n",
+ "* **Method agnostic**: Ray/Tune works across deep learning frameworks (including PyTorch, Keras, Tensorflow, and PyTorchLightning) and with other ML methods like gradient-boosted trees (XGBoost, LightGBM)\n",
+ "* **Fault-tolerance**: Ray/Tune is built on top of Ray, providing tolerance for failed runs out of the box.\n",
+ "\n",
+ "This Colab demonstrates how this integration works for a simple grid search over two hyperparameters. If you've got any questions about the details,\n",
+ "check out\n",
+ "[our documentation](https://docs.wandb.com/library/integrations/ray-tune)\n",
+ "or the\n",
+ "[documentation for Ray/Tune](https://docs.ray.io/en/master/tune/api_docs/integration.html#weights-and-biases-tune-integration-wandb).\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Hw1mGxhlzWte"
+ },
+ "source": [
+ "W&B integrates with `ray.tune` by offering two lightweight standalone integrations:\n",
+ "\n",
+ "1. For simple cases, `WandbLoggerCallback` automatically logs metrics reported to Tune to W&B, along with the configuration of the experiment, using Tune's [`logger` interface](https://docs.ray.io/en/latest/tune/api_docs/logging.html).\n",
+ "2. The `@wandb_mixin` decorator gives you greater control over logging by letting you call `wandb.log` inside the decorated function, allowing you to [log custom metrics, plots, and other outputs, like media](https://colab.research.google.com/github/wandb/examples/blob/master/colabs/wandb-log/Log_(Almost)_Anything_with_W%26B_Media.ipynb).\n",
+ "\n",
+ "These methods can be used together or independently.\n",
+ "\n",
+ "The example below demonstrates how they can be used together."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "qRrP4G23zWte"
+ },
+ "source": [
+ "# 🧹 Running a hyperparameter sweep with W&B and Ray/Tune"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "QfDVmu9szWte"
+ },
+ "source": [
+ "## 📥 Install, `import`, and set seeds"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "XRjOn9QRzWte"
+ },
+ "source": [
+ "\n",
+ "Let's start by installing the libraries and importing everything we need.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "LCVSvebNzWtf"
+ },
+ "outputs": [],
+ "source": [
+ "!pip install -Uq ray[tune] wandb"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "qBCk5SvQ1q_n"
+ },
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "os.environ[\"WANDB_API_KEY\"] = \"654d20637b9232f03417d6ae6e620f3b7cb70125\" #@param {type:\"string\"}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "dfHgJfVszWtf"
+ },
+ "outputs": [],
+ "source": [
+ "import random\n",
+ "import numpy as np\n",
+ "from filelock import FileLock\n",
+ "import tempfile\n",
+ "from ray import train, tune\n",
+ "from ray.air.integrations.wandb import WandbLoggerCallback, setup_wandb\n",
+ "from ray.tune.schedulers import AsyncHyperBandScheduler\n",
+ "from ray.train import Checkpoint\n",
+ "import torch\n",
+ "import torch.optim as optim\n",
+ "import wandb"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "Y1ysNX6pzWtf",
+ "outputId": "f6a36dad-1d03-4f13-d552-e00896d919db"
+ },
+ "outputs": [],
+ "source": [
+ "wandb.login()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "mYPSEuIgzWtg"
+ },
+ "source": [
+ "We'll make use of Ray's handy [`mnist_pytorch` example code](https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/mnist_pytorch.py)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "8b-t9a7pzWtg"
+ },
+ "outputs": [],
+ "source": [
+ "from ray.tune.examples.mnist_pytorch import ConvNet, get_data_loaders, test_func, train_func"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "BJ-zAJWpzWtg"
+ },
+ "source": [
+ "In order to make this experiment reproducible, we'll set the seeds for random number generators of various libraries used in this experiment."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "GbXqaFYizWtg"
+ },
+ "outputs": [],
+ "source": [
+ "torch.backends.cudnn.deterministic = True\n",
+ "random.seed(2022)\n",
+ "np.random.seed(2022)\n",
+ "torch.manual_seed(2022)\n",
+ "torch.cuda.manual_seed_all(2022)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "nAmSUMP4zWtg"
+ },
+ "source": [
+ "## 🤝 Integrating W&B with Ray/Tune"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "OVmU1CRgzWtg"
+ },
+ "source": [
+ "Now, we define our training process, decorated with `@wandb_mixin` so we can call `wandb.log` to log our custom metric\n",
+ "(here, just the error rate; you might also [log media](https://docs.wandb.com/library/log#media), e.g. images from the validation set, captioned by the model predictions).\n",
+ "\n",
+ "When we execute our hyperparameter sweep below,\n",
+ "this function will be called with a `config`uration dictionary\n",
+ "that contains values for any hyperparameters.\n",
+ "For simplicity, we only have two hyperparameters here:\n",
+ "the learning rate and momentum value for accelerated SGD."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "nf2FKgaM21Em"
+ },
+ "outputs": [],
+ "source": [
+ "def train_mnist(config):\n",
+ " # Setup wandb\n",
+ " wandb = setup_wandb(config)\n",
+ " should_checkpoint = config.get(\"should_checkpoint\", False)\n",
+ " use_cuda = torch.cuda.is_available()\n",
+ " device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n",
+ " train_loader, test_loader = get_data_loaders()\n",
+ " model = ConvNet().to(device)\n",
+ "\n",
+ " optimizer = optim.SGD(\n",
+ " model.parameters(), lr=config[\"lr\"], momentum=config[\"momentum\"]\n",
+ " )\n",
+ " while True:\n",
+ " train_func(model, optimizer, train_loader, device)\n",
+ " acc = test_func(model, test_loader, device)\n",
+ " metrics = {\"mean_accuracy\": acc}\n",
+ "\n",
+ " # Report metrics (and possibly a checkpoint)\n",
+ " if should_checkpoint:\n",
+ " with tempfile.TemporaryDirectory() as tempdir:\n",
+ " torch.save(model.state_dict(), os.path.join(tempdir, \"model.pt\"))\n",
+ " train.report(metrics, checkpoint=Checkpoint.from_directory(tempdir))\n",
+ " else:\n",
+ " train.report(metrics)\n",
+ " # enables logging custom metrics using wandb.log()\n",
+ " error_rate = 100 * (1 - acc)\n",
+ " wandb.log({\"error_rate\": error_rate})"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "FjSalAQkzWth"
+ },
+ "source": [
+ "## 🚀 Launching a Sweep with W&B and Ray/Tune"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "lBS6iyXAzWth"
+ },
+ "source": [
+ "We're now almost ready to call `tune.run` to launch our hyperparameter sweep!\n",
+ "We just need to do three things:\n",
+ "1. set up a `wandb.Run`,\n",
+ "2. give the `WandbLoggerCallback` to `tune.run` so we can capture the output of `tune.report`, and\n",
+ "3. set up our hyperparameter sweep.\n",
+ "\n",
+ "A `wandb.Run` is normally created by calling `wandb.init`.\n",
+ "`tune` will handle that for you, you just need to pass\n",
+ "the arguments as a dictionary\n",
+ "(see [our documentation](https://docs.wandb.com/library/init) for details on `wandb.init`).\n",
+ "At the bare minimum, you need to pass in a `project` name --\n",
+ "sort of like a `git` repo name, but for your ML projects.\n",
+ "\n",
+ "In addition to holding arguments for `wandb.init`,\n",
+ "that dictionary also has a few special keys, described in\n",
+ "[the documentation for the `WandbLoggerCallback`](https://docs.ray.io/en/master/tune/tutorials/tune-wandb.html)."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "E6HRghT9zWth"
+ },
+ "source": [
+ "We handle steps 2 and 3 when we invoke `tune.run`.\n",
+ "\n",
+ "Step 2 is handled by passing in the `WandbLoggerCallback` class in a list\n",
+ "to the `loggers` argument of `tune.run`.\n",
+ "\n",
+ "The setup of the hyperparameter sweep is handled by the\n",
+ "`config` argument of `tune.run`.\n",
+ "For the purposes of the integration,\n",
+ "the most important part is that this is where we pass in the `wandb_init`\n",
+ "dictionary.\n",
+ "\n",
+ "This is also where we configure the \"meat\" of the hyperparameter sweep:\n",
+ "what are the hyperparameters we're sweeping over,\n",
+ "and how do we choose their values.\n",
+ "\n",
+ "Here, we do a simple grid search, but\n",
+ "[Ray/Tune provides lots of sophisticated options](https://docs.ray.io/en/latest/tune/api_docs/suggestion.html)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "tw4eawRD0u8y",
+ "outputId": "e4ddf59b-5445-4d55-8f69-84a9fcb35024"
+ },
+ "outputs": [],
+ "source": [
+ "# for early stopping\n",
+ "sched = AsyncHyperBandScheduler()\n",
+ "\n",
+ "resources_per_trial = {\"gpu\": 1} # set this for GPUs\n",
+ "tuner = tune.Tuner(\n",
+ " tune.with_resources(train_mnist, resources=resources_per_trial),\n",
+ " tune_config=tune.TuneConfig(\n",
+ " metric=\"mean_accuracy\",\n",
+ " mode=\"max\",\n",
+ " scheduler=sched,\n",
+ " num_samples=50,\n",
+ " ),\n",
+ " run_config=train.RunConfig(\n",
+ " name=\"exp\",\n",
+ " stop={\n",
+ " \"mean_accuracy\": 0.98,\n",
+ " \"training_iteration\": 5,\n",
+ " },\n",
+ " callbacks=[WandbLoggerCallback(project=\"raytune-colab\")]\n",
+ " ),\n",
+ " param_space={\n",
+ " \"lr\": tune.loguniform(1e-4, 1e-2),\n",
+ " \"momentum\": tune.uniform(0.1, 0.9),\n",
+ " },\n",
+ ")\n",
+ "results = tuner.fit()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "1f8oI7tYzWth",
+ "outputId": "9fece67a-47e7-443a-ab8f-971f3d8dfd94"
+ },
+ "outputs": [],
+ "source": [
+ "print(\"Best config is:\", results.get_best_result().config)"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "provenance": [],
+ "toc_visible": true
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ }
},
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "\n",
- "\n",
- "\n",
- "\n",
- "Use Weights & Biases for machine learning experiment tracking, dataset versioning, and project collaboration.\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- ""
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "\n",
- "# 🌞 Ray/Tune and 🏋️♀️ Weights & Biases \n",
- "\n",
- "Both Weights and Biases and Ray/Tune are built for scale and handle millions of models every month for teams doing some of the most cutting-edge deep learning research.\n",
- "\n",
- "[W&B](https://wandb.com) is a toolkit with everything you need to track, reproduce, and gain insights from your models easily; [Ray/Tune](https://docs.ray.io/en/latest/tune/) provides a simple interface for scaling and running distributed experiments.\n"
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "\n",
- "### 🤝 They're a natural match! 🤝\n",
- "\n",
- "Here's just a few reasons why our community likes Ray/Tune –\n",
- "\n",
- "* **Simple distributed execution**: Ray/Tune makes it easy to scale all the way from a single node on a laptop, through to multiple GPUs, and up to multiple nodes on multiple machines.\n",
- "* **State-of-the-art algorithms**: Ray/Tune has tested implementations of a huge number of potent scheduling algorithms including\n",
- "[Population-Based Training](https://docs.ray.io/en/latest/tune/tutorials/tune-advanced-tutorial.html),\n",
- "[ASHA](https://docs.ray.io/en/master/tune/tutorials/tune-tutorial.html#early-stopping-with-asha),\n",
- "and\n",
- "[HyperBand](https://docs.ray.io/en/latest/tune/api_docs/schedulers.html#hyperband-tune-schedulers-hyperbandscheduler)\n",
- "* **Method agnostic**: Ray/Tune works across deep learning frameworks (including PyTorch, Keras, Tensorflow, and PyTorchLightning) and with other ML methods like gradient-boosted trees (XGBoost, LightGBM)\n",
- "* **Fault-tolerance**: Ray/Tune is built on top of Ray, providing tolerance for failed runs out of the box.\n",
- "\n",
- "This Colab demonstrates how this integration works for a simple grid search over two hyperparameters. If you've got any questions about the details,\n",
- "check out\n",
- "[our documentation](https://docs.wandb.com/library/integrations/ray-tune)\n",
- "or the\n",
- "[documentation for Ray/Tune](https://docs.ray.io/en/master/tune/api_docs/integration.html#weights-and-biases-tune-integration-wandb).\n"
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "W&B integrates with `ray.tune` by offering two lightweight standalone integrations:\n",
- "\n",
- "1. For simple cases, `WandbLoggerCallback` automatically logs metrics reported to Tune to W&B, along with the configuration of the experiment, using Tune's [`logger` interface](https://docs.ray.io/en/latest/tune/api_docs/logging.html).\n",
- "2. The `@wandb_mixin` decorator gives you greater control over logging by letting you call `wandb.log` inside the decorated function, allowing you to [log custom metrics, plots, and other outputs, like media](https://colab.research.google.com/github/wandb/examples/blob/master/colabs/wandb-log/Log_(Almost)_Anything_with_W%26B_Media.ipynb).\n",
- "\n",
- "These methods can be used together or independently.\n",
- "\n",
- "The example below demonstrates how they can be used together."
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# 🧹 Running a hyperparameter sweep with W&B and Ray/Tune"
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 📥 Install, `import`, and set seeds"
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "\n",
- "Let's start by installing the libraries and importing everything we need.\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "!pip install -Uq ray[tune] wandb"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "import random\n",
- "import numpy as np\n",
- "from ray import tune\n",
- "from ray.air.integrations.wandb import WandbLoggerCallback, setup_wandb\n",
- "import torch\n",
- "import torch.optim as optim\n",
- "import wandb"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "wandb.login()"
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "We'll make use of Ray's handy [`mnist_pytorch` example code](https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/mnist_pytorch.py)."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from ray.tune.examples.mnist_pytorch import ConvNet, get_data_loaders, test, train"
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "In order to make this experiment reproducible, we'll set the seeds for random number generators of various libraries used in this experiment."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "torch.backends.cudnn.deterministic = True\n",
- "random.seed(2022)\n",
- "np.random.seed(2022)\n",
- "torch.manual_seed(2022)\n",
- "torch.cuda.manual_seed_all(2022)"
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 🤝 Integrating W&B with Ray/Tune"
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Now, we define our training process, decorated with `@wandb_mixin` so we can call `wandb.log` to log our custom metric\n",
- "(here, just the error rate; you might also [log media](https://docs.wandb.com/library/log#media), e.g. images from the validation set, captioned by the model predictions).\n",
- "\n",
- "When we execute our hyperparameter sweep below,\n",
- "this function will be called with a `config`uration dictionary\n",
- "that contains values for any hyperparameters.\n",
- "For simplicity, we only have two hyperparameters here:\n",
- "the learning rate and momentum value for accelerated SGD."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "def train_mnist(config):\n",
- " # Setup wandb\n",
- " wandb = setup_wandb(config)\n",
- " device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
- " train_loader, test_loader = get_data_loaders()\n",
- "\n",
- " model = ConvNet()\n",
- " model.to(device)\n",
- "\n",
- " optimizer = optim.SGD(model.parameters(),\n",
- " lr=config[\"lr\"], momentum=config[\"momentum\"])\n",
- " \n",
- " for i in range(10):\n",
- " train(model, optimizer, train_loader, device=device)\n",
- " acc = test(model, test_loader, device=device)\n",
- "\n",
- " # When using WandbLoggerCallback, the metrics reported to tune are also logged in the W&B dashboard\n",
- " tune.report(mean_accuracy=acc)\n",
- "\n",
- " # enables logging custom metrics using wandb.log()\n",
- " error_rate = 100 * (1 - acc)\n",
- " wandb.log({\"error_rate\": error_rate})"
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 🚀 Launching a Sweep with W&B and Ray/Tune"
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "We're now almost ready to call `tune.run` to launch our hyperparameter sweep!\n",
- "We just need to do three things:\n",
- "1. set up a `wandb.Run`,\n",
- "2. give the `WandbLoggerCallback` to `tune.run` so we can capture the output of `tune.report`, and\n",
- "3. set up our hyperparameter sweep.\n",
- "\n",
- "A `wandb.Run` is normally created by calling `wandb.init`.\n",
- "`tune` will handle that for you, you just need to pass\n",
- "the arguments as a dictionary\n",
- "(see [our documentation](https://docs.wandb.com/library/init) for details on `wandb.init`).\n",
- "At the bare minimum, you need to pass in a `project` name --\n",
- "sort of like a `git` repo name, but for your ML projects.\n",
- "\n",
- "In addition to holding arguments for `wandb.init`,\n",
- "that dictionary also has a few special keys, described in\n",
- "[the documentation for the `WandbLoggerCallback`](https://docs.ray.io/en/master/tune/tutorials/tune-wandb.html)."
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "We handle steps 2 and 3 when we invoke `tune.run`.\n",
- "\n",
- "Step 2 is handled by passing in the `WandbLoggerCallback` class in a list\n",
- "to the `loggers` argument of `tune.run`.\n",
- "\n",
- "The setup of the hyperparameter sweep is handled by the\n",
- "`config` argument of `tune.run`.\n",
- "For the purposes of the integration,\n",
- "the most important part is that this is where we pass in the `wandb_init`\n",
- "dictionary.\n",
- "\n",
- "This is also where we configure the \"meat\" of the hyperparameter sweep:\n",
- "what are the hyperparameters we're sweeping over,\n",
- "and how do we choose their values.\n",
- "\n",
- "Here, we do a simple grid search, but\n",
- "[Ray/Tune provides lots of sophisticated options](https://docs.ray.io/en/latest/tune/api_docs/suggestion.html)."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "analysis = tune.run(\n",
- " train_mnist,\n",
- " callbacks=[WandbLoggerCallback(project=\"raytune-colab\")], # WandbLoggerCallback uses tune.run's logger interface\n",
- " resources_per_trial={\"gpu\": 1},\n",
- " config={\n",
- " # hyperparameters are set by keyword arguments\n",
- " \"lr\": tune.grid_search([0.0001, 0.001, 0.1]),\n",
- " \"momentum\": tune.grid_search([0.9, 0.99])\n",
- " }\n",
- " )\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "print(\"Best config: \", analysis.get_best_config(metric=\"mean_accuracy\", mode=\"max\"))"
- ]
- }
- ],
- "metadata": {
- "accelerator": "GPU",
- "colab": {
- "include_colab_link": true,
- "provenance": [],
- "toc_visible": true
- },
- "kernelspec": {
- "display_name": "Python 3",
- "name": "python3"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 0
+ "nbformat": 4,
+ "nbformat_minor": 0
}
diff --git a/colabs/raytune/tune-wandb.ipynb b/colabs/raytune/tune-wandb.ipynb
index dd141d2c..f35865f7 100644
--- a/colabs/raytune/tune-wandb.ipynb
+++ b/colabs/raytune/tune-wandb.ipynb
@@ -2,6 +2,7 @@
"cells": [
{
"cell_type": "markdown",
+ "id": "91d48998",
"metadata": {},
"source": [
"\n"
@@ -9,7 +10,7 @@
},
{
"cell_type": "markdown",
- "id": "ecad719c",
+ "id": "b9c22142",
"metadata": {},
"source": [
"# Using Weights & Biases with Tune\n",
@@ -40,7 +41,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "3db798d0",
+ "id": "2969a4cb",
"metadata": {},
"outputs": [],
"source": [
@@ -50,7 +51,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "100bcf8a",
+ "id": "55026926",
"metadata": {},
"outputs": [],
"source": [
@@ -70,7 +71,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "1b4dcae4",
+ "id": "e40e6ca6",
"metadata": {},
"outputs": [],
"source": [
@@ -79,7 +80,7 @@
},
{
"cell_type": "markdown",
- "id": "9346c0f6",
+ "id": "1d6c8d04",
"metadata": {},
"source": [
"Next, let's define an easy `objective` function (a Tune `Trainable`) that reports a random loss to Tune.\n",
@@ -90,7 +91,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "e8b4fc4d",
+ "id": "e6e75c5c",
"metadata": {},
"outputs": [],
"source": [
@@ -102,7 +103,7 @@
},
{
"cell_type": "markdown",
- "id": "831eed42",
+ "id": "ff78ca64",
"metadata": {},
"source": [
"Given that you provide an `api_key_file` pointing to your Weights & Biases API key, you cna define a\n",
@@ -112,7 +113,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "52988599",
+ "id": "494e1c84",
"metadata": {},
"outputs": [],
"source": [
@@ -141,7 +142,7 @@
},
{
"cell_type": "markdown",
- "id": "e24c05fa",
+ "id": "de9b195d",
"metadata": {},
"source": [
"To use the `wandb_mixin` decorator, you can simply decorate the objective function from earlier.\n",
@@ -152,7 +153,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "5e30d5e7",
+ "id": "e0275757",
"metadata": {},
"outputs": [],
"source": [
@@ -166,7 +167,7 @@
},
{
"cell_type": "markdown",
- "id": "04040bcb",
+ "id": "59c23cff",
"metadata": {},
"source": [
"With the `decorated_objective` defined, running a Tune experiment is as simple as providing this objective and\n",
@@ -176,7 +177,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "d4fbd368",
+ "id": "e4d408df",
"metadata": {},
"outputs": [],
"source": [
@@ -201,7 +202,7 @@
},
{
"cell_type": "markdown",
- "id": "f9521481",
+ "id": "6565559f",
"metadata": {},
"source": [
"Finally, you can also define a class-based Tune `Trainable` by using the `WandbTrainableMixin` to define your objective:"
@@ -210,7 +211,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "d27a7a35",
+ "id": "1b140b28",
"metadata": {},
"outputs": [],
"source": [
@@ -224,7 +225,7 @@
},
{
"cell_type": "markdown",
- "id": "fa189bb2",
+ "id": "cc42d1a4",
"metadata": {},
"source": [
"Running Tune with this `WandbTrainable` works exactly the same as with the function API.\n",
@@ -235,7 +236,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "6e546cc2",
+ "id": "e2782ef4",
"metadata": {},
"outputs": [],
"source": [
@@ -260,7 +261,7 @@
},
{
"cell_type": "markdown",
- "id": "0b736172",
+ "id": "58d4914b",
"metadata": {},
"source": [
"Since you may not have an API key for Wandb, we can _mock_ the Wandb logger and test all three of our training\n",
@@ -271,7 +272,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "e0e7f481",
+ "id": "fecd85ed",
"metadata": {},
"outputs": [],
"source": [
@@ -302,7 +303,7 @@
},
{
"cell_type": "markdown",
- "id": "2f6e9138",
+ "id": "fbd9728a",
"metadata": {},
"source": [
"This completes our Tune and Wandb walk-through.\n",
diff --git a/colabs/rdkit/wb_rdkit.ipynb b/colabs/rdkit/wb_rdkit.ipynb
index d23c2b03..678a516c 100644
--- a/colabs/rdkit/wb_rdkit.ipynb
+++ b/colabs/rdkit/wb_rdkit.ipynb
@@ -41,7 +41,15 @@
"outputs": [],
"source": [
"\n",
- "import wandb\n",
+ "import wandb"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"wandb.login()"
]
},
diff --git a/colabs/scikit/Simple_Scikit_Integration.ipynb b/colabs/scikit/Simple_Scikit_Integration.ipynb
index c5a35d54..f68bb555 100644
--- a/colabs/scikit/Simple_Scikit_Integration.ipynb
+++ b/colabs/scikit/Simple_Scikit_Integration.ipynb
@@ -91,8 +91,15 @@
"metadata": {},
"outputs": [],
"source": [
- "import wandb\n",
- "\n",
+ "import wandb\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"wandb.login()"
]
},
diff --git a/colabs/scikit/wandb_decision_tree.ipynb b/colabs/scikit/wandb_decision_tree.ipynb
new file mode 100644
index 00000000..d2d34873
--- /dev/null
+++ b/colabs/scikit/wandb_decision_tree.ipynb
@@ -0,0 +1,247 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "e6qvOYJV8BQw"
+ },
+ "source": [
+ "## Author: [@SauravMaheshkar](https://twitter.com/MaheshkarSaurav)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "BL5XadIr0MbM"
+ },
+ "source": [
+ "# Packages 📦 and Basic Setup\n",
+ "---"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "fZH0PUfw0OBV"
+ },
+ "source": [
+ "## Install Packages"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "_uk-bdosSzwk"
+ },
+ "outputs": [],
+ "source": [
+ "%%capture\n",
+ "## Install Sklearn\n",
+ "!pip install -U scikit-learn\n",
+ "## Install the latest version of wandb client 🔥🔥\n",
+ "!pip install -q --upgrade wandb"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "UgQLiLtz0STo"
+ },
+ "source": [
+ "## Project Configuration using **`wandb.config`**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "OvABqBlB0VhB"
+ },
+ "outputs": [],
+ "source": [
+ "import wandb\n",
+ "\n",
+ "## Importing Libraries\n",
+ "from sklearn.datasets import load_iris\n",
+ "from sklearn.model_selection import train_test_split\n",
+ "from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "wandb.login()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Initialize the run\n",
+ "run = wandb.init(project='simple-scikit')\n",
+ "\n",
+ "# Feel free to change these and experiment !!\n",
+ "config = wandb.config\n",
+ "config.max_depth = 5\n",
+ "config.min_samples_split = 2\n",
+ "config.clf_criterion = \"gini\"\n",
+ "config.reg_criterion = \"mse\"\n",
+ "config.splitter = \"best\"\n",
+ "config.dataset = \"iris\"\n",
+ "config.test_size = 0.2\n",
+ "config.random_state = 42\n",
+ "config.labels =['setosa', 'versicolor', 'virginica']\n",
+ "\n",
+ "# Update the config\n",
+ "wandb.config.update(config)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "OGV1c4VYJMGa"
+ },
+ "source": [
+ "# 💿 Dataset\n",
+ "---"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "lWnhwnPI2NWD"
+ },
+ "outputs": [],
+ "source": [
+ "## Loading the Dataset\n",
+ "iris = load_iris(return_X_y = True, as_frame= True)\n",
+ "dataset = iris[0]\n",
+ "target = iris[1]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "kycxKC7ER7lW"
+ },
+ "source": [
+ "# ✍️ Model Architecture\n",
+ "---"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "4wWuB4NZl3h_"
+ },
+ "source": [
+ "## Classification"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "vElZpyEQtXzX"
+ },
+ "outputs": [],
+ "source": [
+ "X, y = load_iris(return_X_y=True)\n",
+ "x_train, x_test, y_train, y_test = train_test_split(X, y, test_size = config.test_size, random_state = config.random_state)\n",
+ "\n",
+ "clf = DecisionTreeClassifier(\n",
+ " max_depth=config.max_depth,\n",
+ " min_samples_split=config.min_samples_split,\n",
+ " criterion=config.clf_criterion,\n",
+ " splitter=config.splitter\n",
+ ")\n",
+ "clf = clf.fit(x_train,y_train)\n",
+ "\n",
+ "y_pred = clf.predict(x_test)\n",
+ "\n",
+ "# Visualize Confustion Matrix\n",
+ "wandb.sklearn.plot_confusion_matrix(y_test, y_pred, config.labels)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "BX31c0s8MVu5"
+ },
+ "source": [
+ "## Regression"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "1kBLTtH4MbsY"
+ },
+ "outputs": [],
+ "source": [
+ "X, y = load_iris(return_X_y=True)\n",
+ "\n",
+ "x_train, x_test, y_train, y_test = train_test_split(X, y, test_size = config.test_size, random_state = config.random_state)\n",
+ "\n",
+ "reg = DecisionTreeRegressor(\n",
+ " max_depth=config.max_depth,\n",
+ " min_samples_split=config.min_samples_split,\n",
+ " criterion=config.reg_criterion,\n",
+ " splitter=config.splitter\n",
+ ")\n",
+ "\n",
+ "reg = reg.fit(x_train,y_train)\n",
+ "\n",
+ "# All regression plots\n",
+ "wandb.sklearn.plot_regressor(reg, x_train, x_test, y_train, y_test, model_name='DecisionTreeRegressor')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Finish the W&B Process\n",
+ "wandb.finish()"
+ ]
+ }
+ ],
+ "metadata": {
+ "colab": {
+ "collapsed_sections": [
+ "bE1OZxeIwPqg",
+ "_N6qEF3cvIpN",
+ "_vdIezY_zb_9",
+ "OGV1c4VYJMGa",
+ "H8IIKXd0PFL8",
+ "kycxKC7ER7lW",
+ "4wWuB4NZl3h_",
+ "BX31c0s8MVu5",
+ "c6PoNscLNuog",
+ "3Cf152r2NK1M",
+ "a_HO0jHMyyyP",
+ "ahE27LgSzUVx",
+ "FaENg6O44dso",
+ "t3K1L5UWSZCL"
+ ],
+ "name": "Decision Trees",
+ "provenance": [],
+ "toc_visible": true
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/colabs/simpletransformers/SimpleTransformersQA.ipynb b/colabs/simpletransformers/SimpleTransformersQA.ipynb
index f5e22c82..9ad3a6a8 100644
--- a/colabs/simpletransformers/SimpleTransformersQA.ipynb
+++ b/colabs/simpletransformers/SimpleTransformersQA.ipynb
@@ -115,10 +115,18 @@
"source": [
"import wandb\n",
"\n",
- "wandb.login()\n",
"wandb_project = \"SimpleTransformers-QA\""
]
},
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "wandb.login()"
+ ]
+ },
{
"cell_type": "markdown",
"metadata": {},
diff --git a/colabs/spacy/SpaCy_v3_and_W&B.ipynb b/colabs/spacy/SpaCy_v3_and_W&B.ipynb
index 9ffe78a2..5435b742 100644
--- a/colabs/spacy/SpaCy_v3_and_W&B.ipynb
+++ b/colabs/spacy/SpaCy_v3_and_W&B.ipynb
@@ -103,7 +103,15 @@
"metadata": {},
"outputs": [],
"source": [
- "import wandb\n",
+ "import wandb"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"wandb.login()"
]
},
diff --git a/colabs/spacy/spaCy_v3_and_W&B_Sweeps.ipynb b/colabs/spacy/spaCy_v3_and_W&B_Sweeps.ipynb
index d4b3e561..e27a6e08 100644
--- a/colabs/spacy/spaCy_v3_and_W&B_Sweeps.ipynb
+++ b/colabs/spacy/spaCy_v3_and_W&B_Sweeps.ipynb
@@ -84,8 +84,15 @@
"metadata": {},
"outputs": [],
"source": [
- "import wandb\n",
- "\n",
+ "import wandb\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"wandb.login()"
]
},
diff --git a/colabs/stable_baselines3/Stable_Baselines3_wandb_experiment_tracking.ipynb b/colabs/stable_baselines3/Stable_Baselines3_wandb_experiment_tracking.ipynb
index 8d0103e9..943a8471 100644
--- a/colabs/stable_baselines3/Stable_Baselines3_wandb_experiment_tracking.ipynb
+++ b/colabs/stable_baselines3/Stable_Baselines3_wandb_experiment_tracking.ipynb
@@ -123,11 +123,6 @@
"\n",
"![](https://user-images.githubusercontent.com/5555347/122989248-97b5bd00-d370-11eb-95d6-52d56cfbce19.gif)"
]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": []
}
],
"metadata": {
diff --git a/colabs/stylegan_nada/StyleGAN-NADA.ipynb b/colabs/stylegan_nada/StyleGAN-NADA.ipynb
index 3dcb606b..8db2c1a9 100644
--- a/colabs/stylegan_nada/StyleGAN-NADA.ipynb
+++ b/colabs/stylegan_nada/StyleGAN-NADA.ipynb
@@ -81,7 +81,15 @@
"metadata": {},
"outputs": [],
"source": [
- "import wandb\n",
+ "import wandb"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"wandb.login()"
]
},
diff --git a/colabs/super-gradients/yolo_nas.ipynb b/colabs/super-gradients/yolo_nas.ipynb
index a79ad316..2c953fdf 100644
--- a/colabs/super-gradients/yolo_nas.ipynb
+++ b/colabs/super-gradients/yolo_nas.ipynb
@@ -359,12 +359,7 @@
]
}
],
- "metadata": {
- "language_info": {
- "name": "python"
- },
- "orig_nbformat": 4
- },
+ "metadata": {},
"nbformat": 4,
"nbformat_minor": 2
}
diff --git a/colabs/super-gradients/yolo_nas_data_analysis.ipynb b/colabs/super-gradients/yolo_nas_data_analysis.ipynb
new file mode 100644
index 00000000..e98425cf
--- /dev/null
+++ b/colabs/super-gradients/yolo_nas_data_analysis.ipynb
@@ -0,0 +1,438 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "provenance": []
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "SJenmP6BkY2V"
+ },
+ "outputs": [],
+ "source": [
+ "!sudo apt install libcairo2-dev pkg-config python3-dev -qq\n",
+ "!pip install roboflow pycairo wandb sweeps -qqq\n",
+ "!pip install super_gradients"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import os\n",
+ "import glob\n",
+ "import torch\n",
+ "import wandb\n",
+ "import warnings\n",
+ "import numpy as np\n",
+ "import pandas as pd\n",
+ "import matplotlib.pyplot as plt\n",
+ "\n",
+ "from matplotlib import patches\n",
+ "from google.colab import userdata\n",
+ "from torchvision.io import read_image\n",
+ "from torch.utils.data import DataLoader\n",
+ "from super_gradients.training import models, dataloaders\n",
+ "from super_gradients.training.dataloaders.dataloaders import (\n",
+ " coco_detection_yolo_format_train, coco_detection_yolo_format_val\n",
+ ")\n",
+ "\n",
+ "warnings.filterwarnings(\"ignore\")\n",
+ "\n",
+ "os.environ[\"WANDB_API_KEY\"] = userdata.get('wandb')\n",
+ "os.environ[\"ROBOFLOW_API_KEY\"] = userdata.get('roboflow')"
+ ],
+ "metadata": {
+ "id": "zxmwpTBzpd-d"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from roboflow import Roboflow\n",
+ "rf = Roboflow(api_key=os.getenv(\"ROBOFLOW_API_KEY\"))\n",
+ "project = rf.workspace(\"easyhyeon\").project(\"trash-sea\")\n",
+ "dataset = project.version(10).download(\"yolov5\")"
+ ],
+ "metadata": {
+ "id": "77XdRx8rpfqQ"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "DATASET_PATH = \"/content/trash-sea-10\"\n",
+ "WANDB_PROJECT_NAME = \"fconn-yolo-nas\"\n",
+ "ENTITY = \"ml-colabs\""
+ ],
+ "metadata": {
+ "id": "8shwiGscpglK"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "dataset_params = {\n",
+ " 'data_dir':DATASET_PATH,\n",
+ " 'train_images_dir':'train/images',\n",
+ " 'train_labels_dir':'train/labels',\n",
+ " 'val_images_dir':'valid/images',\n",
+ " 'val_labels_dir':'valid/labels',\n",
+ " 'test_images_dir':'test/images',\n",
+ " 'test_labels_dir':'test/labels',\n",
+ " 'classes': [\"Buoy\", \"Can\", \"Paper\", \"Plastic Bag\", \"Plastic Bottle\"]\n",
+ "}"
+ ],
+ "metadata": {
+ "id": "13DJyeaIvKOf"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from IPython.display import clear_output\n",
+ "\n",
+ "train_data = coco_detection_yolo_format_train(\n",
+ " dataset_params={\n",
+ " 'data_dir': dataset_params['data_dir'],\n",
+ " 'images_dir': dataset_params['train_images_dir'],\n",
+ " 'labels_dir': dataset_params['train_labels_dir'],\n",
+ " 'classes': dataset_params['classes'],\n",
+ " },\n",
+ " dataloader_params={\n",
+ " 'batch_size':16,\n",
+ " 'num_workers':2\n",
+ " }\n",
+ ")\n",
+ "\n",
+ "val_data = coco_detection_yolo_format_val(\n",
+ " dataset_params={\n",
+ " 'data_dir': dataset_params['data_dir'],\n",
+ " 'images_dir': dataset_params['val_images_dir'],\n",
+ " 'labels_dir': dataset_params['val_labels_dir'],\n",
+ " 'classes': dataset_params['classes'],\n",
+ " },\n",
+ " dataloader_params={\n",
+ " 'batch_size':16,\n",
+ " 'num_workers':2\n",
+ " }\n",
+ ")\n",
+ "\n",
+ "test_data = coco_detection_yolo_format_val(\n",
+ " dataset_params={\n",
+ " 'data_dir': dataset_params['data_dir'],\n",
+ " 'images_dir': dataset_params['test_images_dir'],\n",
+ " 'labels_dir': dataset_params['test_labels_dir'],\n",
+ " 'classes': dataset_params['classes'],\n",
+ " },\n",
+ " dataloader_params={\n",
+ " 'batch_size':16,\n",
+ " 'num_workers':2\n",
+ " }\n",
+ ")\n",
+ "\n",
+ "train_data.dataset.transforms = train_data.dataset.transforms[5:]"
+ ],
+ "metadata": {
+ "id": "7EAeEm3xvS82"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "colors = {\n",
+ " 0: 'red',\n",
+ " 1: 'green',\n",
+ " 2: 'blue',\n",
+ " 3: 'yellow',\n",
+ " 4: 'black'\n",
+ "}\n",
+ "classes = {\n",
+ " 0:\"Buoy\",\n",
+ " 1:\"Can\",\n",
+ " 2:\"Paper\",\n",
+ " 3:\"Plastic Bag\",\n",
+ " 4:\"Plastic Bottle\"\n",
+ "}"
+ ],
+ "metadata": {
+ "id": "0XKew_u_vRxs"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "def process_bounding_boxes_list(annots):\n",
+ " result = []\n",
+ "\n",
+ " class_count = {i:0 for i in range(0, 5)}\n",
+ "\n",
+ " for annot_idx, annotation in enumerate(annots):\n",
+ " class_count[int(annotation[1])] += 1\n",
+ " result.append({\n",
+ " \"position\":{\n",
+ " \"middle\":[float(annotation[2]), float(annotation[3])],\n",
+ " \"width\":float(annotation[4]),\n",
+ " \"height\":float(annotation[5])\n",
+ " },\n",
+ " \"domain\":\"pixel\",\n",
+ " \"class_id\":int(annotation[1]),\n",
+ " \"box_caption\":classes[int(annotation[1])]\n",
+ " })\n",
+ "\n",
+ " return result, class_count\n",
+ "\n",
+ "def populate_wandb_image_samples(train_data):\n",
+ " wandb.init(\n",
+ " project=WANDB_PROJECT_NAME,\n",
+ " entity=ENTITY,\n",
+ " id='add-image-samples',\n",
+ " job_type=\"add-tables\",\n",
+ " resume='allow'\n",
+ " )\n",
+ "\n",
+ " class_set = wandb.Classes(\n",
+ " [\n",
+ " {\"name\": \"Buoy\", \"id\": 0},\n",
+ " {\"name\": \"Can\", \"id\": 1},\n",
+ " {\"name\": \"Paper\", \"id\": 2},\n",
+ " {\"name\": \"Plastic Bag\", \"id\": 3},\n",
+ " {\"name\": \"Plastic Bottle\", \"id\": 4},\n",
+ " ]\n",
+ " )\n",
+ "\n",
+ " table = wandb.Table(\n",
+ " columns=[\n",
+ " \"Annotated-Image\", \"Number-of-objects\",\n",
+ " \"Number-Buoy\", \"Number-Can\", \"Number-Paper\",\n",
+ " \"Number-Plastic-Bag\", \"Number-Plastic-Bottle\"\n",
+ " ]\n",
+ " )\n",
+ "\n",
+ " img_count = 0\n",
+ "\n",
+ " for batch_idx, batch_sample in enumerate(train_data):\n",
+ " batch_images = batch_sample[0]\n",
+ " batch_annotations = batch_sample[1]\n",
+ "\n",
+ " annots_dict = {i:[] for i in range(0, batch_images.shape[0])}\n",
+ " for annot in batch_annotations:\n",
+ " annots_dict[int(annot[0])].append(annot)\n",
+ "\n",
+ " for idx, image in enumerate(batch_images):\n",
+ "\n",
+ " bbox, class_count = process_bounding_boxes_list(annots_dict[idx])\n",
+ "\n",
+ " image = image.flip(0)\n",
+ "\n",
+ " img = wandb.Image(\n",
+ " image,\n",
+ " boxes={\n",
+ " \"ground_truth\":{\n",
+ " \"box_data\":bbox,\n",
+ " \"class_labels\": classes,\n",
+ " }\n",
+ " },\n",
+ " classes=class_set,\n",
+ " )\n",
+ "\n",
+ " table.add_data(img, len(bbox), class_count[0],\n",
+ " class_count[1], class_count[2],\n",
+ " class_count[3], class_count[4])\n",
+ " img_count += 1\n",
+ "\n",
+ " print(f\"{img_count}/{len(train_data)*16} completed\")\n",
+ "\n",
+ " wandb.log({\"ground_truth_dataset\": table})\n",
+ " wandb.finish()\n",
+ "\n",
+ "populate_wandb_image_samples(train_data)"
+ ],
+ "metadata": {
+ "id": "ncYSU2svvd-B"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "def process_bounding_boxes_list(annots):\n",
+ " result = []\n",
+ "\n",
+ " class_count = {i:0 for i in range(0, 5)}\n",
+ "\n",
+ " for annot_idx, annotation in enumerate(annots):\n",
+ " class_count[int(annotation[1])] += 1\n",
+ " result.append({\n",
+ " \"position\":{\n",
+ " \"middle\":[float(annotation[2]), float(annotation[3])],\n",
+ " \"width\":float(annotation[4]),\n",
+ " \"height\":float(annotation[5])\n",
+ " },\n",
+ " \"domain\":\"pixel\",\n",
+ " \"class_id\":int(annotation[1]),\n",
+ " \"box_caption\":classes[int(annotation[1])]\n",
+ " })\n",
+ "\n",
+ " return result, class_count\n",
+ "\n",
+ "def populate_wandb_bbox(train_data):\n",
+ " wandb.init(\n",
+ " project=WANDB_PROJECT_NAME,\n",
+ " entity=ENTITY,\n",
+ " id='add-bbox-data',\n",
+ " job_type=\"add-tables\",\n",
+ " resume='allow'\n",
+ " )\n",
+ "\n",
+ " class_set = wandb.Classes(\n",
+ " [\n",
+ " {\"name\": \"Buoy\", \"id\": 0},\n",
+ " {\"name\": \"Can\", \"id\": 1},\n",
+ " {\"name\": \"Paper\", \"id\": 2},\n",
+ " {\"name\": \"Plastic Bag\", \"id\": 3},\n",
+ " {\"name\": \"Plastic Bottle\", \"id\": 4},\n",
+ " ]\n",
+ " )\n",
+ "\n",
+ " table = wandb.Table(\n",
+ " columns=[\n",
+ " \"Image-Id\",\n",
+ " \"BBox-Height\",\n",
+ " \"BBox-Width\",\n",
+ " \"Class-Id\"\n",
+ " ]\n",
+ " )\n",
+ "\n",
+ " img_count = 0\n",
+ "\n",
+ " for batch_idx, batch_sample in enumerate(train_data):\n",
+ " batch_images = batch_sample[0]\n",
+ " batch_annotations = batch_sample[1]\n",
+ "\n",
+ " annots_dict = {i:[] for i in range(0, batch_images.shape[0])}\n",
+ " for annot in batch_annotations:\n",
+ " annots_dict[int(annot[0])].append(annot)\n",
+ "\n",
+ " for idx, image in enumerate(batch_images):\n",
+ "\n",
+ " result, class_count = process_bounding_boxes_list(annots_dict[idx])\n",
+ "\n",
+ " for bbox in result:\n",
+ " height = bbox[\"position\"][\"height\"]\n",
+ " width = bbox[\"position\"][\"width\"]\n",
+ " class_id = bbox[\"class_id\"]\n",
+ " table.add_data(img_count, height, width, classes[class_id])\n",
+ "\n",
+ " img_count += 1\n",
+ "\n",
+ " print(f\"{img_count}/{len(train_data)*16} completed\")\n",
+ "\n",
+ " wandb.log({\"bounding_box_information\": table})\n",
+ " wandb.finish()\n",
+ "\n",
+ "populate_wandb_bbox(train_data)"
+ ],
+ "metadata": {
+ "id": "-8W-zcs9vh93"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "def populate_wandb_spatial_heatmaps(train_data):\n",
+ " wandb.init(\n",
+ " project=WANDB_PROJECT_NAME,\n",
+ " entity=ENTITY,\n",
+ " id='add-heatmap',\n",
+ " job_type=\"add-tables\",\n",
+ " resume='allow'\n",
+ " )\n",
+ "\n",
+ " class_set = wandb.Classes(\n",
+ " [\n",
+ " {\"name\": \"Buoy\", \"id\": 0},\n",
+ " {\"name\": \"Can\", \"id\": 1},\n",
+ " {\"name\": \"Paper\", \"id\": 2},\n",
+ " {\"name\": \"Plastic Bag\", \"id\": 3},\n",
+ " {\"name\": \"Plastic Bottle\", \"id\": 4},\n",
+ " ]\n",
+ " )\n",
+ " heatmaps = [np.zeros((224, 224, 1), dtype=np.float32) for _ in classes]\n",
+ " annotation_counts = {i:0 for i in range(len(classes))}\n",
+ "\n",
+ " table = wandb.Table(columns=[\"Class-Id\", \"Class-Name\", \"Spatial-Heatmap\",\n",
+ " \"Num-Total-Objects\"])\n",
+ "\n",
+ " for batch_idx, batch_sample in enumerate(train_data):\n",
+ " batch_images = batch_sample[0]\n",
+ " batch_annotations = batch_sample[1]\n",
+ "\n",
+ " annots_dict = {i:[] for i in range(0, batch_images.shape[0])}\n",
+ "\n",
+ " for annot in batch_annotations:\n",
+ " class_idx = int(annot[1])\n",
+ "\n",
+ " midpoint_x = int(annot[2])\n",
+ " midpoint_y = int(annot[3])\n",
+ " width = int(annot[4])\n",
+ " height = int(annot[5])\n",
+ "\n",
+ " x_min = midpoint_x - (width//2)\n",
+ " x_max = midpoint_x + (width//2)\n",
+ "\n",
+ " y_min = midpoint_y - (height//2)\n",
+ " y_max = midpoint_y + (height//2)\n",
+ "\n",
+ " heatmaps[class_idx][y_min:y_max, x_min:x_max] += 1\n",
+ "\n",
+ " annotation_counts[class_idx] += 1\n",
+ "\n",
+ " print(f\"{batch_idx+1}/{len(train_data)} batches completed\")\n",
+ "\n",
+ " for class_idx in range(len(classes)):\n",
+ " heatmap = wandb.Image(\n",
+ " heatmaps[class_idx],\n",
+ " caption=classes[class_idx]\n",
+ " )\n",
+ " table.add_data(class_idx, classes[class_idx], heatmap, annotation_counts[class_idx])\n",
+ "\n",
+ " wandb.log({\"spatial_heatmap_information\": table})\n",
+ " wandb.finish()\n",
+ "\n",
+ "populate_wandb_spatial_heatmaps(train_data)"
+ ],
+ "metadata": {
+ "id": "AaN1pUvJvqjA"
+ },
+ "execution_count": null,
+ "outputs": []
+ }
+ ]
+}
\ No newline at end of file
diff --git a/colabs/super-gradients/yolo_nas_sweep_run.ipynb b/colabs/super-gradients/yolo_nas_sweep_run.ipynb
new file mode 100644
index 00000000..93ab971c
--- /dev/null
+++ b/colabs/super-gradients/yolo_nas_sweep_run.ipynb
@@ -0,0 +1,345 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "provenance": []
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Installation and Imports"
+ ],
+ "metadata": {
+ "id": "Fe3Dp6NDzPyr"
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "KXwTjgsFkZnm"
+ },
+ "outputs": [],
+ "source": [
+ "!sudo apt install libcairo2-dev pkg-config python3-dev -qq\n",
+ "!pip install roboflow pycairo wandb sweeps -qqq\n",
+ "!pip install super_gradients"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import os\n",
+ "import glob\n",
+ "import torch\n",
+ "import wandb\n",
+ "import warnings\n",
+ "import pandas as pd\n",
+ "\n",
+ "from google.colab import userdata\n",
+ "from torchvision.io import read_image\n",
+ "from torch.utils.data import DataLoader\n",
+ "from IPython.display import clear_output\n",
+ "\n",
+ "from super_gradients.training import models, Trainer, dataloaders\n",
+ "from super_gradients.training.losses import PPYoloELoss\n",
+ "from super_gradients.training.metrics import DetectionMetrics_050\n",
+ "from super_gradients.training.models.detection_models.pp_yolo_e import PPYoloEPostPredictionCallback\n",
+ "from super_gradients.training.dataloaders.dataloaders import coco_detection_yolo_format_train, coco_detection_yolo_format_val\n",
+ "\n",
+ "os.environ[\"WANDB_API_KEY\"] = userdata.get('wandb')\n",
+ "os.environ[\"ROBOFLOW_API_KEY\"] = userdata.get('roboflow')"
+ ],
+ "metadata": {
+ "id": "PpsF3k8awWbP"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# Definitions"
+ ],
+ "metadata": {
+ "id": "qRjY7eYPzScK"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "seed = 42\n",
+ "torch.manual_seed(seed)\n",
+ "\n",
+ "if torch.cuda.is_available():\n",
+ " torch.cuda.manual_seed(seed)\n",
+ " torch.cuda.manual_seed_all(seed)\n",
+ "\n",
+ "torch.backends.cudnn.deterministic = True\n",
+ "torch.backends.cudnn.benchmark = False"
+ ],
+ "metadata": {
+ "id": "nVJIsdFOwXX2"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# Download and Register dataset"
+ ],
+ "metadata": {
+ "id": "FhcF4opozT8c"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from roboflow import Roboflow\n",
+ "rf = Roboflow(api_key=os.getenv(\"ROBOFLOW_API_KEY\"))\n",
+ "project = rf.workspace(\"easyhyeon\").project(\"trash-sea\")\n",
+ "dataset = project.version(10).download(\"yolov5\")"
+ ],
+ "metadata": {
+ "id": "BcSFy0iewaJM"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "ENTITY = \"ml-colabs\"\n",
+ "SWEEP_NUM_RUNS = 100\n",
+ "WANDB_PROJECT_NAME = \"fconn-yolo-nas\"\n",
+ "DATASET_PATH = \"/content/trash-sea-10\""
+ ],
+ "metadata": {
+ "id": "QIjvNsTowbKO"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# Define Sweep Configuration and functions"
+ ],
+ "metadata": {
+ "id": "VO0p4iM0zW1N"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "sweep_configuration = {\n",
+ " \"name\": WANDB_EXP_NAME,\n",
+ " \"metric\": {\"name\": \"Valid_mAP@0.50\", \"goal\": \"maximize\"},\n",
+ " \"method\": \"bayes\",\n",
+ " \"parameters\": {\n",
+ " \"batch_size\": {\"values\": [16, 24, 32]},\n",
+ " \"optimizer\": {\"values\": [\"Adam\", \"SGD\", \"RMSProp\", \"AdamW\"]},\n",
+ " \"ema_decay\": {\"min\":0.5, \"max\":0.9},\n",
+ " \"ema_decay_type\": {\"values\": [\"constant\", \"threshold\"]},\n",
+ " \"cosine_lr_ratio\": {\"min\": 0.01, \"max\": 0.4},\n",
+ " \"iou_loss_weight\": {\"min\": 0.25, \"max\": 2.0},\n",
+ " \"dfl_loss_weight\": {\"min\": 0.25, \"max\": 2.0},\n",
+ " \"classification_loss_weight\": {\"min\": 0.25, \"max\": 2.0},\n",
+ " \"model_flavor\": {\"values\": [\"yolo_nas_s\", \"yolo_nas_m\", \"yolo_nas_l\"]},\n",
+ " \"weight_decay\": {\"min\": 0.0001, \"max\": 0.01},\n",
+ " },\n",
+ "}"
+ ],
+ "metadata": {
+ "id": "8kMUNQ6HwoFa"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "def main_call():\n",
+ "\n",
+ " CHECKPOINT_DIR = 'checkpoints'\n",
+ "\n",
+ " wandb.init(\n",
+ " project=WANDB_PROJECT_NAME,\n",
+ " entity=ENTITY,\n",
+ " resume=\"allow\",\n",
+ " save_code=True,\n",
+ " id=WANDB_EXP_NAME\n",
+ " )\n",
+ "\n",
+ " config = wandb.config\n",
+ "\n",
+ " dataset_params = {\n",
+ " 'data_dir':DATASET_PATH,\n",
+ " 'train_images_dir':'train/images',\n",
+ " 'train_labels_dir':'train/labels',\n",
+ " 'val_images_dir':'valid/images',\n",
+ " 'val_labels_dir':'valid/labels',\n",
+ " 'test_images_dir':'test/images',\n",
+ " 'test_labels_dir':'test/labels',\n",
+ " 'classes': [\"Buoy\", \"Can\", \"Paper\", \"Plastic Bag\", \"Plastic Bottle\"]\n",
+ " }\n",
+ "\n",
+ " train_data = coco_detection_yolo_format_train(\n",
+ " dataset_params={\n",
+ " 'data_dir': dataset_params['data_dir'],\n",
+ " 'images_dir': dataset_params['train_images_dir'],\n",
+ " 'labels_dir': dataset_params['train_labels_dir'],\n",
+ " 'classes': dataset_params['classes'],\n",
+ " },\n",
+ " dataloader_params={\n",
+ " 'batch_size':config[\"batch_size\"],\n",
+ " 'num_workers':4\n",
+ " }\n",
+ " )\n",
+ "\n",
+ " val_data = coco_detection_yolo_format_val(\n",
+ " dataset_params={\n",
+ " 'data_dir': dataset_params['data_dir'],\n",
+ " 'images_dir': dataset_params['val_images_dir'],\n",
+ " 'labels_dir': dataset_params['val_labels_dir'],\n",
+ " 'classes': dataset_params['classes'],\n",
+ " },\n",
+ " dataloader_params={\n",
+ " 'batch_size':config[\"batch_size\"],\n",
+ " 'num_workers':4\n",
+ " }\n",
+ " )\n",
+ "\n",
+ " test_data = coco_detection_yolo_format_val(\n",
+ " dataset_params={\n",
+ " 'data_dir': dataset_params['data_dir'],\n",
+ " 'images_dir': dataset_params['test_images_dir'],\n",
+ " 'labels_dir': dataset_params['test_labels_dir'],\n",
+ " 'classes': dataset_params['classes'],\n",
+ " },\n",
+ " dataloader_params={\n",
+ " 'batch_size':config[\"batch_size\"],\n",
+ " 'num_workers':4\n",
+ " }\n",
+ " )\n",
+ "\n",
+ " train_data.dataset.transforms = train_data.dataset.transforms[1:]\n",
+ "\n",
+ " model = models.get(\n",
+ " config[\"model_flavor\"],\n",
+ " num_classes=len(dataset_params['classes']),\n",
+ " pretrained_weights=\"coco\"\n",
+ " )\n",
+ "\n",
+ " train_params = {\n",
+ " 'silent_mode': False,\n",
+ " \"average_best_models\":True,\n",
+ " \"warmup_mode\": \"linear_epoch_step\",\n",
+ " \"warmup_initial_lr\": 1e-6,\n",
+ " \"lr_warmup_epochs\": 3,\n",
+ " \"initial_lr\": 1e-3,\n",
+ " \"lr_mode\": \"cosine\",\n",
+ " \"cosine_final_lr_ratio\": config[\"cosine_lr_ratio\"],\n",
+ " \"optimizer\": config[\"optimizer\"],\n",
+ " \"optimizer_params\": {\n",
+ " \"weight_decay\": config[\"weight_decay\"]\n",
+ " },\n",
+ " \"zero_weight_decay_on_bias_and_bn\": True,\n",
+ " \"ema\": True,\n",
+ " \"ema_params\": {\n",
+ " \"decay\": config[\"ema_decay\"],\n",
+ " \"decay_type\": config[\"ema_decay_type\"]\n",
+ " },\n",
+ " \"max_epochs\": 5,\n",
+ " \"mixed_precision\": False,\n",
+ " \"loss\": PPYoloELoss(\n",
+ " use_static_assigner=False,\n",
+ " num_classes=len(dataset_params['classes']),\n",
+ " reg_max=16,\n",
+ " iou_loss_weight=config[\"iou_loss_weight\"],\n",
+ " dfl_loss_weight=config[\"dfl_loss_weight\"],\n",
+ " classification_loss_weight=config[\"classification_loss_weight\"]\n",
+ " ),\n",
+ " \"valid_metrics_list\": [\n",
+ " DetectionMetrics_050(\n",
+ " score_thres=0.1,\n",
+ " top_k_predictions=300,\n",
+ " num_cls=len(dataset_params['classes']),\n",
+ " normalize_targets=True,\n",
+ " post_prediction_callback=PPYoloEPostPredictionCallback(\n",
+ " score_threshold=0.01,\n",
+ " nms_top_k=1000,\n",
+ " max_predictions=300,\n",
+ " nms_threshold=0.7\n",
+ " )\n",
+ " )\n",
+ " ],\n",
+ " \"metric_to_watch\": 'mAP@0.50',\n",
+ " \"sg_logger\": \"wandb_sg_logger\",\n",
+ " \"sg_logger_params\": {\n",
+ " \"project_name\": WANDB_PROJECT_NAME,\n",
+ " \"save_checkpoints_remote\": True,\n",
+ " \"save_tensorboard_remote\": True,\n",
+ " \"save_logs_remote\": True,\n",
+ " \"entity\": ENTITY\n",
+ " }\n",
+ " }\n",
+ "\n",
+ " trainer = Trainer(\n",
+ " experiment_name=WANDB_EXP_NAME,\n",
+ " ckpt_root_dir=CHECKPOINT_DIR\n",
+ " )\n",
+ "\n",
+ " trainer.train(\n",
+ " model=model,\n",
+ " training_params=train_params,\n",
+ " train_loader=train_data,\n",
+ " valid_loader=val_data\n",
+ " )\n",
+ "\n",
+ " wandb.finish()"
+ ],
+ "metadata": {
+ "id": "nnZzxVmEwb-U"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# Execute Sweep"
+ ],
+ "metadata": {
+ "id": "3Xp0UU73zajw"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "sweep_id = wandb.sweep(\n",
+ " sweep=sweep_configuration,\n",
+ " project=\"yolo-nas-sweep\"\n",
+ ")\n",
+ "\n",
+ "wandb.agent(sweep_id, function=main_call, count=SWEEP_NUM_RUNS)"
+ ],
+ "metadata": {
+ "id": "9r6LrkhQwxdA"
+ },
+ "execution_count": null,
+ "outputs": []
+ }
+ ]
+}
\ No newline at end of file
diff --git a/colabs/tables/AlphaFold_with_W&B_Align,_Fold,_Log.ipynb b/colabs/tables/AlphaFold_with_W&B_Align,_Fold,_Log.ipynb
index fce4e1f3..ceb8e2c6 100644
--- a/colabs/tables/AlphaFold_with_W&B_Align,_Fold,_Log.ipynb
+++ b/colabs/tables/AlphaFold_with_W&B_Align,_Fold,_Log.ipynb
@@ -89,8 +89,15 @@
"outputs": [],
"source": [
"# now run this cell (again by pressing the play button)\n",
- "# to log in to Weights & Biases (by following any prompts)\n",
- "\n",
+ "# to log in to Weights & Biases (by following any prompts)\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"wandb.login()"
]
},
diff --git a/colabs/tables/Log_Tables_Incrementally.ipynb b/colabs/tables/Log_Tables_Incrementally.ipynb
index fdb44e7b..ab826995 100644
--- a/colabs/tables/Log_Tables_Incrementally.ipynb
+++ b/colabs/tables/Log_Tables_Incrementally.ipynb
@@ -38,8 +38,15 @@
"metadata": {},
"outputs": [],
"source": [
- "import wandb\n",
- "\n",
+ "import wandb\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"wandb.login()"
]
},
diff --git a/colabs/tables/W&B_Tables_Quickstart.ipynb b/colabs/tables/W&B_Tables_Quickstart.ipynb
index 8814ca0a..185c4766 100644
--- a/colabs/tables/W&B_Tables_Quickstart.ipynb
+++ b/colabs/tables/W&B_Tables_Quickstart.ipynb
@@ -46,7 +46,15 @@
"metadata": {},
"outputs": [],
"source": [
- "import wandb\n",
+ "import wandb"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"wandb.login()"
]
},
diff --git a/colabs/tensorboard/Accelerator_W&B_Tensorboard.ipynb b/colabs/tensorboard/Accelerator_W&B_Tensorboard.ipynb
index 5f8806be..f6c28aac 100644
--- a/colabs/tensorboard/Accelerator_W&B_Tensorboard.ipynb
+++ b/colabs/tensorboard/Accelerator_W&B_Tensorboard.ipynb
@@ -88,7 +88,15 @@
"source": [
"import wandb\n",
"\n",
- "# log to Weights and biases\n",
+ "# log to Weights and biases"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"wandb.login()"
]
},
diff --git a/colabs/tensorboard/TensorBoard_and_Weights_and_Biases.ipynb b/colabs/tensorboard/TensorBoard_and_Weights_and_Biases.ipynb
index 0bb947f4..c3f953b5 100644
--- a/colabs/tensorboard/TensorBoard_and_Weights_and_Biases.ipynb
+++ b/colabs/tensorboard/TensorBoard_and_Weights_and_Biases.ipynb
@@ -152,13 +152,6 @@
"\n",
"You can find additional documentation of how to use [Weights & Biases with Tensorboard here](https://docs.wandb.ai/guides/integrations/tensorboard)"
]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
}
],
"metadata": {
diff --git a/colabs/tensorflow/Hyperparameter_Optimization_in_TensorFlow_using_W&B_Sweeps.ipynb b/colabs/tensorflow/Hyperparameter_Optimization_in_TensorFlow_using_W&B_Sweeps.ipynb
index bec47201..d45de2a0 100644
--- a/colabs/tensorflow/Hyperparameter_Optimization_in_TensorFlow_using_W&B_Sweeps.ipynb
+++ b/colabs/tensorflow/Hyperparameter_Optimization_in_TensorFlow_using_W&B_Sweeps.ipynb
@@ -141,8 +141,15 @@
"outputs": [],
"source": [
"import wandb\n",
- "from wandb.keras import WandbCallback\n",
- "\n",
+ "from wandb.keras import WandbCallback\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"wandb.login()"
]
},
diff --git a/colabs/tensorflow/Simple_TensorFlow_Integration.ipynb b/colabs/tensorflow/Simple_TensorFlow_Integration.ipynb
index 70313d0d..a42a5a0c 100644
--- a/colabs/tensorflow/Simple_TensorFlow_Integration.ipynb
+++ b/colabs/tensorflow/Simple_TensorFlow_Integration.ipynb
@@ -136,8 +136,15 @@
"outputs": [],
"source": [
"import wandb\n",
- "from wandb.keras import WandbCallback\n",
- "\n",
+ "from wandb.keras import WandbCallback\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"wandb.login()"
]
},
diff --git a/colabs/tensorflow/convert_imagenette_tfrecord.ipynb b/colabs/tensorflow/convert_imagenette_tfrecord.ipynb
index cb82bc7d..36503727 100644
--- a/colabs/tensorflow/convert_imagenette_tfrecord.ipynb
+++ b/colabs/tensorflow/convert_imagenette_tfrecord.ipynb
@@ -264,13 +264,6 @@
"source": [
"wandb.finish()"
]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
}
],
"metadata": {
diff --git a/colabs/ultralytics/00_inference.ipynb b/colabs/ultralytics/00_inference.ipynb
index 7814eed7..684762db 100644
--- a/colabs/ultralytics/00_inference.ipynb
+++ b/colabs/ultralytics/00_inference.ipynb
@@ -37,12 +37,8 @@
"metadata": {},
"outputs": [],
"source": [
- "# Install Ultralytics\n",
- "!pip install -q ultralytics\n",
- "\n",
- "# Install the `feat/ultralytics`` branch from W&B,\n",
- "# which currently houses the out-of-the-box integration for Ultralytics.\n",
- "!pip install -q git+https://github.com/wandb/wandb@feat/ultralytics"
+ "# Install WandB and Ultralytics\n",
+ "!pip install -q -U wandb ultralytics"
]
},
{
@@ -153,8 +149,7 @@
"metadata": {
"language_info": {
"name": "python"
- },
- "orig_nbformat": 4
+ }
},
"nbformat": 4,
"nbformat_minor": 2
diff --git a/colabs/ultralytics/01_train_val.ipynb b/colabs/ultralytics/01_train_val.ipynb
index 22021e16..463d9559 100644
--- a/colabs/ultralytics/01_train_val.ipynb
+++ b/colabs/ultralytics/01_train_val.ipynb
@@ -30,12 +30,8 @@
"metadata": {},
"outputs": [],
"source": [
- "# Install Ultralytics\n",
- "!pip install -q ultralytics\n",
- "\n",
- "# Install the `feat/ultralytics`` branch from W&B,\n",
- "# which currently houses the out-of-the-box integration for Ultralytics.\n",
- "!pip install -q git+https://github.com/wandb/wandb@feat/ultralytics"
+ "# Install WandB and Ultralytics\n",
+ "!pip install -q -U wandb ultralytics"
]
},
{
@@ -142,8 +138,7 @@
"metadata": {
"language_info": {
"name": "python"
- },
- "orig_nbformat": 4
+ }
},
"nbformat": 4,
"nbformat_minor": 2
diff --git a/colabs/wandb-artifacts/Artifacts_Quickstart_with_W&B.ipynb b/colabs/wandb-artifacts/Artifacts_Quickstart_with_W&B.ipynb
index a18817ac..abfd131b 100644
--- a/colabs/wandb-artifacts/Artifacts_Quickstart_with_W&B.ipynb
+++ b/colabs/wandb-artifacts/Artifacts_Quickstart_with_W&B.ipynb
@@ -126,7 +126,15 @@
"outputs": [],
"source": [
"!pip install wandb -qq\n",
- "import wandb\n",
+ "import wandb"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"wandb.login()"
]
},
diff --git a/colabs/wandb-artifacts/Basic_Artifacts_with_W&B.ipynb b/colabs/wandb-artifacts/Basic_Artifacts_with_W&B.ipynb
index a5a6acf5..2a566b88 100644
--- a/colabs/wandb-artifacts/Basic_Artifacts_with_W&B.ipynb
+++ b/colabs/wandb-artifacts/Basic_Artifacts_with_W&B.ipynb
@@ -82,7 +82,15 @@
"outputs": [],
"source": [
"!pip install wandb -qq\n",
- "import wandb\n",
+ "import wandb"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"wandb.login()"
]
},
diff --git a/colabs/wandb-artifacts/W&B_artifacts_for_auditing_purposes.ipynb b/colabs/wandb-artifacts/W&B_artifacts_for_auditing_purposes.ipynb
index 5818a0a7..c7ee4332 100644
--- a/colabs/wandb-artifacts/W&B_artifacts_for_auditing_purposes.ipynb
+++ b/colabs/wandb-artifacts/W&B_artifacts_for_auditing_purposes.ipynb
@@ -84,8 +84,15 @@
"metadata": {},
"outputs": [],
"source": [
- "import wandb\n",
- "\n",
+ "import wandb\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"wandb.login()"
]
},
diff --git a/colabs/wandb-artifacts/train_val_test_split_with_tabular_data.ipynb b/colabs/wandb-artifacts/train_val_test_split_with_tabular_data.ipynb
index e7375e69..a23c6672 100644
--- a/colabs/wandb-artifacts/train_val_test_split_with_tabular_data.ipynb
+++ b/colabs/wandb-artifacts/train_val_test_split_with_tabular_data.ipynb
@@ -60,7 +60,6 @@
"source": [
"!pip install --upgrade wandb -qqq\n",
"import wandb\n",
- "wandb.login()\n",
"\n",
"import random\n",
"from collections import OrderedDict\n",
@@ -90,6 +89,15 @@
"set_seeds(0)"
]
},
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "wandb.login()"
+ ]
+ },
{
"cell_type": "code",
"execution_count": null,
diff --git a/colabs/wandb-log/Configs_in_W&B.ipynb b/colabs/wandb-log/Configs_in_W&B.ipynb
index 819736e4..67469509 100644
--- a/colabs/wandb-log/Configs_in_W&B.ipynb
+++ b/colabs/wandb-log/Configs_in_W&B.ipynb
@@ -47,7 +47,15 @@
"metadata": {},
"outputs": [],
"source": [
- "import wandb\n",
+ "import wandb"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"wandb.login()"
]
},
diff --git a/colabs/wandb-log/Customize_metric_logging_with_define_metric.ipynb b/colabs/wandb-log/Customize_metric_logging_with_define_metric.ipynb
index d9d4a932..0769958e 100644
--- a/colabs/wandb-log/Customize_metric_logging_with_define_metric.ipynb
+++ b/colabs/wandb-log/Customize_metric_logging_with_define_metric.ipynb
@@ -41,8 +41,15 @@
"source": [
"\n",
"import wandb\n",
- "import random\n",
- "\n",
+ "import random\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"wandb.login()"
]
},
diff --git a/colabs/wandb-log/Generate_gifs_from_logged_images_on_wandb.ipynb b/colabs/wandb-log/Generate_gifs_from_logged_images_on_wandb.ipynb
index 7c1c2f5e..5d69b6e6 100644
--- a/colabs/wandb-log/Generate_gifs_from_logged_images_on_wandb.ipynb
+++ b/colabs/wandb-log/Generate_gifs_from_logged_images_on_wandb.ipynb
@@ -46,7 +46,15 @@
"metadata": {},
"outputs": [],
"source": [
- "import wandb\n",
+ "import wandb"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"wandb.login()"
]
},
@@ -224,13 +232,6 @@
"run = api.run(RUN_PATH)\n",
"make_and_display_gifs(run)"
]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
}
],
"metadata": {
diff --git a/colabs/wandb-log/Image_Logging_de_duplication.ipynb b/colabs/wandb-log/Image_Logging_de_duplication.ipynb
index 25af66a9..77d3e89d 100644
--- a/colabs/wandb-log/Image_Logging_de_duplication.ipynb
+++ b/colabs/wandb-log/Image_Logging_de_duplication.ipynb
@@ -80,13 +80,6 @@
"wandb.log({\"image\": img_1})\n",
"wandb.finish()"
]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
}
],
"metadata": {
diff --git a/colabs/wandb-log/Log_a_Confusion_Matrix_with_W&B.ipynb b/colabs/wandb-log/Log_a_Confusion_Matrix_with_W&B.ipynb
index 4bfc7979..d628ce43 100644
--- a/colabs/wandb-log/Log_a_Confusion_Matrix_with_W&B.ipynb
+++ b/colabs/wandb-log/Log_a_Confusion_Matrix_with_W&B.ipynb
@@ -86,7 +86,15 @@
"metadata": {},
"outputs": [],
"source": [
- "import wandb\n",
+ "import wandb"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"wandb.login()"
]
},
diff --git a/colabs/wandb-log/Logging_Strategies_for_High_Frequency_Data.ipynb b/colabs/wandb-log/Logging_Strategies_for_High_Frequency_Data.ipynb
index eb45d0cc..15aa99cc 100644
--- a/colabs/wandb-log/Logging_Strategies_for_High_Frequency_Data.ipynb
+++ b/colabs/wandb-log/Logging_Strategies_for_High_Frequency_Data.ipynb
@@ -86,8 +86,15 @@
"metadata": {},
"outputs": [],
"source": [
- "# login\n",
- "import wandb\n",
+ "import wandb"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"wandb.login()"
]
},
diff --git a/colabs/wandb-log/Plot_Precision_Recall_Curves_with_W&B.ipynb b/colabs/wandb-log/Plot_Precision_Recall_Curves_with_W&B.ipynb
index 5ecd9eb5..0e8d8498 100644
--- a/colabs/wandb-log/Plot_Precision_Recall_Curves_with_W&B.ipynb
+++ b/colabs/wandb-log/Plot_Precision_Recall_Curves_with_W&B.ipynb
@@ -83,7 +83,15 @@
"metadata": {},
"outputs": [],
"source": [
- "import wandb\n",
+ "import wandb"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"wandb.login()"
]
},
diff --git a/colabs/wandb-log/Plot_ROC_Curves_with_W&B.ipynb b/colabs/wandb-log/Plot_ROC_Curves_with_W&B.ipynb
index 87bd5971..6bda70c0 100644
--- a/colabs/wandb-log/Plot_ROC_Curves_with_W&B.ipynb
+++ b/colabs/wandb-log/Plot_ROC_Curves_with_W&B.ipynb
@@ -83,8 +83,15 @@
"metadata": {},
"outputs": [],
"source": [
- "# login\n",
- "import wandb\n",
+ "import wandb"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"wandb.login()"
]
},
diff --git a/colabs/wandb-log/Run_names_visualized_using_min_dalle.ipynb b/colabs/wandb-log/Run_names_visualized_using_min_dalle.ipynb
index 18c9c64e..072ea979 100644
--- a/colabs/wandb-log/Run_names_visualized_using_min_dalle.ipynb
+++ b/colabs/wandb-log/Run_names_visualized_using_min_dalle.ipynb
@@ -39,10 +39,18 @@
"! git lfs install\n",
"! git clone https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384 /content/min-dalle/pretrained/vqgan\n",
"!pip install torch flax==0.4.2 wandb\n",
- "! wandb login\n",
"! wandb artifact get --root=/content/min-dalle/pretrained/dalle_bart_mini dalle-mini/dalle-mini/mini-1:v0"
]
},
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "wandb.login()"
+ ]
+ },
{
"cell_type": "markdown",
"metadata": {},
diff --git a/colabs/wandb-log/Set_Alerts_with_W_&_B.ipynb b/colabs/wandb-log/Set_Alerts_with_W_&_B.ipynb
index 19fedfe8..137fd17b 100644
--- a/colabs/wandb-log/Set_Alerts_with_W_&_B.ipynb
+++ b/colabs/wandb-log/Set_Alerts_with_W_&_B.ipynb
@@ -53,7 +53,15 @@
"outputs": [],
"source": [
"# Log in to your W&B account\n",
- "import wandb\n",
+ "import wandb"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"wandb.login()"
]
},
diff --git a/colabs/wandb-model-registry/Model_Registry_E2E.ipynb b/colabs/wandb-model-registry/Model_Registry_E2E.ipynb
index 5ecae223..1124c5c3 100644
--- a/colabs/wandb-model-registry/Model_Registry_E2E.ipynb
+++ b/colabs/wandb-model-registry/Model_Registry_E2E.ipynb
@@ -75,7 +75,15 @@
"source": [
"import wandb\n",
"\n",
- "# Login to W&B\n",
+ "# Login to W&B"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"wandb.login()"
]
},
@@ -594,13 +602,6 @@
"artifact_dir = artifact.download()\n",
"wandb.finish()"
]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
}
],
"metadata": {
diff --git a/colabs/wandb-model-registry/New_Model_Logging_in_W&B.ipynb b/colabs/wandb-model-registry/New_Model_Logging_in_W&B.ipynb
new file mode 100644
index 00000000..449e9d7e
--- /dev/null
+++ b/colabs/wandb-model-registry/New_Model_Logging_in_W&B.ipynb
@@ -0,0 +1,416 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "provenance": [],
+ "gpuType": "T4",
+ "authorship_tag": "ABX9TyOlvkZjseUluPYEdqGH1TjK",
+ "include_colab_link": true
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "language_info": {
+ "name": "python"
+ },
+ "accelerator": "GPU"
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "view-in-github",
+ "colab_type": "text"
+ },
+ "source": [
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# Logging and Registering Models in W&B\n",
+ "It's never been easier to log your model checkpoints, keep track of the best ones, and maintain lineage of runs and results!\n",
+ "\n",
+ "W&B is introducing a few convenience methods to make logging models and linking them to the registry simple:\n",
+ "- `log_model`\n",
+ "- `use_model`\n",
+ "- `link_model`"
+ ],
+ "metadata": {
+ "id": "saeQIrcF155i"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Imports"
+ ],
+ "metadata": {
+ "id": "_ewwM5T_OWTr"
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "BbkuyVKRwCbi"
+ },
+ "outputs": [],
+ "source": [
+ "!pip install -qqq wandb einops"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import torch\n",
+ "from torch import nn\n",
+ "from einops import rearrange, repeat\n",
+ "from einops.layers.torch import Rearrange\n",
+ "import torch\n",
+ "from torch.utils.data import DataLoader, Dataset\n",
+ "from torchvision import transforms\n",
+ "import torch\n",
+ "from torch import nn\n",
+ "from torch.optim import Adam\n",
+ "from torch.utils.data import DataLoader\n",
+ "from torchvision import datasets, transforms\n",
+ "\n",
+ "import wandb"
+ ],
+ "metadata": {
+ "id": "6OsMa64TOacf"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Log in to W&B\n",
+ "- You can explicitly login using `wandb login` or `wandb.login()` (See below)\n",
+ "- Alternatively you can set environment variables. There are several env variables which you can set to change the behavior of W&B logging. The most important are:\n",
+ " - `WANDB_API_KEY` - find this in your \"Settings\" section under your profile\n",
+ " - `WANDB_BASE_URL` - this is the url of the W&B server\n",
+ "- Find your API Token in \"Profile\" -> \"Setttings\" in the W&B App\n",
+ "\n",
+ "![api_token](https://drive.google.com/uc?export=view&id=1Xn7hnn0rfPu_EW0A_-32oCXqDmpA0-kx)"
+ ],
+ "metadata": {
+ "id": "m3Q_QG14Ovvz"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "wandb.login()"
+ ],
+ "metadata": {
+ "id": "5-r3O3gFwfcr"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# Define the Model and Dataset\n",
+ "This is a simple implementation of a Vision Transformer (ViT) and utilizes a random dataset for training.\n",
+ "- Credit to https://github.com/lucidrains/vit-pytorch"
+ ],
+ "metadata": {
+ "id": "ZuL3yZ17qIGE"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Define some config for the model and dataset"
+ ],
+ "metadata": {
+ "id": "WE4j_CI7Tt9O"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# Define the number of samples, classes, and image size\n",
+ "num_samples = 100\n",
+ "num_classes = 10\n",
+ "image_size = 256\n",
+ "batch_size = 32"
+ ],
+ "metadata": {
+ "id": "D7sVekyVTlxN"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# helpers\n",
+ "def pair(t):\n",
+ " return t if isinstance(t, tuple) else (t, t)\n",
+ "\n",
+ "# classes\n",
+ "class FeedForward(nn.Module):\n",
+ " def __init__(self, dim, hidden_dim, dropout = 0.):\n",
+ " super().__init__()\n",
+ " self.net = nn.Sequential(\n",
+ " nn.LayerNorm(dim),\n",
+ " nn.Linear(dim, hidden_dim),\n",
+ " nn.GELU(),\n",
+ " nn.Dropout(dropout),\n",
+ " nn.Linear(hidden_dim, dim),\n",
+ " nn.Dropout(dropout)\n",
+ " )\n",
+ "\n",
+ " def forward(self, x):\n",
+ " return self.net(x)\n",
+ "\n",
+ "class Attention(nn.Module):\n",
+ " def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):\n",
+ " super().__init__()\n",
+ " inner_dim = dim_head * heads\n",
+ " project_out = not (heads == 1 and dim_head == dim)\n",
+ "\n",
+ " self.heads = heads\n",
+ " self.scale = dim_head ** -0.5\n",
+ "\n",
+ " self.norm = nn.LayerNorm(dim)\n",
+ "\n",
+ " self.attend = nn.Softmax(dim = -1)\n",
+ " self.dropout = nn.Dropout(dropout)\n",
+ "\n",
+ " self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)\n",
+ "\n",
+ " self.to_out = nn.Sequential(\n",
+ " nn.Linear(inner_dim, dim),\n",
+ " nn.Dropout(dropout)\n",
+ " ) if project_out else nn.Identity()\n",
+ "\n",
+ " def forward(self, x):\n",
+ " x = self.norm(x)\n",
+ "\n",
+ " qkv = self.to_qkv(x).chunk(3, dim = -1)\n",
+ " q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)\n",
+ "\n",
+ " dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale\n",
+ "\n",
+ " attn = self.attend(dots)\n",
+ " attn = self.dropout(attn)\n",
+ "\n",
+ " out = torch.matmul(attn, v)\n",
+ " out = rearrange(out, 'b h n d -> b n (h d)')\n",
+ " return self.to_out(out)\n",
+ "\n",
+ "class Transformer(nn.Module):\n",
+ " def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):\n",
+ " super().__init__()\n",
+ " self.norm = nn.LayerNorm(dim)\n",
+ " self.layers = nn.ModuleList([])\n",
+ " for _ in range(depth):\n",
+ " self.layers.append(nn.ModuleList([\n",
+ " Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),\n",
+ " FeedForward(dim, mlp_dim, dropout = dropout)\n",
+ " ]))\n",
+ "\n",
+ " def forward(self, x):\n",
+ " for attn, ff in self.layers:\n",
+ " x = attn(x) + x\n",
+ " x = ff(x) + x\n",
+ "\n",
+ " return self.norm(x)\n",
+ "\n",
+ "class ViT(nn.Module):\n",
+ " def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):\n",
+ " super().__init__()\n",
+ " image_height, image_width = pair(image_size)\n",
+ " patch_height, patch_width = pair(patch_size)\n",
+ "\n",
+ " assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'\n",
+ "\n",
+ " num_patches = (image_height // patch_height) * (image_width // patch_width)\n",
+ " patch_dim = channels * patch_height * patch_width\n",
+ " assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'\n",
+ "\n",
+ " self.to_patch_embedding = nn.Sequential(\n",
+ " Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),\n",
+ " nn.LayerNorm(patch_dim),\n",
+ " nn.Linear(patch_dim, dim),\n",
+ " nn.LayerNorm(dim),\n",
+ " )\n",
+ "\n",
+ " self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))\n",
+ " self.cls_token = nn.Parameter(torch.randn(1, 1, dim))\n",
+ " self.dropout = nn.Dropout(emb_dropout)\n",
+ "\n",
+ " self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)\n",
+ "\n",
+ " self.pool = pool\n",
+ " self.to_latent = nn.Identity()\n",
+ "\n",
+ " self.mlp_head = nn.Linear(dim, num_classes)\n",
+ "\n",
+ " def forward(self, img):\n",
+ " x = self.to_patch_embedding(img)\n",
+ " b, n, _ = x.shape\n",
+ "\n",
+ " cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)\n",
+ " x = torch.cat((cls_tokens, x), dim=1)\n",
+ " x += self.pos_embedding[:, :(n + 1)]\n",
+ " x = self.dropout(x)\n",
+ "\n",
+ " x = self.transformer(x)\n",
+ "\n",
+ " x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]\n",
+ "\n",
+ " x = self.to_latent(x)\n",
+ " return self.mlp_head(x)\n",
+ "\n",
+ "\n",
+ "# Define a custom dataset\n",
+ "class RandomImageDataset(Dataset):\n",
+ " def __init__(self, num_samples, num_classes, image_size):\n",
+ " self.num_samples = num_samples\n",
+ " self.num_classes = num_classes\n",
+ " self.image_size = image_size\n",
+ "\n",
+ " def __len__(self):\n",
+ " return self.num_samples\n",
+ "\n",
+ " def __getitem__(self, idx):\n",
+ " # Generate a random image tensor\n",
+ " image = torch.randn(3, self.image_size, self.image_size) # 3 channels, image_size x image_size\n",
+ " # Generate a random label\n",
+ " label = torch.randint(0, self.num_classes, (1,)).item()\n",
+ " return image, label\n",
+ "\n",
+ "\n",
+ "\n",
+ "# Create the dataset\n",
+ "dataset = RandomImageDataset(num_samples=num_samples, num_classes=num_classes, image_size=image_size)\n",
+ "\n",
+ "# Create a DataLoader\n",
+ "dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)"
+ ],
+ "metadata": {
+ "id": "TntUlCT2ppV2"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# Log Model Checkpoints to W&B with 1 Line!\n",
+ "\n",
+ "Use the `log_model` method to log a model artifact containing the contents inside the ‘path’ to an a run. It also marks it as an output to the run. You can see the full lineage graph of the model artifact by accessing the [lineage](https://docs.wandb.ai/guides/artifacts/explore-and-traverse-an-artifact-graph#docusaurus_skipToContent_fallback) tab inside the Artifacts view.\n",
+ "\n",
+ "`log_model()` takes as input:\n",
+ "\n",
+ "- `path`: A path to the model file(s), which can be a local file (of the form `/local/directory/file.txt`), directory (of the form `/local/directory`), or reference path to S3 (`s3://bucket/path`).\n",
+ "- `name`: An optional name for the model artifact the files will be logged to. Note that if no name is specified, This will default to the basename of the input path prepended with the run ID.\n",
+ "- `aliases`: An optional list of aliases, which can be thought of as semantic ‘nicknames’ or identifiers for a model version. For example, if this model yielded the best accuracy, you might add the alias ‘highest-accuracy’ or ‘best’."
+ ],
+ "metadata": {
+ "id": "ZkomIMW6VNky"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "run = wandb.init(project=\"new_model_logging\",\n",
+ " job_type=\"training\")\n",
+ "\n",
+ "v = ViT(\n",
+ " image_size = image_size,\n",
+ " patch_size = 32,\n",
+ " num_classes = num_classes,\n",
+ " dim = 128,\n",
+ " depth = 3,\n",
+ " heads = 2,\n",
+ " mlp_dim = 256,\n",
+ " dropout = 0.1,\n",
+ " emb_dropout = 0.1\n",
+ ")\n",
+ "\n",
+ "# Define the loss function and optimizer\n",
+ "criterion = nn.CrossEntropyLoss()\n",
+ "optimizer = Adam(v.parameters(), lr=0.003)\n",
+ "\n",
+ "# Training loop\n",
+ "best_accuracy = 0\n",
+ "for epoch in range(5): # number of epochs\n",
+ " for images, labels in dataloader:\n",
+ " # Forward pass\n",
+ " preds = v(images)\n",
+ " loss = criterion(preds, labels)\n",
+ "\n",
+ " # Backward pass and optimization\n",
+ " optimizer.zero_grad()\n",
+ " loss.backward()\n",
+ " optimizer.step()\n",
+ "\n",
+ " wandb.log({\"train/loss\": loss})\n",
+ "\n",
+ " # Model evaluation after each epoch (using a validation set)\n",
+ " # Here you would write your validation loop and calculate accuracy\n",
+ " val_accuracy = 0.5 # Assume this is the validation accuracy you compute\n",
+ " model_path = 'model_vit.pth'\n",
+ " torch.save(v.state_dict(), model_path)\n",
+ "\n",
+ " # Check if this is the best model so far\n",
+ " if val_accuracy > best_accuracy:\n",
+ " best_accuracy = val_accuracy\n",
+ " # Log the model to your W&B run\n",
+ " wandb.log_model(name=f\"model_vit-{wandb.run.id}\", path=model_path, aliases=[\"best\", f\"epoch_{epoch}\"])\n",
+ " else:\n",
+ " wandb.log_model(name=f\"model_vit-{wandb.run.id}\", path=model_path, aliases=[f\"epoch_{epoch}\"])"
+ ],
+ "metadata": {
+ "id": "48khrxaKt-rm"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# Link Your Best Models to the Model Registry\n",
+ "You can bookmark your best model checkpoints and centralize them across your team. The Model Registry allows you can organize your best models by task, manage model lifecycle, facilitate easy tracking and auditing throughout the ML lifecyle, and automate downstream actions with webhooks or jobs. You can this via api through `link_model()`, which takes as input:\n",
+ "\n",
+ "- `path`: A path to the model file(s), which can be a local file (of the form `/local/directory/file.txt`), directory (of the form `/local/directory`), or reference path to S3 (`s3://bucket/path`).\n",
+ "- `registered_model_name`: the name of the Registered Model - a collection of linked model versions in the Model Registry, typically representing a team’s ML task - that the model should be linked to. If no Registered Model with the given name exists, a new one will be created with this name.\n",
+ "- `name`: An **optional** name for the model artifact the files will be logged to. Note that if no name is specified, This will default to the basename of the input path prepended with the run ID.\n",
+ "- `aliases`: An **optional** list of aliases, which can be thought of as semantic ‘nicknames’ or identifiers for a linked model version. For example, since this model is being linked, or published, to the Model Registry, you might add an alias “staging” or “QA”."
+ ],
+ "metadata": {
+ "id": "km2nU7IUVg63"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# Link the best model to the W&B Model Registry (after all epochs are finished)\n",
+ "artifact_name = f\"model_vit-{wandb.run.id}:best\"\n",
+ "best_model_path = wandb.use_model(artifact_name)\n",
+ "\n",
+ "# Link the best model to the registry\n",
+ "wandb.link_model(path=best_model_path,\n",
+ " registered_model_name=\"Industrial ViT\",\n",
+ " aliases=[\"staging\"])\n",
+ "run.finish()"
+ ],
+ "metadata": {
+ "id": "ckH-yexDVgJI"
+ },
+ "execution_count": null,
+ "outputs": []
+ }
+ ]
+}
\ No newline at end of file
diff --git a/colabs/yolox/Train_and_Debug_YOLOX_Models_with_Weights_&_Biases.ipynb b/colabs/yolox/Train_and_Debug_YOLOX_Models_with_Weights_&_Biases.ipynb
index 8442213f..4f9c3028 100644
--- a/colabs/yolox/Train_and_Debug_YOLOX_Models_with_Weights_&_Biases.ipynb
+++ b/colabs/yolox/Train_and_Debug_YOLOX_Models_with_Weights_&_Biases.ipynb
@@ -89,7 +89,15 @@
"metadata": {},
"outputs": [],
"source": [
- "import wandb\n",
+ "import wandb"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"wandb.login()"
]
},
diff --git a/examples/jax/jax-llm/create_tokenizer.ipynb b/examples/jax/jax-llm/create_tokenizer.ipynb
index d3c1fd83..090e18ef 100644
--- a/examples/jax/jax-llm/create_tokenizer.ipynb
+++ b/examples/jax/jax-llm/create_tokenizer.ipynb
@@ -3,7 +3,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "eedce872-2204-49f8-8a52-ef338f7a6654",
+ "id": "20860c23",
"metadata": {},
"outputs": [],
"source": [
@@ -22,7 +22,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "2f30468d-36f7-448c-8a09-6e2637dc5c3c",
+ "id": "202d2317",
"metadata": {},
"outputs": [],
"source": [
@@ -32,7 +32,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "db2f7eeb-9178-4fa9-9156-1a86072add3e",
+ "id": "8e472e8d",
"metadata": {},
"outputs": [],
"source": [
@@ -43,7 +43,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "80520842-7460-4f0b-bf6c-c586506f8bc8",
+ "id": "8a4be9b2",
"metadata": {},
"outputs": [],
"source": [
@@ -53,7 +53,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "7216dd5d-25bf-4b6d-b677-157eb7e85756",
+ "id": "9c9cfb89",
"metadata": {},
"outputs": [],
"source": [
@@ -71,7 +71,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "49ff186a-ecfc-4492-b457-053734ed5970",
+ "id": "c280c1eb",
"metadata": {},
"outputs": [],
"source": [
@@ -81,7 +81,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "2a2fffc0-8bd5-4558-b0db-655271ced251",
+ "id": "3da91241",
"metadata": {},
"outputs": [],
"source": [
@@ -103,7 +103,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "6784b4d7-22dc-4726-bedf-a655a7ae2982",
+ "id": "25df71ba",
"metadata": {},
"outputs": [],
"source": [
@@ -116,7 +116,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "7d9804dd-3fb1-42f5-80cc-9d392c9a229a",
+ "id": "bedeed59",
"metadata": {},
"outputs": [],
"source": [
@@ -126,7 +126,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "e2298511-f2ba-4863-95dd-f592c78499ed",
+ "id": "a1d2330a",
"metadata": {},
"outputs": [],
"source": [
@@ -136,7 +136,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "b5ed4e7f-3218-4f72-b407-4369dd9df3d2",
+ "id": "b1af6ef5",
"metadata": {},
"outputs": [],
"source": [
@@ -146,7 +146,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "aa0fb355-304b-48be-8585-ba246a553ad4",
+ "id": "f2ff15a1",
"metadata": {},
"outputs": [],
"source": [
@@ -159,23 +159,6 @@
"display_name": "Python 3",
"language": "python",
"name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.8.10"
- },
- "vscode": {
- "interpreter": {
- "hash": "916dbcbb3f70747c44a77c7bcd40155683ae19c65e1c03b4aa3499c5328201f1"
- }
}
},
"nbformat": 4,
diff --git a/examples/keras/keras-tensorflow2/TensorFlow_2_0_+_Keras_Crash_Course_+_W&B.ipynb b/examples/keras/keras-tensorflow2/TensorFlow_2_0_+_Keras_Crash_Course_+_W&B.ipynb
index 37560e3a..2a1cc50f 100644
--- a/examples/keras/keras-tensorflow2/TensorFlow_2_0_+_Keras_Crash_Course_+_W&B.ipynb
+++ b/examples/keras/keras-tensorflow2/TensorFlow_2_0_+_Keras_Crash_Course_+_W&B.ipynb
@@ -1679,11 +1679,6 @@
"model.compile(optimizer=optimizer, loss=loss, metrics=[accuracy])"
]
},
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": []
- },
{
"cell_type": "markdown",
"metadata": {},
diff --git a/examples/llama-cpp/.dockerignore b/examples/llama-cpp/.dockerignore
new file mode 100644
index 00000000..6ea88749
--- /dev/null
+++ b/examples/llama-cpp/.dockerignore
@@ -0,0 +1 @@
+models/
\ No newline at end of file
diff --git a/examples/llama-cpp/.gitignore b/examples/llama-cpp/.gitignore
new file mode 100644
index 00000000..4592db0a
--- /dev/null
+++ b/examples/llama-cpp/.gitignore
@@ -0,0 +1,2 @@
+wandb/
+models/
\ No newline at end of file
diff --git a/examples/llama-cpp/Dockerfile b/examples/llama-cpp/Dockerfile
new file mode 100644
index 00000000..adc8d65c
--- /dev/null
+++ b/examples/llama-cpp/Dockerfile
@@ -0,0 +1,30 @@
+ARG CUDA_IMAGE="12.1.1-devel-ubuntu22.04"
+FROM nvidia/cuda:${CUDA_IMAGE}
+
+# We need to set the host to 0.0.0.0 to allow outside access
+ENV HOST 0.0.0.0
+
+RUN apt-get update && apt-get upgrade -y \
+ && apt-get install -y git build-essential \
+ python3 python3-pip gcc wget \
+ ocl-icd-opencl-dev opencl-headers clinfo \
+ libclblast-dev libopenblas-dev \
+ && mkdir -p /etc/OpenCL/vendors && echo "libnvidia-opencl.so.1" > /etc/OpenCL/vendors/nvidia.icd
+
+# setting build related env vars
+ENV CUDA_DOCKER_ARCH=all
+ENV LLAMA_CUBLAS=1
+ENV PYTHONUNBUFFERED=1
+
+# Install depencencies
+RUN python3 -m pip install --no-cache-dir --upgrade pip pytest cmake scikit-build setuptools fastapi \
+ uvicorn sse-starlette pydantic-settings starlette-context openai wandb timeout-decorator
+
+# Install llama-cpp-python (build with cuda)
+RUN CMAKE_ARGS="-DLLAMA_CUBLAS=on" pip install --no-cache-dir llama-cpp-python
+
+COPY app /app
+WORKDIR /app
+
+# Run evaluation
+CMD python3 evaluate.py
\ No newline at end of file
diff --git a/examples/llama-cpp/README.md b/examples/llama-cpp/README.md
new file mode 100644
index 00000000..86dbe4d3
--- /dev/null
+++ b/examples/llama-cpp/README.md
@@ -0,0 +1,76 @@
+# Overview
+
+This script automates the evaluation of different LLM's with W&B. It was originally used at Dockercon '23. The default dataset attemtps to convert english commands into docker CLI commands. See `eval.jsonl`. All of the logic is in `evaluate.py`.
+
+# Documentation
+
+## Setup W&B
+
+```bash
+pip install wandb
+
+# Find your api key at https://wandb.ai/authorize
+export WANDB_API_KEY=XXX
+# Find your openai api key at https://platform.openai.com/account/api-keys
+export OPENAI_API_KEY=XXX
+```
+
+## Download Models
+
+```bash
+python download_models.py
+```
+
+## Nvidia/CUDA
+
+### Build the docker container
+
+```bash
+docker build -t wandb/eval-llm:cuda .
+```
+
+### Run evaluation
+
+```bash
+docker run --gpus=all --cap-add SYS_RESOURCE -e USE_MLOCK=0 -e WANDB_API_KEY -e OPENAI_API_KEY -e MODEL=mistral-7b-instruct-v0.1.Q5_K_M.gguf -e TEMP=0.3 -v $(pwd)/models:/var/models wandb/eval-llm:cuda
+```
+
+## Environment variables
+
+* `TEMP` - temperature _(0.5)_
+* `MAX_TOKENS` - maximum number of tokens to emit _(128)_
+* `SYSTEM_PROMPT` - instructions for the model _(You're a Docker expert. Translate the following sentence to a simple docker command.)_
+* `MODEL` - name of gguf file, or gpt-turbo-3.5, gpt-40 _(codellama-13b-instruct.Q4_K_M.gguf)_
+* `EVAL_PATH` - the path to a jsonl file with "input" and "ideal" keys _(eval.jsonl)_
+* `VERBOSE` - print verbose info from llama-cpp-python _(False)_
+* `DIFF_THRESHOLD` - the percentage threshold for considering a response accurate _(0.7)_
+* `REPITITION_PENALTY` - how much to penalize repeated tokens _(1.1)_
+* `GPU_LAYERS` - the number of layers to offload to the gpu _(-1 for CUDA, 0 for CPU)_
+
+# W&B Launch Setup
+
+## Create a queue
+
+Goto https://wandb.ai/vanpelt/launch and create a queue named "llm-eval-cuda". Set it's config to:
+
+> Note: replace `/home/jupyter` to whatever `pwd` returns in your current directory.
+
+```json
+{
+ "env": ["USE_MLOCK=0", "OPENAI_API_KEY"],
+ "gpus": "all",
+ "volume": "/home/jupyter/models:/var/models",
+ "cap-add": "SYS_RESOURCE"
+}
+```
+
+## Create a docker job
+
+```bash
+wandb job create --project "llm-eval" --name "llm-eval-cuda" image wandb/eval-llm:cuda
+```
+
+## Run an agent
+
+```bash
+wandb launch-agent -q llm-eval-cuda
\ No newline at end of file
diff --git a/examples/llama-cpp/app/download_models.py b/examples/llama-cpp/app/download_models.py
new file mode 100644
index 00000000..6af41b40
--- /dev/null
+++ b/examples/llama-cpp/app/download_models.py
@@ -0,0 +1,27 @@
+import os
+import requests
+
+models = [
+ "https://huggingface.co/TheBloke/CodeLlama-7B-GGUF/resolve/main/codellama-7b.Q5_K_S.gguf",
+ "https://huggingface.co/TheBloke/CodeLlama-13B-Instruct-GGUF/resolve/main/codellama-13b-instruct.Q4_K_M.gguf",
+ "https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.1-GGUF/resolve/main/mistral-7b-instruct-v0.1.Q5_K_M.gguf"
+]
+for url in models:
+ dest = f"models/{url.split('/')[-1]}"
+ if os.path.exists(dest):
+ print(f"Skipping {dest}, already exists")
+ continue
+ print(f"Downloading {url} to {dest}...")
+ response = requests.get(url, stream=True)
+ if response.status_code == 200:
+ with open(dest, 'wb') as f:
+ total_downloaded = 0
+ for chunk in response.iter_content(chunk_size=1024):
+ if chunk: # filter out keep-alive new chunks
+ f.write(chunk)
+ total_downloaded += len(chunk)
+ if total_downloaded >= 10485760: # 10 MB
+ print('.', end='', flush=True)
+ total_downloaded = 0
+ print("\nDownload complete.")
+
\ No newline at end of file
diff --git a/examples/llama-cpp/app/eval.jsonl b/examples/llama-cpp/app/eval.jsonl
new file mode 100644
index 00000000..2a0b036e
--- /dev/null
+++ b/examples/llama-cpp/app/eval.jsonl
@@ -0,0 +1,10 @@
+{"input": "Spin up a new container named busy_server_1 with the latest version of the busybox image and place it under the high_priority cgroup parent in the background.", "ideal": "docker run -d --name busy_server_1 --cgroup-parent /high_priority busybox"}
+{"input": "Please display the images in a table format with the repository, tag, ID, and size included.", "ideal": "docker images --format \"table {{.Repository}},{{.Tag}},{{.ID}},{{.Size}}\""}
+{"input": "Can you show me the containers that are running and have port 8080 published?", "ideal": "docker ps --filter 'publish=8080'"}
+{"input": "I need just the IDs of Docker images.", "ideal": "docker images --quiet"}
+{"input": "Display the containers that are both running and healthy.", "ideal": "docker ps --filter 'status=running' --filter 'health=healthy'"}
+{"input": "Get the images labeled with \"maintainer=nginx\" and show their repository, tag, and ID.", "ideal": "docker images --filter \"label=maintainer=nginx\" --format \"{{.Repository}},{{.Tag}},{{.ID}}\""}
+{"input": "Display all the running containers.", "ideal": "docker ps"}
+{"input": "Display the details of the most recent container execution now!", "ideal": "docker ps -l"}
+{"input": "Please run httpd as an Apache server on port 8080 using the latest image.", "ideal": "docker run --name apache_server -p 8080:80 httpd"}
+{"input": "I need to see the running tasks in Docker.", "ideal": "docker ps --filter 'is-task=true'"}
\ No newline at end of file
diff --git a/examples/llama-cpp/app/evaluate.py b/examples/llama-cpp/app/evaluate.py
new file mode 100644
index 00000000..e4ec62f2
--- /dev/null
+++ b/examples/llama-cpp/app/evaluate.py
@@ -0,0 +1,113 @@
+import difflib
+import json
+from llama_cpp import Llama
+import openai
+import os
+import re
+import subprocess
+import time
+import wandb
+
+config = {
+ 'max_tokens': int(os.getenv("MAX_TOKENS", 128)),
+ 'repetition_penalty': float(os.getenv("REPITITION_PENALTY", 1.1)),
+ 'temperature': float(os.getenv("TEMP", 0.5)),
+ 'gpu_layers': int(os.getenv("GPU_LAYERS", 0)),
+}
+
+model_path = os.getenv("MODEL", "codellama-13b-instruct.Q4_K_M.gguf")
+eval_path = os.getenv("EVAL_PATH", "eval.jsonl")
+system_prompt = os.getenv("SYSTEM_PROMPT", "You're a Docker expert. Translate the following sentence to a simple docker command.")
+diff_threshold = float(os.getenv("DIFF_THRESHOLD", 0.7))
+
+def is_cuda_available():
+ try:
+ subprocess.check_output(["nvidia-smi"])
+ return True
+ except (subprocess.CalledProcessError, FileNotFoundError):
+ return False
+
+wandb_config = {"model": model_path, "eval": eval_path,
+ "system_prompt": system_prompt, **config}
+
+# Set WANDB_MODE=disabled when running this files in tests
+wandb.init(project="llm-eval-v2", config=wandb_config)
+
+if wandb.config["model"].startswith("gpt"):
+ def llm(prompt):
+ res = openai.ChatCompletion.create(
+ model=wandb.config["model"],
+ messages=[
+ {"role": "system", "content": wandb.config["system_prompt"]},
+ {"role": "user", "content": prompt}
+ ],
+ temperature=wandb.config["temperature"],
+ max_tokens=wandb.config["max_tokens"],
+ frequency_penalty=wandb.config["repetition_penalty"],
+ )
+ return res.choices[0].message.content, res.usage.total_tokens
+else:
+ default_gpu = -1 if is_cuda_available() else 0
+ cpp = Llama(f"/var/models/{wandb.config['model']}",
+ verbose=bool(os.getenv("VERBOSE", False)),
+ n_gpu_layers=int(os.getenv("GPU_LAYERS", default_gpu)))
+ def llm(prompt):
+ res = cpp.create_chat_completion(
+ messages=[
+ {"role": "system", "content": wandb.config["system_prompt"]},
+ {"role": "user", "content": f"Q: {prompt}"}
+ ],
+ max_tokens=wandb.config["max_tokens"], stop=["Q:"],
+ repeat_penalty=wandb.config["repetition_penalty"],
+ temperature=wandb.config["temperature"],
+ )
+ return res["choices"][0]["message"]["content"], res["usage"]["total_tokens"]
+
+print(f"Evaluating {wandb.config['model']}")
+table = wandb.Table(columns=["prompt", "output", "ideal", "score", "latency", "tokens"])
+
+codeblock_pattern = re.compile(r'(docker.+)$', re.MULTILINE)
+def fmt(s):
+ return f"`{s}`"
+
+total_score = 0
+total_latency = 0
+total_tokens = 0
+correct = 0.0
+total = 0.0
+with open(eval_path, "r") as f:
+ for line in f:
+ data = json.loads(line)
+ total += 1.0
+ prompt = data["input"]
+ print(prompt)
+ start = time.time()
+ output, tokens = llm(prompt)
+ latency = time.time() - start
+ total_latency += latency
+ matches = codeblock_pattern.findall(output)
+ if len(matches) == 0:
+ print("\t!!! No code generated:")
+ for l in output.split("\n"):
+ print(f"\t> {l}")
+ continue
+ command = matches[0].split("`")[0]
+ score = difflib.SequenceMatcher(None, data["ideal"], command).ratio()
+ print(f"\t({score:.2f}) {command}")
+ total_score += score
+ total_tokens += tokens
+ if score > diff_threshold:
+ correct += 1.0
+ table.add_data(prompt, fmt(command), fmt(data["ideal"]), score, latency, tokens)
+
+wandb.log({
+ "accuracy": correct / total,
+ "diff_score": total_score / total,
+ "avg_tokens": total_tokens / total,
+ "latency": total_latency / total,
+ "eval": table
+})
+print("\nConfig:\n")
+print(json.dumps(dict(wandb.config), indent=4))
+print(f"Accuracy: {wandb.run.summary['accuracy']}")
+print(f"Average diff score: {wandb.run.summary['diff_score']}")
\ No newline at end of file
diff --git a/examples/llama-cpp/sweep.yaml b/examples/llama-cpp/sweep.yaml
new file mode 100644
index 00000000..a79a8e7d
--- /dev/null
+++ b/examples/llama-cpp/sweep.yaml
@@ -0,0 +1,22 @@
+program: evaluate.py
+method: random
+metric:
+ goal: maximize
+ name: diff_score
+parameters:
+ model:
+ distribution: categorical
+ values:
+ - codellama-13b-instruct.Q4_K_M.gguf
+ - codellama-7b.Q5_K_S.gguf
+ - mistral-7b-instruct-v0.1.Q5_K_M.gguf
+ - gpt-3.5-turbo
+ system_prompt:
+ distribution: categorical
+ values:
+ - You're a Docker expert. Translate the following sentence to a simple docker command.
+ - You'll be asked a question about Docker. Your job is to convert this question to a succinct docker command. Only provide a single command and limit your use of pipeing to other unix tools.
+ temperature:
+ distribution: uniform
+ max: 0.6
+ min: 0
\ No newline at end of file
diff --git a/examples/model-evaluation/requirements.txt b/examples/model-evaluation/requirements.txt
index 89530871..8253d0e1 100644
--- a/examples/model-evaluation/requirements.txt
+++ b/examples/model-evaluation/requirements.txt
@@ -1,5 +1,5 @@
wandb
keras
tensorflow
-sklearn
-numpy
\ No newline at end of file
+scikit-learn
+numpy
diff --git a/examples/pytorch/pytorch-cifar10-sagemaker/train.ipynb b/examples/pytorch/pytorch-cifar10-sagemaker/train.ipynb
index b9ec4e33..e1b9dd71 100644
--- a/examples/pytorch/pytorch-cifar10-sagemaker/train.ipynb
+++ b/examples/pytorch/pytorch-cifar10-sagemaker/train.ipynb
@@ -32,7 +32,6 @@
"metadata": {},
"outputs": [],
"source": [
- "# login\n",
"wandb.login()"
]
},
diff --git a/examples/pytorch/pytorch-cifar10-sagemaker/train.py b/examples/pytorch/pytorch-cifar10-sagemaker/train.py
index fe6d0b5b..cae1944e 100644
--- a/examples/pytorch/pytorch-cifar10-sagemaker/train.py
+++ b/examples/pytorch/pytorch-cifar10-sagemaker/train.py
@@ -50,13 +50,14 @@
estimator = PyTorch(entry_point="cifar10.py",
source_dir=os.getcwd() + "/source",
role=role,
- framework_version='1.7.1',
- py_version='py3',
+ framework_version='1.12.1',
+ py_version='py38',
instance_count=1,
instance_type='ml.c5.xlarge',
hyperparameters={
'epochs': args.train_epochs,
- 'momentum': 0.9
+ 'momentum': 0.9,
+ 'num_round': 10,
})
hyperparameter_ranges = {
diff --git a/examples/pytorch/pytorch-intro/intro.ipynb b/examples/pytorch/pytorch-intro/intro.ipynb
index 29e80825..c2a9d5d0 100755
--- a/examples/pytorch/pytorch-intro/intro.ipynb
+++ b/examples/pytorch/pytorch-intro/intro.ipynb
@@ -112,8 +112,15 @@
"metadata": {},
"outputs": [],
"source": [
- "import wandb\n",
- "\n",
+ "import wandb\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"wandb.login()"
]
},
diff --git a/examples/pytorch/pytorch-mnist-sagemaker/pytorch_mnist.ipynb b/examples/pytorch/pytorch-mnist-sagemaker/pytorch_mnist.ipynb
index b6dca598..2e20fdf9 100644
--- a/examples/pytorch/pytorch-mnist-sagemaker/pytorch_mnist.ipynb
+++ b/examples/pytorch/pytorch-mnist-sagemaker/pytorch_mnist.ipynb
@@ -491,11 +491,19 @@
],
"source": [
"import wandb\n",
- "wandb.login()\n",
"settings = wandb.setup().settings\n",
"current_api_key = wandb.wandb_lib.apikey.api_key(settings=settings)"
]
},
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "wandb.login()"
+ ]
+ },
{
"cell_type": "code",
"execution_count": null,
diff --git a/examples/sagemaker/text_classification/text_classification.ipynb b/examples/sagemaker/text_classification/text_classification.ipynb
index b9bf3487..19e7431a 100644
--- a/examples/sagemaker/text_classification/text_classification.ipynb
+++ b/examples/sagemaker/text_classification/text_classification.ipynb
@@ -98,7 +98,15 @@
"metadata": {},
"outputs": [],
"source": [
- "import wandb\n",
+ "import wandb"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"wandb.login()"
]
},
diff --git a/examples/streamlit/requirements.txt b/examples/streamlit/requirements.txt
index caffbe21..9566dd96 100644
--- a/examples/streamlit/requirements.txt
+++ b/examples/streamlit/requirements.txt
@@ -21,14 +21,14 @@ defusedxml==0.7.1
docker-pycreds==0.4.0
entrypoints==0.3
gitdb==4.0.9
-GitPython==3.1.34
+GitPython==3.1.41
idna==3.3
ipykernel==6.5.0
ipython==8.10.0
ipython-genutils==0.2.0
ipywidgets==7.6.5
jedi==0.18.0
-Jinja2==3.0.2
+Jinja2==3.1.3
jsonschema==4.1.2
jupyter-client==7.0.6
jupyter-core==4.11.2
@@ -50,14 +50,14 @@ parso==0.8.2
pathtools==0.1.2
pexpect==4.8.0
pickleshare==0.7.5
-Pillow==9.0.1
+Pillow==10.0.1
prometheus-client==0.12.0
promise==2.3
prompt-toolkit==3.0.21
protobuf==3.19.5
psutil==5.8.0
ptyprocess==0.7.0
-pyarrow==6.0.0
+pyarrow==14.0.1
pycodestyle==2.8.0
pycparser==2.20
pydeck==0.7.1
@@ -76,7 +76,7 @@ sentry-sdk==1.14.0
shortuuid==1.0.1
six==1.16.0
smmap==5.0.0
-streamlit==1.11.1
+streamlit==1.30.0
subprocess32==3.5.4
termcolor==1.1.0
terminado==0.12.1
@@ -88,7 +88,7 @@ traitlets==5.1.1
typing-extensions==3.10.0.2
tzdata==2021.5
tzlocal==4.1
-urllib3==1.26.7
+urllib3==1.26.18
validators==0.18.2
wandb==0.12.6
watchdog==2.1.6