diff --git a/README.md b/README.md index 5e7b5dae..52a34738 100644 --- a/README.md +++ b/README.md @@ -159,10 +159,10 @@ task.evaluate(val_pred, val_table) # Tutorials To get started with RelBench, we provide some helpful Colab notebook tutorials. These tutorials cover (i) how to load data using RelBench, focusing on providing users with the understanding of RelBench data logic needed to use RelBench data freely with any desired ML models, and (ii) training a GNN predictive model to solve tasks in RelBench. Please refer to the code for more detailed documentation. -| Name | Colab | Description | -|-------|-------|---------------------------------------------------------| -| Loading Data | [](https://colab.research.google.com/drive/1PAOktBqh_3QzgAKi53F4JbQxoOuBsUBY?usp=sharing) | How to load and explore RelBench data -| Training models | [](https://colab.research.google.com/drive/1_z0aKcs5XndEacX1eob6csDuR4DYhGQU?usp=sharing)| Train your first GNN-based model on RelBench. | +| Name | Notebook | Try on Colab | Description | +|-------|----------|--------------|---------------------------------------------------------| +| Loading Data | [load_data.ipynb](tutorials/load_data.ipynb) | [](https://colab.research.google.com/github/snap-stanford/relbench/blob/main/tutorials/load_data.ipynb) | How to load and explore RelBench data +| Training models | [train_model.ipynb](tutorials/train_model.ipynb) | [](https://colab.research.google.com/github/snap-stanford/relbench/blob/main/tutorials/train_model.ipynb)| Train your first GNN-based model on RelBench. | diff --git a/tutorials/load_data.ipynb b/tutorials/load_data.ipynb new file mode 100644 index 00000000..23067119 --- /dev/null +++ b/tutorials/load_data.ipynb @@ -0,0 +1,1259 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "First, `pip install`" + ], + "metadata": { + "id": "9hVIN7QW2etU" + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "QkzAsRon2ZIw", + "outputId": "4dca4023-9f6b-4124-a629-3b9257844ba8" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Collecting relbench==1.0.0rc1\n", + " Downloading relbench-1.0.0rc1-py3-none-any.whl (57 kB)\n", + "\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/57.4 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m57.4/57.4 kB\u001b[0m \u001b[31m1.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from relbench==1.0.0rc1) (2.0.3)\n", + "Requirement already satisfied: pooch in /usr/local/lib/python3.10/dist-packages (from relbench==1.0.0rc1) (1.8.2)\n", + "Requirement already satisfied: pyarrow in /usr/local/lib/python3.10/dist-packages (from relbench==1.0.0rc1) (14.0.2)\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from relbench==1.0.0rc1) (1.25.2)\n", + "Requirement already satisfied: duckdb in /usr/local/lib/python3.10/dist-packages (from relbench==1.0.0rc1) (0.10.3)\n", + "Requirement already satisfied: scikit-learn in /usr/local/lib/python3.10/dist-packages (from relbench==1.0.0rc1) (1.2.2)\n", + "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from relbench==1.0.0rc1) (4.12.2)\n", + "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from relbench==1.0.0rc1) (4.66.4)\n", + "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->relbench==1.0.0rc1) (2.8.2)\n", + "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->relbench==1.0.0rc1) (2023.4)\n", + "Requirement already satisfied: tzdata>=2022.1 in /usr/local/lib/python3.10/dist-packages (from pandas->relbench==1.0.0rc1) (2024.1)\n", + "Requirement already satisfied: platformdirs>=2.5.0 in /usr/local/lib/python3.10/dist-packages (from pooch->relbench==1.0.0rc1) (4.2.2)\n", + "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from pooch->relbench==1.0.0rc1) (24.1)\n", + "Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.10/dist-packages (from pooch->relbench==1.0.0rc1) (2.31.0)\n", + "Requirement already satisfied: scipy>=1.3.2 in /usr/local/lib/python3.10/dist-packages (from scikit-learn->relbench==1.0.0rc1) (1.11.4)\n", + "Requirement already satisfied: joblib>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from scikit-learn->relbench==1.0.0rc1) (1.4.2)\n", + "Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn->relbench==1.0.0rc1) (3.5.0)\n", + "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas->relbench==1.0.0rc1) (1.16.0)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->pooch->relbench==1.0.0rc1) (3.3.2)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->pooch->relbench==1.0.0rc1) (3.7)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->pooch->relbench==1.0.0rc1) (2.0.7)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->pooch->relbench==1.0.0rc1) (2024.6.2)\n", + "Installing collected packages: relbench\n", + "Successfully installed relbench-1.0.0rc1\n" + ] + } + ], + "source": [ + "!pip install relbench==1.0.0rc1" + ] + }, + { + "cell_type": "markdown", + "source": [ + "To start we can check all of the databases currently available in RelBench by printing:" + ], + "metadata": { + "id": "yLfAfxNO-LBV" + } + }, + { + "cell_type": "code", + "source": [ + "from relbench.datasets import get_dataset_names\n", + "\n", + "get_dataset_names()" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "ApqEZu5--HJy", + "outputId": "f45a7623-ce90-4dd9-c088-004290d6519a" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "['rel-amazon',\n", + " 'rel-avito',\n", + " 'rel-event',\n", + " 'rel-f1',\n", + " 'rel-hm',\n", + " 'rel-stack',\n", + " 'rel-trial']" + ] + }, + "metadata": {}, + "execution_count": 2 + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "# Get dataset\n", + "\n", + "Let's start with the F1 dataset since it's the smallest and is easy to work with. All it takes is one line!\n" + ], + "metadata": { + "id": "foOR0gDJ2j1s" + } + }, + { + "cell_type": "code", + "source": [ + "from relbench.datasets import get_dataset\n", + "\n", + "dataset = get_dataset(name=\"rel-f1\", download=True)" + ], + "metadata": { + "id": "RsAPkjOk2hDn", + "outputId": "837cbd40-167b-4bb1-968e-7120de1fa028", + "colab": { + "base_uri": "https://localhost:8080/" + } + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "Downloading file 'rel-f1/db.zip' from 'https://relbench.stanford.edu/download/rel-f1/db.zip' to '/root/.cache/relbench'.\n", + "100%|████████████████████████████████████████| 704k/704k [00:00<00:00, 174MB/s]\n", + "Unzipping contents of '/root/.cache/relbench/rel-f1/db.zip' to '/root/.cache/relbench/rel-f1/.'\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "Use `download=True` the first time you load a patricular dataset to automatically download the data from the RelBench server onto your machine.\n", + "\n", + "Now we have loaded the database, let's start poking around to see what's inside." + ], + "metadata": { + "id": "M3dJVd8A2sOt" + } + }, + { + "cell_type": "markdown", + "source": [ + "# Val / Test cutoffs\n", + "\n", + "We can check the val/test time cutoffs as follows:" + ], + "metadata": { + "id": "lgeqc7x82596" + } + }, + { + "cell_type": "code", + "source": [ + "dataset.val_timestamp, dataset.test_timestamp" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "239e0MQB2yMO", + "outputId": "07b2cf05-a8d4-4c6c-b5cd-a755617a5baa" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "(Timestamp('2005-01-01 00:00:00'), Timestamp('2010-01-01 00:00:00'))" + ] + }, + "metadata": {}, + "execution_count": 4 + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "This means that information upto 2005 can be used for training, and upto 2010 can be used for validation.\n", + "\n", + "Note that it is a RelBench design choice to make the validation and test cutoffs a dataset property, _not_ a task-specific property. In other words, all tasks for a given database use the same time splits.\n", + "\n", + "\n", + "# Load database\n", + "\n", + "Next we check out the database itself..." + ], + "metadata": { + "id": "-75pIuTV3Ae6" + } + }, + { + "cell_type": "code", + "source": [ + "db = dataset.get_db()" + ], + "metadata": { + "id": "dm1f8YDF273p" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "*This* returns a RelBench `Database` object. By default, the rows with timestamp > test_timestamp are excluded to prevent accidental test set leakage.\n", + "The complete database can be loaded with `database.get_db(upto_test_timestamp=False)`." + ], + "metadata": { + "id": "kpBRvvqb3FXQ" + } + }, + { + "cell_type": "markdown", + "source": [ + "With this we can double check the full timespan of the database:" + ], + "metadata": { + "id": "H9mRPjk33JfQ" + } + }, + { + "cell_type": "code", + "source": [ + "db.min_timestamp, db.max_timestamp" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Ahk12agU3HLA", + "outputId": "6d93a289-6063-4cd6-fc66-72076c7a8bd1" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "(Timestamp('1950-05-13 00:00:00'), Timestamp('2009-11-01 11:00:00'))" + ] + }, + "metadata": {}, + "execution_count": 6 + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "1950 is the first season for F1! So we have data for the full history of F1. Note that the `max_timestamp` is the same as `test_timestamp`." + ], + "metadata": { + "id": "oKmC3R-H3OIU" + } + }, + { + "cell_type": "markdown", + "source": [ + "Next let's check out the tables in the database.\n", + "\n", + "More info on the schemas for F1 and all other datasets can be found at https://relbench.stanford.edu/." + ], + "metadata": { + "id": "zmGcKvxX3SO5" + } + }, + { + "cell_type": "markdown", + "source": [ + "We have the following tables:" + ], + "metadata": { + "id": "7RobvEoY3XQt" + } + }, + { + "cell_type": "code", + "source": [ + "db.table_dict.keys()" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "LyAdmIry3UFl", + "outputId": "dde952af-3953-47d5-8831-98676018cc65" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "dict_keys(['constructors', 'results', 'standings', 'constructor_results', 'drivers', 'qualifying', 'races', 'circuits', 'constructor_standings'])" + ] + }, + "metadata": {}, + "execution_count": 7 + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "That's 9 tables total! Let's look more closely at one of them." + ], + "metadata": { + "id": "3sdJ6OgukDrL" + } + }, + { + "cell_type": "code", + "source": [ + "table = db.table_dict[\"drivers\"]\n", + "table" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "n0OGtbMy3Y2n", + "outputId": "42f6ee62-f1ba-4519-bfec-32808baba1f3" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "Table(df=\n", + " driverId driverRef code forename surname dob \\\n", + "0 0 hamilton HAM Lewis Hamilton 1985-01-07 \n", + "1 1 heidfeld HEI Nick Heidfeld 1977-05-10 \n", + "2 2 rosberg ROS Nico Rosberg 1985-06-27 \n", + "3 3 alonso ALO Fernando Alonso 1981-07-29 \n", + "4 4 kovalainen KOV Heikki Kovalainen 1981-10-19 \n", + ".. ... ... ... ... ... ... \n", + "852 852 mick_schumacher MSC Mick Schumacher 1999-03-22 \n", + "853 853 zhou ZHO Guanyu Zhou 1999-05-30 \n", + "854 854 de_vries DEV Nyck de Vries 1995-02-06 \n", + "855 855 piastri PIA Oscar Piastri 2001-04-06 \n", + "856 856 sargeant SAR Logan Sargeant 2000-12-31 \n", + "\n", + " nationality \n", + "0 British \n", + "1 German \n", + "2 German \n", + "3 Spanish \n", + "4 Finnish \n", + ".. ... \n", + "852 German \n", + "853 Chinese \n", + "854 Dutch \n", + "855 Australian \n", + "856 American \n", + "\n", + "[857 rows x 7 columns],\n", + " fkey_col_to_pkey_table={},\n", + " pkey_col=driverId,\n", + " time_col=None)" + ] + }, + "metadata": {}, + "execution_count": 8 + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "The `drivers` table stores information on all F1 drivers that ever competed in a race. Note that the table comes with multiple bits of information:\n", + "- The table itself, `table.df` which is simply a Pandas DataFrame.\n", + "- The primary key column, `table.pkey_col`, which indicates that the `driverId` column holds the primary key for this particular table in the database.\n", + "- The primary time column, `table.time_col` which, if the entity is an event, records the time an event happened. In the case of drivers, they are non-temporal entities, so `table.time_col=None`.\n", + "- The other tables that foreign keys points to `table.fkey_col_to_pkey_table`. If the table has any foreign key columns, then this dict indicates which table we foreign key corresponds to. Again in the case of drivers this is not applicable.\n", + "\n", + "We can start to explore the data a little, e.g., check out the oldest and youngest ever F1 drivers, spanning 3 centuries!" + ], + "metadata": { + "id": "ZPEZo1TNkMaU" + } + }, + { + "cell_type": "code", + "source": [ + "table.df.iloc[table.df[\"dob\"].idxmax()]" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "paE62mcKkDCF", + "outputId": "f6535bd4-059f-4f87-bb60-b5ac05cc3889" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "driverId 855\n", + "driverRef piastri\n", + "code PIA\n", + "forename Oscar\n", + "surname Piastri\n", + "dob 2001-04-06 00:00:00\n", + "nationality Australian\n", + "Name: 855, dtype: object" + ] + }, + "metadata": {}, + "execution_count": 9 + } + ] + }, + { + "cell_type": "code", + "source": [ + "table.df.iloc[table.df[\"dob\"].idxmin()]" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "pqxywakH29js", + "outputId": "a5c2d066-3700-43da-83c1-0fd6828ceca8" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "driverId 741\n", + "driverRef etancelin\n", + "code \\N\n", + "forename Philippe\n", + "surname Étancelin\n", + "dob 1896-12-28 00:00:00\n", + "nationality French\n", + "Name: 741, dtype: object" + ] + }, + "metadata": {}, + "execution_count": 10 + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "Going back to the `table.time_col` and `table.fkey_col_to_pkey_table`, the `results` table contains a non-trivial example." + ], + "metadata": { + "id": "VJ3Se_Irkff1" + } + }, + { + "cell_type": "code", + "source": [ + "table = db.table_dict[\"results\"]\n", + "table.df" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 791 + }, + "id": "DSmDQc3Pke5c", + "outputId": "726f3491-5436-4a95-a905-607279afd7b6" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + " resultId raceId driverId constructorId number grid position \\\n", + "0 0 0 660 152 18.0 21 11.0 \n", + "1 1 0 790 149 8.0 12 NaN \n", + "2 2 0 579 49 1.0 3 NaN \n", + "3 3 0 661 149 9.0 10 NaN \n", + "4 4 0 789 152 17.0 7 NaN \n", + "... ... ... ... ... ... ... ... \n", + "20318 20318 819 1 1 6.0 8 5.0 \n", + "20319 20319 819 21 22 23.0 4 4.0 \n", + "20320 20320 819 17 22 22.0 5 3.0 \n", + "20321 20321 819 16 8 14.0 3 2.0 \n", + "20322 20322 819 2 2 16.0 9 9.0 \n", + "\n", + " positionOrder points laps milliseconds fastestLap rank statusId \\\n", + "0 11 0.0 64 NaN NaN NaN 16 \n", + "1 21 0.0 2 NaN NaN NaN 126 \n", + "2 12 0.0 62 NaN NaN NaN 44 \n", + "3 20 0.0 5 NaN NaN NaN 6 \n", + "4 19 0.0 8 NaN NaN NaN 51 \n", + "... ... ... ... ... ... ... ... \n", + "20318 5 4.0 55 5669667.0 54.0 7.0 1 \n", + "20319 4 5.0 55 5666149.0 54.0 4.0 1 \n", + "20320 3 6.0 55 5661881.0 49.0 6.0 1 \n", + "20321 2 8.0 55 5661271.0 14.0 5.0 1 \n", + "20322 9 0.0 55 5689355.0 49.0 15.0 1 \n", + "\n", + " date \n", + "0 1950-05-13 00:00:00 \n", + "1 1950-05-13 00:00:00 \n", + "2 1950-05-13 00:00:00 \n", + "3 1950-05-13 00:00:00 \n", + "4 1950-05-13 00:00:00 \n", + "... ... \n", + "20318 2009-11-01 11:00:00 \n", + "20319 2009-11-01 11:00:00 \n", + "20320 2009-11-01 11:00:00 \n", + "20321 2009-11-01 11:00:00 \n", + "20322 2009-11-01 11:00:00 \n", + "\n", + "[20323 rows x 15 columns]" + ], + "text/html": [ + "\n", + "
\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
resultIdraceIddriverIdconstructorIdnumbergridpositionpositionOrderpointslapsmillisecondsfastestLaprankstatusIddate
00066015218.02111.0110.064NaNNaNNaN161950-05-13 00:00:00
1107901498.012NaN210.02NaNNaNNaN1261950-05-13 00:00:00
220579491.03NaN120.062NaNNaNNaN441950-05-13 00:00:00
3306611499.010NaN200.05NaNNaNNaN61950-05-13 00:00:00
44078915217.07NaN190.08NaNNaNNaN511950-05-13 00:00:00
................................................
2031820318819116.085.054.0555669667.054.07.012009-11-01 11:00:00
2031920319819212223.044.045.0555666149.054.04.012009-11-01 11:00:00
2032020320819172222.053.036.0555661881.049.06.012009-11-01 11:00:00
203212032181916814.032.028.0555661271.014.05.012009-11-01 11:00:00
20322203228192216.099.090.0555689355.049.015.012009-11-01 11:00:00
\n", + "

20323 rows × 15 columns

\n", + "
\n", + "
\n", + "\n", + "
\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "
\n", + "\n", + "\n", + "
\n", + " \n", + "\n", + "\n", + "\n", + " \n", + "
\n", + "\n", + "
\n", + "
\n" + ], + "application/vnd.google.colaboratory.intrinsic+json": { + "type": "dataframe", + "summary": "{\n \"name\": \"table\",\n \"rows\": 20323,\n \"fields\": [\n {\n \"column\": \"resultId\",\n \"properties\": {\n \"dtype\": \"Int64\",\n \"num_unique_values\": 20323,\n \"samples\": [\n 8086,\n 10911,\n 3018\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"raceId\",\n \"properties\": {\n \"dtype\": \"Int64\",\n \"num_unique_values\": 820,\n \"samples\": [\n 641,\n 333,\n 67\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"driverId\",\n \"properties\": {\n \"dtype\": \"Int64\",\n \"num_unique_values\": 806,\n \"samples\": [\n 141,\n 600,\n 561\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"constructorId\",\n \"properties\": {\n \"dtype\": \"Int64\",\n \"num_unique_values\": 199,\n \"samples\": [\n 33,\n 105,\n 80\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"number\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 12.618109254667905,\n \"min\": 0.0,\n \"max\": 208.0,\n \"num_unique_values\": 128,\n \"samples\": [\n 58.0,\n 62.0,\n 23.0\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"grid\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 7,\n \"min\": 0,\n \"max\": 34,\n \"num_unique_values\": 35,\n \"samples\": [\n 29,\n 9,\n 28\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"position\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 4.368745850543276,\n \"min\": 1.0,\n \"max\": 33.0,\n \"num_unique_values\": 33,\n \"samples\": [\n 26.0,\n 19.0,\n 31.0\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"positionOrder\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 8,\n \"min\": 1,\n \"max\": 39,\n \"num_unique_values\": 39,\n \"samples\": [\n 34,\n 38,\n 19\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"points\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 2.3409589523698164,\n \"min\": 0.0,\n \"max\": 10.0,\n \"num_unique_values\": 23,\n \"samples\": [\n 4.14,\n 7.0,\n 0.0\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"laps\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 32,\n \"min\": 0,\n \"max\": 200,\n \"num_unique_values\": 172,\n \"samples\": [\n 80,\n 146,\n 134\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"milliseconds\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 1863478.9976882094,\n \"min\": 1474899.0,\n \"max\": 15090540.0,\n \"num_unique_values\": 4340,\n \"samples\": [\n 7276534.0,\n 6101558.0,\n 5954029.0\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"fastestLap\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 18.115517488881853,\n \"min\": 2.0,\n \"max\": 77.0,\n \"num_unique_values\": 76,\n \"samples\": [\n 10.0,\n 18.0,\n 29.0\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"rank\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 5.790800780283663,\n \"min\": 1.0,\n \"max\": 22.0,\n \"num_unique_values\": 22,\n \"samples\": [\n 14.0,\n 3.0,\n 16.0\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"statusId\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 27,\n \"min\": 1,\n \"max\": 127,\n \"num_unique_values\": 125,\n \"samples\": [\n 67,\n 80,\n 31\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"date\",\n \"properties\": {\n \"dtype\": \"date\",\n \"min\": \"1950-05-13 00:00:00\",\n \"max\": \"2009-11-01 11:00:00\",\n \"num_unique_values\": 820,\n \"samples\": [\n \"1999-08-29 00:00:00\",\n \"1980-05-18 00:00:00\",\n \"1958-05-30 00:00:00\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}" + } + }, + "metadata": {}, + "execution_count": 11 + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "Here we start to notice certain data artifacts that might be good to keep in mind for later when doing ML modeling. For instance, the `milliseconds` and `fastestLap` columns seem to only have been collected for more recent races, with `NaN` features for earlier races." + ], + "metadata": { + "id": "pc23JvAykj7c" + } + }, + { + "cell_type": "markdown", + "source": [ + "# Loading a task\n", + "\n", + "Each RelBench dataset comes with multiple pre-defined predictive tasks. For any given RelBench dataset, you can check all the associated tasks with:" + ], + "metadata": { + "id": "NBzQPRJvkqUh" + } + }, + { + "cell_type": "code", + "source": [ + "from relbench.tasks import get_task_names, get_task\n", + "\n", + "get_task_names(\"rel-f1\")" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "TyHERgUAkjU8", + "outputId": "b53a58ec-c657-4403-ab66-0735499b31b6" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "['driver-position', 'driver-dnf', 'driver-top3']" + ] + }, + "metadata": {}, + "execution_count": 17 + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "Check out https://relbench.stanford.edu/ for detailed descriptions of what each task is. As an example, let's use `driver-top3` where the task is, for a given driver and a given timestamp, to predict whether that driver will finish in the top 3 in some race in the next 30 days.\n", + "\n", + "The task itself is instantiated by calling:" + ], + "metadata": { + "id": "aV6Ks72nku6y" + } + }, + { + "cell_type": "code", + "source": [ + "task = get_task(\"rel-f1\", \"driver-top3\", download=True)" + ], + "metadata": { + "id": "IbVQxxdxktyU", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "540d4d51-421d-45b8-ddf1-e93b413a970e" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "Downloading file 'rel-f1/tasks/driver-top3.zip' from 'https://relbench.stanford.edu/download/rel-f1/tasks/driver-top3.zip' to '/root/.cache/relbench'.\n", + "100%|█████████████████████████████████████| 10.3k/10.3k [00:00<00:00, 1.46MB/s]\n", + "Unzipping contents of '/root/.cache/relbench/rel-f1/tasks/driver-top3.zip' to '/root/.cache/relbench/rel-f1/tasks/.'\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "Next we load the train / val / test labels. Each task table contains triples (timestamp, Id, label) indicating the entity the label is associated to, the timepoint at which the prediction is made, an the label itself. The task table also indicates which database table it is \"attached\" to - in this case the the `drivers` table." + ], + "metadata": { + "id": "hkyvGxbSk0Tr" + } + }, + { + "cell_type": "code", + "source": [ + "task.get_table(\"train\")" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "k90T-Oq1kx89", + "outputId": "63462965-c459-4d90-e945-f4a53ea39939" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "Table(df=\n", + " date driverId qualifying\n", + "0 2004-08-04 12 0\n", + "1 2004-08-04 20 0\n", + "2 2004-07-05 10 0\n", + "3 2004-07-05 47 0\n", + "4 2004-06-05 31 0\n", + "... ... ... ...\n", + "1348 1994-03-30 80 0\n", + "1349 1994-03-30 48 0\n", + "1350 1994-03-30 77 0\n", + "1351 1994-02-28 43 0\n", + "1352 1994-02-28 56 0\n", + "\n", + "[1353 rows x 3 columns],\n", + " fkey_col_to_pkey_table={'driverId': 'drivers'},\n", + " pkey_col=None,\n", + " time_col=date)" + ] + }, + "metadata": {}, + "execution_count": 19 + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "The test table is handled differently, with the labels being hidden by default to prevent accidental test set leakage." + ], + "metadata": { + "id": "-GuWKV6Hk4Pz" + } + }, + { + "cell_type": "code", + "source": [ + "task.get_table(\"test\")" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "o3m1ROIDk2NC", + "outputId": "656ed6c9-11f1-42e8-d1bd-16fd091f6ac9" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "Table(df=\n", + " date driverId\n", + "0 2013-03-16 814\n", + "1 2012-11-16 9\n", + "2 2012-11-16 17\n", + "3 2012-10-17 0\n", + "4 2012-09-17 816\n", + ".. ... ...\n", + "721 2010-07-30 14\n", + "722 2010-06-30 154\n", + "723 2010-06-30 14\n", + "724 2010-05-01 14\n", + "725 2010-05-01 154\n", + "\n", + "[726 rows x 2 columns],\n", + " fkey_col_to_pkey_table={'driverId': 'drivers'},\n", + " pkey_col=None,\n", + " time_col=date)" + ] + }, + "metadata": {}, + "execution_count": 20 + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "If strictly needed, test labels can be retrieved by calling:" + ], + "metadata": { + "id": "QgnTWc2Kk9gY" + } + }, + { + "cell_type": "code", + "source": [ + "task.get_table(\"test\", mask_input_cols=False)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "dhh0H7rBk6GL", + "outputId": "5651517f-ba75-4e01-d3d2-71d0536181db" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "Table(df=\n", + " date driverId qualifying\n", + "0 2013-03-16 814 0\n", + "1 2012-11-16 9 0\n", + "2 2012-11-16 17 1\n", + "3 2012-10-17 0 1\n", + "4 2012-09-17 816 0\n", + ".. ... ... ...\n", + "721 2010-07-30 14 0\n", + "722 2010-06-30 154 0\n", + "723 2010-06-30 14 0\n", + "724 2010-05-01 14 0\n", + "725 2010-05-01 154 0\n", + "\n", + "[726 rows x 3 columns],\n", + " fkey_col_to_pkey_table={'driverId': 'drivers'},\n", + " pkey_col=None,\n", + " time_col=date)" + ] + }, + "metadata": {}, + "execution_count": 21 + } + ] + }, + { + "cell_type": "code", + "source": [], + "metadata": { + "id": "XCyVAj1Q6E0Z" + }, + "execution_count": null, + "outputs": [] + } + ] +} diff --git a/tutorials/train_model.ipynb b/tutorials/train_model.ipynb new file mode 100644 index 00000000..c3ac2b52 --- /dev/null +++ b/tutorials/train_model.ipynb @@ -0,0 +1,1444 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "gpuType": "T4" + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "zNziUzq9nTdU", + "outputId": "edb40abf-b984-4fec-8033-1ed92fbdb128" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "2.3.0+cu121\n", + "Looking in links: https://data.pyg.org/whl/torch-2.3.0+cu121.html\n", + "Collecting pyg-lib\n", + " Downloading https://data.pyg.org/whl/torch-2.3.0%2Bcu121/pyg_lib-0.4.0%2Bpt23cu121-cp310-cp310-linux_x86_64.whl (2.5 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.5/2.5 MB\u001b[0m \u001b[31m22.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hInstalling collected packages: pyg-lib\n", + "Successfully installed pyg-lib-0.4.0+pt23cu121\n", + "Collecting git+https://github.com/pyg-team/pytorch_geometric.git\n", + " Cloning https://github.com/pyg-team/pytorch_geometric.git to /tmp/pip-req-build-0v5d62c1\n", + " Running command git clone --filter=blob:none --quiet https://github.com/pyg-team/pytorch_geometric.git /tmp/pip-req-build-0v5d62c1\n", + " Resolved https://github.com/pyg-team/pytorch_geometric.git to commit fbafbc4fc9181e8759ec1f39d9618992793b5fe1\n", + " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", + " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", + " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from torch-geometric==2.6.0) (3.9.5)\n", + "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch-geometric==2.6.0) (2023.6.0)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch-geometric==2.6.0) (3.1.4)\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from torch-geometric==2.6.0) (1.25.2)\n", + "Requirement already satisfied: psutil>=5.8.0 in /usr/local/lib/python3.10/dist-packages (from torch-geometric==2.6.0) (5.9.5)\n", + "Requirement already satisfied: pyparsing in /usr/local/lib/python3.10/dist-packages (from torch-geometric==2.6.0) (3.1.2)\n", + "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from torch-geometric==2.6.0) (2.31.0)\n", + "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from torch-geometric==2.6.0) (4.66.4)\n", + "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->torch-geometric==2.6.0) (1.3.1)\n", + "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->torch-geometric==2.6.0) (23.2.0)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->torch-geometric==2.6.0) (1.4.1)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->torch-geometric==2.6.0) (6.0.5)\n", + "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->torch-geometric==2.6.0) (1.9.4)\n", + "Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->torch-geometric==2.6.0) (4.0.3)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch-geometric==2.6.0) (2.1.5)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->torch-geometric==2.6.0) (3.3.2)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->torch-geometric==2.6.0) (3.7)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->torch-geometric==2.6.0) (2.0.7)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->torch-geometric==2.6.0) (2024.6.2)\n", + "Building wheels for collected packages: torch-geometric\n", + " Building wheel for torch-geometric (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + " Created wheel for torch-geometric: filename=torch_geometric-2.6.0-py3-none-any.whl size=1122975 sha256=942596ab7c5d81703af08d193615963fbfde8806a6840d5f3ccef502768d7d18\n", + " Stored in directory: /tmp/pip-ephem-wheel-cache-9vz0psfl/wheels/d3/78/eb/9e26525b948d19533f1688fb6c209cec8a0ba793d39b49ae8f\n", + "Successfully built torch-geometric\n", + "Installing collected packages: torch-geometric\n", + "Successfully installed torch-geometric-2.6.0\n", + "Collecting pytorch_frame[full]\n", + " Downloading pytorch_frame-0.2.2-py3-none-any.whl (140 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m140.8/140.8 kB\u001b[0m \u001b[31m4.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from pytorch_frame[full]) (1.25.2)\n", + "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from pytorch_frame[full]) (2.0.3)\n", + "Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (from pytorch_frame[full]) (2.3.0+cu121)\n", + "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from pytorch_frame[full]) (4.66.4)\n", + "Requirement already satisfied: pyarrow in /usr/local/lib/python3.10/dist-packages (from pytorch_frame[full]) (14.0.2)\n", + "Requirement already satisfied: Pillow in /usr/local/lib/python3.10/dist-packages (from pytorch_frame[full]) (9.4.0)\n", + "Requirement already satisfied: scikit-learn in /usr/local/lib/python3.10/dist-packages (from pytorch_frame[full]) (1.2.2)\n", + "Collecting xgboost<2.0.0,>=1.7.0 (from pytorch_frame[full])\n", + " Downloading xgboost-1.7.6-py3-none-manylinux2014_x86_64.whl (200.3 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m200.3/200.3 MB\u001b[0m \u001b[31m3.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting optuna>=3.0.0 (from pytorch_frame[full])\n", + " Downloading optuna-3.6.1-py3-none-any.whl (380 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m380.1/380.1 kB\u001b[0m \u001b[31m50.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: mpmath==1.3.0 in /usr/local/lib/python3.10/dist-packages (from pytorch_frame[full]) (1.3.0)\n", + "Collecting catboost (from pytorch_frame[full])\n", + " Downloading catboost-1.2.5-cp310-cp310-manylinux2014_x86_64.whl (98.2 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m98.2/98.2 MB\u001b[0m \u001b[31m9.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: lightgbm in /usr/local/lib/python3.10/dist-packages (from pytorch_frame[full]) (4.1.0)\n", + "Collecting datasets (from pytorch_frame[full])\n", + " Downloading datasets-2.20.0-py3-none-any.whl (547 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m547.8/547.8 kB\u001b[0m \u001b[31m45.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting alembic>=1.5.0 (from optuna>=3.0.0->pytorch_frame[full])\n", + " Downloading alembic-1.13.2-py3-none-any.whl (232 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m233.0/233.0 kB\u001b[0m \u001b[31m30.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting colorlog (from optuna>=3.0.0->pytorch_frame[full])\n", + " Downloading colorlog-6.8.2-py3-none-any.whl (11 kB)\n", + "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from optuna>=3.0.0->pytorch_frame[full]) (24.1)\n", + "Requirement already satisfied: sqlalchemy>=1.3.0 in /usr/local/lib/python3.10/dist-packages (from optuna>=3.0.0->pytorch_frame[full]) (2.0.31)\n", + "Requirement already satisfied: PyYAML in /usr/local/lib/python3.10/dist-packages (from optuna>=3.0.0->pytorch_frame[full]) (6.0.1)\n", + "Requirement already satisfied: scipy in /usr/local/lib/python3.10/dist-packages (from xgboost<2.0.0,>=1.7.0->pytorch_frame[full]) (1.11.4)\n", + "Requirement already satisfied: graphviz in /usr/local/lib/python3.10/dist-packages (from catboost->pytorch_frame[full]) (0.20.3)\n", + "Requirement already satisfied: matplotlib in /usr/local/lib/python3.10/dist-packages (from catboost->pytorch_frame[full]) (3.7.1)\n", + "Requirement already satisfied: plotly in /usr/local/lib/python3.10/dist-packages (from catboost->pytorch_frame[full]) (5.15.0)\n", + "Requirement already satisfied: six in /usr/local/lib/python3.10/dist-packages (from catboost->pytorch_frame[full]) (1.16.0)\n", + "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->pytorch_frame[full]) (2.8.2)\n", + "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->pytorch_frame[full]) (2023.4)\n", + "Requirement already satisfied: tzdata>=2022.1 in /usr/local/lib/python3.10/dist-packages (from pandas->pytorch_frame[full]) (2024.1)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from datasets->pytorch_frame[full]) (3.15.4)\n", + "Collecting pyarrow (from pytorch_frame[full])\n", + " Downloading pyarrow-16.1.0-cp310-cp310-manylinux_2_28_x86_64.whl (40.8 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m40.8/40.8 MB\u001b[0m \u001b[31m13.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: pyarrow-hotfix in /usr/local/lib/python3.10/dist-packages (from datasets->pytorch_frame[full]) (0.6)\n", + "Collecting dill<0.3.9,>=0.3.0 (from datasets->pytorch_frame[full])\n", + " Downloading dill-0.3.8-py3-none-any.whl (116 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m116.3/116.3 kB\u001b[0m \u001b[31m16.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting requests>=2.32.2 (from datasets->pytorch_frame[full])\n", + " Downloading requests-2.32.3-py3-none-any.whl (64 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m64.9/64.9 kB\u001b[0m \u001b[31m9.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting xxhash (from datasets->pytorch_frame[full])\n", + " Downloading xxhash-3.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m194.1/194.1 kB\u001b[0m \u001b[31m19.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting multiprocess (from datasets->pytorch_frame[full])\n", + " Downloading multiprocess-0.70.16-py310-none-any.whl (134 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m10.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: fsspec[http]<=2024.5.0,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from datasets->pytorch_frame[full]) (2023.6.0)\n", + "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets->pytorch_frame[full]) (3.9.5)\n", + "Requirement already satisfied: huggingface-hub>=0.21.2 in /usr/local/lib/python3.10/dist-packages (from datasets->pytorch_frame[full]) (0.23.4)\n", + "Requirement already satisfied: joblib>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from scikit-learn->pytorch_frame[full]) (1.4.2)\n", + "Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn->pytorch_frame[full]) (3.5.0)\n", + "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch->pytorch_frame[full]) (4.12.2)\n", + "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch->pytorch_frame[full]) (1.12.1)\n", + "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch->pytorch_frame[full]) (3.3)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch->pytorch_frame[full]) (3.1.4)\n", + "Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch->pytorch_frame[full])\n", + " Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)\n", + "Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch->pytorch_frame[full])\n", + " Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)\n", + "Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch->pytorch_frame[full])\n", + " Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)\n", + "Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch->pytorch_frame[full])\n", + " Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)\n", + "Collecting nvidia-cublas-cu12==12.1.3.1 (from torch->pytorch_frame[full])\n", + " Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)\n", + "Collecting nvidia-cufft-cu12==11.0.2.54 (from torch->pytorch_frame[full])\n", + " Using cached nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)\n", + "Collecting nvidia-curand-cu12==10.3.2.106 (from torch->pytorch_frame[full])\n", + " Using cached nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl (56.5 MB)\n", + "Collecting nvidia-cusolver-cu12==11.4.5.107 (from torch->pytorch_frame[full])\n", + " Using cached nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl (124.2 MB)\n", + "Collecting nvidia-cusparse-cu12==12.1.0.106 (from torch->pytorch_frame[full])\n", + " Using cached nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl (196.0 MB)\n", + "Collecting nvidia-nccl-cu12==2.20.5 (from torch->pytorch_frame[full])\n", + " Using cached nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl (176.2 MB)\n", + "Collecting nvidia-nvtx-cu12==12.1.105 (from torch->pytorch_frame[full])\n", + " Using cached nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (99 kB)\n", + "Requirement already satisfied: triton==2.3.0 in /usr/local/lib/python3.10/dist-packages (from torch->pytorch_frame[full]) (2.3.0)\n", + "Collecting nvidia-nvjitlink-cu12 (from nvidia-cusolver-cu12==11.4.5.107->torch->pytorch_frame[full])\n", + " Downloading nvidia_nvjitlink_cu12-12.5.82-py3-none-manylinux2014_x86_64.whl (21.3 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.3/21.3 MB\u001b[0m \u001b[31m71.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting Mako (from alembic>=1.5.0->optuna>=3.0.0->pytorch_frame[full])\n", + " Downloading Mako-1.3.5-py3-none-any.whl (78 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m78.6/78.6 kB\u001b[0m \u001b[31m14.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->pytorch_frame[full]) (1.3.1)\n", + "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->pytorch_frame[full]) (23.2.0)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->pytorch_frame[full]) (1.4.1)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->pytorch_frame[full]) (6.0.5)\n", + "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->pytorch_frame[full]) (1.9.4)\n", + "Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->pytorch_frame[full]) (4.0.3)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets->pytorch_frame[full]) (3.3.2)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets->pytorch_frame[full]) (3.7)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets->pytorch_frame[full]) (2.0.7)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets->pytorch_frame[full]) (2024.6.2)\n", + "Requirement already satisfied: greenlet!=0.4.17 in /usr/local/lib/python3.10/dist-packages (from sqlalchemy>=1.3.0->optuna>=3.0.0->pytorch_frame[full]) (3.0.3)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch->pytorch_frame[full]) (2.1.5)\n", + "Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->catboost->pytorch_frame[full]) (1.2.1)\n", + "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib->catboost->pytorch_frame[full]) (0.12.1)\n", + "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->catboost->pytorch_frame[full]) (4.53.0)\n", + "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->catboost->pytorch_frame[full]) (1.4.5)\n", + "Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->catboost->pytorch_frame[full]) (3.1.2)\n", + "Requirement already satisfied: tenacity>=6.2.0 in /usr/local/lib/python3.10/dist-packages (from plotly->catboost->pytorch_frame[full]) (8.4.2)\n", + "Installing collected packages: xxhash, requests, pyarrow, nvidia-nvtx-cu12, nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, Mako, dill, colorlog, xgboost, nvidia-cusparse-cu12, nvidia-cudnn-cu12, multiprocess, alembic, optuna, nvidia-cusolver-cu12, catboost, datasets, pytorch_frame\n", + " Attempting uninstall: requests\n", + " Found existing installation: requests 2.31.0\n", + " Uninstalling requests-2.31.0:\n", + " Successfully uninstalled requests-2.31.0\n", + " Attempting uninstall: pyarrow\n", + " Found existing installation: pyarrow 14.0.2\n", + " Uninstalling pyarrow-14.0.2:\n", + " Successfully uninstalled pyarrow-14.0.2\n", + " Attempting uninstall: xgboost\n", + " Found existing installation: xgboost 2.0.3\n", + " Uninstalling xgboost-2.0.3:\n", + " Successfully uninstalled xgboost-2.0.3\n", + "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", + "cudf-cu12 24.4.1 requires pyarrow<15.0.0a0,>=14.0.1, but you have pyarrow 16.1.0 which is incompatible.\n", + "google-colab 1.0.0 requires requests==2.31.0, but you have requests 2.32.3 which is incompatible.\n", + "ibis-framework 8.0.0 requires pyarrow<16,>=2, but you have pyarrow 16.1.0 which is incompatible.\u001b[0m\u001b[31m\n", + "\u001b[0mSuccessfully installed Mako-1.3.5 alembic-1.13.2 catboost-1.2.5 colorlog-6.8.2 datasets-2.20.0 dill-0.3.8 multiprocess-0.70.16 nvidia-cublas-cu12-12.1.3.1 nvidia-cuda-cupti-cu12-12.1.105 nvidia-cuda-nvrtc-cu12-12.1.105 nvidia-cuda-runtime-cu12-12.1.105 nvidia-cudnn-cu12-8.9.2.26 nvidia-cufft-cu12-11.0.2.54 nvidia-curand-cu12-10.3.2.106 nvidia-cusolver-cu12-11.4.5.107 nvidia-cusparse-cu12-12.1.0.106 nvidia-nccl-cu12-2.20.5 nvidia-nvjitlink-cu12-12.5.82 nvidia-nvtx-cu12-12.1.105 optuna-3.6.1 pyarrow-16.1.0 pytorch_frame-0.2.2 requests-2.32.3 xgboost-1.7.6 xxhash-3.4.1\n", + "Collecting relbench[full]==1.0.0rc1\n", + " Downloading relbench-1.0.0rc1-py3-none-any.whl (57 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m57.4/57.4 kB\u001b[0m \u001b[31m3.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from relbench[full]==1.0.0rc1) (2.0.3)\n", + "Requirement already satisfied: pooch in /usr/local/lib/python3.10/dist-packages (from relbench[full]==1.0.0rc1) (1.8.2)\n", + "Requirement already satisfied: pyarrow in /usr/local/lib/python3.10/dist-packages (from relbench[full]==1.0.0rc1) (16.1.0)\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from relbench[full]==1.0.0rc1) (1.25.2)\n", + "Requirement already satisfied: duckdb in /usr/local/lib/python3.10/dist-packages (from relbench[full]==1.0.0rc1) (0.10.3)\n", + "Requirement already satisfied: scikit-learn in /usr/local/lib/python3.10/dist-packages (from relbench[full]==1.0.0rc1) (1.2.2)\n", + "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from relbench[full]==1.0.0rc1) (4.12.2)\n", + "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from relbench[full]==1.0.0rc1) (4.66.4)\n", + "Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (from relbench[full]==1.0.0rc1) (2.3.0+cu121)\n", + "Requirement already satisfied: pytorch_frame>=0.2.2 in /usr/local/lib/python3.10/dist-packages (from relbench[full]==1.0.0rc1) (0.2.2)\n", + "Requirement already satisfied: torch_geometric in /usr/local/lib/python3.10/dist-packages (from relbench[full]==1.0.0rc1) (2.6.0)\n", + "Requirement already satisfied: Pillow in /usr/local/lib/python3.10/dist-packages (from pytorch_frame>=0.2.2->relbench[full]==1.0.0rc1) (9.4.0)\n", + "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->relbench[full]==1.0.0rc1) (2.8.2)\n", + "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->relbench[full]==1.0.0rc1) (2023.4)\n", + "Requirement already satisfied: tzdata>=2022.1 in /usr/local/lib/python3.10/dist-packages (from pandas->relbench[full]==1.0.0rc1) (2024.1)\n", + "Requirement already satisfied: platformdirs>=2.5.0 in /usr/local/lib/python3.10/dist-packages (from pooch->relbench[full]==1.0.0rc1) (4.2.2)\n", + "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from pooch->relbench[full]==1.0.0rc1) (24.1)\n", + "Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.10/dist-packages (from pooch->relbench[full]==1.0.0rc1) (2.32.3)\n", + "Requirement already satisfied: scipy>=1.3.2 in /usr/local/lib/python3.10/dist-packages (from scikit-learn->relbench[full]==1.0.0rc1) (1.11.4)\n", + "Requirement already satisfied: joblib>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from scikit-learn->relbench[full]==1.0.0rc1) (1.4.2)\n", + "Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn->relbench[full]==1.0.0rc1) (3.5.0)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch->relbench[full]==1.0.0rc1) (3.15.4)\n", + "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch->relbench[full]==1.0.0rc1) (1.12.1)\n", + "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch->relbench[full]==1.0.0rc1) (3.3)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch->relbench[full]==1.0.0rc1) (3.1.4)\n", + "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch->relbench[full]==1.0.0rc1) (2023.6.0)\n", + "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch->relbench[full]==1.0.0rc1) (12.1.105)\n", + "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch->relbench[full]==1.0.0rc1) (12.1.105)\n", + "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch->relbench[full]==1.0.0rc1) (12.1.105)\n", + "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /usr/local/lib/python3.10/dist-packages (from torch->relbench[full]==1.0.0rc1) (8.9.2.26)\n", + "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /usr/local/lib/python3.10/dist-packages (from torch->relbench[full]==1.0.0rc1) (12.1.3.1)\n", + "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /usr/local/lib/python3.10/dist-packages (from torch->relbench[full]==1.0.0rc1) (11.0.2.54)\n", + "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /usr/local/lib/python3.10/dist-packages (from torch->relbench[full]==1.0.0rc1) (10.3.2.106)\n", + "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /usr/local/lib/python3.10/dist-packages (from torch->relbench[full]==1.0.0rc1) (11.4.5.107)\n", + "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /usr/local/lib/python3.10/dist-packages (from torch->relbench[full]==1.0.0rc1) (12.1.0.106)\n", + "Requirement already satisfied: nvidia-nccl-cu12==2.20.5 in /usr/local/lib/python3.10/dist-packages (from torch->relbench[full]==1.0.0rc1) (2.20.5)\n", + "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch->relbench[full]==1.0.0rc1) (12.1.105)\n", + "Requirement already satisfied: triton==2.3.0 in /usr/local/lib/python3.10/dist-packages (from torch->relbench[full]==1.0.0rc1) (2.3.0)\n", + "Requirement already satisfied: nvidia-nvjitlink-cu12 in /usr/local/lib/python3.10/dist-packages (from nvidia-cusolver-cu12==11.4.5.107->torch->relbench[full]==1.0.0rc1) (12.5.82)\n", + "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from torch_geometric->relbench[full]==1.0.0rc1) (3.9.5)\n", + "Requirement already satisfied: psutil>=5.8.0 in /usr/local/lib/python3.10/dist-packages (from torch_geometric->relbench[full]==1.0.0rc1) (5.9.5)\n", + "Requirement already satisfied: pyparsing in /usr/local/lib/python3.10/dist-packages (from torch_geometric->relbench[full]==1.0.0rc1) (3.1.2)\n", + "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas->relbench[full]==1.0.0rc1) (1.16.0)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->pooch->relbench[full]==1.0.0rc1) (3.3.2)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->pooch->relbench[full]==1.0.0rc1) (3.7)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->pooch->relbench[full]==1.0.0rc1) (2.0.7)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->pooch->relbench[full]==1.0.0rc1) (2024.6.2)\n", + "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->torch_geometric->relbench[full]==1.0.0rc1) (1.3.1)\n", + "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->torch_geometric->relbench[full]==1.0.0rc1) (23.2.0)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->torch_geometric->relbench[full]==1.0.0rc1) (1.4.1)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->torch_geometric->relbench[full]==1.0.0rc1) (6.0.5)\n", + "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->torch_geometric->relbench[full]==1.0.0rc1) (1.9.4)\n", + "Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->torch_geometric->relbench[full]==1.0.0rc1) (4.0.3)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch->relbench[full]==1.0.0rc1) (2.1.5)\n", + "Requirement already satisfied: mpmath<1.4.0,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy->torch->relbench[full]==1.0.0rc1) (1.3.0)\n", + "Installing collected packages: relbench\n", + "Successfully installed relbench-1.0.0rc1\n" + ] + } + ], + "source": [ + "# Install required packages.\n", + "import os\n", + "import torch\n", + "\n", + "os.environ[\"TORCH\"] = torch.__version__\n", + "print(torch.__version__)\n", + "\n", + "!pip install pyg-lib -f https://data.pyg.org/whl/torch-${TORCH}.html # PyG for working with graphs\n", + "!pip install git+https://github.com/pyg-team/pytorch_geometric.git # more PyG\n", + "!pip install pytorch_frame[full] #PyTorch Frame for working with tabular data\n", + "!pip install relbench[full]==1.0.0rc1" + ] + }, + { + "cell_type": "code", + "source": [ + "import numpy as np\n", + "\n", + "from torch.nn import BCEWithLogitsLoss, L1Loss\n", + "from relbench.datasets import get_dataset\n", + "from relbench.tasks import get_task\n", + "\n", + "dataset = get_dataset(\"rel-f1\", download=True)\n", + "task = get_task(\"rel-f1\", \"driver-position\", download=True)\n", + "\n", + "train_table = task.get_table(\"train\")\n", + "val_table = task.get_table(\"val\")\n", + "test_table = task.get_table(\"test\")\n", + "\n", + "out_channels = 1\n", + "loss_fn = L1Loss()\n", + "tune_metric = \"mae\"\n", + "higher_is_better = False" + ], + "metadata": { + "id": "6DWB-Kf6nl2y" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Let's check out the training table just to make sure it looks fine." + ], + "metadata": { + "id": "UKFT5H51j_Um" + } + }, + { + "cell_type": "code", + "source": [ + "train_table" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "ABN_fdN3kAB9", + "outputId": "03d4a31a-124d-45c7-b4dc-9713e5e4b942" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "Table(df=\n", + " date driverId position\n", + "0 2004-07-05 10 10.75\n", + "1 2004-07-05 47 12.00\n", + "2 2004-03-07 7 15.00\n", + "3 2004-01-07 10 9.00\n", + "4 2003-09-09 52 13.00\n", + "... ... ... ...\n", + "7448 1995-08-22 96 15.75\n", + "7449 1975-06-08 228 8.00\n", + "7450 1965-05-31 418 16.00\n", + "7451 1961-08-20 467 37.00\n", + "7452 1954-05-29 677 30.00\n", + "\n", + "[7453 rows x 3 columns],\n", + " fkey_col_to_pkey_table={'driverId': 'drivers'},\n", + " pkey_col=None,\n", + " time_col=date)" + ] + }, + "metadata": {}, + "execution_count": 6 + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "Note that to load the data we did not require any deep learning libraries. Now we introduce the PyTorch Frame library, which is useful for encoding individual tables into initial node features." + ], + "metadata": { + "id": "qQhuHIdHkOxv" + } + }, + { + "cell_type": "code", + "source": [ + "import os\n", + "import math\n", + "import numpy as np\n", + "from tqdm import tqdm\n", + "\n", + "import torch\n", + "import torch_geometric\n", + "import torch_frame\n", + "\n", + "# Some book keeping\n", + "from torch_geometric.seed import seed_everything\n", + "\n", + "seed_everything(42)\n", + "\n", + "\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "print(device) # check that it's cuda if you want it to run in reasonable time!\n", + "root_dir = \"./data\"" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "qNzfdwsrkPIo", + "outputId": "c985185e-b785-405e-bd46-ebf2f48e3ac6" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "cuda\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "The first big move is to build a graph out of the database. Here we use our pre-prepared conversion function.\n", + "\n", + "The source code can be found at: https://github.com/snap-stanford/relbench/blob/main/relbench/modeling/graph.py\n", + "\n", + "Each node in the graph corresonds to a single row in the database. Crucially, PyTorch Frame stores whole tables as objects in a way that is compatibile with PyG minibatch sampling, meaning we can sample subgraphs as in https://arxiv.org/abs/1706.02216, and retrieve the relevant raw features.\n", + "\n", + "PyTorch Frame also stores the `stype` (i.e., modality) of each column, and any specialized feature encoders (e.g., text encoders) to be used later. So we need to configure the `stype` for each column, for which we use a function that tries to automatically detect the `stype`." + ], + "metadata": { + "id": "0Y79g5H0kVjX" + } + }, + { + "cell_type": "code", + "source": [ + "from relbench.modeling.utils import get_stype_proposal\n", + "\n", + "db = dataset.get_db()\n", + "col_to_stype_dict = get_stype_proposal(db)\n", + "col_to_stype_dict" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "kiV3TGI-kRuy", + "outputId": "98e88ec3-ab38-4a14-8dd8-24f3ea349893" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Loading Database object from /root/.cache/relbench/rel-f1/db...\n", + "Done in 0.04 seconds.\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "{'drivers': {'driverId': ,\n", + " 'driverRef': ,\n", + " 'code': ,\n", + " 'forename': ,\n", + " 'surname': ,\n", + " 'dob': ,\n", + " 'nationality': },\n", + " 'races': {'raceId': ,\n", + " 'year': ,\n", + " 'round': ,\n", + " 'circuitId': ,\n", + " 'name': ,\n", + " 'date': ,\n", + " 'time': },\n", + " 'constructor_standings': {'constructorStandingsId': ,\n", + " 'raceId': ,\n", + " 'constructorId': ,\n", + " 'points': ,\n", + " 'position': ,\n", + " 'wins': ,\n", + " 'date': },\n", + " 'constructor_results': {'constructorResultsId': ,\n", + " 'raceId': ,\n", + " 'constructorId': ,\n", + " 'points': ,\n", + " 'date': },\n", + " 'results': {'resultId': ,\n", + " 'raceId': ,\n", + " 'driverId': ,\n", + " 'constructorId': ,\n", + " 'number': ,\n", + " 'grid': ,\n", + " 'position': ,\n", + " 'positionOrder': ,\n", + " 'points': ,\n", + " 'laps': ,\n", + " 'milliseconds': ,\n", + " 'fastestLap': ,\n", + " 'rank': ,\n", + " 'statusId': ,\n", + " 'date': },\n", + " 'qualifying': {'qualifyId': ,\n", + " 'raceId': ,\n", + " 'driverId': ,\n", + " 'constructorId': ,\n", + " 'number': ,\n", + " 'position': ,\n", + " 'date': },\n", + " 'circuits': {'circuitId': ,\n", + " 'circuitRef': ,\n", + " 'name': ,\n", + " 'location': ,\n", + " 'country': ,\n", + " 'lat': ,\n", + " 'lng': ,\n", + " 'alt': },\n", + " 'constructors': {'constructorId': ,\n", + " 'constructorRef': ,\n", + " 'name': ,\n", + " 'nationality': },\n", + " 'standings': {'driverStandingsId': ,\n", + " 'raceId': ,\n", + " 'driverId': ,\n", + " 'points': ,\n", + " 'position': ,\n", + " 'wins': ,\n", + " 'date': }}" + ] + }, + "metadata": {}, + "execution_count": 9 + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "If trying a new dataset, you should definitely check through this dict of `stype`s to check that look right, and manually change any mistakes by the auto-detection function.\n", + "\n", + "Next we also define our text encoding model, which we use GloVe embeddings for speed and convenience. Feel free to try alternatives here." + ], + "metadata": { + "id": "Sm3uYXqXkbZt" + } + }, + { + "cell_type": "code", + "source": [ + "!pip install -U sentence-transformers # we need another package for text encoding\n", + "from typing import List, Optional\n", + "from sentence_transformers import SentenceTransformer\n", + "from torch import Tensor\n", + "\n", + "\n", + "class GloveTextEmbedding:\n", + " def __init__(self, device: Optional[torch.device\n", + " ] = None):\n", + " self.model = SentenceTransformer(\n", + " \"sentence-transformers/average_word_embeddings_glove.6B.300d\",\n", + " device=device,\n", + " )\n", + "\n", + " def __call__(self, sentences: List[str]) -> Tensor:\n", + " return torch.from_numpy(self.model.encode(sentences))\n", + "\n" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "QQHYmgIxkX1j", + "outputId": "857b70dd-e7eb-4b09-a5cd-394fccef758a" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Collecting sentence-transformers\n", + " Downloading sentence_transformers-3.0.1-py3-none-any.whl (227 kB)\n", + "\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/227.1 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[91m╸\u001b[0m \u001b[32m225.3/227.1 kB\u001b[0m \u001b[31m7.1 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m227.1/227.1 kB\u001b[0m \u001b[31m5.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: transformers<5.0.0,>=4.34.0 in /usr/local/lib/python3.10/dist-packages (from sentence-transformers) (4.41.2)\n", + "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from sentence-transformers) (4.66.4)\n", + "Requirement already satisfied: torch>=1.11.0 in /usr/local/lib/python3.10/dist-packages (from sentence-transformers) (2.3.0+cu121)\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from sentence-transformers) (1.25.2)\n", + "Requirement already satisfied: scikit-learn in /usr/local/lib/python3.10/dist-packages (from sentence-transformers) (1.2.2)\n", + "Requirement already satisfied: scipy in /usr/local/lib/python3.10/dist-packages (from sentence-transformers) (1.11.4)\n", + "Requirement already satisfied: huggingface-hub>=0.15.1 in /usr/local/lib/python3.10/dist-packages (from sentence-transformers) (0.23.4)\n", + "Requirement already satisfied: Pillow in /usr/local/lib/python3.10/dist-packages (from sentence-transformers) (9.4.0)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.15.1->sentence-transformers) (3.15.4)\n", + "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.15.1->sentence-transformers) (2023.6.0)\n", + "Requirement already satisfied: packaging>=20.9 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.15.1->sentence-transformers) (24.1)\n", + "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.15.1->sentence-transformers) (6.0.1)\n", + "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.15.1->sentence-transformers) (2.32.3)\n", + "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.15.1->sentence-transformers) (4.12.2)\n", + "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.11.0->sentence-transformers) (1.12.1)\n", + "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.11.0->sentence-transformers) (3.3)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.11.0->sentence-transformers) (3.1.4)\n", + "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.11.0->sentence-transformers) (12.1.105)\n", + "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.11.0->sentence-transformers) (12.1.105)\n", + "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.11.0->sentence-transformers) (12.1.105)\n", + "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /usr/local/lib/python3.10/dist-packages (from torch>=1.11.0->sentence-transformers) (8.9.2.26)\n", + "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /usr/local/lib/python3.10/dist-packages (from torch>=1.11.0->sentence-transformers) (12.1.3.1)\n", + "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /usr/local/lib/python3.10/dist-packages (from torch>=1.11.0->sentence-transformers) (11.0.2.54)\n", + "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /usr/local/lib/python3.10/dist-packages (from torch>=1.11.0->sentence-transformers) (10.3.2.106)\n", + "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /usr/local/lib/python3.10/dist-packages (from torch>=1.11.0->sentence-transformers) (11.4.5.107)\n", + "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /usr/local/lib/python3.10/dist-packages (from torch>=1.11.0->sentence-transformers) (12.1.0.106)\n", + "Requirement already satisfied: nvidia-nccl-cu12==2.20.5 in /usr/local/lib/python3.10/dist-packages (from torch>=1.11.0->sentence-transformers) (2.20.5)\n", + "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.11.0->sentence-transformers) (12.1.105)\n", + "Requirement already satisfied: triton==2.3.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.11.0->sentence-transformers) (2.3.0)\n", + "Requirement already satisfied: nvidia-nvjitlink-cu12 in /usr/local/lib/python3.10/dist-packages (from nvidia-cusolver-cu12==11.4.5.107->torch>=1.11.0->sentence-transformers) (12.5.82)\n", + "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers<5.0.0,>=4.34.0->sentence-transformers) (2024.5.15)\n", + "Requirement already satisfied: tokenizers<0.20,>=0.19 in /usr/local/lib/python3.10/dist-packages (from transformers<5.0.0,>=4.34.0->sentence-transformers) (0.19.1)\n", + "Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.10/dist-packages (from transformers<5.0.0,>=4.34.0->sentence-transformers) (0.4.3)\n", + "Requirement already satisfied: joblib>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from scikit-learn->sentence-transformers) (1.4.2)\n", + "Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn->sentence-transformers) (3.5.0)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.11.0->sentence-transformers) (2.1.5)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub>=0.15.1->sentence-transformers) (3.3.2)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub>=0.15.1->sentence-transformers) (3.7)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub>=0.15.1->sentence-transformers) (2.0.7)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub>=0.15.1->sentence-transformers) (2024.6.2)\n", + "Requirement already satisfied: mpmath<1.4.0,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.11.0->sentence-transformers) (1.3.0)\n", + "Installing collected packages: sentence-transformers\n", + "Successfully installed sentence-transformers-3.0.1\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "from torch_frame.config.text_embedder import TextEmbedderConfig\n", + "from relbench.modeling.graph import make_pkey_fkey_graph\n", + "\n", + "text_embedder_cfg = TextEmbedderConfig(\n", + " text_embedder=GloveTextEmbedding(device=device), batch_size=256\n", + ")\n", + "\n", + "data, col_stats_dict = make_pkey_fkey_graph(\n", + " db,\n", + " col_to_stype_dict=col_to_stype_dict, # speficied column types\n", + " text_embedder_cfg=text_embedder_cfg, # our chosen text encoder\n", + " cache_dir=os.path.join(\n", + " root_dir, f\"rel-f1_materialized_cache\"\n", + " ), # store materialized graph for convenience\n", + ")" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "L-BBpUrakdwY", + "outputId": "b152bf13-f47d-4728-d58b-fc65f738b03d" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "Embedding raw data in mini-batch: 100%|██████████| 4/4 [00:00<00:00, 6.11it/s]\n", + "Embedding raw data in mini-batch: 100%|██████████| 4/4 [00:00<00:00, 177.19it/s]\n", + "Embedding raw data in mini-batch: 100%|██████████| 4/4 [00:00<00:00, 185.70it/s]\n", + "Embedding raw data in mini-batch: 100%|██████████| 4/4 [00:00<00:00, 184.06it/s]\n", + "Embedding raw data in mini-batch: 100%|██████████| 4/4 [00:00<00:00, 157.68it/s]\n", + "/usr/local/lib/python3.10/dist-packages/torch_frame/data/stats.py:177: UserWarning: Could not infer format, so each element will be parsed individually, falling back to `dateutil`. To ensure parsing is consistent and as-expected, please specify a format.\n", + " ser = pd.to_datetime(ser, format=time_format)\n", + "Embedding raw data in mini-batch: 100%|██████████| 4/4 [00:00<00:00, 144.99it/s]\n", + "/usr/local/lib/python3.10/dist-packages/torch_frame/data/mapper.py:290: UserWarning: Could not infer format, so each element will be parsed individually, falling back to `dateutil`. To ensure parsing is consistent and as-expected, please specify a format.\n", + " ser = pd.to_datetime(ser, format=self.format, errors='coerce')\n", + "Embedding raw data in mini-batch: 100%|██████████| 1/1 [00:00<00:00, 151.76it/s]\n", + "Embedding raw data in mini-batch: 100%|██████████| 1/1 [00:00<00:00, 172.90it/s]\n", + "Embedding raw data in mini-batch: 100%|██████████| 1/1 [00:00<00:00, 207.29it/s]\n", + "Embedding raw data in mini-batch: 100%|██████████| 1/1 [00:00<00:00, 148.61it/s]\n", + "Embedding raw data in mini-batch: 100%|██████████| 1/1 [00:00<00:00, 42.85it/s]\n", + "Embedding raw data in mini-batch: 100%|██████████| 1/1 [00:00<00:00, 95.30it/s]\n", + "Embedding raw data in mini-batch: 100%|██████████| 1/1 [00:00<00:00, 93.10it/s]\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "We can now check out `data`, our main graph object. `data` is a heterogeneous and temporal graph, with node types given by the table it originates from." + ], + "metadata": { + "id": "mwQejmg0kzOg" + } + }, + { + "cell_type": "code", + "source": [ + "data" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Gt4a8lw1kufy", + "outputId": "4117959f-6f0d-4c31-9489-49db7d5f3c5d" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "HeteroData(\n", + " drivers={ tf=TensorFrame([857, 6]) },\n", + " races={\n", + " tf=TensorFrame([820, 5]),\n", + " time=[820],\n", + " },\n", + " constructor_standings={\n", + " tf=TensorFrame([10170, 4]),\n", + " time=[10170],\n", + " },\n", + " constructor_results={\n", + " tf=TensorFrame([9408, 2]),\n", + " time=[9408],\n", + " },\n", + " results={\n", + " tf=TensorFrame([20323, 11]),\n", + " time=[20323],\n", + " },\n", + " qualifying={\n", + " tf=TensorFrame([4082, 3]),\n", + " time=[4082],\n", + " },\n", + " circuits={ tf=TensorFrame([77, 7]) },\n", + " constructors={ tf=TensorFrame([211, 3]) },\n", + " standings={\n", + " tf=TensorFrame([28115, 4]),\n", + " time=[28115],\n", + " },\n", + " (races, f2p_circuitId, circuits)={ edge_index=[2, 820] },\n", + " (circuits, rev_f2p_circuitId, races)={ edge_index=[2, 820] },\n", + " (constructor_standings, f2p_raceId, races)={ edge_index=[2, 10170] },\n", + " (races, rev_f2p_raceId, constructor_standings)={ edge_index=[2, 10170] },\n", + " (constructor_standings, f2p_constructorId, constructors)={ edge_index=[2, 10170] },\n", + " (constructors, rev_f2p_constructorId, constructor_standings)={ edge_index=[2, 10170] },\n", + " (constructor_results, f2p_raceId, races)={ edge_index=[2, 9408] },\n", + " (races, rev_f2p_raceId, constructor_results)={ edge_index=[2, 9408] },\n", + " (constructor_results, f2p_constructorId, constructors)={ edge_index=[2, 9408] },\n", + " (constructors, rev_f2p_constructorId, constructor_results)={ edge_index=[2, 9408] },\n", + " (results, f2p_raceId, races)={ edge_index=[2, 20323] },\n", + " (races, rev_f2p_raceId, results)={ edge_index=[2, 20323] },\n", + " (results, f2p_driverId, drivers)={ edge_index=[2, 20323] },\n", + " (drivers, rev_f2p_driverId, results)={ edge_index=[2, 20323] },\n", + " (results, f2p_constructorId, constructors)={ edge_index=[2, 20323] },\n", + " (constructors, rev_f2p_constructorId, results)={ edge_index=[2, 20323] },\n", + " (qualifying, f2p_raceId, races)={ edge_index=[2, 4082] },\n", + " (races, rev_f2p_raceId, qualifying)={ edge_index=[2, 4082] },\n", + " (qualifying, f2p_driverId, drivers)={ edge_index=[2, 4082] },\n", + " (drivers, rev_f2p_driverId, qualifying)={ edge_index=[2, 4082] },\n", + " (qualifying, f2p_constructorId, constructors)={ edge_index=[2, 4082] },\n", + " (constructors, rev_f2p_constructorId, qualifying)={ edge_index=[2, 4082] },\n", + " (standings, f2p_raceId, races)={ edge_index=[2, 28115] },\n", + " (races, rev_f2p_raceId, standings)={ edge_index=[2, 28115] },\n", + " (standings, f2p_driverId, drivers)={ edge_index=[2, 28115] },\n", + " (drivers, rev_f2p_driverId, standings)={ edge_index=[2, 28115] }\n", + ")" + ] + }, + "metadata": {}, + "execution_count": 13 + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "We can also check out the TensorFrame for one table like this:" + ], + "metadata": { + "id": "yd6DqCXgk41x" + } + }, + { + "cell_type": "code", + "source": [ + "data[\"races\"].tf" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "-mMQTQeLk1rl", + "outputId": "04d698af-a4f4-4a98-8321-b5f8dc6ee10d" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "TensorFrame(\n", + " num_cols=5,\n", + " num_rows=820,\n", + " categorical (1): ['year'],\n", + " numerical (1): ['round'],\n", + " timestamp (2): ['date', 'time'],\n", + " embedding (1): ['name'],\n", + " has_target=False,\n", + " device='cpu',\n", + ")" + ] + }, + "metadata": {}, + "execution_count": 15 + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "This may be a little confusing at first, as in graph ML it is more standard to associate to the graph object `data` a tensor, e.g., `data.x` for which `data.x[idx]` is a 1D array/tensor storing all the features for node with index `idx`.\n", + "\n", + "But actually this `data` object behaves similarly. For a given node type, e.g., `races` again, `data['races']` stores two pieces of information\n" + ], + "metadata": { + "id": "1kbysKXMk-3X" + } + }, + { + "cell_type": "code", + "source": [ + "list(data[\"races\"].keys())" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "cDIcp7L5k6pU", + "outputId": "be742ecb-02db-43e6-9c12-53fb00e51fec" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "['tf', 'time']" + ] + }, + "metadata": {}, + "execution_count": 17 + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "A `TensorFrame` object, and a timestamp for each node. The `TensorFrame` object acts analogously to the usual tensor of node features, and you can simply use indexing to retrieve the features of a single row (node), or group of nodes." + ], + "metadata": { + "id": "Z18qPRPllB1H" + } + }, + { + "cell_type": "code", + "source": [ + "data[\"races\"].tf[10]" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Im8bhNh5lFG6", + "outputId": "7ae1e0bd-746a-4164-fc95-124e302c816d" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "TensorFrame(\n", + " num_cols=5,\n", + " num_rows=1,\n", + " categorical (1): ['year'],\n", + " numerical (1): ['round'],\n", + " timestamp (2): ['date', 'time'],\n", + " embedding (1): ['name'],\n", + " has_target=False,\n", + " device='cpu',\n", + ")" + ] + }, + "metadata": {}, + "execution_count": 18 + } + ] + }, + { + "cell_type": "code", + "source": [ + "data[\"races\"].tf[10:20]" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "eYZ28pzNlG4s", + "outputId": "1066a167-2e3b-4ad6-d929-02acf18cad0e" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "TensorFrame(\n", + " num_cols=5,\n", + " num_rows=10,\n", + " categorical (1): ['year'],\n", + " numerical (1): ['round'],\n", + " timestamp (2): ['date', 'time'],\n", + " embedding (1): ['name'],\n", + " has_target=False,\n", + " device='cpu',\n", + ")" + ] + }, + "metadata": {}, + "execution_count": 19 + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "We can also check the edge indices between two different node types, such as `races` amd `circuits`. Note that the edges are also heterogenous, so we also need to specify which edge type we want to look at. Here we look at `f2p_curcuitId`, which are the directed edges pointing _from_ a race (the `f` stands for `foreign key`), _to_ the circuit at which te race happened (the `p` stands for `primary key`)." + ], + "metadata": { + "id": "Ql15svcelK3A" + } + }, + { + "cell_type": "code", + "source": [ + "data[(\"races\", \"f2p_circuitId\", \"circuits\")]" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "TynkD36QlInL", + "outputId": "abc2f80d-5ff4-42b1-f9e3-bd9f004d84ee" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "{'edge_index': tensor([[ 0, 1, 2, ..., 817, 818, 819],\n", + " [ 8, 5, 18, ..., 21, 17, 23]])}" + ] + }, + "metadata": {}, + "execution_count": 20 + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "Now we are ready to instantiate our data loaders. For this we will need to import PyTorch Geometric, our GNN library. Whilst we're at it let's add a seed.\n" + ], + "metadata": { + "id": "Xx4V5KCelNxl" + } + }, + { + "cell_type": "code", + "source": [ + "from relbench.modeling.graph import get_node_train_table_input, make_pkey_fkey_graph\n", + "from torch_geometric.loader import NeighborLoader\n", + "\n", + "loader_dict = {}\n", + "\n", + "for split, table in [\n", + " (\"train\", train_table),\n", + " (\"val\", val_table),\n", + " (\"test\", test_table),\n", + "]:\n", + " table_input = get_node_train_table_input(\n", + " table=table,\n", + " task=task,\n", + " )\n", + " entity_table = table_input.nodes[0]\n", + " loader_dict[split] = NeighborLoader(\n", + " data,\n", + " num_neighbors=[\n", + " 128 for i in range(2)\n", + " ], # we sample subgraphs of depth 2, 128 neighbors per node.\n", + " time_attr=\"time\",\n", + " input_nodes=table_input.nodes,\n", + " input_time=table_input.time,\n", + " transform=table_input.transform,\n", + " batch_size=512,\n", + " temporal_strategy=\"uniform\",\n", + " shuffle=split == \"train\",\n", + " num_workers=0,\n", + " persistent_workers=False,\n", + " )" + ], + "metadata": { + "id": "HUHVG-g6lM-b" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Now we need our model...\n", + "\n", + "\n" + ], + "metadata": { + "id": "BQc8BWsGludR" + } + }, + { + "cell_type": "code", + "source": [ + "from torch.nn import BCEWithLogitsLoss\n", + "import copy\n", + "from typing import Any, Dict, List\n", + "\n", + "import torch\n", + "from torch import Tensor\n", + "from torch.nn import Embedding, ModuleDict\n", + "from torch_frame.data.stats import StatType\n", + "from torch_geometric.data import HeteroData\n", + "from torch_geometric.nn import MLP\n", + "from torch_geometric.typing import NodeType\n", + "\n", + "from relbench.modeling.nn import HeteroEncoder, HeteroGraphSAGE, HeteroTemporalEncoder\n", + "\n", + "\n", + "class Model(torch.nn.Module):\n", + "\n", + " def __init__(\n", + " self,\n", + " data: HeteroData,\n", + " col_stats_dict: Dict[str, Dict[str, Dict[StatType, Any]]],\n", + " num_layers: int,\n", + " channels: int,\n", + " out_channels: int,\n", + " aggr: str,\n", + " norm: str,\n", + " # List of node types to add shallow embeddings to input\n", + " shallow_list: List[NodeType] = [],\n", + " # ID awareness\n", + " id_awareness: bool = False,\n", + " ):\n", + " super().__init__()\n", + "\n", + " self.encoder = HeteroEncoder(\n", + " channels=channels,\n", + " node_to_col_names_dict={\n", + " node_type: data[node_type].tf.col_names_dict\n", + " for node_type in data.node_types\n", + " },\n", + " node_to_col_stats=col_stats_dict,\n", + " )\n", + " self.temporal_encoder = HeteroTemporalEncoder(\n", + " node_types=[\n", + " node_type for node_type in data.node_types if \"time\" in data[node_type]\n", + " ],\n", + " channels=channels,\n", + " )\n", + " self.gnn = HeteroGraphSAGE(\n", + " node_types=data.node_types,\n", + " edge_types=data.edge_types,\n", + " channels=channels,\n", + " aggr=aggr,\n", + " num_layers=num_layers,\n", + " )\n", + " self.head = MLP(\n", + " channels,\n", + " out_channels=out_channels,\n", + " norm=norm,\n", + " num_layers=1,\n", + " )\n", + " self.embedding_dict = ModuleDict(\n", + " {\n", + " node: Embedding(data.num_nodes_dict[node], channels)\n", + " for node in shallow_list\n", + " }\n", + " )\n", + "\n", + " self.id_awareness_emb = None\n", + " if id_awareness:\n", + " self.id_awareness_emb = torch.nn.Embedding(1, channels)\n", + " self.reset_parameters()\n", + "\n", + " def reset_parameters(self):\n", + " self.encoder.reset_parameters()\n", + " self.temporal_encoder.reset_parameters()\n", + " self.gnn.reset_parameters()\n", + " self.head.reset_parameters()\n", + " for embedding in self.embedding_dict.values():\n", + " torch.nn.init.normal_(embedding.weight, std=0.1)\n", + " if self.id_awareness_emb is not None:\n", + " self.id_awareness_emb.reset_parameters()\n", + "\n", + " def forward(\n", + " self,\n", + " batch: HeteroData,\n", + " entity_table: NodeType,\n", + " ) -> Tensor:\n", + " seed_time = batch[entity_table].seed_time\n", + " x_dict = self.encoder(batch.tf_dict)\n", + "\n", + " rel_time_dict = self.temporal_encoder(\n", + " seed_time, batch.time_dict, batch.batch_dict\n", + " )\n", + "\n", + " for node_type, rel_time in rel_time_dict.items():\n", + " x_dict[node_type] = x_dict[node_type] + rel_time\n", + "\n", + " for node_type, embedding in self.embedding_dict.items():\n", + " x_dict[node_type] = x_dict[node_type] + embedding(batch[node_type].n_id)\n", + "\n", + " x_dict = self.gnn(\n", + " x_dict,\n", + " batch.edge_index_dict,\n", + " batch.num_sampled_nodes_dict,\n", + " batch.num_sampled_edges_dict,\n", + " )\n", + "\n", + " return self.head(x_dict[entity_table][: seed_time.size(0)])\n", + "\n", + " def forward_dst_readout(\n", + " self,\n", + " batch: HeteroData,\n", + " entity_table: NodeType,\n", + " dst_table: NodeType,\n", + " ) -> Tensor:\n", + " if self.id_awareness_emb is None:\n", + " raise RuntimeError(\n", + " \"id_awareness must be set True to use forward_dst_readout\"\n", + " )\n", + " seed_time = batch[entity_table].seed_time\n", + " x_dict = self.encoder(batch.tf_dict)\n", + " # Add ID-awareness to the root node\n", + " x_dict[entity_table][: seed_time.size(0)] += self.id_awareness_emb.weight\n", + "\n", + " rel_time_dict = self.temporal_encoder(\n", + " seed_time, batch.time_dict, batch.batch_dict\n", + " )\n", + "\n", + " for node_type, rel_time in rel_time_dict.items():\n", + " x_dict[node_type] = x_dict[node_type] + rel_time\n", + "\n", + " for node_type, embedding in self.embedding_dict.items():\n", + " x_dict[node_type] = x_dict[node_type] + embedding(batch[node_type].n_id)\n", + "\n", + " x_dict = self.gnn(\n", + " x_dict,\n", + " batch.edge_index_dict,\n", + " )\n", + "\n", + " return self.head(x_dict[dst_table])\n", + "\n", + "\n", + "model = Model(\n", + " data=data,\n", + " col_stats_dict=col_stats_dict,\n", + " num_layers=2,\n", + " channels=128,\n", + " out_channels=1,\n", + " aggr=\"sum\",\n", + " norm=\"batch_norm\",\n", + ").to(device)\n", + "\n", + "\n", + "# if you try out different RelBench tasks you will need to change these\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=0.005)\n", + "epochs = 10" + ], + "metadata": { + "id": "u3m3jEqClQnw" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "We also need standard train/test loops" + ], + "metadata": { + "id": "Vl-6So7Llb-p" + } + }, + { + "cell_type": "code", + "source": [ + "def train() -> float:\n", + " model.train()\n", + "\n", + " loss_accum = count_accum = 0\n", + " for batch in tqdm(loader_dict[\"train\"]):\n", + " batch = batch.to(device)\n", + "\n", + " optimizer.zero_grad()\n", + " pred = model(\n", + " batch,\n", + " task.entity_table,\n", + " )\n", + " pred = pred.view(-1) if pred.size(1) == 1 else pred\n", + "\n", + " loss = loss_fn(pred.float(), batch[entity_table].y.float())\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " loss_accum += loss.detach().item() * pred.size(0)\n", + " count_accum += pred.size(0)\n", + "\n", + " return loss_accum / count_accum\n", + "\n", + "\n", + "@torch.no_grad()\n", + "def test(loader: NeighborLoader) -> np.ndarray:\n", + " model.eval()\n", + "\n", + " pred_list = []\n", + " for batch in loader:\n", + " batch = batch.to(device)\n", + " pred = model(\n", + " batch,\n", + " task.entity_table,\n", + " )\n", + " pred = pred.view(-1) if pred.size(1) == 1 else pred\n", + " pred_list.append(pred.detach().cpu())\n", + " return torch.cat(pred_list, dim=0).numpy()" + ], + "metadata": { + "id": "SAHRIr15lVs6" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Now we are ready to train!" + ], + "metadata": { + "id": "4s-p7dW1ledd" + } + }, + { + "cell_type": "code", + "source": [ + "state_dict = None\n", + "best_val_metric = -math.inf if higher_is_better else math.inf\n", + "for epoch in range(1, epochs + 1):\n", + " train_loss = train()\n", + " val_pred = test(loader_dict[\"val\"])\n", + " val_metrics = task.evaluate(val_pred, val_table)\n", + " print(f\"Epoch: {epoch:02d}, Train loss: {train_loss}, Val metrics: {val_metrics}\")\n", + "\n", + " if (higher_is_better and val_metrics[tune_metric] > best_val_metric) or (\n", + " not higher_is_better and val_metrics[tune_metric] < best_val_metric\n", + " ):\n", + " best_val_metric = val_metrics[tune_metric]\n", + " state_dict = copy.deepcopy(model.state_dict())\n", + "\n", + "\n", + "model.load_state_dict(state_dict)\n", + "val_pred = test(loader_dict[\"val\"])\n", + "val_metrics = task.evaluate(val_pred, val_table)\n", + "print(f\"Best Val metrics: {val_metrics}\")\n", + "\n", + "test_pred = test(loader_dict[\"test\"])\n", + "test_metrics = task.evaluate(test_pred)\n", + "print(f\"Best test metrics: {test_metrics}\")" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "yF3W68Eqlew_", + "outputId": "a81a48dc-234a-47f3-8759-8dc9a766661c" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 15/15 [00:03<00:00, 4.55it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Epoch: 01, Train loss: 4.5940603298195954, Val metrics: {'r2': 0.26664876365429346, 'mae': 3.192621118001486, 'rmse': 3.970144408166473}\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 15/15 [00:02<00:00, 5.32it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Epoch: 02, Train loss: 4.568261575973963, Val metrics: {'r2': 0.28401349234585715, 'mae': 3.1377921566297466, 'rmse': 3.9228590932193326}\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 15/15 [00:02<00:00, 5.33it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Epoch: 03, Train loss: 4.508331975344442, Val metrics: {'r2': 0.2779053222548966, 'mae': 3.145869382063229, 'rmse': 3.9395567561975118}\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 15/15 [00:02<00:00, 5.21it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Epoch: 04, Train loss: 4.454095084475588, Val metrics: {'r2': 0.2748767008953672, 'mae': 3.1624791348545886, 'rmse': 3.9478097883334575}\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 15/15 [00:03<00:00, 3.98it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Epoch: 05, Train loss: 4.428226253993716, Val metrics: {'r2': 0.2624319702044985, 'mae': 3.2122242543724435, 'rmse': 3.9815422768300364}\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 15/15 [00:02<00:00, 5.23it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Epoch: 06, Train loss: 4.3808791889818695, Val metrics: {'r2': 0.23813335297839988, 'mae': 3.2412406909282634, 'rmse': 4.046595277497622}\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 15/15 [00:03<00:00, 4.95it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Epoch: 07, Train loss: 4.338853529823013, Val metrics: {'r2': 0.2456946012502239, 'mae': 3.247610264136621, 'rmse': 4.026464715586905}\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 15/15 [00:02<00:00, 5.10it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Epoch: 08, Train loss: 4.280953785724709, Val metrics: {'r2': 0.2494808324189115, 'mae': 3.227723037073751, 'rmse': 4.016346595592636}\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 15/15 [00:03<00:00, 4.67it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Epoch: 09, Train loss: 4.242680662992767, Val metrics: {'r2': 0.25660482780398064, 'mae': 3.185533819345132, 'rmse': 3.997239384230069}\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 15/15 [00:02<00:00, 5.23it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Epoch: 10, Train loss: 4.173513173641142, Val metrics: {'r2': 0.24525833749507364, 'mae': 3.2292002695755078, 'rmse': 4.027628930176982}\n", + "Best Val metrics: {'r2': 0.2855097827935169, 'mae': 3.132982980591819, 'rmse': 3.918757894088351}\n", + "Best test metrics: {'r2': -0.026182202542586408, 'mae': 4.38193179076178, 'rmse': 5.278154456184378}\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [], + "metadata": { + "id": "58i-5Z508liB" + }, + "execution_count": null, + "outputs": [] + } + ] +}