Skip to content

Commit

Permalink
feat: qa pairs generation script
Browse files Browse the repository at this point in the history
  • Loading branch information
jamnicki committed Apr 27, 2024
1 parent 103a2e2 commit 41749dc
Show file tree
Hide file tree
Showing 9 changed files with 664 additions and 980 deletions.
49 changes: 49 additions & 0 deletions juddges/data/qa_pairs_json_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import re
from typing import Any, List, Callable
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.output_parsers.json import parse_partial_json, _parse_json
from langchain_core.outputs import Generation
from langchain_core.exceptions import OutputParserException
from json.decoder import JSONDecodeError

from juddges.data.synthetic.patterns import CUSTOM_PARSE_JSON_MARKDOWN


class QAPairsJsonParser(JsonOutputParser):

def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
text = result[0].text
text = text.strip()
if partial:
try:
return parse_json_markdown(text)
except JSONDecodeError:
return None
else:
try:
return parse_json_markdown(text)
except JSONDecodeError as e:
msg = f"Invalid json output: {text}"
raise OutputParserException(msg, llm_output=text) from e


def parse_json_markdown(
json_string: str, *, parser: Callable[[str], Any] = parse_partial_json
) -> dict:
"""Modified version of `langchain_core.output_parsers.json:parse_json_markdown`
Fixes: JSONDecodeError when parsing CoT like output that contains multiple JSON strings
"""
try:
return _parse_json(json_string, parser=parser)
except JSONDecodeError:
# Try to find the last JSON string within triple backticks
match = CUSTOM_PARSE_JSON_MARKDOWN.findall(json_string)

# If no match found, assume the entire string is a JSON string
if match is None:
json_str = json_string
else:
# If match found, use the content within the backticks
json_str = match[-1]
return _parse_json(json_str, parser=parser)
Empty file.
17 changes: 17 additions & 0 deletions juddges/data/synthetic/generation_prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
JUDGEMENTS_QA_COT_PROMPT_V1 = '''\
You are a question-answer generator. Your goal is to generate question-answer pairs given the Context.
Do not tranlate the Context, generate questions and answers in original language.
Context: {context}
Step 1: Identify spans that are likely to be answers to questions, identify as many as possible.
Step 2: For each identified span, generate a question.
Step 3: Respond to the question in only a few tokens concisely.
Ensure that you distinctly label and delineate Steps 1, 2 and 3.
{format_instructions}
Output:
```{format_md_ext}
'''
7 changes: 7 additions & 0 deletions juddges/data/synthetic/patterns.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import re


CUSTOM_PARSE_JSON_MARKDOWN = re.compile(
pattern=r"```(?:json)?([^`]+)(?:```)?\s*$",
flags=re.IGNORECASE
)
25 changes: 25 additions & 0 deletions juddges/data/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from pathlib import Path
from typing import Iterable, Generator
from jsonlines import jsonlines
from datetime import datetime


def save_jsonl(records: Iterable[dict], out: Path | str):
"""Save a list of dictionaries to a jsonl file."""
with jsonlines.open(out, mode="w") as writer:
writer.write_all(records)


def read_jsonl(path: Path | str) -> Generator[dict, None, None]:
"""Read a jsonl file and yield dictionaries."""
with jsonlines.open(path) as reader:
yield from reader


def path_safe_udate() -> str:
"""Generate a unique timestamp string for file naming.
Returns:
str: A string with the current date and time in the %Y%m%d_%H%M%Sf%f format
"""
return datetime.now().strftime("%Y%m%d_%H%M%Sf%f")
1,344 changes: 364 additions & 980 deletions notebooks/3_lcel_synth_legis_data_gen.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ mpire==2.10.0
tenacity==8.2.3
loguru==0.7.2
typer==0.9.0
jsonlines==4.0.0

datasets==2.18.0
pyarrow==15.0.0
Expand Down
128 changes: 128 additions & 0 deletions scripts/gen_synthetic_judgements_qa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import typer
from typing import List
from loguru import logger
from dotenv import load_dotenv
import json

from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.prompts import ChatPromptTemplate
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline

from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

from juddges.data.utils import read_jsonl
from juddges.data.synthetic.generation_prompt import JUDGEMENTS_QA_COT_PROMPT_V1
from juddges.data.qa_pairs_json_parser import QAPairsJsonParser

load_dotenv("secrets.env", verbose=True)



class SyntheticLegisQAPairs(BaseModel):
questions: List[str] = Field(description="List of generated questions")
answers: List[str] = Field(description="List of generated answers")

def test_empty(self):
assert len(self.questions) > 0, "At least one question should be generated"

def test_equal_length(self):
assertion_msg = "Number of questions and answers should be equal"
assert len(self.questions) == len(self.answers), assertion_msg

def test_q_duplicates(self):
assertion_msg = "Questions should be unique"
assert len(set(self.questions)) == len(self.questions), assertion_msg

def test_a_duplicates(self):
assertion_msg = "Answers should be unique"
assert len(set(self.answers)) == len(self.answers), assertion_msg

def test_duplicates(self):
self.test_q_duplicates()
self.test_a_duplicates()

def test(self):
self.test_empty()
self.test_equal_length()
self.test_duplicates()


def main(
judgements_fpath: str = typer.Option(default=None, help="Dumped `judgements` collection file path"),
out: str = typer.Option(default=None, help="Output file path"),
hf_model: str = typer.Option(
help="Hugging Face model name or path",
default="TheBloke/CapybaraHermes-2.5-Mistral-7B-GPTQ"
),
max_input_length: int = typer.Option(default=3551, help="Maximum number of tokens in input text"),
):
if judgements_fpath is None:
# FIXME
default_judgements_fpath = "/app/data/datasets/pl/judgements_sample10_20240427_094707f595590.jsonl"
logger.warning(
"Dumped `judgements` collection file path not provided."
f" Using the default `judgements` path: {default_judgements_fpath}"
)
judgements_fpath = default_judgements_fpath

if out is None:
# FIXME
default_out = "/app/data/datasets/pl/synthetic_judgements_qa.jsonl"
logger.warning("Output file path not provided. Using the default `out`: {default_out}")
out = default_out

qa_parser = QAPairsJsonParser(pydantic_object=SyntheticLegisQAPairs)
logger.debug(f"{qa_parser.get_format_instructions()=}")

prompt = ChatPromptTemplate.from_template(
template=JUDGEMENTS_QA_COT_PROMPT_V1,
partial_variables={"format_instructions": qa_parser.get_format_instructions()},
)

# For example: revision="gptq-4bit-32g-actorder_True"
model = AutoModelForCausalLM.from_pretrained(hf_model,
device_map="auto",
trust_remote_code=True,
revision="main"
)
tokenizer = AutoTokenizer.from_pretrained(hf_model, use_fast=True)
text_gen_pipeline = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=512,
do_sample=True,
temperature=0.7,
top_p=0.95,
top_k=40,
repetition_penalty=1.1
)
hf_pipeline = HuggingFacePipeline(pipeline=text_gen_pipeline)

gen_chain = prompt | hf_pipeline | qa_parser

logger.info("Generating QA pairs from provided collection data...")
for judgement in read_jsonl(judgements_fpath):
num_text_tokens = tokenizer.encode(judgement["text"], return_tensors="pt").shape[1]
if num_text_tokens > 0.95 * max_input_length:
logger.warning(
f"Skipping judgement with id: {judgement['_id']} due to text"
f"length > {max_input_length} ({num_text_tokens})..."
)
continue

chain_input = {
"context": judgement["text"],
"format_md_ext": "json"
}
qa_pairs = gen_chain.invoke(chain_input)
logger.debug(json.dumps(qa_pairs, indent=2, ensure_ascii=False))

dto = SyntheticLegisQAPairs(**qa_pairs)
dto.test()

break


if __name__ == "__main__":
typer.run(main)
73 changes: 73 additions & 0 deletions scripts/get_juddgements_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import typer
from pathlib import Path
from pymongo import MongoClient
from pymongo.collection import Collection
from pymongo.command_cursor import CommandCursor
from dotenv import load_dotenv
from loguru import logger

from juddges.data.utils import save_jsonl, path_safe_udate

load_dotenv("secrets.env", verbose=True)


def agg_sample(collection: Collection, size: int) -> CommandCursor:
"""Aggregate a sample of records from a MongoDB collection.
Args:
collection (Collection): MongoDB collection
size (itn): Number of records to sample
Returns:
CommandCursor: MongoDB cursor with sampled records
"""
logger.info(f"Sampling {size} records...")
agg = [
{
"$match": {
"text": {
"$exists": True
}
}
},
{
"$sample": {
"size": size
}
},
]
return collection.aggregate(agg)


def main(
size: int = typer.Option(help="Number of records to sample"),
# FIXME `seed` mongo do not support seed in sampling
seed: int = typer.Option(default=None, help="Random seed"),
mongo_uri: str = typer.Option(..., envvar="MONGO_URI"),
out: str = typer.Option(default=None, help="Output file path"),
):
if out is None:
# FIXME: data dir from settings when available, fix logger info
# from juddges.settings import PL_JUDGEMENTS_PATH
# out = PL_JUDGEMENTS_PATH / f"judgements_sample{size}_{path_safe_udate()}.jsonl"
logger.warning("Output file path not provided, using the default `out`: ./")
out = f"judgements_sample{size}_{path_safe_udate()}.jsonl"

logger.info("Connecting to MongoDB...")
client = MongoClient(mongo_uri)

logger.info("Fetching the `judgements` collection...")
collection = client["juddges"]["judgements"]

sample_records = agg_sample(collection, size=size)

out_path = Path(out)
if ".jsonl" in out_path.suffixes or ".jsonlines" in out_path.suffixes:
logger.info(f"Saving sample to {out_path}...")
save_jsonl(sample_records, out_path)
else:
raise NotImplementedError("Only JSONL output is supported")


if __name__ == "__main__":
typer.run(main)

0 comments on commit 41749dc

Please sign in to comment.