diff --git a/test/collection/test_config.py b/test/collection/test_config.py index 87f102b19..4b384b4d5 100644 --- a/test/collection/test_config.py +++ b/test/collection/test_config.py @@ -82,6 +82,22 @@ def test_basic_config(): } }, ), + ( + Configure.Vectorizer.multi2vec_cohere( + model="embed-multilingual-v2.0", + truncate="NONE", + vectorize_collection_name=False, + base_url="https://api.cohere.ai", + ), + { + "multi2vec-cohere": { + "model": "embed-multilingual-v2.0", + "truncate": "NONE", + "vectorizeClassName": False, + "baseURL": "https://api.cohere.ai/", + } + }, + ), ( Configure.Vectorizer.text2vec_gpt4all(), { @@ -1219,6 +1235,20 @@ def test_vector_config_flat_pq() -> None: } }, ), + ( + [Configure.NamedVectors.multi2vec_cohere(name="test", text_fields=["prop"])], + { + "test": { + "vectorizer": { + "multi2vec-cohere": { + "vectorizeClassName": True, + "textFields": ["prop"], + } + }, + "vectorIndexType": "hnsw", + } + }, + ), ( [Configure.NamedVectors.text2vec_gpt4all(name="test", source_properties=["prop"])], { diff --git a/weaviate/collections/classes/config_named_vectors.py b/weaviate/collections/classes/config_named_vectors.py index 8c93040d0..a8dd09185 100644 --- a/weaviate/collections/classes/config_named_vectors.py +++ b/weaviate/collections/classes/config_named_vectors.py @@ -50,6 +50,7 @@ _VectorizerCustomConfig, _Text2VecDatabricksConfig, _Text2VecVoyageConfig, + _Multi2VecCohereConfig, ) from ...warnings import _Warnings @@ -196,6 +197,59 @@ def text2vec_cohere( vector_index_config=vector_index_config, ) + @staticmethod + def multi2vec_cohere( + name: str, + *, + vector_index_config: Optional[_VectorIndexConfigCreate] = None, + vectorize_collection_name: bool = True, + base_url: Optional[AnyHttpUrl] = None, + model: Optional[Union[CohereModel, str]] = None, + truncate: Optional[CohereTruncation] = None, + image_fields: Optional[Union[List[str], List[Multi2VecField]]] = None, + text_fields: Optional[Union[List[str], List[Multi2VecField]]] = None, + ) -> _NamedVectorConfigCreate: + """Create a named vector using the `multi2vec_cohere` model. + + See the [documentation](https://weaviate.io/developers/weaviate/modules/retriever-vectorizer-modules/text2vec-cohere) + for detailed usage. + + Arguments: + `name` + The name of the named vector. + `vector_index_config` + The configuration for Weaviate's vector index. Use wvc.config.Configure.VectorIndex to create a vector index configuration. None by default + `vectorize_collection_name` + Whether to vectorize the collection name. Defaults to `True`. + `model` + The model to use. Defaults to `None`, which uses the server-defined default. + `truncate` + The truncation strategy to use. Defaults to `None`, which uses the server-defined default. + `vectorize_collection_name` + Whether to vectorize the collection name. Defaults to `True`. + `base_url` + The base URL to use where API requests should go. Defaults to `None`, which uses the server-defined default. + `image_fields` + The image fields to use in vectorization. + `text_fields` + The text fields to use in vectorization. + + Raises: + `pydantic.ValidationError` if `truncate` is not a valid value from the `CohereModel` type. + """ + return _NamedVectorConfigCreate( + name=name, + vectorizer=_Multi2VecCohereConfig( + baseURL=base_url, + model=model, + truncate=truncate, + vectorizeClassName=vectorize_collection_name, + imageFields=_map_multi2vec_fields(image_fields), + textFields=_map_multi2vec_fields(text_fields), + ), + vector_index_config=vector_index_config, + ) + @staticmethod def text2vec_contextionary( name: str, diff --git a/weaviate/collections/classes/config_vectorizers.py b/weaviate/collections/classes/config_vectorizers.py index 93626f248..f0b81c1e3 100644 --- a/weaviate/collections/classes/config_vectorizers.py +++ b/weaviate/collections/classes/config_vectorizers.py @@ -112,6 +112,7 @@ class Vectorizers(str, Enum): TEXT2VEC_VOYAGEAI = "text2vec-voyageai" IMG2VEC_NEURAL = "img2vec-neural" MULTI2VEC_CLIP = "multi2vec-clip" + MULTI2VEC_COHERE = "multi2vec-cohere" MULTI2VEC_BIND = "multi2vec-bind" MULTI2VEC_PALM = "multi2vec-palm" # change to google once 1.27 is the lowest supported version REF2VEC_CENTROID = "ref2vec-centroid" @@ -374,6 +375,21 @@ def _to_dict(self) -> Dict[str, Any]: return ret_dict +class _Multi2VecCohereConfig(_Multi2VecBase): + vectorizer: Union[Vectorizers, _EnumLikeStr] = Field( + default=Vectorizers.MULTI2VEC_COHERE, frozen=True, exclude=True + ) + baseURL: Optional[AnyHttpUrl] + model: Optional[str] + truncate: Optional[CohereTruncation] + + def _to_dict(self) -> Dict[str, Any]: + ret_dict = super()._to_dict() + if self.baseURL is not None: + ret_dict["baseURL"] = self.baseURL.unicode_string() + return ret_dict + + class _Multi2VecClipConfig(_Multi2VecBase): vectorizer: Union[Vectorizers, _EnumLikeStr] = Field( default=Vectorizers.MULTI2VEC_CLIP, frozen=True, exclude=True @@ -698,6 +714,47 @@ def text2vec_cohere( vectorizeClassName=vectorize_collection_name, ) + @staticmethod + def multi2vec_cohere( + *, + model: Optional[Union[CohereModel, str]] = None, + truncate: Optional[CohereTruncation] = None, + vectorize_collection_name: bool = True, + base_url: Optional[AnyHttpUrl] = None, + image_fields: Optional[Union[List[str], List[Multi2VecField]]] = None, + text_fields: Optional[Union[List[str], List[Multi2VecField]]] = None, + ) -> _VectorizerConfigCreate: + """Create a `_Multi2VecCohereConfig` object for use when vectorizing using the `multi2vec-cohere` model. + + See the [documentation](https://weaviate.io/developers/weaviate/modules/retriever-vectorizer-modules/text2vec-cohere) + for detailed usage. + + Arguments: + `model` + The model to use. Defaults to `None`, which uses the server-defined default. + `truncate` + The truncation strategy to use. Defaults to `None`, which uses the server-defined default. + `vectorize_collection_name` + Whether to vectorize the collection name. Defaults to `True`. + `base_url` + The base URL to use where API requests should go. Defaults to `None`, which uses the server-defined default. + `image_fields` + The image fields to use in vectorization. + `text_fields` + The text fields to use in vectorization. + + Raises: + `pydantic.ValidationError` if `truncate` is not a valid value from the `CohereModel` type. + """ + return _Multi2VecCohereConfig( + baseURL=base_url, + model=model, + truncate=truncate, + vectorizeClassName=vectorize_collection_name, + imageFields=_map_multi2vec_fields(image_fields), + textFields=_map_multi2vec_fields(text_fields), + ) + @staticmethod def text2vec_databricks( *,