Skip to content

Commit

Permalink
create_collection handles nonvector collections; added requirements-d…
Browse files Browse the repository at this point in the history
…ev.txt
  • Loading branch information
hemidactylus committed Nov 16, 2023
1 parent 5bce510 commit e015112
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 27 deletions.
59 changes: 32 additions & 27 deletions astrapy/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,48 +740,53 @@ def get_collections(self):
return response

def create_collection(
self, collection_name, *, options=None, dimension=None, metric=""
self, collection_name, *, options=None, dimension=None, metric=None
):
"""
Create a new collection in the database.
Args:
collection_name (str): The name of the collection to create.
options (dict, optional): Options for the collection.
dimension (int, optional): Dimension for vector search.
metric (str, optional): Metric type for vector search.
metric (str, optional): Metric choice for vector search.
Returns:
AstraDBCollection: The created collection object.
"""
# Make sure we provide a collection name
if not collection_name:
raise ValueError("Must provide a collection name")

# Initialize options if not passed
if not options:
options = {"vector": {}}
elif "vector" not in options:
options["vector"] = {}

# Now check the remaining parameters - dimension
if dimension:
# options from named params
vector_options = {
k: v
for k, v in {
"dimension": dimension,
"metric": metric,
}.items()
if v is not None
}
# overlap/merge with stuff in options.vector
dup_params = set((options or {}).get("vector", {}).keys()) & set(
vector_options.keys()
)
if dup_params:
dups = ", ".join(sorted(dup_params))
raise ValueError(
f"Parameter(s) {dups} passed both to the method and in the options"
)
if vector_options:
options = options or {}
options["vector"] = {
**options.get("vector", {}),
**vector_options,
}
if "dimension" not in options["vector"]:
options["vector"]["dimension"] = dimension
else:
raise ValueError(
"dimension parameter provided both in options and as function parameter."
)

# Check the metric parameter
if metric:
if "metric" not in options["vector"]:
options["vector"]["metric"] = metric
else:
raise ValueError(
"metric parameter provided both in options as function parameter."
)
raise ValueError("Must pass dimension for vector collections")

# Build the final json payload
jsondata = {"name": collection_name, "options": options}
jsondata = {
k: v
for k, v in {"name": collection_name, "options": options}.items()
if v is not None
}

# Make the request to the endpoitn
self._request(
Expand Down
3 changes: 3 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
black==23.11.0
mypy==1.7.0
ruff==0.1.5
23 changes: 23 additions & 0 deletions tests/astrapy/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
TEST_COLLECTION_NAME = "test_collection"
TEST_FIXTURE_COLLECTION_NAME = "test_fixture_collection"
TEST_FIXTURE_PROJECTION_COLLECTION_NAME = "test_projection_collection"
TEST_NONVECTOR_COLLECTION_NAME = "test_nonvector_collection"


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -165,6 +166,28 @@ def test_create_collection(db):
assert isinstance(res, AstraDBCollection)


@pytest.mark.describe("should create and use a non-vector collection")
def test_nonvector_collection(db):
col = db.create_collection(TEST_NONVECTOR_COLLECTION_NAME)
col.insert_one({"_id": "first", "name": "a"})
col.insert_many(
[
{"_id": "second", "name": "b", "room": 7},
{"name": "c", "room": 7},
{"_id": "last", "type": "unnamed", "room": 7},
]
)
docs = col.find(filter={"room": 7}, projection={"name": 1})
ids = [doc["_id"] for doc in docs["data"]["documents"]]
assert len(ids) == 3
assert "second" in ids
assert "first" not in ids
auto_id = [id for id in ids if id not in {"second", "last"}][0]
col.delete(auto_id)
assert col.find_one(filter={"name": "c"})["data"]["document"] is None
db.delete_collection(TEST_NONVECTOR_COLLECTION_NAME)


@pytest.mark.describe("should get all collections")
def test_get_collections(db):
res = db.get_collections()
Expand Down

0 comments on commit e015112

Please sign in to comment.