diff --git a/integration/test_collection.py b/integration/test_collection.py index 561ae6c70..00d4554bb 100644 --- a/integration/test_collection.py +++ b/integration/test_collection.py @@ -1,10 +1,12 @@ import datetime import io import pathlib +import struct import time import uuid from typing import Any, Callable, Dict, List, Optional, Sequence, TypedDict, Union +import numpy as np import pytest from integration.conftest import CollectionFactory, CollectionFactoryGet, _sanitize_collection_name @@ -293,6 +295,63 @@ class TestInsertManyWithTypedDict(TypedDict): assert obj2.properties["name"] == "some other name" +@pytest.mark.parametrize( + "objects, should_error", + [ + ( + [ + DataObject( + properties={"name": "some numpy one"}, vector=np.array([1, 2, 3]) + ), + ], + False, + ), + ( + [ + DataObject( + properties={"name": "some numpy one"}, vector=np.array([1, 2, 3]) + ), + DataObject( + properties={"name": "some numpy two"}, vector=np.array([11, 12, 13]) + ), + ], + False, + ), + ( + [ + DataObject( + properties={"name": "some numpy 2d"}, + vector=np.array([[1, 2, 3], [11, 12, 13]]), + ), + ], + True, + ), + ], +) +def test_insert_many_with_numpy( + collection_factory: CollectionFactory, + objects: Sequence[DataObject[WeaviateProperties, Any]], + should_error: bool, +) -> None: + if isinstance(objects[0].vector, list): + pytest.skip("numpy not available") + collection = collection_factory( + properties=[Property(name="Name", data_type=DataType.TEXT)], + vectorizer_config=Configure.Vectorizer.none(), + ) + if not should_error: + ret = collection.data.insert_many(objects) + for idx, uuid_ in ret.uuids.items(): + obj1 = collection.query.fetch_object_by_id(uuid_, include_vector=True) + inserted = objects[idx] + assert inserted.properties["name"] == obj1.properties["name"] + assert inserted.vector.tolist() == obj1.vector["default"] # type: ignore[union-attr] + else: + with pytest.raises(struct.error) as e: + collection.data.insert_many(objects) + assert str(e.value) == "required argument is not a float" + + def test_insert_many_with_refs(collection_factory: CollectionFactory) -> None: ref_collection = collection_factory( name="target", vectorizer_config=Configure.Vectorizer.none() diff --git a/weaviate/collections/batch/grpc_batch_objects.py b/weaviate/collections/batch/grpc_batch_objects.py index 84d7fddaa..a743c9086 100644 --- a/weaviate/collections/batch/grpc_batch_objects.py +++ b/weaviate/collections/batch/grpc_batch_objects.py @@ -27,11 +27,15 @@ from weaviate.util import _datetime_to_string, _get_vector_v4 +def _pack_vector(vector: Any) -> bytes: + return struct.pack("{}f".format(len(vector)), *vector) + + def _pack_named_vectors(vectors: Dict[str, List[float]]) -> List[base_pb2.Vectors]: return [ base_pb2.Vectors( name=name, - vector_bytes=struct.pack("{}f".format(len(vector)), *vector), + vector_bytes=_pack_vector(vector), ) for name, vector in vectors.items() ] @@ -48,18 +52,9 @@ def __init__(self, connection: ConnectionV4, consistency_level: Optional[Consist super().__init__(connection, consistency_level) def __grpc_objects(self, objects: List[_BatchObject]) -> List[batch_pb2.BatchObject]: - def pack_vector(vector: Any) -> bytes: - vector_list = _get_vector_v4(vector) - return struct.pack("{}f".format(len(vector_list)), *vector_list) - return [ batch_pb2.BatchObject( collection=obj.collection, - vector_bytes=( - pack_vector(obj.vector) - if obj.vector is not None and isinstance(obj.vector, list) - else None - ), uuid=str(obj.uuid) if obj.uuid is not None else str(uuid_package.uuid4()), properties=( self.__translate_properties_from_python_to_grpc( @@ -70,6 +65,11 @@ def pack_vector(vector: Any) -> bytes: else None ), tenant=obj.tenant, + vector_bytes=( + _pack_vector(obj.vector) + if obj.vector is not None and isinstance(obj.vector, list) + else None + ), vectors=( _pack_named_vectors(obj.vector) if obj.vector is not None and isinstance(obj.vector, dict) diff --git a/weaviate/collections/data/data.py b/weaviate/collections/data/data.py index b86d0ac94..2e64dd84a 100644 --- a/weaviate/collections/data/data.py +++ b/weaviate/collections/data/data.py @@ -363,7 +363,12 @@ async def insert_many( ( _BatchObject( collection=self.name, - vector=obj.vector, + vector=( + obj.vector + if obj.vector is None + or isinstance(obj.vector, dict) + else _get_vector_v4(obj.vector) + ), uuid=str(obj.uuid if obj.uuid is not None else uuid_package.uuid4()), properties=cast(dict, obj.properties), tenant=self._tenant,