From 0a25d06729957f9f68729c33286d072782f82067 Mon Sep 17 00:00:00 2001 From: Rishabh Ranjan Date: Fri, 5 Jul 2024 13:26:51 -0700 Subject: [PATCH] More misc cleanup (#225) --- examples/gnn_link.py | 3 ++- relbench/base/database.py | 3 --- relbench/base/task_base.py | 2 -- relbench/base/task_link.py | 15 ++------------- relbench/datasets/avito.py | 10 +++++----- test/modeling/test_link_nn.py | 3 ++- 6 files changed, 11 insertions(+), 25 deletions(-) diff --git a/examples/gnn_link.py b/examples/gnn_link.py index 05e7460c..1cfba989 100644 --- a/examples/gnn_link.py +++ b/examples/gnn_link.py @@ -108,7 +108,8 @@ eval_loaders_dict: Dict[str, Tuple[NeighborLoader, NeighborLoader]] = {} for split in ["val", "test"]: - seed_time = task.val_seed_time if split == "val" else task.test_seed_time + timestamp = dataset.val_timestamp if split == "val" else dataset.test_timestamp + seed_time = int(timestamp.timestamp()) target_table = task.get_table(split) src_node_indices = torch.from_numpy(target_table.df[task.src_entity_col].values) src_loader = NeighborLoader( diff --git a/relbench/base/database.py b/relbench/base/database.py index b4d73a6e..e0210b67 100644 --- a/relbench/base/database.py +++ b/relbench/base/database.py @@ -13,15 +13,12 @@ class Database: r"""A database is a collection of named tables linked by foreign key - primary key connections.""" - # TODO: maybe add a function to visualize schema in jupyter - def __init__(self, table_dict: Dict[str, Table]) -> None: r"""Creates a database from a dictionary of tables.""" self.table_dict = table_dict def __repr__(self) -> str: - # TODO: add more info return f"{self.__class__.__name__}()" def save(self, path: Union[str, os.PathLike]) -> None: diff --git a/relbench/base/task_base.py b/relbench/base/task_base.py index de292c6e..8bbee168 100644 --- a/relbench/base/task_base.py +++ b/relbench/base/task_base.py @@ -63,8 +63,6 @@ def make_table( ) -> Table: r"""To be implemented by subclass.""" - # TODO: ensure that tasks follow the right-closed convention - raise NotImplementedError def _get_table(self, split: str) -> Table: diff --git a/relbench/base/task_link.py b/relbench/base/task_link.py index 3a80f02e..e9d3337b 100644 --- a/relbench/base/task_link.py +++ b/relbench/base/task_link.py @@ -6,8 +6,6 @@ import pandas as pd from numpy.typing import NDArray -# TODO: remove! -from ..modeling.utils import to_unix_time from .dataset import Dataset from .table import Table from .task_base import BaseTask, TaskType @@ -89,7 +87,6 @@ def evaluate( return {fn.__name__: fn(pred_isin, dst_count) for fn in metrics} - # TODO: should these be here? seed_time is confusing terminology? @property def num_src_nodes(self) -> int: return len(self.dataset.get_db().table_dict[self.src_entity_table]) @@ -98,15 +95,7 @@ def num_src_nodes(self) -> int: def num_dst_nodes(self) -> int: return len(self.dataset.get_db().table_dict[self.dst_entity_table]) - @property - def val_seed_time(self) -> int: - return to_unix_time(pd.Series([self.dataset.val_timestamp]))[0] - - @property - def test_seed_time(self) -> int: - return to_unix_time(pd.Series([self.dataset.test_timestamp]))[0] - - def stats(self) -> dict[str, dict[str, int]]: + def stats(self) -> Dict[str, Dict[str, int]]: r"""Get train / val / test table statistics for each timestamp and the whole table, including number of unique source entities, number of unique destination entities, number of destination @@ -177,7 +166,7 @@ def stats(self) -> dict[str, dict[str, int]]: ] = ratio_train_test_entity_overlap return res - def _get_stats(self, df: pd.DataFrame) -> list[int]: + def _get_stats(self, df: pd.DataFrame) -> List[int]: num_unique_src_entities = df[self.src_entity_col].nunique() num_unique_dst_entities = len( set(value for row in df[self.dst_entity_col] for value in row) diff --git a/relbench/datasets/avito.py b/relbench/datasets/avito.py index c906befc..16160b65 100644 --- a/relbench/datasets/avito.py +++ b/relbench/datasets/avito.py @@ -8,17 +8,15 @@ class AvitoDataset(Dataset): - url = "https://www.kaggle.com/competitions/avito-context-ad-clicks" - err_msg = ( - "{data} not found. Please download avito data from " - "'{url}' and move it to '{path}'." - ) + """Original data source: + https://www.kaggle.com/competitions/avito-context-ad-clicks""" # search stream ranges from 2015-04-25 to 2015-05-20 val_timestamp = pd.Timestamp("2015-05-08") test_timestamp = pd.Timestamp("2015-05-14") def make_db(self) -> Database: + # subsampled version of the original dataset # Customize path as necessary r"""Process the raw files into a database.""" url = "https://relbench.stanford.edu/data/rel-avito-raw-100k.zip" @@ -69,6 +67,8 @@ def make_db(self) -> Database: ) visit_stream_df = clean_datetime(visit_stream_df, "ViewDate") + category_df.drop(columns=["__index_level_0__"], inplace=True) + tables = {} tables["AdsInfo"] = Table( df=ads_info_df, diff --git a/test/modeling/test_link_nn.py b/test/modeling/test_link_nn.py index e4a4162b..45750432 100644 --- a/test/modeling/test_link_nn.py +++ b/test/modeling/test_link_nn.py @@ -100,7 +100,8 @@ def test_link_train_fake_product_dataset(tmp_path, share_same_time): eval_loaders_dict: Dict[str, Tuple[NeighborLoader, NeighborLoader]] = {} for split in ["val", "test"]: - seed_time = task.val_seed_time if split == "val" else task.test_seed_time + timestamp = dataset.val_timestamp if split == "val" else dataset.test_timestamp + seed_time = int(timestamp.timestamp()) target_table = task.get_table(split) src_node_indices = torch.from_numpy(target_table.df[task.src_entity_col].values) src_loader = NeighborLoader(