diff --git a/oasis/social_platform/platform.py b/oasis/social_platform/platform.py index 80d55b3..8031826 100644 --- a/oasis/social_platform/platform.py +++ b/oasis/social_platform/platform.py @@ -252,67 +252,68 @@ async def refresh(self, agent_id: int): datetime.now(), self.start_time) else: current_time = os.environ["SANDBOX_TIME"] - # try: - user_id = agent_id - # Retrieve all post_ids for a given user_id from the rec table - rec_query = "SELECT post_id FROM rec WHERE user_id = ?" - self.pl_utils._execute_db_command(rec_query, (user_id, )) - rec_results = self.db_cursor.fetchall() - - post_ids = [row[0] for row in rec_results] - selected_post_ids = post_ids - # If the number of post_ids >= self.refresh_rec_post_count, - # randomly select a specified number of post_ids - if len(selected_post_ids) >= self.refresh_rec_post_count: - selected_post_ids = random.sample(selected_post_ids, - self.refresh_rec_post_count) - - if self.recsys_type != RecsysType.REDDIT: - # Retrieve posts from following (in network) - # Modify the SQL query so that the refresh gets posts from - # people the user follows, sorted by the number of likes on - # Twitter - query_following_post = ( - "SELECT post.post_id, post.user_id, post.content, " - "post.created_at, post.num_likes FROM post " - "JOIN follow ON post.user_id = follow.followee_id " - "WHERE follow.follower_id = ? " - "ORDER BY post.num_likes DESC " - "LIMIT ?") - self.pl_utils._execute_db_command( - query_following_post, - ( - user_id, - self.following_post_count, - ), - ) - - following_posts = self.db_cursor.fetchall() - following_posts_ids = [row[0] for row in following_posts] - - selected_post_ids = following_posts_ids + selected_post_ids - selected_post_ids = list(set(selected_post_ids)) - - placeholders = ", ".join("?" for _ in selected_post_ids) - - post_query = ( - f"SELECT post_id, user_id, original_post_id, content, " - f"quote_content, created_at, num_likes, num_dislikes, " - f"num_shares FROM post WHERE post_id IN ({placeholders})") - self.pl_utils._execute_db_command(post_query, selected_post_ids) - results = self.db_cursor.fetchall() - if not results: - return {"success": False, "message": "No posts found."} - results_with_comments = self.pl_utils._add_comments_to_posts(results) + try: + user_id = agent_id + # Retrieve all post_ids for a given user_id from the rec table + rec_query = "SELECT post_id FROM rec WHERE user_id = ?" + self.pl_utils._execute_db_command(rec_query, (user_id, )) + rec_results = self.db_cursor.fetchall() + + post_ids = [row[0] for row in rec_results] + selected_post_ids = post_ids + # If the number of post_ids >= self.refresh_rec_post_count, + # randomly select a specified number of post_ids + if len(selected_post_ids) >= self.refresh_rec_post_count: + selected_post_ids = random.sample(selected_post_ids, + self.refresh_rec_post_count) + + if self.recsys_type != RecsysType.REDDIT: + # Retrieve posts from following (in network) + # Modify the SQL query so that the refresh gets posts from + # people the user follows, sorted by the number of likes on + # Twitter + query_following_post = ( + "SELECT post.post_id, post.user_id, post.content, " + "post.created_at, post.num_likes FROM post " + "JOIN follow ON post.user_id = follow.followee_id " + "WHERE follow.follower_id = ? " + "ORDER BY post.num_likes DESC " + "LIMIT ?") + self.pl_utils._execute_db_command( + query_following_post, + ( + user_id, + self.following_post_count, + ), + ) + + following_posts = self.db_cursor.fetchall() + following_posts_ids = [row[0] for row in following_posts] + + selected_post_ids = following_posts_ids + selected_post_ids + selected_post_ids = list(set(selected_post_ids)) + + placeholders = ", ".join("?" for _ in selected_post_ids) + + post_query = ( + f"SELECT post_id, user_id, original_post_id, content, " + f"quote_content, created_at, num_likes, num_dislikes, " + f"num_shares FROM post WHERE post_id IN ({placeholders})") + self.pl_utils._execute_db_command(post_query, selected_post_ids) + results = self.db_cursor.fetchall() + if not results: + return {"success": False, "message": "No posts found."} + results_with_comments = self.pl_utils._add_comments_to_posts( + results) - action_info = {"posts": results_with_comments} - twitter_log.info(action_info) - self.pl_utils._record_trace(user_id, ActionType.REFRESH.value, - action_info, current_time) + action_info = {"posts": results_with_comments} + twitter_log.info(action_info) + self.pl_utils._record_trace(user_id, ActionType.REFRESH.value, + action_info, current_time) - return {"success": True, "posts": results_with_comments} - # except Exception as e: - # return {"success": False, "error": str(e)} + return {"success": True, "posts": results_with_comments} + except Exception as e: + return {"success": False, "error": str(e)} async def update_rec_table(self): # Recsys(trace/user/post table), refresh rec table @@ -428,8 +429,8 @@ async def repost(self, agent_id: int, post_id: int): } post_type_result = self.pl_utils._get_post_type(post_id) - post_insert_query = ("INSERT INTO post (user_id, original_post_id)" - "VALUES (?, ?)") + post_insert_query = ("INSERT INTO post (user_id, original_post_id" + ", created_at) VALUES (?, ?, ?)") # Update num_shares for the found post update_shares_query = ( "UPDATE post SET num_shares = num_shares + 1 WHERE post_id = ?" @@ -439,9 +440,9 @@ async def repost(self, agent_id: int, post_id: int): return {"success": False, "error": "Post not found."} elif (post_type_result['type'] == 'common' or post_type_result['type'] == 'quote'): - self.pl_utils._execute_db_command(post_insert_query, - (user_id, post_id), - commit=True) + self.pl_utils._execute_db_command( + post_insert_query, (user_id, post_id, current_time), + commit=True) self.pl_utils._execute_db_command(update_shares_query, (post_id, ), commit=True) @@ -460,11 +461,10 @@ async def repost(self, agent_id: int, post_id: int): "error": "Repost record already exists." } - self.pl_utils._execute_db_command(post_insert_query, ( - user_id, - post_type_result['root_post_id'], - ), - commit=True) + self.pl_utils._execute_db_command( + post_insert_query, + (user_id, post_type_result['root_post_id'], current_time), + commit=True) self.pl_utils._execute_db_command( update_shares_query, (post_type_result['root_post_id'], ), commit=True) @@ -545,6 +545,9 @@ async def like_post(self, agent_id: int, post_id: int): else: current_time = os.environ["SANDBOX_TIME"] try: + post_type_result = self.pl_utils._get_post_type(post_id) + if post_type_result['type'] == 'repost': + post_id = post_type_result['root_post_id'] user_id = agent_id # Check if a like record already exists like_check_query = ("SELECT * FROM 'like' WHERE post_id = ? AND " @@ -582,6 +585,7 @@ async def like_post(self, agent_id: int, post_id: int): like_id = self.db_cursor.lastrowid # Record the action in the trace table + # if post has been reposted, record the root post id into trace action_info = {"post_id": post_id, "like_id": like_id} self.pl_utils._record_trace(user_id, ActionType.LIKE_POST.value, action_info, current_time) @@ -591,6 +595,9 @@ async def like_post(self, agent_id: int, post_id: int): async def unlike_post(self, agent_id: int, post_id: int): try: + post_type_result = self.pl_utils._get_post_type(post_id) + if post_type_result['type'] == 'repost': + post_id = post_type_result['root_post_id'] user_id = agent_id # Check if a like record already exists @@ -642,6 +649,9 @@ async def dislike_post(self, agent_id: int, post_id: int): else: current_time = os.environ["SANDBOX_TIME"] try: + post_type_result = self.pl_utils._get_post_type(post_id) + if post_type_result['type'] == 'repost': + post_id = post_type_result['root_post_id'] user_id = agent_id # Check if a dislike record already exists like_check_query = ( @@ -689,6 +699,9 @@ async def dislike_post(self, agent_id: int, post_id: int): async def undo_dislike_post(self, agent_id: int, post_id: int): try: + post_type_result = self.pl_utils._get_post_type(post_id) + if post_type_result['type'] == 'repost': + post_id = post_type_result['root_post_id'] user_id = agent_id # Check if a dislike record already exists @@ -1049,6 +1062,9 @@ async def create_comment(self, agent_id: int, comment_message: tuple): else: current_time = os.environ["SANDBOX_TIME"] try: + post_type_result = self.pl_utils._get_post_type(post_id) + if post_type_result['type'] == 'repost': + post_id = post_type_result['root_post_id'] user_id = agent_id # Insert the comment record diff --git a/oasis/social_platform/platform_utils.py b/oasis/social_platform/platform_utils.py index 2f5ab95..7da95f0 100644 --- a/oasis/social_platform/platform_utils.py +++ b/oasis/social_platform/platform_utils.py @@ -77,6 +77,7 @@ def _add_comments_to_posts(self, posts_results): self.db_cursor.execute(original_user_id_query, (original_post_id, )) original_user_id = self.db_cursor.fetchone()[0] + original_post_id = post_id post_id = post_type_result["root_post_id"] self.db_cursor.execute( "SELECT content, quote_content, created_at, num_likes, " @@ -140,7 +141,8 @@ def _add_comments_to_posts(self, posts_results): # Add post information and corresponding comments to the posts list posts.append({ "post_id": - post_id, + post_id + if post_type_result["type"] != "repost" else original_post_id, "user_id": user_id, "content": diff --git a/oasis/social_platform/schema/post.sql b/oasis/social_platform/schema/post.sql index 0452474..1a28469 100644 --- a/oasis/social_platform/schema/post.sql +++ b/oasis/social_platform/schema/post.sql @@ -4,7 +4,7 @@ CREATE TABLE post ( post_id INTEGER PRIMARY KEY AUTOINCREMENT, user_id INTEGER, original_post_id INTEGER, -- NULL if this is an original post - content TEXT, + content TEXT DEFAULT '', -- DEFAULT '' for initial posts quote_content TEXT, -- NULL if this is an original post or a repost created_at DATETIME, num_likes INTEGER DEFAULT 0, diff --git a/test/infra/database/test_post.py b/test/infra/database/test_post.py index 0543da9..2935534 100644 --- a/test/infra/database/test_post.py +++ b/test/infra/database/test_post.py @@ -59,24 +59,27 @@ async def receive_from(self): return ("id_", (2, 1, "repost")) elif self.call_count == 8: self.call_count += 1 - return ("id_", (2, 1, "repost")) + return ("id_", (2, 2, "like_post")) elif self.call_count == 9: self.call_count += 1 - return ("id_", (2, 3, "repost")) + return ("id_", (2, 1, "repost")) elif self.call_count == 10: self.call_count += 1 - return ("id_", (3, 2, "repost")) + return ("id_", (2, 3, "repost")) elif self.call_count == 11: self.call_count += 1 - return ("id_", (1, (1, 'I like the post.'), "quote_post")) + return ("id_", (3, 2, "repost")) elif self.call_count == 12: + self.call_count += 1 + return ("id_", (1, (1, 'I like the post.'), "quote_post")) + elif self.call_count == 13: self.call_count += 1 return ("id_", (2, (2, 'I quote to the reposted post.'), "quote_post")) - elif self.call_count == 13: + elif self.call_count == 14: self.call_count += 1 return ("id_", (1, 4, "repost")) - elif self.call_count == 14: + elif self.call_count == 15: self.call_count += 1 return ("id_", (2, (4, 'I quote to the quoted post.'), "quote_post")) @@ -116,18 +119,17 @@ async def send_to(self, message): assert message[2]["success"] is True assert "post_id" in message[2] elif self.call_count == 9: + assert message[2]["success"] is True + assert "like_id" in message[2] + elif self.call_count == 10: # Assert the success message for a repost assert message[2]["success"] is False assert message[2]["error"] == "Repost record already exists." - elif self.call_count == 10: + elif self.call_count == 11: # Assert the success message for a repost assert message[2]["success"] is False assert message[2]["error"] == "Post not found." - elif self.call_count == 11: - assert message[2]["success"] is True - assert "post_id" in message[2] elif self.call_count == 12: - # Assert the success message for a repost assert message[2]["success"] is True assert "post_id" in message[2] elif self.call_count == 13: @@ -135,9 +137,13 @@ async def send_to(self, message): assert message[2]["success"] is True assert "post_id" in message[2] elif self.call_count == 14: + # Assert the success message for a repost assert message[2]["success"] is True assert "post_id" in message[2] elif self.call_count == 15: + assert message[2]["success"] is True + assert "post_id" in message[2] + elif self.call_count == 16: # Assert the success message for a repost assert message[2]["success"] is True assert "post_id" in message[2] @@ -198,19 +204,23 @@ async def test_create_repost_like_unlike_post(setup_platform): post = posts[0] assert post[1] == 1 # Assuming user ID is 1 assert post[3] == "This is a test post" - assert post[6] == 1 # num_likes + assert post[6] == 2 # num_likes assert post[7] == 1 # num_dislikes assert post[8] == 5 # num_shares repost = posts[1] assert repost[1] == 2 # Repost user ID is 2 assert repost[2] == 1 # Original post ID is 1 - assert repost[3] is None # Reposted post has no content + assert repost[3] == '' # Reposted post is empty + print('created_at:', repost[5]) + assert repost[5] is not None # created_at + assert repost[6] == 0 # num_likes repost_2 = posts[2] assert repost_2[1] == 3 # Repost user ID is 2 assert repost_2[2] == 1 # Original post ID is 1 - assert repost_2[3] is None # Reposted post has no content + assert repost_2[3] == '' # Reposted post is empty + assert repost[5] is not None # created_at quote_post = posts[3] assert quote_post[1] == 1 # Repost user ID is @@ -234,7 +244,7 @@ async def test_create_repost_like_unlike_post(setup_platform): # Verify the like table has the correct data inserted cursor.execute("SELECT * FROM like") likes = cursor.fetchall() - assert len(likes) == 1 + assert len(likes) == 2 # Verify the dislike table has the correct data inserted cursor.execute("SELECT * FROM dislike") @@ -251,7 +261,7 @@ async def test_create_repost_like_unlike_post(setup_platform): cursor.execute("SELECT * FROM trace WHERE action='like_post'") results = cursor.fetchall() assert results is not None, "Like post action not traced" - assert len(results) == 2 + assert len(results) == 3 cursor.execute("SELECT * FROM trace WHERE action='unlike_post'") results = cursor.fetchall() diff --git a/test/infra/database/test_quote_repost_refresh.py b/test/infra/database/test_quote_repost_refresh.py index 3748ff2..cf5091d 100644 --- a/test/infra/database/test_quote_repost_refresh.py +++ b/test/infra/database/test_quote_repost_refresh.py @@ -123,7 +123,7 @@ async def send_to(self, message): assert posts[0]['comments'][0]['content'] == 'a comment' # Post 2 - assert posts[1]['post_id'] == 1 + assert posts[1]['post_id'] == 2 assert posts[1]['user_id'] == 2 assert posts[1]['content'] == ( 'User 2 reposted a post from User 1. Repost content: This is ' @@ -147,7 +147,7 @@ async def send_to(self, message): assert posts[3]['num_likes'] == 0 # Post 5 - assert posts[4]['post_id'] == 4 + assert posts[4]['post_id'] == 5 assert posts[4]['user_id'] == 1 assert posts[4]['content'] == ( 'User 1 reposted a post from User 2. Repost content: This is '