Skip to content

Commit

Permalink
Merge branch 'master' into add-schema-chain
Browse files Browse the repository at this point in the history
  • Loading branch information
laugustyniak committed Apr 9, 2024
2 parents 82f4449 + 80893fa commit a3f62ac
Show file tree
Hide file tree
Showing 24 changed files with 1,442 additions and 121 deletions.
1 change: 0 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ fix:
check:
ruff check $(lint_dirs)
ruff format $(lint_dirs) --check
mypy --install-types --non-interactive $(mypy_dirs)

test:
coverage run -m pytest
Expand Down
1 change: 0 additions & 1 deletion data/.gitignore

This file was deleted.

2 changes: 2 additions & 0 deletions data/datasets/pl/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
/raw
/text
6 changes: 6 additions & 0 deletions data/datasets/pl/raw.dvc
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
outs:
- md5: 801ebfe4c29d0564abfce7006536adc8.dir
size: 5466475038
nfiles: 9
hash: md5
path: raw
6 changes: 6 additions & 0 deletions data/datasets/pl/text.dvc
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
outs:
- md5: cac0bc44e36e68d606eff7500d627bd1.dir
size: 22741832080
nfiles: 11
hash: md5
path: text
5 changes: 0 additions & 5 deletions data/dummy_file.txt.dvc

This file was deleted.

2 changes: 2 additions & 0 deletions juddges/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,7 @@
'lib_path': 'juddges'},
'syms': { 'juddges.data.models': {},
'juddges.data.pl_court_api': {},
'juddges.preprocessing.parser_base': {},
'juddges.preprocessing.pl_court_parser': {},
'juddges.prompts.information_extraction': {},
'juddges.settings': {}}}
3 changes: 2 additions & 1 deletion juddges/data/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os

from pymongo import MongoClient
from pymongo.collection import Collection

if os.environ.get("MONGO_URI", None) is None:
raise Exception("Missing `MONGO_URI` environment variable.")
Expand All @@ -10,7 +11,7 @@
raise Exception("Missing `MONGO_DB_NAME` environment variable.")


def get_mongo_collection(collection_name: str):
def get_mongo_collection(collection_name: str) -> Collection:
client = MongoClient(os.environ["MONGO_URI"])
db = client[os.environ["MONGO_DB_NAME"]]
return db[collection_name]
69 changes: 66 additions & 3 deletions juddges/data/pl_court_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,41 @@
import requests
import xmltodict
from loguru import logger
from requests import HTTPError


class PolishCourtAPI:
def __init__(self) -> None:
self.url = "https://apiorzeczenia.wroclaw.sa.gov.pl/ncourt-api"

@property
def schema(self) -> dict[str, list[str]]:
return {
"judgement": [
"_id",
"signature",
"date",
"publicationDate",
"lastUpdate",
"courtId",
"departmentId",
"type",
"excerpt",
],
"content": ["content"],
"details": [
"chairman",
"judges",
"themePhrases",
"references",
"legalBases",
"recorder",
"decision",
"reviser",
"publisher",
],
}

def get_number_of_judgements(self, params: dict[str, Any] | None = None) -> int:
if params is None:
params = {}
Expand All @@ -33,11 +62,45 @@ def get_judgements(self, params: dict[str, Any]) -> list[dict[str, Any]]:

return judgements

def get_content(self, id: str) -> str:
def get_content(self, id: str) -> dict[str, Any]:
params = {"id": id}
endpoint = f"{self.url}/judgement/content"
res = requests.get(endpoint, params=params)
res.raise_for_status()

try:
res.raise_for_status()
except HTTPError as err:
if err.response.status_code == 404:
raise DataNotFoundError(f"Not found content for document: {id}")
raise

content = res.content.decode("utf-8")

return content
return {"content": content}

def get_cleaned_details(self, id: str) -> dict[str, Any]:
"""Downloads details without repeating fields retrieved in get_judgements."""
details = self.get_details(id)
return {k: v for k, v in details.items() if k in self.schema["details"]}

def get_details(self, id: str) -> dict[str, Any]:
params = {"id": id}
endpoint = f"{self.url}/judgement/details"
res = requests.get(endpoint, params=params)
res.raise_for_status()

# for details, API returns XML with error info instead of 404 status code
data = xmltodict.parse(res.content.decode("utf-8"))
try:
details = data["judgement"]
except KeyError:
if "error" in data.keys():
raise DataNotFoundError(f"Not found details for document: {id}")
raise
else:
assert isinstance(details, dict)
return details


class DataNotFoundError(Exception):
pass
File renamed without changes.
18 changes: 18 additions & 0 deletions juddges/preprocessing/parser_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from abc import ABC, abstractmethod
from typing import Any


class DocParserBase(ABC):
"""Base class for parser retrieving data from a document."""

def __call__(self, document: str) -> dict[str, Any]:
return self.parse(document)

@property
@abstractmethod
def schema(self) -> list[str]:
pass

@abstractmethod
def parse(self, document: str) -> dict[str, Any]:
pass
80 changes: 80 additions & 0 deletions juddges/preprocessing/pl_court_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import re
from typing import Any, Generator
from xml.etree import ElementTree
from xml.etree.ElementTree import Element

from juddges.preprocessing.parser_base import DocParserBase

MULTIPLE_NEWLINES = re.compile(r"(\n\s*)+\n+")


class SimplePlJudgementsParser(DocParserBase):
"""The simplest parser for the simple XML format used by the Polish courts.
It extracts the text from XML file, without adhering to any specific structure.
"""

@property
def schema(self) -> list[str]:
return ["num_pages", "vol_number", "vol_type", "text"]

def parse(self, document: str) -> dict[str, Any]:
et = ElementTree.fromstring(document)

xblock_elements = et.findall("xBlock")
assert len(xblock_elements) == 1, "There should be only one xBlock element"
content_root, *_ = xblock_elements

return {
"num_pages": int(et.attrib["xToPage"]),
"vol_number": int(et.attrib["xVolNmbr"]),
"vol_type": et.attrib["xVolType"],
"text": self.extract_text(content_root),
}

@staticmethod
def extract_text(element: Element) -> str:
text = ""
for elem_txt in element.itertext():
if elem_txt is None:
continue
if txt := elem_txt.strip(" "):
text += txt

text = re.sub(MULTIPLE_NEWLINES, "\n\n", text).strip()

return text


def itertext(element: Element, prefix: str = "") -> Generator[str, None, None]:
"""Extension of the Element.itertext method to handle special tags in pl court XML."""
tag = element.tag
if not isinstance(tag, str) and tag is not None:
return

t: str | None
match (tag, element.attrib):
case ("xName", {"xSffx": suffix}):
element.tail = element.tail.strip() if element.tail else None
t = f"{element.text}{suffix} "
case ("xEnum", _):
bullet_elem = element.find("xBullet")
if bullet_elem:
prefix = bullet_elem.text or ""
element.remove(bullet_elem)
t = ""
case ("xEnumElem", _):
t = prefix
case _:
t = element.text

if t:
yield t

for e in element:
yield from itertext(e, prefix)
t = e.tail

if t:
yield t
74 changes: 74 additions & 0 deletions juddges/prompts/information_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,77 @@ def route(response_schema: str) -> dict[str, str]:
raise ValueError(
"Cannot determine schema for the given input prompt. Please try different query."
)
from langchain.chat_models import ChatOpenAI
from langchain.output_parsers.json import parse_json_markdown
from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate

PROMPT_TEMPLATE = """Act as a legal document tool that extracts information and answer questions based on judgements.
Instruction for extracting information from judgements:
- Judgements are in {LANGUAGE} language, please extract information in {LANGUAGE}.
- Do not provide information that are not explicitly mentioned in judgements. If you can't extract information from the text field, leave the field with empty string "".
Follow the following YAML structure to extract information and answer questions based on judgements:
{SCHEMA}
====
{TEXT}
====
Format response as JSON:
"""

SCHEMA = """
defendant_gender: string
defendant_age: integer
defendant_relationship_status: string
defendant_has_children: boolean
defendant_homeless: boolean
appellant: string
appeal_against: string
defendant_plead_or_convicted: string "guilty plea" or "convicted at trial"
jury_unanimous: string "unanimous" or "other"
first_trial: boolean
drug_offence: boolean
original_sentence: string
tried_court_type: string "Crown" or "magistrates'"
single_transaction_multiple_offence: boolean
multiple_transactions_period: string "within 1 year" or "more than 1 year"
concurrent_or_consecutive_sentence: string "concurrently" or "consecutively"
sentence_on_top_existing: boolean
sentence_adding_up: boolean
sentence_leniency: string "unduly lenient" or "too excessive"
guilty_plea_reduction_reason: string
sentence_discount_mention: boolean
totality_issues_similar_offences: boolean
sentence_proportionality_mention: boolean
sentence_type_issues_totality: boolean
sentence_adjustment_mention: boolean
offender_culpability_determination: boolean
harm_caused_determination: boolean
offence_seriousness: boolean
aggravating_factors: string list of factors excluding previous convictions
previous_convictions_similarity: string "similar" or "dissimilar"
mitigating_factors: string list of factors
immediate_sentence_concurrency: boolean
totality_conflicting_guidelines: boolean
totality_principle_misapplication: boolean
"""


def prepare_information_extraction_chain(
model_name: str = "gpt-4-0125-preview", log_to_mlflow: bool = False
):
model = ChatOpenAI(model=model_name, temperature=0)
human_message_template = HumanMessagePromptTemplate.from_template(PROMPT_TEMPLATE)
_prompt = ChatPromptTemplate(
messages=[human_message_template],
input_variables=["TEXT", "LANGUAGE", "SCHEMA"],
)

if log_to_mlflow:
import mlflow

mlflow.log_dict(_prompt.save_to_json(), "prompt.json")

return _prompt | model | (lambda x: parse_json_markdown(x.content))
15 changes: 8 additions & 7 deletions juddges/settings.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from pathlib import Path

import mlflow
import tiktoken
from sqlalchemy import create_engine
from sqlalchemy.engine import Engine

# get root path as ROOT_PATH as pathlib objects
ROOT_PATH = Path(__file__).resolve().parent.parent
Expand Down Expand Up @@ -37,9 +40,7 @@ def num_tokens_from_string(
LOCAL_POSTGRES = "postgresql+psycopg2://llm:llm@postgres-juddges:5432/llm"


def get_sqlalchemy_engine():
from sqlalchemy import create_engine

def get_sqlalchemy_engine() -> Engine:
return create_engine(
LOCAL_POSTGRES,
pool_size=10,
Expand All @@ -50,15 +51,15 @@ def get_sqlalchemy_engine():
)


def prepare_langchain_cache():
def prepare_langchain_cache() -> None:
import langchain
from langchain.cache import SQLAlchemyMd5Cache

langchain.llm_cache = SQLAlchemyMd5Cache(get_sqlalchemy_engine())


def prepare_mlflow(experiment_name: str = MLFLOW_EXP_NAME, url="http://host.docker.internal"):
import mlflow

def prepare_mlflow(
experiment_name: str = MLFLOW_EXP_NAME, url="http://host.docker.internal"
) -> None:
mlflow.set_tracking_uri(url)
mlflow.set_experiment(experiment_name)
Loading

0 comments on commit a3f62ac

Please sign in to comment.