diff --git a/README.md b/README.md index 628fa0b1..550778b1 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ Now you can [turn any raw text](#quickstart) into a high-quality custom dataset Being extensible, new pipelines can be added to Augmentoolkit incredibly easily, and there are already three of them: the original QA generation pipeline, the classifier creator, and a pipeline for generating creative writing data based on inputted fictional stories. -Augmentoolkit is an AI-powered tool that lets you create domain-specific datasets, using open-source AI. +Augmentoolkit is an AI-powered tool that lets you create domain-specific data, using open-source AI. ![](images/augmentoolkit-logo.png) diff --git a/augmentoolkit/control_flow_functions/control_flow_functions.py b/augmentoolkit/control_flow_functions/control_flow_functions.py new file mode 100644 index 00000000..1e569f2e --- /dev/null +++ b/augmentoolkit/control_flow_functions/control_flow_functions.py @@ -0,0 +1,1601 @@ +import os +import json +import re +import sys +from tqdm import asyncio as tqdmasyncio +from tqdm import tqdm +from augmentoolkit.utils.make_id import make_id +from augmentoolkit.utils.write_output_to_file import write_output_to_file +from augmentoolkit.generation_functions.safe_formatter import safe_format +from nltk.tokenize import sent_tokenize +import matplotlib.pyplot as plt +from collections import Counter +import logging +from math import ceil +import traceback +import glob +import yaml +from datasets import load_dataset +import chardet + + +from augmentoolkit.utils.create_conv_starter import create_conv_starter +from augmentoolkit.utils.extract_steps import extract_steps +from augmentoolkit.utils.escape_unescaped_quotes import escape_unescaped_quotes + +from augmentoolkit.generation_functions import ( + extract_question_answer, + identify_duplicates, + process_multiturn_functions, + extract_name, + random_name, + strip_steps, +) +from augmentoolkit.generation_functions.format_qatuples import format_qatuples + +from augmentoolkit.generation_functions.generation_step_class import GenerationStep +from augmentoolkit.generation_functions.special_instructions import special_instructions + +with open("./config.yaml", "r") as file: + obj_conf = yaml.safe_load(file) + +DEFAULT_PROMPT_PATH = obj_conf["PATH"]["DEFAULT_PROMPTS"] +HUB_PATH = obj_conf["HUGGINGFACE"]["HUB_PATH"] +PRIVATE = obj_conf["HUGGINGFACE"]["PRIVATE"] +PUSH_TO_HUB = obj_conf["HUGGINGFACE"]["PUSH_TO_HUB"] +has_pushed_yet = False + +def extract_qa_tuples(text): + pattern = r"\*\*QUESTION:\*\*\s*((?:.|\n)*?)\s*\*\*ANSWER:\*\*\s*((?:.|\n)*?)(?=\s*\*\*QUESTION:\*\*|\Z)" + matches = re.findall( + pattern, text + "\n\n**QUESTION:**", re.DOTALL + ) # The addition is a hack to get around the tricky lookahead problem + return [(question.strip(), answer.strip()) for question, answer in matches] + +import os + + +# Also used basically everywhere: +def convert_logging_to_dataset(directory): + print("entering saving mode") + global has_pushed_yet + + output_dir = os.path.join(obj_conf["PATH"]["OUTPUT"], directory) + + output_file_path = os.path.join(obj_conf["PATH"]["OUTPUT"], directory + "_DATAGEN_OUTPUT.jsonl") + + + + if not os.path.exists(output_dir): + raise Exception("ERROR!! Trying to convert a logging directory to a dataset, when that directory does not exist!") + + full_list_of_dicts = [] + with open(output_file_path, "w",encoding='utf-8') as f: + existing_files = glob.glob( + os.path.join(output_dir, "*.yaml") + ) + + for file in existing_files: + with open(file,'r') as file2: + file_list_of_dicts = yaml.safe_load(file2) + # print(file_list_of_dicts) + + sysprompt = {"from": "system", "value": file_list_of_dicts[0]["content"]} + input = {"from": "human", "value": file_list_of_dicts[-2]["content"]} + output = {"from": "gpt", "value": file_list_of_dicts[-1]["content"]} + + json_to_write = {"conversations": [sysprompt, input, output]} + + f.write(json.dumps(json_to_write) + "\n") + full_list_of_dicts.append(json_to_write) + print("...Converted successfully (we think)") + + dataset_with_split_output_file_path = os.path.join(obj_conf["PATH"]["OUTPUT"], directory + "_DATAGEN_OUTPUT_SPLIT.json") + with open(dataset_with_split_output_file_path, "w",encoding='utf-8') as f: + json_to_write = {"train": full_list_of_dicts} + + f.write(json.dumps(json_to_write) + "\n") + + + if PUSH_TO_HUB: + if os.path.exists(output_file_path): + dataset = load_dataset("json", data_files=dataset_with_split_output_file_path, split="train") + print("DATASET TYPE:") + print(type(dataset)) + part_nb = directory.split("_")[0] + if not has_pushed_yet: + dataset.push_to_hub(HUB_PATH, private=PRIVATE) + dataset.to_parquet(f"hf://datasets/{HUB_PATH}/train{part_nb}.parquet") + has_pushed_yet = True + else: + dataset.to_parquet(f"hf://datasets/{HUB_PATH}/train-{part_nb}.parquet") + # remove the output with split file + os.remove(dataset_with_split_output_file_path) + + + + + + +def convert_revised_questions_to_question_generation_training(qa_tuples_by_paragraph, use_filenames): + print("entering saving mode") + # found a solution to overfitting on the examples: + # TRAIN WITHOUT THEM + # This will produce a WEALTH of instruct data + # fucking awesome, hopefully + # also it's also about the domain, lmao + # so more domain knowledge + + output_file_path = os.path.join(obj_conf["PATH"]["OUTPUT"], "questions_generation_dataset.jsonl") + + if use_filenames: + question_generation_prompt = os.path.join(obj_conf["PATH"]["PROMPTS"], "qatuples_gen_filenames.yaml") + else: + question_generation_prompt = os.path.join(obj_conf["PATH"]["PROMPTS"], "qatuples_gen_no_filenames.yaml") + + with open(question_generation_prompt, "r",encoding='utf-8', errors="replace") as f: + qgen_prompt_full = yaml.safe_load(f) + + sysprompt = qgen_prompt_full[0]["content"] + input_template = qgen_prompt_full[-1]["content"] + + # revised_questions_output_path = os.path.join(obj_conf["PATH"]["OUTPUT"], "qatuples_revised") + convos = [] + with open(output_file_path, 'w',encoding='utf-8') as out_file: + for qatup_group in qa_tuples_by_paragraph: + answer = format_qatuples(qatup_group) + text = qatup_group[0][2] + + # print(text) + if not use_filenames: + input_text = safe_format(input_template, text=text) + else: + textname = qatup_group[0][3] + input_text = safe_format(input_template, text=text, textname=textname) + sysprompt_obj = {"from": "system", "value": sysprompt} + input_obj = {"from": "human", "value": input_text} + answer_obj = {"from": "gpt", "value": answer} + + convo = {"conversations": [sysprompt_obj, input_obj, answer_obj]} + out_file.write(json.dumps(convo) + "\n") + convos.append(convo) + + print("...Converted successfully (we think)") + if PUSH_TO_HUB: ## IMPORTANT STUFF FOR YOU BEGINS HERE ## + # temporarily create a json file with splits to load the dataset from + output_file_path = os.path.join(obj_conf["PATH"]["OUTPUT"], "questions_generation_dataset_split.json") + with open(output_file_path, 'w') as out_file_json: + json.dump({"train": convos},out_file_json) + dataset = load_dataset("json", data_files=output_file_path, split="train") # THIS APPROACH WORKS! + + with open(output_file_path[:-1], 'w') as out_file_json: + json.dump(convo,out_file_json) + dataset.to_parquet(f"hf://datasets/{HUB_PATH}/data/train-qgen.parquet") + os.remove(output_file_path) + + + + + +def extract_reasoning_from_context_check(response): + # print("\n----\/----\n RESPONSE:") + # print(response) + # print("\n\n\n---/\---\n\n") + decision_pattern = re.compile(r"Final judgment:(.+)", re.IGNORECASE) + determination = decision_pattern.search(response) + if determination: + determination = determination.group(1).strip() + if not determination: + print("LLM ISSUE: Did not contain a determination! Maybe check your LLM it is being stupid, or perhaps the input is diffuclt.") + return None, response + if "PASS" in determination: + print("Leaving be...") + return (True, response) # , completion + elif "REWORD" in determination: + print("Rewording...") + q, a = extract_question_answer.extract_question_answer(response) + print((q, a)) + if "the provided" in a.lower(): # catch infrequent cases where the reworded answer contains reference to provided information + print("'The provided' found in reworded answer -- Setting to None...") + return (False, response) + if "the reworded" in a.lower(): # Catch infrequent cases where it talks about the reworded question and answer pair + print("'The reworded' found in reworded answer -- Setting to None...") + return (False, response) + if "mention" in a.lower(): + print("'Mention' found in reworded answer -- Setting to None...") + return (False, response) + if "no information" in a.lower(): + print("'No information' found in reworded answer -- Setting to None...") + return (False, response) + if "follow the instructions in a separate" in a.lower(): + print("'Follow the instructions in a separate' found in reworded answer -- Setting to None...") + return (False, response) + return (q, a) # (q, a, qatuple[2], qatuple[3]), completion + elif "FAIL" in determination: + print("Setting to None...") + return (False, response) # , completion + else: + print("Did not contain relevant or irrelevant! Retrying") + # print("!!! RESPONSE !!!") + # print("\n\n\n---\/---\n\n") + # print(response) + # print("\n\n\n---/\---\n\n") + raise Exception("error in judgement extraction (ans relevancy)") + +# Postprocessing function for question/answer validation +async def repair_qatuple_context( + idx, + tup, + engine_wrapper, + writepath, + vetted_qa_tuples, + use_filenames=False, + completion_mode=None, + logging_level=logging.INFO, +): + # NOTE set up the generation step + context_repairer_path = "check_qatuple_context_no_filenames" + if use_filenames: + context_repairer_path = "check_qatuple_context_filenames" + if completion_mode: + context_repairer_path = context_repairer_path + ".txt" + else: + context_repairer_path = context_repairer_path + ".yaml" + + repair_context_regex = re.compile( + r"Reasoning and thought process \(be thorough\):(.+)", + re.DOTALL | re.IGNORECASE, + ) + context_repairer = GenerationStep( + prompt_path=context_repairer_path, + regex=repair_context_regex, + sampling_params={ + "max_tokens": 2000, + "stop": [ + "### Response", + "\n\n\n\n\n\n\n\n\n\n\n\n\n", + "", + "# Input:", + "[INST]", + "### Instruction", + "[INST", + "<|eot_id|>", + "<|start_header_id|>", + "<|end_header_id|>", + ], + "temperature": 0.2, + }, + completion_mode=completion_mode, + retries=1, + engine_wrapper=engine_wrapper, + logging_level=logging_level, + output_processor=extract_reasoning_from_context_check, + prompt_folder=obj_conf["PATH"]["PROMPTS"], + default_prompt_folder=DEFAULT_PROMPT_PATH, + use_stop=obj_conf["SYSTEM"]["STOP"] + ) + + # Resume normal control flow + file_path = os.path.join(writepath, f"revised_{idx}.json") + if os.path.exists(file_path): + with open(file_path, "r", encoding="utf-8", errors="replace") as f: + content = f.read() # Read the file once and store its content + print(file_path) + if content == "failed": + print("Loaded failed file") + vetted_qa_tuples[idx] = None + return None + print("Loaded file:") + print(content) + try: + data = json.loads(content) # Convert the string back to JSON + vetted_qa_tuples[idx] = (data[0], data[1], data[2], data[3], data[4], data[5], data[6]) + return None + except json.JSONDecodeError: + print("JSON decode error with the contents:", content) + + try: + revision_id = make_id() + revision, revision_output = await context_repairer.generate( + arguments={ + "textname": tup[3], + "question": tup[0], + "answer": tup[1], + } + ) + write_output_to_file( + revision_output, + obj_conf["PATH"]["OUTPUT"] + "/question_context_revision_generations", + revision_id, + ) # incidentally, identifying the problem and fixing it in the same step (without another planning step) works a lot better than identifying it and then trying to fix it in the next step. + if isinstance(revision[0], str): # if the thing was reworded + vetted_qa_tuples[idx] = ( + revision[0], + revision[1], + tup[2], + tup[3], + tup[4], + tup[5], + tup[6] + ) # replace the old tuple with the new one, revision doesn't have text name so we keep the old one + elif not revision[0]: + vetted_qa_tuples[ + idx + ] = None # prepare item for deletion later; right now we just store it as None because indexes + # else, if it passed, we just leave it be. + + # Write in-progress + if not os.path.exists(writepath): + os.makedirs(writepath) + + if vetted_qa_tuples[idx]: + with open(file_path, "w") as file: + json.dump(vetted_qa_tuples[idx], file, indent=4) + else: + with open(file_path, "w") as file: + file.write("failed") + + except Exception as e: + print("!!! ERROR!", e) + traceback.print_exc() + + +def parse_answer_accuracy_validation(response): + determination_pattern = re.compile( + r"Overall Accuracy Determination:(.+)", re.DOTALL + ) + try: + determination = determination_pattern.search(response).group(1).strip() + except Exception as e: + print("Error encountered, model messed up output format") + print(e) + return (False, response) + if ( + "inaccurate" in determination.lower() + or "Inaccurate" in determination.lower() + or "mostly" in determination.lower() + or "partial" in determination.lower() + or "irrelevant" in determination.lower() + ): # The "mostly" is there to catch "mostly accurate" which the model says occasionally, and which actually means inaccurate. + return (False, response) + elif "accurate" in determination.lower(): + return (True, response) + else: + print("Answer accuracy validation made a mistake") + raise Exception("answer accuracy validation did not include a judgement") + + +# Control flow helpers -- Question/Answer Validation +async def vet_answer_accuracy_loop( + qa_tuple, + run_id, + engine_wrapper=None, + double_check_counter=3, + completion_mode=None, + logging_level=None, + file_path=None, +): + # NOTE Set up answer check generation step + prompt_path_ans_accuracy_check = "check_answer" + if completion_mode: + prompt_path_ans_accuracy_check = prompt_path_ans_accuracy_check + ".txt" + else: + prompt_path_ans_accuracy_check = prompt_path_ans_accuracy_check + ".yaml" + check_ans_accuracy_regex = re.compile( + r"Reasoning and thought process \(the text is your single source of truth\):\n(.+)", + re.DOTALL, + ) + # TODO performance improvement could be gained by using async for to do the checks simultaneously + answer_accuracy_checker = GenerationStep( + prompt_path=prompt_path_ans_accuracy_check, + regex=check_ans_accuracy_regex, + sampling_params={ + "max_tokens": 1500, + "stop": [ + "### Response", + "\n\n\n\n\n", + "", + "# Input:", + "[INST]", + "### Instruction", + "[INST", + "<|eot_id|>", + "<|start_header_id|>", + "<|end_header_id|>", + ], + "temperature": 0.2, + }, + completion_mode=completion_mode, + retries=1, + engine_wrapper=engine_wrapper, + logging_level=logging_level, + output_processor=parse_answer_accuracy_validation, + prompt_folder=obj_conf["PATH"]["PROMPTS"], + default_prompt_folder=DEFAULT_PROMPT_PATH, + use_stop=obj_conf["SYSTEM"]["STOP"], + ) + + # Resume normal control flow code + + try: + qtuple = qa_tuple + # print( + # f"\n\nStarting ACCURACY loop for question: {qtuple[0]}, context: {qtuple[2]}" + # ) + passed_checks = 0 + times_checked = 0 + dissenting_reasoning = "" + while times_checked < double_check_counter: + check_id = make_id() + # print( + # f"\n\nACCURACY CALL CHECK ANSWER: {qtuple[0]}, context: {qtuple[2]}, retries: {total_retries}, dissenting reasoning: {dissenting_reasoning}" + # ) + judgement, answer_accuracy_output = await answer_accuracy_checker.generate( + arguments={ + "text": qtuple[2], + "question": qtuple[0], + "answer": qtuple[1], + } + ) + write_output_to_file( + answer_accuracy_output, + obj_conf["PATH"]["OUTPUT"] + "/check_answer_accuracy_generations", + run_id + "--check--" + check_id, + ) + if not judgement[0]: # if not accurate + dissenting_reasoning = judgement[1] + print("\nNegative Vote Cast! Here was the reasoning:\n") + print(dissenting_reasoning) + else: + passed_checks += 1 + times_checked += 1 + if passed_checks >= ceil(double_check_counter / 2): + break + failed_checks = times_checked - passed_checks + if failed_checks >= ceil(double_check_counter / 2): + break + + if passed_checks >= ceil(double_check_counter / 2): # if question checks passed + # print(f"\n\ANSWER ACCURACY CHECKS PASSED retries: {total_retries}") + return qtuple + else: + print("Answer accuracy validation failed! Tossing") + with open(file_path, "w") as file: + file.write("failed") + return + except Exception as e: + print("!!ERROR!!") + print(e) + traceback.print_exc() + + with open(file_path, "w") as file: + file.write("failed") + return + + +def parse_answer_relevancy_validation_step(thought_process): + judgement_pattern = re.compile( + r"Explanation of Judgment:(.+)", re.DOTALL | re.IGNORECASE + ) + try: + determination = judgement_pattern.search(thought_process).group(1).strip() + if ( + "irrelevant" in determination.lower() + or "mostly" in determination.lower() + or "partial" in determination.lower() + or "introduces information not present in the text" in determination.lower() + ): # Hack to get around faulty outputs + return (False, thought_process) # , completion + elif "relevant" in determination or "Relevant" in determination: + return (True, thought_process) # , completion + else: + print(f"Answer relevancy parsing failed! Retrying! {judgement_pattern}") + raise Exception("error in judgement extranction (ans relevancy)") + except Exception as e: + print("Model did not provide a judgement") + print(e) + # raise Exception("retry") + return (False, thought_process) + + +async def vet_answer_relevance_loop( + qa_tuple, + run_id, + engine_wrapper=None, + double_check_counter=3, + completion_mode=None, + logging_level=None, + file_path=None, +): + # NOTE Set up answer check generation step + prompt_path_ans_relevancy_check = "check_answer_relevancy_with_text" + check_ans_relevancy_regex = re.compile( + r"Reasoning and thought process \(be careful about extra details, even vague ones\):\n(.+)", + re.DOTALL | re.IGNORECASE, + ) + + if completion_mode: + prompt_path_ans_relevancy_check = prompt_path_ans_relevancy_check + ".txt" + else: + prompt_path_ans_relevancy_check = prompt_path_ans_relevancy_check + ".yaml" + + answer_relevancy_checker = GenerationStep( + prompt_path=prompt_path_ans_relevancy_check, + regex=check_ans_relevancy_regex, + sampling_params={ + "max_tokens": 1500, + "stop": [ + "### Response", + "\n\n\n\n\n\n", + "", + "# Input:", + "[INST]", + "### Instruction", + "[INST", + "<|eot_id|>", + "<|start_header_id|>", + "<|end_header_id|>", + ], + "temperature": 0.2, + }, + completion_mode=completion_mode, + retries=1, + engine_wrapper=engine_wrapper, + logging_level=logging_level, + output_processor=parse_answer_relevancy_validation_step, + prompt_folder=obj_conf["PATH"]["PROMPTS"], + default_prompt_folder=DEFAULT_PROMPT_PATH, + use_stop=obj_conf["SYSTEM"]["STOP"] + ) + + # Resume normal control flow code + try: + qtuple = qa_tuple + # print( + # f"\n\nStarting RELEVANCE loop for question: {qtuple[0]}, context: {qtuple[2]}" + # ) + passed_checks = 0 + times_checked = 0 + dissenting_reasoning = "" + while times_checked < double_check_counter: + check_id = make_id() + + # print( + # f"\n\nRELEVANCE CALL CHECK ANSWER: {qtuple[0]}, context: {qtuple[2]}, retries: {total_retries}, dissenting reasoning: {dissenting_reasoning}" + # ) + ( + judgement, + answer_relevancy_output, + ) = await answer_relevancy_checker.generate( + arguments={ + "text": qtuple[2], + "question": qtuple[0], + "answer": qtuple[1], + } + ) + write_output_to_file( + answer_relevancy_output, + obj_conf["PATH"]["OUTPUT"] + "/check_answer_relevancy_generations", + check_id, + ) + if not judgement[0]: # if not relevant + dissenting_reasoning = judgement[1] + print("\nNegative Vote Cast! Here was the reasoning:\n") + print(dissenting_reasoning) + else: + passed_checks += 1 + times_checked += 1 + if passed_checks >= ceil(double_check_counter / 2): + break + failed_checks = times_checked - passed_checks + if failed_checks >= ceil(double_check_counter / 2): + break + + if passed_checks >= ceil(double_check_counter / 2): # if question checks passed + # print(f"\n\ANSWER ACCURACY CHECKS PASSED retries: {total_retries}") + return await vet_answer_accuracy_loop( + qtuple, + run_id, + engine_wrapper=engine_wrapper, + double_check_counter=double_check_counter, + completion_mode=completion_mode, + logging_level=logging_level, + file_path=file_path + ) + else: + print("Answer relevancy validation failed! Tossing") + with open(file_path, "w") as file: + file.write("failed") + return + except Exception as e: + print("!!ERROR!!") + print(e) + traceback.print_exc() + + with open(file_path, "w") as file: + file.write("failed") + return + + +def parse_validation_step(response): + # print("!!! RESPONSE !!!") + # print(response) + decision_pattern = re.compile(r"Critical Evaluation and Final Judgment:(.+)", re.DOTALL | re.IGNORECASE) + determination = decision_pattern.search(response).group(1).strip() + # print("!!! DETERMINATION !!!") + # print(determination) + if ( + "irrelevant" in determination + or "Irrelevant" in determination.lower() + or "mostly" in determination.lower() + or "partial" in determination.lower() + or "introduces information not present in the text" in determination.lower() + ): + return ( + False, + response, + ) # TODO ensure that in the control flow code it passes on (False, response), completion + elif "relevant" in determination.lower(): + return (True, response) # TODO same as above(True, response), completion + else: + print("Did not contain relevant or irrelevant! Retrying") + raise Exception( + "Validation step screwed up and did not reach a conclusion! Retrying!" + ) + + +async def vet_question_loop( + qa_tuple, + question_group_id=None, + engine_wrapper=None, + qa_tuples_dir=None, # idx is qa_tuple[5]. Really should've used a dict at this point, oh well. + vetted_qa_tuples=None, + double_check_counter=3, + completion_mode=None, + logging_level=None, +): + try: + file_path = os.path.join(qa_tuples_dir, f"para_{qa_tuple[5]}_q_{qa_tuple[6]}.json") + idx = qa_tuple[5] + # Check for existing qa tuples + existing_files = glob.glob( + os.path.join(qa_tuples_dir, f"para_{idx}_q_{qa_tuple[6]}.json") + ) # check if qs already exist + + if len(existing_files) > 0: # If files exist, skip this paragraph entirely + print(f"Loading file") + for file_path in existing_files: + with open(file_path, "r", errors="replace") as file: + file_body = file.read() + if file_body == "failed": + qa_tuple = None + else: + file.seek(0) + qa_tuple = tuple(json.loads(file_body)) + vetted_qa_tuples.append(qa_tuple) + return + + + # NOTE Set up question check generation step + prompt_path_q_check = "check_question" + check_q_regex = re.compile( + r"Reasoning and thought process \(be careful around \"how\" and \"why\" questions\):(.+)", + re.DOTALL | re.IGNORECASE, + ) + + if completion_mode: + prompt_path_q_check = prompt_path_q_check + ".txt" + else: + prompt_path_q_check = prompt_path_q_check + ".yaml" + + question_checker = GenerationStep( + prompt_path=prompt_path_q_check, + regex=check_q_regex, + sampling_params={ + "max_tokens": 1500, + "stop": [ + "### Response", + "\n\n\n\n\n", + "", + "# Input:", + "[INST]", + "### Instruction", + "[INST", + "<|eot_id|>", + "<|start_header_id|>", + "<|end_header_id|>", + ], + "temperature": 0.2, + }, + completion_mode=completion_mode, + retries=1, + engine_wrapper=engine_wrapper, + logging_level=logging_level, + output_processor=parse_validation_step, + prompt_folder=obj_conf["PATH"]["PROMPTS"], + default_prompt_folder=DEFAULT_PROMPT_PATH, + use_stop=obj_conf["SYSTEM"]["STOP"], + ) + + # NOTE Set up generate new question step + # MODIFICATION: so that the conversations make sense, we just toss failed questions, rather than regenning. They're plentiful enough. + try: + qtuple = qa_tuple + # print( + # f"\n\nStarting QUESTION loop for question: {qtuple[0]}, context: {qtuple[2]}" + # ) + run_id = question_group_id + "--subquestion--" + make_id() + passed_checks = 0 + times_checked = 0 + dissenting_reasoning = "" + if obj_conf["SKIP"]["QUESTION_CHECK"]: + print("DEBUG: Skipping question check") + return await vet_answer_accuracy_loop( + qtuple, + run_id, + engine_wrapper=engine_wrapper, + double_check_counter=double_check_counter, + completion_mode=completion_mode, + logging_level=logging_level, + file_path=file_path + ) + while times_checked < double_check_counter: + check_id = make_id() + # print( + # f"\n\nQUESTION CALL CHECK ANSWER: {qtuple[0]}, context: {qtuple[2]}, retries: {total_retries}, dissenting reasoning: {dissenting_reasoning}" + # ) + judgement, check_q_output = await question_checker.generate( + arguments={"text": qtuple[2], "question": qtuple[0], "answer": qtuple[1]} + ) + + # Now we need to put the judgement together into the format it expects it to be in + + write_output_to_file( + check_q_output, + obj_conf["PATH"]["OUTPUT"] + "/check_question_generations", + run_id + "--check--" + check_id, + ) + + # print("JUDGEMENT:") + # print(judgement) + if not judgement[0]: # if not relevant + dissenting_reasoning = judgement[1] + print("\nNegative Vote Cast! Here was the reasoning:\n") + print(dissenting_reasoning) + print(f"ID: {check_id}") + else: + passed_checks += 1 + times_checked += 1 + if passed_checks >= ceil(double_check_counter / 2): + break + failed_checks = times_checked - passed_checks + if failed_checks >= ceil(double_check_counter / 2): + break + + if passed_checks >= ceil( + double_check_counter / 2 + ): # if all question checks passed + # print(f"\n\nQUESTION CHECKS PASSED retries: {total_retries}") + + if obj_conf["SKIP"]["ANSWER_RELEVANCY_CHECK"]: + res = await vet_answer_accuracy_loop( + qtuple, + run_id, + engine_wrapper=engine_wrapper, + double_check_counter=double_check_counter, + completion_mode=completion_mode, + logging_level=logging_level, + file_path=file_path + ) + else: + res = await vet_answer_relevance_loop( + qtuple, + run_id, + engine_wrapper=engine_wrapper, + double_check_counter=double_check_counter, + completion_mode=completion_mode, + logging_level=logging_level, + file_path=file_path + ) + + # Return response + + vetted_qa_tuples.append(res) + if res is not None: + with open(file_path, "w") as file: + json.dump(res, file, indent=4) + return + else: # this path is probably redundant + print("Question accuracy validation failed! Tossing") + with open(file_path, "w") as file: + file.write("failed") + return + except Exception as e: + print("!!ERROR!!") + print(e) + traceback.print_exc() + with open(file_path, "w") as file: + file.write("failed") + except Exception as e: + print(f"Q ERROR: {e}") + traceback.print_exc() + + +def extract_questions_from_response( + generation, +): # TODO extract to non-controlflow file + questions = extract_qa_tuples(generation) + if len(questions) == 0: + print("FAILED TO GENERATE QUESTIONS!") + return [] + return questions + + +def extract_question_from_response( + generation, +): # TODO extract to non-controlflow file + return extract_questions_from_response(generation)[0] + + +# Question generation +async def generate_qatuples_from_para( + idx, + para, + engine_wrapper_large=None, + generated_qa_tuples=None, + qa_tuples_dir=None, + use_filenames=False, + completion_mode=None, + logging_level=None, +): + + # NOTE Set up qatuple generation step # + prompt_path_qatuples_gen = "qatuples_gen_no_filenames" + if use_filenames: + prompt_path_qatuples_gen = "qatuples_gen_filenames" + + if completion_mode: + prompt_path_qatuples_gen = prompt_path_qatuples_gen + ".txt" + else: + prompt_path_qatuples_gen = prompt_path_qatuples_gen + ".yaml" + + qatuples_gen_regex = re.compile( + r"Questions \(make 4\):\n(.+)", re.IGNORECASE | re.DOTALL + ) + qatuples_generator = GenerationStep( + prompt_path=prompt_path_qatuples_gen, + regex=qatuples_gen_regex, + sampling_params={ + "max_tokens": 2000, + "stop": [ + "### Response", + "\n\n\n\n\n", + "", + "# Input:", + "[INST]", + "### Instruction", + "[INST", + "<|eot_id|>", + "<|start_header_id|>", + "<|end_header_id|>", + ], + "temperature": 0.8, + # top_k=-1, + "top_p": 1, + # min_p=0.5, + }, + completion_mode=completion_mode, + retries=3, + engine_wrapper=engine_wrapper_large, + logging_level=logging_level, + output_processor=extract_questions_from_response, + prompt_folder=obj_conf["PATH"]["PROMPTS"], + default_prompt_folder=DEFAULT_PROMPT_PATH, + use_stop=obj_conf["SYSTEM"]["STOP"] + ) + # Resume normal control flow code + try: + existing_files = glob.glob( + os.path.join(qa_tuples_dir, f"para_{idx}_*.json") + ) # check if qs already exist + + if len(existing_files) > 0: # If files exist, skip this paragraph entirely + print(f"Skipping para_{idx} as files already exist; loading said files") + for file_path in existing_files: + with open(file_path, "r", errors="replace") as file: + qa_tuple = tuple(json.load(file)) + generated_qa_tuples.append(qa_tuple) + return + question_group_id = make_id() + # print(f"\n\n\nOUTER LOOP CALL GENERATE QPLAN para: {para}, \n\n idx: {idx}") + # print( + # f"\n\n\nOUTER LOOP CALL GENERATE Q: {para}, \n\n idx: {idx} \n\n plan: {plan}" + # ) + ( + question_answer_tuples, + question_generation_output, + ) = await qatuples_generator.generate( + arguments={ + "text": para[0], + "textdetails": para[1], + } + ) + + question_answer_tuples_more_info = [ + (qatup[0], qatup[1], para[0], para[1], question_group_id, idx, qnum) for qnum, qatup in enumerate(question_answer_tuples) + ] + write_output_to_file( + question_generation_output, + obj_conf["PATH"]["OUTPUT"] + "/question_generation_generations", + question_group_id, + ) + + for qatup in question_answer_tuples_more_info: + generated_qa_tuples.append(qatup) + if qatup[0] is not None: + file_path = os.path.join(qa_tuples_dir, f"para_{qatup[5]}_q_{qatup[6]}.json") + with open(file_path, "w") as file: + json.dump(qatup, file, indent=4) + + except Exception as e: + print(f"Q ERROR: {e}") + traceback.print_exc() + + +def filter_and_graph(tuples): + # Count the occurrences of None and non-None for each source text + source_counts = Counter() + for paragraph, source in tuples: + if paragraph is None: + source_counts[source] = source_counts.get(source, [0, 0]) + source_counts[source][0] += 1 + else: + source_counts[source] = source_counts.get(source, [0, 0]) + source_counts[source][1] += 1 + + # Filter out tuples with None and return the new list + filtered_list = [t for t in tuples if t[0] is not None] + return filtered_list + + +## Paragraph Filtering (worthy for questions?) +async def determine_worthy( + idx, + p, + judged_worthy_for_questions, + output_dir, + judge: GenerationStep, +): + # for idx, p in tqdm(enumerate(paragraphs_processed[:10])): + id = make_id() + + file_name = f"{idx}.json" + file_path = os.path.join(output_dir, file_name) + # Check if the judgement for this paragraph already exists + if os.path.isfile(file_path): + with open(file_path, "r") as file: + data = json.load(file) + print("LOADING: ", data) + if isinstance(data, str): + judged_worthy_for_questions.append( + (None, data[7:]) + ) # hacky way of appending only the text name. See the file output of a failed judgement for details (Takes after "failed|") + else: + judged_worthy_for_questions.append((data["paragraph"], data["metadata"])) + else: + judgement, judgement_output = await judge.generate(arguments={"text": p[0], "textname": p[1]}) + write_output_to_file(judgement_output, obj_conf["PATH"]["OUTPUT"] + "/judge_paragraph_generations", id) + to_append = (None, p[1]) + if judgement: + to_append = (p[0], p[1]) + + judged_worthy_for_questions.append(to_append) + + # Prepare the data to be written to the file + if judgement: + # The paragraph passed the judgement + data_to_write = {"paragraph": to_append[0], "metadata": to_append[1]} + else: + # The paragraph did not pass the judgement + data_to_write = f"failed|{to_append[1]}" + + # Write the judgement to a unique file as JSON + with open(file_path, "w") as file: + json.dump(data_to_write, file) + + # Debug messages + try: + if judgement: + print(f"DEBUG model decided that index {idx} was suitable") + else: + print(f"DEBUG model decided that index {idx} was not suitable") + except: + print(f"DEBUG max retries exceeded for index {idx}") + + +def judge_paragraph_processor( + determination, +): # TODO extract to separate file to avoid muddying the control flow code + if "unsuitable" in determination.lower() or "table of contents" in determination.lower(): + return False # control flow has been modified to use the information it has, based on the determination of the output processors + elif "suitable" in determination.lower(): + return True + + +# EXEMPLAR +async def filter_all_questions( + paragraphs_processed, + judged_worthy_for_questions, + engine_wrapper, + output_dir, + take_subset=False, + subset_size=None, + use_filenames=False, + rtwl=None, + completion_mode=None, + logging_level=None, +): + if use_filenames: + prompt_path = "judge_paragraph_filenames" + else: + prompt_path = "judge_paragraph_no_filenames" + + judgement_regex = re.compile( + r"Reasoning and thought process \(reason intelligently\):(.+)", + re.DOTALL | re.IGNORECASE, + ) + + if completion_mode: + prompt_path = prompt_path + ".txt" + else: + prompt_path = prompt_path + ".yaml" + + judge = GenerationStep( + prompt_path=prompt_path, + regex=judgement_regex, + sampling_params={ + "max_tokens": 1450, + # "min_p": 0.4, + "stop": [ + "### Response", + "\n\n\n\n\n\n\n\n\n\n\n\n\n", + "", + "# Input:", + "[INST]", + "### Instruction", + "[INST", + "<|eot_id|>", + "<|start_header_id|>", + "<|end_header_id|>", + ], + "temperature": 0.2, + }, + completion_mode=completion_mode, + retries=2, + engine_wrapper=engine_wrapper, + logging_level=logging_level, # TODO change to warning + output_processor=judge_paragraph_processor, + # return_input_too=False, + prompt_folder=obj_conf["PATH"]["PROMPTS"], + default_prompt_folder=DEFAULT_PROMPT_PATH, + use_stop=obj_conf["SYSTEM"]["STOP"] + ) + if not take_subset: + tasks = [ + determine_worthy(idx, p, judged_worthy_for_questions, output_dir, judge) + for idx, p in enumerate(paragraphs_processed) + ] + else: + tasks = [ + determine_worthy(idx, p, judged_worthy_for_questions, output_dir, judge) + for idx, p in enumerate(paragraphs_processed[:subset_size]) + ] + limited_tasks = [rtwl(task) for task in tasks] + for future in tqdmasyncio.tqdm.as_completed(limited_tasks): + await future + + +def sentence_chunking_algorithm(file_path, max_char_length=1900): + """ + This function takes a plaintext file and chunks it into paragraphs or sentences if the paragraph exceeds max_char_length. + + :param file_path: Path to the plaintext file + :param max_char_length: The maximum char5acter length for a chunk + :return: List of chunks with source text information + """ + chunks_with_source = [] + current_chunk = [] + char_count = 0 + source_name = file_path.replace(".txt", "") + + + with open(file_path, 'r', encoding='utf-8', errors="replace") as file: + content = file.read() + + # try: + # with open(file_path, "r", encoding="utf-8") as f: + # content = f.read() + # except Exception as e: + # print(f"\nError reading file {file_path}: {e}\n") + # return [] + + paragraphs = content.split( + "\n\n" + ) # Assuming paragraphs are separated by two newlines # TODO change so that if the length is 1 after this, split by tabs instead + + # HOW TO DO IT probably: + # add tokens to the paragraph until we reach the max length, + # create chunks out of the remainder of the paragraph (split at max chunk length until it's done) + # if the final chunk does not have the max length, then make it the new current chunk, set the current token count to its length, and continue with the for loop. + # Ensure max_char_length is an integer + max_char_length = int(max_char_length) + + for paragraph in paragraphs: + paragraph = paragraph.strip() # Remove leading and trailing whitespace + if not paragraph: # Skip empty paragraphs + continue + + paragraph_char_count = len(paragraph) + + # Check if the paragraph itself exceeds the max token length + if paragraph_char_count > max_char_length: + + # Fallback to character chunking for this paragraph + end_index = ( + max_char_length - char_count + ) # after this we will take max_char_length chunks starting from end index until the end of the paragraph + current_chunk.append(paragraph[:end_index]) + # characters = list(paragraph) + chunks_with_source.append(("".join(current_chunk), source_name)) + current_chunk = [] + while end_index < paragraph_char_count: + current_chunk.append(paragraph[end_index : end_index + max_char_length]) + chunks_with_source.append(("".join(current_chunk), source_name)) + current_chunk = [] + end_index += max_char_length + + # # handle the remainder of the paragraph + # end_index = end_index - max_char_length + # current_chunk.append(paragraph[end_index:]) + + # char_count = paragraph_char_count - end_index + else: + if char_count + paragraph_char_count <= max_char_length: + current_chunk.append(paragraph) + char_count += paragraph_char_count + else: + chunks_with_source.append(("".join(current_chunk), source_name)) + current_chunk = [paragraph] + char_count = paragraph_char_count + + # Add the last chunk if it exists + if current_chunk: + chunks_with_source.append(("\n\n".join(current_chunk), source_name)) + + # filter out chunks with fewer than 50 characters + chunks_with_source = [chunk for chunk in chunks_with_source if len(chunk[0]) >= 50] + + return chunks_with_source + + +def fix_text(to_replace_arr, text): + for startup in to_replace_arr: + text = text.replace(startup[0], startup[1]) + return text + + +async def ensure_multiple_answers_are_same( + info, conv, multi_turn_conv_generator, completion_mode=None, conversation_instructions="For this conversation, you are generating a chat between a general-purpose AI assistant and a human." +): # why is this a whole separate function? Once upon a time, LLMs were used in validation here, too. But programmatic validation SEEMS to catch the common problems. This is here so that I can add it back in if I have to. + """Loop to ensure that the answer is consistent in the conversation and in the tuple.""" + retries = 0 + c = conv + while retries < 2: # try twice, since multiturn is an expensive operation + if process_multiturn_functions.call_all_processors( + c[0], info[0] + ): # if programmatic validation passes + return c + + retries += 1 + if retries >= 2: + return None + # If we're here, majority of relevance checks failed + print("----------------\n\n\n\nRETRYING!!!!\n\n\n\n----------------") + # Broken info is 1) rare and 2) handled by the retry limit. We don't want to waste compute on regenerating info as they take time. + retry = await make_multiturn_conversation( + info, multi_turn_conv_generator, completion_mode=completion_mode, conversation_instructions=conversation_instructions + ) + if retry is not None: # Note: retry CANNOT actually be None + c = retry + else: + # If we failed to generate a retry, don't waste compute + return None + + return None + + + +async def make_multiturn_conversation( + info, multi_turn_conv_generator, completion_mode=None, conversation_instructions="For this conversation, you are generating a chat between a general-purpose AI assistant and a human." +): + + conv, conv_output = await multi_turn_conv_generator.generate( + arguments={ + "question_answer_list": format_qatuples(info[0]).strip(), + "conversation_instructions": conversation_instructions + } + ) + write_output_to_file( + conv_output, + obj_conf["PATH"]["OUTPUT"] + "/multiturn_conversation_generations", + info[4], + ) + + return (conv, info[1], info[2], info[3], info[0]) + +async def create_info( + idx, + group, + multi_turn_convs_info, + multi_turn_convs_info_dir, +): + + file_path = os.path.join(multi_turn_convs_info_dir, f"info_{idx}.json") + + # Skip if file already exists + if not os.path.exists(file_path): + info = (group, "will", "be", "replaced", make_id()) + + with open(file_path, "w") as file: + json.dump(info, file, indent=4) + else: + with open(file_path, "r") as file: + info = json.load(file) + + multi_turn_convs_info.append( + [info] + ) # hacky-looking things because the legacy functionality was simplified. + +def read_json_files_info(directory): + # Create a list to hold the tuples + tuple_list = [] + + # Get all the .json files in the directory, sorted + json_files = sorted([f for f in os.listdir(directory) if f.endswith(".json")]) + + # Read each file and convert the contents + for file in json_files: + with open(os.path.join(directory, file), "r") as f: + data = json.load(f) + # Ensure the data is in the correct format before converting to tuple + if ( + isinstance(data, list) + and len(data) == 5 + and isinstance(data[0], list) + and all(len(item) == 7 for item in data[0]) + and all(isinstance(i, str) for i in data[1:]) + ): + tuple_list.append((data[0], data[1], data[2], data[3], data[4])) + + return tuple_list + + +async def create_conversation( + idx, + info, + engine_wrapper, + multi_turn_convs, + multi_turn_convs_dir, + completion_mode=None, + logging_level=logging.INFO, + conversation_instructions="For this conversation, you are generating a chat between a general-purpose AI assistant and a human." +): + file_path = os.path.join(multi_turn_convs_dir, f"conv_{idx}.json") + multi_turn_conversation_prompt_path = "multi_turn_assistant_conversation" + + conversation_regex = re.compile( + f"Conversation that answers the provided question \(be sure that you do not change the questions or answers themselves; AI Assistant will answer the questions, not ask them; the questions and answers provided should be copied word for word, and surrounded by compelling conversation\):\n(.+)", + re.IGNORECASE | re.DOTALL, + ) + + if completion_mode: + multi_turn_conversation_prompt_path = ( + multi_turn_conversation_prompt_path + ".txt" + ) + else: + multi_turn_conversation_prompt_path = ( + multi_turn_conversation_prompt_path + ".yaml" + ) + + multi_turn_conv_generator = GenerationStep( + prompt_path=multi_turn_conversation_prompt_path, + regex=conversation_regex, + sampling_params={ + "max_tokens": 2000, + "stop": [ + "### Response", + "\n\n\n\n\n", + "", + "# Input:", + "[INST]", + "### Instruction", + "### Information", + "## Information", + "## Instruction", + "Name:", + "<|eot_id|>", + "<|start_header_id|>", + "<|end_header_id|>", + ], + "temperature": 0.8, + # "top_k": -1, + "top_p": 1, + # "min_p": 0.6, + }, + completion_mode=completion_mode, + retries=1, + engine_wrapper=engine_wrapper, + logging_level=logging_level, + prompt_folder=obj_conf["PATH"]["PROMPTS"], + default_prompt_folder=DEFAULT_PROMPT_PATH, + use_stop=obj_conf["SYSTEM"]["STOP"], + ) + + # Skip if file already exists + if not os.path.exists(file_path): + try: + conv = await make_multiturn_conversation( + info, multi_turn_conv_generator, completion_mode=completion_mode, conversation_instructions=conversation_instructions + ) + final_conv = await ensure_multiple_answers_are_same( + info, conv, multi_turn_conv_generator, completion_mode=completion_mode, conversation_instructions=conversation_instructions + ) + + if final_conv is not None: + final_conv = ( + final_conv[0], + "AI Assistant", + "", + "N/A", + final_conv[4], + ) + with open(file_path, "w") as file: + json.dump(final_conv, file, indent=4) + + multi_turn_convs.append(final_conv) + except Exception as e: + traceback.print_exc() + print("Had an error, retrying...", e) + else: + try: + with open(file_path, "r", encoding="utf-8", errors="replace") as f: + data = json.load(f) + multi_turn_convs.append(data) + print(f"Skipped generating {file_path} as it already exists") + except Exception as e: + print(f"Error reading {file_path}:", e) + print("Continuing...") + + +def convert_directory_to_list(directory_path): + master_list = [] + simplified_list = [] + simplified_rag_list = [] + + for filename in os.listdir(directory_path): # for each file + if filename.endswith(".json"): # if it's a conversation file + filepath = os.path.join(directory_path, filename) # get the path + with open(filepath, "r") as file: # open it + try: + data = json.load(file) # load its data + if isinstance(data, list) and all( + isinstance(item, (list, str)) + for item in data # if it has the correct format + ): + + data_dict = { + "conversation": data[0], + "qa_tuples": [ + tup[:2] for tup in data[4] + ], # only take first two items from each tuple + "rag_context": data[4][0][2], + "source_filename": data[4][0][3], + } + master_list.append( + data_dict + ) # append it as-is to the master-list + + # Extract and process conversation + conversation, primary_char_desc = ( + data[0], + data[1], + ) # first and second items are conv and char desc + dialogues = process_multiturn_functions.extract_conversation( + conversation + ) + + # Convert to simplified format + simplified_conversations = [] + simplified_conversations_rag = [] + + # Load system prompts + system_prompt_norag = obj_conf["SYSTEM"][ + "FINAL_ASSISTANT_PROMPT_NO_RAG" + ] + system_prompt_rag = obj_conf["SYSTEM"][ + "FINAL_ASSISTANT_PROMPT_RAG" + ] + simplified_conversations.append( + {"from": "system", "value": system_prompt_norag} + ) + + simplified_conversations_rag.append( + { + "from": "system", + "value": system_prompt_rag.replace( + "{data}", data_dict["rag_context"] + ), + } + ) + for i, (charname, message) in enumerate( + dialogues + ): # Skipping the first message + from_person = "human" if (i % 2) == 0 else "gpt" + simplified_conversations.append( + {"from": from_person, "value": f"{message}"} + ) + simplified_conversations_rag.append( + { + "from": from_person, + "value": f"{message}", + } # same as above, but for the RAG context + ) + + if simplified_conversations: # If there are any conversations + simplified_list.append( + {"conversations": simplified_conversations} + ) + simplified_rag_list.append( + {"conversations": simplified_conversations_rag} + ) + except Exception as e: + print(f"Error reading {filename}: {e}") + + + + # Write the master list to a new .jsonl file + write_1 = obj_conf["PATH"]["OUTPUT"] + "/master_list.jsonl" + with open(write_1, "w") as file: + for item in master_list: + file.write(json.dumps(item) + "\n") + + # Process and push simplified_list (no RAG) + write_2 = obj_conf["PATH"]["OUTPUT"] + "/simplified_data_no_rag.jsonl" + with open(write_2, "w") as file: + for item in simplified_list: + file.write(json.dumps(item) + "\n") + + if PUSH_TO_HUB: + # Create a temporary JSON file with train split + temp_file_no_rag = obj_conf["PATH"]["OUTPUT"] + "/temp_simplified_data_no_rag.json" + with open(temp_file_no_rag, 'w') as temp_file: + json.dump({"train": simplified_list}, temp_file) + + # Load the dataset from the temporary file + dataset_no_rag = load_dataset("json", data_files=temp_file_no_rag, split="train") + + # Push to Hugging Face Hub + dataset_no_rag.to_parquet(f"hf://datasets/{HUB_PATH}/data/train-no_rag.parquet") + + # Remove the temporary file + os.remove(temp_file_no_rag) + + # Process and push simplified_rag_list (RAG) + write_3 = obj_conf["PATH"]["OUTPUT"] + "/simplified_data_rag.jsonl" + with open(write_3, "w") as file: + for item in simplified_rag_list: + file.write(json.dumps(item) + "\n") + + if PUSH_TO_HUB: + # Create a temporary JSON file with train split + temp_file_rag = obj_conf["PATH"]["OUTPUT"] + "/temp_simplified_data_rag.json" + with open(temp_file_rag, 'w') as temp_file: + json.dump({"train": simplified_rag_list}, temp_file) + + # Load the dataset from the temporary file + dataset_rag = load_dataset("json", data_files=temp_file_rag, split="train") + + # Push to Hugging Face Hub + dataset_rag.to_parquet(f"hf://datasets/{HUB_PATH}/data/train-rag.parquet") + + # Remove the temporary file + os.remove(temp_file_rag) + + print( + f"Conversion complete. Master list written to {write_1}. Simplified data written to {write_2} (no RAG) and {write_3} (RAG)." + ) + if PUSH_TO_HUB: + print("Data successfully pushed to Hugging Face Hub.") + + +def convert_directory_and_process_conversations(directory_path): + master_list = [] + + for filename in os.listdir(directory_path): + if filename.endswith(".json"): + filepath = os.path.join(directory_path, filename) + with open(filepath, "r") as file: + try: + data = json.load(file) + + if isinstance(data, list) and all( + isinstance(item, (list, str)) for item in data + ): + # Extract and process the conversation part + conversations = ( + process_multiturn_functions.extract_conversation(data[0]) + ) + # Convert tuples back to the formatted string as required + data[0] = [ + f"{charname}: {message}" + for charname, message in conversations + ] + master_list.append(data) + else: + print(f"File {filename} is not in the expected format.") + except: + print(f"Error reading {filename}") + + # Write the master list to a new file + with open(obj_conf["PATH"]["OUTPUT"] + "/processed_master_list.json", "w") as file: + json.dump(master_list, file) + + print( + "Conversion complete. The processed master list is written to 'processed_master_list.json'." + ) + +def create_pretraining_set(directory_path, json_file): + # Initialize a variable to store the combined text of all files + combined_text = "" + # Walk through all directories and files in the directory + for root, dirs, files in os.walk(directory_path): + pbar = tqdm(files) + pbar.set_description("Creating pretraining dataset (processing files; this may take a while due to encoding safety)") + for filename in pbar: + # Skip PDF files + if filename.lower().endswith('.pdf'): + continue + + file_path = os.path.join(root, filename) + # Read the contents of the file + try: + # First, detect the file encoding + with open(file_path, 'rb') as raw_file: + raw_data = raw_file.read() + result = chardet.detect(raw_data) + file_encoding = result['encoding'] + + # Now read the file with the detected encoding + with open(file_path, "r", encoding=file_encoding) as file: + file_contents = file.read() + # Append the file contents to the combined text, with a separator + if combined_text: + combined_text += "\n\n---NEW FILE---\n\n" + combined_text += file_contents + # print(f"Successfully read file: {file_path}") + except UnicodeDecodeError as e: + print(f"Error reading file {file_path}: {e}. Skipping.") + continue # Skip this file and continue with the next one + except IOError as e: + print(f"IOError reading file {file_path}: {e}. Skipping.") + continue # Skip this file and continue with the next one + + # Create a dictionary with the combined text + data = {"text": combined_text} + + try: + with open(json_file, "w", encoding='utf-8') as file: + json.dump(data, file, ensure_ascii=False) + print("JSON file saved successfully.") + except IOError as e: + print(f"Error saving JSON file: {e}") diff --git a/augmentoolkit/utils/pdf_to_text.py b/augmentoolkit/utils/pdf_to_text.py new file mode 100644 index 00000000..da30fbf6 --- /dev/null +++ b/augmentoolkit/utils/pdf_to_text.py @@ -0,0 +1,66 @@ +import os +from pypdf import PdfReader +import fitz # PyMuPDF +from PIL import Image +import pytesseract + + +def convert_pdf_to_text(pdf_path, output_folder): + base_name = os.path.splitext(os.path.basename(pdf_path))[0] + output_path = os.path.join(output_folder, f"{base_name}.txt") + + if os.path.exists(output_path): + print(f"Skipping already converted file: {output_path}") + return output_path + + try: + # Try to extract text directly + with open(pdf_path, 'rb') as file: + pdf_reader = PdfReader(file) + text = "" + for page in pdf_reader.pages: + try: + page_text = page.extract_text() + # Try different encodings if UTF-8 fails + encodings = ['utf-8', 'latin-1', 'ascii', 'utf-16'] + for encoding in encodings: + try: + text += page_text.encode(encoding).decode('utf-8') + "\n" + break + except UnicodeEncodeError: + continue + except UnicodeDecodeError: + continue + except Exception as e: + print(f"Error extracting text from page in {pdf_path}: {str(e)}") + continue # Skip this page and continue with the next + + if text.strip(): + with open(output_path, 'w', encoding='utf-8', errors='ignore') as out_file: + out_file.write(text) + return output_path + except Exception as e: + print(f"Error in direct text extraction for {pdf_path}: {str(e)}") + # If direct extraction fails, proceed to OCR + + # Use OCR for scanned PDFs + try: + doc = fitz.open(pdf_path) + text = "" + for page in doc: + try: + pix = page.get_pixmap() + img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) + page_text = pytesseract.image_to_string(img) + text += page_text + "\n" + except Exception as e: + print(f"Error processing page in {pdf_path}: {str(e)}") + continue # Skip this page and continue with the next + + with open(output_path, 'w', encoding='utf-8', errors='ignore') as out_file: + out_file.write(text) + return output_path + except Exception as e: + print(f"Error processing PDF {pdf_path}: {str(e)}") + return None + diff --git a/config.yaml b/config.yaml new file mode 100644 index 00000000..c1801cca --- /dev/null +++ b/config.yaml @@ -0,0 +1,46 @@ +PATH: + INPUT: "./input" + OUTPUT: "./output" + DEFAULT_PROMPTS: "./prompts" # the baseline prompt folder that Augmentoolkit falls back to if it can't find a step in the PROMPTS path + PROMPTS: "./prompts" # Where Augmentoolkit first looks for prompts +API: + API_KEY: "replace with your key" # Add the API key for your favorite provider here + BASE_URL: "https://api.together.xyz" # add the base url for a provider, or local server, here. Some possible values: http://127.0.0.1:5000/ # <- local models. # https://api.together.xyz # <- together.ai, which is real cheap, real flexible, and real high-quality, if a tad unreliable. # https://api.openai.com/v1/ # <- OpenAI. Will bankrupt you very fast. # anything else that accepts OAI-style requests, so basically any API out there (openrouter, fireworks, etc etc etc...) + LOGICAL_MODEL: "meta-llama/Llama-3-8b-chat-hf" # model used for everything except conversation generation at the very end + LARGE_LOGICAL_MODEL: "meta-llama/Llama-3-70b-chat-hf" # model used for question generation and conversation generation at the very end. A pretty tough task, if ASSISTANT_MODE isn't on. + QUANTIZATION_SMALL: "gptq" # Only use if Aphrodite mode is on. + QUANTIZATION_LARGE: "gptq" # Only use if Aphrodite mode is on. +SKIP: + QUESTION_CHECK: False + ANSWER_RELEVANCY_CHECK: False # turn on if using the negative question prompt override + FILTER_CHUNKS: False +SYSTEM: + CHUNK_SIZE: 1900 + USE_FILENAMES: False # give the AI context from the filenames provided to it. Useful if the filenames are meaningful, otherwise turn them off. + DOUBLE_CHECK_COUNTER: 1 # How many times to check a question and answer pair during each validation step. Majority vote decides if it passes that step. There are three steps. So most questions are by default checked around 9 times (fewer if the first two checks for a step pass, obviously). + SUBSET_SIZE: 10 + USE_SUBSET: False # Whether to take only the first 13 chunks from a text during the run. Useful for experimenting and iterating and seeing all the steps without costing too much money or time. + CONCURRENCY_LIMIT: 50 # Hard limit of how many calls can be run at the same time, useful for API mode (aphrodite automatically manages this and queues things, as far as I know) + COMPLETION_MODE: False # Change to false if you want to use chat (instruct) mode; this requires .json files in your chosen prompts directory, in the OpenAI API format. Not all APIs support completion mode. + MODE: "api" # can be one of "api"|"aphrodite" + STOP: True # True = Use stop tokens, False = do not use stop tokens. OpenAI's API restricts you to four stop tokens and all steps have way more than four stop tokens, so you'll need to turn this to False if you're using OAI's API. Also NOTE that if you turn this OFF while using COMPLETION MODE, EVERYTHING WILL BREAK and it will cost you money in the process. Don't do that. + CONVERSATION_INSTRUCTIONS: For this conversation, you are generating a chat between a generalist, generic AI assistant, and a human. + FINAL_ASSISTANT_PROMPT_NO_RAG: | + You are a helpful AI assistant. + FINAL_ASSISTANT_PROMPT_RAG: | + You are a helpful AI assistant. + + Context information is below: + + ---------------------- + {data} +PHASE: + WORK_IN_PHASES: False + PHASE_INDEX: 3 # index of the phase we are currently on (index 0 = filtering out chunks with no relevant context; index 1 = question generation; index 2 = question validation; index 3 = context revision and conversation generation, the final phase) +HUGGINGFACE: + HUB_PATH: "Heralax/test-atk-dataset-do-not-use-3" + PRIVATE: false + PUSH_TO_HUB: false + + +# May be much more efficient to rent H100s for large-scale inference than A100s \ No newline at end of file diff --git a/processing.py b/processing.py new file mode 100644 index 00000000..59b20175 --- /dev/null +++ b/processing.py @@ -0,0 +1,518 @@ +import sys +import os + + +from augmentoolkit.utils.pdf_to_text import convert_pdf_to_text +# Get the directory of the current script +script_dir = os.path.dirname(os.path.abspath(__file__)) +# Change the current working directory to the script directory +os.chdir(script_dir) +# Add the script directory to the Python path +sys.path.append(script_dir) +sys.path.append(os.path.dirname(os.path.abspath(__file__))) +import asyncio +import traceback + +import augmentoolkit.utils.group_by_text + +# created with nbconvert, minimally cleaned up + + +async def main(): + from tqdm import tqdm + # NOTE NOTEBOOK SETTINGS AND CONSTANTS (some script file constants are in generation_functions/constants.py) + + # Put your desired quant of your desired model in the relevant directories + + import logging + import yaml + import glob + from augmentoolkit.utils.group_by_text import group_by_text + from augmentoolkit.control_flow_functions import control_flow_functions + + with open("./config.yaml", "r") as f: + config = yaml.safe_load(f) + + if not os.path.exists(config["PATH"]["OUTPUT"]): + os.makedirs(config["PATH"]["OUTPUT"]) + + # "airoboros-l2-70b-3.1.2.Q4_K_M.gguf" <- recommended for the large logical model + # "flatorcamaid-13b-v0.2.Q8_0.gguf" <- recommended for the normal logical model + # A6000s on Vast.ai are a good choice for running this notebook + + if ( + not config["SYSTEM"]["COMPLETION_MODE"] + and config["SYSTEM"]["MODE"] == "aphrodite" + ): + raise Exception("Aphrodite engine mode MUST use completion prompts!") + + LOGICAL_MODEL = config["API"]["LOGICAL_MODEL"] + + LARGE_LOGICAL_MODEL = config["API"]["LARGE_LOGICAL_MODEL"] + + DOUBLE_CHECK_COUNTER = config["SYSTEM"][ + "DOUBLE_CHECK_COUNTER" + ] # Set to 1 to check outputs only once; set to 2 to check twice; set to 3 to check thrice, etc. Set to 0 to break everything in vet_question_loop() and elsewhere. Set to -1 and cause the universe to implode? + + USE_SUBSET = config["SYSTEM"][ + "USE_SUBSET" + ] # Set to True if you want to use only a small subset of the text, to test whether it plays nicely with the current setup of the notebook + + SUBSET_SIZE = config["SYSTEM"]["SUBSET_SIZE"] # Set to the number of chunks you want to use if you're using a subset. If you're not using a subset, this will be ignored. + + USE_FILENAMES = config["SYSTEM"][ + "USE_FILENAMES" + ] # Turn on if you want the model to use the names of your files as additional context (this is what original Augmentoolkit does). Useful if you have a small number of large input files grouped by subject matter, IE books. Turn off if you have a large number of files with meaningless names. + + CONCURRENCY_LIMIT = config["SYSTEM"][ + "CONCURRENCY_LIMIT" + ] # Adjust this number based on the rate limit constraints of your api + + API_KEY = config["API"]["API_KEY"] + + BASE_URL = config["API"][ + "BASE_URL" + ] # Augmentoolkit-API should also be compatible with any other API provider that accepts OAI-style requests + + COMPLETION_MODE = config["SYSTEM"]["COMPLETION_MODE"] + + MODE = config["SYSTEM"]["MODE"] + + LOG_LEVEL = logging.INFO + + INPUT_FOLDER = config["PATH"]["INPUT"] + + CONVERSATION_INSTRUCTIONS = config["SYSTEM"][ + "CONVERSATION_INSTRUCTIONS" + ] + + # Create pretraining set from raw inputs (pretrain first, then instruct tune) + + PHASE_INDEX = config["PHASE"]["PHASE_INDEX"] + + WORK_IN_PHASES = config["PHASE"]["WORK_IN_PHASES"] + + SKIP_FILTER_CHUNKS = config["SKIP"]["FILTER_CHUNKS"] + + print("Pretraining set created.") + + extensions = [".txt", ".md", ".pdf"] + + source_texts = [] + for extension in extensions: + path = f"{INPUT_FOLDER}/**/*{extension}" + files = glob.glob(path, recursive=True) + + pbar = tqdm(files) + for file in pbar: + pbar.set_description(f"Processing {file}") + if extension == ".pdf": + converted_file = convert_pdf_to_text(file, INPUT_FOLDER) + if converted_file: + source_texts.append(converted_file) + else: + source_texts.append(file) + + if source_texts: + print(source_texts) + else: + print(f"No source texts found in: {INPUT_FOLDER}") + + control_flow_functions.create_pretraining_set( + INPUT_FOLDER, os.path.join(config["PATH"]["OUTPUT"], "pretraining.json") + ) + + # Chunking step + sentence_chunks = [] + for source_text in source_texts: + if ".pdf" not in source_text: + sentence_chunks += control_flow_functions.sentence_chunking_algorithm( + source_text, config["SYSTEM"]["CHUNK_SIZE"] + ) + + print( + "\n\n\nIMPORTANT NOTE! Augmentoolkit prints a lot of stuff when it runs. Including tracebacks caused by model errors. Most errors are the result of the models, not the code, and any tracebacks you see were almost certainly handled. So: don't panic! You're gonna make it! Alright that's the end of this PSA. Happy dataset generation!\n\n\n" + ) + + + import uuid + + # This is in no way best practices, but all my prompts being searchable and separate files is a good way to make my life easier. + import pkgutil + import importlib + import sys + from tqdm import asyncio as tqdmasyncio + import asyncio + + # Set up rate-limit-conscious functions + semaphore = asyncio.Semaphore(int(CONCURRENCY_LIMIT)) + + async def run_task_with_limit(task): + async with semaphore: + # Run your task here + return await task + + # We have to define this up here so that two-step generation works, you'll see later. + multi_turn_convs_info_dir = ( + config["PATH"]["OUTPUT"] + "/multi_turn_convs_info" + ) # we generate all the information fed to the multiturn prompt, and generate the actual multiturn prompt, separately; since every step but the last is capable of being done by a 13b + + sys.path.append("./generation_functions") + sys.path.append("./control_flow_functions") + + import augmentoolkit.generation_functions as generation_functions # This is the package directory + from augmentoolkit.generation_functions.engine_wrapper_class import EngineWrapper + + engine_wrapper = EngineWrapper( + model=LOGICAL_MODEL, + api_key=API_KEY, + base_url=BASE_URL, + mode=MODE, + # quantization="gptq" # modify if you want to do stuff with the aphrodite branch + ) + + engine_wrapper_large = EngineWrapper( + model=LARGE_LOGICAL_MODEL, + api_key=API_KEY, + base_url=BASE_URL, + mode=MODE, + # quantization="gptq" # modify if you want to do stuff with the aphrodite branch + ) + + import re + from tqdm import tqdm + + + conversions = [("\n", " "), (" ", " ")] + + paragraphs_processed = [ + (control_flow_functions.fix_text(conversions, seq[0]), seq[1]) + for seq in sentence_chunks + ] + + len(paragraphs_processed) + + paragraphs_processed[0] + + print(paragraphs_processed[:3]) + + import json + + from tqdm import tqdm + import asyncio + + + if SKIP_FILTER_CHUNKS: + print("Skipping chunk filtering") + if USE_SUBSET: + filtered_worthy_for_questions = paragraphs_processed[:SUBSET_SIZE] + else: + filtered_worthy_for_questions = paragraphs_processed + else: + # Create directory if it doesn't exist + output_dir = config["PATH"]["OUTPUT"] + "/worthy_for_questions" + os.makedirs(output_dir, exist_ok=True) + + # Determine which paragraphs are worthy of making questions from + judged_worthy_for_questions = [] + + await control_flow_functions.filter_all_questions( + paragraphs_processed, + judged_worthy_for_questions, + engine_wrapper, + output_dir, + take_subset=USE_SUBSET, + subset_size=SUBSET_SIZE, + use_filenames=False, + rtwl=run_task_with_limit, + completion_mode=COMPLETION_MODE, + logging_level=LOG_LEVEL, + ) + + filtered_worthy_for_questions = control_flow_functions.filter_and_graph( + judged_worthy_for_questions + ) + + print("Converting generations to training data") + control_flow_functions.convert_logging_to_dataset("judge_paragraph_generations") + + print(filtered_worthy_for_questions[0]) + + # PHASE 0 END + print("\n\nCOMPLETED PHASE 0") + if WORK_IN_PHASES and PHASE_INDEX == 0: + sys.exit(0) + + ##### + + + + # ### The cell below begins generating questions. SOME OF THESE MAY FAIL and have to retry due to model errors (the API branch cannot use grammars). But if you let it run you will see that the vast majority eventually get through. + # + + # control flow + import json + + import glob + + # Directory for QA tuples + qa_tuples_dir_unchecked = config["PATH"]["OUTPUT"] + "/qatuples_raw" + if not os.path.exists(qa_tuples_dir_unchecked): + os.makedirs(qa_tuples_dir_unchecked) + + generated_qa_tuples = [] # tuple list of qa tuples that have been judged good + + # Attempt to initialize filtered_worthy_for_questions + try: + _ = filtered_worthy_for_questions + except NameError: + filtered_worthy_for_questions = [] + + if not filtered_worthy_for_questions: + # Load all files in the qa_tuples_dir if filtered_worthy_for_questions is not initialized + existing_files = glob.glob(os.path.join(qa_tuples_dir_unchecked, "*.json")) + for file_path in existing_files: + with open(file_path, "r") as file: + qa_tuple = tuple(json.load(file)) + print(f"Loaded {file}") + generated_qa_tuples.append(qa_tuple) + else: + tasks = [ + control_flow_functions.generate_qatuples_from_para( + idx, + para, + engine_wrapper_large=engine_wrapper_large, + generated_qa_tuples=generated_qa_tuples, + qa_tuples_dir=qa_tuples_dir_unchecked, + use_filenames=USE_FILENAMES, + completion_mode=COMPLETION_MODE, + logging_level=LOG_LEVEL, + ) + for idx, para in enumerate(filtered_worthy_for_questions) + ] + limited_tasks_qgen = [run_task_with_limit(task) for task in tasks] + for future in tqdmasyncio.tqdm.as_completed(limited_tasks_qgen): + await future + + + # only convert questions to training data if they passed validation + + # for qatup in generated_qa_tuples: + # if question_answer_tuple[0] is not None: + # file_path = os.path.join(qa_tuples_dir_unchecked, f"para_{question_answer_tuple[5]}_q_{qnum}.json") + # with open(file_path, "w") as file: + # json.dump(question_answer_tuple, file, indent=4) + + # PHASE 1 END + print("COMPLETED PHASE 1") + if WORK_IN_PHASES and PHASE_INDEX == 1: + print("EXITING DUE TO config.yaml SETTINGS AROUND PHASES; SET TO ONLY EXECUTE PHASE 1 RIGHT NOW") + sys.exit(0) + #### + + vetted_qa_tuples = [] + qa_tuples_dir_checked = config["PATH"]["OUTPUT"] + "/qatuples_filtered" + if not os.path.exists(qa_tuples_dir_checked): + os.makedirs(qa_tuples_dir_checked) + + # print(generated_qa_tuples[0]) + + tasks = [ + control_flow_functions.vet_question_loop( + question_answer_tuple, + question_group_id=question_answer_tuple[4], + engine_wrapper=engine_wrapper, + qa_tuples_dir=qa_tuples_dir_checked, + vetted_qa_tuples=vetted_qa_tuples, + double_check_counter=DOUBLE_CHECK_COUNTER, + completion_mode=COMPLETION_MODE, + logging_level=LOG_LEVEL, + ) for question_answer_tuple in generated_qa_tuples + ] + limited_tasks_q_validation = [run_task_with_limit(task) for task in tasks] + for future in tqdmasyncio.tqdm.as_completed(limited_tasks_q_validation): + await future + + + if WORK_IN_PHASES and PHASE_INDEX == 2: + print("EXITING DUE TO config.yaml SETTINGS AROUND PHASES; SET TO ONLY EXECUTE PHASE 2 RIGHT NOW") + sys.exit(0) + + print( + "-------------- QUESTIONS CREATED ------------- STATS SO FAR (may be wrong if run was continued from interruption):" + ) + nones = list(filter(lambda x: x is None, vetted_qa_tuples)) + print(f"Nones: {len(nones)}") + print(f"Non-nones: {len(vetted_qa_tuples) - len(nones)}") + print(f"Total: {len(vetted_qa_tuples)}") + # filter out all None values + vetted_qa_tuples = [qa for qa in vetted_qa_tuples if qa is not None] + print("---------------- ONTO REVISION ------------------") + + # Check for and fix the common mistake: mentioning "the text". + writepath = config["PATH"]["OUTPUT"] + "/qatuples_revised" + import json + + # Assuming vetted_qa_tuples is a list that might or might not exist + try: + _ = vetted_qa_tuples + except NameError: + vetted_qa_tuples = [] + + # Load all files at the start if vetted_qa_tuples is empty + if not vetted_qa_tuples: + print("WENT DOWN HERE") + # Check if the directory exists + if os.path.exists(writepath): + # List all files in directory + for file_name in os.listdir(writepath): + file_path = os.path.join(writepath, file_name) + try: # for each file already generated, see if it succeeded or failed; if it succeeded, append its contents; if it failed, append None for stats logging + with open(file_path, "r", encoding="utf-8", errors="replace") as f: + content = f.read() + print(f"Loading file: {file_path}") + if content == "failed": + vetted_qa_tuples.append(None) + else: + try: + data = json.loads(content) + vetted_qa_tuples.append( + (data[0], data[1], data[2], data[3]) + ) + except json.JSONDecodeError: + print("JSON decode error with the contents:", content) + vetted_qa_tuples.append(None) + except Exception as e: + print(f"Error reading {file_path}: {e}") + else: + tasks = [ + control_flow_functions.repair_qatuple_context( # NOTE PROBLEM in that things that this writes, do not have enough items in the tuple + idx, + tup, + engine_wrapper_large, + writepath, + vetted_qa_tuples, + use_filenames=USE_FILENAMES, + ) + for idx, tup in enumerate(vetted_qa_tuples) + ] + limited_tasks_qcorrection = [run_task_with_limit(task) for task in tasks] + for future in tqdmasyncio.tqdm.as_completed(limited_tasks_qcorrection): + await future + + # Print stats related to revised qatuples, and filter out nones (questions that were unanswerable due to lack of context). + import json + + + print("-------------- QUESTIONS REVISED ------------- STATS SO FAR:") + nones = list(filter(lambda x: x is None, vetted_qa_tuples)) + print(f"Nones: {len(nones)}") + print(f"Non-nones: {len(vetted_qa_tuples) - len(nones)}") + print(f"Total: {len(vetted_qa_tuples)}") + # filter out all None values + vetted_qa_tuples = [qa for qa in vetted_qa_tuples if qa is not None] + print("---------------- ONTO EXAMPLES GENERATION-------------------") + + qa_tuples_by_paragraph = augmentoolkit.utils.group_by_text.group_by_text(vetted_qa_tuples) + + print("Creating question generation training data...") + control_flow_functions.convert_revised_questions_to_question_generation_training(qa_tuples_by_paragraph=qa_tuples_by_paragraph, use_filenames=USE_FILENAMES) + + if not os.path.exists(multi_turn_convs_info_dir): + os.makedirs(multi_turn_convs_info_dir) + + import json + import random + import itertools + + multi_turn_convs_info = [] + + tasks = [ + control_flow_functions.create_info( + idx, + group, + multi_turn_convs_info, + multi_turn_convs_info_dir + ) + for idx, group in enumerate(qa_tuples_by_paragraph) + ] + limited_tasks_infocreation = [run_task_with_limit(task) for task in tasks] + for future in tqdmasyncio.tqdm.as_completed(limited_tasks_infocreation): + await future + + + + + import json + + convs_info = control_flow_functions.read_json_files_info(multi_turn_convs_info_dir) + + + import json + import random + import itertools + import asyncio + + multi_turn_convs_dir = config["PATH"]["OUTPUT"] + "/multi_turn_convs" + if not os.path.exists(multi_turn_convs_dir): + os.makedirs(multi_turn_convs_dir) + + multi_turn_convs = [] + + tasks = [ + control_flow_functions.create_conversation( + idx, + info, + engine_wrapper_large, + multi_turn_convs, + multi_turn_convs_dir, + completion_mode=COMPLETION_MODE, + logging_level=LOG_LEVEL, + conversation_instructions=CONVERSATION_INSTRUCTIONS + ) + for idx, info in enumerate(convs_info) + ] + limited_tasks_convwriting = [run_task_with_limit(task) for task in tasks] + for future in tqdmasyncio.tqdm.as_completed(limited_tasks_convwriting): + await future + + print("Converting conversational data generations to training data") + control_flow_functions.convert_logging_to_dataset("multiturn_conversation_generations") + + # # Yay! Now you have a dataset! + # ### GPT wrote the cell below. I think it successfully converts things to ShareGPT format for use with axolotl, but I am not sure because I don't know that format very well and haven't used Axolotl. However, the json produced by the second function looks fine. + + + import json + + # Make ShareGPT-format dataset (I think, still need verification it actually works) + control_flow_functions.convert_directory_to_list( + config["PATH"]["OUTPUT"] + "/multi_turn_convs/" + ) + # Make dataset in a format that has all the information. See README for details on this format. + control_flow_functions.convert_directory_and_process_conversations( + config["PATH"]["OUTPUT"] + "/multi_turn_convs/" + ) + + with open(config["PATH"]["OUTPUT"] + "/processed_master_list.json", "r") as f: + first = f.read() + data = json.loads(first) + + # For curiosity's sake, you can find out how many lines of dialogue you generated + def filter_and_flatten(lst): + flat_list = [] + + # Loop through each sublist in the main list + for sublst in lst: + # Check if the first element of the sublist is itself a list (subsublist1) + if isinstance(sublst[0], list): + # Extend the flat_list with the elements from subsublist1 + flat_list.extend(sublst[0]) + + return flat_list + + len(filter_and_flatten(data)) + print("COMPLETED FINAL PHASE") + + +asyncio.run(main()) diff --git a/test_convert_logging_to_dataset.py b/test_convert_logging_to_dataset.py new file mode 100644 index 00000000..7d752782 --- /dev/null +++ b/test_convert_logging_to_dataset.py @@ -0,0 +1,50 @@ +import glob +import json +import os + +import yaml + +with open("./config.yaml", "r") as file: + obj_conf = yaml.safe_load(file) + +def convert_logging_to_dataset(directory): + print("entering saving mode") + # found a solution to overfitting on the examples: + # TRAIN WITHOUT THEM + # This will produce a WEALTH of instruct data + # fucking awesome, hopefully + # also it's also about the domain, lmao + # so more domain knowledge + + output_dir = os.path.join(obj_conf["PATH"]["OUTPUT"], directory) + + output_file_path = os.path.join(obj_conf["PATH"]["OUTPUT"], directory + "_DATAGEN_OUTPUT.jsonl") + + + + if not os.path.exists(output_dir): + raise Exception("ERROR!! Trying to convert a logging directory to a dataset, when that directory does not exist!") + + with open(output_file_path, "w",encoding='utf-8') as f: + existing_files = glob.glob( + os.path.join(output_dir, "*.txt") + ) + + print(existing_files) + + for file in existing_files: + with open(file,'r') as file2: + file_list_of_dicts = yaml.safe_load(file2) + + # print(file_list_of_dicts) + + sysprompt = {"from": "system", "value": file_list_of_dicts[0]["content"]} + input = {"from": "human", "value": file_list_of_dicts[-2]["content"]} + output = {"from": "gpt", "value": file_list_of_dicts[-1]["content"]} + + json_to_write = {"conversations": [sysprompt, input, output]} + + f.write(json.dumps(json_to_write) + "\n") + print("...Converted successfully (we think)") + +convert_logging_to_dataset("judge_paragraph_generations") \ No newline at end of file