Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change repost action to align with twitter Recsys #31

Merged
merged 2 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 85 additions & 69 deletions oasis/social_platform/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,67 +252,68 @@ async def refresh(self, agent_id: int):
datetime.now(), self.start_time)
else:
current_time = os.environ["SANDBOX_TIME"]
# try:
user_id = agent_id
# Retrieve all post_ids for a given user_id from the rec table
rec_query = "SELECT post_id FROM rec WHERE user_id = ?"
self.pl_utils._execute_db_command(rec_query, (user_id, ))
rec_results = self.db_cursor.fetchall()

post_ids = [row[0] for row in rec_results]
selected_post_ids = post_ids
# If the number of post_ids >= self.refresh_rec_post_count,
# randomly select a specified number of post_ids
if len(selected_post_ids) >= self.refresh_rec_post_count:
selected_post_ids = random.sample(selected_post_ids,
self.refresh_rec_post_count)

if self.recsys_type != RecsysType.REDDIT:
# Retrieve posts from following (in network)
# Modify the SQL query so that the refresh gets posts from
# people the user follows, sorted by the number of likes on
# Twitter
query_following_post = (
"SELECT post.post_id, post.user_id, post.content, "
"post.created_at, post.num_likes FROM post "
"JOIN follow ON post.user_id = follow.followee_id "
"WHERE follow.follower_id = ? "
"ORDER BY post.num_likes DESC "
"LIMIT ?")
self.pl_utils._execute_db_command(
query_following_post,
(
user_id,
self.following_post_count,
),
)

following_posts = self.db_cursor.fetchall()
following_posts_ids = [row[0] for row in following_posts]

selected_post_ids = following_posts_ids + selected_post_ids
selected_post_ids = list(set(selected_post_ids))

placeholders = ", ".join("?" for _ in selected_post_ids)

post_query = (
f"SELECT post_id, user_id, original_post_id, content, "
f"quote_content, created_at, num_likes, num_dislikes, "
f"num_shares FROM post WHERE post_id IN ({placeholders})")
self.pl_utils._execute_db_command(post_query, selected_post_ids)
results = self.db_cursor.fetchall()
if not results:
return {"success": False, "message": "No posts found."}
results_with_comments = self.pl_utils._add_comments_to_posts(results)
try:
user_id = agent_id
# Retrieve all post_ids for a given user_id from the rec table
rec_query = "SELECT post_id FROM rec WHERE user_id = ?"
self.pl_utils._execute_db_command(rec_query, (user_id, ))
rec_results = self.db_cursor.fetchall()

post_ids = [row[0] for row in rec_results]
selected_post_ids = post_ids
# If the number of post_ids >= self.refresh_rec_post_count,
# randomly select a specified number of post_ids
if len(selected_post_ids) >= self.refresh_rec_post_count:
selected_post_ids = random.sample(selected_post_ids,
self.refresh_rec_post_count)

if self.recsys_type != RecsysType.REDDIT:
# Retrieve posts from following (in network)
# Modify the SQL query so that the refresh gets posts from
# people the user follows, sorted by the number of likes on
# Twitter
query_following_post = (
"SELECT post.post_id, post.user_id, post.content, "
"post.created_at, post.num_likes FROM post "
"JOIN follow ON post.user_id = follow.followee_id "
"WHERE follow.follower_id = ? "
"ORDER BY post.num_likes DESC "
"LIMIT ?")
self.pl_utils._execute_db_command(
query_following_post,
(
user_id,
self.following_post_count,
),
)

following_posts = self.db_cursor.fetchall()
following_posts_ids = [row[0] for row in following_posts]

selected_post_ids = following_posts_ids + selected_post_ids
selected_post_ids = list(set(selected_post_ids))

placeholders = ", ".join("?" for _ in selected_post_ids)

post_query = (
f"SELECT post_id, user_id, original_post_id, content, "
f"quote_content, created_at, num_likes, num_dislikes, "
f"num_shares FROM post WHERE post_id IN ({placeholders})")
self.pl_utils._execute_db_command(post_query, selected_post_ids)
results = self.db_cursor.fetchall()
if not results:
return {"success": False, "message": "No posts found."}
results_with_comments = self.pl_utils._add_comments_to_posts(
results)

action_info = {"posts": results_with_comments}
twitter_log.info(action_info)
self.pl_utils._record_trace(user_id, ActionType.REFRESH.value,
action_info, current_time)
action_info = {"posts": results_with_comments}
twitter_log.info(action_info)
self.pl_utils._record_trace(user_id, ActionType.REFRESH.value,
action_info, current_time)

return {"success": True, "posts": results_with_comments}
# except Exception as e:
# return {"success": False, "error": str(e)}
return {"success": True, "posts": results_with_comments}
except Exception as e:
return {"success": False, "error": str(e)}

async def update_rec_table(self):
# Recsys(trace/user/post table), refresh rec table
Expand Down Expand Up @@ -428,8 +429,8 @@ async def repost(self, agent_id: int, post_id: int):
}

post_type_result = self.pl_utils._get_post_type(post_id)
post_insert_query = ("INSERT INTO post (user_id, original_post_id)"
"VALUES (?, ?)")
post_insert_query = ("INSERT INTO post (user_id, original_post_id"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, the current_time is added, what would post["content"] be ? "None"?

", created_at) VALUES (?, ?, ?)")
# Update num_shares for the found post
update_shares_query = (
"UPDATE post SET num_shares = num_shares + 1 WHERE post_id = ?"
Expand All @@ -439,9 +440,9 @@ async def repost(self, agent_id: int, post_id: int):
return {"success": False, "error": "Post not found."}
elif (post_type_result['type'] == 'common'
or post_type_result['type'] == 'quote'):
self.pl_utils._execute_db_command(post_insert_query,
(user_id, post_id),
commit=True)
self.pl_utils._execute_db_command(
post_insert_query, (user_id, post_id, current_time),
commit=True)
self.pl_utils._execute_db_command(update_shares_query,
(post_id, ),
commit=True)
Expand All @@ -460,11 +461,10 @@ async def repost(self, agent_id: int, post_id: int):
"error": "Repost record already exists."
}

self.pl_utils._execute_db_command(post_insert_query, (
user_id,
post_type_result['root_post_id'],
),
commit=True)
self.pl_utils._execute_db_command(
post_insert_query,
(user_id, post_type_result['root_post_id'], current_time),
commit=True)
self.pl_utils._execute_db_command(
update_shares_query, (post_type_result['root_post_id'], ),
commit=True)
Expand Down Expand Up @@ -545,6 +545,9 @@ async def like_post(self, agent_id: int, post_id: int):
else:
current_time = os.environ["SANDBOX_TIME"]
try:
post_type_result = self.pl_utils._get_post_type(post_id)
if post_type_result['type'] == 'repost':
post_id = post_type_result['root_post_id']
user_id = agent_id
# Check if a like record already exists
like_check_query = ("SELECT * FROM 'like' WHERE post_id = ? AND "
Expand Down Expand Up @@ -582,6 +585,7 @@ async def like_post(self, agent_id: int, post_id: int):
like_id = self.db_cursor.lastrowid

# Record the action in the trace table
# if post has been reposted, record the root post id into trace
action_info = {"post_id": post_id, "like_id": like_id}
self.pl_utils._record_trace(user_id, ActionType.LIKE_POST.value,
action_info, current_time)
Expand All @@ -591,6 +595,9 @@ async def like_post(self, agent_id: int, post_id: int):

async def unlike_post(self, agent_id: int, post_id: int):
try:
post_type_result = self.pl_utils._get_post_type(post_id)
if post_type_result['type'] == 'repost':
post_id = post_type_result['root_post_id']
user_id = agent_id

# Check if a like record already exists
Expand Down Expand Up @@ -642,6 +649,9 @@ async def dislike_post(self, agent_id: int, post_id: int):
else:
current_time = os.environ["SANDBOX_TIME"]
try:
post_type_result = self.pl_utils._get_post_type(post_id)
if post_type_result['type'] == 'repost':
post_id = post_type_result['root_post_id']
user_id = agent_id
# Check if a dislike record already exists
like_check_query = (
Expand Down Expand Up @@ -689,6 +699,9 @@ async def dislike_post(self, agent_id: int, post_id: int):

async def undo_dislike_post(self, agent_id: int, post_id: int):
try:
post_type_result = self.pl_utils._get_post_type(post_id)
if post_type_result['type'] == 'repost':
post_id = post_type_result['root_post_id']
user_id = agent_id

# Check if a dislike record already exists
Expand Down Expand Up @@ -1049,6 +1062,9 @@ async def create_comment(self, agent_id: int, comment_message: tuple):
else:
current_time = os.environ["SANDBOX_TIME"]
try:
post_type_result = self.pl_utils._get_post_type(post_id)
if post_type_result['type'] == 'repost':
post_id = post_type_result['root_post_id']
user_id = agent_id

# Insert the comment record
Expand Down
4 changes: 3 additions & 1 deletion oasis/social_platform/platform_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def _add_comments_to_posts(self, posts_results):
self.db_cursor.execute(original_user_id_query,
(original_post_id, ))
original_user_id = self.db_cursor.fetchone()[0]
original_post_id = post_id
post_id = post_type_result["root_post_id"]
self.db_cursor.execute(
"SELECT content, quote_content, created_at, num_likes, "
Expand Down Expand Up @@ -140,7 +141,8 @@ def _add_comments_to_posts(self, posts_results):
# Add post information and corresponding comments to the posts list
posts.append({
"post_id":
post_id,
post_id
if post_type_result["type"] != "repost" else original_post_id,
"user_id":
user_id,
"content":
Expand Down
2 changes: 1 addition & 1 deletion oasis/social_platform/schema/post.sql
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ CREATE TABLE post (
post_id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER,
original_post_id INTEGER, -- NULL if this is an original post
content TEXT,
content TEXT DEFAULT '', -- DEFAULT '' for initial posts
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Redtides0 I changed the default value of content as a empty string. I think this might be a more convenient way to implementat it? WDYT?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

understood, LGTM.

quote_content TEXT, -- NULL if this is an original post or a repost
created_at DATETIME,
num_likes INTEGER DEFAULT 0,
Expand Down
42 changes: 26 additions & 16 deletions test/infra/database/test_post.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,24 +59,27 @@ async def receive_from(self):
return ("id_", (2, 1, "repost"))
elif self.call_count == 8:
self.call_count += 1
return ("id_", (2, 1, "repost"))
return ("id_", (2, 2, "like_post"))
elif self.call_count == 9:
self.call_count += 1
return ("id_", (2, 3, "repost"))
return ("id_", (2, 1, "repost"))
elif self.call_count == 10:
self.call_count += 1
return ("id_", (3, 2, "repost"))
return ("id_", (2, 3, "repost"))
elif self.call_count == 11:
self.call_count += 1
return ("id_", (1, (1, 'I like the post.'), "quote_post"))
return ("id_", (3, 2, "repost"))
elif self.call_count == 12:
self.call_count += 1
return ("id_", (1, (1, 'I like the post.'), "quote_post"))
elif self.call_count == 13:
self.call_count += 1
return ("id_", (2, (2, 'I quote to the reposted post.'),
"quote_post"))
elif self.call_count == 13:
elif self.call_count == 14:
self.call_count += 1
return ("id_", (1, 4, "repost"))
elif self.call_count == 14:
elif self.call_count == 15:
self.call_count += 1
return ("id_", (2, (4, 'I quote to the quoted post.'),
"quote_post"))
Expand Down Expand Up @@ -116,28 +119,31 @@ async def send_to(self, message):
assert message[2]["success"] is True
assert "post_id" in message[2]
elif self.call_count == 9:
assert message[2]["success"] is True
assert "like_id" in message[2]
elif self.call_count == 10:
# Assert the success message for a repost
assert message[2]["success"] is False
assert message[2]["error"] == "Repost record already exists."
elif self.call_count == 10:
elif self.call_count == 11:
# Assert the success message for a repost
assert message[2]["success"] is False
assert message[2]["error"] == "Post not found."
elif self.call_count == 11:
assert message[2]["success"] is True
assert "post_id" in message[2]
elif self.call_count == 12:
# Assert the success message for a repost
assert message[2]["success"] is True
assert "post_id" in message[2]
elif self.call_count == 13:
# Assert the success message for a repost
assert message[2]["success"] is True
assert "post_id" in message[2]
elif self.call_count == 14:
# Assert the success message for a repost
assert message[2]["success"] is True
assert "post_id" in message[2]
elif self.call_count == 15:
assert message[2]["success"] is True
assert "post_id" in message[2]
elif self.call_count == 16:
# Assert the success message for a repost
assert message[2]["success"] is True
assert "post_id" in message[2]
Expand Down Expand Up @@ -198,19 +204,23 @@ async def test_create_repost_like_unlike_post(setup_platform):
post = posts[0]
assert post[1] == 1 # Assuming user ID is 1
assert post[3] == "This is a test post"
assert post[6] == 1 # num_likes
assert post[6] == 2 # num_likes
assert post[7] == 1 # num_dislikes
assert post[8] == 5 # num_shares

repost = posts[1]
assert repost[1] == 2 # Repost user ID is 2
assert repost[2] == 1 # Original post ID is 1
assert repost[3] is None # Reposted post has no content
assert repost[3] == '' # Reposted post is empty
print('created_at:', repost[5])
assert repost[5] is not None # created_at
assert repost[6] == 0 # num_likes

repost_2 = posts[2]
assert repost_2[1] == 3 # Repost user ID is 2
assert repost_2[2] == 1 # Original post ID is 1
assert repost_2[3] is None # Reposted post has no content
assert repost_2[3] == '' # Reposted post is empty
assert repost[5] is not None # created_at

quote_post = posts[3]
assert quote_post[1] == 1 # Repost user ID is
Expand All @@ -234,7 +244,7 @@ async def test_create_repost_like_unlike_post(setup_platform):
# Verify the like table has the correct data inserted
cursor.execute("SELECT * FROM like")
likes = cursor.fetchall()
assert len(likes) == 1
assert len(likes) == 2

# Verify the dislike table has the correct data inserted
cursor.execute("SELECT * FROM dislike")
Expand All @@ -251,7 +261,7 @@ async def test_create_repost_like_unlike_post(setup_platform):
cursor.execute("SELECT * FROM trace WHERE action='like_post'")
results = cursor.fetchall()
assert results is not None, "Like post action not traced"
assert len(results) == 2
assert len(results) == 3

cursor.execute("SELECT * FROM trace WHERE action='unlike_post'")
results = cursor.fetchall()
Expand Down
4 changes: 2 additions & 2 deletions test/infra/database/test_quote_repost_refresh.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ async def send_to(self, message):
assert posts[0]['comments'][0]['content'] == 'a comment'

# Post 2
assert posts[1]['post_id'] == 1
assert posts[1]['post_id'] == 2
assert posts[1]['user_id'] == 2
assert posts[1]['content'] == (
'User 2 reposted a post from User 1. Repost content: This is '
Expand All @@ -147,7 +147,7 @@ async def send_to(self, message):
assert posts[3]['num_likes'] == 0

# Post 5
assert posts[4]['post_id'] == 4
assert posts[4]['post_id'] == 5
assert posts[4]['user_id'] == 1
assert posts[4]['content'] == (
'User 1 reposted a post from User 2. Repost content: This is '
Expand Down
Loading