Skip to content

Commit

Permalink
fix make_db for event dataset (#244)
Browse files Browse the repository at this point in the history
  • Loading branch information
yiweny authored Jul 25, 2024
1 parent e32756f commit 815c508
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions relbench/datasets/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,18 +54,18 @@ def make_db(self) -> Database:
users_df = pd.read_csv(users, dtype={"user_id": int}, parse_dates=["joinedAt"])
users_df["birthyear"] = pd.to_numeric(users_df["birthyear"], errors="coerce")
users_df["joinedAt"] = pd.to_datetime(
users_df["joinedAt"], errors="coerce", format="mixed"
users_df["joinedAt"], errors="coerce", format="%Y-%m-%d %H:%M:%S.%f%z"
).dt.tz_localize(None)

events_df = pd.read_csv(events)
events_df["start_time"] = pd.to_datetime(
events_df["start_time"], errors="coerce", format="mixed"
events_df["start_time"], errors="coerce", format="%Y-%m-%d %H:%M:%S.%f%z"
).dt.tz_localize(None)

train = os.path.join(path, "train.csv")
event_interest_df = pd.read_csv(train)
event_interest_df["timestamp"] = pd.to_datetime(
event_interest_df["timestamp"], format="mixed"
event_interest_df["timestamp"], format="%Y-%m-%d %H:%M:%S.%f%z"
).dt.tz_localize(None)

if not os.path.exists(os.path.join(path, "user_friends_flattened.csv")):
Expand Down Expand Up @@ -152,7 +152,7 @@ def make_db(self) -> Database:
),
"events": Table(
df=events_df,
fkey_col_to_pkey_table={"user_id": "friends"},
fkey_col_to_pkey_table={"user_id": "users"},
pkey_col="event_id",
time_col="start_time",
),
Expand Down

0 comments on commit 815c508

Please sign in to comment.