Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix bm25 & keyword search #564

Merged
merged 3 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion daras_ai_v2/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1795,7 +1795,9 @@ def render_variables(self):
if not self.functions_in_settings:
functions_input(self.request.user)
variables_input(
template_keys=self.template_keys, allow_add=is_functions_enabled()
template_keys=self.template_keys,
allow_add=is_functions_enabled(),
exclude=self.fields_to_save(),
)

@classmethod
Expand Down
11 changes: 6 additions & 5 deletions daras_ai_v2/query_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@ def generate_final_search_query(
context: dict = None,
response_format_type: typing.Literal["text", "json_object"] = None,
):
if context is None:
context = request.dict()
if response:
context |= response.dict()
instructions = render_prompt_vars(instructions, context).strip()
state = request.dict()
if response:
state |= response.dict()
if context:
state |= context
instructions = render_prompt_vars(instructions, state).strip()
if not instructions:
return ""
return run_language_model(
Expand Down
3 changes: 2 additions & 1 deletion daras_ai_v2/variables_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def variables_input(
description: str = "Variables let you pass custom parameters to your workflow. Access a variable in your instruction prompt with <a href='https://jinja.palletsprojects.com/en/3.1.x/templates/' target='_blank'>Jinja</a>, e.g. `{{ my_variable }}`\n ",
key: str = "variables",
allow_add: bool = False,
exclude: typing.Iterable[str] = (),
):
from recipes.BulkRunner import list_view_editor

Expand All @@ -45,7 +46,7 @@ def variables_input(
var_names = (
(template_var_names | set(variables.keys()))
- set(context_globals().keys()) # dont show global context variables
- set(gui.session_state.keys()) # dont show other session state variables
- set(exclude) # used for hiding request/response fields
)
pressed_add = False
if var_names or allow_add:
Expand Down
101 changes: 68 additions & 33 deletions daras_ai_v2/vector_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import re
import tempfile
import typing
import unicodedata
from functools import partial
from time import time

Expand Down Expand Up @@ -197,6 +198,7 @@ def get_top_k_references(
s = time()
search_result = query_vespa(
request.search_query,
request.keyword_query,
file_ids=vespa_file_ids,
limit=request.max_references or 100,
embedding_model=embedding_model,
Expand Down Expand Up @@ -245,34 +247,63 @@ def vespa_search_results_to_refs(

def query_vespa(
search_query: str,
keyword_query: str | list[str] | None,
file_ids: list[str],
limit: int,
embedding_model: EmbeddingModels,
semantic_weight: float = 1.0,
threshold: float = 0.7,
rerank_count: int = 1000,
) -> dict:
query_embedding = create_embeddings_cached([search_query], model=embedding_model)[0]
if query_embedding is None or not file_ids:
if not file_ids:
return {"root": {"children": []}}
file_ids_str = ", ".join(map(repr, file_ids))
query = f"select * from {settings.VESPA_SCHEMA} where file_id in (@fileIds) and (userQuery() or ({{targetHits: {limit}}}nearestNeighbor(embedding, q))) limit {limit}"
logger.debug(f"Vespa query: {query!r}")
if semantic_weight == 1.0:
ranking = "semantic"
elif semantic_weight == 0.0:

yql = "select * from %(schema)s where file_id in (@fileIds) and " % dict(
schema=settings.VESPA_SCHEMA
)
bm25_yql = "( {targetHits: %(hits)i} userInput(@bm25Query) )"
semantic_yql = "( {targetHits: %(hits)i, distanceThreshold: %(threshold)f} nearestNeighbor(embedding, queryEmbedding) )"

if semantic_weight == 0.0:
yql += bm25_yql % dict(hits=limit)
ranking = "bm25"
elif semantic_weight == 1.0:
yql += semantic_yql % dict(hits=limit, threshold=threshold)
ranking = "semantic"
else:
yql += (
"( "
+ bm25_yql % dict(hits=rerank_count)
+ " or "
+ semantic_yql % dict(hits=rerank_count, threshold=threshold)
+ " )"
)
ranking = "fusion"
response = get_vespa_app().query(
yql=query,
query=search_query,
ranking=ranking,
body={
"ranking.features.query(q)": padded_embedding(query_embedding),
"ranking.features.query(semanticWeight)": semantic_weight,
"fileIds": file_ids_str,
},

body = {"yql": yql, "ranking": ranking, "hits": limit}

if ranking in ("bm25", "fusion"):
if isinstance(keyword_query, list):
keyword_query = " ".join(keyword_query)
body["bm25Query"] = remove_control_characters(keyword_query or search_query)

logger.debug(
"vespa query " + " ".join(repr(f"{k}={v}") for k, v in body.items()) + " ..."
)

if ranking in ("semantic", "fusion"):
query_embedding = create_embeddings_cached(
[search_query], model=embedding_model
)[0]
if query_embedding is None:
return {"root": {"children": []}}
body["input.query(queryEmbedding)"] = padded_embedding(query_embedding)

body["fileIds"] = ", ".join(map(repr, file_ids))

response = get_vespa_app().query(body)
assert response.is_successful()

return response.get_json()


Expand Down Expand Up @@ -601,6 +632,23 @@ def _sha256(x) -> str:
return hashlib.sha256(str(x).encode()).hexdigest()


def format_embedding_row(
doc_id: str,
file_id: str,
ref: SearchReference,
embedding: np.ndarray,
created_at: datetime.datetime,
):
return dict(
id=doc_id,
file_id=file_id,
embedding=padded_embedding(embedding),
created_at=int(created_at.timestamp() * 1000),
title=remove_control_characters(ref["title"]),
snippet=remove_control_characters(ref["snippet"]),
)


def get_embeds_for_doc(
*,
f_url: str,
Expand Down Expand Up @@ -1063,22 +1111,9 @@ def render_sources_widget(refs: list[SearchReference]):
)


def format_embedding_row(
doc_id: str,
file_id: str,
ref: SearchReference,
embedding: np.ndarray,
created_at: datetime.datetime,
):
return dict(
id=doc_id,
file_id=file_id,
embedding=padded_embedding(embedding),
created_at=int(created_at.timestamp() * 1000),
# url=ref["url"].encode("unicode-escape").decode(),
# title=ref["title"].encode("unicode-escape").decode(),
# snippet=ref["snippet"].encode("unicode-escape").decode(),
)
def remove_control_characters(s):
# from https://docs.vespa.ai/en/troubleshooting-encoding.html
return "".join(ch for ch in s if unicodedata.category(ch)[0] != "C")
devxpy marked this conversation as resolved.
Show resolved Hide resolved


EMBEDDING_SIZE = 3072
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ loguru = "^0.7.2"
aifail = "^0.3.0"
pytest-playwright = "^0.4.3"
emoji = "^2.10.1"
pyvespa = "^0.39.0"
pyvespa = "^0.51.0"
anthropic = "^0.34.1"
azure-cognitiveservices-speech = "^1.37.0"
twilio = "^9.2.3"
Expand Down
1 change: 1 addition & 0 deletions recipes/Functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def render_variables(self):
allow_add=True,
description="Pass custom parameters to your function and access the parent workflow data. "
"Variables will be passed down as the first argument to your anonymous JS function.",
exclude=self.fields_to_save(),
)

options = set(gui.session_state.get("secrets") or [])
Expand Down
9 changes: 6 additions & 3 deletions recipes/VideoBots.py
Original file line number Diff line number Diff line change
Expand Up @@ -979,8 +979,9 @@ def search_step(self, request, response, user_input, model):
yield "Creating search query..."
response.final_search_query = generate_final_search_query(
request=request,
response=response,
instructions=query_instructions,
context={**gui.session_state, "messages": chat_history},
context={"messages": chat_history},
)
else:
query_msgs.reverse()
Expand All @@ -998,8 +999,9 @@ def search_step(self, request, response, user_input, model):
keyword_query = json.loads(
generate_final_search_query(
request=k_request,
response=response,
instructions=keyword_instructions,
context={**gui.session_state, "messages": chat_history},
context={"messages": chat_history},
response_format_type="json_object",
),
)
Expand All @@ -1011,7 +1013,8 @@ def search_step(self, request, response, user_input, model):
response.references = yield from get_top_k_references(
DocSearchRequest.parse_obj(
{
**gui.session_state,
**request.dict(),
**response.dict(),
"search_query": response.final_search_query,
"keyword_query": response.final_keyword_query,
},
Expand Down
92 changes: 26 additions & 66 deletions scripts/setup_vespa_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@
Schema,
Document,
Field,
FieldSet,
HNSW,
RankProfile,
FieldSet,
Function,
GlobalPhaseRanking,
QueryTypeField,
)

from daras_ai_v2 import settings
Expand All @@ -35,110 +34,71 @@
rank="filter",
),
Field(
name="url",
name="file_id",
type="string",
indexing=["attribute", "summary"],
),
Field(
name="title",
type="string",
indexing=["index", "summary"],
index="enable-bm25",
),
Field(
name="snippet",
type="string",
indexing=["index", "summary"],
index="enable-bm25",
attribute=["fast-search"],
rank="filter",
),
Field(
name="embedding",
type=EMBEDDING_TYPE,
indexing=["index", "attribute"],
ann=HNSW(distance_metric="dotproduct"),
),
Field(
name="file_id",
type="string",
indexing=["attribute", "summary"],
attribute=["fast-search"],
rank="filter",
),
Field(
name="created_at",
type="long",
indexing=["attribute"],
attribute=["fast-access"],
),
Field(
name="title",
type="string",
indexing=["index", "summary"],
index="enable-bm25",
),
Field(
name="snippet",
type="string",
indexing=["index", "summary"],
index="enable-bm25",
),
]
),
fieldsets=[FieldSet(name="default", fields=["title", "snippet"])],
rank_profiles=[
RankProfile(
name="bm25",
inputs=[
("query(q)", EMBEDDING_TYPE),
],
functions=[
Function(
name="bm25sum", expression="bm25(title) + bm25(snippet)"
)
],
first_phase="bm25sum",
first_phase="bm25(title) + bm25(snippet)",
),
RankProfile(
name="semantic",
inputs=[
("query(q)", EMBEDDING_TYPE),
],
inputs=[("query(queryEmbedding)", EMBEDDING_TYPE)],
first_phase="closeness(field, embedding)",
devxpy marked this conversation as resolved.
Show resolved Hide resolved
),
RankProfile(
name="fusion",
inherits="bm25",
inputs=[
("query(q)", EMBEDDING_TYPE),
("query(queryEmbedding)", EMBEDDING_TYPE),
("query(semanticWeight)", "double"),
],
first_phase="closeness(field, embedding)",
global_phase=GlobalPhaseRanking(
expression="""
if (closeness(field, embedding)>0.6,
reciprocal_rank(bm25sum) * (1 - query(semanticWeight)) +
reciprocal_rank(closeness(field, embedding)) * query(semanticWeight),
0)
""",
rerank_count=1000,
),
),
RankProfile(
name="fusion2", # with bm25 first
inherits="bm25",
inputs=[
("query(q)", EMBEDDING_TYPE),
("query(semanticWeight)", "double"),
functions=[
Function(
name="bm25sum",
expression="bm25(title) + bm25(snippet)",
),
],
first_phase="closeness(field, embedding)",
first_phase="bm25sum",
global_phase=GlobalPhaseRanking(
expression="""
if (bm25sum>0.6,
reciprocal_rank(bm25sum) * (1 - query(semanticWeight)) +
reciprocal_rank(closeness(field, embedding)) * query(semanticWeight),
0)
""",
expression="reciprocal_rank(bm25sum) * (1 - query(semanticWeight)) + reciprocal_rank(closeness(field, embedding)) * query(semanticWeight)",
rerank_count=1000,
),
),
],
)
],
)
package.query_profile_type.add_fields(
QueryTypeField(
name="ranking.features.query(q)",
type=EMBEDDING_TYPE,
),
)


def run():
Expand Down
Loading