diff --git a/.github/workflows/nbdev-test.yaml b/.github/workflows/nbdev-test.yaml index acc31c8..822fab5 100644 --- a/.github/workflows/nbdev-test.yaml +++ b/.github/workflows/nbdev-test.yaml @@ -4,7 +4,7 @@ on: [workflow_dispatch, pull_request] jobs: nbdev-test: runs-on: ubuntu-latest - steps: + steps: - uses: fastai/workflows/nbdev-ci@master with: skip_test: true diff --git a/.github/workflows/python.yaml b/.github/workflows/python.yaml index 0248313..af5c3fa 100644 --- a/.github/workflows/python.yaml +++ b/.github/workflows/python.yaml @@ -14,15 +14,26 @@ jobs: runs-on: ${{ matrix.os }} steps: - uses: actions/checkout@v4 - - uses: actions/setup-python@v5 + + - name: Setup Python + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - - uses: actions/cache@v3 + cache: 'pip' + cache-dependency-path: | + requirements.txt + + - name: Cache pre-commit + uses: actions/cache@v4 with: - path: ${{ env.pythonLocation }} - key: ${{ runner.os }}-${{ env.pythonLocation }}-${{ hashFiles('requirements.txt') }} + path: ~/.cache/pre-commit + key: pre-commit|${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('.pre-commit-config.yaml') }} + - name: Install deps - run: make install_cpu + run: | + python -m pip install --upgrade pip + make install_cpu + - name: Lint run: make check @@ -40,15 +51,20 @@ jobs: runs-on: ${{ matrix.os }} steps: - uses: actions/checkout@v4 - - uses: actions/setup-python@v5 + + - name: Setup Python + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - - uses: actions/cache@v3 - with: - path: ${{ env.pythonLocation }} - key: ${{ runner.os }}-${{ env.pythonLocation }}-${{ hashFiles('requirements.txt') }} + cache: 'pip' + cache-dependency-path: | + requirements.txt + - name: Install deps - run: make install_cpu + run: | + python -m pip install --upgrade pip + make install_cpu + - name: Test env: HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..8628541 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,29 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.6.0 + hooks: + - id: no-commit-to-branch + name: No commits to master + - id: end-of-file-fixer + name: End-of-file fixer + - name: mixed-line-ending + id: mixed-line-ending + args: [--fix, lf] + - id: trailing-whitespace + name: Remove trailing whitespaces + - id: check-toml + name: Check toml + - id: check-yaml + name: Check yaml + args: [--allow-multiple-documents] + + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.4.10 + hooks: + - id: ruff + name: Ruff Linter + args: [--fix, --exit-non-zero-on-fix, juddges, scripts, dashboards, tests] + - id: ruff-format + name: Ruff Formatter + args: [juddges, scripts, dashboards, tests] diff --git a/Dockerfile b/Dockerfile index bf1347c..75968f9 100644 --- a/Dockerfile +++ b/Dockerfile @@ -24,7 +24,7 @@ RUN apt-get update \ && gdebi -n quarto-1.5.17-linux-amd64.deb \ && apt-get clean \ && rm -rf /var/lib/apt/lists \ - && rm -rf /tmp + && rm -rf /tmp ENV PYTHONDONTWRITEBYTECODE 1 ENV PYTHONUNBUFFERED 1 diff --git a/Makefile b/Makefile index fc031cf..f311356 100644 --- a/Makefile +++ b/Makefile @@ -2,12 +2,10 @@ lint_dirs := juddges scripts dashboards tests mypy_dirs := juddges scripts dashboards tests fix: - ruff check $(lint_dirs) --fix - ruff format $(lint_dirs) + pre-commit run --all-files check: - ruff check $(lint_dirs) - ruff format $(lint_dirs) --check + pre-commit run --all-files check-types: mypy --install-types --non-interactive $(mypy_dirs) @@ -25,8 +23,7 @@ install: install_cpu: pip install --find-links https://download.pytorch.org/whl/cpu -r requirements.txt -# unsloth requires python 3.10 -# requires conda environment +# unsloth requires python 3.10 and conda environment install_unsloth: conda install --yes pytorch-cuda=12.1 pytorch cudatoolkit xformers -c pytorch -c nvidia -c xformers pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git" diff --git a/configs/embedding.yaml b/configs/embedding.yaml index c9c891e..37afafe 100644 --- a/configs/embedding.yaml +++ b/configs/embedding.yaml @@ -2,8 +2,8 @@ defaults: - embedding_model: ??? - dataset: pl-court-raw - _self_ - - override hydra/hydra_logging: disabled - - override hydra/job_logging: disabled + - override hydra/hydra_logging: disabled + - override hydra/job_logging: disabled length_adjust_mode: chunk chunk_config: @@ -14,7 +14,7 @@ batch_size: 64 output_dir: data/embeddings/${dataset.name}/${hydra:runtime.choices.embedding_model}/all_embeddings -hydra: - output_subdir: null - run: +hydra: + output_subdir: null + run: dir: . diff --git a/configs/fine_tuning.yaml b/configs/fine_tuning.yaml index 4cac7aa..e96d55b 100644 --- a/configs/fine_tuning.yaml +++ b/configs/fine_tuning.yaml @@ -2,8 +2,8 @@ defaults: - model: ??? - dataset: pl-court-instruct - _self_ - - override hydra/hydra_logging: disabled - - override hydra/job_logging: disabled + - override hydra/hydra_logging: disabled + - override hydra/job_logging: disabled output_dir: data/experiments/fine-tune/${hydra:runtime.choices.model}/${hydra:runtime.choices.dataset} run_name: ${hydra:runtime.choices.model}_${hydra:runtime.choices.dataset}_fine_tune @@ -17,7 +17,7 @@ truncate_context: True epochs: 1 batch_size: 4 -hydra: - output_subdir: null - run: +hydra: + output_subdir: null + run: dir: . diff --git a/configs/llm_judge.yaml b/configs/llm_judge.yaml new file mode 100644 index 0000000..416519f --- /dev/null +++ b/configs/llm_judge.yaml @@ -0,0 +1,11 @@ +defaults: + - model: ??? + - _self_ + +answers_file: ??? +out_metric_file: ??? +out_predictions_file: ??? + +generate_kwargs: + max_new_tokens: 20 + do_sample: False diff --git a/configs/predict.yaml b/configs/predict.yaml index a85cf50..9dd1440 100644 --- a/configs/predict.yaml +++ b/configs/predict.yaml @@ -2,17 +2,19 @@ defaults: - model: ??? - dataset: pl-court-instruct - _self_ - - override hydra/hydra_logging: disabled - - override hydra/job_logging: disabled + - override hydra/hydra_logging: disabled + - override hydra/job_logging: disabled device_map: 'auto' output_file: data/experiments/predict/${hydra:runtime.choices.dataset}/outputs_${hydra:runtime.choices.model}.json metrics_file: data/experiments/predict/${hydra:runtime.choices.dataset}/metrics_${hydra:runtime.choices.model}.json -max_new_tokens: 250 truncate_context: True +generate_kwargs: + max_new_tokens: 250 + do_sample: False -hydra: - output_subdir: null - run: +hydra: + output_subdir: null + run: dir: . diff --git "a/dashboards/pages/01_\360\237\224\215_Search_Judgements.py" "b/dashboards/pages/01_\360\237\224\215_Search_Judgements.py" index 74d86f0..28543d2 100644 --- "a/dashboards/pages/01_\360\237\224\215_Search_Judgements.py" +++ "b/dashboards/pages/01_\360\237\224\215_Search_Judgements.py" @@ -1,40 +1,41 @@ -from typing import Any -import streamlit as st - -from juddges.data.datasets import get_mongo_collection -from pymongo.collection import Collection - -TITLE = "Search for Judgements" - -st.set_page_config(page_title=TITLE, page_icon="⚖️", layout="wide") - -st.title(TITLE) - - -@st.cache_resource -def get_judgements_collection() -> Collection: - return get_mongo_collection("judgements") - - -judgements_collection = get_judgements_collection() - - -def search_data(query: str, max_judgements: int = 5) -> list[dict[str, Any]]: - items = list(judgements_collection.find({"$text": {"$search": query}}).limit(max_judgements)) - return items - - -with st.form(key="search_form"): - text = st.text_area("What you are looking for in the judgements?") - max_judgements = st.slider("Max judgements to show", min_value=1, max_value=20, value=5) - submit_button = st.form_submit_button(label="Search") - -if submit_button: - with st.spinner("Searching..."): - items = search_data(text, max_judgements) - - st.header("Judgements - Results") - for item in items: - st.header(item["signature"]) - st.subheader(item["publicationDate"]) - st.write(item["text"]) +from typing import Any + +import streamlit as st +from pymongo.collection import Collection + +from juddges.data.datasets import get_mongo_collection + +TITLE = "Search for Judgements" + +st.set_page_config(page_title=TITLE, page_icon="⚖️", layout="wide") + +st.title(TITLE) + + +@st.cache_resource +def get_judgements_collection() -> Collection: + return get_mongo_collection("judgements") + + +judgements_collection = get_judgements_collection() + + +def search_data(query: str, max_judgements: int = 5) -> list[dict[str, Any]]: + items = list(judgements_collection.find({"$text": {"$search": query}}).limit(max_judgements)) + return items + + +with st.form(key="search_form"): + text = st.text_area("What you are looking for in the judgements?") + max_judgements = st.slider("Max judgements to show", min_value=1, max_value=20, value=5) + submit_button = st.form_submit_button(label="Search") + +if submit_button: + with st.spinner("Searching..."): + items = search_data(text, max_judgements) + + st.header("Judgements - Results") + for item in items: + st.header(item["signature"]) + st.subheader(item["publicationDate"]) + st.write(item["text"]) diff --git "a/dashboards/pages/02_\360\237\224\215_Analyse_Extracted_Information.py" "b/dashboards/pages/02_\360\237\224\215_Analyse_Extracted_Information.py" index ad663b2..9aeb056 100644 --- "a/dashboards/pages/02_\360\237\224\215_Analyse_Extracted_Information.py" +++ "b/dashboards/pages/02_\360\237\224\215_Analyse_Extracted_Information.py" @@ -1,86 +1,86 @@ -import io - -import pandas as pd -import streamlit as st - -from juddges.prompts.information_extraction import EXAMPLE_SCHEMA -from juddges.settings import SAMPLE_DATA_PATH - -TITLE = "Analyse Judgements" - -st.set_page_config(page_title=TITLE, page_icon="⚖️", layout="wide") - -st.title(TITLE) - - -@st.cache_resource -def load_data(): - return pd.read_csv(SAMPLE_DATA_PATH / "judgements-100-sample-with-retrieved-informations.csv") - - -df = load_data() -extracted_keys = [line.split(":")[0] for line in EXAMPLE_SCHEMA.split("\n") if len(line) > 3] + [ - "signature", - "excerpt", - "text", - "judges", - "references", -] - -st.info( - "We sampled 100 random judgements from the dataset and extracted information from them. Below is the extracted information and the schema (questions) used to extract it." -) - -st.text_area( - "Example schema for extracted informations: ", value=EXAMPLE_SCHEMA, height=300, disabled=True -) - -st.header("Extracted Information - tabular format") -st.write(df[extracted_keys]) - - -output = io.BytesIO() -with pd.ExcelWriter(output, engine="xlsxwriter") as writer: - df.to_excel(writer, sheet_name="Sheet1", index=False) -output.seek(0) -st.download_button( - label="Download data as Excel", - data=output, - file_name="judgements.xlsx", - mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", -) - -st.header("Analyse Extracted Information") - -st.subheader("How many judgements we analyzed?") - -st.write(f"Number of judgements: {len(df)}") - -st.subheader("What courts judgement do we analyse") - -st.write(df.groupby("court")["_id"].count()) - -st.subheader("How many judgements are drug offences?") - -drug_offences = df["drug_offence"].sum() - -st.info(f"Number of drug offences: {drug_offences}") - -st.subheader("How many judgements are child offences?") - -child_offences = df["child_offence"].sum() - -st.info(f"Number of child offences: {child_offences}") - -st.subheader("Show examples of judgements that are child offences") - -drug_offences_df = df[df["child_offence"]] - -st.write("We can check the sentences of them") - -for row_id, row in drug_offences_df.iterrows(): - st.subheader(row["signature"]) - st.info(row["verdict_summary"]) - if st.toggle(key=row, label="Show judgement's text"): - st.markdown(row["text"]) - st.markdown("---") +import io + +import pandas as pd +import streamlit as st + +from juddges.prompts.information_extraction import EXAMPLE_SCHEMA +from juddges.settings import SAMPLE_DATA_PATH + +TITLE = "Analyse Judgements" + +st.set_page_config(page_title=TITLE, page_icon="⚖️", layout="wide") + +st.title(TITLE) + + +@st.cache_resource +def load_data(): + return pd.read_csv(SAMPLE_DATA_PATH / "judgements-100-sample-with-retrieved-informations.csv") + + +df = load_data() +extracted_keys = [line.split(":")[0] for line in EXAMPLE_SCHEMA.split("\n") if len(line) > 3] + [ + "signature", + "excerpt", + "text", + "judges", + "references", +] + +st.info( + "We sampled 100 random judgements from the dataset and extracted information from them. Below is the extracted information and the schema (questions) used to extract it." +) + +st.text_area( + "Example schema for extracted informations: ", value=EXAMPLE_SCHEMA, height=300, disabled=True +) + +st.header("Extracted Information - tabular format") +st.write(df[extracted_keys]) + + +output = io.BytesIO() +with pd.ExcelWriter(output, engine="xlsxwriter") as writer: + df.to_excel(writer, sheet_name="Sheet1", index=False) +output.seek(0) +st.download_button( + label="Download data as Excel", + data=output, + file_name="judgements.xlsx", + mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", +) + +st.header("Analyse Extracted Information") + +st.subheader("How many judgements we analyzed?") + +st.write(f"Number of judgements: {len(df)}") + +st.subheader("What courts judgement do we analyse") + +st.write(df.groupby("court")["_id"].count()) + +st.subheader("How many judgements are drug offences?") + +drug_offences = df["drug_offence"].sum() + +st.info(f"Number of drug offences: {drug_offences}") + +st.subheader("How many judgements are child offences?") + +child_offences = df["child_offence"].sum() + +st.info(f"Number of child offences: {child_offences}") + +st.subheader("Show examples of judgements that are child offences") + +drug_offences_df = df[df["child_offence"]] + +st.write("We can check the sentences of them") + +for row_id, row in drug_offences_df.iterrows(): + st.subheader(row["signature"]) + st.info(row["verdict_summary"]) + if st.toggle(key=row, label="Show judgement's text"): + st.markdown(row["text"]) + st.markdown("---") diff --git a/data/datasets/pl/graph/template_README.md b/data/datasets/pl/graph/template_README.md index 842b85f..8f9dd2b 100755 --- a/data/datasets/pl/graph/template_README.md +++ b/data/datasets/pl/graph/template_README.md @@ -10,11 +10,11 @@ tags: {{tags}} # Polish Court Judgments Graph ## Dataset description -We introduce a graph dataset of Polish Court Judgments. This dataset is primarily based on the [`JuDDGES/pl-court-raw`](https://huggingface.co/datasets/JuDDGES/pl-court-raw). The dataset consists of nodes representing either judgments or legal bases, and edges connecting judgments to the legal bases they refer to. Also, the graph was cleaned from small disconnected components, leaving single giant component. Consequently, the resulting graph is bipartite. We provide the dataset in both `JSON` and `PyG` formats, each has different purpose. While structurally graphs in these formats are the same, their attributes differ. +We introduce a graph dataset of Polish Court Judgments. This dataset is primarily based on the [`JuDDGES/pl-court-raw`](https://huggingface.co/datasets/JuDDGES/pl-court-raw). The dataset consists of nodes representing either judgments or legal bases, and edges connecting judgments to the legal bases they refer to. Also, the graph was cleaned from small disconnected components, leaving single giant component. Consequently, the resulting graph is bipartite. We provide the dataset in both `JSON` and `PyG` formats, each has different purpose. While structurally graphs in these formats are the same, their attributes differ. The `JSON` format is intended for analysis and contains most of the attributes available in [`JuDDGES/pl-court-raw`](https://huggingface.co/datasets/JuDDGES/pl-court-raw). We excluded some less-useful attributes and text content, which can be easily retrieved from the raw dataset and added to the graph as needed. -The `PyG` format is designed for machine learning applications, such as link prediction on graphs, and is fully compatible with the [`Pytorch Geometric`](https://github.com/pyg-team/pytorch_geometric) framework. +The `PyG` format is designed for machine learning applications, such as link prediction on graphs, and is fully compatible with the [`Pytorch Geometric`](https://github.com/pyg-team/pytorch_geometric) framework. In the following sections, we provide a more detailed explanation and use case examples for each format. @@ -28,9 +28,9 @@ In the following sections, we provide a more detailed explanation and use case e | #nodes (type=`legal_base`) | {{num_target_nodes}} | | avg(degree) | {{avg_degree}} | - + ![png](assets/degree_distribution.png) - + ## `JSON` format @@ -67,10 +67,10 @@ g = nx.node_link_graph(g_data) ## `PyG` format -The `PyTorch Geometric` format includes embeddings of the judgment content, obtained with [{{embedding_method}}](https://huggingface.co/{{embedding_method}}) for judgment nodes, -and one-hot-vector identifiers for legal-base nodes (note that for efficiency one can substitute it with random noise identifiers, +The `PyTorch Geometric` format includes embeddings of the judgment content, obtained with [{{embedding_method}}](https://huggingface.co/{{embedding_method}}) for judgment nodes, +and one-hot-vector identifiers for legal-base nodes (note that for efficiency one can substitute it with random noise identifiers, like in [(Abboud et al., 2021)](https://arxiv.org/abs/2010.01179)). - + ### Loading @@ -134,4 +134,4 @@ print(ds) ### Example usage ```python # TBD -``` \ No newline at end of file +``` diff --git a/data/experiments/llm_as_judge/pl-court-instruct/.gitignore b/data/experiments/llm_as_judge/pl-court-instruct/.gitignore new file mode 100644 index 0000000..8c7a84e --- /dev/null +++ b/data/experiments/llm_as_judge/pl-court-instruct/.gitignore @@ -0,0 +1 @@ +/judge_Unsloth-Llama-3-8B-Instruct_predictions_Unsloth-Llama-3-8B-Instruct.json diff --git a/data/experiments/predict/pl-court-instruct/metrics_Unsloth-Llama-3-8B-Instruct-fine-tuned.json b/data/experiments/predict/pl-court-instruct/metrics_Unsloth-Llama-3-8B-Instruct-fine-tuned.json index 6740087..4bfb85e 100644 --- a/data/experiments/predict/pl-court-instruct/metrics_Unsloth-Llama-3-8B-Instruct-fine-tuned.json +++ b/data/experiments/predict/pl-court-instruct/metrics_Unsloth-Llama-3-8B-Instruct-fine-tuned.json @@ -9,4 +9,4 @@ "recorder": 0.9931748509407043, "signature": 0.9937450289726257 } -} \ No newline at end of file +} diff --git a/data/experiments/predict/pl-court-instruct/metrics_Unsloth-Llama-3-8B-Instruct.json b/data/experiments/predict/pl-court-instruct/metrics_Unsloth-Llama-3-8B-Instruct.json index 0e8eb94..e474cc6 100644 --- a/data/experiments/predict/pl-court-instruct/metrics_Unsloth-Llama-3-8B-Instruct.json +++ b/data/experiments/predict/pl-court-instruct/metrics_Unsloth-Llama-3-8B-Instruct.json @@ -9,4 +9,4 @@ "recorder": 0.7640316486358643, "signature": 0.7549777626991272 } -} \ No newline at end of file +} diff --git a/data/experiments/predict/pl-court-instruct/metrics_Unsloth-Mistral-7B-Instruct-v0.3-fine-tuned.json b/data/experiments/predict/pl-court-instruct/metrics_Unsloth-Mistral-7B-Instruct-v0.3-fine-tuned.json index e0a55ee..6fed7d3 100644 --- a/data/experiments/predict/pl-court-instruct/metrics_Unsloth-Mistral-7B-Instruct-v0.3-fine-tuned.json +++ b/data/experiments/predict/pl-court-instruct/metrics_Unsloth-Mistral-7B-Instruct-v0.3-fine-tuned.json @@ -9,4 +9,4 @@ "recorder": 0.9933416843414307, "signature": 0.9780842661857605 } -} \ No newline at end of file +} diff --git a/data/experiments/predict/pl-court-instruct/metrics_Unsloth-Mistral-7B-Instruct-v0.3.json b/data/experiments/predict/pl-court-instruct/metrics_Unsloth-Mistral-7B-Instruct-v0.3.json index c6c6bd8..184b00e 100644 --- a/data/experiments/predict/pl-court-instruct/metrics_Unsloth-Mistral-7B-Instruct-v0.3.json +++ b/data/experiments/predict/pl-court-instruct/metrics_Unsloth-Mistral-7B-Instruct-v0.3.json @@ -9,4 +9,4 @@ "recorder": 0.9425673484802246, "signature": 0.56711345911026 } -} \ No newline at end of file +} diff --git a/data/experiments/predict/pl-court-instruct/metrics_summary.md b/data/experiments/predict/pl-court-instruct/metrics_summary.md index f2307b6..6225012 100644 --- a/data/experiments/predict/pl-court-instruct/metrics_summary.md +++ b/data/experiments/predict/pl-court-instruct/metrics_summary.md @@ -3,4 +3,4 @@ | Unsloth-Llama-3-8B-Instruct | 0.439 | 0.879 | 0.982 | 0.906 | 0.915 | 0.426 | 0.764 | 0.755 | | Unsloth-Llama-3-8B-Instruct-fine-tuned | 0.828 | 0.995 | 0.989 | 0.986 | 0.977 | 0.601 | 0.993 | 0.994 | | Unsloth-Mistral-7B-Instruct-v0.3 | 0.477 | 0.830 | 0.987 | 0.900 | 0.870 | 0.419 | 0.943 | 0.567 | -| Unsloth-Mistral-7B-Instruct-v0.3-fine-tuned | 0.819 | 0.996 | 0.989 | 0.996 | 0.981 | 0.737 | 0.993 | 0.978 | \ No newline at end of file +| Unsloth-Mistral-7B-Instruct-v0.3-fine-tuned | 0.819 | 0.996 | 0.989 | 0.996 | 0.981 | 0.737 | 0.993 | 0.978 | diff --git a/data/sample_data/.gitignore b/data/sample_data/.gitignore index fb39c27..1f07155 100644 --- a/data/sample_data/.gitignore +++ b/data/sample_data/.gitignore @@ -1,2 +1,2 @@ -/judgements-100-sample-with-retrieved-informations.csv -/judgements-100-sample.csv +/judgements-100-sample-with-retrieved-informations.csv +/judgements-100-sample.csv diff --git a/dvc.lock b/dvc.lock index b04ce8c..687e50d 100644 --- a/dvc.lock +++ b/dvc.lock @@ -65,8 +65,8 @@ stages: nfiles: 17 - path: scripts/dataset/build_instruct_dataset.py hash: md5 - md5: 9b138322059d63ce3ad1bd05c8b931f2 - size: 5461 + md5: 5038c49e847d847ea3fd05903624d5c9 + size: 5696 embed@mmlw-roberta-large: cmd: PYTHONPATH=. python scripts/embed/embed_text.py embedding_model=mmlw-roberta-large deps: @@ -94,10 +94,10 @@ stages: size: 24415235644 nfiles: 53 evaluate@Unsloth-Llama-3-8B-Instruct: - cmd: PYTHONPATH=. python scripts/sft/evaluate.py --output-file + cmd: PYTHONPATH=. python scripts/sft/evaluate.py --output-file data/experiments/predict/pl-court-instruct/outputs_Unsloth-Llama-3-8B-Instruct.json deps: - - path: + - path: data/experiments/predict/pl-court-instruct/outputs_Unsloth-Llama-3-8B-Instruct.json hash: md5 md5: df2f1d464152f87737c8ebb5b0673854 @@ -107,16 +107,16 @@ stages: md5: 66211e8b6f056234240f094896966a9c size: 578 outs: - - path: + - path: data/experiments/predict/pl-court-instruct/metrics_Unsloth-Llama-3-8B-Instruct.json hash: md5 md5: 521a731cc2c45d3eda0656a8e69d505b size: 307 evaluate@Unsloth-Llama-3-8B-Instruct-fine-tuned: - cmd: PYTHONPATH=. python scripts/sft/evaluate.py --output-file + cmd: PYTHONPATH=. python scripts/sft/evaluate.py --output-file data/experiments/predict/pl-court-instruct/outputs_Unsloth-Llama-3-8B-Instruct-fine-tuned.json deps: - - path: + - path: data/experiments/predict/pl-court-instruct/outputs_Unsloth-Llama-3-8B-Instruct-fine-tuned.json hash: md5 md5: 9199da7e04fb35cc1ce2bbe9dd5cd274 @@ -126,16 +126,16 @@ stages: md5: 66211e8b6f056234240f094896966a9c size: 578 outs: - - path: + - path: data/experiments/predict/pl-court-instruct/metrics_Unsloth-Llama-3-8B-Instruct-fine-tuned.json hash: md5 md5: 6a0eb30a14687342bc86ae80253cd60c size: 306 evaluate@Unsloth-Mistral-7B-Instruct-v0.3: - cmd: PYTHONPATH=. python scripts/sft/evaluate.py --output-file + cmd: PYTHONPATH=. python scripts/sft/evaluate.py --output-file data/experiments/predict/pl-court-instruct/outputs_Unsloth-Mistral-7B-Instruct-v0.3.json deps: - - path: + - path: data/experiments/predict/pl-court-instruct/outputs_Unsloth-Mistral-7B-Instruct-v0.3.json hash: md5 md5: c2e03f3fbd29c744023bdac7e1007265 @@ -145,16 +145,16 @@ stages: md5: 66211e8b6f056234240f094896966a9c size: 578 outs: - - path: + - path: data/experiments/predict/pl-court-instruct/metrics_Unsloth-Mistral-7B-Instruct-v0.3.json hash: md5 md5: 091b8888275600052dd2dcdd36a55588 size: 305 evaluate@Unsloth-Mistral-7B-Instruct-v0.3-fine-tuned: - cmd: PYTHONPATH=. python scripts/sft/evaluate.py --output-file + cmd: PYTHONPATH=. python scripts/sft/evaluate.py --output-file data/experiments/predict/pl-court-instruct/outputs_Unsloth-Mistral-7B-Instruct-v0.3-fine-tuned.json deps: - - path: + - path: data/experiments/predict/pl-court-instruct/outputs_Unsloth-Mistral-7B-Instruct-v0.3-fine-tuned.json hash: md5 md5: a4fda5774b367e8924cf07f3bf271922 @@ -164,7 +164,7 @@ stages: md5: 66211e8b6f056234240f094896966a9c size: 578 outs: - - path: + - path: data/experiments/predict/pl-court-instruct/metrics_Unsloth-Mistral-7B-Instruct-v0.3-fine-tuned.json hash: md5 md5: 3b3589929112cb2f199044d240e87bcc @@ -192,14 +192,14 @@ stages: size: 161 - path: configs/predict.yaml hash: md5 - md5: 74ad1dc5d9f130074533078d85e55e94 - size: 504 + md5: 888667e56c54157be4d75f85657cf478 + size: 494 - path: scripts/sft/predict.py hash: md5 - md5: 59c2afb977f520c9134153def544111d + md5: 1dc3e25365c4200d1e26e04b41d6b831 size: 3188 outs: - - path: + - path: data/experiments/predict/pl-court-instruct/outputs_Unsloth-Llama-3-8B-Instruct.json hash: md5 md5: df2f1d464152f87737c8ebb5b0673854 @@ -213,18 +213,18 @@ stages: size: 245 - path: configs/predict.yaml hash: md5 - md5: 74ad1dc5d9f130074533078d85e55e94 - size: 504 + md5: 7422a2c12c7d31d7b68dbe89f02dab5a + size: 532 - path: scripts/sft/predict.py hash: md5 - md5: 59c2afb977f520c9134153def544111d - size: 3188 + md5: 150d40027312348c19a82ca4f89b4cc6 + size: 2735 outs: - - path: + - path: data/experiments/predict/pl-court-instruct/outputs_Unsloth-Llama-3-8B-Instruct-fine-tuned.json hash: md5 - md5: 9199da7e04fb35cc1ce2bbe9dd5cd274 - size: 1891254 + md5: 5c49073109ca97d16501ca74fc568df7 + size: 1742376 predict@Unsloth-Mistral-7B-Instruct-v0.3: cmd: PYTHONPATH=. python scripts/sft/predict.py model=Unsloth-Mistral-7B-Instruct-v0.3 deps: @@ -234,14 +234,14 @@ stages: size: 167 - path: configs/predict.yaml hash: md5 - md5: 74ad1dc5d9f130074533078d85e55e94 - size: 504 + md5: 888667e56c54157be4d75f85657cf478 + size: 494 - path: scripts/sft/predict.py hash: md5 - md5: 59c2afb977f520c9134153def544111d + md5: 1dc3e25365c4200d1e26e04b41d6b831 size: 3188 outs: - - path: + - path: data/experiments/predict/pl-court-instruct/outputs_Unsloth-Mistral-7B-Instruct-v0.3.json hash: md5 md5: c2e03f3fbd29c744023bdac7e1007265 @@ -255,14 +255,14 @@ stages: size: 256 - path: configs/predict.yaml hash: md5 - md5: 74ad1dc5d9f130074533078d85e55e94 - size: 504 + md5: 888667e56c54157be4d75f85657cf478 + size: 494 - path: scripts/sft/predict.py hash: md5 - md5: 59c2afb977f520c9134153def544111d + md5: 1dc3e25365c4200d1e26e04b41d6b831 size: 3188 outs: - - path: + - path: data/experiments/predict/pl-court-instruct/outputs_Unsloth-Mistral-7B-Instruct-v0.3-fine-tuned.json hash: md5 md5: a4fda5774b367e8924cf07f3bf271922 @@ -340,3 +340,24 @@ stages: hash: md5 md5: 703e82986a6ae26fbc2fd0dfac7f8893 size: 989 + evaluate_llm_as_judge@Unsloth-Llama-3-8B-Instruct-Unsloth-Llama-3-8B-Instruct: + cmd: PYTHONPATH=. python scripts/sft/evaluate_llm_as_judge.py model=Unsloth-Llama-3-8B-Instruct + answers_file=data/experiments/predict/pl-court-instruct/outputs_Unsloth-Llama-3-8B-Instruct.json + out_metric_file=data/experiments/llm_as_judge/pl-court-instruct/judge_Unsloth-Llama-3-8B-Instruct_metrics_Unsloth-Llama-3-8B-Instruct.json + out_predictions_file=data/experiments/llm_as_judge/pl-court-instruct/judge_Unsloth-Llama-3-8B-Instruct_predictions_Unsloth-Llama-3-8B-Instruct.json + deps: + - path: + data/experiments/predict/pl-court-instruct/outputs_Unsloth-Llama-3-8B-Instruct.json + hash: md5 + md5: df2f1d464152f87737c8ebb5b0673854 + size: 2179383 + - path: scripts/sft/evaluate_llm_as_judge.py + hash: md5 + md5: 55ffa83e2778e921bdfc677889e45a23 + size: 3676 + outs: + - path: + data/experiments/llm_as_judge/pl-court-instruct/judge_Unsloth-Llama-3-8B-Instruct_predictions_Unsloth-Llama-3-8B-Instruct.json + hash: md5 + md5: d0be277f3585e4d71d9551cd96851183 + size: 54800 diff --git a/dvc.yaml b/dvc.yaml index 4d3f1b5..d01f8ba 100644 --- a/dvc.yaml +++ b/dvc.yaml @@ -1,9 +1,9 @@ stages: raw_dataset_readme: cmd: >- - jupyter nbconvert - --no-input - --to markdown + jupyter nbconvert + --no-input + --to markdown --execute nbs/Data/02_Dataset_Description_Raw.ipynb --output-dir data/datasets/pl/readme/raw --output README @@ -15,9 +15,9 @@ stages: instruct_dataset_readme: cmd: >- - jupyter nbconvert - --no-input - --to markdown + jupyter nbconvert + --no-input + --to markdown --execute nbs/Data/03_Dataset_Description_Instruct.ipynb --output-dir data/datasets/pl/readme/instruct --output README @@ -25,6 +25,7 @@ stages: - nbs/Data/03_Dataset_Description_Instruct.ipynb outs: - data/datasets/pl/readme/instruct/ + build_instruct_dataset: cmd: >- PYTHONPATH=. python scripts/dataset/build_instruct_dataset.py @@ -54,7 +55,7 @@ stages: model: - mmlw-roberta-large cmd: >- - PYTHONPATH=. python scripts/embed/aggregate_embeddings.py + PYTHONPATH=. python scripts/embed/aggregate_embeddings.py --embeddings-dir data/embeddings/pl-court-raw/${item.model}/all_embeddings deps: - scripts/embed/aggregate_embeddings.py @@ -62,17 +63,17 @@ stages: outs: - data/embeddings/pl-court-raw/${item.model}/agg_embeddings.pt - + build_graph_dataset: cmd: >- PYTHONPATH=. python scripts/dataset/build_graph_dataset.py - --dataset-dir data/datasets/pl/raw + --dataset-dir data/datasets/pl/raw --embeddings-root-dir data/embeddings/pl-court-raw/mmlw-roberta-large/ --target-dir data/datasets/pl/graph deps: - scripts/dataset/build_graph_dataset.py - juddges/data/pl_court_graph.py - - data/datasets/pl/raw + - data/datasets/pl/raw - data/embeddings/pl-court-raw/mmlw-roberta-large/agg_embeddings.pt - data/embeddings/pl-court-raw/mmlw-roberta-large/all_embeddings/config.yaml outs: @@ -88,7 +89,7 @@ stages: PYTHONPATH=. python scripts/sft/fine_tune_unsloth.py model=${item.model} deps: - scripts/sft/fine_tune_unsloth.py - - configs/fine_tuning.yaml + - configs/fine_tuning.yaml - configs/model/${item.model}.yaml outs: - data/experiments/fine-tune/${item.model}/pl-court-instruct @@ -117,7 +118,7 @@ stages: - Unsloth-Mistral-7B-Instruct-v0.3 - Unsloth-Mistral-7B-Instruct-v0.3-fine-tuned cmd: >- - PYTHONPATH=. python scripts/sft/evaluate.py + PYTHONPATH=. python scripts/sft/evaluate.py --output-file data/experiments/predict/pl-court-instruct/outputs_${item.model}.json deps: - scripts/sft/evaluate.py @@ -126,12 +127,31 @@ stages: - data/experiments/predict/pl-court-instruct/metrics_${item.model}.json: cache: false + evaluate_llm_as_judge: + matrix: + judge_model: + - Unsloth-Llama-3-8B-Instruct + predictions_model: + - Unsloth-Llama-3-8B-Instruct + cmd: >- + PYTHONPATH=. python scripts/sft/evaluate_llm_as_judge.py + model=${item.judge_model} + answers_file=data/experiments/predict/pl-court-instruct/outputs_${item.predictions_model}.json + out_metric_file=data/experiments/llm_as_judge/pl-court-instruct/judge_${item.judge_model}_metrics_${item.predictions_model}.json + out_predictions_file=data/experiments/llm_as_judge/pl-court-instruct/judge_${item.judge_model}_predictions_${item.predictions_model}.json + deps: + - scripts/sft/evaluate_llm_as_judge.py + - data/experiments/predict/pl-court-instruct/outputs_${item.predictions_model}.json + outs: + # - data/experiments/llm_as_judge/pl-court-instruct/judge_${item.judge_model}_metrics_${item.predictions_model}.json + - data/experiments/llm_as_judge/pl-court-instruct/judge_${item.judge_model}_predictions_${item.predictions_model}.json + summarize_metrics: matrix: dir: - data/experiments/predict/pl-court-instruct cmd: >- - PYTHONPATH=. python scripts/sft/summarize_metrics.py + PYTHONPATH=. python scripts/sft/summarize_metrics.py --root-dir ${item.dir} deps: - scripts/sft/summarize_metrics.py diff --git a/juddges/config.py b/juddges/config.py index 35eec7a..f0bf098 100644 --- a/juddges/config.py +++ b/juddges/config.py @@ -1,4 +1,5 @@ from pathlib import Path + from pydantic import BaseModel diff --git a/juddges/data/database.py b/juddges/data/database.py index 9c2dae0..4b10c2f 100644 --- a/juddges/data/database.py +++ b/juddges/data/database.py @@ -1,81 +1,81 @@ -import os -from typing import Any, Callable, Generator, Iterable, Iterator - -from loguru import logger -from pymongo import MongoClient, UpdateOne -from pymongo.collection import Collection -from pymongo.cursor import Cursor -from pymongo.errors import BulkWriteError - - -def get_mongo_collection( - mongo_uri: str | None = None, - mongo_db: str | None = None, - collection_name: str = "pl-court", -) -> Collection: - uri = mongo_uri or os.environ.get("MONGO_URI") - assert uri, "Mongo URI is required" - db_name = mongo_db or os.environ.get("MONGO_DB_NAME") - assert db_name, "Mongo DB name is required" - - client: MongoClient = MongoClient(uri) - db = client[db_name] - return db[collection_name] - - -class BatchedDatabaseCursor: - """MongoDB cursor wrapper that returns documents in batches. - - Cursor is consumed in batches of specified size. - - Prefetch option loads all documents into memory before iterating. - """ - - def __init__(self, cursor: Cursor, batch_size: int, prefetch: bool) -> None: - self.cursor = cursor - self.batch_size = batch_size - self.prefetch = prefetch - - def __iter__(self) -> Iterator[list[dict[str, Any]]]: - if self.prefetch: - iterable: Iterable = list(self.cursor) - else: - iterable = self.cursor - - def gen_batches() -> Generator[list[dict[str, Any]], None, None]: - """Credit: https://stackoverflow.com/a/61809417""" - chunk: list[dict[str, Any]] = [] - for i, row in enumerate(iterable): - if i % self.batch_size == 0 and i > 0: - yield chunk - del chunk[:] - chunk.append(row) - yield chunk - - return gen_batches() - - -class BatchDatabaseUpdate: - """Updates database in batches using provided update function. - - Update function takes document id and returns dictionary with updated fields: - def update_func (document: dict[str, Any]) -> dict[str, Any]: - - Updated document may be constrained to only necessary fields (_id must be present). - - Update fields may or may not be already present in the database. - - Update is called specified documents. - """ - - def __init__(self, mongo_uri: str, update_func: Callable[[dict[str, Any]], dict]) -> None: - self.mongo_uri = mongo_uri - self.update_func = update_func - - def __call__(self, documents: list[dict[str, Any]]) -> None: - update_batch: list[UpdateOne] = [] - - for doc in documents: - update_data = self.update_func(doc) - update_batch.append(UpdateOne({"_id": doc["_id"]}, {"$set": update_data})) - - collection = get_mongo_collection(mongo_uri=self.mongo_uri) - - try: - collection.bulk_write(update_batch, ordered=False) - except BulkWriteError as err: - logger.error(err) +import os +from typing import Any, Callable, Generator, Iterable, Iterator + +from loguru import logger +from pymongo import MongoClient, UpdateOne +from pymongo.collection import Collection +from pymongo.cursor import Cursor +from pymongo.errors import BulkWriteError + + +def get_mongo_collection( + mongo_uri: str | None = None, + mongo_db: str | None = None, + collection_name: str = "pl-court", +) -> Collection: + uri = mongo_uri or os.environ.get("MONGO_URI") + assert uri, "Mongo URI is required" + db_name = mongo_db or os.environ.get("MONGO_DB_NAME") + assert db_name, "Mongo DB name is required" + + client: MongoClient = MongoClient(uri) + db = client[db_name] + return db[collection_name] + + +class BatchedDatabaseCursor: + """MongoDB cursor wrapper that returns documents in batches. + - Cursor is consumed in batches of specified size. + - Prefetch option loads all documents into memory before iterating. + """ + + def __init__(self, cursor: Cursor, batch_size: int, prefetch: bool) -> None: + self.cursor = cursor + self.batch_size = batch_size + self.prefetch = prefetch + + def __iter__(self) -> Iterator[list[dict[str, Any]]]: + if self.prefetch: + iterable: Iterable = list(self.cursor) + else: + iterable = self.cursor + + def gen_batches() -> Generator[list[dict[str, Any]], None, None]: + """Credit: https://stackoverflow.com/a/61809417""" + chunk: list[dict[str, Any]] = [] + for i, row in enumerate(iterable): + if i % self.batch_size == 0 and i > 0: + yield chunk + del chunk[:] + chunk.append(row) + yield chunk + + return gen_batches() + + +class BatchDatabaseUpdate: + """Updates database in batches using provided update function. + - Update function takes document id and returns dictionary with updated fields: + def update_func (document: dict[str, Any]) -> dict[str, Any]: + - Updated document may be constrained to only necessary fields (_id must be present). + - Update fields may or may not be already present in the database. + - Update is called specified documents. + """ + + def __init__(self, mongo_uri: str, update_func: Callable[[dict[str, Any]], dict]) -> None: + self.mongo_uri = mongo_uri + self.update_func = update_func + + def __call__(self, documents: list[dict[str, Any]]) -> None: + update_batch: list[UpdateOne] = [] + + for doc in documents: + update_data = self.update_func(doc) + update_batch.append(UpdateOne({"_id": doc["_id"]}, {"$set": update_data})) + + collection = get_mongo_collection(mongo_uri=self.mongo_uri) + + try: + collection.bulk_write(update_batch, ordered=False) + except BulkWriteError as err: + logger.error(err) diff --git a/juddges/data/pl_court_graph.py b/juddges/data/pl_court_graph.py index b9a77d6..f1ad526 100644 --- a/juddges/data/pl_court_graph.py +++ b/juddges/data/pl_court_graph.py @@ -1,10 +1,11 @@ from pathlib import Path from typing import Any -from torch_geometric.data import HeteroData -from torch import Tensor -import polars as pl + import networkx as nx +import polars as pl import torch +from torch import Tensor +from torch_geometric.data import HeteroData from tqdm.auto import tqdm JUDGMENT_ATTRS = [ diff --git a/juddges/metrics/info_extraction.py b/juddges/metrics/info_extraction.py index 4cd75b2..1c262ee 100644 --- a/juddges/metrics/info_extraction.py +++ b/juddges/metrics/info_extraction.py @@ -1,6 +1,8 @@ -from torchmetrics.functional.text import chrf_score import datetime from collections import defaultdict + +from torchmetrics.functional.text import chrf_score + from juddges.utils.misc import parse_yaml EMPTY_ANSWER = "" diff --git a/juddges/models/factory.py b/juddges/models/factory.py index 34b0d92..eff304f 100644 --- a/juddges/models/factory.py +++ b/juddges/models/factory.py @@ -1,7 +1,8 @@ from dataclasses import dataclass from typing import Any -from peft import PeftModel + import torch +from peft import PeftModel from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig from juddges.config import LLMConfig diff --git a/juddges/models/predict.py b/juddges/models/predict.py new file mode 100644 index 0000000..bbd13ac --- /dev/null +++ b/juddges/models/predict.py @@ -0,0 +1,61 @@ +import time + +from datasets import Dataset +from torch.utils.data import DataLoader +from tqdm.auto import tqdm + +from juddges.models.factory import ModelForGeneration + + +def predict_with_llm( + model_pack: ModelForGeneration, + dataset: Dataset, + batch_size: int, + num_proc: int, + verbose: bool = True, +) -> list[str]: + """Generates LLM predictions for a given dataset + + Args: + llm (AutoModelForCausalLM): LLM + dataset (Dataset): dataset to make prediction for, should be tokenized and has input_ids field + + Returns: + list[str]: List of generated texts with preserved order + """ + dataloader = DataLoader( + dataset, + batch_size=batch_size, + num_workers=num_proc, + pin_memory=(num_proc > 1), + shuffle=False, + ) + + model_outputs = [] + + model = model_pack.model + tokenizer = model_pack.tokenizer + device = next(model.parameters()).device + + with tqdm(dataloader, disable=not verbose) as pbar: + for batch in pbar: + model_inputs = batch["input_ids"].view(batch_size, -1) + model_inputs = model_inputs.to(device, non_blocking=True) + input_length = model_inputs.size(1) + + start_time = time.time() + generated_ids = model.generate( + model_inputs, + **model_pack.generate_kwargs, + ) + duration = time.time() - start_time + + decoded = tokenizer.batch_decode( + generated_ids[:, input_length:], + skip_special_tokens=True, + ) + model_outputs.extend(decoded) + + pbar.set_postfix_str(f"{generated_ids.numel() / duration: 0.2f} tok/sec") + + return model_outputs diff --git a/juddges/preprocessing/text_chunker.py b/juddges/preprocessing/text_chunker.py index f504f5b..25ec80f 100644 --- a/juddges/preprocessing/text_chunker.py +++ b/juddges/preprocessing/text_chunker.py @@ -1,4 +1,5 @@ from typing import Any + from langchain_text_splitters import RecursiveCharacterTextSplitter from transformers import PreTrainedTokenizer diff --git a/juddges/preprocessing/text_encoder.py b/juddges/preprocessing/text_encoder.py index 6f575d7..26ce426 100644 --- a/juddges/preprocessing/text_encoder.py +++ b/juddges/preprocessing/text_encoder.py @@ -1,6 +1,8 @@ from typing import Any + from torch import Tensor from transformers import PreTrainedTokenizer + from juddges.preprocessing.context_truncator import ContextTruncator diff --git a/juddges/prompts/information_extraction.py b/juddges/prompts/information_extraction.py index 0cc2284..18adb5e 100644 --- a/juddges/prompts/information_extraction.py +++ b/juddges/prompts/information_extraction.py @@ -1,108 +1,108 @@ -from langchain.output_parsers.json import parse_json_markdown -from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate -from langchain_core.messages import AIMessage -from langchain_core.runnables import RunnableLambda, RunnablePassthrough, RunnableSequence -from langchain_openai import ChatOpenAI - -SCHEMA_PROMPT_TEMPLATE = """ -Act as a assistant that prepares schema for information extraction - -Based on the user input prepare schema containing variables with their short description and type. -Be precise about variable names, format names using snake_case. -If user asks irrelevant question always return empty JSON. -As example: -User: I want extract age, gender, and plea from the judgement -Agent: - age: integer - gender: male or female - plea: string - -==== -{SCHEMA_TEXT} -==== - -Format response as JSON: -""" - -EXTRACTION_PROMPT_TEMPLATE = """Act as a legal document tool that extracts information and answer questions based on judgements. - -Instruction for extracting information from judgements: -- Judgements are in {LANGUAGE} language, please extract information in {LANGUAGE}. -- Do not provide information that are not explicitly mentioned in judgements. If you can't extract information from the text field, leave the field with empty string "". - -Follow the following YAML structure to extract information and answer questions based on judgements: -{SCHEMA} - -==== -{TEXT} -==== - -Format response as JSON: -""" - -EXAMPLE_SCHEMA = """verdict_date: date as ISO 8601 -verdict: string, text representing verdict of the judgement -verdict_summary: string, short summary of the verdict -verdict_id: string -court: string -parties: string -appeal_against: string -first_trial: boolean -drug_offence: boolean -child_offence: boolean -offence_seriousness: boolean -verdict_tags: List[string]""" - - -def prepare_information_extraction_chain_from_user_prompt() -> RunnableSequence: - schema_chain = prepare_schema_chain() - inputs = { - "SCHEMA": schema_chain, - "TEXT": RunnablePassthrough(), - "LANGUAGE": RunnablePassthrough(), - } - return inputs | RunnableLambda(route) - - -def prepare_information_extraction_chain( - model_name: str = "gpt-4-0125-preview", - log_to_mlflow: bool = False, -) -> RunnableSequence: - model = ChatOpenAI(model=model_name, temperature=0) - human_message_template = HumanMessagePromptTemplate.from_template(EXTRACTION_PROMPT_TEMPLATE) - _prompt = ChatPromptTemplate( - messages=[human_message_template], - input_variables=["TEXT", "LANGUAGE", "SCHEMA"], - ) - - if log_to_mlflow: - import mlflow - - mlflow.log_dict(_prompt.save_to_json(), "prompt.json") - - return _prompt | model | (lambda x: parse_json_markdown(x.content)) - - -def prepare_schema_chain(model_name: str = "gpt-3.5-turbo") -> RunnableSequence: - model = ChatOpenAI(model=model_name, temperature=0) - human_message_template = HumanMessagePromptTemplate.from_template(SCHEMA_PROMPT_TEMPLATE) - _prompt = ChatPromptTemplate( - messages=[human_message_template], - input_variables=["TEXT", "LANGUAGE", "SCHEMA"], - ) - - return _prompt | model | parse_schema - - -def parse_schema(ai_message: AIMessage) -> str: - response_schema = parse_json_markdown(ai_message.content) - return "\n".join(f"{key}: {val}" for key, val in response_schema.items()) - - -def route(response_schema: str) -> dict[str, str]: - if response_schema["SCHEMA"]: - return prepare_information_extraction_chain() - - raise ValueError( - "Cannot determine schema for the given input prompt. Please try different query." - ) +from langchain.output_parsers.json import parse_json_markdown +from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate +from langchain_core.messages import AIMessage +from langchain_core.runnables import RunnableLambda, RunnablePassthrough, RunnableSequence +from langchain_openai import ChatOpenAI + +SCHEMA_PROMPT_TEMPLATE = """ +Act as a assistant that prepares schema for information extraction + +Based on the user input prepare schema containing variables with their short description and type. +Be precise about variable names, format names using snake_case. +If user asks irrelevant question always return empty JSON. +As example: +User: I want extract age, gender, and plea from the judgement +Agent: + age: integer + gender: male or female + plea: string + +==== +{SCHEMA_TEXT} +==== + +Format response as JSON: +""" + +EXTRACTION_PROMPT_TEMPLATE = """Act as a legal document tool that extracts information and answer questions based on judgements. + +Instruction for extracting information from judgements: +- Judgements are in {LANGUAGE} language, please extract information in {LANGUAGE}. +- Do not provide information that are not explicitly mentioned in judgements. If you can't extract information from the text field, leave the field with empty string "". + +Follow the following YAML structure to extract information and answer questions based on judgements: +{SCHEMA} + +==== +{TEXT} +==== + +Format response as JSON: +""" + +EXAMPLE_SCHEMA = """verdict_date: date as ISO 8601 +verdict: string, text representing verdict of the judgement +verdict_summary: string, short summary of the verdict +verdict_id: string +court: string +parties: string +appeal_against: string +first_trial: boolean +drug_offence: boolean +child_offence: boolean +offence_seriousness: boolean +verdict_tags: List[string]""" + + +def prepare_information_extraction_chain_from_user_prompt() -> RunnableSequence: + schema_chain = prepare_schema_chain() + inputs = { + "SCHEMA": schema_chain, + "TEXT": RunnablePassthrough(), + "LANGUAGE": RunnablePassthrough(), + } + return inputs | RunnableLambda(route) + + +def prepare_information_extraction_chain( + model_name: str = "gpt-4-0125-preview", + log_to_mlflow: bool = False, +) -> RunnableSequence: + model = ChatOpenAI(model=model_name, temperature=0) + human_message_template = HumanMessagePromptTemplate.from_template(EXTRACTION_PROMPT_TEMPLATE) + _prompt = ChatPromptTemplate( + messages=[human_message_template], + input_variables=["TEXT", "LANGUAGE", "SCHEMA"], + ) + + if log_to_mlflow: + import mlflow + + mlflow.log_dict(_prompt.save_to_json(), "prompt.json") + + return _prompt | model | (lambda x: parse_json_markdown(x.content)) + + +def prepare_schema_chain(model_name: str = "gpt-3.5-turbo") -> RunnableSequence: + model = ChatOpenAI(model=model_name, temperature=0) + human_message_template = HumanMessagePromptTemplate.from_template(SCHEMA_PROMPT_TEMPLATE) + _prompt = ChatPromptTemplate( + messages=[human_message_template], + input_variables=["TEXT", "LANGUAGE", "SCHEMA"], + ) + + return _prompt | model | parse_schema + + +def parse_schema(ai_message: AIMessage) -> str: + response_schema = parse_json_markdown(ai_message.content) + return "\n".join(f"{key}: {val}" for key, val in response_schema.items()) + + +def route(response_schema: str) -> dict[str, str]: + if response_schema["SCHEMA"]: + return prepare_information_extraction_chain() + + raise ValueError( + "Cannot determine schema for the given input prompt. Please try different query." + ) diff --git a/juddges/settings.py b/juddges/settings.py index 84a2058..835d3cf 100644 --- a/juddges/settings.py +++ b/juddges/settings.py @@ -1,72 +1,72 @@ -from pathlib import Path - -import mlflow -import tiktoken -from sqlalchemy import create_engine -from sqlalchemy.engine import Engine - -# get root path as ROOT_PATH as pathlib objects -ROOT_PATH = Path(__file__).resolve().parent.parent - -DATA_PATH = ROOT_PATH / "data" -CONFIG_PATH = ROOT_PATH / "configs" - -SAMPLE_DATA_PATH = DATA_PATH / "sample_data" - -PL_JUDGEMENTS_PATH = DATA_PATH / "datasets" / "pl" -PL_COURT_DEP_ID_2_NAME = PL_JUDGEMENTS_PATH / "court_dep_names.csv" -PL_JUDGEMENTS_PATH_RAW = PL_JUDGEMENTS_PATH / "raw" -PL_JUDGEMENTS_PATH_TEXTS = PL_JUDGEMENTS_PATH / "text" -PL_JUDGEMENTS_PATH_INSTRUCT = PL_JUDGEMENTS_PATH / "instruct" - -MLFLOW_EXP_NAME = "Juddges-Information-Extraction" - - -def num_tokens_from_string( - string: str, # The string to count tokens for - encoding_name: str = "cl100k_base", # gpt-4, gpt-3.5-turbo, text-embedding-ada-002 -) -> int: # The number of tokens in the string - """ - Returns the number of tokens in a text string. - """ - encoding = tiktoken.get_encoding(encoding_name) - num_tokens = len(encoding.encode(string)) - return num_tokens - - -LLM_TO_PRICE_INPUT = { - "gpt-4-1106-preview": 0.01 / 1000, - "gpt-4-0125-preview": 0.01 / 1000, - "gpt-3.5-turbo-1106": 0.001 / 1000, -} - -LLM_TO_PRICE_COMPLETION = { - "gpt-4-1106-preview": 0.03 / 1000, - "gpt-4-0125-preview": 0.03 / 1000, - "gpt-3.5-turbo-1106": 0.002 / 1000, -} - -LOCAL_POSTGRES = "postgresql+psycopg2://llm:llm@postgres-juddges:5432/llm" - - -def get_sqlalchemy_engine() -> Engine: - return create_engine( - LOCAL_POSTGRES, - pool_size=10, - max_overflow=2, - pool_recycle=300, - pool_pre_ping=True, - pool_use_lifo=True, - ) - - -def prepare_langchain_cache() -> None: - import langchain - from langchain.cache import SQLAlchemyMd5Cache - - langchain.llm_cache = SQLAlchemyMd5Cache(get_sqlalchemy_engine()) - - -def prepare_mlflow(experiment_name: str = MLFLOW_EXP_NAME, url: str = "postgres-juddges") -> None: - mlflow.set_tracking_uri(url) - mlflow.set_experiment(experiment_name) +from pathlib import Path + +import mlflow +import tiktoken +from sqlalchemy import create_engine +from sqlalchemy.engine import Engine + +# get root path as ROOT_PATH as pathlib objects +ROOT_PATH = Path(__file__).resolve().parent.parent + +DATA_PATH = ROOT_PATH / "data" +CONFIG_PATH = ROOT_PATH / "configs" + +SAMPLE_DATA_PATH = DATA_PATH / "sample_data" + +PL_JUDGEMENTS_PATH = DATA_PATH / "datasets" / "pl" +PL_COURT_DEP_ID_2_NAME = PL_JUDGEMENTS_PATH / "court_dep_names.csv" +PL_JUDGEMENTS_PATH_RAW = PL_JUDGEMENTS_PATH / "raw" +PL_JUDGEMENTS_PATH_TEXTS = PL_JUDGEMENTS_PATH / "text" +PL_JUDGEMENTS_PATH_INSTRUCT = PL_JUDGEMENTS_PATH / "instruct" + +MLFLOW_EXP_NAME = "Juddges-Information-Extraction" + + +def num_tokens_from_string( + string: str, # The string to count tokens for + encoding_name: str = "cl100k_base", # gpt-4, gpt-3.5-turbo, text-embedding-ada-002 +) -> int: # The number of tokens in the string + """ + Returns the number of tokens in a text string. + """ + encoding = tiktoken.get_encoding(encoding_name) + num_tokens = len(encoding.encode(string)) + return num_tokens + + +LLM_TO_PRICE_INPUT = { + "gpt-4-1106-preview": 0.01 / 1000, + "gpt-4-0125-preview": 0.01 / 1000, + "gpt-3.5-turbo-1106": 0.001 / 1000, +} + +LLM_TO_PRICE_COMPLETION = { + "gpt-4-1106-preview": 0.03 / 1000, + "gpt-4-0125-preview": 0.03 / 1000, + "gpt-3.5-turbo-1106": 0.002 / 1000, +} + +LOCAL_POSTGRES = "postgresql+psycopg2://llm:llm@postgres-juddges:5432/llm" + + +def get_sqlalchemy_engine() -> Engine: + return create_engine( + LOCAL_POSTGRES, + pool_size=10, + max_overflow=2, + pool_recycle=300, + pool_pre_ping=True, + pool_use_lifo=True, + ) + + +def prepare_langchain_cache() -> None: + import langchain + from langchain.cache import SQLAlchemyMd5Cache + + langchain.llm_cache = SQLAlchemyMd5Cache(get_sqlalchemy_engine()) + + +def prepare_mlflow(experiment_name: str = MLFLOW_EXP_NAME, url: str = "postgres-juddges") -> None: + mlflow.set_tracking_uri(url) + mlflow.set_experiment(experiment_name) diff --git a/juddges/utils/__init__.py b/juddges/utils/__init__.py index 5a11143..effe27d 100644 --- a/juddges/utils/__init__.py +++ b/juddges/utils/__init__.py @@ -1,6 +1,6 @@ -from .versioning import bump_version, VersionBump -from .misc import parse_yaml from .config import load_and_resolve_config, resolve_config +from .misc import parse_yaml +from .versioning import VersionBump, bump_version __all__ = [ "bump_version", diff --git a/juddges/utils/misc.py b/juddges/utils/misc.py index df3bd21..a1afe98 100644 --- a/juddges/utils/misc.py +++ b/juddges/utils/misc.py @@ -1,5 +1,6 @@ import re from typing import Any + import yaml yaml_pattern: re.Pattern = re.compile(r"^```(?:ya?ml)?(?P[^`]*)", re.MULTILINE | re.DOTALL) diff --git a/nbs/Dataset Cards/03_Graph_Description.md b/nbs/Dataset Cards/03_Graph_Description.md index afb8e4e..33fd907 100644 --- a/nbs/Dataset Cards/03_Graph_Description.md +++ b/nbs/Dataset Cards/03_Graph_Description.md @@ -1,11 +1,11 @@ # Polish Court Judgments Graph ## Dataset description -We introduce a graph dataset of Polish Court Judgments. This dataset is primarily based on the [`JuDDGES/pl-court-raw`](https://huggingface.co/datasets/JuDDGES/pl-court-raw). The dataset consists of nodes representing either judgments or legal bases, and edges connecting judgments to the legal bases they refer to. Also, the graph was cleaned from small disconnected components, leaving single giant component. Consequently, the resulting graph is bipartite. We provide the dataset in both `JSON` and `PyG` formats, each has different purpose. While structurally graphs in these formats are the same, their attributes differ. +We introduce a graph dataset of Polish Court Judgments. This dataset is primarily based on the [`JuDDGES/pl-court-raw`](https://huggingface.co/datasets/JuDDGES/pl-court-raw). The dataset consists of nodes representing either judgments or legal bases, and edges connecting judgments to the legal bases they refer to. Also, the graph was cleaned from small disconnected components, leaving single giant component. Consequently, the resulting graph is bipartite. We provide the dataset in both `JSON` and `PyG` formats, each has different purpose. While structurally graphs in these formats are the same, their attributes differ. The `JSON` format is intended for analysis and contains most of the attributes available in [`JuDDGES/pl-court-raw`](https://huggingface.co/datasets/JuDDGES/pl-court-raw). We excluded some less-useful attributes and text content, which can be easily retrieved from the raw dataset and added to the graph as needed. -The `PyG` format is designed for machine learning applications, such as link prediction on graphs, and is fully compatible with the [`Pytorch Geometric`](https://github.com/pyg-team/pytorch_geometric) framework. +The `PyG` format is designed for machine learning applications, such as link prediction on graphs, and is fully compatible with the [`Pytorch Geometric`](https://github.com/pyg-team/pytorch_geometric) framework. In the following sections, we provide a more detailed explanation and use case examples for each format. @@ -19,9 +19,9 @@ In the following sections, we provide a more detailed explanation and use case e | #nodes (type=`legal_base`) | 2819 | | avg(degree) | 6.132015294025195 | - + ![png](../images/degree_distribution.png) - + ## `JSON` format @@ -58,10 +58,10 @@ g = nx.node_link_graph(g_data) ## `PyG` format -The `PyTorch Geometric` format includes embeddings of the judgment content, obtained with [sdadas/mmlw-roberta-large](https://huggingface.co/sdadas/mmlw-roberta-large) for judgment nodes, -and one-hot-vector identifiers for legal-base nodes (note that for efficiency one can substitute it with random noise identifiers, +The `PyTorch Geometric` format includes embeddings of the judgment content, obtained with [sdadas/mmlw-roberta-large](https://huggingface.co/sdadas/mmlw-roberta-large) for judgment nodes, +and one-hot-vector identifiers for legal-base nodes (note that for efficiency one can substitute it with random noise identifiers, like in [(Abboud et al., 2021)](https://arxiv.org/abs/2010.01179)). - + ### Loading @@ -125,4 +125,4 @@ print(ds) ### Example usage ```python # TBD -``` \ No newline at end of file +``` diff --git a/nbs/Presentations/00_workshop_demo.ipynb b/nbs/Presentations/00_workshop_demo.ipynb index 3b5ac83..d178f43 100644 --- a/nbs/Presentations/00_workshop_demo.ipynb +++ b/nbs/Presentations/00_workshop_demo.ipynb @@ -1,43 +1,43 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Workshop Demo\n", - "\n", - "> Jak możemy strukturyzować orzeczenia?\n", - "\n", - "## Jakie informacje chcemy/możemy ekstrahować automatycznie z orzeczeń?\n", - "\n", - "### Ogólne\n", - "\n", - "- Sygnatura sprawy\n", - "- Podstawa prawna\n", - "- Strony\n", - "- Sentencja\n", - "- Podsumowanie\n", - "- Tagi, etykiety\n", - "- ...\n", - "\n", - "### Przykłady specyficznych pytań/zagadnień\n", - "\n", - "- Czy sprawa dotyczy dzieci?\n", - "- Czy sprawa dotyczy wolności słowa?\n", - "- Czy sprawa dotyczy XXX? - każde tego typu pytanie możemy użyć\n", - "- ...\n", - "\n", - "### Czego brakuje nad w codziennych zadaniach/pracach?\n" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "python3", - "language": "python", - "name": "python3" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Workshop Demo\n", + "\n", + "> Jak możemy strukturyzować orzeczenia?\n", + "\n", + "## Jakie informacje chcemy/możemy ekstrahować automatycznie z orzeczeń?\n", + "\n", + "### Ogólne\n", + "\n", + "- Sygnatura sprawy\n", + "- Podstawa prawna\n", + "- Strony\n", + "- Sentencja\n", + "- Podsumowanie\n", + "- Tagi, etykiety\n", + "- ...\n", + "\n", + "### Przykłady specyficznych pytań/zagadnień\n", + "\n", + "- Czy sprawa dotyczy dzieci?\n", + "- Czy sprawa dotyczy wolności słowa?\n", + "- Czy sprawa dotyczy XXX? - każde tego typu pytanie możemy użyć\n", + "- ...\n", + "\n", + "### Czego brakuje nad w codziennych zadaniach/pracach?\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "python3", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/nbs/_quarto.yml b/nbs/_quarto.yml index 0a6dfcb..006b406 100644 --- a/nbs/_quarto.yml +++ b/nbs/_quarto.yml @@ -17,4 +17,4 @@ website: sidebar: style: floating -metadata-files: [nbdev.yml, sidebar.yml] \ No newline at end of file +metadata-files: [nbdev.yml, sidebar.yml] diff --git a/nginx/Dockerfile b/nginx/Dockerfile index d61e299..d70c6ce 100644 --- a/nginx/Dockerfile +++ b/nginx/Dockerfile @@ -1,20 +1,20 @@ -FROM nginx:1.25.3 - -RUN apt-get update -y \ - && apt-get install -y \ - apache2-utils \ - && rm -rf /var/lib/apt/lists/* - -ENV LISTEN_PORT=8080 \ - AUTH_REALM="Restricted" \ - HTPASSWD_FILE="/etc/nginx/conf.d/auth.htpasswd" \ - FORWARD_PROTOCOL="http" \ - FORWARD_PORT=8501 - -WORKDIR /opt - -COPY auth.htpasswd launch.sh ./ - -RUN chmod 0755 ./launch.sh - -CMD ["./launch.sh"] +FROM nginx:1.25.3 + +RUN apt-get update -y \ + && apt-get install -y \ + apache2-utils \ + && rm -rf /var/lib/apt/lists/* + +ENV LISTEN_PORT=8080 \ + AUTH_REALM="Restricted" \ + HTPASSWD_FILE="/etc/nginx/conf.d/auth.htpasswd" \ + FORWARD_PROTOCOL="http" \ + FORWARD_PORT=8501 + +WORKDIR /opt + +COPY auth.htpasswd launch.sh ./ + +RUN chmod 0755 ./launch.sh + +CMD ["./launch.sh"] diff --git a/nginx/auth.conf b/nginx/auth.conf index 7e18251..211b0a6 100644 --- a/nginx/auth.conf +++ b/nginx/auth.conf @@ -1,31 +1,31 @@ -upstream ws-backend { - # enable sticky session based on IP - ip_hash; - - server web:8501; -} - -server { - listen 8080 default_server; - listen [::]:8080; - - # server_name web; - - location / { - proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; - proxy_set_header Host $host; - - # basic auth - auth_basic "Restricted"; - auth_basic_user_file auth.htpasswd; - - # proxy pass - proxy_pass http://ws-backend; - proxy_read_timeout 900; - - proxy_http_version 1.1; - proxy_set_header Upgrade $http_upgrade; - proxy_set_header Connection "upgrade"; - - } -} \ No newline at end of file +upstream ws-backend { + # enable sticky session based on IP + ip_hash; + + server web:8501; +} + +server { + listen 8080 default_server; + listen [::]:8080; + + # server_name web; + + location / { + proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; + proxy_set_header Host $host; + + # basic auth + auth_basic "Restricted"; + auth_basic_user_file auth.htpasswd; + + # proxy pass + proxy_pass http://ws-backend; + proxy_read_timeout 900; + + proxy_http_version 1.1; + proxy_set_header Upgrade $http_upgrade; + proxy_set_header Connection "upgrade"; + + } +} diff --git a/nginx/launch.sh b/nginx/launch.sh index fb26765..8f30390 100644 --- a/nginx/launch.sh +++ b/nginx/launch.sh @@ -7,4 +7,4 @@ htpasswd -c -b /etc/nginx/auth.htpasswd $USER $PASS echo basic-auth-pwd cat /etc/nginx/auth.htpasswd -nginx -g "daemon off;" \ No newline at end of file +nginx -g "daemon off;" diff --git a/pyproject.toml b/pyproject.toml index 0616818..366f1fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,6 +3,9 @@ line-length = 100 extend-include = ["*.ipynb"] exclude = ["_modidx.py"] +[tool.ruff.lint] +select = ["E4", "E7", "E9", "F", "I"] + [tool.mypy] python_version = "3.11" strict = true @@ -12,4 +15,3 @@ plugins = "numpy.typing.mypy_plugin" [[tool.mypy.overrides]] module = ["pyarrow.*", "datasets.*", "sentence_transformers.*"] ignore_missing_imports = true - diff --git a/requirements.txt b/requirements.txt index af77995..b11f9a9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,47 +1,46 @@ -accelerate==0.29.3 -bitsandbytes==0.43.1 -chardet==5.2.0 -datasets==2.19.1 -langchain-openai==0.1.1 -langchain==0.1.13 -langsmith==0.1.33 -loguru==0.7.2 -mlflow==2.11.3 -mpire==2.10.0 -openpyxl==3.1.2 -pandas==2.2.1 -peft==0.10.0 -polars==0.20.15 -pydantic==2.7.1 -pyarrow==15.0.0 -pymongo==4.3.3 -python-dotenv==1.0.1 -PyYAML==6.0.1 -requests==2.31.0 -rich==13.7.0 -seaborn==0.13.2 -sentence-transformers==3.0.0 -tenacity==8.2.3 -tensorboard==2.16.2 -tiktoken==0.6.0 -torch==2.2.1 -torchmetrics==1.4.0 -torch_geometric==2.5.3 -transformers==4.40.2 -trl==0.8.6 -typer==0.9.0 -wandb==0.16.5 -xmltodict==0.13.0 -xlsxwriter==3.2.0 - -# dev -coverage==7.4.3 -dvc[s3]==3.48.3 -mlflow==2.11.3 -mypy==1.8.0 -nbdev==2.3.25 -psycopg2-binary==2.9.9 -pytest==8.0.2 -ruff==0.2.2 -streamlit==1.31.1 - +accelerate==0.29.3 +bitsandbytes==0.43.1 +chardet==5.2.0 +datasets==2.19.1 +langchain-openai==0.1.1 +langchain==0.1.13 +langsmith==0.1.33 +loguru==0.7.2 +mlflow==2.11.3 +mpire==2.10.0 +openpyxl==3.1.2 +pandas==2.2.1 +peft==0.10.0 +polars==0.20.15 +pydantic==2.7.1 +pyarrow==15.0.0 +pymongo==4.3.3 +python-dotenv==1.0.1 +PyYAML==6.0.1 +requests==2.31.0 +rich==13.7.0 +seaborn==0.13.2 +sentence-transformers==3.0.0 +tenacity==8.2.3 +tensorboard==2.16.2 +tiktoken==0.6.0 +torch==2.2.1 +torchmetrics==1.4.0 +torch_geometric==2.5.3 +transformers==4.40.2 +trl==0.8.6 +typer==0.9.0 +wandb==0.16.5 +xmltodict==0.13.0 +xlsxwriter==3.2.0 + +# dev +coverage==7.4.3 +dvc[s3]==3.48.3 +mlflow==2.11.3 +mypy==1.8.0 +nbdev==2.3.25 +pre-commit==3.7.1 +psycopg2-binary==2.9.9 +pytest==8.0.2 +streamlit==1.31.1 diff --git a/requirements_unsloth.txt b/requirements_unsloth.txt index 16eae7b..d6249b1 100644 --- a/requirements_unsloth.txt +++ b/requirements_unsloth.txt @@ -1,41 +1,41 @@ -accelerate==0.29.3 -bitsandbytes==0.43.1 -chardet==5.2.0 -datasets==2.19.1 -langchain-openai==0.1.1 -langchain==0.1.13 -langsmith==0.1.33 -loguru==0.7.2 -mlflow==2.11.3 -mpire==2.10.0 -pandas==2.2.1 -peft==0.10.0 -polars==0.20.15 -pydantic==2.7.1 -pyarrow==15.0.0 -pymongo==4.3.3 -python-dotenv==1.0.1 -PyYAML==6.0.1 -requests==2.31.0 -rich==13.7.0 -seaborn==0.13.2 -sentence-transformers==3.0.0 -tenacity==8.2.3 -tensorboard==2.16.2 -tiktoken==0.6.0 -torchmetrics==1.4.0 -trl==0.8.6 -typer==0.9.0 -wandb==0.16.5 -xmltodict==0.13.0 - -# dev -coverage==7.4.3 -dvc[s3]==3.48.3 -mlflow==2.11.3 -mypy==1.8.0 -nbdev==2.3.13 -psycopg2-binary==2.9.9 -pytest==8.0.2 -ruff==0.2.2 -streamlit==1.31.1 \ No newline at end of file +accelerate==0.29.3 +bitsandbytes==0.43.1 +chardet==5.2.0 +datasets==2.19.1 +langchain-openai==0.1.1 +langchain==0.1.13 +langsmith==0.1.33 +loguru==0.7.2 +mlflow==2.11.3 +mpire==2.10.0 +pandas==2.2.1 +peft==0.10.0 +polars==0.20.15 +pydantic==2.7.1 +pyarrow==15.0.0 +pymongo==4.3.3 +python-dotenv==1.0.1 +PyYAML==6.0.1 +requests==2.31.0 +rich==13.7.0 +seaborn==0.13.2 +sentence-transformers==3.0.0 +tenacity==8.2.3 +tensorboard==2.16.2 +tiktoken==0.6.0 +torchmetrics==1.4.0 +trl==0.8.6 +typer==0.9.0 +wandb==0.16.5 +xmltodict==0.13.0 + +# dev +coverage==7.4.3 +dvc[s3]==3.48.3 +mlflow==2.11.3 +mypy==1.8.0 +nbdev==2.3.13 +psycopg2-binary==2.9.9 +pytest==8.0.2 +ruff==0.2.2 +streamlit==1.31.1 diff --git a/scripts/README.md b/scripts/README.md index 82f914c..8893971 100644 --- a/scripts/README.md +++ b/scripts/README.md @@ -53,7 +53,7 @@ MONGO_DB_NAME="datasets" ```shell PYTHONPATH=. python scripts/dataset/dump_pl_dataset.py \ --file-name data/datasets/pl/raw/raw.parquet - dvc add data/datasets/pl/raw/raw.parquet && dvc push + dvc add data/datasets/pl/raw/raw.parquet && dvc push ``` 7. Generate dataset card for `pl-court-raw` ```shell @@ -70,12 +70,12 @@ MONGO_DB_NAME="datasets" ```shell NUM_JOBS=8 dvc repro build_instruct_dataset ``` - + 11. Generate dataset card for `pl-court-instruct` ```shell dvc repro instruct_dataset_readme && dvc push ``` - + 12. Upload `pl-court-instruct` dataset card to huggingface ```shell PYTHONPATH=. scripts/dataset/push_instruct_readme.py --repo-id JuDDGES/pl-court-instruct diff --git a/scripts/dataset/build_graph_dataset.py b/scripts/dataset/build_graph_dataset.py index 1daf05e..40c45bc 100644 --- a/scripts/dataset/build_graph_dataset.py +++ b/scripts/dataset/build_graph_dataset.py @@ -1,9 +1,10 @@ import json from pathlib import Path + import networkx as nx -from omegaconf import OmegaConf import torch import typer +from omegaconf import OmegaConf from juddges.data.pl_court_graph import ( create_judgment_legal_base_graph, diff --git a/scripts/dataset/build_instruct_dataset.py b/scripts/dataset/build_instruct_dataset.py index 853e311..60211b2 100644 --- a/scripts/dataset/build_instruct_dataset.py +++ b/scripts/dataset/build_instruct_dataset.py @@ -1,13 +1,12 @@ -from pathlib import Path -from typing import Any -from typing import Optional -from dotenv import load_dotenv -from loguru import logger -import typer -from datasets import load_dataset from datetime import datetime +from pathlib import Path +from typing import Any, Optional +import typer import yaml +from datasets import load_dataset +from dotenv import load_dotenv +from loguru import logger from juddges.settings import PL_JUDGEMENTS_PATH_INSTRUCT, PL_JUDGEMENTS_PATH_RAW @@ -76,6 +75,9 @@ def main( help="Number of parallel jobs to use", ), branch: Optional[str] = typer.Option(None, help="Branch to push the dataset to"), + commit_message: Optional[str] = typer.Option( + None, help="Commit message", envvar="COMMIT_MESSAGE" + ), ) -> None: feature_cols = ["_id"] + FEATURES logger.info("Loading dataset...") @@ -109,7 +111,12 @@ def main( logger.info("Built dataset with following parameters: {ds_info}", ds_info=str(ds)) if repo_id: - ds.push_to_hub(repo_id, max_shard_size=MAX_SHARD_SIZE, revision=branch) + ds.push_to_hub( + repo_id, + max_shard_size=MAX_SHARD_SIZE, + commit_message=commit_message, + revision=branch, + ) else: ds.save_to_disk(target_dir, max_shard_size=MAX_SHARD_SIZE, num_proc=num_jobs) @@ -156,9 +163,11 @@ def _filter(item: dict[str, Any]) -> bool: def to_instruction_fmt(item: dict[str, Any]) -> dict[str, str]: - output = SCHEMA_TEMPLATE.format( - schema=yaml.dump({k: item[SCHEMA_2_FEATURES[k]] for k in SCHEMA_DESC.keys()}).strip() - ) + yaml_output = yaml.dump( + {k: item[SCHEMA_2_FEATURES[k]] for k in SCHEMA_DESC.keys()}, + allow_unicode=True, + ).strip() + output = SCHEMA_TEMPLATE.format(schema=yaml_output) return {"prompt": PROMPT, "context": item["text"], "output": output} diff --git a/scripts/dataset/download_pl_additional_data.py b/scripts/dataset/download_pl_additional_data.py index 9019df0..87ec27c 100644 --- a/scripts/dataset/download_pl_additional_data.py +++ b/scripts/dataset/download_pl_additional_data.py @@ -6,12 +6,12 @@ import typer from dotenv import load_dotenv from loguru import logger -from requests import HTTPError, ConnectionError -from tenacity import retry, wait_random_exponential, retry_if_exception_type, stop_after_attempt +from requests import ConnectionError, HTTPError +from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_random_exponential from tqdm import tqdm from juddges.data.database import BatchDatabaseUpdate, BatchedDatabaseCursor, get_mongo_collection -from juddges.data.pl_court_api import PolishCourtAPI, DataNotFoundError +from juddges.data.pl_court_api import DataNotFoundError, PolishCourtAPI N_JOBS = 6 BATCH_SIZE = 100 diff --git a/scripts/dataset/extract_pl_xml.py b/scripts/dataset/extract_pl_xml.py index 4b5ee66..2d4689d 100644 --- a/scripts/dataset/extract_pl_xml.py +++ b/scripts/dataset/extract_pl_xml.py @@ -1,13 +1,13 @@ import math import multiprocessing -from typing import Optional, Any +from typing import Any, Optional import typer from dotenv import load_dotenv from loguru import logger from tqdm import tqdm -from juddges.data.database import get_mongo_collection, BatchedDatabaseCursor, BatchDatabaseUpdate +from juddges.data.database import BatchDatabaseUpdate, BatchedDatabaseCursor, get_mongo_collection from juddges.preprocessing.pl_court_parser import SimplePlJudgementsParser BATCH_SIZE = 100 diff --git a/scripts/dataset/push_instruct_readme.py b/scripts/dataset/push_instruct_readme.py index 881e33a..a5d0988 100644 --- a/scripts/dataset/push_instruct_readme.py +++ b/scripts/dataset/push_instruct_readme.py @@ -3,7 +3,7 @@ import typer from dotenv import load_dotenv -from huggingface_hub import DatasetCardData, DatasetCard, HfApi +from huggingface_hub import DatasetCard, DatasetCardData, HfApi load_dotenv() diff --git a/scripts/dataset/push_raw_dataset.py b/scripts/dataset/push_raw_dataset.py index 0700c05..c1127f9 100644 --- a/scripts/dataset/push_raw_dataset.py +++ b/scripts/dataset/push_raw_dataset.py @@ -4,7 +4,7 @@ import typer from datasets import load_dataset from dotenv import load_dotenv -from huggingface_hub import DatasetCardData, DatasetCard, HfApi +from huggingface_hub import DatasetCard, DatasetCardData, HfApi from loguru import logger from juddges.settings import PL_JUDGEMENTS_PATH_RAW diff --git a/scripts/dataset/upload_graph_dataset.py b/scripts/dataset/upload_graph_dataset.py index 2feb718..b892524 100644 --- a/scripts/dataset/upload_graph_dataset.py +++ b/scripts/dataset/upload_graph_dataset.py @@ -1,6 +1,7 @@ import json from pathlib import Path from typing import Any + import networkx as nx import typer from huggingface_hub import DatasetCard, DatasetCardData, HfApi diff --git a/scripts/embed/aggregate_embeddings.py b/scripts/embed/aggregate_embeddings.py index 5a7beed..e6b612c 100644 --- a/scripts/embed/aggregate_embeddings.py +++ b/scripts/embed/aggregate_embeddings.py @@ -1,12 +1,13 @@ from pathlib import Path -from loguru import logger + import numpy as np -from tqdm.auto import tqdm -import typer -from datasets import load_from_disk import polars as pl import torch +import typer +from datasets import load_from_disk +from loguru import logger from torch import Tensor +from tqdm.auto import tqdm def main( diff --git a/scripts/embed/embed_text.py b/scripts/embed/embed_text.py index 070d761..54bba5c 100644 --- a/scripts/embed/embed_text.py +++ b/scripts/embed/embed_text.py @@ -1,20 +1,20 @@ import os from pathlib import Path -from datasets import Dataset from typing import Any, Literal + import hydra +import torch +import yaml +from datasets import Dataset, load_dataset from loguru import logger from omegaconf import DictConfig from openai import BaseModel -import torch -from datasets import load_dataset from sentence_transformers import SentenceTransformer from transformers.utils import is_flash_attn_2_available -import yaml from juddges.config import EmbeddingModelConfig, RawDatasetConfig -from juddges.settings import CONFIG_PATH from juddges.preprocessing.text_chunker import TextSplitter +from juddges.settings import CONFIG_PATH from juddges.utils.config import resolve_config assert is_flash_attn_2_available(), "FlashAttention2 is required for this script" diff --git a/scripts/embed/ingest.py b/scripts/embed/ingest.py index 50894c5..b20f765 100644 --- a/scripts/embed/ingest.py +++ b/scripts/embed/ingest.py @@ -2,11 +2,11 @@ from pathlib import Path from typing import Any -from dotenv import load_dotenv import torch +import typer +from dotenv import load_dotenv from loguru import logger from tqdm.auto import tqdm -import typer from juddges.data.database import BatchDatabaseUpdate, BatchedDatabaseCursor, get_mongo_collection diff --git a/scripts/graph_use_case.py b/scripts/graph_use_case.py index 029e9a1..8c09994 100644 --- a/scripts/graph_use_case.py +++ b/scripts/graph_use_case.py @@ -1,5 +1,6 @@ -import torch import os + +import torch from torch_geometric.data import InMemoryDataset, download_url from torch_geometric.transforms import BaseTransform diff --git a/scripts/sft/evaluate.py b/scripts/sft/evaluate.py index 7a2462a..e5b544a 100644 --- a/scripts/sft/evaluate.py +++ b/scripts/sft/evaluate.py @@ -1,5 +1,6 @@ import json from pathlib import Path + import typer from juddges.metrics.info_extraction import evaluate_extraction diff --git a/scripts/sft/evaluate_llm_as_judge.py b/scripts/sft/evaluate_llm_as_judge.py new file mode 100644 index 0000000..f7bb52a --- /dev/null +++ b/scripts/sft/evaluate_llm_as_judge.py @@ -0,0 +1,116 @@ +import json +import os +from pathlib import Path +from pprint import pformat +from typing import Any + +import hydra +import torch +from accelerate import PartialState +from datasets import load_dataset +from loguru import logger +from omegaconf import DictConfig +from pydantic import BaseModel, Field +from torch import Tensor +from transformers import PreTrainedTokenizer + +from juddges.config import LLMConfig +from juddges.models.factory import get_model +from juddges.models.predict import predict_with_llm +from juddges.settings import CONFIG_PATH +from juddges.utils.config import resolve_config + +NUM_PROC = int(os.getenv("NUM_PROC", 1)) +if NUM_PROC > 1: + os.environ["TOKENIZERS_PARALLELISM"] = "false" + +JUDGE_PROMPT = """ +You are evaluating information extraction system by comparing a submitted answer to an expert answer on a given question. +Data is in Polish. Here is the data: +[BEGIN DATA] +************ +[Expert]: {gold} +************ +[Submission]: {answer} +************ +[END DATA] + +Submitted answer should be formatted as YAML. If the submitted answer cannot be parsed as YAML, return incorrect. +When comparing consecutive fields, ignore order of fields, capitalization and don't be sensitive to abbreviations which preserves the meaning of the answer. +In comparison, ignore legal_bases field. +Format you answer as follows: The answer is . Don't provide any additional explanation. +""" + + +class LLMJudgeConfig(BaseModel, extra="forbid"): + model: LLMConfig + answers_file: Path + out_metric_file: Path + out_predictions_file: Path + generate_kwargs: dict[str, Any] = Field(default_factory=dict) + + +@torch.inference_mode() +@hydra.main(version_base="1.3", config_path=str(CONFIG_PATH), config_name="llm_judge.yaml") +def main(cfg: DictConfig) -> None: + cfg_dict = resolve_config(cfg) + logger.info(f"config:\n{pformat(cfg_dict)}") + config = LLMJudgeConfig(**cfg_dict) + + config.out_metric_file.parent.mkdir(parents=True, exist_ok=True) + + ds = load_dataset("json", data_files=str(config.answers_file), split="train") + ds = ds.map( + lambda x: {"input_text": JUDGE_PROMPT.format(answer=x["answer"], gold=x["gold"])}, + ) + ds.cleanup_cache_files() + + model_pack = get_model( + llm_config=config.model, + device_map={"": PartialState().process_index}, + ) + model_pack.generate_kwargs |= config.generate_kwargs + + encoder = SimpleEncoder(tokenizer=model_pack.tokenizer) + ds.set_transform(encoder, columns=["input_text"]) + + predictions = predict_with_llm( + model_pack=model_pack, + dataset=ds, + batch_size=config.model.batch_size, + num_proc=NUM_PROC, + verbose=True, + ) + + with open(config.out_predictions_file, "w") as f: + json.dump(predictions, f, indent="\t") + + +class SimpleEncoder: + def __init__(self, tokenizer: PreTrainedTokenizer): + self.tokenizer = tokenizer + + def __call__(self, batch: dict[str, list[str]]) -> dict[str, Tensor]: + # NOTE: truncation is disabled and padding is set to "longest" + input_texts = [] + for text in batch["input_text"]: + input_chat = [{"role": "user", "content": text}] + final_input = self.tokenizer.apply_chat_template( + input_chat, + add_generation_prompt=True, + tokenize=False, + ) + input_texts.append(final_input) + + return self.tokenizer( + input_texts, + padding="longest", + truncation=False, + return_tensors="pt", + return_attention_mask=False, + return_special_tokens_mask=False, + ) + + +if __name__ == "__main__": + main() diff --git a/scripts/sft/fine_tune_llm.py b/scripts/sft/fine_tune_llm.py index d9d43d1..7e96782 100644 --- a/scripts/sft/fine_tune_llm.py +++ b/scripts/sft/fine_tune_llm.py @@ -2,40 +2,38 @@ Fine-tune a large language model using SFT. the script is based on: https://www.philschmid.de/fine-tune-llms-in-2024-with-trl """ + import os from pathlib import Path import hydra +import torch +from datasets import ( + Dataset, + DatasetDict, + IterableDataset, + IterableDatasetDict, + load_dataset, +) from loguru import logger from omegaconf import DictConfig from openai import BaseModel from peft.tuners.lora.config import LoraConfig -from trl import SFTTrainer from transformers import ( - AutoTokenizer, AutoModelForCausalLM, + AutoTokenizer, BitsAndBytesConfig, PreTrainedModel, PreTrainedTokenizer, Trainer, + TrainingArguments, ) +from trl import SFTTrainer from juddges.config import DatasetConfig, LLMConfig from juddges.data.datasets.utils import create_chat -from juddges.settings import CONFIG_PATH - -from datasets import ( - load_dataset, - DatasetDict, - Dataset, - IterableDatasetDict, - IterableDataset, -) - -import torch -from transformers import TrainingArguments - from juddges.preprocessing.context_truncator import ContextTruncator +from juddges.settings import CONFIG_PATH from juddges.utils.config import resolve_config NUM_PROC = int(os.getenv("NUM_PROC", 1)) diff --git a/scripts/sft/fine_tune_unsloth.py b/scripts/sft/fine_tune_unsloth.py index a293bfa..d0e14d0 100644 --- a/scripts/sft/fine_tune_unsloth.py +++ b/scripts/sft/fine_tune_unsloth.py @@ -6,34 +6,31 @@ from pathlib import Path import hydra +from accelerate import PartialState +from datasets import ( + Dataset, + DatasetDict, + IterableDataset, + IterableDatasetDict, + load_dataset, +) from loguru import logger from omegaconf import DictConfig from openai import BaseModel -from trl import SFTTrainer from transformers import ( PreTrainedModel, PreTrainedTokenizer, Trainer, + TrainingArguments, ) +from trl import SFTTrainer +from unsloth import FastLanguageModel from juddges.config import DatasetConfig, LLMConfig from juddges.data.datasets.utils import create_chat -from juddges.settings import CONFIG_PATH -from accelerate import PartialState - -from datasets import ( - load_dataset, - DatasetDict, - Dataset, - IterableDatasetDict, - IterableDataset, -) - -from transformers import TrainingArguments - from juddges.preprocessing.context_truncator import ContextTruncator +from juddges.settings import CONFIG_PATH from juddges.utils.config import resolve_config -from unsloth import FastLanguageModel NUM_PROC = int(os.getenv("NUM_PROC", 1)) diff --git a/scripts/sft/predict.py b/scripts/sft/predict.py index 89318d9..7515c87 100644 --- a/scripts/sft/predict.py +++ b/scripts/sft/predict.py @@ -1,21 +1,22 @@ import json import os from pathlib import Path -import time +from pprint import pformat +from typing import Any import hydra +import torch from datasets import load_dataset - from loguru import logger -from openai import BaseModel -import torch from omegaconf import DictConfig -from tqdm import tqdm -from torch.utils.data import DataLoader +from openai import BaseModel +from pydantic import Field + from juddges.config import DatasetConfig, LLMConfig -from juddges.settings import CONFIG_PATH from juddges.models.factory import get_model +from juddges.models.predict import predict_with_llm from juddges.preprocessing.text_encoder import TextEncoderForEval +from juddges.settings import CONFIG_PATH from juddges.utils.config import resolve_config DEVICE = "cuda" if torch.cuda.is_available() else "cpu" @@ -31,15 +32,15 @@ class PredictConfig(BaseModel, extra="forbid"): device_map: str output_file: Path metrics_file: Path - max_new_tokens: int truncate_context: bool + generate_kwargs: dict[str, Any] = Field(default_factory=dict) @torch.inference_mode() @hydra.main(version_base="1.3", config_path=str(CONFIG_PATH), config_name="predict.yaml") def main(cfg: DictConfig) -> None: cfg_dict = resolve_config(cfg) - logger.info(f"config:\n{cfg_dict}") + logger.info(f"config:\n{pformat(cfg_dict)}") config = PredictConfig(**cfg_dict) output_file = Path(config.output_file) @@ -49,6 +50,10 @@ def main(cfg: DictConfig) -> None: logger.info("Loading model...") model_pack = get_model(config.model, device_map=config.device_map) + + assert not any(key in model_pack.generate_kwargs for key in config.generate_kwargs.keys()) + model_pack.generate_kwargs |= config.generate_kwargs + model, tokenizer = model_pack.model, model_pack.tokenizer model.eval() if config.model.batch_size > 1 and config.model.padding is False: @@ -62,44 +67,20 @@ def main(cfg: DictConfig) -> None: padding=config.model.padding, ) ds.set_transform(encoder, columns=["prompt", "context"]) - dataloader = DataLoader( - ds, + + model_outputs = predict_with_llm( + model_pack=model_pack, + dataset=ds, batch_size=config.model.batch_size, - num_workers=NUM_PROC, - pin_memory=(NUM_PROC > 1), - shuffle=False, + num_proc=NUM_PROC, + verbose=True, ) - - model_outputs = [] - - with tqdm(dataloader) as pbar: - for batch in pbar: - model_inputs = batch["input_ids"].view(config.model.batch_size, -1) - model_inputs = model_inputs.to(DEVICE, non_blocking=True) - input_length = model_inputs.size(1) - - start_time = time.time() - generated_ids = model.generate( - model_inputs, - max_new_tokens=config.max_new_tokens, - **model_pack.generate_kwargs, - ) - duration = time.time() - start_time - - decoded = tokenizer.batch_decode( - generated_ids[:, input_length:], - skip_special_tokens=True, - ) - model_outputs.extend(decoded) - - pbar.set_postfix_str(f"{generated_ids.numel() / duration: 0.2f} tok/sec") - results = [ {"answer": ans, "gold": gold_ans} for ans, gold_ans in zip(model_outputs, gold_outputs) ] with open(output_file, "w") as f: - json.dump(results, f, indent="\t") + json.dump(results, f, indent="\t", ensure_ascii=False) if __name__ == "__main__": diff --git a/scripts/sft/summarize_metrics.py b/scripts/sft/summarize_metrics.py index e749f6a..68e8768 100644 --- a/scripts/sft/summarize_metrics.py +++ b/scripts/sft/summarize_metrics.py @@ -1,5 +1,5 @@ -from pathlib import Path import json +from pathlib import Path import pandas as pd import typer diff --git a/settings.ini b/settings.ini index ab01294..e21bcad 100644 --- a/settings.ini +++ b/settings.ini @@ -31,7 +31,7 @@ audience = Developers author = Łukasz Augustyniak author_email = aisolutions@lukaszaugustyniak.com copyright = 2024 onwards, %(author)s -description = +description = keywords = nbdev jupyter notebook python language = English status = 3 @@ -39,5 +39,5 @@ user = laugustyniak ### Optional ### # requirements = fastcore pandas -# dev_requirements = -# console_scripts = \ No newline at end of file +# dev_requirements = +# console_scripts = diff --git a/setup.py b/setup.py index e3281ae..27471d2 100644 --- a/setup.py +++ b/setup.py @@ -1,57 +1,79 @@ -from pkg_resources import parse_version +import shlex from configparser import ConfigParser -import setuptools, shlex -assert parse_version(setuptools.__version__)>=parse_version('36.2') + +import setuptools +from pkg_resources import parse_version + +assert parse_version(setuptools.__version__) >= parse_version("36.2") # note: all settings are in settings.ini; edit there, not here -config = ConfigParser(delimiters=['=']) -config.read('settings.ini', encoding='utf-8') -cfg = config['DEFAULT'] +config = ConfigParser(delimiters=["="]) +config.read("settings.ini", encoding="utf-8") +cfg = config["DEFAULT"] -cfg_keys = 'version description keywords author author_email'.split() +cfg_keys = "version description keywords author author_email".split() expected = cfg_keys + "lib_name user branch license status min_python audience language".split() -for o in expected: assert o in cfg, "missing expected setting: {}".format(o) -setup_cfg = {o:cfg[o] for o in cfg_keys} +for o in expected: + assert o in cfg, "missing expected setting: {}".format(o) +setup_cfg = {o: cfg[o] for o in cfg_keys} licenses = { - 'apache2': ('Apache Software License 2.0','OSI Approved :: Apache Software License'), - 'mit': ('MIT License', 'OSI Approved :: MIT License'), - 'gpl2': ('GNU General Public License v2', 'OSI Approved :: GNU General Public License v2 (GPLv2)'), - 'gpl3': ('GNU General Public License v3', 'OSI Approved :: GNU General Public License v3 (GPLv3)'), - 'bsd3': ('BSD License', 'OSI Approved :: BSD License'), + "apache2": ("Apache Software License 2.0", "OSI Approved :: Apache Software License"), + "mit": ("MIT License", "OSI Approved :: MIT License"), + "gpl2": ( + "GNU General Public License v2", + "OSI Approved :: GNU General Public License v2 (GPLv2)", + ), + "gpl3": ( + "GNU General Public License v3", + "OSI Approved :: GNU General Public License v3 (GPLv3)", + ), + "bsd3": ("BSD License", "OSI Approved :: BSD License"), } -statuses = [ '1 - Planning', '2 - Pre-Alpha', '3 - Alpha', - '4 - Beta', '5 - Production/Stable', '6 - Mature', '7 - Inactive' ] -py_versions = '3.6 3.7 3.8 3.9 3.10'.split() +statuses = [ + "1 - Planning", + "2 - Pre-Alpha", + "3 - Alpha", + "4 - Beta", + "5 - Production/Stable", + "6 - Mature", + "7 - Inactive", +] +py_versions = "3.6 3.7 3.8 3.9 3.10".split() -requirements = shlex.split(cfg.get('requirements', '')) -if cfg.get('pip_requirements'): requirements += shlex.split(cfg.get('pip_requirements', '')) -min_python = cfg['min_python'] -lic = licenses.get(cfg['license'].lower(), (cfg['license'], None)) -dev_requirements = (cfg.get('dev_requirements') or '').split() +requirements = shlex.split(cfg.get("requirements", "")) +if cfg.get("pip_requirements"): + requirements += shlex.split(cfg.get("pip_requirements", "")) +min_python = cfg["min_python"] +lic = licenses.get(cfg["license"].lower(), (cfg["license"], None)) +dev_requirements = (cfg.get("dev_requirements") or "").split() setuptools.setup( - name = cfg['lib_name'], - license = lic[0], - classifiers = [ - 'Development Status :: ' + statuses[int(cfg['status'])], - 'Intended Audience :: ' + cfg['audience'].title(), - 'Natural Language :: ' + cfg['language'].title(), - ] + ['Programming Language :: Python :: '+o for o in py_versions[py_versions.index(min_python):]] + (['License :: ' + lic[1] ] if lic[1] else []), - url = cfg['git_url'], - packages = setuptools.find_packages(), - include_package_data = True, - install_requires = requirements, - extras_require={ 'dev': dev_requirements }, - dependency_links = cfg.get('dep_links','').split(), - python_requires = '>=' + cfg['min_python'], - long_description = open('README.md', encoding='utf-8').read(), - long_description_content_type = 'text/markdown', - zip_safe = False, - entry_points = { - 'console_scripts': cfg.get('console_scripts','').split(), - 'nbdev': [f'{cfg.get("lib_path")}={cfg.get("lib_path")}._modidx:d'] + name=cfg["lib_name"], + license=lic[0], + classifiers=[ + "Development Status :: " + statuses[int(cfg["status"])], + "Intended Audience :: " + cfg["audience"].title(), + "Natural Language :: " + cfg["language"].title(), + ] + + [ + "Programming Language :: Python :: " + o + for o in py_versions[py_versions.index(min_python) :] + ] + + (["License :: " + lic[1]] if lic[1] else []), + url=cfg["git_url"], + packages=setuptools.find_packages(), + include_package_data=True, + install_requires=requirements, + extras_require={"dev": dev_requirements}, + dependency_links=cfg.get("dep_links", "").split(), + python_requires=">=" + cfg["min_python"], + long_description=open("README.md", encoding="utf-8").read(), + long_description_content_type="text/markdown", + zip_safe=False, + entry_points={ + "console_scripts": cfg.get("console_scripts", "").split(), + "nbdev": [f'{cfg.get("lib_path")}={cfg.get("lib_path")}._modidx:d'], }, - **setup_cfg) - - + **setup_cfg, +)