From 122bfc74a53e3a70529e636234c9e5000e6a1e9f Mon Sep 17 00:00:00 2001 From: Yiwen Yuan Date: Wed, 3 Jul 2024 11:24:40 -0700 Subject: [PATCH] revert code to construct db from dataset (#218) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Rishabh Ranjan --- relbench/datasets/event.py | 128 ++++++++++++++++++++++++++++--------- 1 file changed, 98 insertions(+), 30 deletions(-) diff --git a/relbench/datasets/event.py b/relbench/datasets/event.py index e8835fb4..be31c70e 100644 --- a/relbench/datasets/event.py +++ b/relbench/datasets/event.py @@ -1,4 +1,5 @@ import os +import os.path as osp import shutil from pathlib import Path @@ -44,48 +45,115 @@ def check_table_and_decompress_if_exists(self, table_path: str, alt_path: str = self.err_msg.format(data=table_path, url=self.url, path=table_path) def make_db(self) -> Database: - url = "https://relbench.stanford.edu/data/rel-event-raw.zip" - path = pooch.retrieve( - url, - known_hash="9cb01d6e5e8bd60db61c769656d69bdd0864ed8030d9932784e8338ed5d1183e", - progressbar=True, - processor=unzip_processor, - ) - users_df = pd.read_csv( - os.path.join(path, "users.csv"), parse_dates=["joinedAt"] + path = osp.join(osp.dirname(osp.realpath(__file__)), "..", "data", "rel-event") + users = os.path.join(path, "users.csv") + user_friends = os.path.join(path, "user_friends.csv") + events = os.path.join(path, "events.csv") + event_attendees = os.path.join(path, "event_attendees.csv") + if not (os.path.exists(users)): + if not os.path.exists(zip): + raise RuntimeError( + self.err_msg.format(data="Dataset", url=self.url, path=zip) + ) + else: + shutil.unpack_archive(zip, Path(zip).parent) + self.check_table_and_decompress_if_exists( + user_friends, os.path.join(path, "user_friends_flattened.csv") ) - friends_df = pd.read_csv( - os.path.join(path, "users.csv"), parse_dates=["joinedAt"] + self.check_table_and_decompress_if_exists(events) + self.check_table_and_decompress_if_exists( + event_attendees, os.path.join(path, "event_attendees_flattened.csv") ) - user_friends_df = pd.read_csv(os.path.join(path, "user_friends.csv")) - events_df = pd.read_csv(os.path.join(path, "events.csv")) - events_df = events_df.dropna() - events_df["user_id"] = events_df["user_id"].astype(int) - event_attendees_df = pd.read_csv(os.path.join(path, "event_attendees.csv")) - event_interest_df = pd.read_csv(os.path.join(path, "train.csv")) + users_df = pd.read_csv(users, dtype={"user_id": int}, parse_dates=["joinedAt"]) + users_df["birthyear"] = pd.to_numeric(users_df["birthyear"], errors="coerce") users_df["joinedAt"] = pd.to_datetime( users_df["joinedAt"], errors="coerce" ).dt.tz_localize(None) - users_df["birthyear"] = pd.to_numeric(users_df["birthyear"], errors="coerce") - friends_df["joinedAt"] = pd.to_datetime( - friends_df["joinedAt"], errors="coerce" - ).dt.tz_localize(None) + + friends_df = pd.read_csv( + users, dtype={"user_id": int}, parse_dates=["joinedAt"] + ) friends_df["birthyear"] = pd.to_numeric( friends_df["birthyear"], errors="coerce" ) + friends_df["joinedAt"] = pd.to_datetime( + friends_df["joinedAt"], errors="coerce" + ).dt.tz_localize(None) + events_df = pd.read_csv(events) events_df["start_time"] = pd.to_datetime( events_df["start_time"], errors="coerce" ).dt.tz_localize(None) + train = os.path.join(path, "train.csv") + event_interest_df = pd.read_csv(train) event_interest_df["timestamp"] = pd.to_datetime( - event_interest_df["timestamp"], errors="coerce" + event_interest_df["timestamp"] ).dt.tz_localize(None) - event_attendees_df["start_time"] = pd.to_datetime( - event_attendees_df["start_time"], errors="coerce" - ) - event_attendees_df["start_time"] = ( - event_attendees_df["start_time"].dt.tz_localize(None).apply(pd.Timestamp) - ) + + if not os.path.exists(os.path.join(path, "user_friends_flattened.csv")): + user_friends_df = pd.read_csv(user_friends) + user_friends_df = ( + user_friends_df.set_index("user")["friends"] + .str.split(expand=True) + .stack() + .reset_index() + ) + user_friends_df.columns = ["user", "index", "friend"] + user_friends_flattened_df = user_friends_df.drop("index", axis=1).assign( + user=lambda df: df["user"].astype(int), + friend=lambda df: df["friend"].astype(int), + ) + user_friends_flattened_df.to_csv( + os.path.join(path, "user_friends_flattened.csv") + ) + else: + user_friends_flattened_df = pd.read_csv( + os.path.join(path, "user_friends_flattened.csv") + ) + + if not os.path.exists(os.path.join(path, "event_attendees_flattened.csv")): + event_attendees_df = pd.read_csv(event_attendees) + melted_df = event_attendees_df.melt( + id_vars=["event"], + value_vars=["yes", "maybe", "invited", "no"], + var_name="status", + value_name="user_ids", + ) + melted_df = melted_df.dropna() + melted_df["user_ids"] = melted_df["user_ids"].str.split() + melted_df["user_ids"] = melted_df["user_ids"].apply( + lambda x: [int(i) for i in x] + ) + exploded_df = melted_df.explode("user_ids") + exploded_df["user_ids"] = exploded_df["user_ids"].astype(int) + exploded_df.rename(columns={"user_ids": "user_id"}, inplace=True) + exploded_df = pd.merge( + exploded_df, + events_df[["event_id", "start_time"]], + left_on="event", + right_on="event_id", + how="left", + ) + exploded_df = exploded_df.drop("event_id", axis=1) + event_attendees_flattened_df = exploded_df.dropna(subset=["user_id"]) + event_attendees_flattened_df.to_csv( + os.path.join(path, "event_attendees_flattened.csv") + ) + else: + event_attendees_flattened_df = pd.read_csv( + os.path.join(path, "event_attendees_flattened.csv") + ) + event_attendees_flattened_df["start_time"] = pd.to_datetime( + event_attendees_flattened_df["start_time"], errors="coerce" + ) + event_attendees_flattened_df["start_time"] = ( + event_attendees_flattened_df["start_time"] + .dt.tz_localize(None) + .apply(pd.Timestamp) + ) + event_attendees_flattened_df = event_attendees_flattened_df.dropna( + subset=["user_id"] + ) return Database( table_dict={ @@ -108,7 +176,7 @@ def make_db(self) -> Database: time_col="start_time", ), "event_attendees": Table( - df=event_attendees_df, + df=event_attendees_flattened_df, fkey_col_to_pkey_table={ "event": "events", "user_id": "users", @@ -124,7 +192,7 @@ def make_db(self) -> Database: time_col="timestamp", ), "user_friends": Table( - df=user_friends_df, + df=user_friends_flattened_df, fkey_col_to_pkey_table={ "user": "users", "friend": "friends",