Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/dynamic rag support #1377

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ env:
WEAVIATE_124: 1.24.26
WEAVIATE_125: 1.25.24
WEAVIATE_126: 1.26.8
WEAVIATE_127: stable-v1.27-372397e
WEAVIATE_127: 1.27.1-8030027

jobs:
lint-and-format:
Expand Down
42 changes: 37 additions & 5 deletions integration/test_collection_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Property,
)
from weaviate.collections.classes.data import DataObject
from weaviate.collections.classes.generative import GenerativeProvider
from weaviate.collections.classes.grpc import GroupBy, Rerank
from weaviate.exceptions import WeaviateQueryError, WeaviateUnsupportedFeatureError
from weaviate.util import _ServerVersion
Expand Down Expand Up @@ -340,8 +341,6 @@ def test_near_object_generate_with_everything(openai_collection: OpenAICollectio
assert res.generated == "apples cats"
assert res.objects[0].generated is not None
assert res.objects[1].generated is not None
assert res.objects[0].generated.lower() == "yes"
assert res.objects[1].generated.lower() == "no"


def test_near_object_generate_and_group_by_with_everything(
Expand All @@ -355,7 +354,7 @@ def test_near_object_generate_and_group_by_with_everything(
[
DataObject(
properties={
"text": "apples are big. you cna eat apples",
"text": "apples are big. you can eat apples",
"content": "Teddy is the biggest and bigger than everything else",
}
),
Expand All @@ -380,8 +379,6 @@ def test_near_object_generate_and_group_by_with_everything(
groups = list(res.groups.values())
assert groups[0].generated is not None
assert groups[1].generated is not None
assert groups[0].generated.lower() == "no"
assert groups[1].generated.lower() == "yes"


def test_near_text_generate_with_everything(openai_collection: OpenAICollection) -> None:
Expand Down Expand Up @@ -644,3 +641,38 @@ def test_queries_with_rerank_and_generative(collection_factory: CollectionFactor
][
0
].metadata.rerank_score


def test_near_text_generate_with_dynamic_rag(openai_collection: OpenAICollection) -> None:
collection = openai_collection(
vectorizer_config=Configure.Vectorizer.text2vec_openai(vectorize_collection_name=False),
)

collection.data.insert_many(
[
DataObject(
properties={
"text": "melons are big",
"content": "Teddy is the biggest and bigger than everything else. Teddy is not a fruit",
}
),
DataObject(
properties={
"text": "cats are small. You cannot eat cats. Cats are not fruit",
"content": "bananas are the smallest and smaller than everything else",
}
),
]
)

res = collection.generate.near_text(
query="small fruit",
single_prompt="Is there something to eat in {text} of the given object? Only answer yes if there is something to eat and no if not. Dont use punctuation",
grouped_task="Write out the fruit in alphabetical order. Only write the names separated by a space",
generative_provider=GenerativeProvider.openai(
temperature=0.1,
),
)
assert res.generated == "bananas melons"
assert res.objects[0].generated is not None
assert res.objects[1].generated is not None
2 changes: 2 additions & 0 deletions weaviate/classes/query.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from weaviate.collections.classes.aggregate import Metrics
from weaviate.collections.classes.filters import Filter
from weaviate.collections.classes.generative import GenerativeProvider
from weaviate.collections.classes.grpc import (
HybridFusion,
GroupBy,
Expand All @@ -18,6 +19,7 @@
__all__ = [
"Filter",
"GeoCoordinate",
"GenerativeProvider",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we have this in wvc.generative.XXX?

"GroupBy",
"HybridFusion",
"HybridVector",
Expand Down
Loading