From 54576350a52b023a8091fbfd7b5a32058a392222 Mon Sep 17 00:00:00 2001 From: Rishabh Ranjan Date: Mon, 8 Jul 2024 17:20:22 -0700 Subject: [PATCH 1/4] reached package usage section --- README.md | 33 ++++++++++++++++++++------------- pyproject.toml | 1 + 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 2e7c9db1..4f6efb21 100644 --- a/README.md +++ b/README.md @@ -8,10 +8,11 @@ [![License: MIT](https://img.shields.io/badge/License-MIT-green.svg)](https://opensource.org/licenses/MIT) [![Twitter](https://img.shields.io/twitter/url/https/twitter.com/cloudposse.svg?style=social&label=Follow%20%40RelBench)](https://twitter.com/RelBench) -**Get Started:** loading data   [](https://colab.research.google.com/drive/1PAOktBqh_3QzgAKi53F4JbQxoOuBsUBY?usp=sharing), training model   [](https://colab.research.google.com/drive/1_z0aKcs5XndEacX1eob6csDuR4DYhGQU?usp=sharing). + - [ **Website**](https://relbench.stanford.edu) | [**Vision Paper**](https://relbench.stanford.edu/paper.pdf) | [**Benchmark Paper**](https://relbench.stanford.edu/paper.pdf) | [**Mailing List**](https://groups.google.com/forum/#!forum/relbench/join) + +[**Website**](https://relbench.stanford.edu) | [**Position Paper**](https://relbench.stanford.edu/paper.pdf) | [**Benchmark Paper [TODO]**](https://relbench.stanford.edu/paper.pdf) | [**Mailing List**](https://groups.google.com/forum/#!forum/relbench/join) # Overview @@ -19,7 +20,9 @@ -Relational Deep Learning is a new approach for end-to-end representation learning on data spread across multiple tables, such as in a _relational database_ (see our [vision paper](https://relbench.stanford.edu/paper.pdf)). Relational databases are the world's most widely used database management system, and are used for industrial and scientific purposes accross many domains. RelBench is a benchmark designed to facilitate efficient, robust and reproducible research in end-to-end deep learning on relational databases. RelBench contains 7 realistic, large-scale, and diverse relational databases spanning domains including medical, social networks, e-commerce and sport. Each database has multiple predictive tasks (29 in total) defined, each carefully scoped to be both challenging and of domain-specific importance. It provides full support for data downloading, task specification and standardized evaluation in an ML-framework-agnostic manner. +Relational Deep Learning is a new approach for end-to-end representation learning on data spread across multiple tables, such as in a _relational database_ (see our [position paper](https://relbench.stanford.edu/paper.pdf)). Relational databases are the world's most widely used data management system, and are used for industrial and scientific purposes across many domains. RelBench is a benchmark designed to facilitate efficient, robust and reproducible research on end-to-end deep learning for relational databases. + +RelBench contains 7 realistic, large-scale, and diverse relational databases spanning domains including medical, social networks, e-commerce and sport. Each database has multiple predictive tasks (29 in total) defined, each carefully scoped to be both challenging and of domain-specific importance. It provides full support for data downloading, task specification and standardized evaluation in an ML-framework-agnostic manner. Additionally, RelBench provides a first open-source implementation of a Graph Neural Network based approach to relational deep learning. This implementation uses [PyTorch Geometric](https://github.com/pyg-team/pytorch_geometric) to load the data as a graph and train GNN models, and [PyTorch Frame](https://github.com/pyg-team/pytorch-frame) to encode the various types of table columns. Finally, there is an open [leaderboard](https://huggingface.co/relbench) for tracking progress. @@ -27,13 +30,13 @@ Additionally, RelBench provides a first open-source implementation of a Graph Ne # Key Papers - [**RelBench Paper**](https://relbench.stanford.edu/paper.pdf) [RelBench: A Benchmark for Deep Learning +[**Benchmark Paper**](https://relbench.stanford.edu/paper.pdf) [RelBench: A Benchmark for Deep Learning on Relational Databases.] This paper details our approach to designing the RelBench benchmark. It also includes a key user study showing that relational deep learning can produce performant models with a fraction of the manual human effort required by typical data science pipelines. This paper is useful for a detailed understanding of RelBench and our initial benchmarking results. If you just want to quickly familiarize with the data and tasks, the [**website**](https://relbench.stanford.edu) is a better place to start. - [**Vision Paper**](https://relbench.stanford.edu/paper.pdf) [Relational Deep Learning: Graph Representation +[**Position Paper (ICML 2024)**](https://relbench.stanford.edu/paper.pdf) [Relational Deep Learning: Graph Representation Learning on Relational Databases.] This paper outlines our proposal for how to do end-to-end deep learning on relational databases by combining graph neural networsk with deep tabular models. We reccomend reading this paper if you want to think about new methods for end-to-end deep learning on relational databases. The paper includes a section on possible directions for future research to give a snapshot of some of the research possilibities there are in this area. @@ -45,8 +48,8 @@ This paper outlines our proposal for how to do end-to-end deep learning on relat

logo

RelBench has the following main components: -1. 7 databases, each automatically downloadable for ease of use (with the exception of H&M, for which RelBench gives other instructions) -2. Easy 1-line loading of data, including loading the raw tables, and also code for constructing a graph from pkey-fkey links +1. 7 databases with a total of 30 tasks; both of these automatically downloadable for ease of use +2. Easy data loading, and graph construction from pkey-fkey links 3. Your own model, which can use any deep learning stack since RelBench is framework-agnostic. We provide a first model implementation using PyTorch Geometric and PyTorch Frame. 4. Standardized evaluators - all you need to do is produce a list of predictions for test samples, and RelBench computes metrics to ensure standardized evaluation 5. A leaderboard you can upload your results to, to track SOTA progress. @@ -55,19 +58,23 @@ RelBench has the following main components: # Installation You can install RelBench using `pip`: - -``` +```bash pip install relbench ``` -This will allow usage of the RelBench data and task loading functionality. To additionally use the example GNN scripts in the ```examples``` directory, and the graph-related helper functions found in ```relbench/modeling``` it is also necessary to install [PyTorch Geometric](https://github.com/pyg-team/pytorch_geometric) and [PyTorch Frame](https://github.com/pyg-team/pytorch-frame). PyTorch Frame can simply be installed with +This will allow usage of the core RelBench data and task loading functionality. +To additionally use `relbench.modeling`, which requires [PyTorch](https://pytorch.org/), [PyTorch Geometric](https://github.com/pyg-team/pytorch_geometric) and [PyTorch Frame](https://github.com/pyg-team/pytorch-frame), install these dependencies manually or do: +```bash +pip install relbench[full] ``` -pip install pytorch_frame -``` -and the PyTorch Geometric installation instructions can be found [here](https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html). Note that as well as ```torch_geometric```, you will also need to install the optional dependencies ```pyg_lib```, ```torch_scatter```, ```torch_sparse```. + +To run the example GNN scripts in the `examples` directory, use: +```bash +pip install relbench[example] +``` # Package Usage diff --git a/pyproject.toml b/pyproject.toml index b0bdac79..f845aa1d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ full=[ example=[ "sentence-transformers", "pytorch_frame[full]", + "torch_geometric", ] test=[ "pytest", From 4661732301e5ec36bd05edab98eeca378196b707 Mon Sep 17 00:00:00 2001 From: Rishabh Ranjan Date: Mon, 8 Jul 2024 18:51:51 -0700 Subject: [PATCH 2/4] finish updating readme --- README.md | 95 +++++++++++++++++++++++++++++++++++++------------------ 1 file changed, 65 insertions(+), 30 deletions(-) diff --git a/README.md b/README.md index 4f6efb21..4a50650b 100644 --- a/README.md +++ b/README.md @@ -20,11 +20,11 @@ -Relational Deep Learning is a new approach for end-to-end representation learning on data spread across multiple tables, such as in a _relational database_ (see our [position paper](https://relbench.stanford.edu/paper.pdf)). Relational databases are the world's most widely used data management system, and are used for industrial and scientific purposes across many domains. RelBench is a benchmark designed to facilitate efficient, robust and reproducible research on end-to-end deep learning for relational databases. +Relational Deep Learning is a new approach for end-to-end representation learning on data spread across multiple tables, such as in a _relational database_ (see our [position paper](https://relbench.stanford.edu/paper.pdf)). Relational databases are the world's most widely used data management system, and are used for industrial and scientific purposes across many domains. RelBench is a benchmark designed to facilitate efficient, robust and reproducible research on end-to-end deep learning over relational databases. -RelBench contains 7 realistic, large-scale, and diverse relational databases spanning domains including medical, social networks, e-commerce and sport. Each database has multiple predictive tasks (29 in total) defined, each carefully scoped to be both challenging and of domain-specific importance. It provides full support for data downloading, task specification and standardized evaluation in an ML-framework-agnostic manner. +RelBench contains 7 realistic, large-scale, and diverse relational databases spanning domains including medical, social networks, e-commerce and sport. Each database has multiple predictive tasks (30 in total) defined, each carefully scoped to be both challenging and of domain-specific importance. It provides full support for data downloading, task specification and standardized evaluation in an ML-framework-agnostic manner. -Additionally, RelBench provides a first open-source implementation of a Graph Neural Network based approach to relational deep learning. This implementation uses [PyTorch Geometric](https://github.com/pyg-team/pytorch_geometric) to load the data as a graph and train GNN models, and [PyTorch Frame](https://github.com/pyg-team/pytorch-frame) to encode the various types of table columns. Finally, there is an open [leaderboard](https://huggingface.co/relbench) for tracking progress. +Additionally, RelBench provides a first open-source implementation of a Graph Neural Network based approach to relational deep learning. This implementation uses [PyTorch Geometric](https://github.com/pyg-team/pytorch_geometric) to load the data as a graph and train GNN models, and [PyTorch Frame](https://github.com/pyg-team/pytorch-frame) for modeling tabular data. Finally, there is an open [leaderboard](https://huggingface.co/relbench) for tracking progress. @@ -76,36 +76,69 @@ To run the example GNN scripts in the `examples` directory, use: pip install relbench[example] ``` + # Package Usage -Here we describe key functions of RelBench. RelBench provides a collection of APIs for easy access to machine-learning-ready relational databases. +This section provides a brief overview of using the RelBench package. For a more in-depth coverage see the [Tutorials](#tutorials) section. For detailed documentations, please see the code directly. + +Imports: +```python +from relbench.base import Table, Database, Dataset, NodeTask +from relbench.datasets import get_dataset +from relbench.tasks import get_task +``` -To see all available datasets: +Get a dataset, e.g., `rel-amazon`: ```python -from relbench.datasets import dataset_names -print(dataset_names) +dataset: Dataset = get_dataset("rel-amazon", download=True) ``` -For a concrete example, to obtain the `rel-stack` relational database, a database of questions and answers from Stack Exchange, do: +
+ Details on downloading and caching behavior. + +RelBench datasets (and tasks) are cached to disk (usually at `~/.cache/relbench`). If not present in cache, `download=True` downloads the data, verifies it against the known hash, and caches it. If present, `download=True` performs the verification and avoids downloading if verification succeeds. This is the recommended way. +`download=False` uses the cached data without verification, if present, or processes and caches the data from scratch / raw sources otherwise. +
+ +`dataset` consists of a `Database` object and temporal splitting times `dataset.val_timestamp` and `dataset.test_timestamp`. + +To get the database: ```python -from relbench.datasets import get_dataset -dataset = get_dataset(name="rel-stack") +db: Database = dataset.get_db() ``` -To see the tasks available for this dataset: +
+ Preventing temporal leakage + +By default, rows with timestamp > `dataset.test_timestamp` are excluded to prevent accidental temporal leakage. The full database can be obtained with: ```python -print(dataset.task_names) +full_db: Database = dataset.get_db(upto_test_timestamp=False) ``` +
-Next, to retrieve the `posts-votes` predictive task, which is to predict the upvotes of a post it will receive in the next 2 years, simply do: +Various tasks can be defined on a dataset. For example, to get the `user-churn` task for `rel-amazon`: +```python +task: NodeTask = get_task("rel-amazon", "user-churn", download=True) +``` + +A task provides train/val/test tables: +```python +train_table: Table = task.get_table("train") +val_table: Table = task.get_table("val") +test_table: Table = task.get_table("test") +``` + +
+ Preventing test leakage +By default, the target labels are hidden from the test table to prevent accidental data leakage. The full test table can be obtained with: ```python -task = dataset.get_task("post-votes") -task.train_table, task.val_table, task.test_table # training/validation/testing tables +full_test_table: Table = task.get_table("test", mask_input_cols=False) ``` +
-The training/validation/testing tables are automatically generated using pre-defined standardized temporal split. You can then build your favorite relational deep learning model on top of it. After training and validation, you can make prediction from your model on `task.test_table`. Suppose your prediction `test_pred` is an array following the order of `task.test_table`, you can call the following to retrieve the unified evaluation metrics: +You can build your model on top of the database and the task tables. After training and validation, you can make prediction from your model on the test table. Suppose your prediction `test_pred` is a NumPy array following the order of `task.test_table`, you can call the following to get the evaluation metrics: ```python task.evaluate(test_pred) @@ -113,32 +146,34 @@ task.evaluate(test_pred) Additionally, you can evaluate validation (or training) predictions as such: ```python -task.evaluate(val_pred, task.val_table) +task.evaluate(val_pred, val_table) ``` # Tutorials -To get started with RelBench, we provide some helpful Colab notebook tutorials. For now these tutorials cover (i) how to load data using RelBench, focusing on providing users with the understanding of RelBench data logic needed to use RelBench data freely with any desired ML models, and (ii) training a GNN predictive model to solve any tasks in RelBench. +To get started with RelBench, we provide some helpful Colab notebook tutorials. These tutorials cover (i) how to load data using RelBench, focusing on providing users with the understanding of RelBench data logic needed to use RelBench data freely with any desired ML models, and (ii) training a GNN predictive model to solve tasks in RelBench. Please refer to the code for more detailed documentation. -| Name | Description | -|-------|---------------------------------------------------------| -| Loading Data   [](https://colab.research.google.com/drive/1PAOktBqh_3QzgAKi53F4JbQxoOuBsUBY?usp=sharing) | How to load and explore RelBench data -| Training models   [](https://colab.research.google.com/drive/1_z0aKcs5XndEacX1eob6csDuR4DYhGQU?usp=sharing)| Train your first GNN-based model on RelBench. | +| Name | Colab | Description | +|-------|-------|---------------------------------------------------------| +| Loading Data | [](https://colab.research.google.com/drive/1PAOktBqh_3QzgAKi53F4JbQxoOuBsUBY?usp=sharing) | How to load and explore RelBench data +| Training models | [](https://colab.research.google.com/drive/1_z0aKcs5XndEacX1eob6csDuR4DYhGQU?usp=sharing)| Train your first GNN-based model on RelBench. | # Cite RelBench -If you use RelBench in your work, please cite our position paper and benchmark paper: -``` -@article{relationaldeeplearning, - title={Relational Deep Learning: Graph Representation Learning on Relational Tables}, - author={Matthias Fey, Weihua Hu, Kexin Huang, Jan Eric Lenssen, Rishabh Ranjan, Joshua Robinson, Rex Ying, Jiaxuan You, Jure Leskovec}, - journal={ICML Position Paper} - year={2024} +If you use RelBench in your work, please cite our position and benchmark papers: + +```bibtex +@inproceedings{rdl, + title={Position: Relational Deep Learning - Graph Representation Learning on Relational Databases}, + author={Fey, Matthias and Hu, Weihua and Huang, Kexin and Lenssen, Jan Eric and Ranjan, Rishabh and Robinson, Joshua and Ying, Rex and You, Jiaxuan and Leskovec, Jure}, + booktitle={Forty-first International Conference on Machine Learning} } ``` -``` +__[TODO: update with arxiv citation]__ + +```bibtex @article{relbench, title={RelBench: A Benchmark for Deep Learning on Relational Databases}, author={Joshua Robinson, Rishabh Ranjan, Weihua Hu, Kexin Huang, Jiaqi Han, Alejandro Dobles, Matthias Fey, Jan Eric Lenssen, Yiwen Yuan, Zecheng Zhang, Xinwei He, Jure Leskovec}, From 55e19cc8e407c3e4992003ab08d056213d2e4f43 Mon Sep 17 00:00:00 2001 From: Rishabh Ranjan Date: Mon, 8 Jul 2024 18:52:39 -0700 Subject: [PATCH 3/4] remove demo because we have colab tutorials now --- ...ench - database and task exploration.ipynb | 1125 ----------------- 1 file changed, 1125 deletions(-) delete mode 100644 examples/demos/RelBench - database and task exploration.ipynb diff --git a/examples/demos/RelBench - database and task exploration.ipynb b/examples/demos/RelBench - database and task exploration.ipynb deleted file mode 100644 index a2347b6e..00000000 --- a/examples/demos/RelBench - database and task exploration.ipynb +++ /dev/null @@ -1,1125 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "f8c6b340", - "metadata": {}, - "source": [ - "First, `pip install` if you haven't already" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "2d2756f0", - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Requirement already satisfied: relbench in /Users/joshuarobinson/anaconda3/lib/python3.11/site-packages (0.1.1)\n", - "Requirement already satisfied: pandas in /Users/joshuarobinson/anaconda3/lib/python3.11/site-packages (from relbench) (2.2.2)\n", - "Requirement already satisfied: pooch in /Users/joshuarobinson/anaconda3/lib/python3.11/site-packages (from relbench) (1.8.0)\n", - "Requirement already satisfied: pyarrow in /Users/joshuarobinson/anaconda3/lib/python3.11/site-packages (from relbench) (11.0.0)\n", - "Requirement already satisfied: numpy in /Users/joshuarobinson/anaconda3/lib/python3.11/site-packages (from relbench) (1.26.4)\n", - "Requirement already satisfied: duckdb in /Users/joshuarobinson/anaconda3/lib/python3.11/site-packages (from relbench) (1.0.0)\n", - "Requirement already satisfied: requests in /Users/joshuarobinson/anaconda3/lib/python3.11/site-packages (from relbench) (2.31.0)\n", - "Requirement already satisfied: tqdm in /Users/joshuarobinson/anaconda3/lib/python3.11/site-packages (from relbench) (4.65.0)\n", - "Requirement already satisfied: scikit-learn in /Users/joshuarobinson/anaconda3/lib/python3.11/site-packages (from relbench) (1.3.0)\n", - "Requirement already satisfied: typing-extensions in /Users/joshuarobinson/anaconda3/lib/python3.11/site-packages (from relbench) (4.12.2)\n", - "Requirement already satisfied: python-dateutil>=2.8.2 in /Users/joshuarobinson/anaconda3/lib/python3.11/site-packages (from pandas->relbench) (2.8.2)\n", - "Requirement already satisfied: pytz>=2020.1 in /Users/joshuarobinson/anaconda3/lib/python3.11/site-packages (from pandas->relbench) (2023.3.post1)\n", - "Requirement already satisfied: tzdata>=2022.7 in /Users/joshuarobinson/anaconda3/lib/python3.11/site-packages (from pandas->relbench) (2023.3)\n", - "Requirement already satisfied: platformdirs>=2.5.0 in /Users/joshuarobinson/anaconda3/lib/python3.11/site-packages (from pooch->relbench) (3.10.0)\n", - "Requirement already satisfied: packaging>=20.0 in /Users/joshuarobinson/anaconda3/lib/python3.11/site-packages (from pooch->relbench) (23.1)\n", - "Requirement already satisfied: charset-normalizer<4,>=2 in /Users/joshuarobinson/anaconda3/lib/python3.11/site-packages (from requests->relbench) (2.0.4)\n", - "Requirement already satisfied: idna<4,>=2.5 in /Users/joshuarobinson/anaconda3/lib/python3.11/site-packages (from requests->relbench) (3.4)\n", - "Requirement already satisfied: urllib3<3,>=1.21.1 in /Users/joshuarobinson/anaconda3/lib/python3.11/site-packages (from requests->relbench) (1.26.16)\n", - "Requirement already satisfied: certifi>=2017.4.17 in /Users/joshuarobinson/anaconda3/lib/python3.11/site-packages (from requests->relbench) (2023.7.22)\n", - "Requirement already satisfied: scipy>=1.5.0 in /Users/joshuarobinson/anaconda3/lib/python3.11/site-packages (from scikit-learn->relbench) (1.11.4)\n", - "Requirement already satisfied: joblib>=1.1.1 in /Users/joshuarobinson/anaconda3/lib/python3.11/site-packages (from scikit-learn->relbench) (1.2.0)\n", - "Requirement already satisfied: threadpoolctl>=2.0.0 in /Users/joshuarobinson/anaconda3/lib/python3.11/site-packages (from scikit-learn->relbench) (2.2.0)\n", - "Requirement already satisfied: six>=1.5 in /Users/joshuarobinson/anaconda3/lib/python3.11/site-packages (from python-dateutil>=2.8.2->pandas->relbench) (1.16.0)\n" - ] - } - ], - "source": [ - "!pip install relbench" - ] - }, - { - "cell_type": "markdown", - "id": "318f9c83", - "metadata": {}, - "source": [ - "# Load database\n", - "\n", - "All it takes is one line!\n" - ] - }, - { - "cell_type": "code", - "execution_count": 70, - "id": "f099d564", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "making Database object from raw files...\n", - "done in 0.12 seconds.\n", - "reindexing pkeys and fkeys...\n", - "done in 0.01 seconds.\n", - "caching Database object to /Users/joshuarobinson/Library/Caches/relbench/rel-f1/db...\n", - "done in 0.02 seconds.\n", - "use process=False to load from cache.\n" - ] - } - ], - "source": [ - "from relbench.datasets import get_dataset\n", - "\n", - "dataset = get_dataset(\n", - " name=\"rel-f1\", process=True\n", - ") # other options to try include 'rel-amazon', 'rel-stack'" - ] - }, - { - "cell_type": "markdown", - "id": "e6123b38", - "metadata": {}, - "source": [ - "Use `process=True` the first time you load a patricular dataset to automatically download the data it's origin source onto your machine. From then on you can set `process=False` for faster loading from cache.\n", - "\n", - "\n", - "Now we have loaded the database, let's start poking around to see what's inside. To start, let's check the full list of attributes the dataset has...\n" - ] - }, - { - "cell_type": "code", - "execution_count": 32, - "id": "35861bcd", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "['__annotations__',\n", - " '__class__',\n", - " '__delattr__',\n", - " '__dict__',\n", - " '__dir__',\n", - " '__doc__',\n", - " '__eq__',\n", - " '__format__',\n", - " '__ge__',\n", - " '__getattribute__',\n", - " '__getstate__',\n", - " '__gt__',\n", - " '__hash__',\n", - " '__init__',\n", - " '__init_subclass__',\n", - " '__le__',\n", - " '__lt__',\n", - " '__module__',\n", - " '__ne__',\n", - " '__new__',\n", - " '__reduce__',\n", - " '__reduce_ex__',\n", - " '__repr__',\n", - " '__setattr__',\n", - " '__sizeof__',\n", - " '__str__',\n", - " '__subclasshook__',\n", - " '__weakref__',\n", - " '_full_db',\n", - " 'cache_dir',\n", - " 'db',\n", - " 'db_dir',\n", - " 'get_task',\n", - " 'make_db',\n", - " 'max_eval_time_frames',\n", - " 'name',\n", - " 'pack_db',\n", - " 'task_cls_dict',\n", - " 'task_cls_list',\n", - " 'task_names',\n", - " 'test_timestamp',\n", - " 'train_start_timestamp',\n", - " 'val_timestamp',\n", - " 'validate_and_correct_db']" - ] - }, - "execution_count": 32, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "dir(dataset)" - ] - }, - { - "cell_type": "markdown", - "id": "b7f50a31", - "metadata": {}, - "source": [ - "A lot of this list can be ignored, especially the `__blah__` attributes. There are, however, a number of attributes that we _do_ care about. \n", - "\n" - ] - }, - { - "cell_type": "markdown", - "id": "f550c912", - "metadata": {}, - "source": [ - "# Val / Test cutoffs" - ] - }, - { - "cell_type": "markdown", - "id": "0a617d65", - "metadata": {}, - "source": [ - "We can check the val/test time cutoffs as follows:" - ] - }, - { - "cell_type": "code", - "execution_count": 60, - "id": "b130a4f0", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2005-01-01 00:00:00\n", - "2010-01-01 00:00:00\n" - ] - } - ], - "source": [ - "print(dataset.val_timestamp)\n", - "print(dataset.test_timestamp)" - ] - }, - { - "cell_type": "markdown", - "id": "1adf6c04", - "metadata": {}, - "source": [ - "This shows that data before 2005 is used for training, between 2005 and 2010 for validation, and after 2010 for testing. \n", - "\n", - "Note that it is a RelBench design choice to make the validation and test cutoffs a dataset property, _not_ a task-specific property. In other words, all tasks for a given database use the same time splits.\n", - "\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "id": "e8ef6f3a", - "metadata": {}, - "source": [ - "# Acessing the raw data" - ] - }, - { - "cell_type": "markdown", - "id": "32ba9316", - "metadata": {}, - "source": [ - "\n", - "Next we check out `dataset.db`, which holds the data itself..." - ] - }, - { - "cell_type": "code", - "execution_count": 51, - "id": "3aa3d1df", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Database()" - ] - }, - "execution_count": 51, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "dataset.db" - ] - }, - { - "cell_type": "markdown", - "id": "555c65b9", - "metadata": {}, - "source": [ - "This returns a RelBench `Database` object. So let's go one layer deeper and check what's inside this..." - ] - }, - { - "cell_type": "code", - "execution_count": 47, - "id": "d4a07844", - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "data": { - "text/plain": [ - "['__class__',\n", - " '__delattr__',\n", - " '__dict__',\n", - " '__dir__',\n", - " '__doc__',\n", - " '__eq__',\n", - " '__format__',\n", - " '__ge__',\n", - " '__getattribute__',\n", - " '__getstate__',\n", - " '__gt__',\n", - " '__hash__',\n", - " '__init__',\n", - " '__init_subclass__',\n", - " '__le__',\n", - " '__lt__',\n", - " '__module__',\n", - " '__ne__',\n", - " '__new__',\n", - " '__reduce__',\n", - " '__reduce_ex__',\n", - " '__repr__',\n", - " '__setattr__',\n", - " '__sizeof__',\n", - " '__str__',\n", - " '__subclasshook__',\n", - " '__weakref__',\n", - " 'load',\n", - " 'max_timestamp',\n", - " 'min_timestamp',\n", - " 'reindex_pkeys_and_fkeys',\n", - " 'save',\n", - " 'table_dict',\n", - " 'upto']" - ] - }, - "execution_count": 47, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "dir(dataset.db)" - ] - }, - { - "cell_type": "markdown", - "id": "d34a6e25", - "metadata": {}, - "source": [ - "With this we can double check the full timespan of the database:" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "id": "2e8c4dc6", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2009-11-01 11:00:00\n", - "1950-05-13 00:00:00\n" - ] - } - ], - "source": [ - "print(dataset.db.max_timestamp)\n", - "print(dataset.db.min_timestamp)" - ] - }, - { - "cell_type": "markdown", - "id": "9572a9e6", - "metadata": {}, - "source": [ - "1950 is the first season for F1! So we have data for the full history of F1. Note that the `max_timestamp` is the cutoff date for validation data. Data from afer 2009 is used for testing, but is hidden from `dataset.db`. To see the full database including test data you can instead use `dataset._full_db`, but we advise caution when using this to avoid inadvertent time leakage. For instance we can check the final cutoff for test data by calling:" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "id": "ccf87b1b", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2023-11-26 13:00:00\n" - ] - } - ], - "source": [ - "print(dataset._full_db.max_timestamp)" - ] - }, - { - "cell_type": "markdown", - "id": "c29c2140", - "metadata": {}, - "source": [ - "Next let's check out the `dataset.db.table_dict`, which contains the raw tables.\n", - "\n", - "More info on the schemas for F1 and all other datasets can be found at https://relbench.stanford.edu/." - ] - }, - { - "cell_type": "code", - "execution_count": 56, - "id": "09bab67c", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'races': Table(df=\n", - " raceId year round circuitId name date \\\n", - " 0 0 1950 1 8 British Grand Prix 1950-05-13 00:00:00 \n", - " 1 1 1950 2 5 Monaco Grand Prix 1950-05-21 00:00:00 \n", - " 2 2 1950 3 18 Indianapolis 500 1950-05-30 00:00:00 \n", - " 3 3 1950 4 65 Swiss Grand Prix 1950-06-04 00:00:00 \n", - " 4 4 1950 5 12 Belgian Grand Prix 1950-06-18 00:00:00 \n", - " .. ... ... ... ... ... ... \n", - " 815 815 2009 13 13 Italian Grand Prix 2009-09-13 12:00:00 \n", - " 816 816 2009 14 14 Singapore Grand Prix 2009-09-27 12:00:00 \n", - " 817 817 2009 15 21 Japanese Grand Prix 2009-10-04 05:00:00 \n", - " 818 818 2009 16 17 Brazilian Grand Prix 2009-10-18 16:00:00 \n", - " 819 819 2009 17 23 Abu Dhabi Grand Prix 2009-11-01 11:00:00 \n", - " \n", - " time \n", - " 0 00:00:00 \n", - " 1 00:00:00 \n", - " 2 00:00:00 \n", - " 3 00:00:00 \n", - " 4 00:00:00 \n", - " .. ... \n", - " 815 12:00:00 \n", - " 816 12:00:00 \n", - " 817 05:00:00 \n", - " 818 16:00:00 \n", - " 819 11:00:00 \n", - " \n", - " [820 rows x 7 columns],\n", - " fkey_col_to_pkey_table={'circuitId': 'circuits'},\n", - " pkey_col=raceId,\n", - " time_col=date),\n", - " 'circuits': Table(df=\n", - " circuitId circuitRef name \\\n", - " 0 0 albert_park Albert Park Grand Prix Circuit \n", - " 1 1 sepang Sepang International Circuit \n", - " 2 2 bahrain Bahrain International Circuit \n", - " 3 3 catalunya Circuit de Barcelona-Catalunya \n", - " 4 4 istanbul Istanbul Park \n", - " .. ... ... ... \n", - " 72 72 portimao Autódromo Internacional do Algarve \n", - " 73 73 mugello Autodromo Internazionale del Mugello \n", - " 74 74 jeddah Jeddah Corniche Circuit \n", - " 75 75 losail Losail International Circuit \n", - " 76 76 miami Miami International Autodrome \n", - " \n", - " location country lat lng alt \n", - " 0 Melbourne Australia -37.84970 144.96800 10.0 \n", - " 1 Kuala Lumpur Malaysia 2.76083 101.73800 18.0 \n", - " 2 Sakhir Bahrain 26.03250 50.51060 7.0 \n", - " 3 Montmeló Spain 41.57000 2.26111 109.0 \n", - " 4 Istanbul Turkey 40.95170 29.40500 130.0 \n", - " .. ... ... ... ... ... \n", - " 72 Portimão Portugal 37.22700 -8.62670 108.0 \n", - " 73 Mugello Italy 43.99750 11.37190 255.0 \n", - " 74 Jeddah Saudi Arabia 21.63190 39.10440 15.0 \n", - " 75 Al Daayen Qatar 25.49000 51.45420 NaN \n", - " 76 Miami USA 25.95810 -80.23890 NaN \n", - " \n", - " [77 rows x 8 columns],\n", - " fkey_col_to_pkey_table={},\n", - " pkey_col=circuitId,\n", - " time_col=None),\n", - " 'drivers': Table(df=\n", - " driverId driverRef code forename surname dob \\\n", - " 0 0 hamilton HAM Lewis Hamilton 1985-01-07 \n", - " 1 1 heidfeld HEI Nick Heidfeld 1977-05-10 \n", - " 2 2 rosberg ROS Nico Rosberg 1985-06-27 \n", - " 3 3 alonso ALO Fernando Alonso 1981-07-29 \n", - " 4 4 kovalainen KOV Heikki Kovalainen 1981-10-19 \n", - " .. ... ... ... ... ... ... \n", - " 852 852 mick_schumacher MSC Mick Schumacher 1999-03-22 \n", - " 853 853 zhou ZHO Guanyu Zhou 1999-05-30 \n", - " 854 854 de_vries DEV Nyck de Vries 1995-02-06 \n", - " 855 855 piastri PIA Oscar Piastri 2001-04-06 \n", - " 856 856 sargeant SAR Logan Sargeant 2000-12-31 \n", - " \n", - " nationality \n", - " 0 British \n", - " 1 German \n", - " 2 German \n", - " 3 Spanish \n", - " 4 Finnish \n", - " .. ... \n", - " 852 German \n", - " 853 Chinese \n", - " 854 Dutch \n", - " 855 Australian \n", - " 856 American \n", - " \n", - " [857 rows x 7 columns],\n", - " fkey_col_to_pkey_table={},\n", - " pkey_col=driverId,\n", - " time_col=None),\n", - " 'results': Table(df=\n", - " resultId raceId driverId constructorId number grid position \\\n", - " 0 0 0 660 152 18.0 21 11.0 \n", - " 1 1 0 790 149 8.0 12 NaN \n", - " 2 2 0 579 49 1.0 3 NaN \n", - " 3 3 0 661 149 9.0 10 NaN \n", - " 4 4 0 789 152 17.0 7 NaN \n", - " ... ... ... ... ... ... ... ... \n", - " 20318 20318 819 1 1 6.0 8 5.0 \n", - " 20319 20319 819 21 22 23.0 4 4.0 \n", - " 20320 20320 819 17 22 22.0 5 3.0 \n", - " 20321 20321 819 16 8 14.0 3 2.0 \n", - " 20322 20322 819 2 2 16.0 9 9.0 \n", - " \n", - " positionOrder points laps milliseconds fastestLap rank statusId \\\n", - " 0 11 0.0 64 NaN NaN NaN 16 \n", - " 1 21 0.0 2 NaN NaN NaN 126 \n", - " 2 12 0.0 62 NaN NaN NaN 44 \n", - " 3 20 0.0 5 NaN NaN NaN 6 \n", - " 4 19 0.0 8 NaN NaN NaN 51 \n", - " ... ... ... ... ... ... ... ... \n", - " 20318 5 4.0 55 5669667.0 54.0 7.0 1 \n", - " 20319 4 5.0 55 5666149.0 54.0 4.0 1 \n", - " 20320 3 6.0 55 5661881.0 49.0 6.0 1 \n", - " 20321 2 8.0 55 5661271.0 14.0 5.0 1 \n", - " 20322 9 0.0 55 5689355.0 49.0 15.0 1 \n", - " \n", - " date \n", - " 0 1950-05-13 00:00:00 \n", - " 1 1950-05-13 00:00:00 \n", - " 2 1950-05-13 00:00:00 \n", - " 3 1950-05-13 00:00:00 \n", - " 4 1950-05-13 00:00:00 \n", - " ... ... \n", - " 20318 2009-11-01 11:00:00 \n", - " 20319 2009-11-01 11:00:00 \n", - " 20320 2009-11-01 11:00:00 \n", - " 20321 2009-11-01 11:00:00 \n", - " 20322 2009-11-01 11:00:00 \n", - " \n", - " [20323 rows x 15 columns],\n", - " fkey_col_to_pkey_table={'raceId': 'races', 'driverId': 'drivers', 'constructorId': 'constructors'},\n", - " pkey_col=resultId,\n", - " time_col=date),\n", - " 'standings': Table(df=\n", - " driverStandingsId raceId driverId points position wins \\\n", - " 0 0 0 789 0.0 20 0 \n", - " 1 1 0 640 0.0 18 0 \n", - " 2 2 0 589 0.0 19 0 \n", - " 3 3 0 669 0.0 15 0 \n", - " 4 4 0 661 0.0 22 0 \n", - " ... ... ... ... ... ... ... \n", - " 28110 28110 819 7 48.0 6 1 \n", - " 28111 28111 819 68 0.0 25 0 \n", - " 28112 28112 819 11 0.0 21 0 \n", - " 28113 28113 819 6 2.0 19 0 \n", - " 28114 28114 819 12 22.0 11 0 \n", - " \n", - " date \n", - " 0 1950-05-13 00:00:00 \n", - " 1 1950-05-13 00:00:00 \n", - " 2 1950-05-13 00:00:00 \n", - " 3 1950-05-13 00:00:00 \n", - " 4 1950-05-13 00:00:00 \n", - " ... ... \n", - " 28110 2009-11-01 11:00:00 \n", - " 28111 2009-11-01 11:00:00 \n", - " 28112 2009-11-01 11:00:00 \n", - " 28113 2009-11-01 11:00:00 \n", - " 28114 2009-11-01 11:00:00 \n", - " \n", - " [28115 rows x 7 columns],\n", - " fkey_col_to_pkey_table={'raceId': 'races', 'driverId': 'drivers'},\n", - " pkey_col=driverStandingsId,\n", - " time_col=date),\n", - " 'constructors': Table(df=\n", - " constructorId constructorRef name nationality\n", - " 0 0 mclaren McLaren British\n", - " 1 1 bmw_sauber BMW Sauber German\n", - " 2 2 williams Williams British\n", - " 3 3 renault Renault French\n", - " 4 4 toro_rosso Toro Rosso Italian\n", - " .. ... ... ... ...\n", - " 206 206 manor Manor Marussia British\n", - " 207 207 haas Haas F1 Team American\n", - " 208 208 racing_point Racing Point British\n", - " 209 209 alphatauri AlphaTauri Italian\n", - " 210 210 alpine Alpine F1 Team French\n", - " \n", - " [211 rows x 4 columns],\n", - " fkey_col_to_pkey_table={},\n", - " pkey_col=constructorId,\n", - " time_col=None),\n", - " 'constructor_results': Table(df=\n", - " constructorResultsId raceId constructorId points date\n", - " 0 0 48 103 13.0 1956-01-22 00:00:00\n", - " 1 1 48 5 12.0 1956-01-22 00:00:00\n", - " 2 2 54 126 0.0 1956-08-05 00:00:00\n", - " 3 3 54 103 15.0 1956-08-05 00:00:00\n", - " 4 4 54 5 9.0 1956-08-05 00:00:00\n", - " ... ... ... ... ... ...\n", - " 9403 9403 819 5 0.0 2009-11-01 11:00:00\n", - " 9404 9404 819 0 0.0 2009-11-01 11:00:00\n", - " 9405 9405 819 2 0.0 2009-11-01 11:00:00\n", - " 9406 9406 819 4 1.0 2009-11-01 11:00:00\n", - " 9407 9407 819 6 5.0 2009-11-01 11:00:00\n", - " \n", - " [9408 rows x 5 columns],\n", - " fkey_col_to_pkey_table={'raceId': 'races', 'constructorId': 'constructors'},\n", - " pkey_col=constructorResultsId,\n", - " time_col=date),\n", - " 'constructor_standings': Table(df=\n", - " constructorStandingsId raceId constructorId points position wins \\\n", - " 0 0 64 103 3.0 3 0 \n", - " 1 1 64 5 6.0 2 0 \n", - " 2 2 64 85 8.0 1 1 \n", - " 3 3 65 85 16.0 1 2 \n", - " 4 4 65 31 0.0 5 0 \n", - " ... ... ... ... ... ... ... \n", - " 10165 10165 819 3 26.0 8 0 \n", - " 10166 10166 819 1 36.0 6 0 \n", - " 10167 10167 819 2 34.5 7 0 \n", - " 10168 10168 819 4 8.0 10 0 \n", - " 10169 10169 819 9 13.0 9 0 \n", - " \n", - " date \n", - " 0 1958-01-19 00:00:00 \n", - " 1 1958-01-19 00:00:00 \n", - " 2 1958-01-19 00:00:00 \n", - " 3 1958-05-18 00:00:00 \n", - " 4 1958-05-18 00:00:00 \n", - " ... ... \n", - " 10165 2009-11-01 11:00:00 \n", - " 10166 2009-11-01 11:00:00 \n", - " 10167 2009-11-01 11:00:00 \n", - " 10168 2009-11-01 11:00:00 \n", - " 10169 2009-11-01 11:00:00 \n", - " \n", - " [10170 rows x 7 columns],\n", - " fkey_col_to_pkey_table={'raceId': 'races', 'constructorId': 'constructors'},\n", - " pkey_col=constructorStandingsId,\n", - " time_col=date),\n", - " 'qualifying': Table(df=\n", - " qualifyId raceId driverId constructorId number position \\\n", - " 0 0 548 43 26 26 19 \n", - " 1 1 548 100 30 31 26 \n", - " 2 2 548 91 29 34 25 \n", - " 3 3 548 82 31 11 24 \n", - " 4 4 548 105 32 19 23 \n", - " ... ... ... ... ... ... ... \n", - " 4077 4077 819 21 22 23 4 \n", - " 4078 4078 819 16 8 14 3 \n", - " 4079 4079 819 19 8 15 2 \n", - " 4080 4080 819 0 0 1 1 \n", - " 4081 4081 819 66 4 12 10 \n", - " \n", - " date \n", - " 0 1994-03-26 00:00:00 \n", - " 1 1994-03-26 00:00:00 \n", - " 2 1994-03-26 00:00:00 \n", - " 3 1994-03-26 00:00:00 \n", - " 4 1994-03-26 00:00:00 \n", - " ... ... \n", - " 4077 2009-10-31 11:00:00 \n", - " 4078 2009-10-31 11:00:00 \n", - " 4079 2009-10-31 11:00:00 \n", - " 4080 2009-10-31 11:00:00 \n", - " 4081 2009-10-31 11:00:00 \n", - " \n", - " [4082 rows x 7 columns],\n", - " fkey_col_to_pkey_table={'raceId': 'races', 'driverId': 'drivers', 'constructorId': 'constructors'},\n", - " pkey_col=qualifyId,\n", - " time_col=date)}" - ] - }, - "execution_count": 56, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "dataset.db.table_dict" - ] - }, - { - "cell_type": "markdown", - "id": "2d49e520", - "metadata": {}, - "source": [ - "So `dataset.db.table_dict` is a dict, and we can check the full list of tables in the F1 database by checking out the dict keys." - ] - }, - { - "cell_type": "code", - "execution_count": 58, - "id": "27d3f3c0", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "dict_keys(['races', 'circuits', 'drivers', 'results', 'standings', 'constructors', 'constructor_results', 'constructor_standings', 'qualifying'])" - ] - }, - "execution_count": 58, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "dataset.db.table_dict.keys()" - ] - }, - { - "cell_type": "markdown", - "id": "85e58011", - "metadata": {}, - "source": [ - "That's 9 tables total! Let's look more closely at one of them." - ] - }, - { - "cell_type": "code", - "execution_count": 68, - "id": "a2c7e757", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Table(df=\n", - " driverId driverRef code forename surname dob \\\n", - "0 0 hamilton HAM Lewis Hamilton 1985-01-07 \n", - "1 1 heidfeld HEI Nick Heidfeld 1977-05-10 \n", - "2 2 rosberg ROS Nico Rosberg 1985-06-27 \n", - "3 3 alonso ALO Fernando Alonso 1981-07-29 \n", - "4 4 kovalainen KOV Heikki Kovalainen 1981-10-19 \n", - ".. ... ... ... ... ... ... \n", - "852 852 mick_schumacher MSC Mick Schumacher 1999-03-22 \n", - "853 853 zhou ZHO Guanyu Zhou 1999-05-30 \n", - "854 854 de_vries DEV Nyck de Vries 1995-02-06 \n", - "855 855 piastri PIA Oscar Piastri 2001-04-06 \n", - "856 856 sargeant SAR Logan Sargeant 2000-12-31 \n", - "\n", - " nationality \n", - "0 British \n", - "1 German \n", - "2 German \n", - "3 Spanish \n", - "4 Finnish \n", - ".. ... \n", - "852 German \n", - "853 Chinese \n", - "854 Dutch \n", - "855 Australian \n", - "856 American \n", - "\n", - "[857 rows x 7 columns],\n", - " fkey_col_to_pkey_table={},\n", - " pkey_col=driverId,\n", - " time_col=None)" - ] - }, - "execution_count": 68, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "table = dataset.db.table_dict[\"drivers\"]\n", - "table" - ] - }, - { - "cell_type": "markdown", - "id": "5dd9d2d5", - "metadata": {}, - "source": [ - "The `drivers` table stores information on all F1 drivers that ever competed in a race. Note that the table comes with multiple bits of information:\n", - "- The table itself, `table.df` which is simply a Pandas DataFrame.\n", - "- The primary key column, `table.pkey_col`, which indicates that the `driverId` column holds the primary key for this particular table in the database.\n", - "- The primary time column, `table.time_col` which, if the entity is an event, records the time an event happened. In the case of drivers, they are non-temporal entities, so `table.time_col=None`.\n", - "- The other tables that foreign keys points to `table.fkey_col_to_pkey_table`. If the table has any foreign key columns, then this dict indicates which table we foreign key corresponds to. Again in the case of drivers this is not applicable. \n", - "\n", - "We can start to explore the data a little, e.g., check out the oldest and youngest ever F1 drivers, spanning 3 centuries!" - ] - }, - { - "cell_type": "code", - "execution_count": 80, - "id": "2eb184ab", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "driverId 855\n", - "driverRef piastri\n", - "code PIA\n", - "forename Oscar\n", - "surname Piastri\n", - "dob 2001-04-06 00:00:00\n", - "nationality Australian\n", - "Name: 855, dtype: object\n", - "driverId 741\n", - "driverRef etancelin\n", - "code \\N\n", - "forename Philippe\n", - "surname Étancelin\n", - "dob 1896-12-28 00:00:00\n", - "nationality French\n", - "Name: 741, dtype: object\n" - ] - } - ], - "source": [ - "print(table.df.iloc[table.df[\"dob\"].idxmax()])\n", - "print(table.df.iloc[table.df[\"dob\"].idxmin()])" - ] - }, - { - "cell_type": "markdown", - "id": "b8c96da3", - "metadata": {}, - "source": [ - "Going back to the `table.time_col` and `table.fkey_col_to_pkey_table`, the `results` table contains a non-trivial example." - ] - }, - { - "cell_type": "code", - "execution_count": 98, - "id": "56dbc068", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Table(df=\n", - " resultId raceId driverId constructorId number grid position \\\n", - "0 0 0 660 152 18.0 21 11.0 \n", - "1 1 0 790 149 8.0 12 NaN \n", - "2 2 0 579 49 1.0 3 NaN \n", - "3 3 0 661 149 9.0 10 NaN \n", - "4 4 0 789 152 17.0 7 NaN \n", - "... ... ... ... ... ... ... ... \n", - "20318 20318 819 1 1 6.0 8 5.0 \n", - "20319 20319 819 21 22 23.0 4 4.0 \n", - "20320 20320 819 17 22 22.0 5 3.0 \n", - "20321 20321 819 16 8 14.0 3 2.0 \n", - "20322 20322 819 2 2 16.0 9 9.0 \n", - "\n", - " positionOrder points laps milliseconds fastestLap rank statusId \\\n", - "0 11 0.0 64 NaN NaN NaN 16 \n", - "1 21 0.0 2 NaN NaN NaN 126 \n", - "2 12 0.0 62 NaN NaN NaN 44 \n", - "3 20 0.0 5 NaN NaN NaN 6 \n", - "4 19 0.0 8 NaN NaN NaN 51 \n", - "... ... ... ... ... ... ... ... \n", - "20318 5 4.0 55 5669667.0 54.0 7.0 1 \n", - "20319 4 5.0 55 5666149.0 54.0 4.0 1 \n", - "20320 3 6.0 55 5661881.0 49.0 6.0 1 \n", - "20321 2 8.0 55 5661271.0 14.0 5.0 1 \n", - "20322 9 0.0 55 5689355.0 49.0 15.0 1 \n", - "\n", - " date \n", - "0 1950-05-13 00:00:00 \n", - "1 1950-05-13 00:00:00 \n", - "2 1950-05-13 00:00:00 \n", - "3 1950-05-13 00:00:00 \n", - "4 1950-05-13 00:00:00 \n", - "... ... \n", - "20318 2009-11-01 11:00:00 \n", - "20319 2009-11-01 11:00:00 \n", - "20320 2009-11-01 11:00:00 \n", - "20321 2009-11-01 11:00:00 \n", - "20322 2009-11-01 11:00:00 \n", - "\n", - "[20323 rows x 15 columns],\n", - " fkey_col_to_pkey_table={'raceId': 'races', 'driverId': 'drivers', 'constructorId': 'constructors'},\n", - " pkey_col=resultId,\n", - " time_col=date)" - ] - }, - "execution_count": 98, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "dataset.db.table_dict[\"results\"]" - ] - }, - { - "cell_type": "markdown", - "id": "472fa6d9", - "metadata": {}, - "source": [ - "Here we start to notice certain data artifacts that might be good to keep in mind for later when doing ML modeling. For instance, the `milliseconds` and `fastestLap` columns seem to only have been collected for more recent races, with `NaN` features for earlier races." - ] - }, - { - "cell_type": "markdown", - "id": "e36f1796", - "metadata": {}, - "source": [ - "# Loading a task" - ] - }, - { - "cell_type": "markdown", - "id": "36a865b1", - "metadata": {}, - "source": [ - "Each RelBench dataset comes with multiple pre-defined predictive tasks. For any given RelBench dataset, you can check all the associated tasks with:" - ] - }, - { - "cell_type": "code", - "execution_count": 81, - "id": "cb910544", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "['driver-position', 'driver-dnf', 'driver-top3']" - ] - }, - "execution_count": 81, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "dataset.task_names" - ] - }, - { - "cell_type": "markdown", - "id": "b416b6df", - "metadata": {}, - "source": [ - "Check out https://relbench.stanford.edu/ for detailed descriptions of what each task is. As an example, let's use `driver-top3` where the task is, for a given driver and a given timestamp, to predict whether that driver will finish in the top 3 in some race in the next 30 days.\n", - "\n", - "The task itself is instantiated by calling:" - ] - }, - { - "cell_type": "code", - "execution_count": 86, - "id": "da9a3c23", - "metadata": {}, - "outputs": [], - "source": [ - "task = dataset.get_task(\"driver-top3\", process=True)" - ] - }, - { - "cell_type": "markdown", - "id": "8d71bcf7", - "metadata": {}, - "source": [ - "Ground truth train / val / test label are computed by calling `task.train_table` etc. Each task table contains triples (timestamp, Id, label) indicating the entity the label is associated to, the timepoint at which the prediction is made, an the label itself. The task table also indicates which database table it is \"attached\" to - in this case the the `drivers` table." - ] - }, - { - "cell_type": "code", - "execution_count": 90, - "id": "b4cf768e", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Table(df=\n", - " date driverId qualifying\n", - "0 2004-08-04 40 0\n", - "1 2004-08-04 45 0\n", - "2 2004-08-04 43 0\n", - "3 2004-06-05 17 1\n", - "4 2004-06-05 9 0\n", - "... ... ... ...\n", - "1348 1994-03-30 80 0\n", - "1349 1994-03-30 48 0\n", - "1350 1994-03-30 77 0\n", - "1351 1994-02-28 43 0\n", - "1352 1994-02-28 56 0\n", - "\n", - "[1353 rows x 3 columns],\n", - " fkey_col_to_pkey_table={'driverId': 'drivers'},\n", - " pkey_col=None,\n", - " time_col=date)" - ] - }, - "execution_count": 90, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "task.train_table" - ] - }, - { - "cell_type": "markdown", - "id": "56c4ec5d", - "metadata": {}, - "source": [ - "The test table is handled differently, with the labels being hidden by default." - ] - }, - { - "cell_type": "code", - "execution_count": 92, - "id": "65a664be", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Table(df=\n", - " date driverId\n", - "0 2013-03-16 153\n", - "1 2013-03-16 19\n", - "2 2012-10-17 808\n", - "3 2012-10-17 818\n", - "4 2012-10-17 817\n", - ".. ... ...\n", - "721 2010-07-30 14\n", - "722 2010-06-30 154\n", - "723 2010-06-30 14\n", - "724 2010-05-01 14\n", - "725 2010-05-01 154\n", - "\n", - "[726 rows x 2 columns],\n", - " fkey_col_to_pkey_table={'driverId': 'drivers'},\n", - " pkey_col=None,\n", - " time_col=date)" - ] - }, - "execution_count": 92, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "task.test_table" - ] - }, - { - "cell_type": "markdown", - "id": "f4754484", - "metadata": {}, - "source": [ - "We have carefully designed the standardized evaluation protocol (see: XXX) so that the test labels themselves are only ever used under the hood of RelBench, so users should not need to ever see them to reduce the risk of data leakage. If strictly needed, test labels can be retrieved by calling." - ] - }, - { - "cell_type": "code", - "execution_count": 97, - "id": "2775ee15", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Table(df=\n", - " date driverId qualifying\n", - "0 2013-03-16 153 0\n", - "1 2013-03-16 19 1\n", - "2 2012-10-17 808 0\n", - "3 2012-10-17 818 0\n", - "4 2012-10-17 817 0\n", - ".. ... ... ...\n", - "721 2010-07-30 14 0\n", - "722 2010-06-30 154 0\n", - "723 2010-06-30 14 0\n", - "724 2010-05-01 14 0\n", - "725 2010-05-01 154 0\n", - "\n", - "[726 rows x 3 columns],\n", - " fkey_col_to_pkey_table={'driverId': 'drivers'},\n", - " pkey_col=None,\n", - " time_col=date)" - ] - }, - "execution_count": 97, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "task._full_test_table" - ] - }, - { - "cell_type": "markdown", - "id": "2b2768c9", - "metadata": {}, - "source": [ - "Now we have explored the data and task, the next step is to train an ML model on the data. See XXX for our GNN-based approach!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "319217ed", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.5" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} From 4ec39b82711bf0a9b5d157875ea58bae5aa7ff35 Mon Sep 17 00:00:00 2001 From: Rishabh Ranjan Date: Tue, 9 Jul 2024 10:58:58 -0700 Subject: [PATCH 4/4] example to run scripts --- README.md | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 4a50650b..5e7b5dae 100644 --- a/README.md +++ b/README.md @@ -71,11 +71,18 @@ pip install relbench[full] ``` -To run the example GNN scripts in the `examples` directory, use: +For the scripts in the `examples` directory, use: ```bash pip install relbench[example] ``` +Then, to run a script: +```bash +git clone https://github.com/snap-stanford/relbench +cd relbench/examples +python gnn_node.py --dataset rel-f1 --task driver-position +``` + # Package Usage