Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Encode minicoil vectors to Qdrant #1

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 90 additions & 8 deletions minicoil_demo/model/sparse_vector.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,114 @@

from typing import List
import mmh3
from minicoil_demo.model.mini_coil import MiniCOIL
from minicoil_demo.model.stopwords import english_stopwords
from fastembed.common.utils import get_all_punctuation

from qdrant_client import models

GAP = 32000
INT32_MAX = 2**31 - 1
english_stopwords = set(english_stopwords)
punctuation = set(get_all_punctuation())
special_tokens = set(['[CLS]', '[SEP]', '[PAD]', '[UNK]', '[MASK]']) #TBD do better

def normalize_vector(vector: List[float]) -> List[float]:
norm = sum([x ** 2 for x in vector]) ** 0.5
return [x / norm for x in vector]

def unkn_word_token_id(word: str, shift: int) -> int: #2-3 words can collied in 1 index with this mapping, not considering mm3 collisions
hash = mmh3.hash(word)

if hash < 0:
unsigned_hash = hash + 2**32
else:
unsigned_hash = hash

range_size = INT32_MAX - shift
remapped_hash = shift + (unsigned_hash % range_size)

return remapped_hash

def embedding_to_vector(model: MiniCOIL, sentence_embedding: List[dict]) -> models.SparseVector:
indicies = []
def bm25_tf(num_occurrences: int, sentence_len: int, k: float = 1.2, b: float = 0.75, avg_len: float = 6.0) -> float: #avg_len 25 for quora
#omitted checking token_max_lenth
res = num_occurrences * (k + 1)
res /= num_occurrences + k * (1 - b + b * sentence_len / avg_len)
return res

def embedding_to_vector(model: MiniCOIL, sentence_embedding: dict) -> models.SparseVector:
indices = []
values = []

embedding_size = model.output_dim
vocab_size = model.vocab_resolver.vocab_size() #mini_coil.vocab_resolver.vocab_size() returns "vocab_size + 1" ("-1" to any word)

#still dependent on vocab_size :(
unknown_words_shift = ((vocab_size * embedding_size) // GAP + 2) * GAP #miniCOIL vocab + at least (32000 // embedding_size) + 1 new words gap

#we can't use fastembed's def remove_non_alphanumeric(text: str) unless propagating it right to vocab_resolver
sentence_len = 0
for embedding in sentence_embedding.values():
if embedding["word"] not in punctuation | english_stopwords | special_tokens:
sentence_len += embedding["count"]

#print(f"Sentence len is {sentence_len}")

#BM25 will always return a positive value, miniCOIL - nope
#So, if a word is familiar to miniCOIL, and in one text it's with a +sign (in some dims of the 4 dims),
#while in another it has a -sign in the same dim, then we penalize the match between these documents compared to the documents where this word is not present
#maybe it's not so good(?)

for embedding in sentence_embedding.values():
word_id = embedding["word_id"]

if word_id >= 0:
num_occurences = embedding["count"]

if word_id >= 0: #miniCOIL starts with ID 1
#print(f"""We counted {num_occurences} occurences of \"{embedding["word"]}\"""")
embedding = embedding["embedding"]
normalized_embedding = normalize_vector(embedding)
for val_id, value in enumerate(normalized_embedding):
indicies.append(word_id * embedding_size + val_id)
values.append(value)
indices.append((word_id - 1) * embedding_size + val_id) #since miniCOIL IDs start with 1
#TBD perhaps only if it's positive <THNK>
values.append(value * bm25_tf(num_occurences, sentence_len))
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps only for positive values, or to remap miniCOIL output to range [0, 1] with norm 1, or smth against this penalizing of matching if some word "meaning" (dimension) is opposite in 2 docs;

if word_id == -1: #unk
if embedding["word"] not in punctuation | english_stopwords | special_tokens:
#print(f"""We counted {num_occurences} occurences of \"{embedding["word"]}\"""")
indices.append(unkn_word_token_id(embedding["word"], unknown_words_shift))
values.append(bm25_tf(num_occurences, sentence_len))

return models.SparseVector(
indices=indicies,
indices=indices,
values=values,
)

def query_embedding_to_vector(model: MiniCOIL, sentence_embedding: dict) -> models.SparseVector:
indices = []
values = []

embedding_size = model.output_dim
vocab_size = model.vocab_resolver.vocab_size() #mini_coil.vocab_resolver.vocab_size() returns "vocab_size + 1" ("-1" to any word)

#still dependent on vocab_size :(
unknown_words_shift = ((vocab_size * embedding_size) // GAP + 2) * GAP #miniCOIL vocab + at least (32000 // embedding_size) + 1 new words gap

for embedding in sentence_embedding.values():
word_id = embedding["word_id"]

if word_id >= 0: #miniCOIL starts with ID 1
#print(f"""We counted {num_occurences} occurences of \"{embedding["word"]}\"""")
embedding = embedding["embedding"]
normalized_embedding = normalize_vector(embedding)
for val_id, value in enumerate(normalized_embedding):
indices.append((word_id - 1) * embedding_size + val_id) #since miniCOIL IDs start with 1
#TBD perhaps only if it's positive <THNK>
values.append(value)
if word_id == -1: #unk
if embedding["word"] not in punctuation | english_stopwords | special_tokens:
#print(f"""We counted {num_occurences} occurences of \"{embedding["word"]}\"""")
indices.append(unkn_word_token_id(embedding["word"], unknown_words_shift))
values.append(1)

return models.SparseVector(
indices=indices,
values=values,
)
3 changes: 2 additions & 1 deletion minicoil_demo/model/vocab_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,8 @@ def resolve_tokens(self, token_ids: np.ndarray) -> Tuple[np.ndarray, dict, dict,
token_ids[token_id] = vocab_id

if vocab_id == 0:
oov_count[token] += 1
#oov_count[token] += 1
oov_count[self.stemmer.stem_word(token)] += 1 #since we need to stem also for bm25, not considering stemmer disabling
else:
counts[vocab_id] += 1

Expand Down
141 changes: 115 additions & 26 deletions minicoil_demo/tools/encode_to_qdrant.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import argparse

import os
from typing import Iterable, List
from typing import Iterable
import uuid
import json

from qdrant_client import QdrantClient, models

Expand All @@ -11,21 +12,41 @@
from minicoil_demo.model.mini_coil import MiniCOIL
from minicoil_demo.model.sparse_vector import embedding_to_vector

from fastembed import SparseTextEmbedding

DEFAULT_MODEL_NAME = os.getenv("MODEL_NAME", "minicoil.model")


def read_file(file_path):
def read_file(file_path: str):
with open(file_path, "r") as f:
for line in f:
yield line.strip()


def embedding_stream(model: MiniCOIL, file_path) -> Iterable[dict]:
def read_file_beir(file_path: str) -> Iterable[str]:
with open(file_path, "r") as file:
for line in file:
row = json.loads(line)
yield row["_id"], row["text"]

def read_texts_beir(file_path: str) -> Iterable[str]:
with open(file_path, "r") as file:
for line in file:
row = json.loads(line)
yield row["text"]

def embedding_stream(model: MiniCOIL, file_path: str) -> Iterable[dict]:
stream = read_file(file_path)
for sentence_embeddings in model.encode_steam(stream, parallel=4):
yield sentence_embeddings


def embedding_stream_beir(model: MiniCOIL, file_path: str) -> Iterable[dict]:
stream = read_texts_beir(file_path)
for sentence_embeddings in model.encode_steam(stream, parallel=4):
yield sentence_embeddings


def read_points(model: MiniCOIL, file_path: str):
sentences = read_file(file_path)
embeddings = embedding_stream(model, file_path=file_path)
Expand All @@ -41,48 +62,116 @@ def read_points(model: MiniCOIL, file_path: str):
"sentence": sentence
}
)


def read_points_beir(model: MiniCOIL, file_path: str, total_points: int = 523000) -> Iterable[models.PointStruct]:
embeddings = embedding_stream_beir(model, file_path=file_path)
sparse_vectors = map(lambda x: embedding_to_vector(model, x), embeddings)

for ((id_text, text), sparse_vector) in tqdm.tqdm(zip(read_file_beir(file_path), sparse_vectors), total=total_points, desc="Processing points, BEIR"): #quora
yield models.PointStruct(
id=int(id_text),
vector={
"minicoil": sparse_vector,
},
payload={
"sentence": text
}
)

def read_points_beir_bm25(model: SparseTextEmbedding, file_path: str, total_points: int = 523000) -> Iterable[models.PointStruct]:
for ((id_text, text), embedding) in zip(read_file_beir(file_path), model.embed(tqdm.tqdm(read_texts_beir(file_path), total=total_points, desc="Processing points, BEIR"), batch_size=32)):
yield models.PointStruct(
id=int(id_text),
vector={
"bm25": models.SparseVector(
values=embedding.values.tolist(),
indices=embedding.indices.tolist()
)
},
payload={
"sentence": text
}
)


def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model-name", type=str)
parser.add_argument("--input-path", type=str)
parser.add_argument("--is-beir-dataset", action="store_true")
parser.add_argument("--collection-name", type=str, default="minicoil-demo")

args = parser.parse_args()

model_name = args.model_name or DEFAULT_MODEL_NAME
vocab_path = os.path.join(DATA_DIR, f"{model_name}.vocab")
model_path = os.path.join(DATA_DIR, f"{model_name}.npy")

transformer_model = "jinaai/jina-embeddings-v2-small-en-tokens"

mini_coil = MiniCOIL(
vocab_path=vocab_path,
word_encoder_path=model_path,
sentence_encoder_model=transformer_model
)
if model_name == 'bm25':
model = SparseTextEmbedding(
model_name="Qdrant/bm25",
avg_len=6.0 #if DATASET == "quora" else 256.,
)
elif model_name == 'minicoil.model':
vocab_path = os.path.join(DATA_DIR, f"{model_name}.vocab")
model_path = os.path.join(DATA_DIR, f"{model_name}.npy")
transformer_model = "jinaai/jina-embeddings-v2-small-en-tokens"
model = MiniCOIL(
vocab_path=vocab_path,
word_encoder_path=model_path,
sentence_encoder_model=transformer_model
)
else:
print(f'''{model_name} is not supported''')

qdrant_cleint = QdrantClient(
qdrant_client = QdrantClient(
url=QDRANT_URL,
api_key=QDRANT_API_KEY
)

if not qdrant_cleint.collection_exists(args.collection_name):
qdrant_cleint.create_collection(
collection_name=args.collection_name,
vectors_config={},
sparse_vectors_config={
"minicoil": models.SparseVectorParams()
}
)
if model_name == 'bm25':
if not qdrant_client.collection_exists(args.collection_name):
qdrant_client.create_collection(
collection_name=args.collection_name,
vectors_config={},
sparse_vectors_config={
"bm25": models.SparseVectorParams()
}
)
elif model_name == 'minicoil.model':
if not qdrant_client.collection_exists(args.collection_name):
qdrant_client.create_collection(
collection_name=args.collection_name,
vectors_config={},
sparse_vectors_config={
"minicoil": models.SparseVectorParams()
}
)
else:
print(f'''{model_name} is not supported''')

import ipdb
with ipdb.launch_ipdb_on_exception():
qdrant_cleint.upload_points(
collection_name=args.collection_name,
points=tqdm.tqdm(read_points(mini_coil, args.input_path))
)

if args.is_beir_dataset:
if model_name == 'bm25':
with ipdb.launch_ipdb_on_exception():
qdrant_client.upload_points(
collection_name=args.collection_name,
points=tqdm.tqdm(read_points_beir_bm25(model, args.input_path))
)
elif model_name == "minicoil.model":
with ipdb.launch_ipdb_on_exception():
qdrant_client.upload_points(
collection_name=args.collection_name,
points=tqdm.tqdm(read_points_beir(model, args.input_path))
)
else:
print(f'''{model_name} is not supported''')
else:
with ipdb.launch_ipdb_on_exception():
qdrant_client.upload_points(
collection_name=args.collection_name,
points=tqdm.tqdm(read_points(model, args.input_path))
)


if __name__ == '__main__':
Expand Down
8 changes: 4 additions & 4 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ py-rust-stemmers = "^0.1.3"
qdrant-client = "^1.12.0"
tokenizers = ">=0.15,<1.0"
ipdb = "^0.13.13"

mmh3 = "^4.1.0"

[tool.poetry.dev-dependencies]

Expand Down