From 28512a8ce82d09395493c2cba8812a343ef0011c Mon Sep 17 00:00:00 2001 From: Ravi Shankar <42587315+ravishankar63@users.noreply.github.com> Date: Tue, 19 Sep 2023 12:07:42 +0530 Subject: [PATCH 1/2] added bigram tokenization --- daras_ai_v2/vector_search.py | 18 ++++++++++++++---- recipes/VideoBots.py | 4 ++-- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/daras_ai_v2/vector_search.py b/daras_ai_v2/vector_search.py index bf6f6e61c..978f74e31 100644 --- a/daras_ai_v2/vector_search.py +++ b/daras_ai_v2/vector_search.py @@ -124,12 +124,13 @@ def get_top_k_references( if sparse_weight: # get sparse scores tokenized_corpus = [ - bm25_tokenizer(ref["title"]) + bm25_tokenizer(ref["snippet"]) + bm25_tokenizer(ref["title"], n_gram=2) + + bm25_tokenizer(ref["snippet"], n_gram=2) for ref, _ in embeds ] bm25 = BM25Okapi(tokenized_corpus, k1=2, b=0.3) sparse_query_tokenized = bm25_tokenizer( - request.keyword_query or request.search_query + request.keyword_query or request.search_query, n_gram=2 ) if sparse_query_tokenized: sparse_scores = np.array(bm25.get_scores(sparse_query_tokenized)) @@ -175,8 +176,17 @@ def get_top_k_references( bm25_split_re = re.compile(rf"[{puncts}\s]") -def bm25_tokenizer(text: str) -> list[str]: - return [t for t in bm25_split_re.split(text.lower()) if t] +def bm25_tokenizer(text: str, n_gram=2) -> list[str]: + tokens = bm25_split_re.split(text.lower()) + n_grams = [] + + n_grams.extend(tokens) + if n_gram == 2: + for i in range(len(tokens) - 1): + n_gram_text = " ".join(tokens[i:i + 2]) + n_grams.append(n_gram_text) + + return n_grams def references_as_prompt(references: list[SearchReference], sep="\n\n") -> str: diff --git a/recipes/VideoBots.py b/recipes/VideoBots.py index 334e1fd46..dea915049 100644 --- a/recipes/VideoBots.py +++ b/recipes/VideoBots.py @@ -619,11 +619,11 @@ def run(self, state: dict) -> typing.Iterator[str | None]: keyword_instructions = (request.keyword_instructions or "").strip() if keyword_instructions: - yield "Exctracting keywords..." + yield "Extracting keywords..." state["final_keyword_query"] = generate_final_search_query( request=request, instructions=keyword_instructions, - context={**state, "messages": chat_history}, + context={**state, "messages": f'{user_prompt["role"]}: """{user_prompt["content"]}"""'}, ) # perform doc search From 67571bad992f0ab880035f7ebcd082dae3278de4 Mon Sep 17 00:00:00 2001 From: Ravi Shankar <42587315+ravishankar63@users.noreply.github.com> Date: Tue, 19 Sep 2023 12:29:56 +0530 Subject: [PATCH 2/2] linting fix --- daras_ai_v2/vector_search.py | 4 ++-- recipes/VideoBots.py | 5 ++++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/daras_ai_v2/vector_search.py b/daras_ai_v2/vector_search.py index 978f74e31..ca57b6c11 100644 --- a/daras_ai_v2/vector_search.py +++ b/daras_ai_v2/vector_search.py @@ -183,9 +183,9 @@ def bm25_tokenizer(text: str, n_gram=2) -> list[str]: n_grams.extend(tokens) if n_gram == 2: for i in range(len(tokens) - 1): - n_gram_text = " ".join(tokens[i:i + 2]) + n_gram_text = " ".join(tokens[i : i + 2]) n_grams.append(n_gram_text) - + return n_grams diff --git a/recipes/VideoBots.py b/recipes/VideoBots.py index dea915049..9331294b9 100644 --- a/recipes/VideoBots.py +++ b/recipes/VideoBots.py @@ -623,7 +623,10 @@ def run(self, state: dict) -> typing.Iterator[str | None]: state["final_keyword_query"] = generate_final_search_query( request=request, instructions=keyword_instructions, - context={**state, "messages": f'{user_prompt["role"]}: """{user_prompt["content"]}"""'}, + context={ + **state, + "messages": f'{user_prompt["role"]}: """{user_prompt["content"]}"""', + }, ) # perform doc search