From 30af0a37fa8fb85720402a439cc926b02c4dc357 Mon Sep 17 00:00:00 2001 From: Rishabh Ranjan Date: Wed, 3 Jul 2024 15:27:12 -0700 Subject: [PATCH] Remove train_start_timestamp (#211) . --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- relbench/data/database.py | 9 +++++++++ relbench/data/dataset.py | 6 ------ relbench/data/table.py | 13 +++++++++++++ relbench/data/task_base.py | 2 +- relbench/datasets/amazon.py | 7 +++++-- relbench/datasets/avito.py | 7 +++++-- relbench/datasets/event.py | 7 +++++-- relbench/datasets/fake.py | 1 - relbench/datasets/hm.py | 7 +++++-- 9 files changed, 43 insertions(+), 16 deletions(-) diff --git a/relbench/data/database.py b/relbench/data/database.py index f4a987d1..dab08b27 100644 --- a/relbench/data/database.py +++ b/relbench/data/database.py @@ -82,6 +82,15 @@ def upto(self, time_stamp: pd.Timestamp) -> Self: } ) + def from_(self, time_stamp: pd.Timestamp) -> Self: + r"""Returns a database with all rows from time_stamp.""" + + return Database( + table_dict={ + name: table.from_(time_stamp) for name, table in self.table_dict.items() + } + ) + def reindex_pkeys_and_fkeys(self) -> None: r"""Mapping primary and foreign keys into indices according to the ordering in the primary key tables. diff --git a/relbench/data/dataset.py b/relbench/data/dataset.py index 657b41b9..5a8b7087 100644 --- a/relbench/data/dataset.py +++ b/relbench/data/dataset.py @@ -20,7 +20,6 @@ class Dataset: def __init__( self, db: Database, - train_start_timestamp: Optional[pd.Timestamp], val_timestamp: pd.Timestamp, test_timestamp: pd.Timestamp, max_eval_time_frames: int, @@ -29,15 +28,12 @@ def __init__( Args: db (Database): The database object. - train_start_timestamp (pd.Timestamp, optional): If specified, we create - train table after the specified time. val_timestamp (pd.Timestamp): The first timestamp for making val table. test_timestamp (pd.Timestamp): The first timestamp for making test table. max_eval_time_frames (int): The maximum number of unique timestamps used to build test and val tables. """ self._full_db = db - self.train_start_timestamp = train_start_timestamp self.val_timestamp = val_timestamp self.test_timestamp = test_timestamp self.max_eval_time_frames = max_eval_time_frames @@ -74,7 +70,6 @@ def validate_and_correct_db(self): class RelBenchDataset(Dataset): name: str - train_start_timestamp: Optional[pd.Timestamp] = None val_timestamp: pd.Timestamp test_timestamp: pd.Timestamp @@ -111,7 +106,6 @@ def __init__(self, process=None) -> None: super().__init__( db, - self.train_start_timestamp, self.val_timestamp, self.test_timestamp, self.max_eval_time_frames, diff --git a/relbench/data/table.py b/relbench/data/table.py index 85ab7a2f..0564a9e0 100644 --- a/relbench/data/table.py +++ b/relbench/data/table.py @@ -110,6 +110,19 @@ def upto(self, time_stamp: pd.Timestamp) -> Self: time_col=self.time_col, ) + def from_(self, time_stamp: pd.Timestamp) -> Self: + r"""Returns a table with all rows from time.""" + + if self.time_col is None: + return self + + return Table( + df=self.df.query(f"{self.time_col} >= @time_stamp"), + fkey_col_to_pkey_table=self.fkey_col_to_pkey_table, + pkey_col=self.pkey_col, + time_col=self.time_col, + ) + @property @lru_cache(maxsize=None) def min_timestamp(self) -> pd.Timestamp: diff --git a/relbench/data/task_base.py b/relbench/data/task_base.py index 4fca0c79..b5533891 100644 --- a/relbench/data/task_base.py +++ b/relbench/data/task_base.py @@ -62,7 +62,7 @@ def train_table(self) -> Table: if "train" not in self._cached_table_dict: timestamps = pd.date_range( start=self.dataset.val_timestamp - self.timedelta, - end=self.dataset.train_start_timestamp or self.dataset.db.min_timestamp, + end=self.dataset.db.min_timestamp, freq=-self.timedelta, ) if len(timestamps) < 3: diff --git a/relbench/datasets/amazon.py b/relbench/datasets/amazon.py index 969ed82f..4b53ac0c 100644 --- a/relbench/datasets/amazon.py +++ b/relbench/datasets/amazon.py @@ -11,7 +11,6 @@ class AmazonDataset(RelBenchDataset): name = "rel-amazon" val_timestamp = pd.Timestamp("2015-10-01") test_timestamp = pd.Timestamp("2016-01-01") - train_start_timestamp = pd.Timestamp("2008-01-01") max_eval_time_frames = 1 @@ -206,7 +205,7 @@ def fix_column(value): toc = time.time() print(f"done in {toc - tic:.2f} seconds.") - return Database( + db = Database( table_dict={ "product": Table( df=pdf, @@ -231,3 +230,7 @@ def fix_column(value): ), } ) + + db = db.from_(pd.Timestamp("2008-01-01")) + + return db diff --git a/relbench/datasets/avito.py b/relbench/datasets/avito.py index f56a99b0..648ba3ca 100644 --- a/relbench/datasets/avito.py +++ b/relbench/datasets/avito.py @@ -16,7 +16,6 @@ class AvitoDataset(RelBenchDataset): ) # search stream ranges from 2015-04-25 to 2015-05-20 - train_start_timestamp = pd.Timestamp("2015-04-25") val_timestamp = pd.Timestamp("2015-05-08") test_timestamp = pd.Timestamp("2015-05-14") max_eval_time_frames = 1 @@ -138,4 +137,8 @@ def make_db(self) -> Database: }, time_col="ViewDate", ) - return Database(tables) + db = Database(tables) + + db = db.from_(pd.Timestamp("2015-04-25")) + + return db diff --git a/relbench/datasets/event.py b/relbench/datasets/event.py index c8dfb07a..24693094 100644 --- a/relbench/datasets/event.py +++ b/relbench/datasets/event.py @@ -20,7 +20,6 @@ class EventDataset(RelBenchDataset): "kaggle competitions download -c event-recommendation-engine-challenge" ) - train_start_timestamp = pd.Timestamp("2012-06-20") val_timestamp = pd.Timestamp("2012-11-21") test_timestamp = pd.Timestamp("2012-11-29") max_eval_time_frames = 1 @@ -153,7 +152,7 @@ def make_db(self) -> Database: subset=["user_id"] ) - return Database( + db = Database( table_dict={ "users": Table( df=users_df, @@ -198,3 +197,7 @@ def make_db(self) -> Database: ), } ) + + db = db.from_(pd.Timestamp("2012-06-20")) + + return db diff --git a/relbench/datasets/fake.py b/relbench/datasets/fake.py index 9ddab8e5..6f457a87 100644 --- a/relbench/datasets/fake.py +++ b/relbench/datasets/fake.py @@ -30,7 +30,6 @@ def __init__( max_eval_time_frames = 1 super().__init__( db=db, - train_start_timestamp=None, val_timestamp=val_timestamp, test_timestamp=test_timestamp, max_eval_time_frames=max_eval_time_frames, diff --git a/relbench/datasets/hm.py b/relbench/datasets/hm.py index 29088525..4822abcb 100644 --- a/relbench/datasets/hm.py +++ b/relbench/datasets/hm.py @@ -15,7 +15,6 @@ class HMDataset(RelBenchDataset): ) # Train for the most recent 1 year out of 2 years of the original # time period - train_start_timestamp = pd.Timestamp("2019-09-07") val_timestamp = pd.Timestamp("2020-09-07") test_timestamp = pd.Timestamp("2020-09-14") max_eval_time_frames = 1 @@ -54,7 +53,7 @@ def make_db(self) -> Database: transactions_df["t_dat"], format="%Y-%m-%d" ) - return Database( + db = Database( table_dict={ "article": Table( df=articles_df, @@ -76,3 +75,7 @@ def make_db(self) -> Database: ), } ) + + db = db.from_(pd.Timestamp("2019-09-07")) + + return db