-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
664 additions
and
980 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import re | ||
from typing import Any, List, Callable | ||
from langchain_core.output_parsers import JsonOutputParser | ||
from langchain_core.output_parsers.json import parse_partial_json, _parse_json | ||
from langchain_core.outputs import Generation | ||
from langchain_core.exceptions import OutputParserException | ||
from json.decoder import JSONDecodeError | ||
|
||
from juddges.data.synthetic.patterns import CUSTOM_PARSE_JSON_MARKDOWN | ||
|
||
|
||
class QAPairsJsonParser(JsonOutputParser): | ||
|
||
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: | ||
text = result[0].text | ||
text = text.strip() | ||
if partial: | ||
try: | ||
return parse_json_markdown(text) | ||
except JSONDecodeError: | ||
return None | ||
else: | ||
try: | ||
return parse_json_markdown(text) | ||
except JSONDecodeError as e: | ||
msg = f"Invalid json output: {text}" | ||
raise OutputParserException(msg, llm_output=text) from e | ||
|
||
|
||
def parse_json_markdown( | ||
json_string: str, *, parser: Callable[[str], Any] = parse_partial_json | ||
) -> dict: | ||
"""Modified version of `langchain_core.output_parsers.json:parse_json_markdown` | ||
Fixes: JSONDecodeError when parsing CoT like output that contains multiple JSON strings | ||
""" | ||
try: | ||
return _parse_json(json_string, parser=parser) | ||
except JSONDecodeError: | ||
# Try to find the last JSON string within triple backticks | ||
match = CUSTOM_PARSE_JSON_MARKDOWN.findall(json_string) | ||
|
||
# If no match found, assume the entire string is a JSON string | ||
if match is None: | ||
json_str = json_string | ||
else: | ||
# If match found, use the content within the backticks | ||
json_str = match[-1] | ||
return _parse_json(json_str, parser=parser) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
JUDGEMENTS_QA_COT_PROMPT_V1 = '''\ | ||
You are a question-answer generator. Your goal is to generate question-answer pairs given the Context. | ||
Do not tranlate the Context, generate questions and answers in original language. | ||
Context: {context} | ||
Step 1: Identify spans that are likely to be answers to questions, identify as many as possible. | ||
Step 2: For each identified span, generate a question. | ||
Step 3: Respond to the question in only a few tokens concisely. | ||
Ensure that you distinctly label and delineate Steps 1, 2 and 3. | ||
{format_instructions} | ||
Output: | ||
```{format_md_ext} | ||
''' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
import re | ||
|
||
|
||
CUSTOM_PARSE_JSON_MARKDOWN = re.compile( | ||
pattern=r"```(?:json)?([^`]+)(?:```)?\s*$", | ||
flags=re.IGNORECASE | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from pathlib import Path | ||
from typing import Iterable, Generator | ||
from jsonlines import jsonlines | ||
from datetime import datetime | ||
|
||
|
||
def save_jsonl(records: Iterable[dict], out: Path | str): | ||
"""Save a list of dictionaries to a jsonl file.""" | ||
with jsonlines.open(out, mode="w") as writer: | ||
writer.write_all(records) | ||
|
||
|
||
def read_jsonl(path: Path | str) -> Generator[dict, None, None]: | ||
"""Read a jsonl file and yield dictionaries.""" | ||
with jsonlines.open(path) as reader: | ||
yield from reader | ||
|
||
|
||
def path_safe_udate() -> str: | ||
"""Generate a unique timestamp string for file naming. | ||
Returns: | ||
str: A string with the current date and time in the %Y%m%d_%H%M%Sf%f format | ||
""" | ||
return datetime.now().strftime("%Y%m%d_%H%M%Sf%f") |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
import typer | ||
from typing import List | ||
from loguru import logger | ||
from dotenv import load_dotenv | ||
import json | ||
|
||
from langchain_core.pydantic_v1 import BaseModel, Field | ||
from langchain_core.prompts import ChatPromptTemplate | ||
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline | ||
|
||
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | ||
|
||
from juddges.data.utils import read_jsonl | ||
from juddges.data.synthetic.generation_prompt import JUDGEMENTS_QA_COT_PROMPT_V1 | ||
from juddges.data.qa_pairs_json_parser import QAPairsJsonParser | ||
|
||
load_dotenv("secrets.env", verbose=True) | ||
|
||
|
||
|
||
class SyntheticLegisQAPairs(BaseModel): | ||
questions: List[str] = Field(description="List of generated questions") | ||
answers: List[str] = Field(description="List of generated answers") | ||
|
||
def test_empty(self): | ||
assert len(self.questions) > 0, "At least one question should be generated" | ||
|
||
def test_equal_length(self): | ||
assertion_msg = "Number of questions and answers should be equal" | ||
assert len(self.questions) == len(self.answers), assertion_msg | ||
|
||
def test_q_duplicates(self): | ||
assertion_msg = "Questions should be unique" | ||
assert len(set(self.questions)) == len(self.questions), assertion_msg | ||
|
||
def test_a_duplicates(self): | ||
assertion_msg = "Answers should be unique" | ||
assert len(set(self.answers)) == len(self.answers), assertion_msg | ||
|
||
def test_duplicates(self): | ||
self.test_q_duplicates() | ||
self.test_a_duplicates() | ||
|
||
def test(self): | ||
self.test_empty() | ||
self.test_equal_length() | ||
self.test_duplicates() | ||
|
||
|
||
def main( | ||
judgements_fpath: str = typer.Option(default=None, help="Dumped `judgements` collection file path"), | ||
out: str = typer.Option(default=None, help="Output file path"), | ||
hf_model: str = typer.Option( | ||
help="Hugging Face model name or path", | ||
default="TheBloke/CapybaraHermes-2.5-Mistral-7B-GPTQ" | ||
), | ||
max_input_length: int = typer.Option(default=3551, help="Maximum number of tokens in input text"), | ||
): | ||
if judgements_fpath is None: | ||
# FIXME | ||
default_judgements_fpath = "/app/data/datasets/pl/judgements_sample10_20240427_094707f595590.jsonl" | ||
logger.warning( | ||
"Dumped `judgements` collection file path not provided." | ||
f" Using the default `judgements` path: {default_judgements_fpath}" | ||
) | ||
judgements_fpath = default_judgements_fpath | ||
|
||
if out is None: | ||
# FIXME | ||
default_out = "/app/data/datasets/pl/synthetic_judgements_qa.jsonl" | ||
logger.warning("Output file path not provided. Using the default `out`: {default_out}") | ||
out = default_out | ||
|
||
qa_parser = QAPairsJsonParser(pydantic_object=SyntheticLegisQAPairs) | ||
logger.debug(f"{qa_parser.get_format_instructions()=}") | ||
|
||
prompt = ChatPromptTemplate.from_template( | ||
template=JUDGEMENTS_QA_COT_PROMPT_V1, | ||
partial_variables={"format_instructions": qa_parser.get_format_instructions()}, | ||
) | ||
|
||
# For example: revision="gptq-4bit-32g-actorder_True" | ||
model = AutoModelForCausalLM.from_pretrained(hf_model, | ||
device_map="auto", | ||
trust_remote_code=True, | ||
revision="main" | ||
) | ||
tokenizer = AutoTokenizer.from_pretrained(hf_model, use_fast=True) | ||
text_gen_pipeline = pipeline( | ||
"text-generation", | ||
model=model, | ||
tokenizer=tokenizer, | ||
max_new_tokens=512, | ||
do_sample=True, | ||
temperature=0.7, | ||
top_p=0.95, | ||
top_k=40, | ||
repetition_penalty=1.1 | ||
) | ||
hf_pipeline = HuggingFacePipeline(pipeline=text_gen_pipeline) | ||
|
||
gen_chain = prompt | hf_pipeline | qa_parser | ||
|
||
logger.info("Generating QA pairs from provided collection data...") | ||
for judgement in read_jsonl(judgements_fpath): | ||
num_text_tokens = tokenizer.encode(judgement["text"], return_tensors="pt").shape[1] | ||
if num_text_tokens > 0.95 * max_input_length: | ||
logger.warning( | ||
f"Skipping judgement with id: {judgement['_id']} due to text" | ||
f"length > {max_input_length} ({num_text_tokens})..." | ||
) | ||
continue | ||
|
||
chain_input = { | ||
"context": judgement["text"], | ||
"format_md_ext": "json" | ||
} | ||
qa_pairs = gen_chain.invoke(chain_input) | ||
logger.debug(json.dumps(qa_pairs, indent=2, ensure_ascii=False)) | ||
|
||
dto = SyntheticLegisQAPairs(**qa_pairs) | ||
dto.test() | ||
|
||
break | ||
|
||
|
||
if __name__ == "__main__": | ||
typer.run(main) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import typer | ||
from pathlib import Path | ||
from pymongo import MongoClient | ||
from pymongo.collection import Collection | ||
from pymongo.command_cursor import CommandCursor | ||
from dotenv import load_dotenv | ||
from loguru import logger | ||
|
||
from juddges.data.utils import save_jsonl, path_safe_udate | ||
|
||
load_dotenv("secrets.env", verbose=True) | ||
|
||
|
||
def agg_sample(collection: Collection, size: int) -> CommandCursor: | ||
"""Aggregate a sample of records from a MongoDB collection. | ||
Args: | ||
collection (Collection): MongoDB collection | ||
size (itn): Number of records to sample | ||
Returns: | ||
CommandCursor: MongoDB cursor with sampled records | ||
""" | ||
logger.info(f"Sampling {size} records...") | ||
agg = [ | ||
{ | ||
"$match": { | ||
"text": { | ||
"$exists": True | ||
} | ||
} | ||
}, | ||
{ | ||
"$sample": { | ||
"size": size | ||
} | ||
}, | ||
] | ||
return collection.aggregate(agg) | ||
|
||
|
||
def main( | ||
size: int = typer.Option(help="Number of records to sample"), | ||
# FIXME `seed` mongo do not support seed in sampling | ||
seed: int = typer.Option(default=None, help="Random seed"), | ||
mongo_uri: str = typer.Option(..., envvar="MONGO_URI"), | ||
out: str = typer.Option(default=None, help="Output file path"), | ||
): | ||
if out is None: | ||
# FIXME: data dir from settings when available, fix logger info | ||
# from juddges.settings import PL_JUDGEMENTS_PATH | ||
# out = PL_JUDGEMENTS_PATH / f"judgements_sample{size}_{path_safe_udate()}.jsonl" | ||
logger.warning("Output file path not provided, using the default `out`: ./") | ||
out = f"judgements_sample{size}_{path_safe_udate()}.jsonl" | ||
|
||
logger.info("Connecting to MongoDB...") | ||
client = MongoClient(mongo_uri) | ||
|
||
logger.info("Fetching the `judgements` collection...") | ||
collection = client["juddges"]["judgements"] | ||
|
||
sample_records = agg_sample(collection, size=size) | ||
|
||
out_path = Path(out) | ||
if ".jsonl" in out_path.suffixes or ".jsonlines" in out_path.suffixes: | ||
logger.info(f"Saving sample to {out_path}...") | ||
save_jsonl(sample_records, out_path) | ||
else: | ||
raise NotImplementedError("Only JSONL output is supported") | ||
|
||
|
||
if __name__ == "__main__": | ||
typer.run(main) |