Skip to content

Commit

Permalink
revert code to construct db from dataset (#218)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Rishabh Ranjan <[email protected]>
  • Loading branch information
3 people authored Jul 3, 2024
1 parent a4ba6ce commit 122bfc7
Showing 1 changed file with 98 additions and 30 deletions.
128 changes: 98 additions & 30 deletions relbench/datasets/event.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import os.path as osp
import shutil
from pathlib import Path

Expand Down Expand Up @@ -44,48 +45,115 @@ def check_table_and_decompress_if_exists(self, table_path: str, alt_path: str =
self.err_msg.format(data=table_path, url=self.url, path=table_path)

def make_db(self) -> Database:
url = "https://relbench.stanford.edu/data/rel-event-raw.zip"
path = pooch.retrieve(
url,
known_hash="9cb01d6e5e8bd60db61c769656d69bdd0864ed8030d9932784e8338ed5d1183e",
progressbar=True,
processor=unzip_processor,
)
users_df = pd.read_csv(
os.path.join(path, "users.csv"), parse_dates=["joinedAt"]
path = osp.join(osp.dirname(osp.realpath(__file__)), "..", "data", "rel-event")
users = os.path.join(path, "users.csv")
user_friends = os.path.join(path, "user_friends.csv")
events = os.path.join(path, "events.csv")
event_attendees = os.path.join(path, "event_attendees.csv")
if not (os.path.exists(users)):
if not os.path.exists(zip):
raise RuntimeError(
self.err_msg.format(data="Dataset", url=self.url, path=zip)
)
else:
shutil.unpack_archive(zip, Path(zip).parent)
self.check_table_and_decompress_if_exists(
user_friends, os.path.join(path, "user_friends_flattened.csv")
)
friends_df = pd.read_csv(
os.path.join(path, "users.csv"), parse_dates=["joinedAt"]
self.check_table_and_decompress_if_exists(events)
self.check_table_and_decompress_if_exists(
event_attendees, os.path.join(path, "event_attendees_flattened.csv")
)
user_friends_df = pd.read_csv(os.path.join(path, "user_friends.csv"))
events_df = pd.read_csv(os.path.join(path, "events.csv"))
events_df = events_df.dropna()
events_df["user_id"] = events_df["user_id"].astype(int)
event_attendees_df = pd.read_csv(os.path.join(path, "event_attendees.csv"))
event_interest_df = pd.read_csv(os.path.join(path, "train.csv"))
users_df = pd.read_csv(users, dtype={"user_id": int}, parse_dates=["joinedAt"])
users_df["birthyear"] = pd.to_numeric(users_df["birthyear"], errors="coerce")
users_df["joinedAt"] = pd.to_datetime(
users_df["joinedAt"], errors="coerce"
).dt.tz_localize(None)
users_df["birthyear"] = pd.to_numeric(users_df["birthyear"], errors="coerce")
friends_df["joinedAt"] = pd.to_datetime(
friends_df["joinedAt"], errors="coerce"
).dt.tz_localize(None)

friends_df = pd.read_csv(
users, dtype={"user_id": int}, parse_dates=["joinedAt"]
)
friends_df["birthyear"] = pd.to_numeric(
friends_df["birthyear"], errors="coerce"
)
friends_df["joinedAt"] = pd.to_datetime(
friends_df["joinedAt"], errors="coerce"
).dt.tz_localize(None)
events_df = pd.read_csv(events)
events_df["start_time"] = pd.to_datetime(
events_df["start_time"], errors="coerce"
).dt.tz_localize(None)

train = os.path.join(path, "train.csv")
event_interest_df = pd.read_csv(train)
event_interest_df["timestamp"] = pd.to_datetime(
event_interest_df["timestamp"], errors="coerce"
event_interest_df["timestamp"]
).dt.tz_localize(None)
event_attendees_df["start_time"] = pd.to_datetime(
event_attendees_df["start_time"], errors="coerce"
)
event_attendees_df["start_time"] = (
event_attendees_df["start_time"].dt.tz_localize(None).apply(pd.Timestamp)
)

if not os.path.exists(os.path.join(path, "user_friends_flattened.csv")):
user_friends_df = pd.read_csv(user_friends)
user_friends_df = (
user_friends_df.set_index("user")["friends"]
.str.split(expand=True)
.stack()
.reset_index()
)
user_friends_df.columns = ["user", "index", "friend"]
user_friends_flattened_df = user_friends_df.drop("index", axis=1).assign(
user=lambda df: df["user"].astype(int),
friend=lambda df: df["friend"].astype(int),
)
user_friends_flattened_df.to_csv(
os.path.join(path, "user_friends_flattened.csv")
)
else:
user_friends_flattened_df = pd.read_csv(
os.path.join(path, "user_friends_flattened.csv")
)

if not os.path.exists(os.path.join(path, "event_attendees_flattened.csv")):
event_attendees_df = pd.read_csv(event_attendees)
melted_df = event_attendees_df.melt(
id_vars=["event"],
value_vars=["yes", "maybe", "invited", "no"],
var_name="status",
value_name="user_ids",
)
melted_df = melted_df.dropna()
melted_df["user_ids"] = melted_df["user_ids"].str.split()
melted_df["user_ids"] = melted_df["user_ids"].apply(
lambda x: [int(i) for i in x]
)
exploded_df = melted_df.explode("user_ids")
exploded_df["user_ids"] = exploded_df["user_ids"].astype(int)
exploded_df.rename(columns={"user_ids": "user_id"}, inplace=True)
exploded_df = pd.merge(
exploded_df,
events_df[["event_id", "start_time"]],
left_on="event",
right_on="event_id",
how="left",
)
exploded_df = exploded_df.drop("event_id", axis=1)
event_attendees_flattened_df = exploded_df.dropna(subset=["user_id"])
event_attendees_flattened_df.to_csv(
os.path.join(path, "event_attendees_flattened.csv")
)
else:
event_attendees_flattened_df = pd.read_csv(
os.path.join(path, "event_attendees_flattened.csv")
)
event_attendees_flattened_df["start_time"] = pd.to_datetime(
event_attendees_flattened_df["start_time"], errors="coerce"
)
event_attendees_flattened_df["start_time"] = (
event_attendees_flattened_df["start_time"]
.dt.tz_localize(None)
.apply(pd.Timestamp)
)
event_attendees_flattened_df = event_attendees_flattened_df.dropna(
subset=["user_id"]
)

return Database(
table_dict={
Expand All @@ -108,7 +176,7 @@ def make_db(self) -> Database:
time_col="start_time",
),
"event_attendees": Table(
df=event_attendees_df,
df=event_attendees_flattened_df,
fkey_col_to_pkey_table={
"event": "events",
"user_id": "users",
Expand All @@ -124,7 +192,7 @@ def make_db(self) -> Database:
time_col="timestamp",
),
"user_friends": Table(
df=user_friends_df,
df=user_friends_flattened_df,
fkey_col_to_pkey_table={
"user": "users",
"friend": "friends",
Expand Down

0 comments on commit 122bfc7

Please sign in to comment.