diff --git a/relbench/datasets/event.py b/relbench/datasets/event.py index 003fbc62..09cc95dd 100644 --- a/relbench/datasets/event.py +++ b/relbench/datasets/event.py @@ -81,6 +81,13 @@ def make_db(self) -> Database: user=lambda df: df["user"].astype(int), friend=lambda df: df["friend"].astype(int), ) + + # Some friends are not present in the user table, so we drop those friends + # in the user_friends table + user_friends_flattened_df = user_friends_flattened_df.merge( + users_df, how="inner", left_on="friend", right_on="user_id" + ) + user_friends_flattened_df = user_friends_flattened_df[["user", "friend"]] user_friends_flattened_df.to_csv( os.path.join(path, "user_friends_flattened.csv") )