Skip to content

Commit

Permalink
remove start_time from datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
rishabh-ranjan committed Jun 27, 2024
1 parent bb11c12 commit 6b1d8d8
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 8 deletions.
7 changes: 5 additions & 2 deletions relbench/datasets/amazon.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ class AmazonDataset(RelBenchDataset):
name = "rel-amazon"
val_timestamp = pd.Timestamp("2015-10-01")
test_timestamp = pd.Timestamp("2016-01-01")
train_start_timestamp = pd.Timestamp("2008-01-01")

max_eval_time_frames = 1

Expand Down Expand Up @@ -206,7 +205,7 @@ def fix_column(value):
toc = time.time()
print(f"done in {toc - tic:.2f} seconds.")

return Database(
db = Database(
table_dict={
"product": Table(
df=pdf,
Expand All @@ -231,3 +230,7 @@ def fix_column(value):
),
}
)

db = db.from_(pd.Timestamp("2008-01-01"))

return db
7 changes: 5 additions & 2 deletions relbench/datasets/avito.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ class AvitoDataset(RelBenchDataset):
)

# search stream ranges from 2015-04-25 to 2015-05-20
train_start_timestamp = pd.Timestamp("2015-04-25")
val_timestamp = pd.Timestamp("2015-05-08")
test_timestamp = pd.Timestamp("2015-05-14")
max_eval_time_frames = 1
Expand Down Expand Up @@ -138,4 +137,8 @@ def make_db(self) -> Database:
},
time_col="ViewDate",
)
return Database(tables)
db = Database(tables)

db = db.from_(pd.Timestamp("2015-04-25"))

return db
7 changes: 5 additions & 2 deletions relbench/datasets/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ class EventDataset(RelBenchDataset):
"kaggle competitions download -c event-recommendation-engine-challenge"
)

train_start_timestamp = pd.Timestamp("2012-06-20")
val_timestamp = pd.Timestamp("2012-11-21")
test_timestamp = pd.Timestamp("2012-11-29")
max_eval_time_frames = 1
Expand Down Expand Up @@ -85,7 +84,7 @@ def make_db(self) -> Database:
event_attendees_df["start_time"].dt.tz_localize(None).apply(pd.Timestamp)
)

return Database(
db = Database(
table_dict={
"users": Table(
df=users_df,
Expand Down Expand Up @@ -130,3 +129,7 @@ def make_db(self) -> Database:
),
}
)

db = db.from_(pd.Timestamp("2012-06-20"))

return db
7 changes: 5 additions & 2 deletions relbench/datasets/hm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ class HMDataset(RelBenchDataset):
)
# Train for the most recent 1 year out of 2 years of the original
# time period
train_start_timestamp = pd.Timestamp("2019-09-07")
val_timestamp = pd.Timestamp("2020-09-07")
test_timestamp = pd.Timestamp("2020-09-14")
max_eval_time_frames = 1
Expand Down Expand Up @@ -54,7 +53,7 @@ def make_db(self) -> Database:
transactions_df["t_dat"], format="%Y-%m-%d"
)

return Database(
db = Database(
table_dict={
"article": Table(
df=articles_df,
Expand All @@ -76,3 +75,7 @@ def make_db(self) -> Database:
),
}
)

db = db.from_(pd.Timestamp("2019-09-07"))

return db

0 comments on commit 6b1d8d8

Please sign in to comment.