Skip to content

Commit

Permalink
Remove train_start_timestamp (#211)
Browse files Browse the repository at this point in the history
.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
rishabh-ranjan and pre-commit-ci[bot] authored Jul 3, 2024
1 parent b5cee90 commit 30af0a3
Show file tree
Hide file tree
Showing 9 changed files with 43 additions and 16 deletions.
9 changes: 9 additions & 0 deletions relbench/data/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,15 @@ def upto(self, time_stamp: pd.Timestamp) -> Self:
}
)

def from_(self, time_stamp: pd.Timestamp) -> Self:
r"""Returns a database with all rows from time_stamp."""

return Database(
table_dict={
name: table.from_(time_stamp) for name, table in self.table_dict.items()
}
)

def reindex_pkeys_and_fkeys(self) -> None:
r"""Mapping primary and foreign keys into indices according to
the ordering in the primary key tables.
Expand Down
6 changes: 0 additions & 6 deletions relbench/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ class Dataset:
def __init__(
self,
db: Database,
train_start_timestamp: Optional[pd.Timestamp],
val_timestamp: pd.Timestamp,
test_timestamp: pd.Timestamp,
max_eval_time_frames: int,
Expand All @@ -29,15 +28,12 @@ def __init__(
Args:
db (Database): The database object.
train_start_timestamp (pd.Timestamp, optional): If specified, we create
train table after the specified time.
val_timestamp (pd.Timestamp): The first timestamp for making val table.
test_timestamp (pd.Timestamp): The first timestamp for making test table.
max_eval_time_frames (int): The maximum number of unique timestamps used to build test and val tables.
"""
self._full_db = db
self.train_start_timestamp = train_start_timestamp
self.val_timestamp = val_timestamp
self.test_timestamp = test_timestamp
self.max_eval_time_frames = max_eval_time_frames
Expand Down Expand Up @@ -74,7 +70,6 @@ def validate_and_correct_db(self):

class RelBenchDataset(Dataset):
name: str
train_start_timestamp: Optional[pd.Timestamp] = None
val_timestamp: pd.Timestamp
test_timestamp: pd.Timestamp

Expand Down Expand Up @@ -111,7 +106,6 @@ def __init__(self, process=None) -> None:

super().__init__(
db,
self.train_start_timestamp,
self.val_timestamp,
self.test_timestamp,
self.max_eval_time_frames,
Expand Down
13 changes: 13 additions & 0 deletions relbench/data/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,19 @@ def upto(self, time_stamp: pd.Timestamp) -> Self:
time_col=self.time_col,
)

def from_(self, time_stamp: pd.Timestamp) -> Self:
r"""Returns a table with all rows from time."""

if self.time_col is None:
return self

return Table(
df=self.df.query(f"{self.time_col} >= @time_stamp"),
fkey_col_to_pkey_table=self.fkey_col_to_pkey_table,
pkey_col=self.pkey_col,
time_col=self.time_col,
)

@property
@lru_cache(maxsize=None)
def min_timestamp(self) -> pd.Timestamp:
Expand Down
2 changes: 1 addition & 1 deletion relbench/data/task_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def train_table(self) -> Table:
if "train" not in self._cached_table_dict:
timestamps = pd.date_range(
start=self.dataset.val_timestamp - self.timedelta,
end=self.dataset.train_start_timestamp or self.dataset.db.min_timestamp,
end=self.dataset.db.min_timestamp,
freq=-self.timedelta,
)
if len(timestamps) < 3:
Expand Down
7 changes: 5 additions & 2 deletions relbench/datasets/amazon.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ class AmazonDataset(RelBenchDataset):
name = "rel-amazon"
val_timestamp = pd.Timestamp("2015-10-01")
test_timestamp = pd.Timestamp("2016-01-01")
train_start_timestamp = pd.Timestamp("2008-01-01")

max_eval_time_frames = 1

Expand Down Expand Up @@ -206,7 +205,7 @@ def fix_column(value):
toc = time.time()
print(f"done in {toc - tic:.2f} seconds.")

return Database(
db = Database(
table_dict={
"product": Table(
df=pdf,
Expand All @@ -231,3 +230,7 @@ def fix_column(value):
),
}
)

db = db.from_(pd.Timestamp("2008-01-01"))

return db
7 changes: 5 additions & 2 deletions relbench/datasets/avito.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ class AvitoDataset(RelBenchDataset):
)

# search stream ranges from 2015-04-25 to 2015-05-20
train_start_timestamp = pd.Timestamp("2015-04-25")
val_timestamp = pd.Timestamp("2015-05-08")
test_timestamp = pd.Timestamp("2015-05-14")
max_eval_time_frames = 1
Expand Down Expand Up @@ -138,4 +137,8 @@ def make_db(self) -> Database:
},
time_col="ViewDate",
)
return Database(tables)
db = Database(tables)

db = db.from_(pd.Timestamp("2015-04-25"))

return db
7 changes: 5 additions & 2 deletions relbench/datasets/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ class EventDataset(RelBenchDataset):
"kaggle competitions download -c event-recommendation-engine-challenge"
)

train_start_timestamp = pd.Timestamp("2012-06-20")
val_timestamp = pd.Timestamp("2012-11-21")
test_timestamp = pd.Timestamp("2012-11-29")
max_eval_time_frames = 1
Expand Down Expand Up @@ -153,7 +152,7 @@ def make_db(self) -> Database:
subset=["user_id"]
)

return Database(
db = Database(
table_dict={
"users": Table(
df=users_df,
Expand Down Expand Up @@ -198,3 +197,7 @@ def make_db(self) -> Database:
),
}
)

db = db.from_(pd.Timestamp("2012-06-20"))

return db
1 change: 0 additions & 1 deletion relbench/datasets/fake.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ def __init__(
max_eval_time_frames = 1
super().__init__(
db=db,
train_start_timestamp=None,
val_timestamp=val_timestamp,
test_timestamp=test_timestamp,
max_eval_time_frames=max_eval_time_frames,
Expand Down
7 changes: 5 additions & 2 deletions relbench/datasets/hm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ class HMDataset(RelBenchDataset):
)
# Train for the most recent 1 year out of 2 years of the original
# time period
train_start_timestamp = pd.Timestamp("2019-09-07")
val_timestamp = pd.Timestamp("2020-09-07")
test_timestamp = pd.Timestamp("2020-09-14")
max_eval_time_frames = 1
Expand Down Expand Up @@ -54,7 +53,7 @@ def make_db(self) -> Database:
transactions_df["t_dat"], format="%Y-%m-%d"
)

return Database(
db = Database(
table_dict={
"article": Table(
df=articles_df,
Expand All @@ -76,3 +75,7 @@ def make_db(self) -> Database:
),
}
)

db = db.from_(pd.Timestamp("2019-09-07"))

return db

0 comments on commit 30af0a3

Please sign in to comment.