diff --git a/tests/unit/test_hnswlib_vectordb.py b/tests/unit/test_hnswlib_vectordb.py index b70c8b0..fe6aa62 100644 --- a/tests/unit/test_hnswlib_vectordb.py +++ b/tests/unit/test_hnswlib_vectordb.py @@ -169,4 +169,11 @@ def test_hnswlib_vectordb_restore(docs_to_index, tmpdir): assert len(res.matches) == 10 # assert res.id == res.matches[0].id # assert res.text == res.matches[0].text - # assert res.scores[0] < 0.001 # some precision issues, should be 0 \ No newline at end of file + # assert res.scores[0] < 0.001 # some precision issues, should be 0 + +def test_hnswlib_num_dos(tmpdir): + db = HNSWVectorDB[MyDoc](workspace=str(tmpdir)) + doc_list = [MyDoc(text=f'toy doc {i}', embedding=np.random.rand(128)) for i in range(1000)] + db.index(inputs=DocList[MyDoc](doc_list)) + x=db.num_docs() + assert x['num_docs']==1000 diff --git a/tests/unit/test_inmemory_vectordb.py b/tests/unit/test_inmemory_vectordb.py index 7b4f787..cc812e4 100644 --- a/tests/unit/test_inmemory_vectordb.py +++ b/tests/unit/test_inmemory_vectordb.py @@ -172,3 +172,10 @@ def test_inmemory_vectordb_restore(docs_to_index, tmpdir): assert res.id == res.matches[0].id assert res.text == res.matches[0].text assert res.scores[0] > 0.99 # some precision issues, should be 1 + +def test_inmemory_num_dos(tmpdir): + db = InMemoryExactNNVectorDB[MyDoc](workspace=str(tmpdir)) + doc_list = [MyDoc(text=f'toy doc {i}', embedding=np.random.rand(128)) for i in range(1000)] + db.index(inputs=DocList[MyDoc](doc_list)) + x=db.num_docs() + assert x['num_docs']==1000 diff --git a/vectordb/db/base.py b/vectordb/db/base.py index 72518ad..8338889 100644 --- a/vectordb/db/base.py +++ b/vectordb/db/base.py @@ -227,6 +227,9 @@ async def _deploy(): ret = asyncio.run(_deploy()) return ret + def num_docs(self, **kwargs): + return self._executor.num_docs() + @pass_kwargs_as_params @unify_input_output def index(self, docs: 'DocList[TSchema]', parameters: Optional[Dict] = None, **kwargs): diff --git a/vectordb/db/executors/hnsw_indexer.py b/vectordb/db/executors/hnsw_indexer.py index d4298c5..c6447bd 100644 --- a/vectordb/db/executors/hnsw_indexer.py +++ b/vectordb/db/executors/hnsw_indexer.py @@ -105,7 +105,7 @@ async def async_update(self, docs, *args, **kwargs): return self.update(docs, *args, **kwargs) def num_docs(self, **kwargs): - return {'num_docs': self._index.num_docs()} + return {'num_docs': self._indexer.num_docs()} def snapshot(self, snapshot_dir): # TODO: Maybe copy the work_dir to workspace if `handle` is False diff --git a/vectordb/db/executors/inmemory_exact_indexer.py b/vectordb/db/executors/inmemory_exact_indexer.py index 2c45f99..aec7ab8 100644 --- a/vectordb/db/executors/inmemory_exact_indexer.py +++ b/vectordb/db/executors/inmemory_exact_indexer.py @@ -71,7 +71,7 @@ def update(self, docs, *args, **kwargs): return self._index(docs) def num_docs(self, *args, **kwargs): - return {'num_docs': self._index.num_docs()} + return {'num_docs': self._indexer.num_docs()} def snapshot(self, snapshot_dir): snapshot_file = f'{snapshot_dir}/index.bin'