From dea9a2fffce499ba6eb23705e386432115ab005b Mon Sep 17 00:00:00 2001 From: "jakub.binkowski" Date: Wed, 12 Jun 2024 16:49:00 +0000 Subject: [PATCH] Apply codestyle changes and optimizations --- .github/workflows/python.yaml | 4 +- README.md | 37 +++++++++---- configs/dataset/pl-court-instruct.yaml | 2 +- ...nsloth-Llama-3-8B-Instruct-fine-tuned.yaml | 2 +- .../model/Unsloth-Llama-3-8B-Instruct.yaml | 2 +- dvc.lock | 36 ++++++------- juddges/preprocessing/context_truncator.py | 2 +- nbs/index.ipynb | 54 ++++++++++--------- scripts/dataset/build_instruct_dataset.py | 10 ++-- .../dataset/download_pl_additional_data.py | 4 +- scripts/sft/fine_tune_llm.py | 4 ++ scripts/sft/predict.py | 4 +- scripts/sft/summarize_metrics.py | 8 +-- scripts/utils/clean_dvc_lock.py | 4 +- 14 files changed, 99 insertions(+), 74 deletions(-) diff --git a/.github/workflows/python.yaml b/.github/workflows/python.yaml index e8d9a81..0248313 100644 --- a/.github/workflows/python.yaml +++ b/.github/workflows/python.yaml @@ -20,7 +20,7 @@ jobs: - uses: actions/cache@v3 with: path: ${{ env.pythonLocation }} - key: ${{ env.pythonLocation }}-${{ hashFiles('requirements.txt') }} + key: ${{ runner.os }}-${{ env.pythonLocation }}-${{ hashFiles('requirements.txt') }} - name: Install deps run: make install_cpu - name: Lint @@ -46,7 +46,7 @@ jobs: - uses: actions/cache@v3 with: path: ${{ env.pythonLocation }} - key: ${{ env.pythonLocation }}-${{ hashFiles('requirements.txt') }} + key: ${{ runner.os }}-${{ env.pythonLocation }}-${{ hashFiles('requirements.txt') }} - name: Install deps run: make install_cpu - name: Test diff --git a/README.md b/README.md index fbc03c5..93e52a4 100644 --- a/README.md +++ b/README.md @@ -23,12 +23,35 @@ fostering cross-disciplinary and cross-jurisdictional collaboration. ![baner](https://raw.githubusercontent.com/pwr-ai/JuDDGES/bffb1d75ba7c78f101fc94bd9086499886b2c128/nbs/images/baner.png) +## Usage + +### Installation + +- to install necessary dependencies use available `Makefile`, you can + use `python>=3.10`: `shell make install` +- if you want to run evaluation and fine-tuning with `unsloth`, use the + following command with `python=3.10` inside conda environment: + `shell make install_unsloth` + +### Dataset creation + +The specific details of dataset creation are available in +[scripts/README.md](scripts/README.md). + +### Fine tuning + +To run evaluation or fine-tuning, run proper stages declared +[`dvc.yaml`](dvc.yaml) (see [DVC docs for +details](https://dvc.org/doc/user-guide)) + +## Project details + The JuDDGES project encompasses several Work Packages (WPs) designed to cover all aspects of its objectives, from project management to the open science practices and engaging early career researchers. Below is an overview of the project’s WPs based on the provided information: -## WP1: Project Management +### WP1: Project Management **Duration**: 24 Months @@ -37,7 +60,7 @@ within budget. This includes administrative management, scientific and technological management, quality innovation and risk management, ethical and legal consideration, and facilitating open science. -## WP2: Gathering and Human Encoding of Judicial Decision Data +### WP2: Gathering and Human Encoding of Judicial Decision Data **Duration**: 22 Months @@ -48,7 +71,7 @@ coders, making human-coded data available for WP3, facilitating human-in-loop coding for WP3, and enabling WP4 to make data open and reusable beyond the project team. -## WP3: NLP and HITL Machine Learning Methodological Development +### WP3: NLP and HITL Machine Learning Methodological Development **Duration**: 24 Months @@ -59,7 +82,7 @@ baseline information extraction, intelligent inference methods for legal corpus data, and constructing an annotation tool through active learning and human-in-the-loop annotation methods. -## WP4: Open Science Practices & Engaging Early Career Researchers +### WP4: Open Science Practices & Engaging Early Career Researchers **Duration**: 12 Months @@ -73,12 +96,6 @@ Each WP includes specific tasks aimed at achieving its goals, involving collaboration among project partners and contributing to the overarching aim of the JuDDGES project​​. -## Install - -``` sh -pip install juddges -``` - ## Acknowledgements The universities involved in the JuDDGES project are: diff --git a/configs/dataset/pl-court-instruct.yaml b/configs/dataset/pl-court-instruct.yaml index 644a76b..d9e0a27 100644 --- a/configs/dataset/pl-court-instruct.yaml +++ b/configs/dataset/pl-court-instruct.yaml @@ -1,4 +1,4 @@ name: JuDDGES/pl-court-instruct prompt_field: prompt context_field: context -output_field: output \ No newline at end of file +output_field: output diff --git a/configs/model/Unsloth-Llama-3-8B-Instruct-fine-tuned.yaml b/configs/model/Unsloth-Llama-3-8B-Instruct-fine-tuned.yaml index cd8f7f8..81acb05 100644 --- a/configs/model/Unsloth-Llama-3-8B-Instruct-fine-tuned.yaml +++ b/configs/model/Unsloth-Llama-3-8B-Instruct-fine-tuned.yaml @@ -7,4 +7,4 @@ max_seq_length: 7_900 batch_size: 1 padding: longest -use_unsloth: true \ No newline at end of file +use_unsloth: true diff --git a/configs/model/Unsloth-Llama-3-8B-Instruct.yaml b/configs/model/Unsloth-Llama-3-8B-Instruct.yaml index b9b331a..9c251ff 100644 --- a/configs/model/Unsloth-Llama-3-8B-Instruct.yaml +++ b/configs/model/Unsloth-Llama-3-8B-Instruct.yaml @@ -7,4 +7,4 @@ max_seq_length: 7_900 batch_size: 1 padding: longest -use_unsloth: true \ No newline at end of file +use_unsloth: true diff --git a/dvc.lock b/dvc.lock index 8c09a5f..b04ce8c 100644 --- a/dvc.lock +++ b/dvc.lock @@ -65,8 +65,8 @@ stages: nfiles: 17 - path: scripts/dataset/build_instruct_dataset.py hash: md5 - md5: f07a9d106853e74be4d0e4807b33ff3d - size: 5393 + md5: 9b138322059d63ce3ad1bd05c8b931f2 + size: 5461 embed@mmlw-roberta-large: cmd: PYTHONPATH=. python scripts/embed/embed_text.py embedding_model=mmlw-roberta-large deps: @@ -188,16 +188,16 @@ stages: deps: - path: configs/model/Unsloth-Llama-3-8B-Instruct.yaml hash: md5 - md5: e97bb2e6bf39f75edea7714d6ba58b77 - size: 160 + md5: 1b4c0353b8c41fd3656ec5cf15eb6c2b + size: 161 - path: configs/predict.yaml hash: md5 md5: 74ad1dc5d9f130074533078d85e55e94 size: 504 - path: scripts/sft/predict.py hash: md5 - md5: 85a61d28419afaf276ea17863205aa2a - size: 3203 + md5: 59c2afb977f520c9134153def544111d + size: 3188 outs: - path: data/experiments/predict/pl-court-instruct/outputs_Unsloth-Llama-3-8B-Instruct.json @@ -209,16 +209,16 @@ stages: deps: - path: configs/model/Unsloth-Llama-3-8B-Instruct-fine-tuned.yaml hash: md5 - md5: c9fab7cd7a4b0159d13ba61e3c516c0a - size: 244 + md5: dd00fc3994bdc95baf1f17de7b026a0f + size: 245 - path: configs/predict.yaml hash: md5 md5: 74ad1dc5d9f130074533078d85e55e94 size: 504 - path: scripts/sft/predict.py hash: md5 - md5: 85a61d28419afaf276ea17863205aa2a - size: 3203 + md5: 59c2afb977f520c9134153def544111d + size: 3188 outs: - path: data/experiments/predict/pl-court-instruct/outputs_Unsloth-Llama-3-8B-Instruct-fine-tuned.json @@ -238,8 +238,8 @@ stages: size: 504 - path: scripts/sft/predict.py hash: md5 - md5: 85a61d28419afaf276ea17863205aa2a - size: 3203 + md5: 59c2afb977f520c9134153def544111d + size: 3188 outs: - path: data/experiments/predict/pl-court-instruct/outputs_Unsloth-Mistral-7B-Instruct-v0.3.json @@ -259,8 +259,8 @@ stages: size: 504 - path: scripts/sft/predict.py hash: md5 - md5: 85a61d28419afaf276ea17863205aa2a - size: 3203 + md5: 59c2afb977f520c9134153def544111d + size: 3188 outs: - path: data/experiments/predict/pl-court-instruct/outputs_Unsloth-Mistral-7B-Instruct-v0.3-fine-tuned.json @@ -295,8 +295,8 @@ stages: size: 528 - path: configs/model/Unsloth-Llama-3-8B-Instruct.yaml hash: md5 - md5: e97bb2e6bf39f75edea7714d6ba58b77 - size: 160 + md5: 1b4c0353b8c41fd3656ec5cf15eb6c2b + size: 161 - path: scripts/sft/fine_tune_unsloth.py hash: md5 md5: 2c3ca748d6f7bf76e92eb7bc8ded5f37 @@ -333,8 +333,8 @@ stages: deps: - path: scripts/sft/summarize_metrics.py hash: md5 - md5: 8993c84349eab010b5e484fbf43fd8ff - size: 737 + md5: 482321d8b7291cbed1e80bed7b685b46 + size: 782 outs: - path: data/experiments/predict/pl-court-instruct/metrics_summary.md hash: md5 diff --git a/juddges/preprocessing/context_truncator.py b/juddges/preprocessing/context_truncator.py index bb9d935..7bc41af 100644 --- a/juddges/preprocessing/context_truncator.py +++ b/juddges/preprocessing/context_truncator.py @@ -10,8 +10,8 @@ def __init__(self, tokenizer: BaseTokenizer, max_length: int): empty_messages = [ {"role": "user", "content": ""}, + {"role": "assistant", "content": ""}, ] - empty_messages.append({"role": "assistant", "content": ""}) self.empty_messages_length = len( self.tokenizer.apply_chat_template(empty_messages, tokenize=True) diff --git a/nbs/index.ipynb b/nbs/index.ipynb index 4d8e6a9..2f640e2 100644 --- a/nbs/index.ipynb +++ b/nbs/index.ipynb @@ -27,27 +27,52 @@ "cell_type": "markdown", "metadata": {}, "source": [ + "## Usage\n", + "\n", + "### Installation\n", + "- to install necessary dependencies use available `Makefile`, you can use `python>=3.10`:\n", + " ```shell\n", + " make install\n", + " ```\n", + "- if you want to run evaluation and fine-tuning with `unsloth`, use the following command with `python=3.10` inside conda environment:\n", + " ```shell\n", + " make install_unsloth\n", + " ```\n", + "\n", + "### Dataset creation\n", + "The specific details of dataset creation are available in [scripts/README.md](scripts/README.md).\n", + "\n", + "### Fine tuning\n", + "To run evaluation or fine-tuning, run proper stages declared [`dvc.yaml`](dvc.yaml) (see [DVC docs for details](https://dvc.org/doc/user-guide))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Project details\n", + "\n", "The JuDDGES project encompasses several Work Packages (WPs) designed to cover all aspects of its objectives, from project management to the open science practices and engaging early career researchers. Below is an overview of the project's WPs based on the provided information:\n", "\n", - "## WP1: Project Management\n", + "### WP1: Project Management\n", "\n", "**Duration**: 24 Months\n", "\n", "**Main Aim**: To ensure the project's successful completion on time and within budget. This includes administrative management, scientific and technological management, quality innovation and risk management, ethical and legal consideration, and facilitating open science.\n", "\n", - "## WP2: Gathering and Human Encoding of Judicial Decision Data\n", + "### WP2: Gathering and Human Encoding of Judicial Decision Data\n", "\n", "**Duration**: 22 Months\n", "\n", "**Main Aim**: To establish the data foundation for developing and testing the project's tools. This involves collating/gathering legal case records and judgments, developing a coding scheme, training human coders, making human-coded data available for WP3, facilitating human-in-loop coding for WP3, and enabling WP4 to make data open and reusable beyond the project team.\n", "\n", - "## WP3: NLP and HITL Machine Learning Methodological Development\n", + "### WP3: NLP and HITL Machine Learning Methodological Development\n", "\n", "**Duration**: 24 Months\n", "\n", "**Main Aim**: To create a bridge between machine learning (led by WUST and MUHEC) and Open Science facilitation (by ELICO), focusing on the development and deployment of annotation methodologies. This includes baseline information extraction, intelligent inference methods for legal corpus data, and constructing an annotation tool through active learning and human-in-the-loop annotation methods.\n", "\n", - "## WP4: Open Science Practices & Engaging Early Career Researchers\n", + "### WP4: Open Science Practices & Engaging Early Career Researchers\n", "\n", "**Duration**: 12 Months\n", "\n", @@ -56,22 +81,6 @@ "Each WP includes specific tasks aimed at achieving its goals, involving collaboration among project partners and contributing to the overarching aim of the JuDDGES project​​.\n" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Install\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "```sh\n", - "pip install juddges\n", - "```\n" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -84,11 +93,6 @@ "2. Middlesex University London (UK)\n", "3. University of Lyon 1 (France)​​.\n" ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [] } ], "metadata": { diff --git a/scripts/dataset/build_instruct_dataset.py b/scripts/dataset/build_instruct_dataset.py index 1c81c83..853e311 100644 --- a/scripts/dataset/build_instruct_dataset.py +++ b/scripts/dataset/build_instruct_dataset.py @@ -14,6 +14,8 @@ load_dotenv() MAX_SHARD_SIZE = "4GB" +DATE_FORMAT = "%Y-%m-%d" +JUDGE_SEPARATOR = " i " SCHEMA_TEMPLATE = "```yaml\n{schema}\n```" INSTRUCTION_TEMPLATE = """ @@ -113,7 +115,7 @@ def main( def _pre_filter(item: dict[str, Any]) -> bool: - return all(item[feat] is not None for feat in FEATURES) + return not any(item[feat] is None for feat in FEATURES) def _preprocess(item: dict[str, Any]) -> dict[str, Any]: @@ -125,7 +127,7 @@ def _preprocess(item: dict[str, Any]) -> dict[str, Any]: def _simplify_date(item: dict[str, Any]) -> dict[str, Any]: item["date"], *_ = item["date"].split() - datetime.strptime(item["date"], "%Y-%m-%d") # raises ValueError on invalid format + datetime.strptime(item["date"], DATE_FORMAT) # raises ValueError on invalid format return item @@ -133,7 +135,7 @@ def _split_multiple_names(item: dict[str, Any]) -> dict[str, Any]: """Splits judges names that are joined by 'i'.""" judges = [] for j in item["judges"]: - for j_part in j.split(" i "): + for j_part in j.split(JUDGE_SEPARATOR): judges.append(j_part) item["judges"] = judges @@ -146,7 +148,7 @@ def _legal_bases_to_texts(item: dict[str, Any]) -> dict[str, Any]: def _filter(item: dict[str, Any]) -> bool: - all_judges_in_text = all(j in item["text"] for j in item["judges"]) + all_judges_in_text = not any(j not in item["text"] for j in item["judges"]) recorder_in_text = item["recorder"] in item["text"] signature_in_text = item["signature"] in item["text"] diff --git a/scripts/dataset/download_pl_additional_data.py b/scripts/dataset/download_pl_additional_data.py index 4d3f65b..9019df0 100644 --- a/scripts/dataset/download_pl_additional_data.py +++ b/scripts/dataset/download_pl_additional_data.py @@ -43,7 +43,7 @@ def main( cursor = collection.find(query, {"_id": 1}, batch_size=batch_size) batched_cursor = BatchedDatabaseCursor(cursor=cursor, batch_size=batch_size, prefetch=True) - download_data = DownloadAdditionalData(data_type) + download_data = AdditionalDataDownloader(data_type) download_data_and_update_db = BatchDatabaseUpdate(mongo_uri, download_data) with multiprocessing.Pool(n_jobs) as pool: @@ -60,7 +60,7 @@ def main( assert collection.count_documents(query) == 0 -class DownloadAdditionalData: +class AdditionalDataDownloader: def __init__(self, data_type: DataType): self.data_type = data_type self.api = PolishCourtAPI() diff --git a/scripts/sft/fine_tune_llm.py b/scripts/sft/fine_tune_llm.py index d562378..d9d43d1 100644 --- a/scripts/sft/fine_tune_llm.py +++ b/scripts/sft/fine_tune_llm.py @@ -1,3 +1,7 @@ +""" +Fine-tune a large language model using SFT. +the script is based on: https://www.philschmid.de/fine-tune-llms-in-2024-with-trl +""" import os from pathlib import Path diff --git a/scripts/sft/predict.py b/scripts/sft/predict.py index e26bcd8..89318d9 100644 --- a/scripts/sft/predict.py +++ b/scripts/sft/predict.py @@ -45,7 +45,7 @@ def main(cfg: DictConfig) -> None: output_file = Path(config.output_file) output_file.parent.mkdir(parents=True, exist_ok=True) - ds = load_dataset("JuDDGES/pl-court-instruct") + ds = load_dataset(config.dataset.name, split="test") logger.info("Loading model...") model_pack = get_model(config.model, device_map=config.device_map) @@ -54,8 +54,6 @@ def main(cfg: DictConfig) -> None: if config.model.batch_size > 1 and config.model.padding is False: raise ValueError("Padding has to be enabled if batch size > 1.") - ds = ds["test"] - gold_outputs = [item["output"] for item in ds] encoder = TextEncoderForEval( diff --git a/scripts/sft/summarize_metrics.py b/scripts/sft/summarize_metrics.py index 991dae4..e749f6a 100644 --- a/scripts/sft/summarize_metrics.py +++ b/scripts/sft/summarize_metrics.py @@ -14,9 +14,11 @@ def main( with f.open() as file: m_res = json.load(file) results.append( - {"llm": model_name} - | {"full_text_chrf": m_res["full_text_chrf"]} - | m_res["field_chrf"] + { + "llm": model_name, + "full_text_chrf": m_res["full_text_chrf"], + **m_res["field_chrf"], + } ) summary_file = root_dir / "metrics_summary.md" diff --git a/scripts/utils/clean_dvc_lock.py b/scripts/utils/clean_dvc_lock.py index 558c755..95edb58 100644 --- a/scripts/utils/clean_dvc_lock.py +++ b/scripts/utils/clean_dvc_lock.py @@ -16,8 +16,6 @@ def main( for stage, *_ in repo.index.graph.nodes(data=True): if stage.path_in_repo == "dvc.yaml": stages.add(stage.name) - else: - continue with dvc_lock.open() as file: lock_file = yaml.safe_load(file) @@ -25,7 +23,7 @@ def main( lock_stages = set(lock_file["stages"].keys()) to_remove = lock_stages.difference(stages) - print(to_remove) + print(f"Removing stages from lock: \n{to_remove}") if typer.confirm("Are you sure you want to delete?"): lock_file["stages"] = { key: val for key, val in lock_file["stages"].items() if key not in to_remove