Skip to content

Commit

Permalink
nbdev
Browse files Browse the repository at this point in the history
  • Loading branch information
laugustyniak committed Dec 2, 2024
1 parent 0e9ae83 commit 5cc97e0
Show file tree
Hide file tree
Showing 4 changed files with 369 additions and 182 deletions.
7 changes: 6 additions & 1 deletion juddges/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,22 @@
'doc_host': 'https://laugustyniak.github.io',
'git_url': 'https://github.com/laugustyniak/juddges',
'lib_path': 'juddges'},
'syms': { 'juddges.config': {},
'syms': { 'juddges.case_law_trends.constants': {},
'juddges.case_law_trends.visualisations': {},
'juddges.config': {},
'juddges.data.database': {},
'juddges.data.datasets.dummy_dataset': {},
'juddges.data.datasets.utils': {},
'juddges.data.pl_court_api': {},
'juddges.data.pl_court_graph': {},
'juddges.data.weaviate_db': {},
'juddges.data_models': {},
'juddges.evaluation.eval_full_text': {},
'juddges.evaluation.eval_structured': {},
'juddges.evaluation.eval_structured_llm_judge': {},
'juddges.evaluation.info_extraction': {},
'juddges.evaluation.parse': {},
'juddges.llms': {},
'juddges.models.factory': {},
'juddges.models.predict': {},
'juddges.preprocessing.context_truncator': {},
Expand All @@ -26,6 +30,7 @@
'juddges.preprocessing.text_encoder': {},
'juddges.prompts.information_extraction': {},
'juddges.retrieval.mongo_hybrid_search': {},
'juddges.retrieval.mongo_term_based_search': {},
'juddges.settings': {},
'juddges.utils.config': {},
'juddges.utils.misc': {},
Expand Down
215 changes: 183 additions & 32 deletions nbs/Data/01_Dataset_Description.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,21 @@
],
"source": [
"# | eval: false\n",
"court_distribution = raw_ds.drop_nulls(subset=\"court_name\").select(\"court_name\").group_by(\"court_name\").len().sort(\"len\", descending=True).collect().to_pandas()\n",
"court_distribution = (\n",
" raw_ds.drop_nulls(subset=\"court_name\")\n",
" .select(\"court_name\")\n",
" .group_by(\"court_name\")\n",
" .len()\n",
" .sort(\"len\", descending=True)\n",
" .collect()\n",
" .to_pandas()\n",
")\n",
"ax = sns.histplot(data=court_distribution, x=\"len\", log_scale=True, kde=True)\n",
"ax.set(title=\"Distribution of judgments per court\", xlabel=\"#Judgements in single court\", ylabel=\"Count\")\n",
"ax.set(\n",
" title=\"Distribution of judgments per court\",\n",
" xlabel=\"#Judgements in single court\",\n",
" ylabel=\"Count\",\n",
")\n",
"plt.show()"
]
},
Expand All @@ -193,12 +205,29 @@
],
"source": [
"# | eval: false\n",
"judgements_per_year = raw_ds.select(\"date\").collect()[\"date\"].str.split(\" \").list.get(0).str.to_date().dt.year().value_counts().sort(\"date\").to_pandas()\n",
"judgements_per_year = (\n",
" raw_ds.select(\"date\")\n",
" .collect()[\"date\"]\n",
" .str.split(\" \")\n",
" .list.get(0)\n",
" .str.to_date()\n",
" .dt.year()\n",
" .value_counts()\n",
" .sort(\"date\")\n",
" .to_pandas()\n",
")\n",
"judgements_per_year = judgements_per_year[judgements_per_year[\"date\"] < 2024]\n",
"\n",
"_, ax = plt.subplots(1, 1, figsize=(10, 5))\n",
"ax = sns.pointplot(data=judgements_per_year, x=\"date\", y=\"count\", linestyles=\"--\", ax=ax)\n",
"ax.set(xlabel=\"Year\", ylabel=\"Number of Judgements\", title=\"Yearly Number of Judgements\", yscale=\"log\")\n",
"ax = sns.pointplot(\n",
" data=judgements_per_year, x=\"date\", y=\"count\", linestyles=\"--\", ax=ax\n",
")\n",
"ax.set(\n",
" xlabel=\"Year\",\n",
" ylabel=\"Number of Judgements\",\n",
" title=\"Yearly Number of Judgements\",\n",
" yscale=\"log\",\n",
")\n",
"plt.xticks(rotation=90)\n",
"plt.show()"
]
Expand All @@ -222,7 +251,15 @@
],
"source": [
"# | eval: false\n",
"types = raw_ds.fill_null(value=\"<null>\").select(\"type\").group_by(\"type\").len().sort(\"len\", descending=True).collect().to_pandas()\n",
"types = (\n",
" raw_ds.fill_null(value=\"<null>\")\n",
" .select(\"type\")\n",
" .group_by(\"type\")\n",
" .len()\n",
" .sort(\"len\", descending=True)\n",
" .collect()\n",
" .to_pandas()\n",
")\n",
"\n",
"_, ax = plt.subplots(1, 1, figsize=(8, 8))\n",
"ax = sns.barplot(data=types, x=\"len\", y=\"type\", errorbar=None, ax=ax)\n",
Expand All @@ -249,9 +286,22 @@
],
"source": [
"# | eval: false\n",
"num_judges = raw_ds.with_columns([pl.col(\"judges\").list.len().alias(\"num_judges\")]).select(\"num_judges\").sort(\"num_judges\").collect().to_pandas()\n",
"ax = sns.histplot(data=num_judges, x=\"num_judges\", bins=num_judges[\"num_judges\"].nunique())\n",
"ax.set(xlabel=\"#Judges per judgement\", ylabel=\"Count\", yscale=\"log\", title=\"#Judges per single judgement\")\n",
"num_judges = (\n",
" raw_ds.with_columns([pl.col(\"judges\").list.len().alias(\"num_judges\")])\n",
" .select(\"num_judges\")\n",
" .sort(\"num_judges\")\n",
" .collect()\n",
" .to_pandas()\n",
")\n",
"ax = sns.histplot(\n",
" data=num_judges, x=\"num_judges\", bins=num_judges[\"num_judges\"].nunique()\n",
")\n",
"ax.set(\n",
" xlabel=\"#Judges per judgement\",\n",
" ylabel=\"Count\",\n",
" yscale=\"log\",\n",
" title=\"#Judges per single judgement\",\n",
")\n",
"plt.show()"
]
},
Expand All @@ -274,9 +324,20 @@
],
"source": [
"# | eval: false\n",
"num_lb = raw_ds.with_columns([pl.col(\"legalBases\").list.len().alias(\"num_lb\")]).select(\"num_lb\").sort(\"num_lb\").collect().to_pandas()\n",
"num_lb = (\n",
" raw_ds.with_columns([pl.col(\"legalBases\").list.len().alias(\"num_lb\")])\n",
" .select(\"num_lb\")\n",
" .sort(\"num_lb\")\n",
" .collect()\n",
" .to_pandas()\n",
")\n",
"ax = sns.histplot(data=num_lb, x=\"num_lb\", bins=num_lb[\"num_lb\"].nunique())\n",
"ax.set(xlabel=\"#Legal bases\", ylabel=\"Count\", yscale=\"log\", title=\"#Legal bases per judgement\")\n",
"ax.set(\n",
" xlabel=\"#Legal bases\",\n",
" ylabel=\"Count\",\n",
" yscale=\"log\",\n",
" title=\"#Legal bases per judgement\",\n",
")\n",
"plt.show()"
]
},
Expand All @@ -303,7 +364,9 @@
],
"source": [
"# | eval: false\n",
"raw_text_ds = load_dataset(\"parquet\", data_dir=\"../../data/datasets/pl/raw/\", columns=[\"_id\", \"text\"])\n",
"raw_text_ds = load_dataset(\n",
" \"parquet\", data_dir=\"../../data/datasets/pl/raw/\", columns=[\"_id\", \"text\"]\n",
")\n",
"raw_text_ds = raw_text_ds.filter(lambda x: x[\"text\"] is not None)"
]
},
Expand Down Expand Up @@ -340,11 +403,21 @@
"# | eval: false\n",
"tokenizer = AutoTokenizer.from_pretrained(\"meta-llama/Meta-Llama-3-8B\")\n",
"\n",
"def tokenize(batch: dict[str, list]) -> list[int]: \n",
" tokenized = tokenizer(batch[\"text\"], add_special_tokens=False, return_attention_mask=False, return_token_type_ids=False, return_length=True)\n",
"\n",
"def tokenize(batch: dict[str, list]) -> list[int]:\n",
" tokenized = tokenizer(\n",
" batch[\"text\"],\n",
" add_special_tokens=False,\n",
" return_attention_mask=False,\n",
" return_token_type_ids=False,\n",
" return_length=True,\n",
" )\n",
" return {\"length\": tokenized[\"length\"]}\n",
"\n",
"raw_text_ds = raw_text_ds.map(tokenize, batched=True, batch_size=16, remove_columns=[\"text\"], num_proc=20)\n",
"\n",
"raw_text_ds = raw_text_ds.map(\n",
" tokenize, batched=True, batch_size=16, remove_columns=[\"text\"], num_proc=20\n",
")\n",
"raw_text_ds"
]
},
Expand All @@ -370,8 +443,13 @@
"judgement_len = raw_text_ds[\"train\"].to_pandas()\n",
"\n",
"ax = sns.histplot(data=judgement_len, x=\"length\", bins=50)\n",
"ax.set(xlabel=\"#Tokens\", ylabel=\"Count\", title=\"#Tokens distribution in judgements (llama-3 tokenizer)\", yscale=\"log\")\n",
"ax.xaxis.set_major_formatter(ticker.FuncFormatter(lambda x, pos: f'{int(x/1_000)}k'))\n",
"ax.set(\n",
" xlabel=\"#Tokens\",\n",
" ylabel=\"Count\",\n",
" title=\"#Tokens distribution in judgements (llama-3 tokenizer)\",\n",
" yscale=\"log\",\n",
")\n",
"ax.xaxis.set_major_formatter(ticker.FuncFormatter(lambda x, pos: f\"{int(x/1_000)}k\"))\n",
"plt.show()"
]
},
Expand All @@ -394,11 +472,23 @@
],
"source": [
"# | eval: false\n",
"per_type_tokens = raw_ds.fill_null(value=\"<null>\").select([\"_id\", \"type\"]).collect().to_pandas().set_index(\"_id\").join(judgement_len.set_index(\"_id\"))\n",
"per_type_tokens = (\n",
" raw_ds.fill_null(value=\"<null>\")\n",
" .select([\"_id\", \"type\"])\n",
" .collect()\n",
" .to_pandas()\n",
" .set_index(\"_id\")\n",
" .join(judgement_len.set_index(\"_id\"))\n",
")\n",
"\n",
"_, ax = plt.subplots(1, 1, figsize=(10, 10))\n",
"ax = sns.boxenplot(data=per_type_tokens, y=\"type\", x=\"length\")\n",
"ax.set(xscale=\"log\", title=\"Judgement token count per type\", xlabel=\"#Tokens\", ylabel=\"Type\")\n",
"ax.set(\n",
" xscale=\"log\",\n",
" title=\"Judgement token count per type\",\n",
" xlabel=\"#Tokens\",\n",
" ylabel=\"Type\",\n",
")\n",
"plt.show()"
]
},
Expand Down Expand Up @@ -573,7 +663,7 @@
],
"source": [
"# | eval: false\n",
"df = pd.DataFrame([{\"Split\":k, \"#\": len(v)} for k, v in instruct_ds.items()])\n",
"df = pd.DataFrame([{\"Split\": k, \"#\": len(v)} for k, v in instruct_ds.items()])\n",
"df[\"%\"] = df[\"#\"] / df[\"#\"].sum() * 100\n",
"df.round(2)"
]
Expand Down Expand Up @@ -686,13 +776,43 @@
"\n",
"\n",
"def tokenize_and_comp_length_instruct_ds(batch: dict[str, list]) -> dict[str, list]:\n",
" tokenized_ctx = tokenizer(batch[\"context\"], add_special_tokens=False, return_attention_mask=False, return_token_type_ids=False, return_length=True)\n",
" tokenized_out = tokenizer(batch[\"output\"], add_special_tokens=False, return_attention_mask=False, return_token_type_ids=False, return_length=True)\n",
" tokenized_ctx = tokenizer(\n",
" batch[\"context\"],\n",
" add_special_tokens=False,\n",
" return_attention_mask=False,\n",
" return_token_type_ids=False,\n",
" return_length=True,\n",
" )\n",
" tokenized_out = tokenizer(\n",
" batch[\"output\"],\n",
" add_special_tokens=False,\n",
" return_attention_mask=False,\n",
" return_token_type_ids=False,\n",
" return_length=True,\n",
" )\n",
"\n",
" return {\"context_num_tokens\": tokenized_ctx[\"length\"], \"output_num_tokens\": tokenized_out[\"length\"]}\n",
" return {\n",
" \"context_num_tokens\": tokenized_ctx[\"length\"],\n",
" \"output_num_tokens\": tokenized_out[\"length\"],\n",
" }\n",
"\n",
"instruct_ds_tok = instruct_ds.map(tokenize_and_comp_length_instruct_ds, batched=True, batch_size=32, remove_columns=[\"prompt\", \"context\", \"output\"], num_proc=20)\n",
"instruct_ds_tok = pd.concat([instruct_ds_tok[\"train\"].to_pandas(), instruct_ds_tok[\"test\"].to_pandas()], axis=0, keys=[\"train\", \"test\"]).reset_index(level=0).rename(columns={\"level_0\": \"split\"})\n",
"\n",
"instruct_ds_tok = instruct_ds.map(\n",
" tokenize_and_comp_length_instruct_ds,\n",
" batched=True,\n",
" batch_size=32,\n",
" remove_columns=[\"prompt\", \"context\", \"output\"],\n",
" num_proc=20,\n",
")\n",
"instruct_ds_tok = (\n",
" pd.concat(\n",
" [instruct_ds_tok[\"train\"].to_pandas(), instruct_ds_tok[\"test\"].to_pandas()],\n",
" axis=0,\n",
" keys=[\"train\", \"test\"],\n",
" )\n",
" .reset_index(level=0)\n",
" .rename(columns={\"level_0\": \"split\"})\n",
")\n",
"instruct_ds_tok.head()"
]
},
Expand All @@ -711,7 +831,10 @@
}
],
"source": [
"print(f\"0.95 quantile of maximum output: {instruct_ds_tok['output_num_tokens'].quantile(0.95)}\")"
"# | eval: false\n",
"print(\n",
" f\"0.95 quantile of maximum output: {instruct_ds_tok['output_num_tokens'].quantile(0.95)}\"\n",
")"
]
},
{
Expand All @@ -733,11 +856,30 @@
],
"source": [
"# | eval: false\n",
"tok_melt = instruct_ds_tok.melt(id_vars=[\"split\"], value_vars=[\"context_num_tokens\", \"output_num_tokens\"], var_name=\"Text\", value_name=\"#Tokens\")\n",
"tok_melt[\"Text\"] = tok_melt[\"Text\"].map({\"context_num_tokens\": \"Context\", \"output_num_tokens\": \"Output\"})\n",
"tok_melt = instruct_ds_tok.melt(\n",
" id_vars=[\"split\"],\n",
" value_vars=[\"context_num_tokens\", \"output_num_tokens\"],\n",
" var_name=\"Text\",\n",
" value_name=\"#Tokens\",\n",
")\n",
"tok_melt[\"Text\"] = tok_melt[\"Text\"].map(\n",
" {\"context_num_tokens\": \"Context\", \"output_num_tokens\": \"Output\"}\n",
")\n",
"\n",
"g = sns.displot(data=tok_melt, x=\"#Tokens\", col=\"Text\", hue=\"split\", kind=\"kde\", fill=True, log_scale=True, common_norm=False, facet_kws=dict(sharex=False, sharey=False))\n",
"g.figure.suptitle(\"Distribution of #Tokens (llama-3 tokenizer) in Context and Output in instruct dataset\")\n",
"g = sns.displot(\n",
" data=tok_melt,\n",
" x=\"#Tokens\",\n",
" col=\"Text\",\n",
" hue=\"split\",\n",
" kind=\"kde\",\n",
" fill=True,\n",
" log_scale=True,\n",
" common_norm=False,\n",
" facet_kws=dict(sharex=False, sharey=False),\n",
")\n",
"g.figure.suptitle(\n",
" \"Distribution of #Tokens (llama-3 tokenizer) in Context and Output in instruct dataset\"\n",
")\n",
"g.figure.tight_layout()\n",
"plt.show()"
]
Expand All @@ -762,8 +904,17 @@
"source": [
"# | eval: false\n",
"_, ax = plt.subplots(1, 1, figsize=(10, 10))\n",
"ax = sns.countplot(data=per_type_tokens.join(instruct_ds_tok.set_index(\"_id\"), how=\"right\"), y=\"type\", hue=\"split\")\n",
"ax.set(xscale=\"log\", title=\"Distribution of types in dataset splits\", xlabel=\"Count\", ylabel=\"Type\")\n",
"ax = sns.countplot(\n",
" data=per_type_tokens.join(instruct_ds_tok.set_index(\"_id\"), how=\"right\"),\n",
" y=\"type\",\n",
" hue=\"split\",\n",
")\n",
"ax.set(\n",
" xscale=\"log\",\n",
" title=\"Distribution of types in dataset splits\",\n",
" xlabel=\"Count\",\n",
" ylabel=\"Type\",\n",
")\n",
"plt.show()"
]
}
Expand Down
Loading

0 comments on commit 5cc97e0

Please sign in to comment.