From e01511268a400400b1ae7721375190b8ab38160b Mon Sep 17 00:00:00 2001 From: Stefano Lottini Date: Thu, 16 Nov 2023 03:16:08 +0100 Subject: [PATCH] create_collection handles nonvector collections; added requirements-dev.txt --- astrapy/db.py | 59 ++++++++++++++++++++++------------------ requirements-dev.txt | 3 ++ tests/astrapy/test_db.py | 23 ++++++++++++++++ 3 files changed, 58 insertions(+), 27 deletions(-) create mode 100644 requirements-dev.txt diff --git a/astrapy/db.py b/astrapy/db.py index 792b107c..3fd5e221 100644 --- a/astrapy/db.py +++ b/astrapy/db.py @@ -740,7 +740,7 @@ def get_collections(self): return response def create_collection( - self, collection_name, *, options=None, dimension=None, metric="" + self, collection_name, *, options=None, dimension=None, metric=None ): """ Create a new collection in the database. @@ -748,40 +748,45 @@ def create_collection( collection_name (str): The name of the collection to create. options (dict, optional): Options for the collection. dimension (int, optional): Dimension for vector search. - metric (str, optional): Metric type for vector search. + metric (str, optional): Metric choice for vector search. Returns: AstraDBCollection: The created collection object. """ - # Make sure we provide a collection name if not collection_name: raise ValueError("Must provide a collection name") - - # Initialize options if not passed - if not options: - options = {"vector": {}} - elif "vector" not in options: - options["vector"] = {} - - # Now check the remaining parameters - dimension - if dimension: + # options from named params + vector_options = { + k: v + for k, v in { + "dimension": dimension, + "metric": metric, + }.items() + if v is not None + } + # overlap/merge with stuff in options.vector + dup_params = set((options or {}).get("vector", {}).keys()) & set( + vector_options.keys() + ) + if dup_params: + dups = ", ".join(sorted(dup_params)) + raise ValueError( + f"Parameter(s) {dups} passed both to the method and in the options" + ) + if vector_options: + options = options or {} + options["vector"] = { + **options.get("vector", {}), + **vector_options, + } if "dimension" not in options["vector"]: - options["vector"]["dimension"] = dimension - else: - raise ValueError( - "dimension parameter provided both in options and as function parameter." - ) - - # Check the metric parameter - if metric: - if "metric" not in options["vector"]: - options["vector"]["metric"] = metric - else: - raise ValueError( - "metric parameter provided both in options as function parameter." - ) + raise ValueError("Must pass dimension for vector collections") # Build the final json payload - jsondata = {"name": collection_name, "options": options} + jsondata = { + k: v + for k, v in {"name": collection_name, "options": options}.items() + if v is not None + } # Make the request to the endpoitn self._request( diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 00000000..31dd823d --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,3 @@ +black==23.11.0 +mypy==1.7.0 +ruff==0.1.5 diff --git a/tests/astrapy/test_db.py b/tests/astrapy/test_db.py index 919f48e7..362b80c5 100644 --- a/tests/astrapy/test_db.py +++ b/tests/astrapy/test_db.py @@ -38,6 +38,7 @@ TEST_COLLECTION_NAME = "test_collection" TEST_FIXTURE_COLLECTION_NAME = "test_fixture_collection" TEST_FIXTURE_PROJECTION_COLLECTION_NAME = "test_projection_collection" +TEST_NONVECTOR_COLLECTION_NAME = "test_nonvector_collection" @pytest.fixture(scope="module") @@ -165,6 +166,28 @@ def test_create_collection(db): assert isinstance(res, AstraDBCollection) +@pytest.mark.describe("should create and use a non-vector collection") +def test_nonvector_collection(db): + col = db.create_collection(TEST_NONVECTOR_COLLECTION_NAME) + col.insert_one({"_id": "first", "name": "a"}) + col.insert_many( + [ + {"_id": "second", "name": "b", "room": 7}, + {"name": "c", "room": 7}, + {"_id": "last", "type": "unnamed", "room": 7}, + ] + ) + docs = col.find(filter={"room": 7}, projection={"name": 1}) + ids = [doc["_id"] for doc in docs["data"]["documents"]] + assert len(ids) == 3 + assert "second" in ids + assert "first" not in ids + auto_id = [id for id in ids if id not in {"second", "last"}][0] + col.delete(auto_id) + assert col.find_one(filter={"name": "c"})["data"]["document"] is None + db.delete_collection(TEST_NONVECTOR_COLLECTION_NAME) + + @pytest.mark.describe("should get all collections") def test_get_collections(db): res = db.get_collections()