Skip to content

Commit

Permalink
dashboard ukpl data with hybrid search (#28)
Browse files Browse the repository at this point in the history
* hybrid search works

* fix text area

* fix linter

* fix nbdev
  • Loading branch information
laugustyniak authored Jul 6, 2024
1 parent c70a092 commit 1641858
Show file tree
Hide file tree
Showing 15 changed files with 418 additions and 14 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ fostering cross-disciplinary and cross-jurisdictional collaboration.
### Installation

- to install necessary dependencies use available `Makefile`, you can
use `python>=3.10`: `shell make install`
use `python>=3.10`: `make install`
- if you want to run evaluation and fine-tuning with `unsloth`, use the
following command with `python=3.10` inside conda environment:
`shell make install_unsloth`
`make install_unsloth`

### Dataset creation

Expand Down
44 changes: 32 additions & 12 deletions dashboards/pages/01_🔍_Search_Judgements.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,60 @@
from typing import Any
import streamlit as st

from juddges.data.datasets import get_mongo_collection
import streamlit as st
from pymongo.collection import Collection
from sentence_transformers import SentenceTransformer

from juddges.data.database import get_mongo_collection
from juddges.retrieval.mongo_hybrid_search import run_hybrid_search
from juddges.settings import TEXT_EMBEDDING_MODEL

TITLE = "Search for Judgements"

st.set_page_config(page_title=TITLE, page_icon="⚖️", layout="wide")

st.title(TITLE)
st.header(
"Search is based on hybrid search using text and vector search with the same priority for both."
)

judgement_country = st.sidebar.selectbox("Select judgement country", ["pl", "uk"])
judgement_collection_name = f"{judgement_country}-court"
st.sidebar.info(f"Selected country: {judgement_collection_name}")


@st.cache_resource
def get_judgements_collection() -> Collection:
return get_mongo_collection("judgements")
def get_judgements_collection(collection_name: str = "pl-court") -> Collection:
return get_mongo_collection(collection_name=collection_name)


judgements_collection = get_judgements_collection()
@st.cache_resource
def get_embedding_model() -> Any:
return SentenceTransformer(TEXT_EMBEDDING_MODEL)


def search_data(query: str, max_judgements: int = 5) -> list[dict[str, Any]]:
items = list(judgements_collection.find({"$text": {"$search": query}}).limit(max_judgements))
return items
judgements_collection = get_judgements_collection(judgement_collection_name)

model = get_embedding_model()

with st.form(key="search_form"):
text = st.text_area("What you are looking for in the judgements?")
query = st.text_area("What you are looking for in the judgements?")
max_judgements = st.slider("Max judgements to show", min_value=1, max_value=20, value=5)
submit_button = st.form_submit_button(label="Search")

if submit_button:
with st.spinner("Searching..."):
items = search_data(text, max_judgements)
items = run_hybrid_search(
collection=judgements_collection,
collection_name=judgement_collection_name,
embedding=model.encode(query).tolist(),
query=query,
limit=max_judgements,
)

st.header("Judgements - Results")
for item in items:
st.header(item["signature"])
st.subheader(item["publicationDate"])
st.write(item["text"])
st.info(f"Department: {item['department_name']}")
st.info(f"Score: {item['score']}")
st.subheader(item["excerpt"])
st.text_area(label="Judgement text", value=item["text"], height=200)
1 change: 1 addition & 0 deletions juddges/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
'juddges.preprocessing.text_chunker': {},
'juddges.preprocessing.text_encoder': {},
'juddges.prompts.information_extraction': {},
'juddges.retrieval.mongo_hybrid_search': {},
'juddges.settings': {},
'juddges.utils.config': {},
'juddges.utils.misc': {},
Expand Down
Empty file added juddges/retrieval/__init__.py
Empty file.
110 changes: 110 additions & 0 deletions juddges/retrieval/mongo_hybrid_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
def run_hybrid_search(
collection,
collection_name: str,
embedding,
query: str,
limit: int = 10,
vector_priority: float = 1,
text_priority: float = 1,
):
num_candidates = limit * 10

vector_search = {
"$vectorSearch": {
"index": "vector_index",
"path": "embedding",
"queryVector": embedding,
"numCandidates": num_candidates,
"limit": limit,
}
}

make_array = {"$group": {"_id": None, "docs": {"$push": "$$ROOT"}}}

add_rank = {"$unwind": {"path": "$docs", "includeArrayIndex": "rank"}}

def make_compute_score_doc(priority, score_field_name):
return {
"$addFields": {score_field_name: {"$divide": [1.0, {"$add": ["$rank", priority, 1]}]}}
}

def make_projection_doc(score_field_name):
return {
"$project": {
score_field_name: 1,
"_id": "$docs._id",
"excerpt": "$docs.excerpt",
"text": "$docs.text",
"department_name": "$docs.department_name",
"signature": "$docs.signature",
}
}

text_search = {
"$search": {
"index": "text_index",
"text": {"query": query, "path": "text"},
}
}

limit_results = {"$limit": limit}

combine_search_results = {
"$group": {
"_id": "$_id",
"vs_score": {"$max": "$vs_score"},
"ts_score": {"$max": "$ts_score"},
"excerpt": {"$first": "$excerpt"},
"text": {"$first": "$text"},
"department_name": {"$first": "$department_name"},
"signature": {"$first": "$signature"},
}
}

project_combined_results = {
"$project": {
"_id": 1,
"excerpt": 1,
"text": 1,
"department_name": 1,
"signature": 1,
"score": {
"$let": {
"vars": {
"vs_score": {"$ifNull": ["$vs_score", 0]},
"ts_score": {"$ifNull": ["$ts_score", 0]},
},
"in": {"$add": ["$$vs_score", "$$ts_score"]},
}
},
}
}

sort_results = {"$sort": {"score": -1}}

pipeline = [
vector_search,
make_array,
add_rank,
make_compute_score_doc(vector_priority, "vs_score"),
make_projection_doc("vs_score"),
{
"$unionWith": {
"coll": collection_name,
"pipeline": [
text_search,
limit_results,
make_array,
add_rank,
make_compute_score_doc(text_priority, "ts_score"),
make_projection_doc("ts_score"),
],
}
},
combine_search_results,
project_combined_results,
sort_results,
limit_results,
]

return collection.aggregate(pipeline)
2 changes: 2 additions & 0 deletions juddges/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

MLFLOW_EXP_NAME = "Juddges-Information-Extraction"

TEXT_EMBEDDING_MODEL = "sdadas/mmlw-roberta-large"


def num_tokens_from_string(
string: str, # The string to count tokens for
Expand Down
1 change: 1 addition & 0 deletions nbs/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
/.quarto/
16 changes: 16 additions & 0 deletions nbs/Data/01_Dataset_Description.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"metadata": {},
"outputs": [],
"source": [
"# | eval: false\n",
"from datasets import load_dataset\n",
"import pandas as pd\n",
"import polars as pl\n",
Expand Down Expand Up @@ -71,6 +72,7 @@
}
],
"source": [
"# | eval: false\n",
"raw_ds = pl.scan_parquet(source=\"../../data/datasets/pl/raw/*\")\n",
"raw_ds.columns"
]
Expand Down Expand Up @@ -165,6 +167,7 @@
}
],
"source": [
"# | eval: false\n",
"court_distribution = raw_ds.drop_nulls(subset=\"court_name\").select(\"court_name\").group_by(\"court_name\").len().sort(\"len\", descending=True).collect().to_pandas()\n",
"ax = sns.histplot(data=court_distribution, x=\"len\", log_scale=True, kde=True)\n",
"ax.set(title=\"Distribution of judgments per court\", xlabel=\"#Judgements in single court\", ylabel=\"Count\")\n",
Expand All @@ -189,6 +192,7 @@
}
],
"source": [
"# | eval: false\n",
"judgements_per_year = raw_ds.select(\"date\").collect()[\"date\"].str.split(\" \").list.get(0).str.to_date().dt.year().value_counts().sort(\"date\").to_pandas()\n",
"judgements_per_year = judgements_per_year[judgements_per_year[\"date\"] < 2024]\n",
"\n",
Expand Down Expand Up @@ -217,6 +221,7 @@
}
],
"source": [
"# | eval: false\n",
"types = raw_ds.fill_null(value=\"<null>\").select(\"type\").group_by(\"type\").len().sort(\"len\", descending=True).collect().to_pandas()\n",
"\n",
"_, ax = plt.subplots(1, 1, figsize=(8, 8))\n",
Expand All @@ -243,6 +248,7 @@
}
],
"source": [
"# | eval: false\n",
"num_judges = raw_ds.with_columns([pl.col(\"judges\").list.len().alias(\"num_judges\")]).select(\"num_judges\").sort(\"num_judges\").collect().to_pandas()\n",
"ax = sns.histplot(data=num_judges, x=\"num_judges\", bins=num_judges[\"num_judges\"].nunique())\n",
"ax.set(xlabel=\"#Judges per judgement\", ylabel=\"Count\", yscale=\"log\", title=\"#Judges per single judgement\")\n",
Expand All @@ -267,6 +273,7 @@
}
],
"source": [
"# | eval: false\n",
"num_lb = raw_ds.with_columns([pl.col(\"legalBases\").list.len().alias(\"num_lb\")]).select(\"num_lb\").sort(\"num_lb\").collect().to_pandas()\n",
"ax = sns.histplot(data=num_lb, x=\"num_lb\", bins=num_lb[\"num_lb\"].nunique())\n",
"ax.set(xlabel=\"#Legal bases\", ylabel=\"Count\", yscale=\"log\", title=\"#Legal bases per judgement\")\n",
Expand Down Expand Up @@ -309,6 +316,7 @@
}
],
"source": [
"# | eval: false\n",
"raw_text_ds = load_dataset(\"parquet\", data_dir=\"../../data/datasets/pl/raw/\", columns=[\"_id\", \"text\"])\n",
"raw_text_ds = raw_text_ds.filter(lambda x: x[\"text\"] is not None)"
]
Expand Down Expand Up @@ -357,6 +365,7 @@
}
],
"source": [
"# | eval: false\n",
"tokenizer = AutoTokenizer.from_pretrained(\"meta-llama/Meta-Llama-3-8B\")\n",
"\n",
"def tokenize(batch: dict[str, list]) -> list[int]: \n",
Expand Down Expand Up @@ -385,6 +394,7 @@
}
],
"source": [
"# | eval: false\n",
"judgement_len = raw_text_ds[\"train\"].to_pandas()\n",
"\n",
"ax = sns.histplot(data=judgement_len, x=\"length\", bins=50)\n",
Expand All @@ -411,6 +421,7 @@
}
],
"source": [
"# | eval: false\n",
"per_type_tokens = raw_ds.fill_null(value=\"<null>\").select([\"_id\", \"type\"]).collect().to_pandas().set_index(\"_id\").join(judgement_len.set_index(\"_id\"))\n",
"\n",
"_, ax = plt.subplots(1, 1, figsize=(10, 10))\n",
Expand Down Expand Up @@ -515,6 +526,7 @@
}
],
"source": [
"# | eval: false\n",
"instruct_ds = load_dataset(\"JuDDGES/pl-court-instruct\")\n",
"instruct_ds"
]
Expand Down Expand Up @@ -588,6 +600,7 @@
}
],
"source": [
"# | eval: false\n",
"df = pd.DataFrame([{\"Split\":k, \"#\": len(v)} for k, v in instruct_ds.items()])\n",
"df[\"%\"] = df[\"#\"] / df[\"#\"].sum() * 100\n",
"df.round(2)"
Expand Down Expand Up @@ -696,6 +709,7 @@
}
],
"source": [
"# | eval: false\n",
"from torch import le\n",
"\n",
"\n",
Expand Down Expand Up @@ -728,6 +742,7 @@
}
],
"source": [
"# | eval: false\n",
"tok_melt = instruct_ds_tok.melt(id_vars=[\"split\"], value_vars=[\"context_num_tokens\", \"output_num_tokens\"], var_name=\"Text\", value_name=\"#Tokens\")\n",
"tok_melt[\"Text\"] = tok_melt[\"Text\"].map({\"context_num_tokens\": \"Context\", \"output_num_tokens\": \"Output\"})\n",
"\n",
Expand Down Expand Up @@ -755,6 +770,7 @@
}
],
"source": [
"# | eval: false\n",
"_, ax = plt.subplots(1, 1, figsize=(10, 10))\n",
"ax = sns.countplot(data=per_type_tokens.join(instruct_ds_tok.set_index(\"_id\"), how=\"right\"), y=\"type\", hue=\"split\")\n",
"ax.set(xscale=\"log\", title=\"Distribution of types in dataset splits\", xlabel=\"Count\", ylabel=\"Type\")\n",
Expand Down
5 changes: 5 additions & 0 deletions nbs/Data/02_Analyse_sft.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"metadata": {},
"outputs": [],
"source": [
"# | eval: false\n",
"import warnings\n",
"import json\n",
"from multiprocessing import Pool\n",
Expand Down Expand Up @@ -176,6 +177,7 @@
}
],
"source": [
"# | eval: false\n",
"results = []\n",
"for f in Path(\"../../data/experiments/predict/pl-court-instruct\").glob(\"metrics_*.json\"):\n",
" model_name = f.stem.replace(\"metrics_\", \"\")\n",
Expand Down Expand Up @@ -203,6 +205,7 @@
"metadata": {},
"outputs": [],
"source": [
"# | eval: false\n",
"OUTPUTS_PATH = \"../../data/experiments/predict/pl-court-instruct/outputs_Unsloth-Llama-3-8B-Instruct-fine-tuned.json\"\n",
"\n",
"with open(OUTPUTS_PATH) as file:\n",
Expand Down Expand Up @@ -237,6 +240,7 @@
}
],
"source": [
"# | eval: false\n",
"def eval_item(item: dict[str, Any]) -> dict[str, Any]:\n",
" item[\"metrics\"] = evaluate_extraction([item])\n",
" item[\"metrics\"][\"mean_field\"] = mean(item[\"metrics\"][\"field_chrf\"].values())\n",
Expand Down Expand Up @@ -279,6 +283,7 @@
}
],
"source": [
"# | eval: false\n",
"data_valid = [item for item in results if item[\"answer\"] is not None]\n",
"data_valid = sorted(data_valid, key=lambda x: x[\"metrics\"][\"mean_field\"])\n",
"\n",
Expand Down
Loading

0 comments on commit 1641858

Please sign in to comment.