Skip to content

Commit

Permalink
Secure Source of Randomness (#6)
Browse files Browse the repository at this point in the history
Co-authored-by: pixeebot[bot] <104101892+pixeebot[bot]@users.noreply.github.com>
  • Loading branch information
pixeebot[bot] authored May 3, 2024
1 parent 1865f27 commit 0d7e3ee
Show file tree
Hide file tree
Showing 33 changed files with 140 additions and 140 deletions.
6 changes: 3 additions & 3 deletions backend/oasst_backend/prompt_repository.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import random
import re
from collections import defaultdict
from datetime import datetime, timedelta
Expand Down Expand Up @@ -39,6 +38,7 @@
from sqlalchemy.orm import Query
from sqlalchemy.orm.attributes import flag_modified
from sqlmodel import JSON, Session, and_, func, literal_column, not_, or_, text, update
import secrets

_task_type_and_reaction = (
(
Expand Down Expand Up @@ -688,9 +688,9 @@ def fetch_random_conversation(

if last_message_role:
conv_messages = [m for m in messages_tree if m.role == last_message_role]
conv_messages = [random.choice(conv_messages)]
conv_messages = [secrets.choice(conv_messages)]
else:
conv_messages = [random.choice(messages_tree)]
conv_messages = [secrets.choice(messages_tree)]
messages_tree = {m.id: m for m in messages_tree}

while True:
Expand Down
30 changes: 15 additions & 15 deletions backend/oasst_backend/tree_manager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import random
from datetime import datetime, timedelta
from enum import Enum
from http import HTTPStatus
Expand Down Expand Up @@ -37,6 +36,7 @@
from oasst_shared.utils import utcnow
from sqlalchemy.sql.functions import coalesce
from sqlmodel import Session, and_, func, not_, or_, text, update
import secrets


class TaskType(Enum):
Expand Down Expand Up @@ -302,7 +302,7 @@ def activate_one(db: Session) -> int:
weights = [data["reply_ranked_1"] + 1 for data in author_data]

# first select an author
prompt_author_id: UUID = random.choices(author_ids, weights=weights)[0]
prompt_author_id: UUID = secrets.SystemRandom().choices(author_ids, weights=weights)[0]
logger.info(f"Selected random prompt author {prompt_author_id} among {len(author_data)} candidates.")

# select random prompt of author
Expand All @@ -325,7 +325,7 @@ def activate_one(db: Session) -> int:
logger.warning("No prompt candidates of selected author found.")
return False

winner_prompt = random.choice(prompt_candidates)
winner_prompt = secrets.choice(prompt_candidates)
message: Message = winner_prompt.Message
logger.info(f"Prompt lottery winner: {message.id=}")

Expand Down Expand Up @@ -514,7 +514,7 @@ def next_task(
incomplete_rankings = list(filter(lambda m: m.role == "assistant", incomplete_rankings))

if len(incomplete_rankings) > 0:
ranking_parent_id = random.choice(incomplete_rankings).parent_id
ranking_parent_id = secrets.choice(incomplete_rankings).parent_id

messages = self.pr.fetch_message_conversation(ranking_parent_id)
assert len(messages) > 0 and messages[-1].id == ranking_parent_id
Expand All @@ -524,7 +524,7 @@ def next_task(
replies = self.pr.fetch_message_children(ranking_parent_id, review_result=True, deleted=False)

assert len(replies) > 1
random.shuffle(replies) # hand out replies in random order
secrets.SystemRandom().shuffle(replies) # hand out replies in random order
reply_messages = prepare_conversation_message_list(replies)
if any(not m.synthetic for m in reply_messages):
reveal_synthetic = False
Expand Down Expand Up @@ -565,7 +565,7 @@ def next_task(
replies_need_review = list(filter(lambda m: m.role == "assistant", replies_need_review))

if len(replies_need_review) > 0:
random_reply_message = random.choice(replies_need_review)
random_reply_message = secrets.choice(replies_need_review)
messages = self.pr.fetch_message_conversation(random_reply_message)

conversation = prepare_conversation(messages)
Expand All @@ -580,7 +580,7 @@ def next_task(
valid_labels = self.cfg.labels_assistant_reply
if (
desired_task_type == protocol_schema.TaskRequestType.random
and random.random() > self.cfg.p_full_labeling_review_reply_assistant
and secrets.SystemRandom().random() > self.cfg.p_full_labeling_review_reply_assistant
):
label_mode = protocol_schema.LabelTaskMode.simple
label_disposition = protocol_schema.LabelTaskDisposition.spam
Expand All @@ -605,7 +605,7 @@ def next_task(
valid_labels = self.cfg.labels_prompter_reply
if (
desired_task_type == protocol_schema.TaskRequestType.random
and random.random() > self.cfg.p_full_labeling_review_reply_prompter
and secrets.SystemRandom().random() > self.cfg.p_full_labeling_review_reply_prompter
):
label_mode = protocol_schema.LabelTaskMode.simple
label_disposition = protocol_schema.LabelTaskDisposition.spam
Expand Down Expand Up @@ -647,11 +647,11 @@ def next_task(
if 0 < p.active_children_count < self.cfg.lonely_children_count
and p.parent_role == "prompter"
]
if len(lonely_children_parents) > 0 and random.random() < self.cfg.p_lonely_child_extension:
random_parent = random.choice(lonely_children_parents)
if len(lonely_children_parents) > 0 and secrets.SystemRandom().random() < self.cfg.p_lonely_child_extension:
random_parent = secrets.choice(lonely_children_parents)

if random_parent is None:
random_parent = random.choice(extendible_parents)
random_parent = secrets.choice(extendible_parents)

# fetch random conversation to extend
logger.debug(f"selected {random_parent=}")
Expand All @@ -672,14 +672,14 @@ def next_task(

case TaskType.LABEL_PROMPT:
assert len(prompts_need_review) > 0
message = random.choice(prompts_need_review)
message = secrets.choice(prompts_need_review)
message = self.pr.fetch_message(message.id) # re-fetch message including emojis

label_mode = protocol_schema.LabelTaskMode.full
label_disposition = protocol_schema.LabelTaskDisposition.quality
valid_labels = self.cfg.labels_initial_prompt

if random.random() > self.cfg.p_full_labeling_review_prompt:
if secrets.SystemRandom().random() > self.cfg.p_full_labeling_review_prompt:
valid_labels = self.cfg.mandatory_labels_initial_prompt.copy()
label_mode = protocol_schema.LabelTaskMode.simple
label_disposition = protocol_schema.LabelTaskDisposition.spam
Expand Down Expand Up @@ -839,7 +839,7 @@ def _enter_state(self, mts: MessageTreeState, state: message_tree_state.State):
logger.info(f"Tree entered terminal '{mts.state}' state ({mts.message_tree_id=})")
root_msg = self.pr.fetch_message(message_id=mts.message_tree_id, fail_if_missing=False)
if root_msg and was_active:
if random.random() < self.cfg.p_activate_backlog_tree:
if secrets.SystemRandom().random() < self.cfg.p_activate_backlog_tree:
self.activate_backlog_tree(lang=root_msg.lang)

if self.cfg.min_active_rankings_per_lang > 0:
Expand Down Expand Up @@ -1509,7 +1509,7 @@ def _insert_default_state(
) -> MessageTreeState:
if goal_tree_size is None:
if self.cfg.random_goal_tree_size and self.cfg.min_goal_tree_size < self.cfg.goal_tree_size:
goal_tree_size = random.randint(self.cfg.min_goal_tree_size, self.cfg.goal_tree_size)
goal_tree_size = secrets.SystemRandom().randint(self.cfg.min_goal_tree_size, self.cfg.goal_tree_size)
else:
goal_tree_size = self.cfg.goal_tree_size
return self._insert_tree_state(
Expand Down
6 changes: 3 additions & 3 deletions data/datasets/TSSB-3M/generate_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
"""

import json
import random
import re
from os.path import join

from tqdm import tqdm
import secrets

INSTRUCTIONS_LIST = [
"Find the bug in the following code:",
Expand Down Expand Up @@ -48,12 +48,12 @@


def gen_instruction():
idx = random.randint(0, len(INSTRUCTIONS_LIST) - 1)
idx = secrets.SystemRandom().randint(0, len(INSTRUCTIONS_LIST) - 1)
return INSTRUCTIONS_LIST[idx]


def gen_response_prefix():
idx = random.randint(0, len(RESPONSE_PREFIX_WORDS) - 1)
idx = secrets.SystemRandom().randint(0, len(RESPONSE_PREFIX_WORDS) - 1)
return RESPONSE_PREFIX_WORDS[idx]


Expand Down
4 changes: 2 additions & 2 deletions data/datasets/logicreference_OA/generate_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@


import os
import random

import rules
import splits
import tensorflow as tf
from absl import app
import secrets

# Generation parameters:
# TARGET_FOLDER = "/path/to/generate/dataset/"
Expand Down Expand Up @@ -74,7 +74,7 @@ def main(_):

# Generate each of the splits:
print("IID:")
random.seed(RANDOM_SEED)
secrets.SystemRandom().seed(RANDOM_SEED)
(train_examples, test_examples) = splits.generate_training_and_test_sets_iid(
N_INFERENCE_PROBLEMS,
N_VARIATIONS,
Expand Down
4 changes: 2 additions & 2 deletions data/datasets/mt_note_generation/prepare.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import json
import math
import os
import random
import re
import sys
from string import punctuation

import kaggle
import pandas as pd
import secrets

CLINICAL_NOTE_GENERATION_TEMPLATE = """User: Write a clinical note about a patient with the following {section}: {section_information}.
Rosey: {note}"""
Expand Down Expand Up @@ -65,7 +65,7 @@ def main(output_dir: str = "data"):
kaggle.api.dataset_download_files("tboyle10/medicaltranscriptions", "data", unzip=True)
mt_samples = preprocess(pd.read_csv("data/mtsamples.csv"))
conversations = get_conversations(mt_samples)
random.shuffle(conversations)
secrets.SystemRandom().shuffle(conversations)
train_limit = math.ceil(len(conversations) * 0.6)
dev_limit = math.ceil(len(conversations) * 0.8)
train, validation, test = (
Expand Down
14 changes: 7 additions & 7 deletions data/datasets/poetry_instruction/prepare.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import json
import os
import random

import kaggle
import pandas as pd
import secrets

# Authenticate the Kaggle API client
kaggle.api.authenticate()
Expand Down Expand Up @@ -116,15 +116,15 @@
author = row["Poet"]

# Variables to store to instruction, reply, source, and metadata.
instruction = random.choice(writing_prompts_topic).replace("$topic", str(topics))
reply = random.choice(replies_topic).replace("$topic", str(topics)).replace("$title", title).replace("$poem", poem)
instruction = secrets.choice(writing_prompts_topic).replace("$topic", str(topics))
reply = secrets.choice(replies_topic).replace("$topic", str(topics)).replace("$title", title).replace("$poem", poem)
source = "PoetryFoundation.org" + " - " + author
metadata = {"author": author, "title": title, "tags": str(topics), "task_type": "writing"}

# If the entry has an empty value for the topic, use the non-topic prompts and replies.
if pd.isna(topics):
instruction = random.choice(writing_prompts_notTopic)
reply = random.choice(replies_notTopic).replace("$title", title).replace("$poem", poem)
instruction = secrets.choice(writing_prompts_notTopic)
reply = secrets.choice(replies_notTopic).replace("$title", title).replace("$poem", poem)

# Create a dictionary entry for the entry and append it to the list.
entry = {"INSTRUCTION": instruction, "RESPONSE": reply, "SOURCE": source, "METADATA": json.dumps(metadata)}
Expand All @@ -139,8 +139,8 @@
author = row["Poet"]

# Variables to store to instruction, reply, source, and metadata.
instruction = random.choice(titling_prompts).replace("$poem", poem)
reply = random.choice(titling_replies).replace("$title", title)
instruction = secrets.choice(titling_prompts).replace("$poem", poem)
reply = secrets.choice(titling_replies).replace("$title", title)
source = "PoetryFoundation.org" + " - " + author
metadata = {"author": author, "title": title, "tags": str(topics), "task_type": "titling"}

Expand Down
8 changes: 4 additions & 4 deletions data/datasets/reasoning_gsm_qna_oa/data_process.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import json
import random
import re
from dataclasses import dataclass

import pandas as pd
from datasets import load_dataset
import secrets

random.seed(42)
secrets.SystemRandom().seed(42)

random_list_python = [
"Make a python code.",
Expand All @@ -29,9 +29,9 @@

def qna_wrapper(source, random_list_python, random_list_answer):
def create_qna(row):
instruction = row["question"] if source == "gsm8k" else row["input"] + " " + random.choice(random_list_python)
instruction = row["question"] if source == "gsm8k" else row["input"] + " " + secrets.choice(random_list_python)
response = (
re.sub(r"(<<[\d\.\-\+\*=/\\]+>>)", "", row["answer"].replace("####", random.choice(random_list_answer)))
re.sub(r"(<<[\d\.\-\+\*=/\\]+>>)", "", row["answer"].replace("####", secrets.choice(random_list_answer)))
+ "."
if source == "gsm8k"
else row["code"]
Expand Down
6 changes: 3 additions & 3 deletions data/datasets/semantics_ws_qna_oa/data_process.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
import json
import random
from dataclasses import dataclass

import pandas as pd
import random_stuff
from datasets import load_dataset
import secrets

random.seed(42)
secrets.SystemRandom().seed(42)


# format to QnA
def qna_wrapper():
def create_qna(row):
# make a random number
random_num = random.randint(0, 2)
random_num = secrets.SystemRandom().randint(0, 2)

# extract rows' vals
lang = row["Language"]
Expand Down
Loading

0 comments on commit 0d7e3ee

Please sign in to comment.