From c98f94b4d3d3fb1d8aba3afc19aa4b3ddc00d650 Mon Sep 17 00:00:00 2001 From: Prashant Dixit <54981696+PrashantDixit0@users.noreply.github.com> Date: Fri, 1 Mar 2024 10:15:57 +0530 Subject: [PATCH] fix (#148) * assests and app name * update README * demo gifs * talk with github codespaces * talk with github codespaces * gitignore * linted * added version * link fix * added local llm tag * crag * link fix * lint * llm tags * non-clickable badge * non-clickable badge * fix * tutorial llm tags * added instructions and fix --- examples/audio_search/main.ipynb | 1800 ++++--- examples/product-recommender/main.ipynb | 5985 ++++++++++++----------- 2 files changed, 3965 insertions(+), 3820 deletions(-) diff --git a/examples/audio_search/main.ipynb b/examples/audio_search/main.ipynb index 49cab5c..ce19c4c 100644 --- a/examples/audio_search/main.ipynb +++ b/examples/audio_search/main.ipynb @@ -1,959 +1,951 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "3lhhVh6TWRjq" - }, - "source": [ - "# Audio Similarity Search using Vector Embeddings\n", - "This notebook demonstrates how to create vector embeddings of audio files to store into the LanceDB vector store, and then to find similar audio files.\n", - "We will be using [panns_inference package](https://github.com/qiuqiangkong/panns_inference) to tag the audio and create embeddings. We'll also be using this [HuggingFace dataset](https://huggingface.co/datasets/ashraq/esc50) for the audio files. The dataset contains 2,000 sounds and labels." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Installing dependencies" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Pks8RDrdWRjt", - "outputId": "387f9c04-f6c5-42ec-f7ba-87a3ae654162" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Collecting datasets\n", - " Downloading datasets-2.14.6-py3-none-any.whl (493 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m493.7/493.7 kB\u001b[0m \u001b[31m5.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hRequirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets) (1.23.5)\n", - "Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (9.0.0)\n", - "Collecting dill<0.3.8,>=0.3.0 (from datasets)\n", - " Downloading dill-0.3.7-py3-none-any.whl (115 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m115.3/115.3 kB\u001b[0m \u001b[31m6.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hRequirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (1.5.3)\n", - "Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (2.31.0)\n", - "Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (4.66.1)\n", - "Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets) (3.4.1)\n", - "Collecting multiprocess (from datasets)\n", - " Downloading multiprocess-0.70.15-py310-none-any.whl (134 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m10.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hRequirement already satisfied: fsspec[http]<=2023.10.0,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (2023.6.0)\n", - "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.8.6)\n", - "Collecting huggingface-hub<1.0.0,>=0.14.0 (from datasets)\n", - " Downloading huggingface_hub-0.18.0-py3-none-any.whl (301 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m302.0/302.0 kB\u001b[0m \u001b[31m11.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hRequirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets) (23.2)\n", - "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (6.0.1)\n", - "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (23.1.0)\n", - "Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (3.3.0)\n", - "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.0.4)\n", - "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3)\n", - "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.9.2)\n", - "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.4.0)\n", - "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)\n", - "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0.0,>=0.14.0->datasets) (3.12.4)\n", - "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0.0,>=0.14.0->datasets) (4.5.0)\n", - "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (3.4)\n", - "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (2.0.7)\n", - "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (2023.7.22)\n", - "Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)\n", - "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2023.3.post1)\n", - "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.1->pandas->datasets) (1.16.0)\n", - "Installing collected packages: dill, multiprocess, huggingface-hub, datasets\n", - "Successfully installed datasets-2.14.6 dill-0.3.7 huggingface-hub-0.18.0 multiprocess-0.70.15\n", - "Collecting lancedb\n", - " Downloading lancedb-0.3.1-py3-none-any.whl (60 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m60.4/60.4 kB\u001b[0m \u001b[31m2.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting deprecation (from lancedb)\n", - " Downloading deprecation-2.1.0-py2.py3-none-any.whl (11 kB)\n", - "Collecting pylance==0.8.3 (from lancedb)\n", - " Downloading pylance-0.8.3-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (21.3 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.3/21.3 MB\u001b[0m \u001b[31m38.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting ratelimiter~=1.0 (from lancedb)\n", - " Downloading ratelimiter-1.2.0.post0-py3-none-any.whl (6.6 kB)\n", - "Collecting retry>=0.9.2 (from lancedb)\n", - " Downloading retry-0.9.2-py2.py3-none-any.whl (8.0 kB)\n", - "Requirement already satisfied: tqdm>=4.1.0 in /usr/local/lib/python3.10/dist-packages (from lancedb) (4.66.1)\n", - "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from lancedb) (3.8.6)\n", - "Requirement already satisfied: pydantic>=1.10 in /usr/local/lib/python3.10/dist-packages (from lancedb) (1.10.13)\n", - "Requirement already satisfied: attrs>=21.3.0 in /usr/local/lib/python3.10/dist-packages (from lancedb) (23.1.0)\n", - "Collecting semver>=3.0 (from lancedb)\n", - " Downloading semver-3.0.2-py3-none-any.whl (17 kB)\n", - "Requirement already satisfied: cachetools in /usr/local/lib/python3.10/dist-packages (from lancedb) (5.3.1)\n", - "Requirement already satisfied: pyyaml>=6.0 in /usr/local/lib/python3.10/dist-packages (from lancedb) (6.0.1)\n", - "Requirement already satisfied: click>=8.1.7 in /usr/local/lib/python3.10/dist-packages (from lancedb) (8.1.7)\n", - "Requirement already satisfied: requests>=2.31.0 in /usr/local/lib/python3.10/dist-packages (from lancedb) (2.31.0)\n", - "Collecting pyarrow>=10 (from pylance==0.8.3->lancedb)\n", - " Downloading pyarrow-13.0.0-cp310-cp310-manylinux_2_28_x86_64.whl (40.0 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m40.0/40.0 MB\u001b[0m \u001b[31m16.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hRequirement already satisfied: numpy>=1.22 in /usr/local/lib/python3.10/dist-packages (from pylance==0.8.3->lancedb) (1.23.5)\n", - "Requirement already satisfied: typing-extensions>=4.2.0 in /usr/local/lib/python3.10/dist-packages (from pydantic>=1.10->lancedb) (4.5.0)\n", - "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.31.0->lancedb) (3.3.0)\n", - "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.31.0->lancedb) (3.4)\n", - "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.31.0->lancedb) (2.0.7)\n", - "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.31.0->lancedb) (2023.7.22)\n", - "Requirement already satisfied: decorator>=3.4.2 in /usr/local/lib/python3.10/dist-packages (from retry>=0.9.2->lancedb) (4.4.2)\n", - "Collecting py<2.0.0,>=1.4.26 (from retry>=0.9.2->lancedb)\n", - " Downloading py-1.11.0-py2.py3-none-any.whl (98 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m98.7/98.7 kB\u001b[0m \u001b[31m15.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hRequirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->lancedb) (6.0.4)\n", - "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.10/dist-packages (from aiohttp->lancedb) (4.0.3)\n", - "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->lancedb) (1.9.2)\n", - "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->lancedb) (1.4.0)\n", - "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->lancedb) (1.3.1)\n", - "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from deprecation->lancedb) (23.2)\n", - "Installing collected packages: ratelimiter, semver, pyarrow, py, deprecation, retry, pylance, lancedb\n", - " Attempting uninstall: pyarrow\n", - " Found existing installation: pyarrow 9.0.0\n", - " Uninstalling pyarrow-9.0.0:\n", - " Successfully uninstalled pyarrow-9.0.0\n", - "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", - "ibis-framework 6.2.0 requires pyarrow<13,>=2, but you have pyarrow 13.0.0 which is incompatible.\n", - "pandas-gbq 0.17.9 requires pyarrow<10.0dev,>=3.0.0, but you have pyarrow 13.0.0 which is incompatible.\u001b[0m\u001b[31m\n", - "\u001b[0mSuccessfully installed deprecation-2.1.0 lancedb-0.3.1 py-1.11.0 pyarrow-13.0.0 pylance-0.8.3 ratelimiter-1.2.0.post0 retry-0.9.2 semver-3.0.2\n" - ] - } - ], - "source": [ - "!pip install panns-inference tqdm --q\n", - "!pip3 install datasets\n", - "!pip install lancedb" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Importing all the libraries" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "id": "hToUqkBBWto1" - }, - "outputs": [], - "source": [ - "import lancedb" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "fF08IHEDalKU" - }, - "source": [ - "**NOTE** : if you get any error while importing lancedb just you need to restart runtime" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "id": "0jIb2Gr8WRju" - }, - "outputs": [], - "source": [ - "from datasets import load_dataset\n", - "from panns_inference import AudioTagging\n", - "from tqdm import tqdm\n", - "from IPython.display import Audio, display\n", - "import numpy as np" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "x6QfsfHlWRju" - }, - "source": [ - "On devices that have CUDA installed, you may be able to install torch's CUDA supported version.\n", - "```bash\n", - "pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118\n", - "```\n", - "If you don't have CUDA or a GPU (or different os), you can install torch here: https://pytorch.org/get-started/locally/" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Load data" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "fyjp-ffQWRjv", - "outputId": "edb7fdfa-27e7-4b00-fa2d-409bbf1d23b8" - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Repo card metadata block was not found. Setting CardData to empty.\n", - "WARNING:huggingface_hub.repocard:Repo card metadata block was not found. Setting CardData to empty.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Checkpoint path: /root/panns_data/Cnn14_mAP=0.431.pth\n", - "GPU number: 1\n" - ] - } - ], - "source": [ - "dataset = load_dataset(\"ashraq/esc50\", split=\"train\")\n", - "at = AudioTagging(checkpoint_path=None, device=\"cuda\") # device=\"cpu\" for CPU inference" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Zm9Qz9WVWRjv", - "outputId": "4cfd5f6d-3d83-4930-ceaf-9cd4c80eb774" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Dataset({\n", - " features: ['filename', 'fold', 'target', 'category', 'esc10', 'src_file', 'take', 'audio'],\n", - " num_rows: 2000\n", - "})" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "dataset" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "INND51clWRjv" - }, - "source": [ - "### Create Embeddings\n", - "Now, to create the data embeddings! We can start by creating batches of 70 for the data, keeping track of the most important columns: `category` and `audio`." - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "id": "VKflK56YWRjv" - }, - "outputs": [], - "source": [ - "batches = [batch[\"audio\"] for batch in dataset.iter(50)]\n", - "meta_batches = [batch[\"category\"] for batch in dataset.iter(50)]\n", - "audio_data = [np.array([audio[\"array\"] for audio in batch]) for batch in batches]\n", - "meta_data = [np.array([meta for meta in batch]) for batch in meta_batches]" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "B4mB3sa2WRjw" - }, - "source": [ - "We now want to iterate through these batches, and for each audio file, we want to use the AudioTagging embedder to extract the embedding. Then, we can store these embeddings, audio files, and category name into a list of dictionaries. Each dictionary has to contain a `vector` column in order to add to the LanceDB table, if no embedding function is provided." - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "pdt1n8S7WRjw", - "outputId": "96d4b5c6-b1c2-497f-c35f-d5905548f6f0" - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|██████████| 40/40 [00:19<00:00, 2.06it/s]\n" - ] - } - ], - "source": [ - "for i in tqdm(range(len(audio_data))):\n", - " (_, embedding) = at.inference(audio_data[i])\n", - " data = [\n", - " {\n", - " \"audio\": x[0][\"array\"],\n", - " \"vector\": x[1],\n", - " \"sampling_rate\": x[0][\"sampling_rate\"],\n", - " \"category\": meta_data[i][j],\n", - " }\n", - " for j, x in enumerate(zip(batches[i], embedding))\n", - " ]" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "CRpnHjJbWRjw" - }, - "source": [ - "Once we have this data list, we can create a LanceDB table by first connecting to a certain directory before, and then calling `db.create_table()`. If the table already exists, we open the table and add the data." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Add the VectorStore" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "id": "3lh_d6m3WRjw" - }, - "outputs": [], - "source": [ - "# Connect to directory at the top of the file\n", - "db = lancedb.connect(\"data/audio-lancedb\")\n", - "table_name = \"audio-search\"\n", - "\n", - "if table_name not in db.table_names():\n", - " tbl = db.create_table(table_name, data)\n", - "else:\n", - " tbl = db.open_table(table_name)\n", - " tbl.add(data)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "m7WfeIv8WRjw" - }, - "source": [ - "We can now combine all of this into a single function:" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Composite function" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "id": "XnCHvlLsWRjw" - }, - "outputs": [], - "source": [ - "def insert_audio():\n", - " batches = [batch[\"audio\"] for batch in dataset.iter(20)]\n", - " meta_batches = [batch[\"category\"] for batch in dataset.iter(20)]\n", - " audio_data = [np.array([audio[\"array\"] for audio in batch]) for batch in batches]\n", - " meta_data = [np.array([meta for meta in batch]) for batch in meta_batches]\n", - " print(\"Start\")\n", - " for i in tqdm(range(len(audio_data))):\n", - " (_, embedding) = at.inference(audio_data[i])\n", - " data = [\n", - " {\n", - " \"audio\": x[0][\"array\"],\n", - " \"vector\": x[1],\n", - " \"sampling_rate\": x[0][\"sampling_rate\"],\n", - " \"category\": meta_data[i][j],\n", - " }\n", - " for j, x in enumerate(zip(batches[i], embedding))\n", - " ]\n", - " if table_name not in db.table_names():\n", - " tbl = db.create_table(table_name, data)\n", - " else:\n", - " tbl = db.open_table(table_name)\n", - " tbl.add(data)" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": { - "id": "UvEhnuLyWRjw" - }, - "outputs": [], - "source": [ - "import shutil\n", - "\n", - "shutil.rmtree(\"data/audio-lancedb/audio-search.lance\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "TXxGHwZdgZrG" - }, - "outputs": [], - "source": [ - "insert_audio()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "vr9LehNiiUNb" - }, - "source": [ - "NOTE: if you get out of ram .next time simply run all cells & uncomment this lines #insert_audio" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "mPBphF19WRjx" - }, - "source": [ - "Great! We now have a fully populated table with all the necessary information. The next step would be to query the table and find those similar audio files. We can do this by first opening the table, and then getting the specific audio file we want to search for." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Query the database" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 93 - }, - "id": "ZsGYl6YSWRjx", - "outputId": "8cc83527-0540-47aa-99b5-054530cf5615" - }, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "3lhhVh6TWRjq" + }, + "source": [ + "# Audio Similarity Search using Vector Embeddings\n", + "This notebook demonstrates how to create vector embeddings of audio files to store into the LanceDB vector store, and then to find similar audio files.\n", + "We will be using [panns_inference package](https://github.com/qiuqiangkong/panns_inference) to tag the audio and create embeddings. We'll also be using this [HuggingFace dataset](https://huggingface.co/datasets/ashraq/esc50) for the audio files. The dataset contains 2,000 sounds and labels." ] - }, - "metadata": {}, - "output_type": "display_data" }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "Category: water_drops\n" - ] - } - ], - "source": [ - "tbl = db.open_table(table_name)\n", - "audio = dataset[50][\"audio\"][\"array\"]\n", - "category = dataset[50][\"category\"]\n", - "display(Audio(audio, rate=dataset[50][\"audio\"][\"sampling_rate\"]))\n", - "print(\"Category:\", category)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Et2C9t87WRjx" - }, - "source": [ - "Next, we call the embedding function again to create those embeddings, which would allow us to search our table." - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "ZmXOqB2FWRjx", - "outputId": "05659b4c-acb6-4514-e3c9-d96ecdf84f1a" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - " audio \\\n", - "0 [0.00506591796875, 0.00653076171875, 0.0051574... \n", - "1 [-0.157318115234375, -0.122344970703125, -0.17... \n", - "2 [-0.0162353515625, -0.015716552734375, -0.0150... \n", - "3 [-0.0008544921875, -0.000762939453125, -0.0005... \n", - "4 [-0.003753662109375, -0.004119873046875, -0.00... \n", - "\n", - " vector sampling_rate \\\n", - "0 [0.0, 0.70255554, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0... 44100 \n", - "1 [0.0, 0.68818694, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0... 44100 \n", - "2 [0.0, 0.58163136, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0... 44100 \n", - "3 [0.0, 1.0475253, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,... 44100 \n", - "4 [0.0, 0.45124823, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0... 44100 \n", - "\n", - " category _distance \n", - "0 water_drops 52.260368 \n", - "1 water_drops 57.536537 \n", - "2 water_drops 75.637558 \n", - "3 drinking_sipping 76.979111 \n", - "4 water_drops 77.981865 \n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - ":2: DeprecatedWarning: to_df is deprecated as of 0.3.1 and will be removed in 0.4.0. Use the bar function instead\n", - " result = tbl.search(embedding[0]).limit(5).to_df()\n" - ] - } - ], - "source": [ - "(_, embedding) = at.inference(audio[None, :])\n", - "result = tbl.search(embedding[0]).limit(5).to_df()\n", - "print(result)" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 396 - }, - "id": "enl39Zp8WRjx", - "outputId": "305805b6-1540-4708-8345-071083221c80" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "0. Category: water_drops\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" + "cell_type": "markdown", + "metadata": { + "id": "is73uCkAZLBj" + }, + "source": [ + "### Installing dependencies" ] - }, - "metadata": {}, - "output_type": "display_data" }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "1. Category: water_drops\n" - ] + "cell_type": "code", + "execution_count": 1, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Pks8RDrdWRjt", + "outputId": "c66c58b2-4f84-4b9c-e563-acda96e620cd" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Requirement already satisfied: datasets in /usr/local/lib/python3.10/dist-packages (2.17.1)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from datasets) (3.13.1)\n", + "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets) (1.25.2)\n", + "Requirement already satisfied: pyarrow>=12.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (14.0.2)\n", + "Requirement already satisfied: pyarrow-hotfix in /usr/local/lib/python3.10/dist-packages (from datasets) (0.6)\n", + "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.3.8)\n", + "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (1.5.3)\n", + "Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (2.31.0)\n", + "Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (4.66.2)\n", + "Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets) (3.4.1)\n", + "Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (from datasets) (0.70.16)\n", + "Requirement already satisfied: fsspec[http]<=2023.10.0,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (2023.6.0)\n", + "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.9.3)\n", + "Requirement already satisfied: huggingface-hub>=0.19.4 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.20.3)\n", + "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets) (23.2)\n", + "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (6.0.1)\n", + "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)\n", + "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (23.2.0)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.4.1)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.0.5)\n", + "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.9.4)\n", + "Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3)\n", + "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.19.4->datasets) (4.10.0)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (3.3.2)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (3.6)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (2.0.7)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (2024.2.2)\n", + "Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)\n", + "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2023.4)\n", + "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.1->pandas->datasets) (1.16.0)\n", + "Requirement already satisfied: lancedb in /usr/local/lib/python3.10/dist-packages (0.6.1)\n", + "Requirement already satisfied: deprecation in /usr/local/lib/python3.10/dist-packages (from lancedb) (2.1.0)\n", + "Requirement already satisfied: pylance==0.10.1 in /usr/local/lib/python3.10/dist-packages (from lancedb) (0.10.1)\n", + "Requirement already satisfied: ratelimiter~=1.0 in /usr/local/lib/python3.10/dist-packages (from lancedb) (1.2.0.post0)\n", + "Requirement already satisfied: retry>=0.9.2 in /usr/local/lib/python3.10/dist-packages (from lancedb) (0.9.2)\n", + "Requirement already satisfied: tqdm>=4.27.0 in /usr/local/lib/python3.10/dist-packages (from lancedb) (4.66.2)\n", + "Requirement already satisfied: pydantic>=1.10 in /usr/local/lib/python3.10/dist-packages (from lancedb) (2.6.3)\n", + "Requirement already satisfied: attrs>=21.3.0 in /usr/local/lib/python3.10/dist-packages (from lancedb) (23.2.0)\n", + "Requirement already satisfied: semver>=3.0 in /usr/local/lib/python3.10/dist-packages (from lancedb) (3.0.2)\n", + "Requirement already satisfied: cachetools in /usr/local/lib/python3.10/dist-packages (from lancedb) (5.3.3)\n", + "Requirement already satisfied: pyyaml>=6.0 in /usr/local/lib/python3.10/dist-packages (from lancedb) (6.0.1)\n", + "Requirement already satisfied: click>=8.1.7 in /usr/local/lib/python3.10/dist-packages (from lancedb) (8.1.7)\n", + "Requirement already satisfied: requests>=2.31.0 in /usr/local/lib/python3.10/dist-packages (from lancedb) (2.31.0)\n", + "Requirement already satisfied: overrides>=0.7 in /usr/local/lib/python3.10/dist-packages (from lancedb) (7.7.0)\n", + "Requirement already satisfied: pyarrow>=12 in /usr/local/lib/python3.10/dist-packages (from pylance==0.10.1->lancedb) (14.0.2)\n", + "Requirement already satisfied: numpy>=1.22 in /usr/local/lib/python3.10/dist-packages (from pylance==0.10.1->lancedb) (1.25.2)\n", + "Requirement already satisfied: annotated-types>=0.4.0 in /usr/local/lib/python3.10/dist-packages (from pydantic>=1.10->lancedb) (0.6.0)\n", + "Requirement already satisfied: pydantic-core==2.16.3 in /usr/local/lib/python3.10/dist-packages (from pydantic>=1.10->lancedb) (2.16.3)\n", + "Requirement already satisfied: typing-extensions>=4.6.1 in /usr/local/lib/python3.10/dist-packages (from pydantic>=1.10->lancedb) (4.10.0)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.31.0->lancedb) (3.3.2)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.31.0->lancedb) (3.6)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.31.0->lancedb) (2.0.7)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.31.0->lancedb) (2024.2.2)\n", + "Requirement already satisfied: decorator>=3.4.2 in /usr/local/lib/python3.10/dist-packages (from retry>=0.9.2->lancedb) (4.4.2)\n", + "Requirement already satisfied: py<2.0.0,>=1.4.26 in /usr/local/lib/python3.10/dist-packages (from retry>=0.9.2->lancedb) (1.11.0)\n", + "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from deprecation->lancedb) (23.2)\n" + ] + } + ], + "source": [ + "!pip install panns-inference tqdm --q\n", + "!pip3 install datasets\n", + "!pip install lancedb" + ] }, { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" + "cell_type": "markdown", + "metadata": { + "id": "ZJsz8MnDZLBn" + }, + "source": [ + "### Importing all the libraries" ] - }, - "metadata": {}, - "output_type": "display_data" }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "2. Category: water_drops\n" - ] + "cell_type": "code", + "execution_count": 2, + "metadata": { + "id": "hToUqkBBWto1" + }, + "outputs": [], + "source": [ + "import lancedb" + ] }, { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" + "cell_type": "markdown", + "metadata": { + "id": "fF08IHEDalKU" + }, + "source": [ + "**NOTE** : if you get any error while importing lancedb just you need to restart runtime" ] - }, - "metadata": {}, - "output_type": "display_data" }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "3. Category: drinking_sipping\n" - ] + "cell_type": "code", + "execution_count": 3, + "metadata": { + "id": "0jIb2Gr8WRju" + }, + "outputs": [], + "source": [ + "from datasets import load_dataset\n", + "from panns_inference import AudioTagging\n", + "from tqdm import tqdm\n", + "from IPython.display import Audio, display\n", + "import numpy as np" + ] }, { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" + "cell_type": "markdown", + "metadata": { + "id": "x6QfsfHlWRju" + }, + "source": [ + "On devices that have CUDA installed, you may be able to install torch's CUDA supported version.\n", + "```bash\n", + "pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118\n", + "```\n", + "If you don't have CUDA or a GPU (or different os), you can install torch here: https://pytorch.org/get-started/locally/" ] - }, - "metadata": {}, - "output_type": "display_data" }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "4. Category: water_drops\n" - ] + "cell_type": "markdown", + "metadata": { + "id": "mMy-7PPNZLBr" + }, + "source": [ + "### Load data" + ] }, { - "data": { - "text/html": [ - "\n", - " \n", - " " + "cell_type": "code", + "execution_count": 4, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "fyjp-ffQWRjv", + "outputId": "0eb8ecb4-aed5-453a-96e4-d956645e4555" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:88: UserWarning: \n", + "The secret `HF_TOKEN` does not exist in your Colab secrets.\n", + "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n", + "You will be able to reuse this secret in all of your notebooks.\n", + "Please note that authentication is recommended but still optional to access public models or datasets.\n", + " warnings.warn(\n", + "/usr/local/lib/python3.10/dist-packages/huggingface_hub/repocard.py:105: UserWarning: Repo card metadata block was not found. Setting CardData to empty.\n", + " warnings.warn(\"Repo card metadata block was not found. Setting CardData to empty.\")\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Checkpoint path: /root/panns_data/Cnn14_mAP=0.431.pth\n", + "GPU number: 1\n" + ] + } ], - "text/plain": [ - "" + "source": [ + "dataset = load_dataset(\"ashraq/esc50\", split=\"train\")\n", + "at = AudioTagging(checkpoint_path=None, device=\"cuda\") # device=\"cpu\" for CPU inference" ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "for i in range(len(result)):\n", - " print(str(i) + \". Category:\", result[\"category\"][i])\n", - " display(Audio(result[\"audio\"][i], rate=result[\"sampling_rate\"][i]))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "mZtR0bxXWRjx" - }, - "source": [ - "Nice! It seems to be working! We can compile this into another function here, that takes an `id` of the audio from 0 to 1,999." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Search Audio using IDs" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": { - "id": "wc1X8MuDWRjx" - }, - "outputs": [], - "source": [ - "def search_audio(id):\n", - " tbl = db.open_table(table_name)\n", - " audio = dataset[id][\"audio\"][\"array\"]\n", - " category = dataset[id][\"category\"]\n", - " display(Audio(audio, rate=dataset[id][\"audio\"][\"sampling_rate\"]))\n", - " print(\"Category:\", category)\n", - "\n", - " (_, embedding) = at.inference(audio[None, :])\n", - " result = tbl.search(embedding[0]).limit(5).to_df()\n", - " print(result)\n", - " for i in range(len(result)):\n", - " print(str(i) + \". Category:\", result[\"category\"][i])\n", - " display(Audio(result[\"audio\"][i], rate=result[\"sampling_rate\"][i]))" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 873 - }, - "id": "dQYVac1kWRjx", - "outputId": "4dd2f8c9-dfb0-475d-97e3-3a82398ee0fd" - }, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - " \n", - " " + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Zm9Qz9WVWRjv", + "outputId": "dcbdce06-309d-45cb-997c-37c89d9b6cc3" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "Dataset({\n", + " features: ['filename', 'fold', 'target', 'category', 'esc10', 'src_file', 'take', 'audio'],\n", + " num_rows: 2000\n", + "})" + ] + }, + "metadata": {}, + "execution_count": 5 + } ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Category: car_horn\n", - " audio \\\n", - "0 [-0.022979736328125, -0.021820068359375, -0.02... \n", - "1 [0.313934326171875, 0.312774658203125, 0.31698... \n", - "2 [0.0655517578125, 0.011505126953125, -0.024536... \n", - "3 [0.063690185546875, 0.065216064453125, 0.07296... \n", - "4 [-0.006866455078125, -0.007476806640625, -0.00... \n", - "\n", - " vector sampling_rate \\\n", - "0 [0.0, 0.12407931, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0... 44100 \n", - "1 [0.0, 0.5878662, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,... 44100 \n", - "2 [0.0, 0.7369921, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,... 44100 \n", - "3 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... 44100 \n", - "4 [0.0, 0.42053863, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0... 44100 \n", - "\n", - " category _distance \n", - "0 airplane 85.660744 \n", - "1 washing_machine 91.059021 \n", - "2 vacuum_cleaner 110.453613 \n", - "3 clapping 111.933456 \n", - "4 footsteps 115.770416 \n", - "0. Category: airplane\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - ":9: DeprecatedWarning: to_df is deprecated as of 0.3.1 and will be removed in 0.4.0. Use the bar function instead\n", - " result = tbl.search(embedding[0]).limit(5).to_df()\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - " \n", - " " + "source": [ + "dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "INND51clWRjv" + }, + "source": [ + "### Create Embeddings\n", + "Now, to create the data embeddings! We can start by creating batches of 70 for the data, keeping track of the most important columns: `category` and `audio`." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "id": "VKflK56YWRjv" + }, + "outputs": [], + "source": [ + "batches = [batch[\"audio\"] for batch in dataset.iter(50)]\n", + "meta_batches = [batch[\"category\"] for batch in dataset.iter(50)]\n", + "audio_data = [np.array([audio[\"array\"] for audio in batch]) for batch in batches]\n", + "meta_data = [np.array([meta for meta in batch]) for batch in meta_batches]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "B4mB3sa2WRjw" + }, + "source": [ + "We now want to iterate through these batches, and for each audio file, we want to use the AudioTagging embedder to extract the embedding. Then, we can store these embeddings, audio files, and category name into a list of dictionaries. Each dictionary has to contain a `vector` column in order to add to the LanceDB table, if no embedding function is provided." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "pdt1n8S7WRjw", + "outputId": "26abb853-33b2-4a86-a41a-6f188e7d4d46" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 40/40 [00:13<00:00, 2.99it/s]\n" + ] + } ], - "text/plain": [ - "" + "source": [ + "for i in tqdm(range(len(audio_data))):\n", + " (_, embedding) = at.inference(audio_data[i])\n", + " data = [\n", + " {\n", + " \"audio\": x[0][\"array\"],\n", + " \"vector\": x[1],\n", + " \"sampling_rate\": x[0][\"sampling_rate\"],\n", + " \"category\": meta_data[i][j],\n", + " }\n", + " for j, x in enumerate(zip(batches[i], embedding))\n", + " ]" ] - }, - "metadata": {}, - "output_type": "display_data" }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "1. Category: washing_machine\n" - ] + "cell_type": "markdown", + "metadata": { + "id": "CRpnHjJbWRjw" + }, + "source": [ + "Once we have this data list, we can create a LanceDB table by first connecting to a certain directory before, and then calling `db.create_table()`. If the table already exists, we open the table and add the data." + ] }, { - "data": { - "text/html": [ - "\n", - " \n", - " " + "cell_type": "markdown", + "metadata": { + "id": "PDGjLT4UZLBu" + }, + "source": [ + "### Add the VectorStore" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "id": "3lh_d6m3WRjw", + "outputId": "691acc64-5791-42b4-9f0f-3992f54b62da", + "colab": { + "base_uri": "https://localhost:8080/" + } + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Created Table\n" + ] + } ], - "text/plain": [ - "" + "source": [ + "# Connect to directory at the top of the file\n", + "db = lancedb.connect(\"data/audio-lancedb\")\n", + "table_name = \"audio-search\"\n", + "\n", + "if table_name not in db.table_names():\n", + " print(\"Created Table\")\n", + " tbl = db.create_table(table_name, data)\n", + "else:\n", + " print(\"Inserting data\")\n", + " tbl = db.open_table(table_name)\n", + " tbl.add(data)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "m7WfeIv8WRjw" + }, + "source": [ + "We can now combine all of this into a single function:" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ARFCQPPjZLBu" + }, + "source": [ + "### Composite function" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "id": "XnCHvlLsWRjw" + }, + "outputs": [], + "source": [ + "def insert_audio():\n", + " batches = [batch[\"audio\"] for batch in dataset.iter(20)]\n", + " meta_batches = [batch[\"category\"] for batch in dataset.iter(20)]\n", + " audio_data = [np.array([audio[\"array\"] for audio in batch]) for batch in batches]\n", + " meta_data = [np.array([meta for meta in batch]) for batch in meta_batches]\n", + " print(\"Start\")\n", + " for i in tqdm(range(len(audio_data))):\n", + " (_, embedding) = at.inference(audio_data[i])\n", + " data = [\n", + " {\n", + " \"audio\": x[0][\"array\"],\n", + " \"vector\": x[1],\n", + " \"sampling_rate\": x[0][\"sampling_rate\"],\n", + " \"category\": meta_data[i][j],\n", + " }\n", + " for j, x in enumerate(zip(batches[i], embedding))\n", + " ]\n", + " if table_name not in db.table_names():\n", + " tbl = db.create_table(table_name, data)\n", + " else:\n", + " tbl = db.open_table(table_name)\n", + " tbl.add(data)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "id": "UvEhnuLyWRjw" + }, + "outputs": [], + "source": [ + "import shutil\n", + "\n", + "shutil.rmtree(\"data/audio-lancedb/audio-search.lance\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "vr9LehNiiUNb" + }, + "source": [ + "NOTE: if you get out of memory, then next time Run all cells & uncomment this lines #insert_audio()" ] - }, - "metadata": {}, - "output_type": "display_data" }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "2. Category: vacuum_cleaner\n" - ] + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "TXxGHwZdgZrG" + }, + "outputs": [], + "source": [ + "# insert_audio()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "mPBphF19WRjx" + }, + "source": [ + "Great! We now have a fully populated table with all the necessary information. The next step would be to query the table and find those similar audio files. We can do this by first opening the table, and then getting the specific audio file we want to search for." + ] }, { - "data": { - "text/html": [ - "\n", - " \n", - " " + "cell_type": "markdown", + "metadata": { + "id": "7B-mrGM6ZLBy" + }, + "source": [ + "### Query the database" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 93 + }, + "id": "ZsGYl6YSWRjx", + "outputId": "7a743814-a168-4fb7-84d4-4c303c55ccea" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + " \n", + " " + ] + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Category: water_drops\n" + ] + } ], - "text/plain": [ - "" + "source": [ + "tbl = db.open_table(table_name)\n", + "audio = dataset[50][\"audio\"][\"array\"]\n", + "category = dataset[50][\"category\"]\n", + "display(Audio(audio, rate=dataset[50][\"audio\"][\"sampling_rate\"]))\n", + "print(\"Category:\", category)" ] - }, - "metadata": {}, - "output_type": "display_data" }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "3. Category: clapping\n" - ] + "cell_type": "markdown", + "metadata": { + "id": "Et2C9t87WRjx" + }, + "source": [ + "Next, we call the embedding function again to create those embeddings, which would allow us to search our table." + ] }, { - "data": { - "text/html": [ - "\n", - " \n", - " " + "cell_type": "code", + "execution_count": 16, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "ZmXOqB2FWRjx", + "outputId": "ed6c36a6-66a7-440d-f8fa-c693e61df0b2" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + " audio \\\n", + "0 [0.00506591796875, 0.00653076171875, 0.0051574... \n", + "1 [-0.157318115234375, -0.122344970703125, -0.17... \n", + "2 [-0.0162353515625, -0.015716552734375, -0.0150... \n", + "3 [-0.0008544921875, -0.000762939453125, -0.0005... \n", + "4 [-0.003753662109375, -0.004119873046875, -0.00... \n", + "\n", + " vector sampling_rate \\\n", + "0 [0.0, 0.70255554, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0... 44100 \n", + "1 [0.0, 0.68818694, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0... 44100 \n", + "2 [0.0, 0.58163136, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0... 44100 \n", + "3 [0.0, 1.0475253, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,... 44100 \n", + "4 [0.0, 0.45124823, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0... 44100 \n", + "\n", + " category _distance \n", + "0 water_drops 52.260319 \n", + "1 water_drops 57.536579 \n", + "2 water_drops 75.637405 \n", + "3 drinking_sipping 76.979073 \n", + "4 water_drops 77.981728 \n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + ":2: UnsupportedWarning: to_df is unsupported as of 0.4.0. Use to_pandas() instead\n", + " result = tbl.search(embedding[0]).limit(5).to_df()\n" + ] + } ], - "text/plain": [ - "" + "source": [ + "(_, embedding) = at.inference(audio[None, :])\n", + "result = tbl.search(embedding[0]).limit(5).to_df()\n", + "print(result)" ] - }, - "metadata": {}, - "output_type": "display_data" }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "4. Category: footsteps\n" - ] + "cell_type": "code", + "execution_count": 17, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 396 + }, + "id": "enl39Zp8WRjx", + "outputId": "296de741-d483-4471-92f4-a263abf1d262" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "0. Category: water_drops\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + " \n", + " " + ] + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "1. Category: water_drops\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + " \n", + " " + ] + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "2. Category: water_drops\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + " \n", + " " + ] + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "3. Category: drinking_sipping\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + " \n", + " " + ] + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "4. Category: water_drops\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + " \n", + " " + ] + }, + "metadata": {} + } + ], + "source": [ + "for i in range(len(result)):\n", + " print(str(i) + \". Category:\", result[\"category\"][i])\n", + " display(Audio(result[\"audio\"][i], rate=result[\"sampling_rate\"][i]))" + ] }, { - "data": { - "text/html": [ - "\n", - " \n", - " " + "cell_type": "markdown", + "metadata": { + "id": "mZtR0bxXWRjx" + }, + "source": [ + "Nice! It seems to be working! We can compile this into another function here, that takes an `id` of the audio from 0 to 1,999." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "OPrn-NAYZLB0" + }, + "source": [ + "### Search Audio using IDs" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "id": "wc1X8MuDWRjx" + }, + "outputs": [], + "source": [ + "def search_audio(id):\n", + " tbl = db.open_table(table_name)\n", + " audio = dataset[id][\"audio\"][\"array\"]\n", + " category = dataset[id][\"category\"]\n", + " display(Audio(audio, rate=dataset[id][\"audio\"][\"sampling_rate\"]))\n", + " print(\"Category:\", category)\n", + "\n", + " (_, embedding) = at.inference(audio[None, :])\n", + " result = tbl.search(embedding[0]).limit(5).to_df()\n", + " print(result)\n", + " for i in range(len(result)):\n", + " print(str(i) + \". Category:\", result[\"category\"][i])\n", + " display(Audio(result[\"audio\"][i], rate=result[\"sampling_rate\"][i]))" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 853 + }, + "id": "dQYVac1kWRjx", + "outputId": "a1ea8e7d-acee-4bb4-a008-6ee90d097cc8" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + " \n", + " " + ] + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Category: car_horn\n", + " audio \\\n", + "0 [-0.022979736328125, -0.021820068359375, -0.02... \n", + "1 [0.313934326171875, 0.312774658203125, 0.31698... \n", + "2 [0.0655517578125, 0.011505126953125, -0.024536... \n", + "3 [0.063690185546875, 0.065216064453125, 0.07296... \n", + "4 [-0.006866455078125, -0.007476806640625, -0.00... \n", + "\n", + " vector sampling_rate \\\n", + "0 [0.0, 0.12407931, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0... 44100 \n", + "1 [0.0, 0.5878662, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,... 44100 \n", + "2 [0.0, 0.7369921, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,... 44100 \n", + "3 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... 44100 \n", + "4 [0.0, 0.42053863, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0... 44100 \n", + "\n", + " category _distance \n", + "0 airplane 85.660736 \n", + "1 washing_machine 91.059029 \n", + "2 vacuum_cleaner 110.453621 \n", + "3 clapping 111.933441 \n", + "4 footsteps 115.770401 \n", + "0. Category: airplane\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + ":9: UnsupportedWarning: to_df is unsupported as of 0.4.0. Use to_pandas() instead\n", + " result = tbl.search(embedding[0]).limit(5).to_df()\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + " \n", + " " + ] + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "1. Category: washing_machine\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + " \n", + " " + ] + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "2. Category: vacuum_cleaner\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + " \n", + " " + ] + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "3. Category: clapping\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + " \n", + " " + ] + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "4. Category: footsteps\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + " \n", + " " + ] + }, + "metadata": {} + } ], - "text/plain": [ - "" + "source": [ + "search_audio(125)" ] - }, - "metadata": {}, - "output_type": "display_data" } - ], - "source": [ - "search_audio(125)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "3X3pePawWRjx" - }, - "outputs": [], - "source": [] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "gpuType": "T4", - "provenance": [] - }, - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.1" + } }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.1" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/examples/product-recommender/main.ipynb b/examples/product-recommender/main.ipynb index 66c5c68..48ee5b9 100644 --- a/examples/product-recommender/main.ipynb +++ b/examples/product-recommender/main.ipynb @@ -1,2979 +1,3132 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "YmdWGrw4t5G2" - }, - "source": [ - "# Product Recommender using Collaborative Filtering and LanceDB\n", - "\n", - "We are going to use **LanceDB** and **Collaborative Filtering** to recommend products based on a user's past buying history. We used the **Instacart dataset** as our data for this example.\n", - "\n", - "![picture](https://daxg39y63pxwu.cloudfront.net/images/blog/product-recommendation-system-projects/Product_Recommendation_System_Project_Ideas_and_Examples.png)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "lXd46ecEt5G7" - }, - "source": [ - "To run this example, you must first create a Kaggle account. Then, go to the 'Account' tab of your user profile and select 'Create New Token'. This will trigger the download of kaggle.json, a file containing your API credentials.\n", - "\n", - "Add Kaggle credentials to `~/.kaggle/kaggle.json` on Linux, OSX, and other UNIX-based operating systems or `C:\\Users\\\\.kaggle\\kaggle.json` for Window's users.\n", - "\n", - "In Google Colab, run the snippet below." - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": { - "id": "l6TTPIF_omEy", - "outputId": "d2cf1685-103e-4b62-bae3-a16d171a928f", - "colab": { - "base_uri": "https://localhost:8080/" - } - }, - "outputs": [ + "cells": [ { - "output_type": "stream", - "name": "stdout", - "text": [ - "Kaggle API key file created and moved successfully.\n" - ] - } - ], - "source": [ - "import json\n", - "import os\n", - "\n", - "# Set the file path\n", - "kaggle_json_path = \"/content/kaggle.json\"\n", - "\n", - "# Write Kaggle API key to the file\n", - "with open(kaggle_json_path, \"w\") as fp:\n", - " json.dump({\"username\": \"\", \"key\": \"\"}, fp)\n", - "\n", - "# Move the file to the correct location\n", - "os.system(\"mkdir -p ~/.kaggle\")\n", - "os.system(f\"mv {kaggle_json_path} ~/.kaggle/kaggle.json\")\n", - "\n", - "# Set permissions\n", - "os.system(\"chmod 600 ~/.kaggle/kaggle.json\")\n", - "\n", - "print(\"Kaggle API key file created and moved successfully.\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "c6G45HrUqNx5" - }, - "source": [ - "### Install dependencies" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" + "cell_type": "markdown", + "metadata": { + "id": "YmdWGrw4t5G2" + }, + "source": [ + "# Product Recommender using Collaborative Filtering and LanceDB\n", + "\n", + "We are going to use **LanceDB** and **Collaborative Filtering** to recommend products based on a user's past buying history. We used the **Instacart dataset** as our data for this example.\n", + "\n", + "![picture](https://daxg39y63pxwu.cloudfront.net/images/blog/product-recommendation-system-projects/Product_Recommendation_System_Project_Ideas_and_Examples.png)" + ] }, - "id": "R3_Hq2VC4_zT", - "outputId": "ee47bbd5-d1c3-4900-894e-2530190e17e7" - }, - "outputs": [ { - "output_type": "stream", - "name": "stdout", - "text": [ - "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (1.23.5)\n", - "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (1.5.3)\n", - "Requirement already satisfied: scipy in /usr/local/lib/python3.10/dist-packages (1.11.4)\n", - "Requirement already satisfied: kaggle in /usr/local/lib/python3.10/dist-packages (1.5.16)\n", - "Collecting implicit\n", - " Downloading implicit-0.7.2-cp310-cp310-manylinux2014_x86_64.whl (8.9 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m8.9/8.9 MB\u001b[0m \u001b[31m18.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hRequirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.1.0+cu121)\n", - "Collecting lancedb\n", - " Downloading lancedb-0.5.0-py3-none-any.whl (87 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m87.4/87.4 kB\u001b[0m \u001b[31m13.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hRequirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas) (2.8.2)\n", - "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas) (2023.3.post1)\n", - "Requirement already satisfied: six>=1.10 in /usr/local/lib/python3.10/dist-packages (from kaggle) (1.16.0)\n", - "Requirement already satisfied: certifi in /usr/local/lib/python3.10/dist-packages (from kaggle) (2023.11.17)\n", - "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from kaggle) (2.31.0)\n", - "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from kaggle) (4.66.1)\n", - "Requirement already satisfied: python-slugify in /usr/local/lib/python3.10/dist-packages (from kaggle) (8.0.1)\n", - "Requirement already satisfied: urllib3 in /usr/local/lib/python3.10/dist-packages (from kaggle) (2.0.7)\n", - "Requirement already satisfied: bleach in /usr/local/lib/python3.10/dist-packages (from kaggle) (6.1.0)\n", - "Requirement already satisfied: threadpoolctl in /usr/local/lib/python3.10/dist-packages (from implicit) (3.2.0)\n", - "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch) (3.13.1)\n", - "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch) (4.5.0)\n", - "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch) (1.12)\n", - "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.2.1)\n", - "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.3)\n", - "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch) (2023.6.0)\n", - "Requirement already satisfied: triton==2.1.0 in /usr/local/lib/python3.10/dist-packages (from torch) (2.1.0)\n", - "Collecting deprecation (from lancedb)\n", - " Downloading deprecation-2.1.0-py2.py3-none-any.whl (11 kB)\n", - "Collecting pylance==0.9.6 (from lancedb)\n", - " Downloading pylance-0.9.6-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.6 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m18.6/18.6 MB\u001b[0m \u001b[31m58.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting ratelimiter~=1.0 (from lancedb)\n", - " Downloading ratelimiter-1.2.0.post0-py3-none-any.whl (6.6 kB)\n", - "Collecting retry>=0.9.2 (from lancedb)\n", - " Downloading retry-0.9.2-py2.py3-none-any.whl (8.0 kB)\n", - "Requirement already satisfied: pydantic>=1.10 in /usr/local/lib/python3.10/dist-packages (from lancedb) (1.10.13)\n", - "Requirement already satisfied: attrs>=21.3.0 in /usr/local/lib/python3.10/dist-packages (from lancedb) (23.2.0)\n", - "Collecting semver>=3.0 (from lancedb)\n", - " Downloading semver-3.0.2-py3-none-any.whl (17 kB)\n", - "Requirement already satisfied: cachetools in /usr/local/lib/python3.10/dist-packages (from lancedb) (5.3.2)\n", - "Requirement already satisfied: pyyaml>=6.0 in /usr/local/lib/python3.10/dist-packages (from lancedb) (6.0.1)\n", - "Requirement already satisfied: click>=8.1.7 in /usr/local/lib/python3.10/dist-packages (from lancedb) (8.1.7)\n", - "Collecting overrides>=0.7 (from lancedb)\n", - " Downloading overrides-7.6.0-py3-none-any.whl (17 kB)\n", - "Collecting pyarrow>=12 (from pylance==0.9.6->lancedb)\n", - " Downloading pyarrow-15.0.0-cp310-cp310-manylinux_2_28_x86_64.whl (38.3 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m38.3/38.3 MB\u001b[0m \u001b[31m13.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hRequirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->kaggle) (3.3.2)\n", - "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->kaggle) (3.6)\n", - "Requirement already satisfied: decorator>=3.4.2 in /usr/local/lib/python3.10/dist-packages (from retry>=0.9.2->lancedb) (4.4.2)\n", - "Collecting py<2.0.0,>=1.4.26 (from retry>=0.9.2->lancedb)\n", - " Downloading py-1.11.0-py2.py3-none-any.whl (98 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m98.7/98.7 kB\u001b[0m \u001b[31m14.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hRequirement already satisfied: webencodings in /usr/local/lib/python3.10/dist-packages (from bleach->kaggle) (0.5.1)\n", - "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from deprecation->lancedb) (23.2)\n", - "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (2.1.3)\n", - "Requirement already satisfied: text-unidecode>=1.3 in /usr/local/lib/python3.10/dist-packages (from python-slugify->kaggle) (1.3)\n", - "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch) (1.3.0)\n", - "Installing collected packages: ratelimiter, semver, pyarrow, py, overrides, deprecation, retry, pylance, implicit, lancedb\n", - " Attempting uninstall: pyarrow\n", - " Found existing installation: pyarrow 10.0.1\n", - " Uninstalling pyarrow-10.0.1:\n", - " Successfully uninstalled pyarrow-10.0.1\n", - "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", - "ibis-framework 7.1.0 requires pyarrow<15,>=2, but you have pyarrow 15.0.0 which is incompatible.\u001b[0m\u001b[31m\n", - "\u001b[0mSuccessfully installed deprecation-2.1.0 implicit-0.7.2 lancedb-0.5.0 overrides-7.6.0 py-1.11.0 pyarrow-15.0.0 pylance-0.9.6 ratelimiter-1.2.0.post0 retry-0.9.2 semver-3.0.2\n" - ] - } - ], - "source": [ - "!pip install numpy pandas scipy kaggle implicit torch lancedb" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "i_eatRhaIGIz" - }, - "source": [ - "### Importing libraries" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "id": "emp_MSXZt5G8" - }, - "outputs": [], - "source": [ - "import zipfile\n", - "import numpy as np\n", - "import pandas as pd\n", - "import scipy.sparse\n", - "import torch\n", - "import implicit\n", - "from implicit import evaluation\n", - "import pydantic\n", - "import lancedb\n", - "from lancedb.pydantic import pydantic_to_schema, vector" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "bUGkN85V4_zY" - }, - "source": [ - "### Load the dataset\n", - "Now we can download the dataset. You will need to accept the rules of the `instacart-market-basket-analysis` competition, which you can do so [here](https://www.kaggle.com/competitions/instacart-market-basket-analysis/rules)." - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" + "cell_type": "markdown", + "metadata": { + "id": "lXd46ecEt5G7" + }, + "source": [ + "To downloading dataset in this example, you must have a Kaggle account.\n", + "\n", + "To get the Kaggle API credentials,\n", + "\n", + "Go to the Your Profile -> Settings -> Create Token\n", + "\n", + "This will download `kaggle.json`, a file containing your API credentials.\n", + "\n", + "Upload Kaggle credentials `kaggle.json` in Google Colab, run the snippet below." + ] }, - "id": "09gdQyBu4_zY", - "outputId": "bb92fb9e-df75-47a5-b50d-290ed0555ef4" - }, - "outputs": [ { - "output_type": "stream", - "name": "stdout", - "text": [ - "Downloading instacart-market-basket-analysis.zip to /content\n", - " 92% 181M/196M [00:01<00:00, 81.3MB/s]\n", - "100% 196M/196M [00:01<00:00, 105MB/s] \n" - ] - } - ], - "source": [ - "!kaggle competitions download -c instacart-market-basket-analysis" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "K4Q4cOX-4_zY" - }, - "source": [ - "We must now extract the zip files." - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": { - "id": "f3g296nL4_zZ" - }, - "outputs": [], - "source": [ - "files = [\n", - " \"instacart-market-basket-analysis.zip\",\n", - " \"order_products__train.csv.zip\",\n", - " \"order_products__prior.csv.zip\",\n", - " \"products.csv.zip\",\n", - " \"orders.csv.zip\",\n", - "]\n", - "\n", - "for filename in files:\n", - " with zipfile.ZipFile(filename, \"r\") as zip_ref:\n", - " zip_ref.extractall(\"./\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "oLgkRIfq4_zZ" - }, - "source": [ - "Now we can move on to loading the dataset. We'll first read the csv files and create dataframes." - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": { - "id": "cBbbR7Rut5G_" - }, - "outputs": [], - "source": [ - "products = pd.read_csv(\"products.csv\")\n", - "orders = pd.read_csv(\"orders.csv\")\n", - "order_products = pd.concat(\n", - " [pd.read_csv(\"order_products__train.csv\"), pd.read_csv(\"order_products__prior.csv\")]\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "5FV_GGjst5HA" - }, - "source": [ - "Since there isn't a user rating attribute, we'll gather \"confidence\" data by looking at the frequency of each item purchased by a user, and store this in the `data` dataframe." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "YNgjd2nnqNx7" - }, - "source": [ - "### Data Manipulation" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": { - "id": "ZjRh7RYpt5HB" - }, - "outputs": [], - "source": [ - "customer_order_products = pd.merge(orders, order_products, how=\"inner\", on=\"order_id\")\n", - "\n", - "# create confidence table\n", - "data = (\n", - " customer_order_products.groupby([\"user_id\", \"product_id\"])[[\"order_id\"]]\n", - " .count()\n", - " .reset_index()\n", - ")\n", - "data.columns = [\"user_id\", \"product_id\", \"total_orders\"]\n", - "data.product_id = data.product_id.astype(\"int64\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "77lvwm0St5HC" - }, - "source": [ - "Let's create a couple of test users to examine the recommendations later:\n", - "- 1st test user: buys 50 sodas: **Zero Calorie Cola**\n", - "- 2nd test user: buys organic produce: **Organic Whole Milk** and **Organic Blackberries**" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 206 + "cell_type": "code", + "source": [ + "! pip install kaggle\n", + "! mkdir ~/.kaggle\n", + "! cp kaggle.json ~/.kaggle/\n", + "! chmod 600 ~/.kaggle/kaggle.json" + ], + "metadata": { + "id": "N3WSkW3kmjyF", + "outputId": "26294f7b-350e-41f9-afe0-e34c9dac3b9e", + "colab": { + "base_uri": "https://localhost:8080/" + } + }, + "execution_count": 1, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Requirement already satisfied: kaggle in /usr/local/lib/python3.10/dist-packages (1.5.16)\n", + "Requirement already satisfied: six>=1.10 in /usr/local/lib/python3.10/dist-packages (from kaggle) (1.16.0)\n", + "Requirement already satisfied: certifi in /usr/local/lib/python3.10/dist-packages (from kaggle) (2024.2.2)\n", + "Requirement already satisfied: python-dateutil in /usr/local/lib/python3.10/dist-packages (from kaggle) (2.8.2)\n", + "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from kaggle) (2.31.0)\n", + "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from kaggle) (4.66.2)\n", + "Requirement already satisfied: python-slugify in /usr/local/lib/python3.10/dist-packages (from kaggle) (8.0.4)\n", + "Requirement already satisfied: urllib3 in /usr/local/lib/python3.10/dist-packages (from kaggle) (2.0.7)\n", + "Requirement already satisfied: bleach in /usr/local/lib/python3.10/dist-packages (from kaggle) (6.1.0)\n", + "Requirement already satisfied: webencodings in /usr/local/lib/python3.10/dist-packages (from bleach->kaggle) (0.5.1)\n", + "Requirement already satisfied: text-unidecode>=1.3 in /usr/local/lib/python3.10/dist-packages (from python-slugify->kaggle) (1.3)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->kaggle) (3.3.2)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->kaggle) (3.6)\n" + ] + } + ] }, - "id": "A06EfAf-t5HC", - "outputId": "95a1f51f-ced1-437a-8b62-569bb915262c" - }, - "outputs": [ { - "output_type": "execute_result", - "data": { - "text/plain": [ - " user_id product_id total_orders\n", - "13863744 206209 48697 1\n", - "13863745 206209 48742 2\n", - "13863746 206210 46149 50\n", - "13863747 206211 27845 49\n", - "13863748 206211 26604 32" + "cell_type": "markdown", + "metadata": { + "id": "c6G45HrUqNx5" + }, + "source": [ + "### Install dependencies" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "R3_Hq2VC4_zT", + "outputId": "752f8e45-ea8b-4b57-8a2b-0c7cb77f5f6c" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (1.25.2)\n", + "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (1.5.3)\n", + "Requirement already satisfied: scipy in /usr/local/lib/python3.10/dist-packages (1.11.4)\n", + "Collecting implicit\n", + " Downloading implicit-0.7.2-cp310-cp310-manylinux2014_x86_64.whl (8.9 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m8.9/8.9 MB\u001b[0m \u001b[31m15.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.1.0+cu121)\n", + "Collecting lancedb\n", + " Downloading lancedb-0.6.1-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (21.3 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.3/21.3 MB\u001b[0m \u001b[31m20.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas) (2.8.2)\n", + "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas) (2023.4)\n", + "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from implicit) (4.66.2)\n", + "Requirement already satisfied: threadpoolctl in /usr/local/lib/python3.10/dist-packages (from implicit) (3.3.0)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch) (3.13.1)\n", + "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch) (4.10.0)\n", + "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch) (1.12)\n", + "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.2.1)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.3)\n", + "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch) (2023.6.0)\n", + "Requirement already satisfied: triton==2.1.0 in /usr/local/lib/python3.10/dist-packages (from torch) (2.1.0)\n", + "Collecting deprecation (from lancedb)\n", + " Downloading deprecation-2.1.0-py2.py3-none-any.whl (11 kB)\n", + "Collecting pylance==0.10.1 (from lancedb)\n", + " Downloading pylance-0.10.1-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (21.5 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.5/21.5 MB\u001b[0m \u001b[31m28.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting ratelimiter~=1.0 (from lancedb)\n", + " Downloading ratelimiter-1.2.0.post0-py3-none-any.whl (6.6 kB)\n", + "Collecting retry>=0.9.2 (from lancedb)\n", + " Downloading retry-0.9.2-py2.py3-none-any.whl (8.0 kB)\n", + "Requirement already satisfied: pydantic>=1.10 in /usr/local/lib/python3.10/dist-packages (from lancedb) (2.6.3)\n", + "Requirement already satisfied: attrs>=21.3.0 in /usr/local/lib/python3.10/dist-packages (from lancedb) (23.2.0)\n", + "Collecting semver>=3.0 (from lancedb)\n", + " Downloading semver-3.0.2-py3-none-any.whl (17 kB)\n", + "Requirement already satisfied: cachetools in /usr/local/lib/python3.10/dist-packages (from lancedb) (5.3.3)\n", + "Requirement already satisfied: pyyaml>=6.0 in /usr/local/lib/python3.10/dist-packages (from lancedb) (6.0.1)\n", + "Requirement already satisfied: click>=8.1.7 in /usr/local/lib/python3.10/dist-packages (from lancedb) (8.1.7)\n", + "Requirement already satisfied: requests>=2.31.0 in /usr/local/lib/python3.10/dist-packages (from lancedb) (2.31.0)\n", + "Collecting overrides>=0.7 (from lancedb)\n", + " Downloading overrides-7.7.0-py3-none-any.whl (17 kB)\n", + "Requirement already satisfied: pyarrow>=12 in /usr/local/lib/python3.10/dist-packages (from pylance==0.10.1->lancedb) (14.0.2)\n", + "Requirement already satisfied: annotated-types>=0.4.0 in /usr/local/lib/python3.10/dist-packages (from pydantic>=1.10->lancedb) (0.6.0)\n", + "Requirement already satisfied: pydantic-core==2.16.3 in /usr/local/lib/python3.10/dist-packages (from pydantic>=1.10->lancedb) (2.16.3)\n", + "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.1->pandas) (1.16.0)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.31.0->lancedb) (3.3.2)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.31.0->lancedb) (3.6)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.31.0->lancedb) (2.0.7)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.31.0->lancedb) (2024.2.2)\n", + "Requirement already satisfied: decorator>=3.4.2 in /usr/local/lib/python3.10/dist-packages (from retry>=0.9.2->lancedb) (4.4.2)\n", + "Collecting py<2.0.0,>=1.4.26 (from retry>=0.9.2->lancedb)\n", + " Downloading py-1.11.0-py2.py3-none-any.whl (98 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m98.7/98.7 kB\u001b[0m \u001b[31m13.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from deprecation->lancedb) (23.2)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (2.1.5)\n", + "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch) (1.3.0)\n", + "Installing collected packages: ratelimiter, semver, py, overrides, deprecation, retry, pylance, implicit, lancedb\n", + "Successfully installed deprecation-2.1.0 implicit-0.7.2 lancedb-0.6.1 overrides-7.7.0 py-1.11.0 pylance-0.10.1 ratelimiter-1.2.0.post0 retry-0.9.2 semver-3.0.2\n" + ] + } ], - "text/html": [ - "\n", - "
\n", - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
user_idproduct_idtotal_orders
13863744206209486971
13863745206209487422
138637462062104614950
138637472062112784549
138637482062112660432
\n", - "
\n", - "
\n", - "\n", - "
\n", - " \n", - "\n", - " \n", - "\n", - " \n", - "
\n", - "\n", - "\n", - "
\n", - " \n", - "\n", - "\n", - "\n", - " \n", - "
\n", - "
\n", - "
\n" + "source": [ + "!pip install numpy pandas scipy implicit torch lancedb" ] - }, - "metadata": {}, - "execution_count": 15 - } - ], - "source": [ - "data_new = pd.DataFrame(\n", - " [\n", - " [data.user_id.max() + 1, 46149, 50],\n", - " [data.user_id.max() + 2, 27845, 49],\n", - " [data.user_id.max() + 2, 26604, 32],\n", - " ],\n", - " columns=[\"user_id\", \"product_id\", \"total_orders\"],\n", - ")\n", - "data = pd.concat([data, data_new]).reset_index(drop=True)\n", - "data.tail()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "xBC-8PFTt5HD" - }, - "source": [ - "In the next step, we will extract user and product unique ids, in order to create a `CSR (Compressed Sparse Row)` matrix. This will allow us to perform collaborative filtering.\n" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": { - "id": "v2_2R7zmt5HE" - }, - "outputs": [], - "source": [ - "# extract unique user and product ids\n", - "unique_users = list(np.sort(data.user_id.unique()))\n", - "unique_products = list(np.sort(products.product_id.unique()))\n", - "purchases = list(data.total_orders)\n", - "\n", - "# create zero-based index position <-> user/item ID mappings\n", - "index_to_user = pd.Series(unique_users)\n", - "\n", - "# create reverse mappings from user/item ID to index positions\n", - "user_to_index = pd.Series(data=index_to_user.index + 1, index=index_to_user.values)\n", - "\n", - "# create row and column for user and product ids\n", - "users_rows = data.user_id.astype(int)\n", - "products_cols = data.product_id.astype(int)\n", - "\n", - "# create CSR matrix\n", - "matrix = scipy.sparse.csr_matrix(\n", - " (purchases, (users_rows, products_cols)),\n", - " shape=(len(unique_users) + 1, len(unique_products) + 1),\n", - ")\n", - "matrix.data = np.nan_to_num(matrix.data, copy=False)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "II6wOH96t5HF" - }, - "source": [ - "Let's now create a recommender model using the **implicit** library. The recommendation model is based off the algorithms described in the paper [Collaborative Filtering for Implicit Feedback Datasets](https://www.researchgate.net/publication/220765111_Collaborative_Filtering_for_Implicit_Feedback_Datasets) with performance optimizations described in [Applications of the Conjugate Gradient Method for Implicit Feedback Collaborative Filtering](https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.379.6473&rep=rep1&type=pdf).\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "JDwIxGMnqNx8" - }, - "source": [ - "# Difference between colloborative and content filtering\n", - "\n", - "![picture](https://miro.medium.com/v2/resize:fit:1400/0*R8qw_CXxCc4600bQ.png)" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 104, - "referenced_widgets": [ - "51febb09c3d54a1a9cf5dd896f3a24f6", - "91b083fde4f14c39bbafb6fd099d44bd", - "84fca55b676b4ef2add284492c8f4c3c", - "bb2c985a09564562b6f040e31d817f07", - "cc06b425a9364b6eb07ef77c4ff6fc48", - "e2e92925bbb442f8a77e2d55886bfbfa", - "bc7f6859319f455da1f552b66a6cf026", - "66396eb857864cc8af94d7e2ced3102c", - "38ddb81c475a472d8439dcf72261b727", - "c095ad1b03a34c4e8b2077e373c82a5b", - "692c702c31904e058c809ae772f1579a" - ] }, - "id": "k0GW99kxt5HF", - "outputId": "548c2514-6194-43e4-dd24-6861f1808f5b" - }, - "outputs": [ { - "output_type": "stream", - "name": "stderr", - "text": [ - "/usr/local/lib/python3.10/dist-packages/implicit/cpu/als.py:95: RuntimeWarning: OpenBLAS is configured to use 2 threads. It is highly recommended to disable its internal threadpool by setting the environment variable 'OPENBLAS_NUM_THREADS=1' or by calling 'threadpoolctl.threadpool_limits(1, \"blas\")'. Having OpenBLAS use a threadpool can lead to severe performance issues here.\n", - " check_blas_config()\n" - ] + "cell_type": "markdown", + "metadata": { + "id": "i_eatRhaIGIz" + }, + "source": [ + "### Importing libraries" + ] }, { - "output_type": "display_data", - "data": { - "text/plain": [ - " 0%| | 0/50 [00:00\n", - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
product_idproduct_namevector_distance
046149Zero Calorie Cola[-0.014371638, -0.016776536, -0.026950998, -0....36.209068
1196Soda[-0.031917833, -0.050772455, 0.013827451, -0.0...36.464764
240939Drinking Water[-0.013426425, 0.0053616967, -0.01992105, -0.0...36.504112
322802Mineral Water[-0.0062663523, -0.00076926383, -0.013624842, ...36.615498
437710Trail Mix[-0.01988333, -0.014069387, -0.021995109, -0.0...36.650448
542500Orange & Lemon Flavor Variety Pack Sparkling F...[-0.009584657, -0.023491196, -0.033104196, -0....36.696648
611759Organic Simply Naked Pita Chips[-0.009341286, -0.014609524, -0.0064758006, -0...36.705814
741400Crunchy Oats 'n Honey Granola Bars[-0.013461881, -0.021371827, -0.02064814, -0.0...36.709579
846061Popcorn[0.0019679032, 0.00719048, -0.01262015, -0.005...36.714954
926348Mixed Fruit Fruit Snacks[-0.0017672281, 0.0020188452, 0.012172974, -0....36.716858
\n", - "
\n", - "
\n", - "\n", - "
\n", - " \n", - "\n", - " \n", - "\n", - " \n", - "
\n", - "\n", - "\n", - "
\n", - " \n", - "\n", - "\n", - "\n", - " \n", - "
\n", - "
\n", - " \n" + "cell_type": "code", + "execution_count": 6, + "metadata": { + "id": "cBbbR7Rut5G_" + }, + "outputs": [], + "source": [ + "products = pd.read_csv(\"products.csv\")\n", + "orders = pd.read_csv(\"orders.csv\")\n", + "order_products = pd.concat(\n", + " [pd.read_csv(\"order_products__train.csv\"), pd.read_csv(\"order_products__prior.csv\")]\n", + ")" ] - }, - "metadata": {} }, { - "output_type": "display_data", - "data": { - "text/plain": [ - " product_id product_name total_orders\n", - "0 46149 Zero Calorie Cola 50" - ], - "text/html": [ - "\n", - "
\n", - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
product_idproduct_nametotal_orders
046149Zero Calorie Cola50
\n", - "
\n", - "
\n", - "\n", - "
\n", - " \n", - "\n", - " \n", - "\n", - " \n", - "
\n", - "\n", - "
\n", - "
\n" + "cell_type": "markdown", + "metadata": { + "id": "5FV_GGjst5HA" + }, + "source": [ + "Since there isn't a user rating attribute, we'll gather \"confidence\" data by looking at the frequency of each item purchased by a user, and store this in the `data` dataframe." ] - }, - "metadata": {} }, { - "output_type": "display_data", - "data": { - "text/plain": [ - " product_id product_name \\\n", - "0 26604 Organic Blackberries \n", - "1 43352 Raspberries \n", - "2 27845 Organic Whole Milk \n", - "3 21288 Blackberries \n", - "4 27966 Organic Raspberries \n", - "5 9076 Blueberries \n", - "6 11777 Red Raspberries \n", - "7 39275 Organic Blueberries \n", - "8 21137 Organic Strawberries \n", - "9 13176 Bag of Organic Bananas \n", - "\n", - " vector _distance \n", - "0 [0.045252558, 0.04258531, 0.011869884, -0.0111... 17.445852 \n", - "1 [0.059606433, 0.014409931, 0.008712215, -0.007... 17.617174 \n", - "2 [-0.03977351, 0.012210161, 0.024828656, 0.0155... 17.692816 \n", - "3 [0.030181486, 0.049021076, 0.003293778, -0.038... 17.696075 \n", - "4 [0.020116415, 0.045062356, 0.00675044, 0.01640... 17.872534 \n", - "5 [0.0482006, 0.06329333, -0.015093377, 0.000180... 17.879623 \n", - "6 [0.05492493, 0.008120705, 0.020613482, 0.00779... 17.931437 \n", - "7 [0.005109854, 0.032895964, -0.013481544, 0.010... 17.970798 \n", - "8 [0.0017651353, 0.033547334, -0.005775958, 0.02... 17.986570 \n", - "9 [0.004607136, 0.02749164, -0.006206838, 0.0187... 18.092993 " - ], - "text/html": [ - "\n", - "
\n", - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
product_idproduct_namevector_distance
026604Organic Blackberries[0.045252558, 0.04258531, 0.011869884, -0.0111...17.445852
143352Raspberries[0.059606433, 0.014409931, 0.008712215, -0.007...17.617174
227845Organic Whole Milk[-0.03977351, 0.012210161, 0.024828656, 0.0155...17.692816
321288Blackberries[0.030181486, 0.049021076, 0.003293778, -0.038...17.696075
427966Organic Raspberries[0.020116415, 0.045062356, 0.00675044, 0.01640...17.872534
59076Blueberries[0.0482006, 0.06329333, -0.015093377, 0.000180...17.879623
611777Red Raspberries[0.05492493, 0.008120705, 0.020613482, 0.00779...17.931437
739275Organic Blueberries[0.005109854, 0.032895964, -0.013481544, 0.010...17.970798
821137Organic Strawberries[0.0017651353, 0.033547334, -0.005775958, 0.02...17.986570
913176Bag of Organic Bananas[0.004607136, 0.02749164, -0.006206838, 0.0187...18.092993
\n", - "
\n", - "
\n", - "\n", - "
\n", - " \n", - "\n", - " \n", - "\n", - " \n", - "
\n", - "\n", - "\n", - "
\n", - " \n", - "\n", - "\n", - "\n", - " \n", - "
\n", - "
\n", - "
\n" + "cell_type": "markdown", + "metadata": { + "id": "YNgjd2nnqNx7" + }, + "source": [ + "### Data Manipulation" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "id": "ZjRh7RYpt5HB" + }, + "outputs": [], + "source": [ + "customer_order_products = pd.merge(orders, order_products, how=\"inner\", on=\"order_id\")\n", + "\n", + "# create confidence table\n", + "data = (\n", + " customer_order_products.groupby([\"user_id\", \"product_id\"])[[\"order_id\"]]\n", + " .count()\n", + " .reset_index()\n", + ")\n", + "data.columns = [\"user_id\", \"product_id\", \"total_orders\"]\n", + "data.product_id = data.product_id.astype(\"int64\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "77lvwm0St5HC" + }, + "source": [ + "Let's create a couple of test users to examine the recommendations later:\n", + "- 1st test user: buys 50 sodas: **Zero Calorie Cola**\n", + "- 2nd test user: buys organic produce: **Organic Whole Milk** and **Organic Blackberries**" ] - }, - "metadata": {} }, { - "output_type": "display_data", - "data": { - "text/plain": [ - " product_id product_name total_orders\n", - "0 27845 Organic Whole Milk 49\n", - "1 26604 Organic Blackberries 32" + "cell_type": "code", + "execution_count": 8, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 206 + }, + "id": "A06EfAf-t5HC", + "outputId": "48ef0f5d-7c7a-4087-fd4b-8d3fa5ebaca1" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + " user_id product_id total_orders\n", + "13863744 206209 48697 1\n", + "13863745 206209 48742 2\n", + "13863746 206210 46149 50\n", + "13863747 206211 27845 49\n", + "13863748 206211 26604 32" + ], + "text/html": [ + "\n", + "
\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_idproduct_idtotal_orders
13863744206209486971
13863745206209487422
138637462062104614950
138637472062112784549
138637482062112660432
\n", + "
\n", + "
\n", + "\n", + "
\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "
\n", + "\n", + "\n", + "
\n", + " \n", + "\n", + "\n", + "\n", + " \n", + "
\n", + "\n", + "
\n", + "
\n" + ], + "application/vnd.google.colaboratory.intrinsic+json": { + "type": "dataframe", + "summary": "{\n \"name\": \"data\",\n \"rows\": 5,\n \"fields\": [\n {\n \"column\": \"user_id\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 1,\n \"min\": 206209,\n \"max\": 206211,\n \"num_unique_values\": 3,\n \"samples\": [\n 206209,\n 206210,\n 206211\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"product_id\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 11361,\n \"min\": 26604,\n \"max\": 48742,\n \"num_unique_values\": 5,\n \"samples\": [\n 48742,\n 26604,\n 46149\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"total_orders\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 24,\n \"min\": 1,\n \"max\": 50,\n \"num_unique_values\": 5,\n \"samples\": [\n 2,\n 32,\n 50\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}" + } + }, + "metadata": {}, + "execution_count": 8 + } ], - "text/html": [ - "\n", - "
\n", - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
product_idproduct_nametotal_orders
027845Organic Whole Milk49
126604Organic Blackberries32
\n", - "
\n", - "
\n", - "\n", - "
\n", - " \n", - "\n", - " \n", - "\n", - " \n", - "
\n", - "\n", - "\n", - "
\n", - " \n", - "\n", - "\n", - "\n", - " \n", - "
\n", - "
\n", - "
\n" + "source": [ + "data_new = pd.DataFrame(\n", + " [\n", + " [data.user_id.max() + 1, 46149, 50],\n", + " [data.user_id.max() + 2, 27845, 49],\n", + " [data.user_id.max() + 2, 26604, 32],\n", + " ],\n", + " columns=[\"user_id\", \"product_id\", \"total_orders\"],\n", + ")\n", + "data = pd.concat([data, data_new]).reset_index(drop=True)\n", + "data.tail()" ] - }, - "metadata": {} - } - ], - "source": [ - "# Query by user factors\n", - "test_user_embeddings = test_user_factors.tolist()\n", - "for embedding, id in zip(test_user_embeddings, test_user_ids):\n", - " results = tbl.search(embedding).limit(10).to_pandas()\n", - " display(results)\n", - " display(products_bought_by_user_in_the_past(id, top=15))" - ] - }, - { - "cell_type": "code", - "source": [], - "metadata": { - "id": "-kWR644v1ZJp" - }, - "execution_count": null, - "outputs": [] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "gpuType": "T4", - "provenance": [] - }, - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.1" - }, - "vscode": { - "interpreter": { - "hash": "5fe10bf018ef3e697f9035d60bf60847932a12bface18908407fd371fe880db9" - } - }, - "widgets": { - "application/vnd.jupyter.widget-state+json": { - "51febb09c3d54a1a9cf5dd896f3a24f6": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HBoxModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_91b083fde4f14c39bbafb6fd099d44bd", - "IPY_MODEL_84fca55b676b4ef2add284492c8f4c3c", - "IPY_MODEL_bb2c985a09564562b6f040e31d817f07" + }, + { + "cell_type": "markdown", + "metadata": { + "id": "xBC-8PFTt5HD" + }, + "source": [ + "In the next step, we will extract user and product unique ids, in order to create a `CSR (Compressed Sparse Row)` matrix. This will allow us to perform collaborative filtering.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "id": "v2_2R7zmt5HE" + }, + "outputs": [], + "source": [ + "# extract unique user and product ids\n", + "unique_users = list(np.sort(data.user_id.unique()))\n", + "unique_products = list(np.sort(products.product_id.unique()))\n", + "purchases = list(data.total_orders)\n", + "\n", + "# create zero-based index position <-> user/item ID mappings\n", + "index_to_user = pd.Series(unique_users)\n", + "\n", + "# create reverse mappings from user/item ID to index positions\n", + "user_to_index = pd.Series(data=index_to_user.index + 1, index=index_to_user.values)\n", + "\n", + "# create row and column for user and product ids\n", + "users_rows = data.user_id.astype(int)\n", + "products_cols = data.product_id.astype(int)\n", + "\n", + "# create CSR matrix\n", + "matrix = scipy.sparse.csr_matrix(\n", + " (purchases, (users_rows, products_cols)),\n", + " shape=(len(unique_users) + 1, len(unique_products) + 1),\n", + ")\n", + "matrix.data = np.nan_to_num(matrix.data, copy=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "II6wOH96t5HF" + }, + "source": [ + "Let's now create a recommender model using the **implicit** library. The recommendation model is based off the algorithms described in the paper [Collaborative Filtering for Implicit Feedback Datasets](https://www.researchgate.net/publication/220765111_Collaborative_Filtering_for_Implicit_Feedback_Datasets) with performance optimizations described in [Applications of the Conjugate Gradient Method for Implicit Feedback Collaborative Filtering](https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.379.6473&rep=rep1&type=pdf).\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "JDwIxGMnqNx8" + }, + "source": [ + "# Difference between colloborative and content filtering\n", + "\n", + "![picture](https://miro.medium.com/v2/resize:fit:1400/0*R8qw_CXxCc4600bQ.png)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 104, + "referenced_widgets": [ + "c159a1c6fc204d239b5ff7713d3c68fe", + "6e3b621f67554d6cbcaa50717008821f", + "1e5f629b939247c088b275a72310cfe0", + "cfde2bc68d9c448b823c690e15c4a169", + "8668f98cebeb4b548e87f2c4e68c9cbf", + "7ebca3dced8e4c029398db02169b868e", + "28400c62e971452b865e70af4e410afc", + "c45f8ded7dc84c18b479c3c427c29463", + "301f4f324d594ff2a63dc2f43ba4391f", + "0e3594636fbf4263b32d195f31fd29c0", + "adf0848d8d8440f18dbd001572772fce" + ] + }, + "id": "k0GW99kxt5HF", + "outputId": "fd9c03c5-c668-4ddd-8fea-1b3e737b8ad6" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.10/dist-packages/implicit/cpu/als.py:95: RuntimeWarning: OpenBLAS is configured to use 2 threads. It is highly recommended to disable its internal threadpool by setting the environment variable 'OPENBLAS_NUM_THREADS=1' or by calling 'threadpoolctl.threadpool_limits(1, \"blas\")'. Having OpenBLAS use a threadpool can lead to severe performance issues here.\n", + " check_blas_config()\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + " 0%| | 0/50 [00:00\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
product_idproduct_namevector_distance
046149Zero Calorie Cola[0.037515923, -0.030325921, 0.004221245, -0.00...38.190578
1196Soda[0.04531822, -0.04450815, -0.0022076364, -0.02...38.340080
222802Mineral Water[0.030236538, -0.0041136313, 0.015683502, -0.0...38.593525
340939Drinking Water[0.03287196, -0.017454194, 0.009911481, -0.004...38.606468
431651Extra Fancy Unsalted Mixed Nuts[0.037796307, -0.009871203, -0.0020715303, -0....38.642967
537710Trail Mix[0.05062829, -0.017916694, 0.0027849572, 0.001...38.668938
641400Crunchy Oats 'n Honey Granola Bars[0.028622035, -0.013106515, -0.0072577046, -0....38.703171
726348Mixed Fruit Fruit Snacks[0.011525251, -0.032522, -0.021976499, 0.01198...38.709934
846061Popcorn[0.039293304, -0.016017294, -0.0010792917, 0.0...38.713402
939657Milk Chocolate Almonds[0.030015469, -0.00927157, 0.0061932686, 0.000...38.748997
\n", + "
\n", + "
\n", + "\n", + "
\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "
\n", + "\n", + "\n", + "
\n", + " \n", + "\n", + "\n", + "\n", + " \n", + "
\n", + "\n", + "
\n", + " \n", + " \n", + " \n", + "
\n", + "\n", + "
\n", + " \n" + ], + "application/vnd.google.colaboratory.intrinsic+json": { + "type": "dataframe", + "variable_name": "results", + "summary": "{\n \"name\": \"results\",\n \"rows\": 10,\n \"fields\": [\n {\n \"column\": \"product_id\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 14019,\n \"min\": 196,\n \"max\": 46149,\n \"num_unique_values\": 10,\n \"samples\": [\n 46061,\n 196,\n 37710\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"product_name\",\n \"properties\": {\n \"dtype\": \"string\",\n \"num_unique_values\": 10,\n \"samples\": [\n \"Popcorn\",\n \"Soda\",\n \"Trail Mix\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"vector\",\n \"properties\": {\n \"dtype\": \"object\",\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"_distance\",\n \"properties\": {\n \"dtype\": \"float32\",\n \"num_unique_values\": 10,\n \"samples\": [\n 38.713401794433594,\n 38.34008026123047,\n 38.66893768310547\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}" + } + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + " product_id product_name total_orders\n", + "0 46149 Zero Calorie Cola 50" + ], + "text/html": [ + "\n", + "
\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
product_idproduct_nametotal_orders
046149Zero Calorie Cola50
\n", + "
\n", + "
\n", + "\n", + "
\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "
\n", + "\n", + "\n", + "
\n", + "
\n" + ], + "application/vnd.google.colaboratory.intrinsic+json": { + "type": "dataframe", + "summary": "{\n \"name\": \" display(products_bought_by_user_in_the_past(id, top=15))\",\n \"rows\": 1,\n \"fields\": [\n {\n \"column\": \"product_id\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": null,\n \"min\": 46149,\n \"max\": 46149,\n \"num_unique_values\": 1,\n \"samples\": [\n 46149\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"product_name\",\n \"properties\": {\n \"dtype\": \"string\",\n \"num_unique_values\": 1,\n \"samples\": [\n \"Zero Calorie Cola\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"total_orders\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": null,\n \"min\": 50,\n \"max\": 50,\n \"num_unique_values\": 1,\n \"samples\": [\n 50\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}" + } + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + " product_id product_name \\\n", + "0 26604 Organic Blackberries \n", + "1 27845 Organic Whole Milk \n", + "2 27966 Organic Raspberries \n", + "3 43352 Raspberries \n", + "4 21288 Blackberries \n", + "5 39275 Organic Blueberries \n", + "6 11777 Red Raspberries \n", + "7 9076 Blueberries \n", + "8 21137 Organic Strawberries \n", + "9 11422 Plain Greek Yogurt \n", + "\n", + " vector _distance \n", + "0 [0.019478824, 0.007443799, 0.004226536, 0.0283... 16.314867 \n", + "1 [-0.03417227, -0.053161107, 0.03893201, 0.0150... 16.432335 \n", + "2 [0.024305355, -0.0063351737, 0.029324768, 0.02... 16.577738 \n", + "3 [0.020642506, 0.025494106, 0.0050161625, 0.003... 16.588812 \n", + "4 [-0.00844225, 0.01996236, -0.0148576135, 0.012... 16.672234 \n", + "5 [0.035410225, -0.0029810749, 0.014112177, 0.00... 16.684757 \n", + "6 [0.020807281, -0.015660688, 0.010914551, 0.028... 16.746056 \n", + "7 [0.033343736, 0.0068411743, 0.0028535812, 0.00... 16.765997 \n", + "8 [0.018478896, -0.0014569649, 0.01558258, 0.009... 16.883642 \n", + "9 [0.003926732, -0.02004065, 0.059874147, 0.0318... 17.008499 " + ], + "text/html": [ + "\n", + "
\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
product_idproduct_namevector_distance
026604Organic Blackberries[0.019478824, 0.007443799, 0.004226536, 0.0283...16.314867
127845Organic Whole Milk[-0.03417227, -0.053161107, 0.03893201, 0.0150...16.432335
227966Organic Raspberries[0.024305355, -0.0063351737, 0.029324768, 0.02...16.577738
343352Raspberries[0.020642506, 0.025494106, 0.0050161625, 0.003...16.588812
421288Blackberries[-0.00844225, 0.01996236, -0.0148576135, 0.012...16.672234
539275Organic Blueberries[0.035410225, -0.0029810749, 0.014112177, 0.00...16.684757
611777Red Raspberries[0.020807281, -0.015660688, 0.010914551, 0.028...16.746056
79076Blueberries[0.033343736, 0.0068411743, 0.0028535812, 0.00...16.765997
821137Organic Strawberries[0.018478896, -0.0014569649, 0.01558258, 0.009...16.883642
911422Plain Greek Yogurt[0.003926732, -0.02004065, 0.059874147, 0.0318...17.008499
\n", + "
\n", + "
\n", + "\n", + "
\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "
\n", + "\n", + "\n", + "
\n", + " \n", + "\n", + "\n", + "\n", + " \n", + "
\n", + "\n", + "
\n", + " \n", + " \n", + " \n", + "
\n", + "\n", + "
\n", + "
\n" + ], + "application/vnd.google.colaboratory.intrinsic+json": { + "type": "dataframe", + "variable_name": "results", + "summary": "{\n \"name\": \"results\",\n \"rows\": 10,\n \"fields\": [\n {\n \"column\": \"product_id\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 11503,\n \"min\": 9076,\n \"max\": 43352,\n \"num_unique_values\": 10,\n \"samples\": [\n 21137,\n 27845,\n 39275\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"product_name\",\n \"properties\": {\n \"dtype\": \"string\",\n \"num_unique_values\": 10,\n \"samples\": [\n \"Organic Strawberries\",\n \"Organic Whole Milk\",\n \"Organic Blueberries\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"vector\",\n \"properties\": {\n \"dtype\": \"object\",\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"_distance\",\n \"properties\": {\n \"dtype\": \"float32\",\n \"num_unique_values\": 10,\n \"samples\": [\n 16.883642196655273,\n 16.432334899902344,\n 16.684757232666016\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}" + } + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + " product_id product_name total_orders\n", + "0 27845 Organic Whole Milk 49\n", + "1 26604 Organic Blackberries 32" + ], + "text/html": [ + "\n", + "
\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
product_idproduct_nametotal_orders
027845Organic Whole Milk49
126604Organic Blackberries32
\n", + "
\n", + "
\n", + "\n", + "
\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "
\n", + "\n", + "\n", + "
\n", + " \n", + "\n", + "\n", + "\n", + " \n", + "
\n", + "\n", + "
\n", + "
\n" + ], + "application/vnd.google.colaboratory.intrinsic+json": { + "type": "dataframe", + "summary": "{\n \"name\": \" display(products_bought_by_user_in_the_past(id, top=15))\",\n \"rows\": 2,\n \"fields\": [\n {\n \"column\": \"product_id\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 877,\n \"min\": 26604,\n \"max\": 27845,\n \"num_unique_values\": 2,\n \"samples\": [\n 26604,\n 27845\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"product_name\",\n \"properties\": {\n \"dtype\": \"string\",\n \"num_unique_values\": 2,\n \"samples\": [\n \"Organic Blackberries\",\n \"Organic Whole Milk\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"total_orders\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 12,\n \"min\": 32,\n \"max\": 49,\n \"num_unique_values\": 2,\n \"samples\": [\n 32,\n 49\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}" + } + }, + "metadata": {} + } + ], + "source": [ + "# Query by user factors\n", + "test_user_embeddings = test_user_factors.tolist()\n", + "for embedding, id in zip(test_user_embeddings, test_user_ids):\n", + " results = tbl.search(embedding).limit(10).to_pandas()\n", + " display(results)\n", + " display(products_bought_by_user_in_the_past(id, top=15))" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [] }, - "2782769e3daa491385bcc8ae34f24f3b": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" }, - "5d41569b941445bea2497c89d3c8e6cb": { - "model_module": "@jupyter-widgets/controls", - "model_name": "ProgressStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.1" }, - "5e7dd2740d174064ac2d1cbc75cb5909": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } + "vscode": { + "interpreter": { + "hash": "5fe10bf018ef3e697f9035d60bf60847932a12bface18908407fd371fe880db9" + } }, - "a67972dc3f264b3699816257f1ad9ed7": { - "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "c159a1c6fc204d239b5ff7713d3c68fe": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_6e3b621f67554d6cbcaa50717008821f", + "IPY_MODEL_1e5f629b939247c088b275a72310cfe0", + "IPY_MODEL_cfde2bc68d9c448b823c690e15c4a169" + ], + "layout": "IPY_MODEL_8668f98cebeb4b548e87f2c4e68c9cbf" + } + }, + "6e3b621f67554d6cbcaa50717008821f": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_7ebca3dced8e4c029398db02169b868e", + "placeholder": "​", + "style": "IPY_MODEL_28400c62e971452b865e70af4e410afc", + "value": "100%" + } + }, + "1e5f629b939247c088b275a72310cfe0": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_c45f8ded7dc84c18b479c3c427c29463", + "max": 50, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_301f4f324d594ff2a63dc2f43ba4391f", + "value": 50 + } + }, + "cfde2bc68d9c448b823c690e15c4a169": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_0e3594636fbf4263b32d195f31fd29c0", + "placeholder": "​", + "style": "IPY_MODEL_adf0848d8d8440f18dbd001572772fce", + "value": " 50/50 [17:12<00:00, 20.75s/it]" + } + }, + "8668f98cebeb4b548e87f2c4e68c9cbf": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "7ebca3dced8e4c029398db02169b868e": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "28400c62e971452b865e70af4e410afc": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "c45f8ded7dc84c18b479c3c427c29463": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "301f4f324d594ff2a63dc2f43ba4391f": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "0e3594636fbf4263b32d195f31fd29c0": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "adf0848d8d8440f18dbd001572772fce": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "142121b5c098477985d3bf5eb9560ad4": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_49f9dab3bf2748a2a0811a7057c32ff7", + "IPY_MODEL_3ea9a47313cd496694180de85b51decf", + "IPY_MODEL_1cd7d3c410ed449eb88cc8d78e49e10d" + ], + "layout": "IPY_MODEL_e66f741c3e794c69a328c715cc9b56a2" + } + }, + "49f9dab3bf2748a2a0811a7057c32ff7": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_4a785b8e4b0d43eca0cf41c2b1cb2f35", + "placeholder": "​", + "style": "IPY_MODEL_05369b050a61407f8cd0c657afb9a6bd", + "value": "100%" + } + }, + "3ea9a47313cd496694180de85b51decf": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_9ffbed3caaf84e1db7bde609b6cc06a7", + "max": 192999, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_6142e63dd35c46839b9b8cd520750844", + "value": 192999 + } + }, + "1cd7d3c410ed449eb88cc8d78e49e10d": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_cb770a9f4224470bba0a7488b76a24c0", + "placeholder": "​", + "style": "IPY_MODEL_2dea74cc01b04e548bb7a77bd31a2fd2", + "value": " 192999/192999 [02:18<00:00, 1522.55it/s]" + } + }, + "e66f741c3e794c69a328c715cc9b56a2": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "4a785b8e4b0d43eca0cf41c2b1cb2f35": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "05369b050a61407f8cd0c657afb9a6bd": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "9ffbed3caaf84e1db7bde609b6cc06a7": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "6142e63dd35c46839b9b8cd520750844": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "cb770a9f4224470bba0a7488b76a24c0": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "2dea74cc01b04e548bb7a77bd31a2fd2": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + } + } } - } - } - }, - "nbformat": 4, - "nbformat_minor": 0 + }, + "nbformat": 4, + "nbformat_minor": 0 } \ No newline at end of file