Skip to content

Commit

Permalink
get_items return numpy array
Browse files Browse the repository at this point in the history
  • Loading branch information
dyashuni committed Aug 12, 2023
1 parent f6d170c commit 4f7b192
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ For other spaces use the nmslib library https://github.com/nmslib/nmslib.

* `set_num_threads(num_threads)` set the default number of cpu threads used during data insertion/querying.

* `get_items(ids)` - returns a numpy array (shape:`N*dim`) of vectors that have integer identifiers specified in `ids` numpy vector (shape:`N`). Note that for cosine similarity it currently returns **normalized** vectors.
* `get_items(ids, return_type = 'numpy')` - returns a numpy array (shape:`N*dim`) of vectors that have integer identifiers specified in `ids` numpy vector (shape:`N`) if `return_type` is `list` return list of lists. Note that for cosine similarity it currently returns **normalized** vectors.

* `get_ids_list()` - returns a list of all elements' ids.

Expand Down
15 changes: 12 additions & 3 deletions python_bindings/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,11 @@ class Index {
}


std::vector<std::vector<data_t>> getDataReturnList(py::object ids_ = py::none()) {
py::object getData(py::object ids_ = py::none(), std::string return_type = "numpy") {
std::vector<std::string> return_types{"numpy", "list"};
if (std::find(std::begin(return_types), std::end(return_types), return_type) == std::end(return_types)) {
throw std::invalid_argument("return_type should be \"numpy\" or \"list\"");
}
std::vector<size_t> ids;
if (!ids_.is_none()) {
py::array_t < size_t, py::array::c_style | py::array::forcecast > items(ids_);
Expand All @@ -325,7 +329,12 @@ class Index {
for (auto id : ids) {
data.push_back(appr_alg->template getDataByLabel<data_t>(id));
}
return data;
if (return_type == "list") {
return py::cast(data);
}
if (return_type == "numpy") {
return py::array_t< data_t, py::array::c_style | py::array::forcecast >(py::cast(data));
}
}


Expand Down Expand Up @@ -925,7 +934,7 @@ PYBIND11_PLUGIN(hnswlib) {
py::arg("ids") = py::none(),
py::arg("num_threads") = -1,
py::arg("replace_deleted") = false)
.def("get_items", &Index<float, float>::getDataReturnList, py::arg("ids") = py::none())
.def("get_items", &Index<float>::getData, py::arg("ids") = py::none(), py::arg("return_type") = "numpy")
.def("get_ids_list", &Index<float>::getIdsList)
.def("set_ef", &Index<float>::set_ef, py::arg("ef"))
.def("set_num_threads", &Index<float>::set_num_threads, py::arg("num_threads"))
Expand Down

0 comments on commit 4f7b192

Please sign in to comment.