Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add weaviate vectors #31

Merged
merged 61 commits into from
Dec 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
a3999bd
Add pre-commit and integrate it with CI
binkjakub Jun 24, 2024
8f189b3
Apply pre-commit style fixes
binkjakub Jun 24, 2024
1b8e84e
Add sorting imports to ruff
binkjakub Jun 24, 2024
96d6ca1
Fix encoding in instruction dataset
binkjakub Jun 24, 2024
80f7002
Refactor and make deterministic predictions
binkjakub Jul 3, 2024
04b8c47
Add llm-as-judge baseline results
binkjakub Jul 3, 2024
4f79883
Make llm-as-judge using API instead of local LLM
binkjakub Jul 3, 2024
54b9c8c
Fix missing attention_mask in prediction and add sampled predictions
binkjakub Jul 4, 2024
f85c48f
Update data description with output_size quantile
binkjakub Jul 4, 2024
1405116
Update packages
binkjakub Jul 4, 2024
c9b7bf2
Refactor fine-tuning scripts and change hparams, make common code for…
binkjakub Jul 8, 2024
3a06f31
Refactor evaluation code
binkjakub Jul 8, 2024
ada17a3
Fix unsloth peft
binkjakub Jul 11, 2024
f2af0f7
Add optimized unsloth inference and inference telemetry
binkjakub Jul 15, 2024
513343f
Reproduce SFT on update and fixed code
binkjakub Aug 8, 2024
3c2cf3e
Fix function evaluating info-extraction metrics
binkjakub Aug 8, 2024
4ca650e
Evaluate updated Mistral and summarize metrics with mean and std
binkjakub Aug 8, 2024
fadb582
Fix computing maximum sequence length (account for output tokens duri…
binkjakub Aug 9, 2024
1d8df86
Add llm-as-judge preliminary results (too much non-evaluable)
binkjakub Aug 11, 2024
4161ba2
Add unit test for parsing prediction output
binkjakub Aug 11, 2024
d056947
Fix bug with CHRF computation
binkjakub Aug 11, 2024
00bffb3
Add unit tests for chrf and fix metric params
binkjakub Aug 11, 2024
c46c476
Rollback to legacy fine-tuning config due to unsloth compatibility issue
binkjakub Aug 11, 2024
ff2b3e3
Add mistral-nemo LLM
binkjakub Aug 12, 2024
f29864f
Add cuda version argument in Makefile
binkjakub Aug 12, 2024
7824769
Fix llm-as-judge implementation
binkjakub Aug 12, 2024
d104981
Disable CI on windows for now (utf-8 bugs)
binkjakub Aug 12, 2024
8ab3078
Reproduce Mistral-Nemo and summarize all results
binkjakub Aug 12, 2024
f95ec86
Add prediction with OpenAI GPT models
binkjakub Aug 23, 2024
57ea86d
Update README.md with reproduction instruction
binkjakub Aug 25, 2024
47a8f6b
Migrate OpenAI interface to langchain to enable request caching
binkjakub Aug 25, 2024
12c1bb9
Add results for gpt-4o and gpt-4o-mini
binkjakub Aug 25, 2024
ceaed8b
Add Bielik v0.1 LLM
binkjakub Aug 25, 2024
7dfd948
Add building english instruct dataset and infer over it with gpt-4o-mini
binkjakub Aug 26, 2024
5665e75
Add optional chat templating
binkjakub Aug 27, 2024
7d40fc5
Add new polish LLMs and separate configs for en and pl fine-tuned models
binkjakub Aug 27, 2024
5afd7e3
Fix pl-court-raw dataset by filtering records with empty content from
binkjakub Aug 27, 2024
e9a1df5
Add null-count info to pl-court-raw README
binkjakub Aug 28, 2024
bf0c034
Reproduce Polish and English LLMs
binkjakub Aug 28, 2024
6e18f71
Fix pre-commit exclude
binkjakub Aug 28, 2024
89f5199
Add bielik v2
binkjakub Aug 28, 2024
d8b83c8
Fix missing gpt-4o outputs and reproduce English data on gpt
binkjakub Aug 28, 2024
f92a4e0
Add llm-as-judge evaluation for English and Polish using gpt-4o-mini
binkjakub Aug 29, 2024
593439d
Fix evaluation parser to account for non-yaml text at the beginning o…
binkjakub Aug 29, 2024
a9cc9b3
Reproduce evaluation on fixed parsing
binkjakub Aug 29, 2024
5f79cce
Fix langchain versions
binkjakub Aug 29, 2024
e27516c
Reproduce llm-as-judge with fixed prompt
binkjakub Aug 29, 2024
3d2d38d
Fix default num_proc in structured evaluator
binkjakub Aug 30, 2024
774048b
Reproduce Bielik v2
binkjakub Aug 31, 2024
12374a5
Add text to chunk embeddings data and weaviate basic ingestion
binkjakub Sep 18, 2024
147815e
Add weaviate deployment
binkjakub Sep 18, 2024
4002989
Add two modes for weaviate ingest
binkjakub Sep 20, 2024
ce619a7
Add hf-transformers vectorizer to weaviate collection
binkjakub Sep 20, 2024
40171ae
Add weaviate example
binkjakub Sep 20, 2024
6fdcee4
Add hf-transformer module to weaviate deployment
binkjakub Sep 20, 2024
b305904
Add weaviate deployment info
binkjakub Sep 20, 2024
27c7eca
Add usage info to weaviate README.md
binkjakub Sep 20, 2024
ef6170e
Fix env variables naming
binkjakub Sep 21, 2024
13ac088
Fix env variables naming
binkjakub Sep 23, 2024
8510775
Merge branch 'master' into add-weaviate-vectors
binkjakub Oct 2, 2024
e245236
Fix _moddix
binkjakub Oct 2, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions configs/embedding.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ chunk_config:
chunk_size: ${embedding_model.max_seq_length}
min_split_chars: 10
take_n_first_chunks: 16
chunk_overlap: 32
batch_size: 64

output_dir: data/embeddings/${dataset.name}/${hydra:runtime.choices.embedding_model}/all_embeddings
Expand Down
594 changes: 242 additions & 352 deletions dvc.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions juddges/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
'juddges.data.datasets.utils': {},
'juddges.data.pl_court_api': {},
'juddges.data.pl_court_graph': {},
'juddges.data.weaviate_db': {},
'juddges.evaluation.eval_full_text': {},
'juddges.evaluation.eval_structured': {},
'juddges.evaluation.eval_structured_llm_judge': {},
Expand Down
112 changes: 112 additions & 0 deletions juddges/data/weaviate_db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import re
from abc import ABC, abstractmethod
from typing import Any, ClassVar

import weaviate
import weaviate.classes.config as wvcc
from weaviate.auth import Auth, _APIKey


class WeaviateDatabase(ABC):
def __init__(self, host: str, port: str, grpc_port: str, api_key: str | None):
self.host = host
self.port = port
self.grpc_port = grpc_port
self.__api_key = api_key

self.client: weaviate.WeaviateClient

def __enter__(self) -> "WeaviateDatabase":
self.client = weaviate.connect_to_local(
host=self.host,
port=self.port,
grpc_port=self.grpc_port,
auth_credentials=self.api_key,
)
self.create_collections()
return self

def __exit__(self, exc_type, exc_value, traceback) -> None:
if hasattr(self, "client"):
self.client.close()

def __del__(self) -> None:
self.__exit__(None, None, None)

@property
def api_key(self) -> _APIKey | None:
if self.__api_key is not None:
return Auth.api_key(self.__api_key)
return None

@abstractmethod
def create_collections(self) -> None:
pass

def insert_batch(
self,
collection: weaviate.collections.Collection,
objects: list[dict[str, Any]],
) -> None:
with collection.batch.dynamic() as wv_batch:
for obj in objects:
wv_batch.add_object(**obj)
if wv_batch.number_errors > 0:
break
if wv_batch.number_errors > 0:
errors = [err.message for err in collection.batch.results.objs.errors.values()]
raise ValueError(f"Error ingesting batch: {errors}")

def get_uuids(self, collection: weaviate.collections.Collection) -> list[str]:
return [str(obj.uuid) for obj in collection.iterator(return_properties=[])]

def _safe_create_collection(self, *args: Any, **kwargs: Any) -> None:
try:
self.client.collections.create(*args, **kwargs)
except weaviate.exceptions.UnexpectedStatusCodeError as err:
if (
re.search(r"class name (\w+?) already exists", err.message)
and err.status_code == 422
):
pass
else:
raise


class WeaviateJudgementsDatabase(WeaviateDatabase):
JUDGMENTS_COLLECTION: ClassVar[str] = "judgements"
JUDGMENT_CHUNKS_COLLECTION: ClassVar[str] = "judgement_chunks"

@property
def judgements_collection(self) -> weaviate.collections.Collection:
return self.client.collections.get(self.JUDGMENTS_COLLECTION)

@property
def judgement_chunks_collection(self) -> weaviate.collections.Collection:
return self.client.collections.get(self.JUDGMENT_CHUNKS_COLLECTION)

def create_collections(self) -> None:
self._safe_create_collection(
name=self.JUDGMENTS_COLLECTION,
properties=[
wvcc.Property(name="judgement_id", data_type=wvcc.DataType.TEXT),
],
)
self._safe_create_collection(
name=self.JUDGMENT_CHUNKS_COLLECTION,
properties=[
wvcc.Property(name="chunk_id", data_type=wvcc.DataType.INT),
wvcc.Property(name="chunk_text", data_type=wvcc.DataType.TEXT),
],
vectorizer_config=wvcc.Configure.Vectorizer.text2vec_transformers(),
references=[
wvcc.ReferenceProperty(
name="judgementChunk",
target_collection=self.JUDGMENTS_COLLECTION,
)
],
)

@staticmethod
def uuid_from_judgement_chunk_id(judgement_id: str, chunk_id: int) -> str:
return weaviate.util.generate_uuid5(f"{judgement_id}_chunk_{chunk_id}")
2 changes: 2 additions & 0 deletions juddges/preprocessing/text_chunker.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ class TextSplitter:
def __init__(
self,
chunk_size: int,
chunk_overlap: int | None = None,
min_split_chars: int | None = None,
take_n_first_chunks: int | None = None,
tokenizer: PreTrainedTokenizer | None = None,
Expand All @@ -16,6 +17,7 @@ def __init__(
self.splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer(
tokenizer,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
else:
self.splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size)
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ transformers==4.42.3
trl==0.9.4
typer==0.9.0
wandb==0.16.5
weaviate-client==4.8.1
xmltodict==0.13.0
xlsxwriter==3.2.0

Expand Down
25 changes: 15 additions & 10 deletions scripts/embed/embed_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from omegaconf import DictConfig
from openai import BaseModel
from sentence_transformers import SentenceTransformer
from transformers import PreTrainedTokenizer
from transformers.utils import is_flash_attn_2_available

from juddges.config import EmbeddingModelConfig, RawDatasetConfig
Expand All @@ -21,6 +22,7 @@

NUM_PROC = int(os.getenv("NUM_PROC", 1))
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
os.environ["TOKENIZERS_PARALLELISM"] = "false" if (NUM_PROC > 1) else "true"


class EmbeddingConfig(BaseModel, extra="forbid"):
Expand Down Expand Up @@ -51,19 +53,19 @@ def main(cfg: DictConfig) -> None:
)
ds = ds.filter(lambda item: item["text"] is not None)

if config.chunk_config is not None:
ds = chunk_dataset(ds, config)
text_column = "text_chunk"
else:
text_column = "text"

model = SentenceTransformer(
config.embedding_model.name,
device=DEVICE,
model_kwargs=dict(torch_dtype=torch.bfloat16),
)
model.compile()

if config.chunk_config is not None:
ds = chunk_dataset(dataset=ds, config=config, tokenizer=model.tokenizer)
text_column = "text_chunk"
else:
text_column = "text"

if config.truncation_tokens is not None:
assert config.truncation_tokens <= config.embedding_model.max_seq_length
model.max_seq_length = config.truncation_tokens
Expand All @@ -74,19 +76,22 @@ def main(cfg: DictConfig) -> None:
batched=True,
batch_size=config.batch_size,
num_proc=None,
remove_columns=[text_column],
desc="Embedding chunks",
)
ds.save_to_disk(config.output_dir)
ds.save_to_disk(str(config.output_dir))

with open(config.output_dir / "config.yaml", "w") as f:
yaml.dump(config.model_dump(), f)


def chunk_dataset(dataset: Dataset, config: EmbeddingConfig) -> Dataset:
def chunk_dataset(
dataset: Dataset,
config: EmbeddingConfig,
tokenizer: PreTrainedTokenizer | None = None,
) -> Dataset:
# todo: To be verified
assert config.chunk_config is not None
split_worker = TextSplitter(**config.chunk_config)
split_worker = TextSplitter(**config.chunk_config, tokenizer=tokenizer)
ds = dataset.select_columns(["_id", "text"]).map(
split_worker,
batched=True,
Expand Down
File renamed without changes.
79 changes: 79 additions & 0 deletions scripts/embed/ingest_weaviate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import math
import os
from pathlib import Path

import typer
from datasets import load_dataset
from dotenv import load_dotenv
from loguru import logger
from tqdm.auto import tqdm

from juddges.data.weaviate_db import WeaviateJudgementsDatabase
from weaviate.util import generate_uuid5

load_dotenv()
WV_HOST = os.getenv("WV_HOST", "localhost")
WV_PORT = os.getenv("WV_PORT", "8080")
WV_GRPC_PORT = os.getenv("WV_GRPC_PORT", "50051")
WV_API_KEY = os.getenv("WV_API_KEY", None)

BATCH_SIZE = 64
NUM_PROC = int(os.getenv("NUM_PROC", 1))

logger.info(f"Connecting to Weaviate at {WV_HOST}:{WV_PORT} (gRPC: {WV_GRPC_PORT})")


def main(
embeddings_dir: Path = typer.Option(...),
batch_size: int = typer.Option(BATCH_SIZE),
upsert: bool = typer.Option(False),
) -> None:
logger.warning(
"The script will upload local embeddings to the database, "
"make sure they are the same as in the inference module of the database."
)
embs = load_dataset(str(embeddings_dir))["train"]
embs = embs.map(
lambda item: {
"uuid": WeaviateJudgementsDatabase.uuid_from_judgement_chunk_id(
judgement_id=item["_id"], chunk_id=item["chunk_id"]
)
},
num_proc=NUM_PROC,
desc="Generating UUIDs",
)
with WeaviateJudgementsDatabase(WV_HOST, WV_PORT, WV_GRPC_PORT, WV_API_KEY) as db:
if not upsert:
logger.info("upsert disabled - uploading only new embeddings")
uuids = set(db.get_uuids(db.judgement_chunks_collection))
embs = embs.filter(lambda item: item["uuid"] not in uuids)
else:
logger.info(
"upsert enabled - uploading all embeddings (automatically updating already uploaded)"
)

for batch in tqdm(
embs.iter(batch_size=batch_size),
total=math.ceil(len(embs) / batch_size),
desc="Uploading batches",
):
objects = [
{
"properties": {
"judgment_id": batch["_id"][i],
"chunk_id": batch["chunk_id"][i],
"chunk_text": batch["text_chunk"][i],
},
"uuid": generate_uuid5(f"{batch['_id'][i]}_chunk_{batch['chunk_id'][i]}"),
"vector": batch["embedding"][i],
}
for i in range(len(batch["_id"]))
]
db.insert_batch(
collection=db.judgement_chunks_collection,
objects=objects,
)


if __name__ == "__main__":
typer.run(main)
38 changes: 38 additions & 0 deletions scripts/embed/weaviate_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import os
from pprint import pprint

from dotenv import load_dotenv

import weaviate
from weaviate.collections.classes.grpc import MetadataQuery

load_dotenv()
WV_HOST = os.getenv("WV_URL", "localhost")
WV_PORT = int(os.getenv("WV_PORT", 8080))
WV_GRPC_PORT = int(os.getenv("WV_GRPC_PORT", 50051))
WV_API_KEY = os.getenv("WV_API_KEY", None)

QUERY_PROMPT = "zapytanie: {query}"

# NOTE: This is standalone example, for convenience you can use judgements/data/weaviate_db.py
with weaviate.connect_to_local(
host=WV_HOST,
port=WV_PORT,
grpc_port=WV_GRPC_PORT,
auth_credentials=weaviate.auth.Auth.api_key(WV_API_KEY),
) as client:
coll = client.collections.get("judgement_chunks")
response = coll.query.hybrid(
query=QUERY_PROMPT.format(query="oskarżony handlował narkotykami"),
limit=2,
return_metadata=MetadataQuery(distance=True),
)

for o in response.objects:
print(
f"{o.properties['judgment_id']} - {o.properties['chunk_id']}".center(
100,
"=",
)
)
pprint(o.properties["chunk_text"])
16 changes: 16 additions & 0 deletions weaviate/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Weaviate deployment

## Instruction
1. Prepare `.env` file with proper user names and API tokens
```bash
cp example.env .env
```
2. Run containers through docker-compose
```bash
docker compose up -d
```

## Remarks
* Persistent data will be stored inside mounted `./weaviate_data` path
* Deployment was tested on machine with 16 CPU, 64GB memory, and without GPU (vectors were computed outside weaviate instance, `t2v-transformers` used only for inference)
* see [scripts/embed/weaviate_example.py](../scripts/embed/weaviate_example.py) to see search example usage
33 changes: 33 additions & 0 deletions weaviate/docker-compose.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
name: weaviate
services:
weaviate:
command:
- --host
- 0.0.0.0
- --port
- '8080'
- --scheme
- http
image: cr.weaviate.io/semitechnologies/weaviate:1.26.4
depends_on:
- t2v-transformers
ports:
- 8080:8080
- 50051:50051
volumes:
- ./weaviate_data:/var/lib/weaviate
restart: on-failure:0
env_file:
- path: .env
required: true
cpu_count: 14
mem_limit: 60g

t2v-transformers:
build:
context: .
dockerfile: hf_transformers.dockerfile
args:
- MODEL_NAME=sdadas/mmlw-roberta-large
environment:
ENABLE_CUDA: 0 # Set to 1 to enable
Loading
Loading