Skip to content

Commit

Permalink
Merge pull request #1390 from weaviate/multi2vec-cohere
Browse files Browse the repository at this point in the history
Add support for multi2vec-cohere
  • Loading branch information
dirkkul authored Nov 15, 2024
2 parents 1392587 + 76cd752 commit 8f62889
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 0 deletions.
30 changes: 30 additions & 0 deletions test/collection/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,22 @@ def test_basic_config():
}
},
),
(
Configure.Vectorizer.multi2vec_cohere(
model="embed-multilingual-v2.0",
truncate="NONE",
vectorize_collection_name=False,
base_url="https://api.cohere.ai",
),
{
"multi2vec-cohere": {
"model": "embed-multilingual-v2.0",
"truncate": "NONE",
"vectorizeClassName": False,
"baseURL": "https://api.cohere.ai/",
}
},
),
(
Configure.Vectorizer.text2vec_gpt4all(),
{
Expand Down Expand Up @@ -1219,6 +1235,20 @@ def test_vector_config_flat_pq() -> None:
}
},
),
(
[Configure.NamedVectors.multi2vec_cohere(name="test", text_fields=["prop"])],
{
"test": {
"vectorizer": {
"multi2vec-cohere": {
"vectorizeClassName": True,
"textFields": ["prop"],
}
},
"vectorIndexType": "hnsw",
}
},
),
(
[Configure.NamedVectors.text2vec_gpt4all(name="test", source_properties=["prop"])],
{
Expand Down
54 changes: 54 additions & 0 deletions weaviate/collections/classes/config_named_vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
_VectorizerCustomConfig,
_Text2VecDatabricksConfig,
_Text2VecVoyageConfig,
_Multi2VecCohereConfig,
)
from ...warnings import _Warnings

Expand Down Expand Up @@ -196,6 +197,59 @@ def text2vec_cohere(
vector_index_config=vector_index_config,
)

@staticmethod
def multi2vec_cohere(
name: str,
*,
vector_index_config: Optional[_VectorIndexConfigCreate] = None,
vectorize_collection_name: bool = True,
base_url: Optional[AnyHttpUrl] = None,
model: Optional[Union[CohereModel, str]] = None,
truncate: Optional[CohereTruncation] = None,
image_fields: Optional[Union[List[str], List[Multi2VecField]]] = None,
text_fields: Optional[Union[List[str], List[Multi2VecField]]] = None,
) -> _NamedVectorConfigCreate:
"""Create a named vector using the `multi2vec_cohere` model.
See the [documentation](https://weaviate.io/developers/weaviate/modules/retriever-vectorizer-modules/text2vec-cohere)
for detailed usage.
Arguments:
`name`
The name of the named vector.
`vector_index_config`
The configuration for Weaviate's vector index. Use wvc.config.Configure.VectorIndex to create a vector index configuration. None by default
`vectorize_collection_name`
Whether to vectorize the collection name. Defaults to `True`.
`model`
The model to use. Defaults to `None`, which uses the server-defined default.
`truncate`
The truncation strategy to use. Defaults to `None`, which uses the server-defined default.
`vectorize_collection_name`
Whether to vectorize the collection name. Defaults to `True`.
`base_url`
The base URL to use where API requests should go. Defaults to `None`, which uses the server-defined default.
`image_fields`
The image fields to use in vectorization.
`text_fields`
The text fields to use in vectorization.
Raises:
`pydantic.ValidationError` if `truncate` is not a valid value from the `CohereModel` type.
"""
return _NamedVectorConfigCreate(
name=name,
vectorizer=_Multi2VecCohereConfig(
baseURL=base_url,
model=model,
truncate=truncate,
vectorizeClassName=vectorize_collection_name,
imageFields=_map_multi2vec_fields(image_fields),
textFields=_map_multi2vec_fields(text_fields),
),
vector_index_config=vector_index_config,
)

@staticmethod
def text2vec_contextionary(
name: str,
Expand Down
57 changes: 57 additions & 0 deletions weaviate/collections/classes/config_vectorizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ class Vectorizers(str, Enum):
TEXT2VEC_VOYAGEAI = "text2vec-voyageai"
IMG2VEC_NEURAL = "img2vec-neural"
MULTI2VEC_CLIP = "multi2vec-clip"
MULTI2VEC_COHERE = "multi2vec-cohere"
MULTI2VEC_BIND = "multi2vec-bind"
MULTI2VEC_PALM = "multi2vec-palm" # change to google once 1.27 is the lowest supported version
REF2VEC_CENTROID = "ref2vec-centroid"
Expand Down Expand Up @@ -374,6 +375,21 @@ def _to_dict(self) -> Dict[str, Any]:
return ret_dict


class _Multi2VecCohereConfig(_Multi2VecBase):
vectorizer: Union[Vectorizers, _EnumLikeStr] = Field(
default=Vectorizers.MULTI2VEC_COHERE, frozen=True, exclude=True
)
baseURL: Optional[AnyHttpUrl]
model: Optional[str]
truncate: Optional[CohereTruncation]

def _to_dict(self) -> Dict[str, Any]:
ret_dict = super()._to_dict()
if self.baseURL is not None:
ret_dict["baseURL"] = self.baseURL.unicode_string()
return ret_dict


class _Multi2VecClipConfig(_Multi2VecBase):
vectorizer: Union[Vectorizers, _EnumLikeStr] = Field(
default=Vectorizers.MULTI2VEC_CLIP, frozen=True, exclude=True
Expand Down Expand Up @@ -698,6 +714,47 @@ def text2vec_cohere(
vectorizeClassName=vectorize_collection_name,
)

@staticmethod
def multi2vec_cohere(
*,
model: Optional[Union[CohereModel, str]] = None,
truncate: Optional[CohereTruncation] = None,
vectorize_collection_name: bool = True,
base_url: Optional[AnyHttpUrl] = None,
image_fields: Optional[Union[List[str], List[Multi2VecField]]] = None,
text_fields: Optional[Union[List[str], List[Multi2VecField]]] = None,
) -> _VectorizerConfigCreate:
"""Create a `_Multi2VecCohereConfig` object for use when vectorizing using the `multi2vec-cohere` model.
See the [documentation](https://weaviate.io/developers/weaviate/modules/retriever-vectorizer-modules/text2vec-cohere)
for detailed usage.
Arguments:
`model`
The model to use. Defaults to `None`, which uses the server-defined default.
`truncate`
The truncation strategy to use. Defaults to `None`, which uses the server-defined default.
`vectorize_collection_name`
Whether to vectorize the collection name. Defaults to `True`.
`base_url`
The base URL to use where API requests should go. Defaults to `None`, which uses the server-defined default.
`image_fields`
The image fields to use in vectorization.
`text_fields`
The text fields to use in vectorization.
Raises:
`pydantic.ValidationError` if `truncate` is not a valid value from the `CohereModel` type.
"""
return _Multi2VecCohereConfig(
baseURL=base_url,
model=model,
truncate=truncate,
vectorizeClassName=vectorize_collection_name,
imageFields=_map_multi2vec_fields(image_fields),
textFields=_map_multi2vec_fields(text_fields),
)

@staticmethod
def text2vec_databricks(
*,
Expand Down

0 comments on commit 8f62889

Please sign in to comment.