Skip to content

Commit

Permalink
Merge branch 'multiword_search_issue' into test
Browse files Browse the repository at this point in the history
  • Loading branch information
ravishankar63 committed Sep 20, 2023
2 parents 4dd6522 + 67571ba commit 52497d8
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 6 deletions.
18 changes: 14 additions & 4 deletions daras_ai_v2/vector_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,12 +124,13 @@ def get_top_k_references(
if sparse_weight:
# get sparse scores
tokenized_corpus = [
bm25_tokenizer(ref["title"]) + bm25_tokenizer(ref["snippet"])
bm25_tokenizer(ref["title"], n_gram=2)
+ bm25_tokenizer(ref["snippet"], n_gram=2)
for ref, _ in embeds
]
bm25 = BM25Okapi(tokenized_corpus, k1=2, b=0.3)
sparse_query_tokenized = bm25_tokenizer(
request.keyword_query or request.search_query
request.keyword_query or request.search_query, n_gram=2
)
if sparse_query_tokenized:
sparse_scores = np.array(bm25.get_scores(sparse_query_tokenized))
Expand Down Expand Up @@ -175,8 +176,17 @@ def get_top_k_references(
bm25_split_re = re.compile(rf"[{puncts}\s]")


def bm25_tokenizer(text: str) -> list[str]:
return [t for t in bm25_split_re.split(text.lower()) if t]
def bm25_tokenizer(text: str, n_gram=2) -> list[str]:
tokens = bm25_split_re.split(text.lower())
n_grams = []

n_grams.extend(tokens)
if n_gram == 2:
for i in range(len(tokens) - 1):
n_gram_text = " ".join(tokens[i : i + 2])
n_grams.append(n_gram_text)

return n_grams


def references_as_prompt(references: list[SearchReference], sep="\n\n") -> str:
Expand Down
7 changes: 5 additions & 2 deletions recipes/VideoBots.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,11 +619,14 @@ def run(self, state: dict) -> typing.Iterator[str | None]:

keyword_instructions = (request.keyword_instructions or "").strip()
if keyword_instructions:
yield "Exctracting keywords..."
yield "Extracting keywords..."
state["final_keyword_query"] = generate_final_search_query(
request=request,
instructions=keyword_instructions,
context={**state, "messages": chat_history},
context={
**state,
"messages": f'{user_prompt["role"]}: """{user_prompt["content"]}"""',
},
)

# perform doc search
Expand Down

0 comments on commit 52497d8

Please sign in to comment.