From 981501c13fb3bd135546fc017a4f0440a02390d4 Mon Sep 17 00:00:00 2001 From: Jakub Binkowski Date: Wed, 3 Apr 2024 12:45:00 +0200 Subject: [PATCH] Parse pl judgements (#4) * Improve error handling in download_pl_content.py * Add dataset dump scrip * Add pl dataset to DVC * Add simple data analysis notebook * Extract text from pl judgements * Refine text extraction and add analysis * Add addtional details download and ingest * Refine extraction and ingest extracted data to mongo * Add script for chunked embeddings --------- Co-authored-by: Jakub Binkowski --- data/.gitignore | 1 - data/datasets/pl/.gitignore | 2 + data/datasets/pl/raw.dvc | 6 + data/datasets/pl/text.dvc | 6 + data/dummy_file.txt.dvc | 5 - juddges/_modidx.py | 2 +- juddges/data/pl_court_api.py | 69 ++++++- .../preprocessing/__init__.py | 0 juddges/preprocessing/parser_base.py | 18 ++ juddges/preprocessing/pl_court_parser.py | 80 ++++++++ notebooks/1_analyse_dataset.ipynb | 144 ++++++++++++++ notebooks/2_analyse_text.ipynb | 185 ++++++++++++++++++ pyproject.toml | 5 + requirements.txt | 9 + scripts/download_pl_additional_data.py | 107 ++++++++++ scripts/download_pl_content.py | 102 ---------- scripts/dump_pl_dataset.py | 74 +++++++ scripts/embed_text.py | 77 ++++++++ scripts/extract_pl_xml.py | 106 ++++++++++ 19 files changed, 886 insertions(+), 112 deletions(-) delete mode 100644 data/.gitignore create mode 100644 data/datasets/pl/.gitignore create mode 100644 data/datasets/pl/raw.dvc create mode 100644 data/datasets/pl/text.dvc delete mode 100644 data/dummy_file.txt.dvc rename notebooks/.gitkeep => juddges/preprocessing/__init__.py (100%) create mode 100644 juddges/preprocessing/parser_base.py create mode 100644 juddges/preprocessing/pl_court_parser.py create mode 100644 notebooks/1_analyse_dataset.ipynb create mode 100644 notebooks/2_analyse_text.ipynb create mode 100644 scripts/download_pl_additional_data.py delete mode 100644 scripts/download_pl_content.py create mode 100644 scripts/dump_pl_dataset.py create mode 100644 scripts/embed_text.py create mode 100644 scripts/extract_pl_xml.py diff --git a/data/.gitignore b/data/.gitignore deleted file mode 100644 index 3dd093b..0000000 --- a/data/.gitignore +++ /dev/null @@ -1 +0,0 @@ -/dummy_file.txt diff --git a/data/datasets/pl/.gitignore b/data/datasets/pl/.gitignore new file mode 100644 index 0000000..f8bb874 --- /dev/null +++ b/data/datasets/pl/.gitignore @@ -0,0 +1,2 @@ +/raw +/text diff --git a/data/datasets/pl/raw.dvc b/data/datasets/pl/raw.dvc new file mode 100644 index 0000000..054bf5c --- /dev/null +++ b/data/datasets/pl/raw.dvc @@ -0,0 +1,6 @@ +outs: +- md5: 801ebfe4c29d0564abfce7006536adc8.dir + size: 5466475038 + nfiles: 9 + hash: md5 + path: raw diff --git a/data/datasets/pl/text.dvc b/data/datasets/pl/text.dvc new file mode 100644 index 0000000..788086f --- /dev/null +++ b/data/datasets/pl/text.dvc @@ -0,0 +1,6 @@ +outs: +- md5: cac0bc44e36e68d606eff7500d627bd1.dir + size: 22741832080 + nfiles: 11 + hash: md5 + path: text diff --git a/data/dummy_file.txt.dvc b/data/dummy_file.txt.dvc deleted file mode 100644 index 9347ba2..0000000 --- a/data/dummy_file.txt.dvc +++ /dev/null @@ -1,5 +0,0 @@ -outs: -- md5: d41d8cd98f00b204e9800998ecf8427e - size: 0 - hash: md5 - path: dummy_file.txt diff --git a/juddges/_modidx.py b/juddges/_modidx.py index dc205b6..09d9d33 100644 --- a/juddges/_modidx.py +++ b/juddges/_modidx.py @@ -5,4 +5,4 @@ 'doc_host': 'https://laugustyniak.github.io', 'git_url': 'https://github.com/laugustyniak/juddges', 'lib_path': 'juddges'}, - 'syms': {'juddges.data.pl_court_api': {}}} + 'syms': {'juddges.data.pl_court_api': {}, 'juddges.preprocessing.parser_base': {}, 'juddges.preprocessing.pl_court_parser': {}}} diff --git a/juddges/data/pl_court_api.py b/juddges/data/pl_court_api.py index 10e857e..c3e18d1 100644 --- a/juddges/data/pl_court_api.py +++ b/juddges/data/pl_court_api.py @@ -3,12 +3,41 @@ import requests import xmltodict from loguru import logger +from requests import HTTPError class PolishCourtAPI: def __init__(self) -> None: self.url = "https://apiorzeczenia.wroclaw.sa.gov.pl/ncourt-api" + @property + def schema(self) -> dict[str, list[str]]: + return { + "judgement": [ + "_id", + "signature", + "date", + "publicationDate", + "lastUpdate", + "courtId", + "departmentId", + "type", + "excerpt", + ], + "content": ["content"], + "details": [ + "chairman", + "judges", + "themePhrases", + "references", + "legalBases", + "recorder", + "decision", + "reviser", + "publisher", + ], + } + def get_number_of_judgements(self, params: dict[str, Any] | None = None) -> int: if params is None: params = {} @@ -33,11 +62,45 @@ def get_judgements(self, params: dict[str, Any]) -> list[dict[str, Any]]: return judgements - def get_content(self, id: str) -> str: + def get_content(self, id: str) -> dict[str, Any]: params = {"id": id} endpoint = f"{self.url}/judgement/content" res = requests.get(endpoint, params=params) - res.raise_for_status() + + try: + res.raise_for_status() + except HTTPError as err: + if err.response.status_code == 404: + raise DataNotFoundError(f"Not found content for document: {id}") + raise + content = res.content.decode("utf-8") - return content + return {"content": content} + + def get_cleaned_details(self, id: str) -> dict[str, Any]: + """Downloads details without repeating fields retrieved in get_judgements.""" + details = self.get_details(id) + return {k: v for k, v in details.items() if k in self.schema["details"]} + + def get_details(self, id: str) -> dict[str, Any]: + params = {"id": id} + endpoint = f"{self.url}/judgement/details" + res = requests.get(endpoint, params=params) + res.raise_for_status() + + # for details, API returns XML with error info instead of 404 status code + data = xmltodict.parse(res.content.decode("utf-8")) + try: + details = data["judgement"] + except KeyError: + if "error" in data.keys(): + raise DataNotFoundError(f"Not found details for document: {id}") + raise + else: + assert isinstance(details, dict) + return details + + +class DataNotFoundError(Exception): + pass diff --git a/notebooks/.gitkeep b/juddges/preprocessing/__init__.py similarity index 100% rename from notebooks/.gitkeep rename to juddges/preprocessing/__init__.py diff --git a/juddges/preprocessing/parser_base.py b/juddges/preprocessing/parser_base.py new file mode 100644 index 0000000..24e26dc --- /dev/null +++ b/juddges/preprocessing/parser_base.py @@ -0,0 +1,18 @@ +from abc import ABC, abstractmethod +from typing import Any + + +class DocParserBase(ABC): + """Base class for parser retrieving data from a document.""" + + def __call__(self, document: str) -> dict[str, Any]: + return self.parse(document) + + @property + @abstractmethod + def schema(self) -> list[str]: + pass + + @abstractmethod + def parse(self, document: str) -> dict[str, Any]: + pass diff --git a/juddges/preprocessing/pl_court_parser.py b/juddges/preprocessing/pl_court_parser.py new file mode 100644 index 0000000..34864f1 --- /dev/null +++ b/juddges/preprocessing/pl_court_parser.py @@ -0,0 +1,80 @@ +import re +from typing import Any, Generator +from xml.etree import ElementTree +from xml.etree.ElementTree import Element + +from juddges.preprocessing.parser_base import DocParserBase + +MULTIPLE_NEWLINES = re.compile(r"(\n\s*)+\n+") + + +class SimplePlJudgementsParser(DocParserBase): + """The simplest parser for the simple XML format used by the Polish courts. + + It extracts the text from XML file, without adhering to any specific structure. + + """ + + @property + def schema(self) -> list[str]: + return ["num_pages", "vol_number", "vol_type", "text"] + + def parse(self, document: str) -> dict[str, Any]: + et = ElementTree.fromstring(document) + + xblock_elements = et.findall("xBlock") + assert len(xblock_elements) == 1, "There should be only one xBlock element" + content_root, *_ = xblock_elements + + return { + "num_pages": int(et.attrib["xToPage"]), + "vol_number": int(et.attrib["xVolNmbr"]), + "vol_type": et.attrib["xVolType"], + "text": self.extract_text(content_root), + } + + @staticmethod + def extract_text(element: Element) -> str: + text = "" + for elem_txt in element.itertext(): + if elem_txt is None: + continue + if txt := elem_txt.strip(" "): + text += txt + + text = re.sub(MULTIPLE_NEWLINES, "\n\n", text).strip() + + return text + + +def itertext(element: Element, prefix: str = "") -> Generator[str, None, None]: + """Extension of the Element.itertext method to handle special tags in pl court XML.""" + tag = element.tag + if not isinstance(tag, str) and tag is not None: + return + + t: str | None + match (tag, element.attrib): + case ("xName", {"xSffx": suffix}): + element.tail = element.tail.strip() if element.tail else None + t = f"{element.text}{suffix} " + case ("xEnum", _): + bullet_elem = element.find("xBullet") + if bullet_elem: + prefix = bullet_elem.text or "" + element.remove(bullet_elem) + t = "" + case ("xEnumElem", _): + t = prefix + case _: + t = element.text + + if t: + yield t + + for e in element: + yield from itertext(e, prefix) + t = e.tail + + if t: + yield t diff --git a/notebooks/1_analyse_dataset.ipynb b/notebooks/1_analyse_dataset.ipynb new file mode 100644 index 0000000..8f2b090 --- /dev/null +++ b/notebooks/1_analyse_dataset.ipynb @@ -0,0 +1,144 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "initial_id", + "metadata": { + "ExecuteTime": { + "end_time": "2024-03-15T11:07:39.123324510Z", + "start_time": "2024-03-15T11:07:39.065139618Z" + } + }, + "outputs": [], + "source": [ + "import polars as pl" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c8a2c7d4858169a2", + "metadata": { + "ExecuteTime": { + "end_time": "2024-03-15T11:12:01.007272795Z", + "start_time": "2024-03-15T11:11:47.709404815Z" + }, + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [], + "source": [ + "ds = pl.read_parquet(\"../data/datasets/pl/raw\", use_pyarrow=True)\n", + "\n", + "dt_fmt = \"%Y-%m-%d %H:%M:%S%.f %Z\"\n", + "dt_unit = \"ms\" # due to https://github.com/pola-rs/polars/issues/13592\n", + "\n", + "ds = ds.with_columns(\n", + " ds[\"date\"].str.to_datetime(format=dt_fmt, time_unit=dt_unit),\n", + " ds[\"publicationDate\"].str.to_datetime(format=dt_fmt, time_unit=dt_unit),\n", + " ds[\"lastUpdate\"].str.to_datetime(format=dt_fmt, time_unit=dt_unit),\n", + " ds[\"courtId\"].cast(pl.Int32),\n", + " ds[\"departmentId\"].cast(pl.Int32),\n", + " ds[\"type\"].cast(pl.Categorical),\n", + ")\n", + "\n", + "ds.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "35e65fe2dd9a4bce", + "metadata": { + "ExecuteTime": { + "end_time": "2024-03-15T11:12:01.074286418Z", + "start_time": "2024-03-15T11:12:00.912825131Z" + }, + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [], + "source": [ + "ds.describe()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ab23ff37327a377a", + "metadata": { + "ExecuteTime": { + "end_time": "2024-03-15T11:15:27.800725934Z", + "start_time": "2024-03-15T11:15:27.753971240Z" + }, + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [], + "source": [ + "ds[\"type\"].value_counts()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11446c299cdf1700", + "metadata": { + "ExecuteTime": { + "end_time": "2024-03-15T11:15:38.059155473Z", + "start_time": "2024-03-15T11:15:38.053450756Z" + }, + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [], + "source": [ + "print(f\"Missing content: {ds['content'].null_count() / len(ds)}\")\n", + "print(f\"Missing theis: {ds['thesis'].null_count() / len(ds)}\")\n", + "print(f\"Missing excerpt: {ds['excerpt'].null_count() / len(ds)}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "891ffbad", + "metadata": {}, + "outputs": [], + "source": [ + "ds[\"excerpt\"].str.strip_chars().str.len_chars().to_pandas().plot.hist(\n", + " bins=50, log=True, title=\"Excerpt #chars distribution\"\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/2_analyse_text.ipynb b/notebooks/2_analyse_text.ipynb new file mode 100644 index 0000000..38c9a43 --- /dev/null +++ b/notebooks/2_analyse_text.ipynb @@ -0,0 +1,185 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "6b666da3-f393-4d88-8036-e818937d2305", + "metadata": {}, + "outputs": [], + "source": [ + "from datasets import load_from_disk\n", + "import string\n", + "import seaborn as sns\n", + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c1f37c21-de73-48ee-8cc3-8f4f2d4ce735", + "metadata": {}, + "outputs": [], + "source": [ + "ds = load_from_disk(dataset_path=\"../data/datasets/pl/text/\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c49a038b-3bd5-4124-89c2-a019c364fd22", + "metadata": {}, + "outputs": [], + "source": [ + "def tagger(item):\n", + " text = item[\"content\"]\n", + " dummy_tokens = text.split()\n", + "\n", + " item[\"chars\"] = len(text)\n", + " item[\"num_dummy_tokens\"] = len(dummy_tokens)\n", + " item[\"num_non_ws_tokens\"] = sum(\n", + " 1 for tok in dummy_tokens if any(char not in string.punctuation for char in tok.strip())\n", + " )\n", + "\n", + " return item\n", + "\n", + "\n", + "ds = ds.map(tagger, num_proc=20)\n", + "ds.cleanup_cache_files()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d7531b42-d802-41b6-b699-d1b16361e8ff", + "metadata": {}, + "outputs": [], + "source": [ + "stats = (\n", + " ds.select_columns([\"_id\", \"type\", \"chars\", \"num_dummy_tokens\", \"num_non_ws_tokens\"])\n", + " .to_pandas()\n", + " .convert_dtypes(dtype_backend=\"pyarrow\")\n", + ")\n", + "stats[\"type\"] = stats[\"type\"].astype(\"category\")\n", + "stats.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6e4a7432-12fe-4b8b-a9c4-77a69a780ec2", + "metadata": {}, + "outputs": [], + "source": [ + "ax = sns.histplot(\n", + " x=stats[\"num_non_ws_tokens\"],\n", + " log_scale=True,\n", + " bins=50,\n", + ")\n", + "ax.set(title=\"#tokens distribution\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "480766df-9c10-4e15-808c-85a76d15166e", + "metadata": {}, + "outputs": [], + "source": [ + "card_order = stats[\"type\"].value_counts().index.tolist()\n", + "data = stats[\"type\"].value_counts().plot.barh(logx=True, title=\"Types cardinality\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c295c4a0-6926-43c2-bb51-77e8a98ff4da", + "metadata": {}, + "outputs": [], + "source": [ + "# sns.displot(data=stats, x=\"num_non_ws_tokens\", col=\"type\", col_wrap=3, log_scale=(True, False), facet_kws=dict(sharey=False, sharex=False), kind=\"hist\", bins=25)\n", + "\n", + "_, ax = plt.subplots(figsize=(8, 12))\n", + "ax.set(title=\"Per type text length ditribution\")\n", + "sns.boxenplot(data=stats, y=\"type\", x=\"num_non_ws_tokens\", order=card_order, log_scale=True)" + ] + }, + { + "cell_type": "markdown", + "id": "ea06ef3f-c12d-4da6-9fc6-45f1809dabad", + "metadata": {}, + "source": [ + "# Tokenize " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "08c70fdc-0b03-4983-8da9-8d065161d3e7", + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import AutoTokenizer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0af8c3ba-aa89-4e1a-bfcb-65b618c4559e", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "tokenizer = AutoTokenizer.from_pretrained(\"intfloat/multilingual-e5-large\")\n", + "ds = ds.map(\n", + " lambda examples: tokenizer(examples[\"content\"], padding=False, truncation=False),\n", + " batched=True,\n", + " num_proc=20,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6f822fae-f91c-4ee1-a114-97a021bf1e81", + "metadata": {}, + "outputs": [], + "source": [ + "tokenized = []\n", + "for item in ds:\n", + " tokenized.append({\"num_tokens\": len(item[\"input_ids\"])})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cdac696f-056a-4b12-a48e-ac8f8dac9eeb", + "metadata": {}, + "outputs": [], + "source": [ + "sns.histplot(tokenized, bins=50)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pyproject.toml b/pyproject.toml index 423fc39..0616818 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,3 +8,8 @@ python_version = "3.11" strict = true untyped_calls_exclude = ["pymongo"] plugins = "numpy.typing.mypy_plugin" + +[[tool.mypy.overrides]] +module = ["pyarrow.*", "datasets.*", "sentence_transformers.*"] +ignore_missing_imports = true + diff --git a/requirements.txt b/requirements.txt index 36d28e4..8928208 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,6 +8,15 @@ tenacity==8.2.3 loguru==0.7.2 typer==0.9.0 +datasets==2.18.0 +pyarrow==15.0.0 +pandas==2.2.1 +polars==0.20.15 +torch==2.2.1 +transformers==4.38.2 +langchain==0.1.13 +sentence-transformers==2.5.1 + # dev nbdev==2.3.13 streamlit==1.31.1 diff --git a/scripts/download_pl_additional_data.py b/scripts/download_pl_additional_data.py new file mode 100644 index 0000000..5408d09 --- /dev/null +++ b/scripts/download_pl_additional_data.py @@ -0,0 +1,107 @@ +import math +from enum import Enum +from typing import Any + +import typer +from dotenv import load_dotenv +from loguru import logger +from mpire.pool import WorkerPool +from pymongo import MongoClient, UpdateOne +from pymongo.errors import BulkWriteError +from pymongo.server_api import ServerApi +from requests import HTTPError, ConnectionError +from tenacity import retry, wait_random_exponential, retry_if_exception_type, stop_after_attempt +from tqdm import tqdm + +from juddges.data.pl_court_api import PolishCourtAPI, DataNotFoundError + +N_JOBS = 6 +BATCH_SIZE = 100 + +load_dotenv("secrets.env", verbose=True) + + +class DataType(Enum): + CONTENT = "content" + DETAILS = "details" + + +def main( + mongo_uri: str = typer.Option(..., envvar="MONGO_URI"), + data_type: DataType = typer.Option(DataType.CONTENT), + batch_size: int = typer.Option(BATCH_SIZE), + n_jobs: int = typer.Option(N_JOBS), +) -> None: + client: MongoClient[dict[str, Any]] = MongoClient(mongo_uri, server_api=ServerApi("1")) + collection = client["juddges"]["judgements"] + client.admin.command("ping") + + api = PolishCourtAPI() + + # find rows which are missing at least one field + query = {"$or": [{field: {"$exists": False}} for field in api.schema[data_type.value]]} + + num_docs_to_update = collection.count_documents(query) + logger.info(f"There are {num_docs_to_update} documents to update") + + # fetch all ids at once to avoid cursor timeout + cursor = collection.find(query, {"_id": 1}, batch_size=batch_size) + docs_to_update: list[str] = [] + for doc in tqdm(cursor, total=num_docs_to_update, desc="Fetching doc list"): + docs_to_update.append(str(doc["_id"])) + + batched_docs_to_update = ( + docs_to_update[i : i + batch_size] for i in range(0, len(docs_to_update), batch_size) + ) + + download_update = DownloadDataAndUpdateDatabase(mongo_uri, data_type) + with WorkerPool(n_jobs=n_jobs) as pool: + pool.map_unordered( + download_update, + batched_docs_to_update, + progress_bar=True, + iterable_len=math.ceil(num_docs_to_update / batch_size), + ) + + +class DownloadDataAndUpdateDatabase: + def __init__(self, mongo_uri: str, data_type: DataType): + self.mongo_uri = mongo_uri + self.data_type = data_type + self.api = PolishCourtAPI() + + def __call__(self, *doc_ids: str) -> None: + data_batch: list[UpdateOne] = [] + + for d_id in doc_ids: + fetched_data = self._download_data(d_id) + data_batch.append(UpdateOne({"_id": d_id}, {"$set": fetched_data})) + + client: MongoClient[dict[str, Any]] = MongoClient(self.mongo_uri) + collection = client["juddges"]["judgements"] + + try: + collection.bulk_write(data_batch) + except BulkWriteError as err: + logger.error(err) + + @retry( + wait=wait_random_exponential(multiplier=1, min=4, max=30), + retry=retry_if_exception_type((HTTPError, ConnectionError)), + stop=stop_after_attempt(5), + ) + def _download_data(self, doc_id: str) -> dict[str, Any]: + try: + if self.data_type == DataType.CONTENT: + return self.api.get_content(doc_id) + elif self.data_type == DataType.DETAILS: + return self.api.get_cleaned_details(doc_id) + else: + raise ValueError(f"Invalid field: {self.data_type.value}") + except DataNotFoundError as err: + logger.warning(err) + return dict.fromkeys(self.api.schema[self.data_type.value], None) + + +if __name__ == "__main__": + typer.run(main) diff --git a/scripts/download_pl_content.py b/scripts/download_pl_content.py deleted file mode 100644 index aeaa3e8..0000000 --- a/scripts/download_pl_content.py +++ /dev/null @@ -1,102 +0,0 @@ -import math -from typing import Generator, Any - -import typer -from dotenv import load_dotenv -from loguru import logger -from mpire.pool import WorkerPool -from pymongo import MongoClient, UpdateOne -from pymongo.cursor import Cursor -from pymongo.errors import BulkWriteError -from pymongo.server_api import ServerApi -from requests import HTTPError -from tenacity import retry, wait_random_exponential, retry_if_exception_type, stop_after_attempt - -from juddges.data.pl_court_api import PolishCourtAPI - -N_JOBS = 8 -BATCH_SIZE = 100 - -load_dotenv("secrets.env", verbose=True) - - -def main( - mongo_uri: str = typer.Option(..., envvar="MONGO_URI"), - batch_size: int = typer.Option(BATCH_SIZE), - n_jobs: int = typer.Option(N_JOBS), -) -> None: - client: MongoClient[dict[str, Any]] = MongoClient(mongo_uri, server_api=ServerApi("1")) - collection = client["juddges"]["judgements"] - client.admin.command("ping") - - query = {"content": {"$exists": False}} - num_docs_without_content = collection.count_documents(query) - logger.info(f"There are {num_docs_without_content} documents without content") - - cursor = collection.find(query, batch_size=batch_size) - - docs_to_update = yield_batches(cursor, batch_size) - download_content = ContentDownloader(mongo_uri) - with WorkerPool(n_jobs=n_jobs) as pool: - pool.map_unordered( - download_content, - docs_to_update, - progress_bar=True, - iterable_len=math.ceil(num_docs_without_content / batch_size), - ) - - -class ContentDownloader: - def __init__(self, mongo_uri: str): - self.mongo_uri = mongo_uri - - def __call__(self, *doc_ids: str) -> None: - data_batch: list[UpdateOne] = [] - - for d_id in doc_ids: - content = self._download_content(d_id) - data_batch.append(UpdateOne({"_id": d_id}, {"$set": {"content": content}})) - - client: MongoClient[dict[str, Any]] = MongoClient(self.mongo_uri) - collection = client["juddges"]["judgements"] - - try: - collection.bulk_write(data_batch) - except BulkWriteError as err: - logger.error(err) - - @retry( - wait=wait_random_exponential(multiplier=1, min=4, max=30), - retry=retry_if_exception_type(HTTPError), - stop=stop_after_attempt(5), - ) - def _download_content(self, doc_id: str) -> str | None: - api = PolishCourtAPI() - try: - return api.get_content(doc_id) - except HTTPError as err: - if err.response.status_code == 404: - logger.warning("Found no content for judgement {id}", id=doc_id) - return None - else: - raise - - -def yield_batches( - cursor: Cursor[dict[str, Any]], batch_size: int -) -> Generator[list[str], None, None]: - """Generates batches of data from pymongo.Cursor. - Credit: https://stackoverflow.com/a/61809417 - """ - - batch: list[str] = [] - for i, row in enumerate(cursor): - if i % batch_size == 0 and i > 0: - yield batch - del batch[:] - batch.append(str(row["_id"])) - yield batch - - -if __name__ == "__main__": - typer.run(main) diff --git a/scripts/dump_pl_dataset.py b/scripts/dump_pl_dataset.py new file mode 100644 index 0000000..f856fc6 --- /dev/null +++ b/scripts/dump_pl_dataset.py @@ -0,0 +1,74 @@ +import sys +from pathlib import Path +from typing import Any + +import pandas as pd +import typer +from dotenv import load_dotenv +from loguru import logger +from pymongo import MongoClient +from pymongo.server_api import ServerApi +from tqdm import tqdm, trange +from pyarrow.parquet import ParquetDataset + +BATCH_SIZE = 100 +CHUNK_SIZE = 50_000 + +load_dotenv("secrets.env", verbose=True) + + +def main( + mongo_uri: str = typer.Option(..., envvar="MONGO_URI"), + batch_size: int = typer.Option(BATCH_SIZE), + chunk_size: int = typer.Option(CHUNK_SIZE), + file_name: Path = typer.Option(..., exists=False), + filter_empty_content: bool = typer.Option(False), +) -> None: + file_name.parent.mkdir(exist_ok=True, parents=True) + + client: MongoClient[dict[str, Any]] = MongoClient(mongo_uri, server_api=ServerApi("1")) + collection = client["juddges"]["judgements"] + client.admin.command("ping") + + if filter_empty_content: + query = {"content": {"$ne": True}} + else: + query = {} + + num_docs = collection.count_documents(query) + + dumped_data = list(file_name.parent.glob("*.parquet")) + start_offset = 0 + if dumped_data: + logger.warning(f"Found {len(dumped_data)} files in {file_name.parent}") + if typer.confirm("Do you want to continue previous dump?"): + dataset = ParquetDataset(file_name.parent) + start_offset = sum(p.count_rows() for p in dataset.fragments) + else: + logger.error("Delete data to start a new data dump") + sys.exit(1) + + logger.info(f"Starting from {start_offset}-th document, batch no. {start_offset // chunk_size}") + for offset in trange(start_offset, num_docs, chunk_size, desc="Chunks"): + docs = list( + tqdm( + collection.find(query, batch_size=batch_size).skip(offset).limit(chunk_size), + total=chunk_size, + leave=False, + desc="Documents in chunk", + ) + ) + i = offset // chunk_size + dumped_f_name = save_docs(docs, file_name, i) + logger.info(f"Dumped {i}-th batch of documents to {dumped_f_name}") + + +def save_docs(docs: list[dict[str, Any]], file_name: Path, i: int | None) -> Path: + if i is not None: + file_name = file_name.with_name(f"{file_name.stem}_{i:02d}{file_name.suffix}") + pd.DataFrame(docs).to_parquet(file_name) + return file_name + + +if __name__ == "__main__": + typer.run(main) diff --git a/scripts/embed_text.py b/scripts/embed_text.py new file mode 100644 index 0000000..fab0ae2 --- /dev/null +++ b/scripts/embed_text.py @@ -0,0 +1,77 @@ +from pathlib import Path +from typing import Any +from langchain_text_splitters import RecursiveCharacterTextSplitter +from loguru import logger +import torch +import typer +from datasets import load_from_disk +from sentence_transformers import SentenceTransformer + +MODEL = "sdadas/mmlw-roberta-large" +MAX_CHUNK_SIZE = 500 +MIN_SPLIT_CHARS = 10 +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" + + +def main( + dataset_dir: Path = typer.Option(..., help="Path to the dataset directory"), + model: str = typer.Option(MODEL, help="Name of the model from HF hub"), + max_chunk_size: int = typer.Option(MAX_CHUNK_SIZE, help="Maximum number of chars in a chunk"), + min_split_chars: int = typer.Option( + MIN_SPLIT_CHARS, + help="Minimum number of chars to keep a chunk", + ), + batch_size: int = typer.Option(..., help="Batch size for tokenization"), + num_jobs: int = typer.Option(..., help="Number of parallel jobs to use"), + device: str = typer.Option(DEVICE, help="Device to use for the model"), +) -> None: + dataset = load_from_disk(dataset_dir) + + split_worker = TextSplitter(chunk_size=max_chunk_size, min_split_chars=min_split_chars) + ds = dataset.select_columns(["_id", "text"]).map( + split_worker, + batched=True, + num_proc=num_jobs, + remove_columns=["_id", "text"], + ) + logger.info(f"Dataset split into {ds.num_rows} chunks") + + model = SentenceTransformer(model).to(device) + ds = ds.map(Encoder(model), batched=True, batch_size=batch_size, num_proc=None) + ds.save_to_disk(dataset_dir.parent / "embeddings", num_shards=8) + + +class TextSplitter: + def __init__(self, chunk_size: int, min_split_chars: int | None = None) -> None: + self.splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size) + self.min_split_chars = min_split_chars + + def __call__(self, txt: dict[str, Any]) -> dict[str, Any]: + ids, chunks = [], [] + + for id_, text in zip(txt["_id"], txt["text"]): + current_chunks = self._split_text(text) + chunks.extend(current_chunks) + ids.extend([id_] * len(current_chunks)) + + return {"_id": ids, "text_chunk": chunks} + + def _split_text(self, text: str) -> list[str]: + chunks = self.splitter.split_text(text) + + if self.min_split_chars: + chunks = [split for split in chunks if len(split) >= self.min_split_chars] + + return chunks + + +class Encoder: + def __init__(self, model: SentenceTransformer) -> None: + self.model = model + + def __call__(self, items: dict[str, Any]) -> Any: + return {"embeddings": self.model.encode(items["text_chunk"])} + + +if __name__ == "__main__": + typer.run(main) diff --git a/scripts/extract_pl_xml.py b/scripts/extract_pl_xml.py new file mode 100644 index 0000000..78e81c5 --- /dev/null +++ b/scripts/extract_pl_xml.py @@ -0,0 +1,106 @@ +import math +from multiprocessing import Pool +from pathlib import Path +from typing import Optional, Any + +import typer +from datasets import load_dataset, Dataset +from dotenv import load_dotenv +from loguru import logger +from pymongo import MongoClient, UpdateOne +from pymongo.errors import BulkWriteError, ConfigurationError +from tenacity import ( + wait_random_exponential, + retry_if_exception_type, + stop_after_attempt, + retry, + retry_if_exception_message, + retry_all, +) +from tqdm import tqdm + +from juddges.preprocessing.pl_court_parser import SimplePlJudgementsParser + +BATCH_SIZE = 100 +INGEST_JOBS = 6 + +load_dotenv("secrets.env", verbose=True) + + +def main( + dataset_dir: Path = typer.Option(..., help="Path to the dataset directory"), + target_dir: Path = typer.Option(..., help="Path to the target directory"), + num_proc: Optional[int] = typer.Option(None, help="Number of processes to use"), + ingest: bool = typer.Option(False, help="Ingest the dataset to MongoDB"), + mongo_uri: Optional[str] = typer.Option(None, envvar="MONGO_URI"), + mongo_batch_size: int = typer.Option(BATCH_SIZE), + ingest_jobs: int = typer.Option(INGEST_JOBS), +) -> None: + target_dir.parent.mkdir(exist_ok=True, parents=True) + ds = load_dataset("parquet", name="pl_judgements", data_dir=dataset_dir) + num_shards = len(ds["train"].info.splits["train"].shard_lengths) + parser = SimplePlJudgementsParser() + ds = ( + ds["train"] + .select_columns( + ["_id", "date", "type", "excerpt", "content"] + ) # leave only most important columns + .filter(lambda x: x["content"] is not None) + .map(parser, input_columns="content", num_proc=num_proc) + ) + ds.save_to_disk(target_dir, num_shards=num_shards) + + if ingest: + assert mongo_uri is not None + ds = ds.with_format(columns=["_id"] + parser.schema) + _ingest_dataset(ds, mongo_uri, mongo_batch_size, ingest_jobs) + + +def _ingest_dataset(dataset: Dataset, mongo_uri: str, batch_size: int, num_jobs: int) -> None: + """Uploads the dataset to MongoDB.""" + num_batches = math.ceil(dataset.num_rows / batch_size) + + worker = IngestWorker(mongo_uri) + with Pool(num_jobs) as pool: + list( + tqdm( + pool.imap_unordered( + worker, + dataset.iter(batch_size=batch_size), + ), + total=num_batches, + desc="Ingesting", + ) + ) + + +class IngestWorker: + def __init__(self, mongo_uri: str): + self.mongo_uri = mongo_uri + + @retry( + wait=wait_random_exponential(multiplier=1, min=4, max=30), + retry=retry_all( + retry_if_exception_type(ConfigurationError), + retry_if_exception_message(match="DNS operation timed out"), + ), + stop=stop_after_attempt(5), + ) + def __call__(self, batch: dict[str, list[Any]]) -> None: + client: MongoClient[dict[str, Any]] = MongoClient(self.mongo_uri) + collection = client["juddges"]["judgements"] + + ids = batch.pop("_id") + ingest_batch = [ + UpdateOne({"_id": ids[i]}, {"$set": {col: batch[col][i] for col in batch.keys()}}) + for i in range(len(ids)) + ] + + try: + collection.bulk_write(ingest_batch, ordered=False) + except BulkWriteError as err: + logger.error(err) + + +if __name__ == "__main__": + typer.run(main)